Skip to content

Commit

Permalink
[Keep materialization] Turn on meterialization (triton-lang#154)
Browse files Browse the repository at this point in the history
This commit modifies TypeConversion to allow conversion cast from
vector<> to tensor<> and avoid usage of option that skips
materialization.

Signed-off-by: Dmitrii Makarenko <[email protected]>
  • Loading branch information
Devjiu authored Sep 27, 2024
1 parent 1a39a50 commit 65e5849
Show file tree
Hide file tree
Showing 13 changed files with 30 additions and 34 deletions.
8 changes: 0 additions & 8 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertDebugOps();
#define GEN_PASS_REGISTRATION
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"

inline LogicalResult applyPartialConversionNoBuildMaterializations(
Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config = ConversionConfig()) {
config.buildMaterializations = false;
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
}

} // namespace cpu
} // namespace triton

Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ struct ConvertAtomicOps
patterns.add<AtomicRMWOpConversion>(typeConverter, context);
patterns.add<AtomicCASOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ struct ConvertControlFlowOps
RewritePatternSet patterns(context);
patterns.add<OpTypeConversion<scf::YieldOp>>(typeConverter, context);
patterns.add<OpTypeConversion<scf::ConditionOp>>(typeConverter, context);
if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}

Expand All @@ -180,8 +179,7 @@ struct ConvertControlFlowOps
{
RewritePatternSet patterns(context);
patterns.add<SCFIfPattern>(typeConverter, context);
if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}

Expand All @@ -197,8 +195,7 @@ struct ConvertControlFlowOps
RewritePatternSet patterns(context);
patterns.add<ForOpConversion>(typeConverter, context);
patterns.add<WhileOpConversion>(typeConverter, context);
if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
}
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ struct ConvertDebugOps
patterns.add<PrintOpConversion>(typeConverter, context);
patterns.add<AssertOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ struct ConvertDotOp : public triton::impl::ConvertDotOpBase<ConvertDotOp> {
RewritePatternSet patterns(context);
patterns.add<DotOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,7 @@ struct ConvertElemManipOps
patterns.add<CatOpConversion>(typeConverter, context);
patterns.add<SplitOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ struct ConvertElementwiseOps
patterns.add<ClampFOpConversion>(typeConverter, context);
patterns.add<FpToFpOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ struct ConvertHistogramOp
RewritePatternSet patterns(context);
patterns.add<HistogramOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,7 @@ struct ConvertMemoryOps
patterns.add<StoreOpConversion>(axisInfoAnalysis, shapeInfoAnalysis,
pointerConverter, useScalarLoops, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase<ConvertPtrOps> {
patterns.add<PtrToIntOpConversion>(typeConverter, context);
patterns.add<IntToPtrOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ struct ConvertReductionOp
patterns.add<ReduceOpConversion>(useReductionOp, useMultiDimReductionOp,
typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
3 changes: 1 addition & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ struct ConvertScanOp : public triton::impl::ConvertScanOpBase<ConvertScanOp> {
RewritePatternSet patterns(context);
patterns.add<ScanOpConversion>(typeConverter, context);

if (failed(applyPartialConversionNoBuildMaterializations(
mod, convTarget, std::move(patterns))))
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};
Expand Down
17 changes: 17 additions & 0 deletions third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() {
return VectorType::get(tensorTy.getShape(), elemTy);
});

addArgumentMaterialization([&](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> std::optional<Value> {
if (isa<TensorType>(type))
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
.getResult(0);
llvm::errs() << "Inputs: ";
llvm::interleaveComma(inputs, llvm::errs());
llvm::errs() << "\n";
llvm::errs() << "Type: " << type << "\n";
llvm_unreachable("Unexpected argument materizalization");
});

// Converted ops produce vectors instead of tensors. Provide conversion
// here for users.
addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs,
Expand All @@ -29,6 +42,10 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() {
if (isa<VectorType>(type))
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
.getResult(0);
llvm::errs() << "Inputs: ";
llvm::interleaveComma(inputs, llvm::errs());
llvm::errs() << "\n";
llvm::errs() << "Type: " << type << "\n";
llvm_unreachable("Unexpected target materizalization");
});
}

0 comments on commit 65e5849

Please sign in to comment.