diff --git a/include/matx/core/storage.h b/include/matx/core/storage.h index a460e8c3..789a783b 100644 --- a/include/matx/core/storage.h +++ b/include/matx/core/storage.h @@ -189,7 +189,7 @@ namespace matx * * @return Size of allocation */ - size_t size() const + __MATX_INLINE__ size_t size() const { return size_; } diff --git a/include/matx/core/tensor.h b/include/matx/core/tensor.h index be5ffece..00227a65 100644 --- a/include/matx/core/tensor.h +++ b/include/matx/core/tensor.h @@ -107,7 +107,7 @@ class tensor_t : public detail::tensor_impl_t { * @param rhs Object to copy from */ __MATX_HOST__ tensor_t(tensor_t const &rhs) noexcept - : detail::tensor_impl_t{rhs.ldata_, rhs.desc_}, storage_(rhs.storage_) + : detail::tensor_impl_t{rhs.Data(), rhs.desc_}, storage_(rhs.storage_) { } /** @@ -116,7 +116,7 @@ class tensor_t : public detail::tensor_impl_t { * @param rhs Object to move from */ __MATX_HOST__ tensor_t(tensor_t &&rhs) noexcept - : detail::tensor_impl_t{rhs.ldata_, std::move(rhs.desc_)}, storage_(std::move(rhs.storage_)) + : detail::tensor_impl_t{rhs.Data(), std::move(rhs.desc_)}, storage_(std::move(rhs.storage_)) { } @@ -131,7 +131,7 @@ class tensor_t : public detail::tensor_impl_t { */ __MATX_HOST__ void Shallow(const self_type &rhs) noexcept { - this->ldata_ = rhs.ldata_; + this->SetData(rhs.Data()); storage_ = rhs.storage_; this->desc_ = rhs.desc_; } @@ -149,7 +149,9 @@ class tensor_t : public detail::tensor_impl_t { { using std::swap; - std::swap(lhs.ldata_, rhs.ldata_); + auto tmpdata = lhs.Data(); + lhs.SetData(rhs.Data()); + rhs.SetData(tmpdata); swap(lhs.storage_, rhs.storage_); swap(lhs.desc_, rhs.desc_); } @@ -651,7 +653,7 @@ class tensor_t : public detail::tensor_impl_t { // Copy descriptor and call ctor with shape Desc new_desc{std::forward(shape)}; - return tensor_t{storage_, std::move(new_desc), this->ldata_}; + return tensor_t{storage_, std::move(new_desc), this->Data()}; } /** @@ -710,7 +712,7 @@ class tensor_t : public detail::tensor_impl_t { "To get a reshaped view the tensor must be compact"); DefaultDescriptor desc{std::move(tshape)}; - return tensor_t{storage_, std::move(desc), this->ldata_}; + return tensor_t{storage_, std::move(desc), this->Data()}; } /** @@ -739,7 +741,7 @@ class tensor_t : public detail::tensor_impl_t { int dev; cudaGetDevice(&dev); - cudaMemPrefetchAsync(this->ldata_, this->desc_.TotalSize() * sizeof(T), dev, stream); + cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), dev, stream); } /** @@ -756,7 +758,7 @@ class tensor_t : public detail::tensor_impl_t { { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) - cudaMemPrefetchAsync(this->ldata_, this->desc_.TotalSize() * sizeof(T), cudaCpuDeviceId, + cudaMemPrefetchAsync(this->Data(), this->desc_.TotalSize() * sizeof(T), cudaCpuDeviceId, stream); } @@ -776,7 +778,7 @@ class tensor_t : public detail::tensor_impl_t { static_assert(is_complex_v, "RealView() only works with complex types"); using Type = typename U::value_type; - Type *data = reinterpret_cast(this->ldata_); + Type *data = reinterpret_cast(this->Data()); cuda::std::array strides; #pragma unroll @@ -821,7 +823,7 @@ class tensor_t : public detail::tensor_impl_t { static_assert(is_complex_v, "ImagView() only works with complex types"); using Type = typename U::value_type; - Type *data = reinterpret_cast(this->ldata_) + 1; + Type *data = reinterpret_cast(this->Data()) + 1; cuda::std::array strides; #pragma unroll for (int i = 0; i < RANK; i++) { @@ -859,7 +861,7 @@ class tensor_t : public detail::tensor_impl_t { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) auto new_desc = this->PermuteImpl(dims); - return tensor_t{storage_, std::move(new_desc), this->ldata_}; + return tensor_t{storage_, std::move(new_desc), this->Data()}; } @@ -904,14 +906,6 @@ class tensor_t : public detail::tensor_impl_t { return Permute(tdims); } - /** - * Get the underlying local data pointer from the view - * - * @returns Underlying data pointer of type T - * - */ - __MATX_HOST__ __MATX_INLINE__ T *Data() const noexcept { return this->ldata_; } - /** * Set the underlying data pointer from the view * @@ -933,7 +927,7 @@ class tensor_t : public detail::tensor_impl_t { { this->desc_.InitFromShape(std::forward(shape)); storage_.SetData(data); - this->ldata_ = data; + this->SetData(data); } /** @@ -953,7 +947,7 @@ class tensor_t : public detail::tensor_impl_t { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) storage_.SetData(data); - this->ldata_ = data; + this->SetData(data); } /** @@ -973,7 +967,7 @@ class tensor_t : public detail::tensor_impl_t { Reset(T *const data, T *const ldata) noexcept { storage_.SetData(data); - this->ldata_ = ldata; + this->SetData(data); } @@ -1035,7 +1029,7 @@ class tensor_t : public detail::tensor_impl_t { OverlapView(const cuda::std::array &windows, const cuda::std::array &strides) const { auto new_desc = this->template OverlapViewImpl(windows, strides); - return tensor_t{storage_, std::move(new_desc), this->ldata_}; + return tensor_t{storage_, std::move(new_desc), this->Data()}; } /** @@ -1069,7 +1063,7 @@ class tensor_t : public detail::tensor_impl_t { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) auto new_desc = this->template CloneImpl(clones); - return tensor_t{storage_, std::move(new_desc), this->ldata_}; + return tensor_t{storage_, std::move(new_desc), this->Data()}; } template @@ -1080,7 +1074,7 @@ class tensor_t : public detail::tensor_impl_t { __MATX_INLINE__ __MATX_HOST__ bool IsManagedPointer() { bool managed; - const CUresult retval = cuPointerGetAttribute(&managed, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)Data()); + const CUresult retval = cuPointerGetAttribute(&managed, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)this->Data()); MATX_ASSERT(retval == CUDA_SUCCESS, matxNotSupported); return managed; } @@ -1454,12 +1448,12 @@ class tensor_t : public detail::tensor_impl_t { int dev_ord; void *data[2] = {&mem_type, &dev_ord}; - t->data = static_cast(this->ldata_); + t->data = static_cast(this->Data()); t->device.device_id = 0; // Determine where this memory resides - auto kind = GetPointerKind(this->ldata_); - auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast(this->ldata_)); + auto kind = GetPointerKind(this->Data()); + auto mem_res = cuPointerGetAttributes(sizeof(attr)/sizeof(attr[0]), attr, data, reinterpret_cast(this->Data())); MATX_ASSERT_STR_EXP(mem_res, CUDA_SUCCESS, matxCudaError, "Error returned from cuPointerGetAttributes"); if (kind == MATX_INVALID_MEMORY) { if (mem_type == CU_MEMORYTYPE_DEVICE) { diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index 6c00062b..928f632e 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -49,6 +49,22 @@ namespace matx { namespace detail { +template +struct DenseTensorData { + using dense_data = bool; + T *ldata_; +}; +template +struct SparseTensorData { + using sparse_data = bool; + using crd_type = CRD; + using pos_type = POS; + static constexpr int LVL = Lvl; + + T *ldata_; + CRD *crd_[LVL]; + POS *pos_[LVL]; +}; /** * @brief Bare implementation of tensor class @@ -66,7 +82,7 @@ namespace detail { * @tparam T Type of tensor * @tparam RANK Rank of tensor */ -template > +template , typename TensorData = DenseTensorData> class tensor_impl_t { public: // Type specifier for reflection on class @@ -75,10 +91,11 @@ class tensor_impl_t { using tensor_view = bool; using tensor_impl = bool; using desc_type = Desc; + using data_type = TensorData; using shape_type = typename Desc::shape_type; using stride_type = typename Desc::stride_type; using matxoplvalue = bool; - using self_type = tensor_impl_t; + using self_type = tensor_impl_t; // Type specifier for signaling this is a matx operation using matxop = bool; @@ -108,7 +125,7 @@ class tensor_impl_t { { using std::swap; - swap(lhs.ldata_, rhs.ldata_); + swap(lhs.data_, rhs.data_); swap(lhs.desc_, rhs.desc_); } @@ -125,10 +142,10 @@ class tensor_impl_t { * @param data * Data pointer */ - tensor_impl_t(T *const data) : ldata_(data) { + tensor_impl_t(T *const data) { + data_.ldata_ = data; static_assert(RANK == 0, "tensor_impl_t with single pointer parameter must be a rank 0 tensor"); } - /** * Constructor for a rank-1 and above tensor. * @@ -152,8 +169,7 @@ class tensor_impl_t { template __MATX_INLINE__ tensor_impl_t(ShapeType &&shape, StrideType &&strides) : desc_(std::forward(shape), std::forward(strides)) - { - } + {} /** * Constructor for a rank-1 and above tensor using a user pointer and shape @@ -168,8 +184,9 @@ class tensor_impl_t { */ template ::type>, bool> = true> __MATX_INLINE__ tensor_impl_t(T *const ldata, ShapeType &&shape) - : ldata_(ldata), desc_(std::forward(shape)) + : desc_(std::forward(shape)) { + data_.ldata_ = ldata; } @@ -191,8 +208,9 @@ class tensor_impl_t { __MATX_INLINE__ tensor_impl_t(T *const ldata, ShapeType &&shape, StrideType &&strides) - : ldata_(ldata), desc_(std::forward(shape), std::forward(strides)) + : desc_(std::forward(shape), std::forward(strides)) { + data_.ldata_ = ldata; } @@ -214,8 +232,9 @@ MATX_IGNORE_WARNING_PUSH_GCC("-Wmaybe-uninitialized") template ::type>, bool> = true> __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ tensor_impl_t(T *const ldata, DescriptorType &&desc) - : ldata_(ldata), desc_{std::forward(desc)} + : desc_{std::forward(desc)} { + data_.ldata_ = ldata; } MATX_IGNORE_WARNING_POP_GCC @@ -237,7 +256,7 @@ MATX_IGNORE_WARNING_POP_GCC __MATX_HOST__ void Shallow(const self_type &rhs) noexcept { - ldata_ = rhs.ldata_; + data_.ldata_ = rhs.Data(); desc_ = rhs.desc_; } @@ -253,7 +272,7 @@ MATX_IGNORE_WARNING_POP_GCC */ [[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto copy(const tensor_impl_t &op) { - ldata_ = op.ldata_; + data_.ldata_ = op.Data(); desc_ = op.desc_; } @@ -639,7 +658,7 @@ MATX_IGNORE_WARNING_POP_GCC cuda::std::array n = {}; cuda::std::array s = {}; - T *data = ldata_; + T *tmpdata = data_.ldata_; int d = 0; [[maybe_unused]] int end_count = 0; @@ -677,7 +696,7 @@ MATX_IGNORE_WARNING_POP_GCC "Requested slice start index out of bounds"); // offset by first - data += first * Stride(i); + tmpdata += first * Stride(i); if constexpr (N > 0) { if (end != matxDropDim) { @@ -702,7 +721,7 @@ MATX_IGNORE_WARNING_POP_GCC MATX_ASSERT_STR(d == N, matxInvalidDim, "Number of indices must match the target rank to slice to"); - return cuda::std::make_tuple(tensor_desc_t{std::move(n), std::move(s)}, data); + return cuda::std::make_tuple(tensor_desc_t{std::move(n), std::move(s)}, tmpdata); } @@ -770,7 +789,7 @@ MATX_IGNORE_WARNING_POP_GCC auto new_desc = CloneImpl(clones); - return tensor_impl_t{this->ldata_, std::move(new_desc)}; + return tensor_impl_t{this->data_.ldata_, std::move(new_desc)}; } __MATX_INLINE__ auto PermuteImpl(const cuda::std::array &dims) const @@ -800,7 +819,7 @@ MATX_IGNORE_WARNING_POP_GCC __MATX_INLINE__ auto Permute(const cuda::std::array &dims) const { auto new_desc = PermuteImpl(dims); - return tensor_impl_t{this->ldata_, std::move(new_desc)}; + return tensor_impl_t{this->data_.ldata_, std::move(new_desc)}; } template @@ -845,7 +864,7 @@ MATX_IGNORE_WARNING_POP_GCC OverlapView(const cuda::std::array &windows, const cuda::std::array &strides) const { auto new_desc = OverlapViewImpl(windows, strides); - return tensor_impl_t{this->ldata_, std::move(new_desc)}; + return tensor_impl_t{this->data_.ldata_, std::move(new_desc)}; } template @@ -870,7 +889,7 @@ MATX_IGNORE_WARNING_POP_GCC template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* GetPointer(Is... indices) const noexcept { - return ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...)); + return data_.ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...)); } /** @@ -922,9 +941,9 @@ MATX_IGNORE_WARNING_POP_GCC { static_assert(sizeof...(Is) == M, "Number of indices of operator() must match rank of tensor"); #ifndef NDEBUG - assert(ldata_ != nullptr); + assert(data_.ldata_ != nullptr); #endif - return *(ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...))); + return *(data_.ldata_ + GetValC<0, Is...>(cuda::std::make_tuple(indices...))); } /** @@ -942,9 +961,9 @@ MATX_IGNORE_WARNING_POP_GCC { static_assert(sizeof...(Is) == M, "Number of indices of operator() must match rank of tensor"); #ifndef NDEBUG - assert(ldata_ != nullptr); + assert(data_.ldata_ != nullptr); #endif - return *(ldata_ + GetVal<0, Is...>(cuda::std::make_tuple(indices...))); + return *(data_.ldata_ + GetVal<0, Is...>(cuda::std::make_tuple(indices...))); } /** @@ -973,30 +992,6 @@ MATX_IGNORE_WARNING_POP_GCC }, idx); } - - /** - * operator() setter with an array index - * - * @returns value in tensor - * - */ - // template = 1, bool> = true> - // __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T &operator()(const cuda::std::array &idx) noexcept - // { - // if constexpr (RANK == 1) { - // return this->operator()(idx[0]); - // } - // else if constexpr (RANK == 2) { - // return this->operator()(idx[0], idx[1]); - // } - // else if constexpr (RANK == 3) { - // return this->operator()(idx[0], idx[1], idx[2]); - // } - // else { - // return this->operator()(idx[0], idx[1], idx[2], idx[3]); - // } - // } - /** * Get the rank of the tensor * @@ -1057,7 +1052,7 @@ MATX_IGNORE_WARNING_POP_GCC __MATX_INLINE__ __MATX_HOST__ auto Bytes() const noexcept { - return TotalSize() * sizeof(*ldata_); + return TotalSize() * sizeof(*data_.ldata_); } /** @@ -1066,9 +1061,28 @@ MATX_IGNORE_WARNING_POP_GCC * @return data pointer */ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Data() const noexcept { - return ldata_; + return data_.ldata_; } + /** + * @brief Set data pointer + * + * @param data Pointer to new data pointer + */ + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ void SetData(T *data) noexcept { + data_.ldata_ = data; + } + + template , int> = 0> + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto CRDData(int l) const noexcept { + return data_.crd_[l]; + } + template , int> = 0> + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto POSData(int l) const noexcept{ + return data_.pos_[l]; + } /** * @brief Set local data pointer @@ -1077,7 +1091,17 @@ MATX_IGNORE_WARNING_POP_GCC * Data pointer to set */ void SetLocalData(T* data) { - ldata_ = data; + data_.ldata_ = data; + } + + template + auto SetSparseData(T* data, + typename U::crd_type* crd[U::LVL], + typename U::pos_type* pos[U::LVL]) + -> std::enable_if_t, void> { + data_.ldata_ = data; + memcpy(data_.crd_, crd, U::LVL*sizeof(crd[0])); + memcpy(data_.pos_, pos, U::LVL*sizeof(pos[0])); } template @@ -1092,7 +1116,7 @@ MATX_IGNORE_WARNING_POP_GCC protected: - T *ldata_; + TensorData data_; Desc desc_; }; diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index 3c66019b..9db69df1 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -95,7 +95,7 @@ struct remove_cvref { template using remove_cvref_t = typename remove_cvref::type; -template class tensor_impl_t; +template class tensor_impl_t; template class tensor_t; namespace detail { @@ -581,6 +581,39 @@ struct is_matx_storage_container inline constexpr bool is_matx_storage_container_v = detail::is_matx_storage_container::type>::value; +namespace detail { +template +struct is_sparse_data : std::false_type { +}; +template +struct is_sparse_data> + : std::true_type { +}; +} +/** + * @brief Determine if a type is a MatX sparse data type + * + * @tparam T Type to test + */ +template +inline constexpr bool is_sparse_data_v = detail::is_sparse_data::type>::value; +namespace detail { +template +struct is_sparse_tensor : std::false_type { +}; +template +struct is_sparse_tensor> + : std::true_type { +}; +} +/** + * @brief Determine if a type is a MatX sparse tensor type + * + * @tparam T Type to test + */ +template +inline constexpr bool is_sparse_tensor_v = detail::is_sparse_tensor::type>::value; + namespace detail { template @@ -801,7 +834,7 @@ constexpr cuda::std::array, N> to_array(T (&a)[N]) } template class tensor_t; -template class tensor_impl_t; +template class tensor_impl_t; // Traits for casting down to impl tensor conditionally template struct base_type { @@ -810,7 +843,7 @@ struct base_type { template struct base_type>> { - using type = tensor_impl_t; + using type = tensor_impl_t; }; template using base_type_t = typename base_type::type>::type; diff --git a/include/matx/operators/set.h b/include/matx/operators/set.h index d74d47a5..811181f4 100644 --- a/include/matx/operators/set.h +++ b/include/matx/operators/set.h @@ -44,7 +44,7 @@ template class BaseOp; ///< Base operator type namespace detail { -template class tensor_impl_t; ///< Tensor implementation type +template class tensor_impl_t; ///< Tensor implementation type /**