Skip to content

Commit

Permalink
atomic_rmw ops should return original value (#95)
Browse files Browse the repository at this point in the history
We were previously discarding the original value & returning all zeros.
  • Loading branch information
int3 authored Aug 7, 2024
1 parent cb4f717 commit 49b6bd4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
23 changes: 15 additions & 8 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,26 +1501,33 @@ def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
# triton kernel

@triton.jit
def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
z = tl.sum(x, axis=AXIS)
if AXIS == 1:
tl.atomic_add(Z + off0, z)
old = tl.atomic_add(Z + off0, z)
tl.store(OLD + off0, old)
else:
tl.atomic_add(Z + off1, z)
old = tl.atomic_add(Z + off1, z)
tl.store(OLD + off1, old)

rs = RandomState(17)
x = numpy_random((shape0, shape1), dtype_str=dtype_x_str, rs=rs)
# reference result
z_ref = np.sum(x, axis=axis, keepdims=False)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
# reference results
z_ref = z + np.sum(x, axis=axis, keepdims=False)
old_ref = np.copy(z)
# triton result
x_tri = to_triton(x, device=device)
z_shape = (shape0, ) if axis == 1 else (shape1, )
z_tri = to_triton(np.zeros(z_shape, dtype=getattr(np, dtype_x_str)), device=device)
kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas)
z_tri = to_triton(z, device=device)
old_tri = to_triton(old, device=device)
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
np.testing.assert_equal(old_ref, to_numpy(old_tri))


@pytest.mark.cpu
Expand Down
4 changes: 2 additions & 2 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern<triton::AtomicRMWOp> {
auto ptrTy = cast<RankedTensorType>(op.getPtr().getType()).getElementType();
auto vecTy = cast<VectorType>(vals.getType());
auto strides = computeStrides(vecTy.getShape());
auto res =
Value res =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
int64_t numElems = vecTy.getNumElements();
for (int64_t idx = 0; idx < numElems; ++idx) {
Expand All @@ -97,7 +97,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern<triton::AtomicRMWOp> {

// Elements with const false mask are skipped.
if (resElem) {
rewriter.create<vector::InsertOp>(loc, resElem, res, indices);
res = rewriter.create<vector::InsertOp>(loc, resElem, res, indices);
}
}

Expand Down

0 comments on commit 49b6bd4

Please sign in to comment.