From e91516243e3ea35e4f35442c01ecd5add88a7246 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 18:51:53 -0500 Subject: [PATCH] Support tt.split for CPU. (#30) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 2 + .../TritonToTritonCPU/ConvertElemManipOps.cpp | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index bacd2e35dba25..c07f6139ed13a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1886,6 +1886,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([10, 20], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split(device): @@ -1908,6 +1909,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr): np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split_to_scalar(device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index 99211ea90e417..a39a93e424463 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -48,6 +48,7 @@ class ElemManipOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -166,6 +167,45 @@ struct CatOpConversion : public OpConversionPattern { } }; +struct SplitOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = cast(src.getType()); + auto resTy = getTypeConverter()->convertType(op.getType(0)); + + SmallVector results; + if (srcTy.getRank() == 1) { + results.push_back(rewriter.create(loc, src, 0)); + results.push_back(rewriter.create(loc, src, 1)); + } else { + SmallVector tmpShape({srcTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, srcTy.getElementType()), src); + + SmallVector evenIndices; + SmallVector oddIndices; + for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) { + evenIndices.push_back(i); + oddIndices.push_back(i + 1); + } + + Value res1 = + rewriter.create(loc, tmp, tmp, evenIndices); + Value res2 = + rewriter.create(loc, tmp, tmp, oddIndices); + results.push_back(rewriter.create(loc, resTy, res1)); + results.push_back(rewriter.create(loc, resTy, res2)); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + struct ConvertElemManipOps : public triton::impl::ConvertElemManipOpsBase { using ConvertElemManipOpsBase::ConvertElemManipOpsBase; @@ -187,6 +227,7 @@ struct ConvertElemManipOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure();