From 410855269afe7989e5f9f1c4f4f94723941dbd92 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 3 Jan 2025 23:09:32 -0500 Subject: [PATCH 1/3] Add fp16 --- allo/backend/llvm.py | 11 ++++++- allo/utils.py | 7 +++++ mlir/lib/Translation/EmitTapaHLS.cpp | 4 ++- mlir/lib/Translation/EmitVivadoHLS.cpp | 6 +++- tests/test_bitop.py | 28 ++++++++++++++++-- tests/test_types.py | 40 ++++++++++++++++++++++++++ 6 files changed, 91 insertions(+), 5 deletions(-) diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py index 4cd3023f..97c0bda2 100644 --- a/allo/backend/llvm.py +++ b/allo/backend/llvm.py @@ -155,7 +155,10 @@ def __call__(self, *args): f"Input type mismatch: {target_in_type} vs f32. Please use NumPy array" " to wrap the data to avoid possible result mismatch" ).warn() - if target_in_type == "f32": + if target_in_type == "f16": + c_float_p = ctypes.c_int16 * 1 + arg = np.float16(arg).view(np.int16) + elif target_in_type == "f32": c_float_p = ctypes.c_float * 1 else: # f64 c_float_p = ctypes.c_double * 1 @@ -317,6 +320,8 @@ def __call__(self, *args): ret = struct_array_to_int_array( ret, bitwidth, result_type[0] == "i" ) + elif result_type == "f16": + ret = np.array(ret, dtype=np.int16).view(np.float16) elif result_type.startswith("fixed") or result_type.startswith( "ufixed" ): @@ -333,6 +338,8 @@ def __call__(self, *args): # INVOKE self.execution_engine.invoke(self.top_func_name, *arg_ptrs, return_ptr) ret = return_ptr[0] + if result_type == "f16": + ret = np.int16(ret).view(np.float16) else: # multiple returns, assume all memref # INVOKE self.execution_engine.invoke(self.top_func_name, return_ptr, *arg_ptrs) @@ -349,6 +356,8 @@ def __call__(self, *args): ret_i = struct_array_to_int_array( np_arr, bitwidth, res_type[0] == "i" ) + elif result_type == "f16": + ret_i = np.array(np_arr, dtype=np.int16).view(np.float16) elif res_type.startswith("fixed") or res_type.startswith("ufixed"): bitwidth, frac = get_bitwidth_and_frac_from_fixed(res_type) ret_i = struct_array_to_int_array( diff --git a/allo/utils.py b/allo/utils.py index dcbc3142..79a11a51 100644 --- a/allo/utils.py +++ b/allo/utils.py @@ -9,6 +9,7 @@ RankedTensorType, IntegerType, IndexType, + F16Type, F32Type, F64Type, ) @@ -18,6 +19,7 @@ np_supported_types = { + "f16": np.float16, "f32": np.float32, "f64": np.float64, "i8": np.int8, @@ -33,6 +35,9 @@ ctype_map = { + # ctypes.c_float16 does not exist + # similar implementation in _mlir/runtime/np_to_memref.py/F16 + "f16": ctypes.c_int16, "f32": ctypes.c_float, "f64": ctypes.c_double, "i8": ctypes.c_int8, @@ -152,6 +157,8 @@ def get_dtype_and_shape_from_type(dtype): return "index", tuple() if IntegerType.isinstance(dtype): return str(IntegerType(dtype)), tuple() + if F16Type.isinstance(dtype): + return str(F16Type(dtype)), tuple() if F32Type.isinstance(dtype): return str(F32Type(dtype)), tuple() if F64Type.isinstance(dtype): diff --git a/mlir/lib/Translation/EmitTapaHLS.cpp b/mlir/lib/Translation/EmitTapaHLS.cpp index 42f97a9b..3d6d1a01 100644 --- a/mlir/lib/Translation/EmitTapaHLS.cpp +++ b/mlir/lib/Translation/EmitTapaHLS.cpp @@ -36,7 +36,9 @@ static SmallString<16> getTypeName(Type valType) { valType = arrayType.getElementType(); // Handle float types. - if (valType.isa()) + if (valType.isa()) + return SmallString<16>("half"); + else if (valType.isa()) return SmallString<16>("float"); else if (valType.isa()) return SmallString<16>("double"); diff --git a/mlir/lib/Translation/EmitVivadoHLS.cpp b/mlir/lib/Translation/EmitVivadoHLS.cpp index 976cfb4a..2c7c3598 100644 --- a/mlir/lib/Translation/EmitVivadoHLS.cpp +++ b/mlir/lib/Translation/EmitVivadoHLS.cpp @@ -35,7 +35,11 @@ static SmallString<16> getTypeName(Type valType) { valType = arrayType.getElementType(); // Handle float types. - if (valType.isa()) + if (valType.isa()) + // Page 222: + // https://www.amd.com/content/dam/xilinx/support/documents/sw_manuals/xilinx2020_2/ug902-vivado-high-level-synthesis.pdf + return SmallString<16>("half"); + else if (valType.isa()) return SmallString<16>("float"); else if (valType.isa()) return SmallString<16>("double"); diff --git a/tests/test_bitop.py b/tests/test_bitop.py index 5fed1a9d..bc3b078a 100644 --- a/tests/test_bitop.py +++ b/tests/test_bitop.py @@ -4,7 +4,7 @@ import pytest import numpy as np import allo -from allo.ir.types import uint1, uint2, int32, uint8, uint32, UInt, float32 +from allo.ir.types import uint1, uint2, int32, uint8, uint32, UInt, float16, float32 def test_scalar(): @@ -125,7 +125,7 @@ def kernel(A: int32, B: int32[11]): assert bin(1234) == "0b" + "".join([str(np_B[i]) for i in range(10, -1, -1)]) -def test_bitcast_uint2float(): +def test_bitcast_uint2float32(): def kernel(A: uint32[10, 10]) -> float32[10, 10]: B: float32[10, 10] for i, j in allo.grid(10, 10): @@ -146,6 +146,30 @@ def kernel(A: uint32[10, 10]) -> float32[10, 10]: print("Passed!") +def test_bitcast_uint2float16(): + def kernel(A: int32[10, 10]) -> float16[10, 10]: + B: float16[10, 10] + for i, j in allo.grid(10, 10): + B[i, j] = A[i, j][0:16].bitcast() + return B + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + + A_np = np.random.randint(100, size=(10, 10)).astype(np.int32) + B_np = mod(A_np) + answer = np.frombuffer(A_np.astype(np.int16).tobytes(), np.float16).reshape( + (10, 10) + ) + assert np.array_equal(B_np, answer) + + code = str(s.build(target="vhls")) + print(code) + assert "union" in code and "half" in code + print("Passed!") + + def test_bitcast_float2uint(): def kernel(A: float32[10, 10]) -> uint32[10, 10]: B: uint32[10, 10] diff --git a/tests/test_types.py b/tests/test_types.py index 0c50a3a0..9c2ea25c 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,6 +13,7 @@ bool, uint1, int32, + float16, float32, index, ) @@ -322,6 +323,45 @@ def kernel[Ty]() -> int32: print(s.module) +def test_fp16(): + def kernel(a: float16) -> float16: + return a + 1 + + s = allo.customize(kernel) + assert "f16" in str(s.module) + mod = s.build() + assert mod(1.0) == kernel(1.0) + + +def test_fp16_array(): + def kernel(A: float16[10]) -> float16[10]: + B: float16[10] + for i in range(10): + B[i] = A[i] + 1 + return B + + s = allo.customize(kernel) + assert "f16" in str(s.module) + mod = s.build() + A = np.random.rand(10).astype(np.float16) + B = mod(A) + np.testing.assert_allclose(B, A + 1, rtol=1e-5) + + +def test_fp16_array_inplace(): + def kernel(A: float16[10]): + for i in range(10): + A[i] += 1 + + s = allo.customize(kernel) + assert "f16" in str(s.module) + mod = s.build() + A = np.random.rand(10).astype(np.float16) + res = A + 1 + mod(A) + np.testing.assert_allclose(A, res, rtol=1e-5) + + def test_select_typing(): def kernel(flt: float32, itg: int32) -> float32: # if correctly typed, the select should have float32 result From 2435eed24f42d8ed3b30a00118b6971cf01766d3 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 3 Jan 2025 23:10:58 -0500 Subject: [PATCH 2/3] Fix pylint --- allo/backend/llvm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py index 97c0bda2..4d92c70c 100644 --- a/allo/backend/llvm.py +++ b/allo/backend/llvm.py @@ -1,6 +1,6 @@ # Copyright Allo authors. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=no-name-in-module, inconsistent-return-statements +# pylint: disable=no-name-in-module, inconsistent-return-statements, too-many-function-args import os import ctypes From 89ac26176469ab8c065062379978677306562c42 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 3 Jan 2025 23:26:53 -0500 Subject: [PATCH 3/3] Fix pytest --- allo/backend/llvm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py index 4d92c70c..54c9b721 100644 --- a/allo/backend/llvm.py +++ b/allo/backend/llvm.py @@ -356,7 +356,7 @@ def __call__(self, *args): ret_i = struct_array_to_int_array( np_arr, bitwidth, res_type[0] == "i" ) - elif result_type == "f16": + elif res_type == "f16": ret_i = np.array(np_arr, dtype=np.int16).view(np.float16) elif res_type.startswith("fixed") or res_type.startswith("ufixed"): bitwidth, frac = get_bitwidth_and_frac_from_fixed(res_type)