Skip to content

Commit

Permalink
[Torch] support AtenScalarImplicitOp canonicalize with float (#3231)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Apr 25, 2024
1 parent 4361178 commit b0ba3de
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
18 changes: 13 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,11 +1957,19 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
Location loc = op.getLoc();
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
return failure();
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
Value scalarIntValue = getScalarIntValue(a, loc, rewriter);
if (scalarIntValue) {
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
scalarIntValue);
return success();
}
Value scalarFloatValue = getScalarFloatValue(a, loc, rewriter);
if (scalarFloatValue) {
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType,
scalarFloatValue);
return success();
}
return failure();
});
}

Expand Down
18 changes: 18 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[]
return %2 : !torch.vtensor<[],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> {
// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
// CHECK: return %[[CST]] : !torch.vtensor<[],si64>
Expand All @@ -2186,6 +2188,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso
return %2 : !torch.vtensor<[],si64>
}

// -----

// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number
Expand All @@ -2197,6 +2201,8 @@ func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.num
return %1 : !torch.number
}

// -----

// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number {
// CHECK: %int1 = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = torch.derefine %int1 : !torch.int to !torch.number
Expand All @@ -2209,6 +2215,18 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number

// -----

// CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
// CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00
// CHECK: %[[VAL_0:.*]] = torch.derefine %float1.000000e00 : !torch.float to !torch.number
// CHECK: return %[[VAL_0]] : !torch.number
func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d_float() -> !torch.number {
%0 = torch.vtensor.literal(dense<1.0> : tensor<f64>) : !torch.vtensor<[],f64>
%1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],f64> -> !torch.number
return %1 : !torch.number
}

// -----

// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: return %[[FLOAT1]] : !torch.float
Expand Down

0 comments on commit b0ba3de

Please sign in to comment.