Skip to content

Commit

Permalink
[cpu] Get more of test_random.py working (#77)
Browse files Browse the repository at this point in the history
Two changes:

1. Lower mulhiui to mulhi_extended rather than a bunch of ext / trunc
   instructions. This greatly simplifies the code, as well as automatically gets
   it working for 64-bit inputs, which the previous implementation did not handle
   correctly
2. Only emit extsi if the result bitwidth is larger than the input bitwidth.
   Otherwise it fails validation.

This gets the int64 tests in test_random.py passing.

Fixes #71.
  • Loading branch information
int3 authored Jul 26, 2024
1 parent 932383c commit f3a642c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
18 changes: 3 additions & 15 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "cpu/include/TritonToTritonCPU/Passes.h"

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -90,21 +91,8 @@ struct MulhiUIOpConversion : public OpConversionPattern<triton::MulhiUIOp> {
auto loc = op.getLoc();
auto lhs = rewriter.getRemappedValue(op.getX());
auto rhs = rewriter.getRemappedValue(op.getY());

Type extUITy = toInt64(lhs.getType());
Type truncITy = toInt32(lhs.getType());
Value cst32 = intCst(loc, extUITy, 32LL, rewriter);
auto lhsTy = getElementTypeOrSelf(lhs.getType());
auto rhsTy = getElementTypeOrSelf(rhs.getType());
if (lhsTy.getIntOrFloatBitWidth() < 64) {
lhs = rewriter.create<arith::ExtUIOp>(loc, extUITy, lhs);
}
if (rhsTy.getIntOrFloatBitWidth() < 64) {
rhs = rewriter.create<arith::ExtUIOp>(loc, extUITy, rhs);
}
Value res = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
res = rewriter.create<arith::ShRUIOp>(loc, res, cst32);
res = rewriter.create<arith::TruncIOp>(loc, truncITy, res);
Value res =
rewriter.create<arith::MulUIExtendedOp>(loc, lhs, rhs).getHigh();
rewriter.replaceOp(op, res);
return success();
}
Expand Down
4 changes: 3 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@ struct AddPtrOpConversion : public OpConversionPattern<triton::AddPtrOp> {
assert(isa<VectorType>(offset.getType()));
assert(isa<VectorType>(ptr.getType()));
VectorType offsetTy = cast<VectorType>(offset.getType());
VectorType ptrTy = cast<VectorType>(ptr.getType());
// Build scale vector. i1 elements take 1 byte.
Value scale = rewriter.create<arith::ConstantOp>(
loc, offsetTy,
SplatElementsAttr::get(
offsetTy, rewriter.getIntegerAttr(offsetTy.getElementType(),
(elemBitWidth + 7) / 8)));
offset = rewriter.create<arith::MulIOp>(loc, offset, scale);
offset = rewriter.create<arith::ExtSIOp>(loc, ptr.getType(), offset);
if (offsetTy.getElementTypeBitWidth() < ptrTy.getElementTypeBitWidth())
offset = rewriter.create<arith::ExtSIOp>(loc, ptr.getType(), offset);
rewriter.replaceOpWithNewOp<arith::AddIOp>(op, ptr.getType(), ptr, offset);
return success();
}
Expand Down

0 comments on commit f3a642c

Please sign in to comment.