!35264 implement CSRSparseMatrixToSparseTensor GPU

Merge pull request !35264 from huangmengxi/cuda_csr_op
This commit is contained in:
i-robot 2022-06-23 12:06:01 +00:00 committed by Gitee
commit 9185ebeb8d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 759 additions and 6 deletions

View File

@ -0,0 +1,22 @@
mindspore.ops.csr_to_coo
========================
.. py:function:: mindspore.ops.csr_to_coo(tensor)
将一个CSRTensor转化成一个COOTensor。
.. note::
现在只支持2维CSRTensor。
**参数:**
- **tensor** (CSRTensor) - 一个CSR矩阵必须是2维。
**返回:**
返回一个2维的COOTensor是原COOTensor的CSR格式表示。
**异常:**
- **TypeError** - `tensor` 不是CSRTensor。
- **ValueError** - `tensor` 不是2维CSRTensor。

View File

@ -404,6 +404,7 @@ Sparse Operation
mindspore.ops.dense_to_sparse_coo
mindspore.ops.dense_to_sparse_csr
mindspore.ops.csr_to_coo
Parameter Operation Oprators
----------------------------

View File

@ -241,6 +241,16 @@ inline bool CheckNullInput(const std::vector<T> &input_shape) {
}
#define CHECK_NULL_INPUT(input_shape) mindspore::device::gpu::CheckNullInput(input_shape)
inline bool CheckShapePositive(const std::vector<int64_t> &input_shape) {
if (input_shape.size() != 0) {
if (std::all_of(input_shape.begin(), input_shape.end(), [](int64_t i) { return i > 0; })) {
return true;
}
}
return false;
}
#define CHECK_SHAPE_POSITIVE(input_shape) mindspore::device::gpu::CheckShapePositive(input_shape)
template <typename T>
inline std::string ConvertVectorToString(const std::vector<T> &value) {
std::stringstream ss;

View File

@ -0,0 +1,87 @@
/**
* Copyright 2022 Huawei Sechnologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WISHOUS WARRANSIES OR CONDISIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/device_ptr.h>
#include <thrust/scan.h>
#include <cuda_runtime.h>
#include "csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
#include "include/cuda_fp16.h"
template <typename T>
using Complex = mindspore::utils::Complex<T>;
template <typename S>
__global__ void StackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
indices[pos * 2] = static_cast<S>(__ldg(row_indices + pos));
indices[pos * 2 + 1] = static_cast<S>(__ldg(col_indices + pos));
}
}
template <typename S>
__device__ inline S BinarySearchRange(S *range, S n, S x) {
S left = 0;
S right = n - 1;
while (left < right) {
S mid = left + (right - left) / 2;
if (x < range[mid]) {
right = mid - 1;
} else if (range[mid + 1] <= x) {
left = mid + 1;
} else {
return mid;
}
}
return left;
}
template <typename S>
__global__ void StackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices,
int batch_size, int total_nnz) {
extern __shared__ S local_batch_ptr[];
for (size_t i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
local_batch_ptr[i] = batch_pointers[i];
}
__syncthreads();
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_nnz; pos += blockDim.x * gridDim.x) {
S batch_idx = BinarySearchRange(local_batch_ptr, static_cast<S>(batch_size), static_cast<S>(pos));
indices[pos * 3] = batch_idx;
indices[pos * 3 + 1] = static_cast<S>(__ldg(row_indices + pos));
indices[pos * 3 + 2] = static_cast<S>(__ldg(col_indices + pos));
}
}
template <typename S>
void CallStackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size, cudaStream_t cuda_stream) {
StackIndices2D<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(row_indices, col_indices, indices, size);
return;
}
template <typename S>
void CallStackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices, int batch_size,
int total_nnz, size_t shared_memory_size, cudaStream_t cuda_stream) {
StackIndices3D<<<GET_BLOCKS(total_nnz), GET_THREADS, shared_memory_size, cuda_stream>>>
(batch_pointers, row_indices, col_indices, indices, batch_size, total_nnz);
return;
}
template CUDA_LIB_EXPORT void CallStackIndices2D<int>(const int *row_indices, const int *col_indices, int *indices,
int size, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallStackIndices3D<int>(const int *batch_pointers, const int *row_indices,
const int *col_indices, int *indices, int batch_size,
int total_nnz, size_t shared_memory_size,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename S>
CUDA_LIB_EXPORT void CallStackIndices2D(const S *row_indices, const S *col_indices, S *indices, int size,
cudaStream_t cuda_stream);
template <typename S>
CUDA_LIB_EXPORT void CallStackIndices3D(const S *batch_pointers, const S *row_indices, const S *col_indices, S *indices,
int batch_size, int total_nnz, size_t shared_memory_size,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_CUH_

View File

@ -0,0 +1,242 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/sparse/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.h"
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
bool CSRSparseMatrixToSparseTensorGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::CSRSparseMatrixToSparseTensor>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "cast CSRSparseMatrixToSparseTensor ops failed!";
return false;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCSRSparseMatrixToSparseTensorInputsNum, kernel_ptr->name());
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCSRSparseMatrixToSparseTensorOutputsNum, kernel_ptr->name());
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_ptr->name()
<< "', it does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return true;
}
void CSRSparseMatrixToSparseTensorGpuKernelMod::ResetResource() noexcept {
is_null_input_ = false;
input_dense_shape_size_ = 0;
input_batch_pointers_size_ = 0;
input_row_pointers_size_ = 0;
input_col_indices_size_ = 0;
input_values_size_ = 0;
output_indices_size_ = 0;
output_values_size_ = 0;
output_dense_shape_size_ = 0;
input_size_list_.clear();
workspace_size_list_.clear();
output_size_list_.clear();
}
int CSRSparseMatrixToSparseTensorGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
ResetResource();
input_dense_shape_shapes_ = inputs[kIndex0]->GetShapeVector();
input_batch_pointers_shapes_ = inputs[kIndex1]->GetShapeVector();
input_row_pointers_shapes_ = inputs[kIndex2]->GetShapeVector();
input_col_indices_shapes_ = inputs[kIndex3]->GetShapeVector();
input_values_shapes_ = inputs[kIndex4]->GetShapeVector();
output_indices_shapes_ = outputs[kIndex0]->GetShapeVector();
output_values_shapes_ = outputs[kIndex1]->GetShapeVector();
output_dense_shape_shapes_ = outputs[kIndex2]->GetShapeVector();
if (!(CHECK_SHAPE_POSITIVE(input_dense_shape_shapes_) && CHECK_SHAPE_POSITIVE(input_batch_pointers_shapes_) &&
CHECK_SHAPE_POSITIVE(input_row_pointers_shapes_) && CHECK_SHAPE_POSITIVE(input_col_indices_shapes_) &&
CHECK_SHAPE_POSITIVE(input_values_shapes_) && CHECK_SHAPE_POSITIVE(output_indices_shapes_) &&
CHECK_SHAPE_POSITIVE(output_values_shapes_) && CHECK_SHAPE_POSITIVE(output_dense_shape_shapes_))) {
is_null_input_ = true;
InitSizeLists();
return 0;
}
MS_EXCEPTION_IF_CHECK_FAIL(!input_dense_shape_shapes_.empty(), "input_dense_shape_ should not be empty!");
MS_EXCEPTION_IF_CHECK_FAIL(!input_batch_pointers_shapes_.empty(), "input_batch_pointers_ should not be empty!");
MS_EXCEPTION_IF_CHECK_FAIL(!input_row_pointers_shapes_.empty(), "input_row_pointers_ should not be empty!");
MS_EXCEPTION_IF_CHECK_FAIL(!output_dense_shape_shapes_.empty(), "output_dense_shapes_ should not be empty!");
rank_ = input_dense_shape_shapes_[kIndex0];
is_batch_csr_ = (rank_ == kBatchCSR) ? true : false;
auto GetNums = [](const std::vector<int64_t> &shape) {
size_t res = 1;
for (const auto &sh : shape) {
res *= LongToSize(sh);
}
return res;
};
input_dense_shape_size_ = abstract::TypeIdSize(inputs[kIndex0]->GetDtype()) * GetNums(input_dense_shape_shapes_);
input_batch_pointers_size_ =
abstract::TypeIdSize(inputs[kIndex1]->GetDtype()) * GetNums(input_batch_pointers_shapes_);
input_row_pointers_size_ = abstract::TypeIdSize(inputs[kIndex2]->GetDtype()) * GetNums(input_row_pointers_shapes_);
input_col_indices_size_ = abstract::TypeIdSize(inputs[kIndex3]->GetDtype()) * GetNums(input_col_indices_shapes_);
input_values_size_ = abstract::TypeIdSize(inputs[kIndex4]->GetDtype()) * GetNums(input_values_shapes_);
output_indices_size_ = abstract::TypeIdSize(outputs[kIndex0]->GetDtype()) * GetNums(output_indices_shapes_);
output_values_size_ = abstract::TypeIdSize(outputs[kIndex1]->GetDtype()) * GetNums(output_values_shapes_);
output_dense_shape_size_ = abstract::TypeIdSize(outputs[kIndex2]->GetDtype()) * GetNums(output_dense_shape_shapes_);
InitSizeLists();
return 0;
}
void CSRSparseMatrixToSparseTensorGpuKernelMod::InitSizeLists() {
input_size_list_.push_back(input_dense_shape_size_);
input_size_list_.push_back(input_batch_pointers_size_);
input_size_list_.push_back(input_row_pointers_size_);
input_size_list_.push_back(input_col_indices_size_);
input_size_list_.push_back(input_values_size_);
workspace_size_list_.push_back(input_col_indices_size_);
output_size_list_.push_back(output_indices_size_);
output_size_list_.push_back(output_values_size_);
output_size_list_.push_back(output_dense_shape_size_);
}
template <typename T, typename S>
bool CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
S *csr_dense_shape_addr = GetDeviceAddress<S>(inputs, kIndex0);
S *csr_batch_pointers_addr = GetDeviceAddress<S>(inputs, kIndex1);
S *csr_row_pointers_addr = GetDeviceAddress<S>(inputs, kIndex2);
S *csr_col_indices_addr = GetDeviceAddress<S>(inputs, kIndex3);
T *csr_values_addr = GetDeviceAddress<T>(inputs, kIndex4);
S *sparse_row_indices_addr = GetDeviceAddress<S>(workspace, kIndex0);
S *sparse_indices_addr = GetDeviceAddress<S>(outputs, kIndex0);
T *sparse_values_addr = GetDeviceAddress<T>(outputs, kIndex1);
S *sparse_dense_shape_addr = GetDeviceAddress<S>(outputs, kIndex2);
std::vector<S> host_shape_pointers(input_dense_shape_shapes_[kIndex0], 0);
device::gpu::CudaDriver::CopyDeviceMemToHost(host_shape_pointers.data(), csr_dense_shape_addr, sizeof(S) * rank_);
size_t num_batches = (is_batch_csr_) ? host_shape_pointers[kIndex0] : 1;
auto total_nnz = input_col_indices_shapes_[kIndex0];
auto row_dim = is_batch_csr_ ? kIndex1 : kIndex0;
auto row_size = host_shape_pointers[row_dim];
if (!is_batch_csr_) {
cusparseXcsr2coo(handle_, csr_row_pointers_addr, total_nnz, row_size, sparse_row_indices_addr,
CUSPARSE_INDEX_BASE_ZERO);
CallStackIndices2D<S>(sparse_row_indices_addr, csr_col_indices_addr, sparse_indices_addr, total_nnz,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
std::vector<S> host_batch_pointers(input_batch_pointers_shapes_[kIndex0], 0);
device::gpu::CudaDriver::CopyDeviceMemToHost(host_batch_pointers.data(), csr_batch_pointers_addr,
sizeof(S) * (num_batches + 1));
int accum_nnz = 0;
for (size_t i = 0; i < num_batches; ++i) {
S *row_ind_ptr = csr_row_pointers_addr + i * (row_size + 1);
S nnz = host_batch_pointers[i + 1] - host_batch_pointers[i];
if (nnz != 0) {
cusparseXcsr2coo(handle_, row_ind_ptr, nnz, row_size, sparse_row_indices_addr + accum_nnz,
CUSPARSE_INDEX_BASE_ZERO);
}
accum_nnz += nnz;
}
if (accum_nnz > 0) {
size_t shared_memory_size = sizeof(S) * (num_batches + 1);
CallStackIndices3D<S>(csr_batch_pointers_addr, sparse_row_indices_addr, csr_col_indices_addr, sparse_indices_addr,
num_batches, total_nnz, shared_memory_size, reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
device::gpu::CudaDriver::CopyDeviceMemToDeviceAsync(sparse_values_addr, csr_values_addr, sizeof(T) * total_nnz,
reinterpret_cast<cudaStream_t>(stream_ptr));
device::gpu::CudaDriver::CopyDeviceMemToDeviceAsync(sparse_dense_shape_addr, csr_dense_shape_addr, sizeof(S) * rank_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
std::vector<KernelAttr> CSRSparseMatrixToSparseTensorGpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CSRSparseMatrixToSparseTensorFunc> &pair) { return pair.first; });
return support_list;
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, CSRSparseMatrixToSparseTensorGpuKernelMod::CSRSparseMatrixToSparseTensorFunc>>
CSRSparseMatrixToSparseTensorGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
&CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel<float, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32),
&CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt32),
&CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel<half, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeInt32),
&CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel<Complex<float>, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeInt32),
&CSRSparseMatrixToSparseTensorGpuKernelMod::LaunchKernel<Complex<double>, int>},
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CSRSparseMatrixToSparseTensor, CSRSparseMatrixToSparseTensorGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_MATRIX_TO_SPARSE_TENSOR_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <cusparse.h>
#include <map>
#include <memory>
#include <vector>
#include <utility>
#include "mindspore/core/ops/csr_sparse_matrix_to_sparse_tensor.h"
#include "plugin/device/gpu/hal/device/cuda_driver.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/csr_sparse_matrix_to_sparse_tensor_gpu_kernel.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
constexpr size_t kCSRSparseMatrixToSparseTensorInputsNum = 5;
constexpr size_t kCSRSparseMatrixToSparseTensorOutputsNum = 3;
constexpr size_t kBatchCSR = 3;
namespace mindspore {
namespace kernel {
class CSRSparseMatrixToSparseTensorGpuKernelMod : public NativeGpuKernelMod {
public:
CSRSparseMatrixToSparseTensorGpuKernelMod() {
ResetResource();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCuSparseHandle();
}
~CSRSparseMatrixToSparseTensorGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
MS_EXCEPTION_IF_NULL(kernel_func_);
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
void ResetResource() noexcept;
protected:
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &others) override;
void InitSizeLists();
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
std::vector<KernelAttr> GetOpSupport() override;
private:
using CSRSparseMatrixToSparseTensorFunc =
std::function<bool(CSRSparseMatrixToSparseTensorGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, CSRSparseMatrixToSparseTensorFunc>> func_list_;
CSRSparseMatrixToSparseTensorFunc kernel_func_;
cusparseHandle_t handle_{nullptr};
bool is_null_input_{false};
size_t input_dense_shape_size_{0};
size_t input_batch_pointers_size_{0};
size_t input_row_pointers_size_{0};
size_t input_col_indices_size_{0};
size_t input_values_size_{0};
size_t output_indices_size_{0};
size_t output_values_size_{0};
size_t output_dense_shape_size_{0};
std::vector<int64_t> input_dense_shape_shapes_;
std::vector<int64_t> input_batch_pointers_shapes_;
std::vector<int64_t> input_row_pointers_shapes_;
std::vector<int64_t> input_col_indices_shapes_;
std::vector<int64_t> input_values_shapes_;
std::vector<int64_t> output_indices_shapes_;
std::vector<int64_t> output_values_shapes_;
std::vector<int64_t> output_dense_shape_shapes_;
bool is_batch_csr_{false};
int rank_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPARSE_CSR_SPARSE_M

View File

@ -15,7 +15,7 @@
"""vmap impl."""
from . import vmap_base, vmap_array_ops, vmap_grad_nn_ops, vmap_debug_ops, vmap_math_ops, vmap_nn_ops,\
vmap_image_ops, vmap_other_ops
vmap_image_ops, vmap_other_ops, vmap_sparse_ops
from .vmap_base import get_vmap_rule, vmap_monad_rule, _broadcast_by_axis, vmap_bind_all_none,\
vmap_unstack, vmap_general_output_process

View File

@ -0,0 +1,63 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""sparse_ops vmap impl."""
from ..operations.sparse_ops import DenseToCSRSparseMatrix, CSRSparseMatrixToSparseTensor
from ..primitive import Primitive
from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _raise_value_error
@vmap_rules_getters.register(CSRSparseMatrixToSparseTensor)
def get_csr_sparse_matrix_to_sparse_tensor_vmap_rule(prim, axis_size):
"""VmapRule for `CSRSparseMatrixToSparseTensor` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(shape_bdim, x_batch_pointers_bdim, x_row_pointers_bdim, x_col_indices_bdim, x_values_bdim):
is_all_none, result = vmap_general_preprocess(prim, shape_bdim, x_batch_pointers_bdim, x_row_pointers_bdim,
x_col_indices_bdim, x_values_bdim)
if not is_all_none:
_, shape_dim = shape_bdim
_, x_batch_pointers_dim = x_batch_pointers_bdim
_, x_row_pointers_dim = x_row_pointers_bdim
_, x_col_indices_dim = x_col_indices_bdim
_, x_values_dim = x_values_bdim
_raise_value_error("For operator in CSRSparseMatrixToSparseTensor, all axes for inputs should be None, but"
" got shape_dim: {}, x_batch_pointesr_dim: {}, x_row_pointers_dim: {},"
" x_col_indices_dim: {}, and x_values_dim: {}.".format(shape_dim, x_batch_pointers_dim,
x_row_pointers_dim,
x_col_indices_dim, x_values_dim))
return result
return vmap_rule
@vmap_rules_getters.register(DenseToCSRSparseMatrix)
def get_dense_to_csr_sparse_matrix_vmap_rule(prim, axis_size):
"""VmapRule for `DenseToCSRSparseMatrix` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
def vmap_rule(dense_input_bdim, indices_bdim):
is_all_none, result = vmap_general_preprocess(prim, dense_input_bdim, indices_bdim)
if not is_all_none:
_, dense_input_dim = dense_input_bdim
_, indices_dim = indices_bdim
_raise_value_error("For operator in DenseToCSRSparseMatrix, all axes for inputs should be None, but"
" got dense_input_dim: {}, indices_dim: {}.".format(dense_input_dim, indices_dim))
return result
return vmap_rule

View File

@ -259,6 +259,7 @@ from .linalg_func import (
from .sparse_func import (
dense_to_sparse_coo,
dense_to_sparse_csr,
csr_to_coo,
)
from .random_func import (
standard_laplace,

View File

@ -15,14 +15,17 @@
"""Defines sparse operators with functional form."""
from ..operations.sparse_ops import DenseToCSRSparseMatrix
from ..operations.sparse_ops import DenseToCSRSparseMatrix, CSRSparseMatrixToSparseTensor
from ..operations.array_ops import GatherNd
from ...common import CSRTensor, COOTensor, Tensor
from ...common import dtype as mstype
from ..composite.multitype_ops._constexpr_utils import raise_value_error, raise_type_error
gather_nd = GatherNd()
dense_to_csr = DenseToCSRSparseMatrix()
csr_sparse_matrix_to_sparse_tensor = CSRSparseMatrixToSparseTensor()
batch_csr_pointers_empty = Tensor([0, -1], dtype=mstype.int32)
def dense_to_sparse_coo(tensor):
@ -56,7 +59,7 @@ def dense_to_sparse_coo(tensor):
>>> print(output)
"""
if not isinstance(tensor, Tensor):
raise_type_error("For functional operator dense_to_sparse_coo, input argument msut be a Tensor.")
raise_type_error("For functional operator dense_to_sparse_coo, input argument must be a Tensor.")
if len(tensor.shape) != 2:
raise_value_error("Currently only support 2-D Tensor when converting to COOTensor.")
indices = tensor.nonzero().astype("int32")
@ -96,16 +99,57 @@ def dense_to_sparse_csr(tensor):
>>> print(output)
"""
if not isinstance(tensor, Tensor):
raise_type_error("For functional operator dense_to_sparse_csr, input argument msut be a Tensor.")
raise_type_error("For functional operator dense_to_sparse_csr, input argument must be a Tensor.")
if len(tensor.shape) != 2:
raise_value_error("Currently only support 2-D Tensor when converting to CSRTensor.")
indices = tensor.nonzero().astype("int32")
_, _, indptr, indices, values = dense_to_csr(tensor, indices)
return CSRTensor(indptr, indices, values, tensor.shape)
def csr_to_coo(tensor):
"""
Converts a CSRTensor to COOTensor.
Note:
Only 2-D CSRTensor is supported for now.
Args:
tensor: A CSRTensor, must be 2-D.
Returns:
2D COOTensor, the input tensor stored in COO format.
Raises:
TypeError: If input is not a COOTensor.
ValueError: If input tensor is not 2-D.
Supported Platforms:
``GPU``
Examples:
>>> from mindspore import Tensor, COOTensor
>>> import mindspore as ms
>>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> shape = (3, 4)
>>> x = COOTensor(indices, values, shape)
>>> output = ops.csr_to_coo(x)
>>> print(output)
"""
if not isinstance(tensor, CSRTensor):
raise_type_error("For functional operator csr_to_coo, input argument must be a CSRTensor.")
if len(tensor.shape) != 2:
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
shape = tensor.shape
indices, values, _ = csr_sparse_matrix_to_sparse_tensor(Tensor(shape, dtype=mstype.int32), batch_csr_pointers_empty,
tensor.indptr, tensor.indices, tensor.values)
return COOTensor(indices, values, shape)
__all__ = [
'dense_to_sparse_coo',
'dense_to_sparse_csr'
'dense_to_sparse_csr',
'csr_to_coo'
]
__all__.sort()

View File

@ -1011,6 +1011,7 @@ tensor_operator_registry.register('dense_to_sparse_csr', dense_to_sparse_csr)
tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('csr_to_coo', csr_to_coo)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('unsorted_segment_min', unsorted_segment_min)
tensor_operator_registry.register('unsorted_segment_max', unsorted_segment_max)

View File

@ -224,6 +224,7 @@ class CSRSparseMatrixToSparseTensor(Primitive):
ValueError: If shape of `x_col_indices` is not corresponding to shape of `x_values`.
Supported Platforms:
``GPU``
Examples:
>>> x_dense_shape = Tensor(np.array([2, 2, 4]).astype(np.int64))

View File

@ -0,0 +1,157 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import scipy.sparse
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
def generate_data(shape, datatype="float32", indicetype="int32", density=0.2):
data_shape = shape[-2:]
shape_tensor = np.array(shape, dtype=indicetype)
is_batch_csr = len(shape) == 3
batch_size = shape[0] if is_batch_csr else 1
accum_nnz = 0
x_batch_pointers = np.array(0, dtype=indicetype)
coo_indices = []
for i in range(batch_size):
csr_matrix = scipy.sparse.random(data_shape[0], data_shape[1], format="csr",
density=density, dtype=indicetype)
row_pointers = np.asarray(csr_matrix.indptr, dtype=indicetype)
col_indices = np.asarray(csr_matrix.indices, dtype=indicetype)
values = np.asarray(csr_matrix.data, dtype=datatype)
coo_tensor = csr_matrix.tocoo()
indices = np.stack(
(np.asarray(coo_tensor.row, dtype=indicetype), np.asarray(coo_tensor.col, dtype=indicetype)), axis=1)
if is_batch_csr:
indices = np.insert(indices, 0, i, axis=1)
coo_indices.append(indices)
if i == 0:
x_row_pointers = row_pointers
x_col_indices = col_indices
x_values = values
else:
x_row_pointers = np.append(x_row_pointers, row_pointers)
x_col_indices = np.append(x_col_indices, col_indices)
x_values = np.append(x_values, values)
accum_nnz += csr_matrix.nnz
x_batch_pointers = np.append(x_batch_pointers, accum_nnz)
output_indices = np.concatenate(coo_indices)
x_batch_pointers = x_batch_pointers.astype(indicetype)
return ((shape_tensor, x_batch_pointers, x_row_pointers, x_col_indices, x_values),
(output_indices, x_values, shape_tensor))
def compare_res(res, expected):
assert len(res) == len(expected)
for r, e in zip(res, expected):
assert np.allclose(r.asnumpy(), e)
class CSRToCOONet(nn.Cell):
def __init__(self):
super(CSRToCOONet, self).__init__()
self.to_coo = CSRSparseMatrixToSparseTensor()
def construct(self, shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values):
return self.to_coo(shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values)
class DynamicShapeCSRToCOONet(nn.Cell):
def __init__(self):
super(DynamicShapeCSRToCOONet, self).__init__()
self.unique = P.Unique()
self.to_coo = CSRSparseMatrixToSparseTensor()
def construct(self, shape, x_batch_pointers, x_row_pointers, x_col_indices, x_values):
unqie_col_indices, _ = self.unique(x_col_indices)
unique_values, _ = self.unique(x_values)
return self.to_coo(shape, x_batch_pointers, x_row_pointers, unqie_col_indices, unique_values)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_2d_csr_to_coo():
"""
Feature: Test 2D CSR tensor to COO tensor.
Description: Test 2D CSR tensor(without batch dimension) to csr tensor.
Expectation: Success.
"""
inputs, expects = generate_data((5, 10))
input_tensors = [Tensor(x) for x in inputs]
net = CSRToCOONet()
outputs = net(*input_tensors)
compare_res(outputs, expects)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_csr_to_coo():
"""
Feature: Test 3D CSR tensor to COO tensor.
Description: Test 3D CSR tensor(with batch dimension) to COO tensor.
Expectation: Success.
"""
inputs, expects = generate_data((3, 5, 10))
input_tensors = [Tensor(x) for x in inputs]
net = CSRToCOONet()
outputs = net(*input_tensors)
compare_res(outputs, expects)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_3d_csr_to_coo_fp64():
"""
Feature: Test 3D CSR tensor to COO tensor.
Description: Test 3D CSR tensor(with batch dimension, fp64) to COO tensor.
Expectation: Success.
"""
inputs, expects = generate_data((3, 5, 10), datatype="float64")
input_tensors = [Tensor(x) for x in inputs]
net = CSRToCOONet()
outputs = net(*input_tensors)
compare_res(outputs, expects)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_shape_csr_to_coo():
"""
Feature: Test dynamic shape.
Description: Test CSR tensor to COO tensor.
Expectation: Success.
"""
shape = (3, 10)
x_batch_pointers = Tensor([0, -1], dtype=mstype.int32)
indptr = Tensor([0, 2, 6, 9], dtype=mstype.int32)
indices = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=mstype.int32)
values = np.random.rand(9)
values = np.sort(values)
net = DynamicShapeCSRToCOONet()
outputs = net(Tensor(shape, dtype=mstype.int32), x_batch_pointers, indptr, indices,
Tensor(values, dtype=mstype.float32))
coo_indices = np.array([[0, 1], [0, 2], [1, 3], [1, 4], [1, 5], [1, 6], [2, 7], [2, 8], [2, 9]])
compare_res(outputs, (coo_indices, values, np.array(shape)))

View File

@ -559,7 +559,7 @@ def test_dtype_csr_tensor():
def test_bprop():
"""
Feature: Test back-propagation with CSR-related Ops.
Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
Description: Test CSRReduceSum, CSRMul, CSRDiv, CSRMV.
Expectation: Success.
"""
if get_platform() != "linux":