Skip to content

Commit

Permalink
[IR] Add shift left/right operations for fixed points (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Nov 3, 2024
1 parent d61322e commit 4d632c0
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 13 deletions.
22 changes: 17 additions & 5 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,15 +652,15 @@ def build_general_binop(ctx, node, lhs, rhs):
Float: RuntimeError,
Int: arith_d.ShLIOp,
UInt: arith_d.ShLIOp,
Fixed: RuntimeError,
UFixed: RuntimeError,
Fixed: allo_d.ShLFixedOp,
UFixed: allo_d.ShLFixedOp,
},
ast.RShift: {
Float: RuntimeError,
Int: arith_d.ShRSIOp,
UInt: arith_d.ShRUIOp,
Fixed: RuntimeError,
UFixed: RuntimeError,
Fixed: allo_d.ShRFixedOp,
UFixed: allo_d.ShRFixedOp,
},
ast.BitOr: {
Float: RuntimeError,
Expand All @@ -685,6 +685,12 @@ def build_general_binop(ctx, node, lhs, rhs):
},
}.get(type(node.op))
ty_cls = Int if isinstance(node.dtype, Index) else type(node.dtype)
if isinstance(node.op, (ast.LShift, ast.RShift)) and isinstance(
node.dtype, (Fixed, UFixed)
):
return opcls[ty_cls](
node.dtype.build(), lhs.result, rhs.result, ip=ctx.get_ip()
)
return opcls[ty_cls](lhs.result, rhs.result, ip=ctx.get_ip())

@staticmethod
Expand Down Expand Up @@ -730,8 +736,14 @@ def build_BinOp(ctx, node):
lhs = ASTTransformer.build_cast_op(
ctx, lhs, node.left.dtype, node.dtype, node.left.shape
)
if isinstance(node.op, (ast.LShift, ast.RShift)) and isinstance(
node.dtype, (Fixed, UFixed)
):
target_rhs_type = Int(32)
else:
target_rhs_type = node.dtype
rhs = ASTTransformer.build_cast_op(
ctx, rhs, node.right.dtype, node.dtype, node.right.shape
ctx, rhs, node.right.dtype, target_rhs_type, node.right.shape
)
lhs = ASTTransformer.build_broadcast_op(
ctx, lhs, node.dtype, node.left.shape, node.shape, node.dims[0]
Expand Down
13 changes: 12 additions & 1 deletion allo/ir/typing_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,13 +615,24 @@ def shift_rule():
(Int, Index): lambda t1, t2: t1,
}
uint_rules = {
(UInt, Int): lambda t1, t2: t1,
(UInt, UInt): lambda t1, t2: t1,
(UInt, Index): lambda t1, t2: t1,
}
index_rules = {
(Index, Index): lambda t1, t2: Index(),
}
return TypingRule([int_rules, uint_rules, index_rules], commutative=True)
fixed_rules = {
(Fixed, Int): lambda t1, t2: t1,
(Fixed, UInt): lambda t1, t2: t1,
(Fixed, Index): lambda t1, t2: t1,
}
ufixed_rules = {
(UFixed, Int): lambda t1, t2: t1,
(UFixed, UInt): lambda t1, t2: t1,
(UFixed, Index): lambda t1, t2: t1,
}
return TypingRule([int_rules, uint_rules, index_rules, fixed_rules, ufixed_rules])


def and_or_rule():
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/allo/Dialect/AlloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,30 @@ def DivFixedOp : FixedBinaryOp<"div_fixed"> {
let summary = "fixed point division operation";
}

def ShLFixedOp : Allo_Op<"shl_fixed"> {
let summary = "fixed point shift left operation";
let arguments = (ins
FixedLike:$lhs,
SignlessIntegerLike:$rhs
);
let results = (outs FixedLike:$result);
let assemblyFormat = [{
`(` $lhs `,` $rhs `)` `:` attr-dict `(` type($lhs) `,` type($rhs) `)` `->` type($result)
}];
}

