Skip to content

Commit

Permalink
[stablehlo] Support aten.any and aten.all lowering (#3217)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Yang authored Apr 25, 2024
1 parent 7be22bb commit 7030eac
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 7 deletions.
25 changes: 18 additions & 7 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (isa<AtenAllDimOp>(op)) {
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
}

if (isa<AtenAnyOp>(op)) {
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(false));
}

op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr;
}
Expand Down Expand Up @@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenAllDimOp>(op)) {
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::AndIOp>(loc, self, result);
} else if (isa<AtenAnyOp>(op)) {
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::MulIOp>(loc, self, result);
return b.create<arith::OrIOp>(loc, self, result);
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
Expand Down Expand Up @@ -510,13 +519,13 @@ class ConvertReductionOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};

if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp, AtenNormScalarOp>(
op)) {
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();

// `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce
// along all the dimensions of the input tensor.
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);

Expand Down Expand Up @@ -715,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMinDimOp>();
patterns.add<ConvertAtenMinMaxDimOp<AtenMinDimOp>>(typeConverter, context);
target.addIllegalOp<AtenSumOp>();
target.addIllegalOp<AtenAnyOp>();
target.addIllegalOp<AtenAllOp>();
target.addIllegalOp<AtenSumDimIntListOp>();
target.addIllegalOp<AtenProdOp>();
target.addIllegalOp<AtenProdDimIntOp>();
Expand Down
158 changes: 158 additions & 0 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenAllOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

if (isa<AtenAnyOp>(op)) {
auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
constAttr);
}

op->emitError("unimplemented lowering in "
"createInitialValueForReduceOp");
return nullptr;
Expand Down Expand Up @@ -463,6 +475,150 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
}
} // namespace

// AtenAllOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAllOp>::matchAndRewrite(
AtenAllOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value allResult = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), allResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenAnyOp
namespace {
template <>
LogicalResult ConvertAtenReductionOp<AtenAnyOp>::matchAndRewrite(
AtenAnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.getSelf();
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
if (!inputTy) {
return rewriter.notifyMatchFailure(
op, "only Tensor types supported in StableHLO");
}
auto inputElemTy = inputTy.getElementType();

// Currently, (u)int8 dtype is not supported
if (isa<mlir::IntegerType>(inputElemTy) &&
inputElemTy.getIntOrFloatBitWidth() == 8) {
return rewriter.notifyMatchFailure(
op, "IntegerType with bitwidth 8 unsupported in convertion from "
"AtenAllOp to StableHLO");
}
auto outTy = getTypeConverter()
->convertType(op.getType())
.template dyn_cast<RankedTensorType>();

if (inputElemTy != outTy.getElementType()) {
// Use output bool type as computation type.
auto dstElemTy = outTy.getElementType();
input =
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
inputTy = input.getType().dyn_cast<RankedTensorType>();
inputElemTy = inputTy.getElementType();
}

SmallVector<int64_t> dims;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
dims.push_back(i);
}

Value initValue =
createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter);
if (!initValue)
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));

Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());

block.addArgument(blockArgumentTy, op->getLoc());
block.addArgument(blockArgumentTy, op->getLoc());

auto *firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();

{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value anyResult = rewriter.create<stablehlo::OrOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), anyResult);
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, getTypeConverter()->convertType(op.getType()),
stablehloReduceOp.getResults());
return success();
}
} // namespace

// AtenProdOp
namespace {
template <>
Expand Down Expand Up @@ -1052,6 +1208,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp);
INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp);
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,12 @@
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"ReduceAllFloatModule_basic",
"ReduceAllIntModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceAnyIntModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
Expand Down Expand Up @@ -1813,6 +1819,8 @@
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"ReduceAllBoolModule_basic",
"ReduceAnyBoolModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
Expand Down Expand Up @@ -2721,6 +2729,7 @@
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
}
Expand Down
114 changes: 114 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,120 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils):

# ==============================================================================

class ReduceAllFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllFloatModule())
def ReduceAllFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceAllIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllIntModule())
def ReduceAllIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))

# ==============================================================================

class ReduceAllBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.all(a)


@register_test_case(module_factory=lambda: ReduceAllBoolModule())
def ReduceAllBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))

# ==============================================================================

class ReduceAnyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyFloatModule())
def ReduceAnyFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceAnyIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyIntModule())
def ReduceAnyIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32))

# ==============================================================================

class ReduceAnyBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.any(a)


@register_test_case(module_factory=lambda: ReduceAnyBoolModule())
def ReduceAnyBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, high=2).to(torch.bool))

# ==============================================================================

class ReduceSumDimIntListFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 7030eac

Please sign in to comment.