Skip to content

Commit

Permalink
[torch-mlir][sparse] recognize sparse tensor conversion (#3226)
Browse files Browse the repository at this point in the history
Sparse tensor conversions are represented by special aten operators.
This PR ensures the conversions are recognized (instead of failing the
full torch aten lowering to linalg).
  • Loading branch information
aartbik authored Apr 25, 2024
1 parent 9e2fe47 commit 4361178
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
42 changes: 42 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
Expand Down Expand Up @@ -2423,6 +2424,42 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
};
} // namespace

namespace {
class ConvertSparseOperatorOp : public OpConversionPattern<OperatorOp> {
public:
using OpConversionPattern::OpConversionPattern;

static bool isSparsePrimitive(StringRef prim) {
return llvm::find(legalizedNames, prim) != legalizedNames.end();
}

// Rewriting method.
LogicalResult
matchAndRewrite(OperatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isSparsePrimitive(op.getNameAttr()))
return failure();
// Conversion is completed specified by information in the sparse tensor
// type. Thus, we can rewrite all legalizedNames to the same construct.
RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<sparse_tensor::ConvertOp>(
op, resultType, adaptor.getOperands()[0]);
return success();
}

private:
// The operators that legalize to sparse tensor conversions.
static SmallVector<StringRef> legalizedNames;
};
// Static initializer.
SmallVector<StringRef> ConvertSparseOperatorOp::legalizedNames = {
"torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc",
"torch.aten._to_bsr", "torch.aten._to_bsc",
};
} // namespace

void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -2469,4 +2506,9 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
target.addIllegalOp<AtenDiagEmbedOp>();
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
// Rewrite all special sparse conversions hidden as operators.
target.addDynamicallyLegalOp<OperatorOp>([&](Torch::OperatorOp op) {
return !ConvertSparseOperatorOp::isSparsePrimitive(op.getNameAttr());
});
patterns.add<ConvertSparseOperatorOp>(typeConverter, context);
}
9 changes: 5 additions & 4 deletions lib/Conversion/TorchToLinalg/TorchToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
Expand Down Expand Up @@ -53,10 +54,10 @@ class ConvertTorchToLinalg
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
cf::ControlFlowDialect, math::MathDialect,
tensor::TensorDialect, arith::ArithDialect,
complex::ComplexDialect>();
target.addLegalDialect<
linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect,
math::MathDialect, sparse_tensor::SparseTensorDialect,
tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();

TypeConverter typeConverter;
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/TorchToLinalg/sparse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,35 @@ func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>,
!torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32>
return %0 : !torch.vtensor<[8,8],f32>
}

// -----

#sparse = #sparse_tensor.encoding<{
map = (d0, d1, d2, d3, d4) ->
(d0 : compressed(nonunique),
d1 : singleton(nonunique, soa),
d2 : singleton(nonunique, soa),
d3 : singleton(nonunique, soa),
d4 : singleton(soa)
),
posWidth = 64,
crdWidth = 64
}>

// CHECK: #[[$ST:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3, d4) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(nonunique, soa), d3 : singleton(nonunique, soa), d4 : singleton(soa)), posWidth = 64, crdWidth = 64 }>
// CHECK-LABEL: func.func @activate(
// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[128,64,30,30,6],f32>)
// CHECK: %[[D:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[128,64,30,30,6],f32> -> tensor<128x64x30x30x6xf32>
// CHECK: %[[C:.*]] = sparse_tensor.convert %0 : tensor<128x64x30x30x6xf32> to tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: %[[R:.*]] = torch_c.from_builtin_tensor %[[C]] : tensor<128x64x30x30x6xf32, #[[$ST]]>
// CHECK: return %[[R]] : !torch.vtensor<[128,64,30,30,6],f32,#[[$ST]]>
func.func @activate(%arg0: !torch.vtensor<[128,64,30,30,6],f32>)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse> {
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%none_2 = torch.constant.none
%result = torch.operator "torch.aten._to_sparse"(%arg0, %none_0, %none_1, %none_2)
: (!torch.vtensor<[128,64,30,30,6],f32>, !torch.none, !torch.none, !torch.none)
-> !torch.vtensor<[128,64,30,30,6],f32,#sparse>
return %result : !torch.vtensor<[128,64,30,30,6],f32,#sparse>
}

0 comments on commit 4361178

Please sign in to comment.