Skip to content

Commit

Permalink
Fix nan issue for fp16 torch.randn/randn_like in ConvertAtenUniformOp (
Browse files Browse the repository at this point in the history
…#3184)

For ops that use ConvertAtenUniformOp (e.g. torch.randn/randn_like),
fp16 datatype returns nan values. Trying to lower [this
repro](https://gist.github.com/aviator19941/1c65e658241dea6906ca423f9abaee69)
will result in nan's, this PR fixes the issue.
  • Loading branch information
aviator19941 authored Apr 24, 2024
1 parent fab2696 commit 678c03b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions lib/Conversion/TorchToLinalg/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {
Value generator = adaptor.getGenerator();
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();
Type f64Ty = rewriter.getF64Type();

if (!isa<mlir::FloatType>(elemTy))
return rewriter.notifyMatchFailure(op, "This op only support float type");
Expand All @@ -139,8 +140,8 @@ class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {
"generator is supported");
// Get key, min and max used by `linalg.generic` compute payload.
Value key = rewriter.create<TorchConversion::GetNextSeedOp>(loc);
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
Value min = convertScalarToDtype(rewriter, loc, from, f64Ty);
Value max = convertScalarToDtype(rewriter, loc, to, f64Ty);

// Construct the `linalg.generic` op.
auto resultRank = resultType.getRank();
Expand Down Expand Up @@ -179,11 +180,14 @@ class ConvertAtenUniformOp : public OpConversionPattern<AtenUniformOp> {

// res = cast(F64, tempN) * scale + min
Value updateFloat =
b.create<arith::UIToFPOp>(loc, elemTy, randomVal);
b.create<arith::UIToFPOp>(loc, f64Ty, randomVal);
Value updateScaled =
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
b.create<linalg::YieldOp>(loc, res);
Value truncRes = res;
if (elemTy.isa<Float16Type, Float32Type>())
truncRes = b.create<arith::TruncFOp>(loc, elemTy, res);
b.create<linalg::YieldOp>(loc, truncRes);
})
.getResult(0);

Expand Down

0 comments on commit 678c03b

Please sign in to comment.