Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ONNX): avoids resizing fixed dimensions #3945

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

bjacobgordon
Copy link
Contributor

No description provided.

@bjacobgordon bjacobgordon force-pushed the fix-onnx-adds-exceptions-enforcing-convention-in-resize-op branch 2 times, most recently from bb3f80f to 6baa8d5 Compare January 8, 2025 22:57
@bjacobgordon bjacobgordon force-pushed the fix-onnx-adds-exceptions-enforcing-convention-in-resize-op branch from 6baa8d5 to ab7e021 Compare January 8, 2025 23:02
@bjacobgordon bjacobgordon changed the title fix(ONNX): protects against mismatched dynamic meta dimensions fix(ONNX): avoids resizing fixed dimensions Jan 9, 2025
@bjacobgordon bjacobgordon force-pushed the fix-onnx-adds-exceptions-enforcing-convention-in-resize-op branch from ab7e021 to 7aec80b Compare January 9, 2025 17:17
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main structural question is about the need for adding the BaseTensorType method. If it were useful elsewhere (I have some doubts, since we would need to know too much about the two tensor shapes prior to using it- namely that they are present, and they have the same rank), I would consider keeping it; however, the code is simplified here by not using it, and I suspect that the same would be true in other circumstances where it might be used.

Comment on lines +222 to +223
auto this_dimensions = /**/ getSizes();
auto that_dimensions = that.getSizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseTensorType might not have sizes, and this will cause a crash when called. I would do:

Suggested change
auto this_dimensions = /**/ getSizes();
auto that_dimensions = that.getSizes();
auto selfSizes = getOptionalSizes();
auto otherSizes = other.getOptionalSizes();

Note also the variable and camel casing conventions. The variables self and other are used more typically than this and that in this codebase (easier to distiguish).

@@ -84,6 +84,10 @@ class BaseTensorType : public Type {
/// Enable isa/dyn_cast for BaseTensorType.
static bool classof(Type type);

/// The element-wise comparison of each dimension/size in `that` tensor
std::vector<std::optional<bool>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use SmallVector instead of std::vector. The methods are the same, and it is better for small containers like this.

@@ -2686,12 +2686,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
});
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Torch::ValueTensorType outputTensor_blueprint;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the name changing of this variable. This isn't a blueprint, it's the result type.

outputTensor_blueprint);

// Comparisons of the dimensions assumed to carry the batch and channel
auto shapeComparisonForFixedDimensions =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change Fixed -> Static

Comment on lines +2735 to +2738
return rewriter.notifyMatchFailure(
binder.op, "Sizes for batch and channel dimensions must be "
"statically defined");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely do not want to constrain this conversion to static batch and channel dims. This was the reason for needing to write asserts into the helper function getValueList in the dynamic case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be fine to just put runtime asserts in right after the last match failure. Something like:

Value inputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstZero);
Value inputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstOne);
Value outputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstZero);
Value outputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstOne);
Value cmpDimZero = rewriter.create<Torch::AtenEqIntOp>(loc, inputDimZero, outputDimZero);
Value cmpDimOne = ...
rewriter.create<Torch::RuntimeAssertOp>(loc, cmpDimZero, rewriter.getStringAttr("message"));
// same for DimOne

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, if one of the two dims have input/output sizes that are static and equal, then these asserts will fold out, so there isn't a pressing need to check again for static dims.

Comment on lines +2733 to +2734
for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) {
if (eachDimensionComparison == std::nullopt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you need to loop over the result of the shape comparison anyway, it would be more efficient to not define the helper function at all, and do

for (int64_t dim = 0; dim < 2; dim++) {
    if (inputSizes[dim] == Torch::kUnknownSize || outputSizes[dim] == Torch::kUnknownSize)
        continue; // you need to implement the runtime asserts, but at least still check the other dim if static.
    if (inputSizes[dim] != outputSizes[dim])
        return rewriter.notifyMatchFailure(...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants