diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2735304fb540..de9dbc157d20 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1501,26 +1501,33 @@ def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): # triton kernel @triton.jit - def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) z = tl.sum(x, axis=AXIS) if AXIS == 1: - tl.atomic_add(Z + off0, z) + old = tl.atomic_add(Z + off0, z) + tl.store(OLD + off0, old) else: - tl.atomic_add(Z + off1, z) + old = tl.atomic_add(Z + off1, z) + tl.store(OLD + off1, old) rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs) - # reference result - z_ref = np.sum(x, axis=axis, keepdims=False) + z_shape = (shape0, ) if axis == 1 else (shape1, ) + z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) + old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) + # reference results + z_ref = z + np.sum(x, axis=axis, keepdims=False) + old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) - z_shape = (shape0, ) if axis == 1 else (shape1, ) - z_tri = to_triton(np.zeros(z_shape, dtype=getattr(np, dtype_x_str)), device=device) - kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + z_tri = to_triton(z, device=device) + old_tri = to_triton(old, device=device) + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) + np.testing.assert_equal(old_ref, to_numpy(old_tri)) @pytest.mark.cpu diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp index 473e97ec8d77..bab0cd94c57e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -72,7 +72,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern { auto ptrTy = cast(op.getPtr().getType()).getElementType(); auto vecTy = cast(vals.getType()); auto strides = computeStrides(vecTy.getShape()); - auto res = + Value res = rewriter.create(loc, rewriter.getZeroAttr(vecTy)); int64_t numElems = vecTy.getNumElements(); for (int64_t idx = 0; idx < numElems; ++idx) { @@ -97,7 +97,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern { // Elements with const false mask are skipped. if (resElem) { - rewriter.create(loc, resElem, res, indices); + res = rewriter.create(loc, resElem, res, indices); } }