diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 02777923303a..8f9c592a411f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -49,7 +49,8 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: config = AttrsDescriptor.from_hints({0: 16}) - assert multiprocessing.get_start_method() == 'fork' + # This can be either fork or spawn, depending on the platform. + assert multiprocessing.get_start_method() in ["fork", "spawn"] proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) proc.start() proc.join() @@ -92,7 +93,9 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: # stage 2.p shutil.rmtree(fresh_triton_cache) - assert multiprocessing.get_start_method() == 'fork' + # This can be either fork or spawn, depending on the platform. + assert multiprocessing.get_start_method() in ["fork", "spawn"] + proc = multiprocessing.Process(target=compile_empty_kernel_with_gc, args=(config, )) # stage 3.c