def ShRFixedOp : Allo_Op<"shr_fixed"> {
let summary = "fixed point shift right operation";
let arguments = (ins
FixedLike:$lhs,
SignlessIntegerLike:$rhs
);
let results = (outs FixedLike:$result);
let assemblyFormat = [{
`(` $lhs `,` $rhs `)` `:` attr-dict `(` type($lhs) `,` type($rhs) `)` `->` type($result)
}];
}

def CmpFixedOp : Allo_Op<"cmp_fixed", [NoMemoryEffect, SameTypeOperands, TypesMatchWith<
"result type has i1 element type and same shape as operands",
"lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
Expand Down
13 changes: 9 additions & 4 deletions mlir/include/allo/Dialect/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,21 @@ class HLSCppVisitorBase {
arith::MaximumFOp, arith::MinimumFOp,
// Logical expressions.
arith::XOrIOp, arith::AndIOp, arith::OrIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, allo::GetIntBitOp, allo::SetIntBitOp,
allo::GetIntSliceOp, allo::SetIntSliceOp, allo::BitReverseOp,
arith::ShRSIOp, arith::ShRUIOp, allo::GetIntBitOp,
allo::SetIntBitOp, allo::GetIntSliceOp, allo::SetIntSliceOp,
allo::BitReverseOp,
// Special operations.
func::CallOp, func::ReturnOp, arith::SelectOp, arith::ConstantOp,
arith::TruncIOp, arith::TruncFOp, arith::ExtUIOp, arith::ExtSIOp,
arith::ExtFOp, arith::IndexCastOp, arith::UIToFPOp, arith::SIToFPOp,
arith::FPToSIOp, arith::FPToUIOp, arith::BitcastOp,
allo::FixedToFloatOp, allo::FloatToFixedOp, allo::IntToFixedOp,
allo::FixedToIntOp, allo::FixedToFixedOp, UnrealizedConversionCastOp,
allo::FixedToIntOp, allo::FixedToFixedOp,
UnrealizedConversionCastOp,
// Allo operations.
allo::CreateLoopHandleOp, allo::CreateOpHandleOp, allo::AddFixedOp,
allo::SubFixedOp, allo::MulFixedOp, allo::DivFixedOp, allo::CmpFixedOp,
allo::SubFixedOp, allo::MulFixedOp, allo::DivFixedOp,
allo::CmpFixedOp, allo::ShLFixedOp, allo::ShRFixedOp,
allo::MinFixedOp, allo::MaxFixedOp, allo::PrintOp>(
[&](auto opNode) -> ResultType {
return thisCast->visitOp(opNode, args...);
Expand Down Expand Up @@ -239,6 +242,8 @@ class HLSCppVisitorBase {
HANDLE(allo::MulFixedOp);
HANDLE(allo::DivFixedOp);
HANDLE(allo::CmpFixedOp);
HANDLE(allo::ShLFixedOp);
HANDLE(allo::ShRFixedOp);
HANDLE(allo::MinFixedOp);
HANDLE(allo::MaxFixedOp);

Expand Down
41 changes: 39 additions & 2 deletions mlir/lib/Conversion/FixedPointToInteger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,39 @@ void lowerFixedDiv(DivFixedOp &op) {
}
}

// Lower ShLFixedOp to ShLIOp
// https://docs.amd.com/r/en-US/ug1399-vitis-hls/Class-Methods-Operators-and-Data-Members
void lowerFixedShL(ShLFixedOp &op) {
OpBuilder rewriter(op);
Type t = op->getOperand(0).getType();
FixedTypeInfo ti = getFixedPointInfo(t);
auto lhs = op->getOperand(0);
int sh_width = op->getOperand(1).getType().cast<IntegerType>().getWidth();
Value rhs =
castIntegerWidth(op->getContext(), rewriter, op->getLoc(),
op->getOperand(1), sh_width, ti.width, ti.isSigned);
Type newType = IntegerType::get(op.getContext(), ti.width);
arith::ShLIOp newOp =
rewriter.create<arith::ShLIOp>(op->getLoc(), newType, lhs, rhs);
op->replaceAllUsesWith(newOp);
}

// Lower ShRFixedOp to ShRIOp
void lowerFixedShR(ShRFixedOp &op) {
OpBuilder rewriter(op);
Type t = op->getOperand(0).getType();
FixedTypeInfo ti = getFixedPointInfo(t);
auto lhs = op->getOperand(0);
int sh_width = op->getOperand(1).getType().cast<IntegerType>().getWidth();
Value rhs =
castIntegerWidth(op->getContext(), rewriter, op->getLoc(),
op->getOperand(1), sh_width, ti.width, ti.isSigned);
Type newType = IntegerType::get(op.getContext(), ti.width);
arith::ShRSIOp newOp =
rewriter.create<arith::ShRSIOp>(op->getLoc(), newType, lhs, rhs);
op->replaceAllUsesWith(newOp);
}

// Lower CmpFixedOp to CmpIOp
void lowerFixedCmp(CmpFixedOp &op) {
OpBuilder rewriter(op);
Expand Down Expand Up @@ -817,6 +850,10 @@ void visitOperation(Operation &op) {
lowerFixedDiv(new_op);
} else if (auto new_op = dyn_cast<CmpFixedOp>(op)) {
lowerFixedCmp(new_op);
} else if (auto new_op = dyn_cast<ShLFixedOp>(op)) {
lowerFixedShL(new_op);
} else if (auto new_op = dyn_cast<ShRFixedOp>(op)) {
lowerFixedShR(new_op);
} else if (auto new_op = dyn_cast<MinFixedOp>(op)) {
lowerFixedMin(new_op);
} else if (auto new_op = dyn_cast<MaxFixedOp>(op)) {
Expand Down Expand Up @@ -852,8 +889,8 @@ void visitBlock(Block &block) {
Operation &op = *it;
visitOperation(op);
if (llvm::isa<AddFixedOp, SubFixedOp, MulFixedOp, DivFixedOp, CmpFixedOp,
MinFixedOp, MaxFixedOp, IntToFixedOp, FixedToIntOp,
FloatToFixedOp, FixedToFloatOp, FixedToFixedOp,
ShLFixedOp, ShRFixedOp, MinFixedOp, MaxFixedOp, IntToFixedOp,
FixedToIntOp, FloatToFixedOp, FixedToFloatOp, FixedToFixedOp,
GetGlobalFixedOp>(op)) {
opToRemove.push_back(&op);
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Translation/EmitVivadoHLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,12 @@ class ExprVisitor : public HLSCppVisitorBase<ExprVisitor, bool> {
return emitter.emitBinary(op, "/"), true;
}
bool visitOp(allo::CmpFixedOp op);
bool visitOp(allo::ShLFixedOp op) {
return emitter.emitBinary(op, "<<"), true;
}
bool visitOp(allo::ShRFixedOp op) {
return emitter.emitBinary(op, ">>"), true;
}
bool visitOp(allo::MinFixedOp op) {
return emitter.emitMaxMin(op, "min"), true;
}
Expand Down
19 changes: 18 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,23 @@ def kernel(A: Ty) -> int32:
# FIXME: FixedType kernels cannot be lowered


def test_fixed_shift():
# Example from https://docs.amd.com/r/en-US/ug1399-vitis-hls/Class-Methods-Operators-and-Data-Members
def kernel(A: Fixed(8, 3)[2]) -> float32[2]:
shl: Fixed(25, 10) = A[0] << 2
shr: Fixed(25, 10) = A[1] >> 2
res: float32[2] = 0.0
res[0] = float(shl)
res[1] = float(shr)
return res

s = allo.customize(kernel)
print(s.module)
mod = s.build()
A = np.array([5.375, 5.375], dtype=np.float32)
np.testing.assert_allclose(mod(A), [-10.5, 1.25], rtol=1e-5)


def test_dynamic_type():
def kernel[Ty]() -> int32:
A: int32 = Ty.bits
Expand Down Expand Up @@ -484,7 +501,7 @@ def kernel(A: bool[16]) -> bool[16]:
s = allo.customize(kernel)
print(s.module)
mod = s.build()
np_A = np.random.randint(0, 2, size=(16)).astype(np.bool)
np_A = np.random.randint(0, 2, size=(16)).astype(np.bool_)
np_B = mod(np_A)
np.testing.assert_array_equal(np_A, np_B)

Expand Down

0 comments on commit 4d632c0

Please sign in to comment.