Skip to content

Commit

Permalink
Support tt.split for CPU.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Jun 20, 2024
1 parent 44035c9 commit a48bd7a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,6 +1886,7 @@ def kernel(X, Y, Z):
np.testing.assert_equal([10, 20], to_numpy(z))


@pytest.mark.cpu
@pytest.mark.interpreter
def test_split(device):

Expand All @@ -1908,6 +1909,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr):
np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2))


@pytest.mark.cpu
@pytest.mark.interpreter
def test_split_to_scalar(device):

Expand Down
41 changes: 41 additions & 0 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ElemManipOpConversionTarget : public ConversionTarget {
addIllegalOp<triton::TransOp>();
addIllegalOp<triton::JoinOp>();
addIllegalOp<triton::CatOp>();
addIllegalOp<triton::SplitOp>();
}
};

Expand Down Expand Up @@ -166,6 +167,45 @@ struct CatOpConversion : public OpConversionPattern<triton::CatOp> {
}
};

struct SplitOpConversion : public OpConversionPattern<triton::SplitOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto src = rewriter.getRemappedValue(op.getSrc());
auto srcTy = cast<VectorType>(src.getType());
auto resTy = getTypeConverter()->convertType(op.getType(0));

SmallVector<Value> results;
if (srcTy.getRank() == 1) {
results.push_back(rewriter.create<vector::ExtractOp>(loc, src, 0));
results.push_back(rewriter.create<vector::ExtractOp>(loc, src, 1));
} else {
SmallVector<int64_t> tmpShape({srcTy.getNumElements()});
auto tmp = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get(tmpShape, srcTy.getElementType()), src);

SmallVector<int64_t> evenIndices;
SmallVector<int64_t> oddIndices;
for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) {
evenIndices.push_back(i);
oddIndices.push_back(i + 1);
}

Value res1 =
rewriter.create<vector::ShuffleOp>(loc, tmp, tmp, evenIndices);
Value res2 =
rewriter.create<vector::ShuffleOp>(loc, tmp, tmp, oddIndices);
results.push_back(rewriter.create<vector::ShapeCastOp>(loc, resTy, res1));
results.push_back(rewriter.create<vector::ShapeCastOp>(loc, resTy, res2));
}
rewriter.replaceOp(op, results);
return success();
}
};

struct ConvertElemManipOps
: public triton::impl::ConvertElemManipOpsBase<ConvertElemManipOps> {
using ConvertElemManipOpsBase::ConvertElemManipOpsBase;
Expand All @@ -187,6 +227,7 @@ struct ConvertElemManipOps
patterns.add<TransOpConversion>(typeConverter, context);
patterns.add<JoinOpConversion>(typeConverter, context);
patterns.add<CatOpConversion>(typeConverter, context);
patterns.add<SplitOpConversion>(typeConverter, context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
Expand Down

0 comments on commit a48bd7a

Please sign in to comment.