Skip to content

Commit

Permalink
Compute a scalar pointer for vector load instead of extracting it fro…
Browse files Browse the repository at this point in the history
…m a tensor (#92)

* Compute a scalar pointer for vector load instead of extracting it from a tensor.

Signed-off-by: Ilya Enkovich <[email protected]>

* Add lit test for scalar ptr usage.

Signed-off-by: Ilya Enkovich <[email protected]>

---------

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored Aug 7, 2024
1 parent 49b6bd4 commit c4cac2e
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ jobs:
python/test/unit/language/test_block_pointer.py \
python/test/unit/language/test_conversions.py \
python/test/unit/cpu/test_libdevice.py \
python/test/unit/cpu/test_libmvec.py
python/test/unit/cpu/test_libmvec.py \
python/test/unit/cpu/test_opt.py
- name: Run lit tests
run: |
Expand Down
37 changes: 37 additions & 0 deletions python/test/unit/cpu/test_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import torch

import triton
import triton.language as tl


def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'


def is_cpu():
return not is_interpreter() and \
triton.runtime.driver.active.get_current_target().backend == "cpu"


def is_x86():
return is_cpu() and \
triton.runtime.driver.active.get_current_target().arch == "x86_64"


def test_scalar_pointer_arith(device):

@triton.jit
def kernel(src, dst, BLOCK_SIZE: tl.constexpr):
offs = tl.arange(0, BLOCK_SIZE)
x = tl.load(src + offs)
tl.store(dst + offs, x)

src = torch.rand((128, ), dtype=torch.float32, device=device)
res = torch.empty_like(src)
meta = kernel[(1, )](src, res, BLOCK_SIZE=128)
assert (src == res).all()

# Check TTCIR doesn't have pointer extraction from a tensor.
ttcir = meta.asm["ttcir"]
assert ttcir.count("extract") == 0
23 changes: 23 additions & 0 deletions test/TritonCPU/convert-memory-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,26 @@ module {
tt.return
}
}

// -----

// Check that pointer for vector load/store is not extracted from a vector

// CHECK-LABEL: @scalar_ptrs
// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64>
// CHECK: {{.+}} = vector.load {{.+}} : memref<128xf32>, vector<128xf32>
// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64>
// CHECK: vector.store {{.+}}, {{.+}} : memref<128xf32>, vector<128xf32>

module {
tt.func public @scalar_ptrs(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
%3 = tt.load %2 : tensor<128x!tt.ptr<f32>>
%4 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>>
%5 = tt.addptr %4, %0 : tensor<128x!tt.ptr<f32>>, tensor<128xi32>
tt.store %5, %3 : tensor<128x!tt.ptr<f32>>
tt.return
}
}
58 changes: 57 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,70 @@ struct MemoryOpConversion : public OpConversionPattern<OpT> {
Value extractScalarPointer(Location loc, Value ptrs,
ArrayRef<int64_t> indices,
ConversionPatternRewriter &rewriter) const {
// TODO: Analyze data flow and build scalar pointer computation code.
// If we build a vector of pointers and the extract a pointer from it, then
// compiler doesn't always optimize it to a simple scalar pointer
// computation. Here we try to follow a data flow of the tensor to rebuild a
// scalar pointer for more efficient resulting code.
if (canComputeScalarValue(ptrs))
return computeScalarValue(ptrs, indices, rewriter);

// Fall back to a scalar pointer extraction from the vector.
Value ptr = rewriter.create<vector::ExtractOp>(
loc, rewriter.getRemappedValue(ptrs), indices);
auto ptrTy = dyn_cast<RankedTensorType>(ptrs.getType()).getElementType();
ptr = rewriter.create<IntToPtrOp>(loc, ptrTy, ptr);
return ptr;
}

bool canComputeScalarValue(Value vals) const {
if (auto def = vals.getDefiningOp<AddPtrOp>()) {
return canComputeScalarValue(def.getPtr()) &&
canComputeScalarValue(def.getOffset());
}

if (auto def = vals.getDefiningOp<arith::AddIOp>()) {
return canComputeScalarValue(def.getLhs()) &&
canComputeScalarValue(def.getRhs());
}

if (vals.getDefiningOp<SplatOp>() || vals.getDefiningOp<MakeRangeOp>()) {
return true;
}

return false;
}

Value computeScalarValue(Value vals, ArrayRef<int64_t> indices,
ConversionPatternRewriter &rewriter) const {
if (auto def = vals.getDefiningOp<AddPtrOp>()) {
Value ptr = computeScalarValue(def.getPtr(), indices, rewriter);
Value offs = computeScalarValue(def.getOffset(), indices, rewriter);
return rewriter.create<AddPtrOp>(def.getLoc(), ptr.getType(), ptr, offs);
}

if (auto def = vals.getDefiningOp<arith::AddIOp>()) {
Value lhs = computeScalarValue(def.getLhs(), indices, rewriter);
Value rhs = computeScalarValue(def.getRhs(), indices, rewriter);
return rewriter.create<arith::AddIOp>(def.getLoc(), lhs.getType(), lhs,
rhs);
}

if (auto def = vals.getDefiningOp<SplatOp>()) {
return def.getSrc();
}

if (auto def = vals.getDefiningOp<MakeRangeOp>()) {
int32_t start = static_cast<int32_t>(def.getStart());
assert(indices.size() == 1);
Type elemTy = cast<RankedTensorType>(def.getType()).getElementType();
return rewriter.create<arith::ConstantOp>(
def.getLoc(), elemTy,
rewriter.getIntegerAttr(elemTy, start + indices[0]));
}

return Value();
}

Value extractMemRef(Location loc, Value ptr,
ConversionPatternRewriter &rewriter) const {
auto tensorTy = dyn_cast<RankedTensorType>(
Expand Down

0 comments on commit c4cac2e

Please sign in to comment.