Skip to content

Commit

Permalink
Fix isSigned and add float16 in PrintOp (triton-lang#191)
Browse files Browse the repository at this point in the history
* Fix isSigned in PrintOp

* Add float16 support for print

* Support float16 printing for old compilers
  • Loading branch information
minjang authored Dec 9, 2024
1 parent a74ee76 commit 958b889
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 10 deletions.
6 changes: 3 additions & 3 deletions third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter,

void createRuntimePrintCall(ConversionPatternRewriter &rewriter,
std::array<Value, 3> pid, StringRef prefix,
Value ptr, Type dtype, bool hex) {
Value ptr, Type dtype, bool isSigned, bool hex) {
assert(!prefix.empty());
auto loc = UnknownLoc::get(rewriter.getContext());
Value prefixValue = LLVM::addStringToModule(
Expand All @@ -205,7 +205,7 @@ void createRuntimePrintCall(ConversionPatternRewriter &rewriter,

allArgs.push_back(i32_val(dtype.getIntOrFloatBitWidth()));
allArgs.push_back(i32_val(dtype.isInteger()));
allArgs.push_back(i32_val(dtype.isSignedInteger()));
allArgs.push_back(i32_val(isSigned));
allArgs.push_back(i32_val(hex));

call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs);
Expand Down Expand Up @@ -254,7 +254,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::cpu::PrintOp> {
createRuntimePrintCall(
rewriter, pid, op.getPrefix(), adaptor.getOperands()[0],
cast<UnrankedMemRefType>(op.getVal()[0].getType()).getElementType(),
op.getHex());
op.getIsSigned()[0], op.getHex());

rewriter.eraseOp(op);
return success();
Expand Down
8 changes: 5 additions & 3 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ struct PrintOpConversion : public OpConversionPattern<triton::PrintOp> {
return success();
}

for (auto operand : op.getOperands()) {
for (size_t i = 0; i < op.getNumOperands(); i++) {
Value operand = op.getOperands()[i];
auto isSigned = {op.getIsSigned()[i]};
if (!isa<RankedTensorType>(operand.getType())) {
rewriter.create<triton::cpu::PrintOp>(
loc, op.getPrefix(), op.getHex(),
rewriter.getRemappedValue(operand), false);
rewriter.getRemappedValue(operand), isSigned);
continue;
}

Expand All @@ -92,7 +94,7 @@ struct PrintOpConversion : public OpConversionPattern<triton::PrintOp> {
allocVal);

rewriter.create<triton::cpu::PrintOp>(loc, op.getPrefix(), op.getHex(),
allocUnrankedVal, false);
allocUnrankedVal, isSigned);

rewriter.create<memref::DeallocOp>(loc, allocVal);
}
Expand Down
62 changes: 58 additions & 4 deletions third_party/cpu/runtime/cpu_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <string>
#include <vector>

#define __STDC_WANT_IEC_60559_TYPES_EXT__
#include <float.h>

#if defined(_MSC_VER)
#define EXPORT __declspec(dllexport)
#elif defined(__GNUC__)
Expand All @@ -24,6 +27,42 @@ const int MAX_FLOAT_WIDTH = 8;
const int FLOAT_PREC = 4;
const int ELEMS_PER_LINE = 8;

using FLOAT16 = struct _FLOAT16 {
#ifdef FLT16_MAX
_Float16 x;
#else
uint16_t x;
#endif

float toFloat32() const {
#ifdef FLT16_MAX
return static_cast<float>(x);
#else
// Based on https://gist.github.com/zhuker/b4bd1fb306c7b04975b712c37c4c4075
uint32_t t1;
uint32_t t2;
uint32_t t3;

t1 = x & 0x7fffu; // Non-sign bits
t2 = x & 0x8000u; // Sign bit
t3 = x & 0x7c00u; // Exponent

t1 <<= 13u; // Align mantissa on MSB
t2 <<= 16u; // Shift sign bit into position

t1 += 0x38000000; // Adjust bias

t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero

t1 |= t2; // Re-insert sign bit

float out;
*((uint32_t *)&out) = t1;
return out;
#endif
}
};

struct FormatInfo {
bool isInt;
bool isSigned;
Expand Down Expand Up @@ -91,6 +130,12 @@ std::pair<int /* numDigits */, bool /* isNegative */> computeDigitInfo(T val) {
return {digits, val < 0};
}

template <>
std::pair<int /* numDigits */, bool /* isNegative */>
computeDigitInfo<FLOAT16>(FLOAT16 val) {
return computeDigitInfo<float>(val.toFloat32());
}

template <typename T>
std::tuple<int, int, bool> computeDigitStats(const MemRefDescriptor<T> &desc) {
int maxIntDigits = 0;
Expand Down Expand Up @@ -177,6 +222,12 @@ void printFormattedElement<uint8_t>(std::stringstream &ss, uint8_t val,
printFormattedElement<uint16_t>(ss, val, formatInfo);
}

template <>
void printFormattedElement<FLOAT16>(std::stringstream &ss, FLOAT16 val,
const FormatInfo &formatInfo) {
printFormattedElement<float>(ss, val.toFloat32(), formatInfo);
}

template <typename T>
void printToStreamRecursive(const MemRefDescriptor<T> &desc,
std::stringstream &ss, const FormatInfo &formatInfo,
Expand Down Expand Up @@ -247,6 +298,10 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor,
printToStream(MemRefDescriptor<float>(rank, descriptor), ss,
partialFormat, linePrefix);
return;
case 16:
printToStream(MemRefDescriptor<FLOAT16>(rank, descriptor), ss,
partialFormat, linePrefix);
return;
default:
llvm_unreachable("Unsupported bitWidth");
}
Expand Down Expand Up @@ -325,14 +380,13 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond,
EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1,
int32_t pid2, const char *prefix,
UnrankedMemRefType memref, int32_t btw,
bool isInteger, bool isSignedInteger,
bool isInteger, bool isSigned,
bool asHex) {
std::stringstream ss;
ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix;
std::string linePrefix(ss.str().size(), ' ');

printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger,
isSignedInteger, asHex, linePrefix);
printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, isSigned,
asHex, linePrefix);
ss << "\n";
std::cout << ss.str() << std::flush;
}
Expand Down

0 comments on commit 958b889

Please sign in to comment.