diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cu new file mode 100644 index 00000000000..f220e812c28 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cu @@ -0,0 +1,194 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_to_dense_impl.cuh" +#include +#include +#include +#include +#include "include/cuda_fp16.h" + +template +__global__ void SetDefaultValue(const T default_value, const int64_t output_elements, T *output) { + for (size_t ops = blockIdx.x * blockDim.x + threadIdx.x; ops < output_elements; ops += blockDim.x * gridDim.x) { + output[ops] = default_value; + } +} + +template +void CallSetDefaultValue(const T default_value, const int64_t output_elements, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + SetDefaultValue<<>>( + default_value, output_elements, output); + return; +} + +template +__global__ void SparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals, + const Index *output_shape, const int ndims, T *output) { + for (size_t ops = blockIdx.x * blockDim.x + threadIdx.x; ops < num_elems; ops += blockDim.x * gridDim.x) { + int64_t output_idx = indices[ops * ndims + ndims - 1]; + Index strides = 1; + for (int i = ndims - 2; i >= 0; i--) { + strides *= output_shape[i + 1]; + output_idx += indices[ops * ndims + i] * strides; + } + // If num_vals == 1, broadcast the scalar to the positions for non-zeros. + output[output_idx] = vals[(num_vals == 1) ? 0 : ops]; + } + __syncthreads(); + return; +} + +template +void CallSparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals, + const Index *output_shape, const int ndims, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream) { + SparseToDense<<>>( + indices, vals, num_elems, num_vals, output_shape, ndims, output); + return; +} + +template CUDA_LIB_EXPORT void CallSetDefaultValue(bool default_value, const int64_t output_elements, bool *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(int8_t default_value, const int64_t output_elements, + int8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(int16_t default_value, const int64_t output_elements, + int16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(int32_t default_value, const int64_t output_elements, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(int64_t default_value, const int64_t output_elements, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(uint8_t default_value, const int64_t output_elements, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(uint16_t default_value, const int64_t output_elements, + uint16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(half default_value, const int64_t output_elements, half *output, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(float default_value, const int64_t output_elements, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSetDefaultValue(double default_value, const int64_t output_elements, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const bool *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + bool *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const int8_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + int8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const int16_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + int16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const int32_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const int64_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const uint8_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const uint16_t *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + uint16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const half *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + half *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const float *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int32_t *indices, const double *vals, + const int num_elems, const int num_vals, + const int32_t *output_shape, const int ndims, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const bool *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + bool *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const int8_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + int8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const int16_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + int16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const int32_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + int32_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const int64_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + int64_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const uint8_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + uint8_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const uint16_t *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + uint16_t *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const half *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + half *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const float *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + float *output, const uint32_t &device_id, + cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CallSparseToDense(const int64_t *indices, const double *vals, + const int num_elems, const int num_vals, + const int64_t *output_shape, const int ndims, + double *output, const uint32_t &device_id, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cuh new file mode 100644 index 00000000000..1b1498d286f --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_ +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" +template +void CallSetDefaultValue(T default_value, const int64_t output_elements, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +template +void CallSparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals, + const Index *output_shape, const int ndims, T *output, const uint32_t &device_id, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.cc new file mode 100644 index 00000000000..e566fd51023 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.cc @@ -0,0 +1,399 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/core/ops/base_operator.h" +#include "mindspore/core/abstract/utils.h" +#include "plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseToDenseV2InputsNum = 4; +constexpr size_t kSparseToDenseV2OutputsNum = 1; +constexpr size_t kSparseToDenseV2First = 0; +constexpr size_t kSparseToDenseV2Second = 1; +constexpr size_t kSparseToDenseV2Third = 2; +constexpr size_t kSparseToDenseV2Fourth = 3; +constexpr size_t kSparseToDenseV2TwoDims = 2; +constexpr size_t kSparseToDenseV2OneDim = 1; +constexpr size_t kSparseToDenseV2ZeroDim = 0; +} // namespace + +bool SparseToDenseV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_ptr_ = std::dynamic_pointer_cast(base_operator); + kernel_name_ = kernel_ptr_->name(); + validate_indices_ = kernel_ptr_->get_validate_indices(); + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + indice_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + value_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first); + return true; +} + +int SparseToDenseV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + for (const auto &input : inputs) { + // If any input shape contains -1, means input shape is dynamic, so just return do nothing. + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + ResetResource(); + indices_shape_ = std::vector(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex0)->GetDeviceShapeAdaptively().end()); + output_shape_ = std::vector(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex1)->GetDeviceShapeAdaptively().end()); + std::vector input_shape_values = std::vector(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(), + inputs.at(kIndex2)->GetDeviceShapeAdaptively().end()); + indices_dims_ = indices_shape_.size(); + ndims = indices_shape_.size() > 1 ? indices_shape_[1] : 1; + num_elems = indices_shape_.size() > 0 ? indices_shape_[0] : 1; + values_size_ = input_shape_values[0]; + output_elements = 1; + std::vector output_shape = outputs.at(kIndex0)->GetShapeVector(); + for (size_t i = 0; i < output_shape.size(); ++i) { + output_elements *= output_shape[i]; + } + input_elements_indices = std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies()); + input_elements_values = + std::accumulate(input_shape_values.begin(), input_shape_values.end(), 1, std::multiplies()); + input_elements_output_shape = + std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies()); + size_t input_size_indices = input_elements_indices * indice_size_; + size_t input_size_values = input_elements_values * value_size_; + size_t input_size_output_shape = input_elements_output_shape * indice_size_; + size_t output_size = output_elements * value_size_; + input_size_list_.push_back(input_size_indices); + input_size_list_.push_back(input_size_values); + input_size_list_.push_back(input_size_output_shape); + output_size_list_.push_back(output_size); + return KRET_OK; +} + +void SparseToDenseV2GpuKernelMod::ResetResource() noexcept { + output_elements = 1; + input_elements_indices = 0; + input_elements_values = 0; + input_elements_output_shape = 0; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); +} + +template +void SparseToDenseV2GpuKernelMod::CheckValidateTwoDim(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_); + if (outputs[0]->size == 0) { + MS_LOG(WARNING) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0."; + } + I *input_indices = GetDeviceAddress(inputs, kIndex0); + I *indices_addr = reinterpret_cast(malloc(input_elements_indices * indice_size_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_addr, input_indices, input_elements_indices * indice_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync indices failed"); + + I *input_output_shape = GetDeviceAddress(inputs, kIndex1); + I *output_shape_addr = reinterpret_cast(malloc(input_elements_output_shape * indice_size_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output_shape_addr, input_output_shape, input_elements_output_shape * indice_size_, + cudaMemcpyDeviceToHost, reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync dense_shape failed"); + bool valid = true; + bool different = false; + bool increasing = true; + for (size_t k = 0; k < indices_shape_[1]; ++k) { + size_t index = k; + if (indices_addr[index] < 0 || indices_addr[index] >= output_shape_addr[index]) { + valid = false; + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds."; + } + for (size_t i = 1; i < indices_shape_[0]; ++i) { + for (size_t j = 0; j < indices_shape_[1]; ++j) { + size_t index1 = i * indices_shape_[1] + j; + size_t index2 = (i - 1) * indices_shape_[1] + j; + if (indices_addr[index1] < 0 || indices_addr[index1] >= output_shape_addr[j]) { + valid = false; + } + I diff = indices_addr[index1] - indices_addr[index2]; + if (diff > 0) { + different = true; + } + if (!different && diff < 0) { + increasing = false; + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds."; + } + if (!increasing) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of order."; + } + if (!different) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is repeated"; + } + } +} + +template +void SparseToDenseV2GpuKernelMod::CheckValidateOneDim(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_); + if (outputs[0]->size == 0) { + MS_LOG(WARNING) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0."; + } + I *input_indices = GetDeviceAddress(inputs, kIndex0); + I *indices_addr = reinterpret_cast(malloc(input_elements_indices * indice_size_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(indices_addr, input_indices, input_elements_indices * indice_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync indices failed"); + + I *input_output_shape = GetDeviceAddress(inputs, kIndex1); + I *output_shape_addr = reinterpret_cast(malloc(input_elements_output_shape * indice_size_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output_shape_addr, input_output_shape, input_elements_output_shape * indice_size_, + cudaMemcpyDeviceToHost, reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync dense_shape failed"); + bool valid = true; + bool different = false; + bool increasing = true; + if (indices_addr[0] < 0 || indices_addr[0] > output_shape_addr[0]) { + valid = false; + } + for (size_t i = 1; i < indices_shape_[0]; ++i) { + if (indices_addr[i] < 0 || indices_addr[i] >= output_shape_addr[0]) { + valid = false; + } + I diff = indices_addr[i] - indices_addr[i - 1]; + if (diff > 0) { + different = true; + } + if (!different && diff < 0) { + increasing = false; + } + if (!valid) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds."; + } + if (!increasing) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of order."; + } + if (!different) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is repeated"; + } + } +} + +template +bool SparseToDenseV2GpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (validate_indices_ == true && indices_dims_ == kSparseToDenseV2TwoDims) { + (void)SparseToDenseV2GpuKernelMod::CheckValidateTwoDim(inputs, workspace, outputs); + } else if (validate_indices_ == true && indices_dims_ == kSparseToDenseV2OneDim) { + (void)SparseToDenseV2GpuKernelMod::CheckValidateOneDim(inputs, workspace, outputs); + } + I *input_indices = GetDeviceAddress(inputs, kIndex0); + I *input_output_shape = GetDeviceAddress(inputs, kIndex1); + T *input_values = GetDeviceAddress(inputs, kIndex2); + T *input_default_value = GetDeviceAddress(inputs, kIndex3); + T *output = GetDeviceAddress(outputs, kIndex0); + + T *default_value_data = reinterpret_cast(malloc(value_size_)); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(default_value_data, input_default_value, value_size_, cudaMemcpyDeviceToHost, + reinterpret_cast(cuda_stream_)), + "cudaMemcpyAsync default_value failed"); + + auto cuda_stream = reinterpret_cast(cuda_stream_); + CallSetDefaultValue(default_value_data[0], output_elements, output, device_id_, cuda_stream); + CallSparseToDense(input_indices, input_values, num_elems, input_elements_values, input_output_shape, ndims, output, + device_id_, cuda_stream); + return true; +} + +std::vector> + SparseToDenseV2GpuKernelMod::func_list_ = {{KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeBool) + .AddOutputAttr(kNumberTypeBool), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &SparseToDenseV2GpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + &SparseToDenseV2GpuKernelMod::LaunchKernel}}; + +std::vector SparseToDenseV2GpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { + return pair.first; + }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseToDenseV2, SparseToDenseV2GpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.h new file mode 100644 index 00000000000..55d061cf907 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.h @@ -0,0 +1,94 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ops/sparse_to_dense_v2.h" +#include "abstract/utils.h" +#include "plugin/factory/ms_factory.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cuh" + +namespace mindspore { +namespace kernel { +class SparseToDenseV2GpuKernelMod : public NativeGpuKernelMod { + public: + SparseToDenseV2GpuKernelMod() { ResetResource(); } + ~SparseToDenseV2GpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + protected: + std::vector GetOpSupport() override; + + private: + void ResetResource() noexcept; + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + template + void CheckValidateOneDim(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + template + void CheckValidateTwoDim(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs); + + using SparseToDenseV2LaunchFunc = + std::function &, const std::vector &, + const std::vector &, void *)>; + static std::vector> func_list_; + SparseToDenseV2LaunchFunc kernel_func_{}; + size_t indice_size_{1}; + size_t value_size_{1}; + size_t input_elements_indices; + size_t input_elements_values; + size_t input_elements_output_shape; + size_t output_elements; + int ndims; + int num_elems; + bool is_null_input_{false}; + void *cuda_stream_{nullptr}; + bool validate_indices_{true}; + std::vector indices_shape_; + std::vector output_shape_; + size_t indices_dims_{0}; + size_t values_size_{0}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/sparse_to_dense_v2.cc b/mindspore/core/ops/sparse_to_dense_v2.cc index 6c17cc8f408..eea3f7716be 100644 --- a/mindspore/core/ops/sparse_to_dense_v2.cc +++ b/mindspore/core/ops/sparse_to_dense_v2.cc @@ -29,17 +29,14 @@ namespace mindspore { namespace ops { namespace { -namespace { -constexpr size_t kIndiceselement = 2; -constexpr size_t kOutShapeSize = 1; -constexpr size_t kValuesSize = 1; -constexpr size_t kDefaultSize = 0; -constexpr size_t kDefaultElem = 1; -} // namespace abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); + const size_t Indiceselement = 2; + const size_t OutShapeSize = 1; + const size_t ValuesSize = 1; + const size_t DefaultSize = 0; auto max_length_ptr = primitive->GetAttr("max_length"); MS_EXCEPTION_IF_NULL(max_length_ptr); int64_t max_length = GetValue(max_length_ptr); @@ -51,26 +48,13 @@ abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive, auto output_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(output_shape_shape_ptr)[kShape]; auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(values_shape_ptr)[kShape]; auto default_value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(default_value_shape_ptr)[kShape]; - (void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kLessEqual, kIndiceselement, + (void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kLessEqual, Indiceselement, prim_name); - (void)CheckAndConvertUtils::CheckInteger("outshape dimension", output_shape_shape.size(), kEqual, kOutShapeSize, + (void)CheckAndConvertUtils::CheckInteger("outshape dimension", output_shape_shape.size(), kEqual, OutShapeSize, prim_name); - (void)CheckAndConvertUtils::CheckInteger("values dimension", values_shape.size(), kLessEqual, kValuesSize, prim_name); - (void)CheckAndConvertUtils::CheckInteger("default_value dimension", default_value_shape.size(), kEqual, kDefaultSize, + (void)CheckAndConvertUtils::CheckInteger("values dimension", values_shape.size(), kLessEqual, ValuesSize, prim_name); + (void)CheckAndConvertUtils::CheckInteger("default_value dimension", default_value_shape.size(), kEqual, DefaultSize, prim_name); - if (indices_shape.size() == 0) { - if (values_shape.size() != 0 && values_shape[0] != 1) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] is 1" - << " should match the the values element " << values_shape[0] << "."; - } - } else { - if (values_shape.size() != 0) { - if (indices_shape[0] != values_shape[0]) { - MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] " << indices_shape[0] - << " should match the the values element " << values_shape[0] << "."; - } - } - } size_t output_shape_numelement = output_shape_shape[0]; auto output_shape = input_args[1]->cast(); MS_EXCEPTION_IF_NULL(output_shape); @@ -85,6 +69,19 @@ abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive, MS_EXCEPTION_IF_NULL(output_shape_type_element); std::vector y_shape; if (!input_args[1]->BuildValue()->isa() && !input_args[1]->BuildValue()->isa()) { + if (indices_shape.size() == 0) { + if (values_shape.size() != 0 && values_shape[0] != 1) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] is 1" + << " should match the the values element " << values_shape[0] << "."; + } + } else { + if (values_shape.size() != 0) { + if (indices_shape[0] != values_shape[0]) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] " << indices_shape[0] + << " should match the the values element " << values_shape[0] << "."; + } + } + } if (output_shape_type_element->type_id() == kNumberTypeInt32) { auto output_shape_data = reinterpret_cast(output_shape_tensor->data_c()); for (size_t i = 0; i < output_shape_numelement; ++i) { diff --git a/mindspore/python/mindspore/ops/_grad/grad_sparse.py b/mindspore/python/mindspore/ops/_grad/grad_sparse.py index 2817d8ab0cd..972c0cc1028 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_sparse.py +++ b/mindspore/python/mindspore/ops/_grad/grad_sparse.py @@ -25,6 +25,7 @@ from .. import operations as P from ..operations import _csr_ops from ..operations.sparse_ops import SparseAdd, CSRSparseMatrixToDense, CSRSparseMatrixToSparseTensor, \ DenseToCSRSparseMatrix +from ..operations.sparse_ops import SparseToDenseV2 # Unused parameters are placeholders. @@ -65,6 +66,19 @@ def get_bprop_sparse_to_dense(self): return bprop +@bprop_getters.register(SparseToDenseV2) +def get_bprop_sparse_to_dense_v2(self): + """Generate bprop for SparseToDenseV2""" + + def bprop(indices, output_shape, values, default_value, out, dout): + sparse_values_grad = F.gather_nd(dout, indices) + default_value_grad = F.reduce_sum(dout) - F.reduce_sum(sparse_values_grad) + result_all = (zeros_like(indices), zeros_like(output_shape), sparse_values_grad, default_value_grad) + return result_all + + return bprop + + @bprop_getters.register(P.SparseTensorDenseMatmul) def get_bprop_sparse_tensor_dense_matmul(self): """Generate bprop for SparseTensorDenseMatmul""" diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index 92e0afb095d..d785cc6b4ef 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -327,13 +327,13 @@ class SparseToDenseV2(Primitive): Raises: TypeError: If the dtype of `indices` is neither Int32 nor Int64. TypeError: If the dtype of `outputshape` is neither Int32 nor Int64. - ValueError: If the shape of `output_shape`, shape of `indices`, shape of - `default_value` and shape of `values` don't meet the parameter description. + ValueError: If the shape of `output_shape`, shape of `indices`, + shape of `default_value` and shape of `values` don't meet the parameter description. ValueError: If each Element of `output_shape` is not > 0. ValueError: If the shape[0] of `indices` don't match with the element of `values`. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32) diff --git a/tests/st/ops/gpu/test_sparse_to_dense_op.py b/tests/st/ops/gpu/test_sparse_to_dense_op.py new file mode 100644 index 00000000000..3bef8c9245d --- /dev/null +++ b/tests/st/ops/gpu/test_sparse_to_dense_op.py @@ -0,0 +1,77 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations.sparse_ops import SparseToDenseV2 +from mindspore.common.api import ms_function +import mindspore.common.dtype as mstype + + +class SparseToDenseNet(nn.Cell): + def __init__(self): + super(SparseToDenseNet, self).__init__() + self.sparsetodense = SparseToDenseV2() + + @ms_function + def construct(self, indices, output_shape, values, default_value): + return self.sparsetodense(indices, output_shape, values, default_value) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparsetodense_2d_int32(): + """ + Feature: Converts a sparse representation into a dense tensor. + Description: 2D , int32 + Expectation: success + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + indices = Tensor(np.array([[0, 1]]).astype(np.int32)) + output_shape = Tensor(np.array([2, 2]).astype(np.int32)) + values = Tensor(np.array([1]).astype(np.int32)) + default_value = Tensor(0, dtype=mstype.int32) + net = SparseToDenseNet() + output = net(indices, output_shape, values, default_value) + sparse_expect = np.array([[0, 1], + [0, 0]]).astype(np.int32) + assert (output.asnumpy() == sparse_expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sparsetodense_2d_double(): + """ + Feature: Converts a sparse representation into a dense tensor. + Description: 2D , double + Expectation: success + """ + for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]: + context.set_context(mode=mode, device_target="GPU") + indices = Tensor(np.array([[0, 1]]).astype(np.int32)) + output_shape = Tensor(np.array([2, 2]).astype(np.int32)) + values = Tensor(np.array([1.0]).astype(np.double)) + default_value = Tensor(0.0, dtype=mstype.double) + net = SparseToDenseNet() + output = net(indices, output_shape, values, default_value) + sparse_expect = np.array([[0.0, 1.0], + [0.0, 0.0]]).astype(np.double) + assert (output.asnumpy() == sparse_expect).all()