-
Notifications
You must be signed in to change notification settings - Fork 518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(ONNX): avoids resizing fixed dimensions #3945
base: main
Are you sure you want to change the base?
fix(ONNX): avoids resizing fixed dimensions #3945
Conversation
bb3f80f
to
6baa8d5
Compare
6baa8d5
to
ab7e021
Compare
- "result" -> "outputTensor" - "type" -> more like "blueprint" since it includes shape and element data type
ab7e021
to
7aec80b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main structural question is about the need for adding the BaseTensorType
method. If it were useful elsewhere (I have some doubts, since we would need to know too much about the two tensor shapes prior to using it- namely that they are present, and they have the same rank), I would consider keeping it; however, the code is simplified here by not using it, and I suspect that the same would be true in other circumstances where it might be used.
auto this_dimensions = /**/ getSizes(); | ||
auto that_dimensions = that.getSizes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseTensorType
might not have sizes, and this will cause a crash when called. I would do:
auto this_dimensions = /**/ getSizes(); | |
auto that_dimensions = that.getSizes(); | |
auto selfSizes = getOptionalSizes(); | |
auto otherSizes = other.getOptionalSizes(); |
Note also the variable and camel casing conventions. The variables self
and other
are used more typically than this
and that
in this codebase (easier to distiguish).
@@ -84,6 +84,10 @@ class BaseTensorType : public Type { | |||
/// Enable isa/dyn_cast for BaseTensorType. | |||
static bool classof(Type type); | |||
|
|||
/// The element-wise comparison of each dimension/size in `that` tensor | |||
std::vector<std::optional<bool>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use SmallVector
instead of std::vector
. The methods are the same, and it is better for small containers like this.
@@ -2686,12 +2686,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( | |||
}); | |||
patterns.onOp( | |||
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { | |||
Torch::ValueTensorType resultType; | |||
Torch::ValueTensorType outputTensor_blueprint; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the name changing of this variable. This isn't a blueprint, it's the result type.
outputTensor_blueprint); | ||
|
||
// Comparisons of the dimensions assumed to carry the batch and channel | ||
auto shapeComparisonForFixedDimensions = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change Fixed
-> Static
return rewriter.notifyMatchFailure( | ||
binder.op, "Sizes for batch and channel dimensions must be " | ||
"statically defined"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We definitely do not want to constrain this conversion to static batch and channel dims. This was the reason for needing to write asserts into the helper function getValueList
in the dynamic case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be fine to just put runtime asserts in right after the last match failure. Something like:
Value inputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstZero);
Value inputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, input, cstOne);
Value outputDimZero = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstZero);
Value outputDimOne = rewriter.create<Torch::AtenSizeIntOp>(loc, output, cstOne);
Value cmpDimZero = rewriter.create<Torch::AtenEqIntOp>(loc, inputDimZero, outputDimZero);
Value cmpDimOne = ...
rewriter.create<Torch::RuntimeAssertOp>(loc, cmpDimZero, rewriter.getStringAttr("message"));
// same for DimOne
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, if one of the two dims have input/output sizes that are static and equal, then these asserts will fold out, so there isn't a pressing need to check again for static dims.
for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) { | ||
if (eachDimensionComparison == std::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you need to loop over the result of the shape comparison anyway, it would be more efficient to not define the helper function at all, and do
for (int64_t dim = 0; dim < 2; dim++) {
if (inputSizes[dim] == Torch::kUnknownSize || outputSizes[dim] == Torch::kUnknownSize)
continue; // you need to implement the runtime asserts, but at least still check the other dim if static.
if (inputSizes[dim] != outputSizes[dim])
return rewriter.notifyMatchFailure(...
No description provided.