Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytests] Add several suits #106

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,21 @@ jobs:
python/test/unit/language/test_annotations.py \
python/test/unit/language/test_block_pointer.py \
python/test/unit/language/test_conversions.py \
python/test/unit/language/test_compile_errors.py \
python/test/unit/language/test_decorator.py \
python/test/unit/language/test_pipeliner.py \
python/test/unit/language/test_random.py \
python/test/unit/language/test_standard.py \
python/test/unit/runtime/test_bindings.py \
python/test/unit/runtime/test_driver.py \
python/test/unit/runtime/test_jit.py \
python/test/unit/runtime/test_launch.py \
python/test/unit/runtime/test_subproc.py \
python/test/unit/runtime/test_autotuner.py \
python/test/unit/runtime/test_cache.py \
python/test/unit/cpu/test_libdevice.py \
python/test/unit/cpu/test_libmvec.py \
python/test/unit/cpu/test_opt.py \
python/test/unit/runtime/test_autotuner.py
python/test/unit/cpu/test_opt.py

- name: Run lit tests
run: |
Expand Down
6 changes: 5 additions & 1 deletion python/test/unit/language/test_pipeliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl
import triton.tools.experimental_descriptor
from test_core import is_cpu


def is_cuda():
Expand Down Expand Up @@ -127,7 +128,10 @@ def test_pipeline_matmul(device):
handler = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
NUM_STAGES=NUM_STAGES)
ref_out = torch.matmul(a, b)
if is_cpu():
ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16)
else:
ref_out = torch.matmul(a, b)
atol = 1e-2 if is_hip_mi200() else None
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
Expand Down
8 changes: 5 additions & 3 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton.language as tl

from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random
from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu

# ---------------
# test maximum/minimum ops
Expand All @@ -26,7 +26,8 @@ def test_maximum_minium(dtype, op, device):


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_sort(M, N, descending, dtype_str, device):
Expand Down Expand Up @@ -54,7 +55,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_flip(M, N, dtype_str, device):

Expand Down
33 changes: 18 additions & 15 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import triton
import triton.language as tl
from triton.runtime.jit import JITFunction
from triton.runtime.jit import JITFunction, get_device_key


@triton.jit
Expand Down Expand Up @@ -193,12 +193,12 @@ def kernel(X, i: tl.int32):

x = torch.empty(1, dtype=torch.int32, device=device)

device = torch.cuda.current_device()
device_key = get_device_key()
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
kernel[(1, )](x, 17)
assert len(kernel.cache[device]) == 3
assert len(kernel.cache[device_key]) == 3


GLOBAL_DEFAULT_ARG = 1
Expand All @@ -221,7 +221,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
kernel[(1, )](x)
assert x == torch.ones_like(x)

device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel.cache[device]) == 1


Expand Down Expand Up @@ -414,7 +414,7 @@ def kernel_add(a, b, o, N: tl.constexpr):
torch.randn(32, dtype=torch.float32, device=device),
32,
]
device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -424,25 +424,28 @@ def kernel_add(a, b, o, N: tl.constexpr):
assert len(kernel_add.cache[device]) == 1


def test_jit_debug() -> None:
def test_jit_debug(device) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

device = torch.cuda.current_device()
assert len(kernel_add.cache[device]) == 0
if device == "cpu":
pytest.skip('Device Assert is not yet supported on CPU')

device_key = get_device_key()
assert len(kernel_add.cache[device_key]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
assert len(kernel_add.cache[device_key]) == 1
kernel_add.debug = False
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 2
assert len(kernel_add.cache[device_key]) == 2
kernel_add.debug = True
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 3
bins = list(kernel_add.cache[device].values())
assert len(kernel_add.cache[device_key]) == 3
bins = list(kernel_add.cache[device_key].values())
assert bins[2].asm['ttir'] != bins[1].asm['ttir']


Expand All @@ -452,13 +455,13 @@ def add_fn(a, b, o, N: tl.constexpr):
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))


def test_jit_noinline() -> None:
def test_jit_noinline(device) -> None:

@triton.jit
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)

device = torch.cuda.current_device()
device = get_device_key()
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
Expand Down Expand Up @@ -502,7 +505,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr):
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx))

device = torch.cuda.current_device()
device = get_device_key()

# get the serialized specialization data
specialization_data = None
Expand Down
6 changes: 3 additions & 3 deletions python/test/unit/runtime/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def kernel(x):
assert used_hook


def test_memory_leak() -> None:
def test_memory_leak(device) -> None:

@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
Expand All @@ -57,8 +57,8 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

tracemalloc.start()
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
inp = torch.randn(10, device=device)
out = torch.randn(10, device=device)
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
Expand Down
12 changes: 9 additions & 3 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ def create_function_from_signature(sig, kparams):
type_canonicalisation_dict[v] = v


def get_device_key():
target = driver.active.get_current_target()
device = driver.active.get_current_device()
return f"{target.backend}:{device}"


class JITFunction(KernelInterface[T]):
# Hook for inspecting compiled functions and modules
cache_hook = None
Expand Down Expand Up @@ -605,7 +611,7 @@ def run(self, *args, grid, warmup, **kwargs):
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)

# compute cache key
device_key = f"{target.backend}:{device}"
device_key = get_device_key()
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
kernel = self.cache[device_key].get(key, None)

Expand Down Expand Up @@ -757,7 +763,7 @@ def preload(self, specialization_data):
from ..compiler import AttrsDescriptor, compile, ASTSource
import json
import triton.language as tl
device = driver.active.get_current_device()
device_key = get_device_key()
deserialized_obj = json.loads(specialization_data)
if deserialized_obj['name'] != self.fn.__name__:
raise RuntimeError(
Expand All @@ -774,7 +780,7 @@ def preload(self, specialization_data):
}
key = deserialized_obj['key']
kernel = compile(src, None, options)
self.cache[device][key] = kernel
self.cache[device_key][key] = kernel
return kernel

# we do not parse `src` in the constructor because
Expand Down