diff --git a/docs/api/api_python/ops/mindspore.ops.func_csr_to_coo.rst b/docs/api/api_python/ops/mindspore.ops.func_csr_to_coo.rst new file mode 100644 index 00000000000..6cad5d91b8d --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_csr_to_coo.rst @@ -0,0 +1,22 @@ +mindspore.ops.csr_to_coo +======================== + +.. py:function:: mindspore.ops.csr_to_coo(tensor) + + 将一个CSRTensor转化成一个COOTensor。 + + .. note:: + 现在只支持2维CSRTensor。 + + **参数:** + + - **tensor** (CSRTensor) - 一个CSR矩阵,必须是2维。 + + **返回:** + + 返回一个2维的COOTensor,是原COOTensor的CSR格式表示。 + + **异常:** + + - **TypeError** - `tensor` 不是CSRTensor。 + - **ValueError** - `tensor` 不是2维CSRTensor。 diff --git a/docs/api/api_python_en/mindspore.ops.functional.rst b/docs/api/api_python_en/mindspore.ops.functional.rst index 800112f9d51..5a15008c560 100644 --- a/docs/api/api_python_en/mindspore.ops.functional.rst +++ b/docs/api/api_python_en/mindspore.ops.functional.rst @@ -404,6 +404,7 @@ Sparse Operation mindspore.ops.dense_to_sparse_coo mindspore.ops.dense_to_sparse_csr + mindspore.ops.csr_to_coo Parameter Operation Oprators ---------------------------- diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_common.h b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_common.h index 3e7e73b8bbe..2fc80e26d5d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_common.h +++ b/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_common.h @@ -241,6 +241,16 @@ inline bool CheckNullInput(const std::vector &input_shape) { } #define CHECK_NULL_INPUT(input_shape) mindspore::device::gpu::CheckNullInput(input_shape) +inline bool CheckShapePositive(const std::vector &input_shape) { + if (input_shape.size() != 0) { + if (std::all_of(input_shape.begin(), input_shape.end(), [](int64_t i) { return i > 0; })) { + return true; + } + } + return false; +} +#define CHECK_SHAPE_POSITIVE(input_shape) mindspore::device::gpu::CheckShapePositive(input_shape) + template inline std::string ConvertVectorToString(const std::vector &value) { std::stringstream ss; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cu new file mode 100644 index 00000000000..9bea22ee877 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2022 Huawei Sechnologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WISHOUS WARRANSIES OR CONDISIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +#include "include/cuda_fp16.h" + +template +using Complex = mindspore::utils::Complex; + +template +__global__ void StackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + indices[pos * 2] = static_cast(__ldg(row_indices + pos)); + indices[pos * 2 + 1] = static_cast(__ldg(col_indices + pos)); + } +} + +template +__device__ inline S BinarySearchRange(S *range, S n, S x) { + S left = 0; + S right = n - 1; + while (left < right) { + S mid = left + (right - left) / 2; + if (x < range[mid]) { + right = mid - 1; + } else if (range[mid + 1] <= x) { + left = mid + 1; + } else { + return mid; + } + } + return left; +} + +template +__global__ void StackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices, + int batch_size, int total_nnz) { + extern __shared__ S local_batch_ptr[]; + for (size_t i = threadIdx.x; i < batch_size + 1; i += blockDim.x) { + local_batch_ptr[i] = batch_pointers[i]; + } + __syncthreads(); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_nnz; pos += blockDim.x * gridDim.x) { + S batch_idx = BinarySearchRange(local_batch_ptr, static_cast(batch_size), static_cast(pos)); + indices[pos * 3] = batch_idx; + indices[pos * 3 + 1] = static_cast(__ldg(row_indices + pos)); + indices[pos * 3 + 2] = static_cast(__ldg(col_indices + pos)); + } +} + +template +void CallStackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size, cudaStream_t cuda_stream) { + StackIndices2D<<>>(row_indices, col_indices, indices, size); + return; +} + +template +void CallStackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices, int batch_size, + int total_nnz, size_t shared_memory_size, cudaStream_t cuda_stream) { + StackIndices3D<<>> + (batch_pointers, row_indices, col_indices, indices, batch_size, total_nnz); + return; +} + +template CUDA_LIB_EXPORT void CallStackIndices2D(const int *row_indices, const int *col_indices, int *indices, + int size, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallStackIndices3D(const int *batch_pointers, const int *row_indices, + const int *col_indices, int *indices, int batch_size, + int total_nnz, size_t shared_memory_size, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh new file mode 100644 index 00000000000..0de7ffefe19 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +template +CUDA_LIB_EXPORT void CallStackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size, + cudaStream_t cuda_stream); + +template +CUDA_LIB_EXPORT void CallStackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices, + int batch_size, int total_nnz, size_t shared_memory_size, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cc new file mode 100644 index 00000000000..6ea74080d7d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.h" +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +namespace mindspore { +namespace kernel { +bool CSRSparseMatrixToSparseTensorGpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr = std::dynamic_pointer_cast(base_operator); + if (!kernel_ptr) { + MS_LOG(ERROR) << "cast CSRSparseMatrixToSparseTensor ops failed!"; + return false; + } + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCSRSparseMatrixToSparseTensorInputsNum, kernel_ptr->name()); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCSRSparseMatrixToSparseTensorOutputsNum, kernel_ptr->name()); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(EXCEPTION) << "For '" << kernel_ptr->name() + << "', it does not support this kernel data type: " << kernel_attr; + } + kernel_func_ = func_list_[index].second; + return true; +} + +void CSRSparseMatrixToSparseTensorGpuKernelMod::ResetResource() noexcept { + is_null_input_ = false; + input_dense_shape_size_ = 0; + input_batch_pointers_size_ = 0; + input_row_pointers_size_ = 0; + input_col_indices_size_ = 0; + input_values_size_ = 0; + output_indices_size_ = 0; + output_values_size_ = 0; + output_dense_shape_size_ = 0; + input_size_list_.clear(); + workspace_size_list_.clear(); + output_size_list_.clear(); +} + +int CSRSparseMatrixToSparseTensorGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &others) { + ResetResource(); + input_dense_shape_shapes_ = inputs[kIndex0]->GetShapeVector(); + input_batch_pointers_shapes_ = inputs[kIndex1]->GetShapeVector(); + input_row_pointers_shapes_ = inputs[kIndex2]->GetShapeVector(); + input_col_indices_shapes_ = inputs[kIndex3]->GetShapeVector(); + input_values_shapes_ = inputs[kIndex4]->GetShapeVector(); + output_indices_shapes_ = outputs[kIndex0]->GetShapeVector(); + output_values_shapes_ = outputs[kIndex1]->GetShapeVector(); + output_dense_shape_shapes_ = outputs[kIndex2]->GetShapeVector(); + if (!(CHECK_SHAPE_POSITIVE(input_dense_shape_shapes_) && CHECK_SHAPE_POSITIVE(input_batch_pointers_shapes_) && + CHECK_SHAPE_POSITIVE(input_row_pointers_shapes_) && CHECK_SHAPE_POSITIVE(input_col_indices_shapes_) && + CHECK_SHAPE_POSITIVE(input_values_shapes_) && CHECK_SHAPE_POSITIVE(output_indices_shapes_) && + CHECK_SHAPE_POSITIVE(output_values_shapes_) && CHECK_SHAPE_POSITIVE(output_dense_shape_shapes_))) { + is_null_input_ = true; + InitSizeLists(); + return 0; + } + + MS_EXCEPTION_IF_CHECK_FAIL(!input_dense_shape_shapes_.empty(), "input_dense_shape_ should not be empty!"); + MS_EXCEPTION_IF_CHECK_FAIL(!input_batch_pointers_shapes_.empty(), "input_batch_pointers_ should not be empty!"); + MS_EXCEPTION_IF_CHECK_FAIL(!input_row_pointers_shapes_.empty(), "input_row_pointers_ should not be empty!"); + MS_EXCEPTION_IF_CHECK_FAIL(!output_dense_shape_shapes_.empty(), "output_dense_shapes_ should not be empty!"); + rank_ = input_dense_shape_shapes_[kIndex0]; + is_batch_csr_ = (rank_ == kBatchCSR) ? true : false; + + auto GetNums = [](const std::vector &shape) { + size_t res = 1; + for (const auto &sh : shape) { + res *= LongToSize(sh); + } + return res; + }; + input_dense_shape_size_ = abstract::TypeIdSize(inputs[kIndex0]->GetDtype()) * GetNums(input_dense_shape_shapes_); + input_batch_pointers_size_ = + abstract::TypeIdSize(inputs[kIndex1]->GetDtype()) * GetNums(input_batch_pointers_shapes_); + input_row_pointers_size_ = abstract::TypeIdSize(inputs[kIndex2]->GetDtype()) * GetNums(input_row_pointers_shapes_); + input_col_indices_size_ = abstract::TypeIdSize(inputs[kIndex3]->GetDtype()) * GetNums(input_col_indices_shapes_); + input_values_size_ = abstract::TypeIdSize(inputs[kIndex4]->GetDtype()) * GetNums(input_values_shapes_); + output_indices_size_ = abstract::TypeIdSize(outputs[kIndex0]->GetDtype()) * GetNums(output_indices_shapes_); + output_values_size_ = abstract::TypeIdSize(outputs[kIndex1]->GetDtype()) * GetNums(output_values_shapes_); + output_dense_shape_size_ = abstract::TypeIdSize(outputs[kIndex2]->GetDtype()) * GetNums(output_dense_shape_shapes_); + InitSizeLists(); + return 0; +} + +void CSRSparseMatrixToSparseTensorGpuKernelMod::InitSizeLists() { + input_size_list_.push_back(input_dense_shape_size_); + input_size_list_.push_back(input_batch_pointers_size_); + input_size_list_.push_back(input_row_pointers_size_); + input_size_list_.push_back(input_col_indices_size_); + input_size_list_.push_back(input_values_size_); + workspace_size_list_.push_back(input_col_indices_size_); + output_size_list_.push_back(output_indices_size_); + output_size_list_.push_back(output_values_size_); + output_size_list_.push_back(output_dense_shape_size_); +} + +template +bool CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + S *csr_dense_shape_addr = GetDeviceAddress(inputs, kIndex0); + S *csr_batch_pointers_addr = GetDeviceAddress(inputs, kIndex1); + S *csr_row_pointers_addr = GetDeviceAddress(inputs, kIndex2); + S *csr_col_indices_addr = GetDeviceAddress(inputs, kIndex3); + T *csr_values_addr = GetDeviceAddress(inputs, kIndex4); + S *sparse_row_indices_addr = GetDeviceAddress(workspace, kIndex0); + S *sparse_indices_addr = GetDeviceAddress(outputs, kIndex0); + T *sparse_values_addr = GetDeviceAddress(outputs, kIndex1); + S *sparse_dense_shape_addr = GetDeviceAddress(outputs, kIndex2); + + std::vector host_shape_pointers(input_dense_shape_shapes_[kIndex0], 0); + device::gpu::CudaDriver::CopyDeviceMemToHost(host_shape_pointers.data(), csr_dense_shape_addr, sizeof(S) * rank_); + size_t num_batches = (is_batch_csr_) ? host_shape_pointers[kIndex0] : 1; + auto total_nnz = input_col_indices_shapes_[kIndex0]; + auto row_dim = is_batch_csr_ ? kIndex1 : kIndex0; + auto row_size = host_shape_pointers[row_dim]; + + if (!is_batch_csr_) { + cusparseXcsr2coo(handle_, csr_row_pointers_addr, total_nnz, row_size, sparse_row_indices_addr, + CUSPARSE_INDEX_BASE_ZERO); + CallStackIndices2D(sparse_row_indices_addr, csr_col_indices_addr, sparse_indices_addr, total_nnz, + reinterpret_cast(stream_ptr)); + } else { + std::vector host_batch_pointers(input_batch_pointers_shapes_[kIndex0], 0); + device::gpu::CudaDriver::CopyDeviceMemToHost(host_batch_pointers.data(), csr_batch_pointers_addr, + sizeof(S) * (num_batches + 1)); + int accum_nnz = 0; + for (size_t i = 0; i < num_batches; ++i) { + S *row_ind_ptr = csr_row_pointers_addr + i * (row_size + 1); + S nnz = host_batch_pointers[i + 1] - host_batch_pointers[i]; + if (nnz != 0) { + cusparseXcsr2coo(handle_, row_ind_ptr, nnz, row_size, sparse_row_indices_addr + accum_nnz, + CUSPARSE_INDEX_BASE_ZERO); + } + accum_nnz += nnz; + } + if (accum_nnz > 0) { + size_t shared_memory_size = sizeof(S) * (num_batches + 1); + CallStackIndices3D(csr_batch_pointers_addr, sparse_row_indices_addr, csr_col_indices_addr, sparse_indices_addr, + num_batches, total_nnz, shared_memory_size, reinterpret_cast(stream_ptr)); + } + } + device::gpu::CudaDriver::CopyDeviceMemToDeviceAsync(sparse_values_addr, csr_values_addr, sizeof(T) * total_nnz, + reinterpret_cast(stream_ptr)); + device::gpu::CudaDriver::CopyDeviceMemToDeviceAsync(sparse_dense_shape_addr, csr_dense_shape_addr, sizeof(S) * rank_, + reinterpret_cast(stream_ptr)); + return true; +} + +std::vector CSRSparseMatrixToSparseTensorGpuKernelMod::GetOpSupport() { + static std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +template +using Complex = mindspore::utils::Complex; + +std::vector> + CSRSparseMatrixToSparseTensorGpuKernelMod::func_list_ = { + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + &CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeInt32), + &CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeInt32), + &CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeInt32), + &CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel, int>}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeInt32), + &CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel, int>}, +}; + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTensorGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.h new file mode 100644 index 00000000000..cf8ed1e83f7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/csr_sparse_matrix_to_sparse_tensor.h" +#include "plugin/device/gpu/hal/device/cuda_driver.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" + +constexpr size_t kCSRSparseMatrixToSparseTensorInputsNum = 5; +constexpr size_t kCSRSparseMatrixToSparseTensorOutputsNum = 3; +constexpr size_t kBatchCSR = 3; + +namespace mindspore { +namespace kernel { +class CSRSparseMatrixToSparseTensorGpuKernelMod : public NativeGpuKernelMod { + public: + CSRSparseMatrixToSparseTensorGpuKernelMod() { + ResetResource(); + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCuSparseHandle(); + } + ~CSRSparseMatrixToSparseTensorGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + void ResetResource() noexcept; + + protected: + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &others) override; + void InitSizeLists(); + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + std::vector GetOpSupport() override; + + private: + using CSRSparseMatrixToSparseTensorFunc = + std::function &, + const std::vector &, const std::vector &, void *)>; + static std::vector> func_list_; + CSRSparseMatrixToSparseTensorFunc kernel_func_; + cusparseHandle_t handle_{nullptr}; + bool is_null_input_{false}; + size_t input_dense_shape_size_{0}; + size_t input_batch_pointers_size_{0}; + size_t input_row_pointers_size_{0}; + size_t input_col_indices_size_{0}; + size_t input_values_size_{0}; + size_t output_indices_size_{0}; + size_t output_values_size_{0}; + size_t output_dense_shape_size_{0}; + std::vector input_dense_shape_shapes_; + std::vector input_batch_pointers_shapes_; + std::vector input_row_pointers_shapes_; + std::vector input_col_indices_shapes_; + std::vector input_values_shapes_; + std::vector output_indices_shapes_; + std::vector output_values_shapes_; + std::vector output_dense_shape_shapes_; + bool is_batch_csr_{false}; + int rank_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_M diff --git a/mindspore/python/mindspore/ops/_vmap/__init__.py b/mindspore/python/mindspore/ops/_vmap/__init__.py index 4ad9dff27dd..47436809ff3 100644 --- a/mindspore/python/mindspore/ops/_vmap/__init__.py +++ b/mindspore/python/mindspore/ops/_vmap/__init__.py @@ -15,7 +15,7 @@ """vmap impl.""" from . import vmap_base, vmap_array_ops, vmap_grad_nn_ops, vmap_debug_ops, vmap_math_ops, vmap_nn_ops,\ - vmap_image_ops, vmap_other_ops + vmap_image_ops, vmap_other_ops, vmap_sparse_ops from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\ vmap_unstack, vmap_general_output_process diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_sparse_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_sparse_ops.py new file mode 100644 index 00000000000..6c4dbefd4f2 --- /dev/null +++ b/mindspore/python/mindspore/ops/_vmap/vmap_sparse_ops.py @@ -0,0 +1,63 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""sparse_ops vmap impl.""" + +from ..operations.sparse_ops import DenseToCSRSparseMatrix, CSRSparseMatrixToSparseTensor +from ..primitive import Primitive +from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error + + +@vmap_rules_getters.register(CSRSparseMatrixToSparseTensor) +def get_csr_sparse_matrix_to_sparse_tensor_vmap_rule(prim, axis_size): + """VmapRule for `CSRSparseMatrixToSparseTensor` operation.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(shape_bdim, x_batch_pointers_bdim, x_row_pointers_bdim, x_col_indices_bdim, x_values_bdim): + is_all_none, result = vmap_general_preprocess(prim, shape_bdim, x_batch_pointers_bdim, x_row_pointers_bdim, + x_col_indices_bdim, x_values_bdim) + if not is_all_none: + _, shape_dim = shape_bdim + _, x_batch_pointers_dim = x_batch_pointers_bdim + _, x_row_pointers_dim = x_row_pointers_bdim + _, x_col_indices_dim = x_col_indices_bdim + _, x_values_dim = x_values_bdim + _raise_value_error("For operator in CSRSparseMatrixToSparseTensor, all axes for inputs should be None, but" + " got shape_dim: {}, x_batch_pointesr_dim: {}, x_row_pointers_dim: {}," + " x_col_indices_dim: {}, and x_values_dim: {}.".format(shape_dim, x_batch_pointers_dim, + x_row_pointers_dim, + x_col_indices_dim, x_values_dim)) + return result + + return vmap_rule + + +@vmap_rules_getters.register(DenseToCSRSparseMatrix) +def get_dense_to_csr_sparse_matrix_vmap_rule(prim, axis_size): + """VmapRule for `DenseToCSRSparseMatrix` operation.""" + if isinstance(prim, str): + prim = Primitive(prim) + + def vmap_rule(dense_input_bdim, indices_bdim): + is_all_none, result = vmap_general_preprocess(prim, dense_input_bdim, indices_bdim) + if not is_all_none: + _, dense_input_dim = dense_input_bdim + _, indices_dim = indices_bdim + _raise_value_error("For operator in DenseToCSRSparseMatrix, all axes for inputs should be None, but" + " got dense_input_dim: {}, indices_dim: {}.".format(dense_input_dim, indices_dim)) + return result + + return vmap_rule diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index ace92741414..109b5b596a8 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -259,6 +259,7 @@ from .linalg_func import ( from .sparse_func import ( dense_to_sparse_coo, dense_to_sparse_csr, + csr_to_coo, ) from .random_func import ( standard_laplace, diff --git a/mindspore/python/mindspore/ops/function/sparse_func.py b/mindspore/python/mindspore/ops/function/sparse_func.py index d610922fc74..9c43ce02c68 100644 --- a/mindspore/python/mindspore/ops/function/sparse_func.py +++ b/mindspore/python/mindspore/ops/function/sparse_func.py @@ -15,14 +15,17 @@ """Defines sparse operators with functional form.""" -from ..operations.sparse_ops import DenseToCSRSparseMatrix +from ..operations.sparse_ops import DenseToCSRSparseMatrix, CSRSparseMatrixToSparseTensor from ..operations.array_ops import GatherNd from ...common import CSRTensor, COOTensor, Tensor +from ...common import dtype as mstype from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error gather_nd = GatherNd() dense_to_csr = DenseToCSRSparseMatrix() +csr_sparse_matrix_to_sparse_tensor = CSRSparseMatrixToSparseTensor() +batch_csr_pointers_empty = Tensor([0, -1], dtype=mstype.int32) def dense_to_sparse_coo(tensor): @@ -56,7 +59,7 @@ def dense_to_sparse_coo(tensor): >>> print(output) """ if not isinstance(tensor, Tensor): - raise_type_error("For functional operator dense_to_sparse_coo, input argument msut be a Tensor.") + raise_type_error("For functional operator dense_to_sparse_coo, input argument must be a Tensor.") if len(tensor.shape) != 2: raise_value_error("Currently only support 2-D Tensor when converting to COOTensor.") indices = tensor.nonzero().astype("int32") @@ -96,16 +99,57 @@ def dense_to_sparse_csr(tensor): >>> print(output) """ if not isinstance(tensor, Tensor): - raise_type_error("For functional operator dense_to_sparse_csr, input argument msut be a Tensor.") + raise_type_error("For functional operator dense_to_sparse_csr, input argument must be a Tensor.") if len(tensor.shape) != 2: raise_value_error("Currently only support 2-D Tensor when converting to CSRTensor.") indices = tensor.nonzero().astype("int32") _, _, indptr, indices, values = dense_to_csr(tensor, indices) return CSRTensor(indptr, indices, values, tensor.shape) + +def csr_to_coo(tensor): + """ + Converts a CSRTensor to COOTensor. + + Note: + Only 2-D CSRTensor is supported for now. + + Args: + tensor: A CSRTensor, must be 2-D. + + Returns: + 2D COOTensor, the input tensor stored in COO format. + + Raises: + TypeError: If input is not a COOTensor. + ValueError: If input tensor is not 2-D. + + Supported Platforms: + ``GPU`` + + Examples: + >>> from mindspore import Tensor, COOTensor + >>> import mindspore as ms + >>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32) + >>> values = Tensor([1, 2], dtype=ms.float32) + >>> shape = (3, 4) + >>> x = COOTensor(indices, values, shape) + >>> output = ops.csr_to_coo(x) + >>> print(output) + """ + if not isinstance(tensor, CSRTensor): + raise_type_error("For functional operator csr_to_coo, input argument must be a CSRTensor.") + if len(tensor.shape) != 2: + raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.") + shape = tensor.shape + indices, values, _ = csr_sparse_matrix_to_sparse_tensor(Tensor(shape, dtype=mstype.int32), batch_csr_pointers_empty, + tensor.indptr, tensor.indices, tensor.values) + return COOTensor(indices, values, shape) + __all__ = [ 'dense_to_sparse_coo', - 'dense_to_sparse_csr' + 'dense_to_sparse_csr', + 'csr_to_coo' ] __all__.sort() diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 3c0c0abce66..6f45fd4509d 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -1011,6 +1011,7 @@ tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr) tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo) tensor_operator_registry.register('narrow', narrow) tensor_operator_registry.register('sort', sort) +tensor_operator_registry.register('csr_to_coo', csr_to_coo) tensor_operator_registry.register('zeros', zeros) tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min) tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max) diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index 358a17ca0b0..99c6fd87803 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -224,6 +224,7 @@ class CSRSparseMatrixToSparseTensor(Primitive): ValueError: If shape of `x_col_indices` is not corresponding to shape of `x_values`. Supported Platforms: + ``GPU`` Examples: >>> x_dense_shape = Tensor(np.array([2, 2, 4]).astype(np.int64)) diff --git a/tests/st/ops/gpu/test_sparse_csr_to_sparse_op.py b/tests/st/ops/gpu/test_sparse_csr_to_sparse_op.py new file mode 100644 index 00000000000..ccdd491f64a --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_csr_to_sparse_op.py @@ -0,0 +1,157 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import scipy.sparse +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor + + +def generate_data(shape, datatype="float32", indicetype="int32", density=0.2): + data_shape = shape[-2:] + shape_tensor = np.array(shape, dtype=indicetype) + is_batch_csr = len(shape) == 3 + batch_size = shape[0] if is_batch_csr else 1 + accum_nnz = 0 + x_batch_pointers = np.array(0, dtype=indicetype) + coo_indices = [] + for i in range(batch_size): + csr_matrix = scipy.sparse.random(data_shape[0], data_shape[1], format="csr", + density=density, dtype=indicetype) + row_pointers = np.asarray(csr_matrix.indptr, dtype=indicetype) + col_indices = np.asarray(csr_matrix.indices, dtype=indicetype) + values = np.asarray(csr_matrix.data, dtype=datatype) + coo_tensor = csr_matrix.tocoo() + indices = np.stack( + (np.asarray(coo_tensor.row, dtype=indicetype), np.asarray(coo_tensor.col, dtype=indicetype)), axis=1) + if is_batch_csr: + indices = np.insert(indices, 0, i, axis=1) + coo_indices.append(indices) + if i == 0: + x_row_pointers = row_pointers + x_col_indices = col_indices + x_values = values + else: + x_row_pointers = np.append(x_row_pointers, row_pointers) + x_col_indices = np.append(x_col_indices, col_indices) + x_values = np.append(x_values, values) + accum_nnz += csr_matrix.nnz + x_batch_pointers = np.append(x_batch_pointers, accum_nnz) + output_indices = np.concatenate(coo_indices) + x_batch_pointers = x_batch_pointers.astype(indicetype) + return ((shape_tensor, x_batch_pointers, x_row_pointers, x_col_indices, x_values), + (output_indices, x_values, shape_tensor)) + + +def compare_res(res, expected): + assert len(res) == len(expected) + for r, e in zip(res, expected): + assert np.allclose(r.asnumpy(), e) + + +class CSRToCOONet(nn.Cell): + def __init__(self): + super(CSRToCOONet, self).__init__() + self.to_coo = CSRSparseMatrixToSparseTensor() + + def construct(self, shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values): + return self.to_coo(shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values) + + +class DynamicShapeCSRToCOONet(nn.Cell): + def __init__(self): + super(DynamicShapeCSRToCOONet, self).__init__() + self.unique = P.Unique() + self.to_coo = CSRSparseMatrixToSparseTensor() + + def construct(self, shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values): + unqie_col_indices, _ = self.unique(x_col_indices) + unique_values, _ = self.unique(x_values) + return self.to_coo(shape, x_batch_pointers, x_row_pointers, unqie_col_indices, unique_values) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_2d_csr_to_coo(): + """ + Feature: Test 2D CSR tensor to COO tensor. + Description: Test 2D CSR tensor(without batch dimension) to csr tensor. + Expectation: Success. + """ + inputs, expects = generate_data((5, 10)) + input_tensors = [Tensor(x) for x in inputs] + net = CSRToCOONet() + outputs = net(*input_tensors) + compare_res(outputs, expects) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_csr_to_coo(): + """ + Feature: Test 3D CSR tensor to COO tensor. + Description: Test 3D CSR tensor(with batch dimension) to COO tensor. + Expectation: Success. + """ + inputs, expects = generate_data((3, 5, 10)) + input_tensors = [Tensor(x) for x in inputs] + net = CSRToCOONet() + outputs = net(*input_tensors) + compare_res(outputs, expects) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_3d_csr_to_coo_fp64(): + """ + Feature: Test 3D CSR tensor to COO tensor. + Description: Test 3D CSR tensor(with batch dimension, fp64) to COO tensor. + Expectation: Success. + """ + inputs, expects = generate_data((3, 5, 10), datatype="float64") + input_tensors = [Tensor(x) for x in inputs] + net = CSRToCOONet() + outputs = net(*input_tensors) + compare_res(outputs, expects) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_shape_csr_to_coo(): + """ + Feature: Test dynamic shape. + Description: Test CSR tensor to COO tensor. + Expectation: Success. + """ + shape = (3, 10) + x_batch_pointers = Tensor([0, -1], dtype=mstype.int32) + indptr = Tensor([0, 2, 6, 9], dtype=mstype.int32) + indices = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=mstype.int32) + values = np.random.rand(9) + values = np.sort(values) + net = DynamicShapeCSRToCOONet() + outputs = net(Tensor(shape, dtype=mstype.int32), x_batch_pointers, indptr, indices, + Tensor(values, dtype=mstype.float32)) + coo_indices = np.array([[0, 1], [0, 2], [1, 3], [1, 4], [1, 5], [1, 6], [2, 7], [2, 8], [2, 9]]) + compare_res(outputs, (coo_indices, values, np.array(shape))) diff --git a/tests/st/sparse/test_csr.py b/tests/st/sparse/test_csr.py index 3ca772b9877..f6686ea9356 100644 --- a/tests/st/sparse/test_csr.py +++ b/tests/st/sparse/test_csr.py @@ -559,7 +559,7 @@ def test_dtype_csr_tensor(): def test_bprop(): """ Feature: Test back-propagation with CSR-related Ops. - Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense(). + Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV. Expectation: Success. """ if get_platform() != "linux":