Skip to content

Commit

Permalink
Enable few more core tests for CPU. (#31)
Browse files Browse the repository at this point in the history
* Enable test_enable_fp_fusion for CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

* Enable test_optimize_thread_locality for CPU.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Jun 20, 2024
1 parent 9538724 commit f9378e3
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,6 +2468,7 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
assert (z_torch == z).all()


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("op", ['sum', 'max', 'min'])
@pytest.mark.parametrize("BLOCK_N", [32, 64, 128])
Expand Down Expand Up @@ -2513,7 +2514,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device)
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device)
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n)
if not is_interpreter():
if not is_interpreter() and not is_cpu():
assert h.asm['ttgir'].count(
'"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"
y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True)
Expand Down Expand Up @@ -5233,6 +5234,7 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device):
# -----------------------


@pytest.mark.cpu
@pytest.mark.parametrize("enable_fp_fusion", [False, True])
def test_enable_fp_fusion(enable_fp_fusion, device):
if is_hip():
Expand All @@ -5249,10 +5251,12 @@ def mul_add(data):
data = torch.randn((128, ), device=device, dtype=torch.float32)
h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion)

if not is_cuda():
return
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
assert found_fma == enable_fp_fusion
if is_cuda():
found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None
assert found_fma == enable_fp_fusion
elif is_cpu():
found_fma = re.search(r'vfma', h.asm["asm"].decode('utf-8')) is not None
assert found_fma == enable_fp_fusion


# -----------------------
Expand Down

0 comments on commit f9378e3

Please sign in to comment.