Skip to content

Commit

Permalink
Fix incorrect casts in mask optimization. (#101)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Aug 8, 2024
1 parent 24dc81f commit f03dfc4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/test/unit/cpu/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,18 @@ def kernel(src, dst, size, TILE_SIZE: tl.constexpr):
else:
assert masked_loads == 1
assert masked_stores == 1


# Regression test for compilation failure in masks optimization
def test_vec_cdiv(device):

@triton.jit
def kernel(in_ptr, out_ptr):
offs = tl.arange(0, 16)
x = tl.load(in_ptr + offs)
res = (x + 15) // 16
tl.store(out_ptr + offs, res)

arg0 = torch.zeros((16, ), dtype=torch.int32)
arg1 = torch.empty_like(arg0)
kernel[(1, )](arg0, arg1)
5 changes: 5 additions & 0 deletions third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ struct CdivToDiv : public OpRewritePattern<arith::DivSIOp> {
LogicalResult matchAndRewrite(arith::DivSIOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// Looking for a scalar op only.
if (isa<VectorType>(op.getType()))
return failure();

Value lhs = op.getLhs();
Value rhs = op.getRhs();
auto addOpDef = lhs.getDefiningOp<arith::AddIOp>();
Expand Down

0 comments on commit f03dfc4

Please sign in to comment.