Skip to content

Commit

Permalink
fix(ONNX): avoids resizing unsupported dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 21, 2025
1 parent f6e3d3d commit a3daf2a
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 7 deletions.
145 changes: 141 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,77 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
return success();
}

Value scaleIdentityComparisonOpForFactorAtDimensionIn(
Value givenScaleFactors, int64_t givenDimension, OpBinder binder,
ConversionPatternRewriter &rewriter) {
auto typeOfScaleFactors =
cast<Torch::BaseTensorType>(givenScaleFactors.getType());

Type typeOfSelectionFromScaleFactors =
typeOfScaleFactors.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, typeOfScaleFactors.getOptionalDtype());

auto loc = binder.getLoc();

Value zeroAsOp =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));

Value scaleIdentityAsOp = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(1.0));

Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(givenDimension));

Type typeOfScaleFactor = rewriter.getType<Torch::FloatType>();

Value selectionFromScaleFactorsAsOp = rewriter.create<Torch::AtenSelectIntOp>(
loc, typeOfSelectionFromScaleFactors, givenScaleFactors, zeroAsOp,
givenDimensionAsOp);

Value scaleFactorAsOp = rewriter.create<Torch::AtenItemOp>(
loc, typeOfScaleFactor, selectionFromScaleFactorsAsOp);

Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();

return rewriter.create<Torch::AtenEqFloatOp>(
loc, typeOfComparisonResult, scaleFactorAsOp, scaleIdentityAsOp);
}

Value originalSizeComparisonOpForSizeAtDimensionIn(
Value givenTargetSizes, Value givenOriginalTensor, int64_t givenDimension,
OpBinder binder, ConversionPatternRewriter &rewriter) {
auto typeOfTargetSizes =
cast<Torch::BaseTensorType>(givenTargetSizes.getType());

Type typeOfSelectionFromTargetSizes = typeOfTargetSizes.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, typeOfTargetSizes.getOptionalDtype());

auto loc = binder.getLoc();

Value zeroAsOp =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));

Type typeOfTargetSize = rewriter.getType<Torch::IntType>();

Value givenDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(givenDimension));

Value selectionFromTargetSizesAsOp = rewriter.create<Torch::AtenSelectIntOp>(
loc, typeOfSelectionFromTargetSizes, givenTargetSizes, zeroAsOp,
givenDimensionAsOp);

Value targetSizeAsOp = rewriter.create<Torch::AtenItemOp>(
loc, typeOfTargetSize, selectionFromTargetSizesAsOp);

Value originalSizeAsOp = rewriter.create<Torch::AtenSizeIntOp>(
loc, givenOriginalTensor, givenDimensionAsOp);

Type typeOfComparisonResult = rewriter.getType<Torch::BoolType>();

return rewriter.create<Torch::AtenEqIntOp>(loc, typeOfComparisonResult,
targetSizeAsOp, originalSizeAsOp);
}

Value withUnsupportedDimensionsFilteredOut(
Value givenTransformationVector, OpBinder binder,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -2719,6 +2790,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
auto typeOfInputTensor =
cast<Torch::BaseTensorType>(inputTensor.getType());

auto sizesOfInputTensor = typeOfInputTensor.getSizes();
ArrayRef<int64_t> sizesOfOutputTensor = typeOfOutputTensor.getSizes();

int64_t const dimensionAssumedToBeBatch = 0;
int64_t const dimensionAssumedToBeChannel = 1;
int64_t nonResizableDimensions[] = {
dimensionAssumedToBeBatch,
dimensionAssumedToBeChannel,
};

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto eachDimension : nonResizableDimensions) {
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension];
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension];

if (eachSizeOfInputTensor == unknownSize ||
eachSizeOfOutputTensor == unknownSize) {
continue;
} else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) {
continue;
}

auto scalingIntentErrorMessage =
"unsupported: non-trivial intent to scale dimension: " +
std::to_string(eachDimension);

return rewriter.notifyMatchFailure(binder.op,
scalingIntentErrorMessage);
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2766,10 +2874,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
}

Value inputTensor = operands[0];
auto typeOfInputTensor =
cast<Torch::BaseTensorType>(inputTensor.getType());
auto sizesOfInputTensor = typeOfInputTensor.getSizes();
unsigned rankOfInputTensor = sizesOfInputTensor.size();

// supported modes:
Expand Down Expand Up @@ -2815,10 +2919,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

if (operands.size() < 4) {
Value proposedScaleFactorsAsOp = operands[2];

// run-time scale factor check for dynamic sizes
for (auto eachDimension : nonResizableDimensions) {
auto eachScaleIdentityComparisonAsOp =
scaleIdentityComparisonOpForFactorAtDimensionIn(
proposedScaleFactorsAsOp, eachDimension, binder, rewriter);

auto eachErrorMessage =
"Unsupported: non-trivial scale factor for dimension " +
std::to_string(eachDimension);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachScaleIdentityComparisonAsOp,
rewriter.getStringAttr(eachErrorMessage));
};

filteredScaleFactorsAsOp = withUnsupportedDimensionsFilteredOut(
proposedScaleFactorsAsOp, binder, rewriter);
} else {
Value proposedSizesAsOp = operands[3];

// run-time target size check for dynamic sizes
for (auto eachDimension : nonResizableDimensions) {
auto eachSizeComparisonAsOp =
originalSizeComparisonOpForSizeAtDimensionIn(
proposedSizesAsOp, inputTensor, eachDimension, binder,
rewriter);

auto eachErrorMessage =
"Unsupported: non-trivial resizing of dimension " +
std::to_string(eachDimension);

rewriter.create<Torch::RuntimeAssertOp>(
loc, eachSizeComparisonAsOp,
rewriter.getStringAttr(eachErrorMessage));
};

filteredSizesAsOp = withUnsupportedDimensionsFilteredOut(
proposedSizesAsOp, binder, rewriter);
}
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand All @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
Expand All @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand Down

0 comments on commit a3daf2a

Please sign in to comment.