Add SparseToDenseV2

This commit is contained in:
al_raya 2022-08-16 18:18:46 +08:00
parent 69202935e2
commit 0caaa95873
8 changed files with 832 additions and 27 deletions

View File

@ -0,0 +1,194 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sparse_to_dense_impl.cuh"
#include <thrust/execution_policy.h>
#include <thrust/device_vector.h>
#include <limits>
#include <algorithm>
#include "include/cuda_fp16.h"
template <typename T>
__global__ void SetDefaultValue(const T default_value, const int64_t output_elements, T *output) {
for (size_t ops = blockIdx.x * blockDim.x + threadIdx.x; ops < output_elements; ops += blockDim.x * gridDim.x) {
output[ops] = default_value;
}
}
template <typename T>
void CallSetDefaultValue(const T default_value, const int64_t output_elements, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
SetDefaultValue<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
default_value, output_elements, output);
return;
}
template <typename T, typename Index>
__global__ void SparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals,
const Index *output_shape, const int ndims, T *output) {
for (size_t ops = blockIdx.x * blockDim.x + threadIdx.x; ops < num_elems; ops += blockDim.x * gridDim.x) {
int64_t output_idx = indices[ops * ndims + ndims - 1];
Index strides = 1;
for (int i = ndims - 2; i >= 0; i--) {
strides *= output_shape[i + 1];
output_idx += indices[ops * ndims + i] * strides;
}
// If num_vals == 1, broadcast the scalar to the positions for non-zeros.
output[output_idx] = vals[(num_vals == 1) ? 0 : ops];
}
__syncthreads();
return;
}
template <typename T, typename Index>
void CallSparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals,
const Index *output_shape, const int ndims, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
SparseToDense<<<CUDA_BLOCKS(device_id, num_elems), CUDA_THREADS(device_id), 0, cuda_stream>>>(
indices, vals, num_elems, num_vals, output_shape, ndims, output);
return;
}
template CUDA_LIB_EXPORT void CallSetDefaultValue<bool>(bool default_value, const int64_t output_elements, bool *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<int8_t>(int8_t default_value, const int64_t output_elements,
int8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<int16_t>(int16_t default_value, const int64_t output_elements,
int16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<int32_t>(int32_t default_value, const int64_t output_elements,
int32_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<int64_t>(int64_t default_value, const int64_t output_elements,
int64_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<uint8_t>(uint8_t default_value, const int64_t output_elements,
uint8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<uint16_t>(uint16_t default_value, const int64_t output_elements,
uint16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<half>(half default_value, const int64_t output_elements, half *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<float>(float default_value, const int64_t output_elements,
float *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSetDefaultValue<double>(double default_value, const int64_t output_elements,
double *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<bool, int32_t>(const int32_t *indices, const bool *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
bool *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int8_t, int32_t>(const int32_t *indices, const int8_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
int8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int16_t, int32_t>(const int32_t *indices, const int16_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
int16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int32_t, int32_t>(const int32_t *indices, const int32_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
int32_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int64_t, int32_t>(const int32_t *indices, const int64_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
int64_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<uint8_t, int32_t>(const int32_t *indices, const uint8_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
uint8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<uint16_t, int32_t>(const int32_t *indices, const uint16_t *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
uint16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<half, int32_t>(const int32_t *indices, const half *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
half *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<float, int32_t>(const int32_t *indices, const float *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
float *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<double, int32_t>(const int32_t *indices, const double *vals,
const int num_elems, const int num_vals,
const int32_t *output_shape, const int ndims,
double *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<bool, int64_t>(const int64_t *indices, const bool *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
bool *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int8_t, int64_t>(const int64_t *indices, const int8_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
int8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int16_t, int64_t>(const int64_t *indices, const int16_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
int16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int32_t, int64_t>(const int64_t *indices, const int32_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
int32_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<int64_t, int64_t>(const int64_t *indices, const int64_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
int64_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<uint8_t, int64_t>(const int64_t *indices, const uint8_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
uint8_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<uint16_t, int64_t>(const int64_t *indices, const uint16_t *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
uint16_t *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<half, int64_t>(const int64_t *indices, const half *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
half *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<float, int64_t>(const int64_t *indices, const float *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
float *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CallSparseToDense<double, int64_t>(const int64_t *indices, const double *vals,
const int num_elems, const int num_vals,
const int64_t *output_shape, const int ndims,
double *output, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,30 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
void CallSetDefaultValue(T default_value, const int64_t output_elements, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template <typename T, typename Index>
void CallSparseToDense(const Index *indices, const T *vals, const int num_elems, const int num_vals,
const Index *output_shape, const int ndims, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPARSE_TO_DENSE_CUH_

View File

@ -0,0 +1,399 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/ops/base_operator.h"
#include "mindspore/core/abstract/utils.h"
#include "plugin/device/gpu/kernel/sparse/sparse_to_dense_v2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseToDenseV2InputsNum = 4;
constexpr size_t kSparseToDenseV2OutputsNum = 1;
constexpr size_t kSparseToDenseV2First = 0;
constexpr size_t kSparseToDenseV2Second = 1;
constexpr size_t kSparseToDenseV2Third = 2;
constexpr size_t kSparseToDenseV2Fourth = 3;
constexpr size_t kSparseToDenseV2TwoDims = 2;
constexpr size_t kSparseToDenseV2OneDim = 1;
constexpr size_t kSparseToDenseV2ZeroDim = 0;
} // namespace
bool SparseToDenseV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::SparseToDenseV2>(base_operator);
kernel_name_ = kernel_ptr_->name();
validate_indices_ = kernel_ptr_->get_validate_indices();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
indice_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
value_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex2).first);
return true;
}
int SparseToDenseV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
indices_shape_ = std::vector<size_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
output_shape_ = std::vector<size_t>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
std::vector<size_t> input_shape_values = std::vector<size_t>(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
indices_dims_ = indices_shape_.size();
ndims = indices_shape_.size() > 1 ? indices_shape_[1] : 1;
num_elems = indices_shape_.size() > 0 ? indices_shape_[0] : 1;
values_size_ = input_shape_values[0];
output_elements = 1;
std::vector<int64_t> output_shape = outputs.at(kIndex0)->GetShapeVector();
for (size_t i = 0; i < output_shape.size(); ++i) {
output_elements *= output_shape[i];
}
input_elements_indices = std::accumulate(indices_shape_.begin(), indices_shape_.end(), 1, std::multiplies<size_t>());
input_elements_values =
std::accumulate(input_shape_values.begin(), input_shape_values.end(), 1, std::multiplies<size_t>());
input_elements_output_shape =
std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<size_t>());
size_t input_size_indices = input_elements_indices * indice_size_;
size_t input_size_values = input_elements_values * value_size_;
size_t input_size_output_shape = input_elements_output_shape * indice_size_;
size_t output_size = output_elements * value_size_;
input_size_list_.push_back(input_size_indices);
input_size_list_.push_back(input_size_values);
input_size_list_.push_back(input_size_output_shape);
output_size_list_.push_back(output_size);
return KRET_OK;
}
void SparseToDenseV2GpuKernelMod::ResetResource() noexcept {
output_elements = 1;
input_elements_indices = 0;
input_elements_values = 0;
input_elements_output_shape = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
}
template <typename I, typename T>
void SparseToDenseV2GpuKernelMod::CheckValidateTwoDim(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_);
if (outputs[0]->size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0.";
}
I *input_indices = GetDeviceAddress<I>(inputs, kIndex0);
I *indices_addr = reinterpret_cast<I *>(malloc(input_elements_indices * indice_size_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indices_addr, input_indices, input_elements_indices * indice_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync indices failed");
I *input_output_shape = GetDeviceAddress<I>(inputs, kIndex1);
I *output_shape_addr = reinterpret_cast<I *>(malloc(input_elements_output_shape * indice_size_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_shape_addr, input_output_shape, input_elements_output_shape * indice_size_,
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync dense_shape failed");
bool valid = true;
bool different = false;
bool increasing = true;
for (size_t k = 0; k < indices_shape_[1]; ++k) {
size_t index = k;
if (indices_addr[index] < 0 || indices_addr[index] >= output_shape_addr[index]) {
valid = false;
}
}
if (!valid) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds.";
}
for (size_t i = 1; i < indices_shape_[0]; ++i) {
for (size_t j = 0; j < indices_shape_[1]; ++j) {
size_t index1 = i * indices_shape_[1] + j;
size_t index2 = (i - 1) * indices_shape_[1] + j;
if (indices_addr[index1] < 0 || indices_addr[index1] >= output_shape_addr[j]) {
valid = false;
}
I diff = indices_addr[index1] - indices_addr[index2];
if (diff > 0) {
different = true;
}
if (!different && diff < 0) {
increasing = false;
}
}
if (!valid) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds.";
}
if (!increasing) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of order.";
}
if (!different) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is repeated";
}
}
}
template <typename I, typename T>
void SparseToDenseV2GpuKernelMod::CheckValidateOneDim(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseToDenseV2InputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseToDenseV2OutputsNum, kernel_name_);
if (outputs[0]->size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', output memory size should be greater than 0, but got 0.";
}
I *input_indices = GetDeviceAddress<I>(inputs, kIndex0);
I *indices_addr = reinterpret_cast<I *>(malloc(input_elements_indices * indice_size_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(indices_addr, input_indices, input_elements_indices * indice_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync indices failed");
I *input_output_shape = GetDeviceAddress<I>(inputs, kIndex1);
I *output_shape_addr = reinterpret_cast<I *>(malloc(input_elements_output_shape * indice_size_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output_shape_addr, input_output_shape, input_elements_output_shape * indice_size_,
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync dense_shape failed");
bool valid = true;
bool different = false;
bool increasing = true;
if (indices_addr[0] < 0 || indices_addr[0] > output_shape_addr[0]) {
valid = false;
}
for (size_t i = 1; i < indices_shape_[0]; ++i) {
if (indices_addr[i] < 0 || indices_addr[i] >= output_shape_addr[0]) {
valid = false;
}
I diff = indices_addr[i] - indices_addr[i - 1];
if (diff > 0) {
different = true;
}
if (!different && diff < 0) {
increasing = false;
}
if (!valid) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of bounds.";
}
if (!increasing) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is out of order.";
}
if (!different) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the indices is repeated";
}
}
}
template <typename I, typename T>
bool SparseToDenseV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (validate_indices_ == true && indices_dims_ == kSparseToDenseV2TwoDims) {
(void)SparseToDenseV2GpuKernelMod::CheckValidateTwoDim<I, T>(inputs, workspace, outputs);
} else if (validate_indices_ == true && indices_dims_ == kSparseToDenseV2OneDim) {
(void)SparseToDenseV2GpuKernelMod::CheckValidateOneDim<I, T>(inputs, workspace, outputs);
}
I *input_indices = GetDeviceAddress<I>(inputs, kIndex0);
I *input_output_shape = GetDeviceAddress<I>(inputs, kIndex1);
T *input_values = GetDeviceAddress<T>(inputs, kIndex2);
T *input_default_value = GetDeviceAddress<T>(inputs, kIndex3);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
T *default_value_data = reinterpret_cast<T *>(malloc(value_size_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(default_value_data, input_default_value, value_size_, cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync default_value failed");
auto cuda_stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
CallSetDefaultValue(default_value_data[0], output_elements, output, device_id_, cuda_stream);
CallSparseToDense(input_indices, input_values, num_elems, input_elements_values, input_output_shape, ndims, output,
device_id_, cuda_stream);
return true;
}
std::vector<std::pair<KernelAttr, SparseToDenseV2GpuKernelMod::SparseToDenseV2LaunchFunc>>
SparseToDenseV2GpuKernelMod::func_list_ = {{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int32_t, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, bool>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, half>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseToDenseV2GpuKernelMod::LaunchKernel<int64_t, double>}};
std::vector<KernelAttr> SparseToDenseV2GpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SparseToDenseV2GpuKernelMod::SparseToDenseV2LaunchFunc> &pair) {
return pair.first;
});
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SparseToDenseV2, SparseToDenseV2GpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,94 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/sparse_to_dense_v2.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sparse_to_dense_impl.cuh"
namespace mindspore {
namespace kernel {
class SparseToDenseV2GpuKernelMod : public NativeGpuKernelMod {
public:
SparseToDenseV2GpuKernelMod() { ResetResource(); }
~SparseToDenseV2GpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void ResetResource() noexcept;
template <typename I, typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
template <typename I, typename T>
void CheckValidateOneDim(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
template <typename I, typename T>
void CheckValidateTwoDim(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using SparseToDenseV2LaunchFunc =
std::function<bool(SparseToDenseV2GpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, SparseToDenseV2LaunchFunc>> func_list_;
SparseToDenseV2LaunchFunc kernel_func_{};
size_t indice_size_{1};
size_t value_size_{1};
size_t input_elements_indices;
size_t input_elements_values;
size_t input_elements_output_shape;
size_t output_elements;
int ndims;
int num_elems;
bool is_null_input_{false};
void *cuda_stream_{nullptr};
bool validate_indices_{true};
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
size_t indices_dims_{0};
size_t values_size_{0};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_SPARSE_SPARSE_TO_DENSE_V2_GPU_KERNEL_H_

View File

@ -29,17 +29,14 @@
namespace mindspore {
namespace ops {
namespace {
namespace {
constexpr size_t kIndiceselement = 2;
constexpr size_t kOutShapeSize = 1;
constexpr size_t kValuesSize = 1;
constexpr size_t kDefaultSize = 0;
constexpr size_t kDefaultElem = 1;
} // namespace
abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const size_t Indiceselement = 2;
const size_t OutShapeSize = 1;
const size_t ValuesSize = 1;
const size_t DefaultSize = 0;
auto max_length_ptr = primitive->GetAttr("max_length");
MS_EXCEPTION_IF_NULL(max_length_ptr);
int64_t max_length = GetValue<int64_t>(max_length_ptr);
@ -51,26 +48,13 @@ abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive,
auto output_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(output_shape_shape_ptr)[kShape];
auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(values_shape_ptr)[kShape];
auto default_value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(default_value_shape_ptr)[kShape];
(void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kLessEqual, kIndiceselement,
(void)CheckAndConvertUtils::CheckInteger("indices dimension", indices_shape.size(), kLessEqual, Indiceselement,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("outshape dimension", output_shape_shape.size(), kEqual, kOutShapeSize,
(void)CheckAndConvertUtils::CheckInteger("outshape dimension", output_shape_shape.size(), kEqual, OutShapeSize,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("values dimension", values_shape.size(), kLessEqual, kValuesSize, prim_name);
(void)CheckAndConvertUtils::CheckInteger("default_value dimension", default_value_shape.size(), kEqual, kDefaultSize,
(void)CheckAndConvertUtils::CheckInteger("values dimension", values_shape.size(), kLessEqual, ValuesSize, prim_name);
(void)CheckAndConvertUtils::CheckInteger("default_value dimension", default_value_shape.size(), kEqual, DefaultSize,
prim_name);
if (indices_shape.size() == 0) {
if (values_shape.size() != 0 && values_shape[0] != 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] is 1"
<< " should match the the values element " << values_shape[0] << ".";
}
} else {
if (values_shape.size() != 0) {
if (indices_shape[0] != values_shape[0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] " << indices_shape[0]
<< " should match the the values element " << values_shape[0] << ".";
}
}
}
size_t output_shape_numelement = output_shape_shape[0];
auto output_shape = input_args[1]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(output_shape);
@ -85,6 +69,19 @@ abstract::ShapePtr SparseToDenseV2InferShape(const PrimitivePtr &primitive,
MS_EXCEPTION_IF_NULL(output_shape_type_element);
std::vector<int64_t> y_shape;
if (!input_args[1]->BuildValue()->isa<AnyValue>() && !input_args[1]->BuildValue()->isa<None>()) {
if (indices_shape.size() == 0) {
if (values_shape.size() != 0 && values_shape[0] != 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] is 1"
<< " should match the the values element " << values_shape[0] << ".";
}
} else {
if (values_shape.size() != 0) {
if (indices_shape[0] != values_shape[0]) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the indices_shape[0] " << indices_shape[0]
<< " should match the the values element " << values_shape[0] << ".";
}
}
}
if (output_shape_type_element->type_id() == kNumberTypeInt32) {
auto output_shape_data = reinterpret_cast<int32_t *>(output_shape_tensor->data_c());
for (size_t i = 0; i < output_shape_numelement; ++i) {

View File

@ -25,6 +25,7 @@ from .. import operations as P
from ..operations import _csr_ops
from ..operations.sparse_ops import SparseAdd, CSRSparseMatrixToDense, CSRSparseMatrixToSparseTensor, \
DenseToCSRSparseMatrix
from ..operations.sparse_ops import SparseToDenseV2
# Unused parameters are placeholders.
@ -65,6 +66,19 @@ def get_bprop_sparse_to_dense(self):
return bprop
@bprop_getters.register(SparseToDenseV2)
def get_bprop_sparse_to_dense_v2(self):
"""Generate bprop for SparseToDenseV2"""
def bprop(indices, output_shape, values, default_value, out, dout):
sparse_values_grad = F.gather_nd(dout, indices)
default_value_grad = F.reduce_sum(dout) - F.reduce_sum(sparse_values_grad)
result_all = (zeros_like(indices), zeros_like(output_shape), sparse_values_grad, default_value_grad)
return result_all
return bprop
@bprop_getters.register(P.SparseTensorDenseMatmul)
def get_bprop_sparse_tensor_dense_matmul(self):
"""Generate bprop for SparseTensorDenseMatmul"""

View File

@ -327,13 +327,13 @@ class SparseToDenseV2(Primitive):
Raises:
TypeError: If the dtype of `indices` is neither Int32 nor Int64.
TypeError: If the dtype of `outputshape` is neither Int32 nor Int64.
ValueError: If the shape of `output_shape`, shape of `indices`, shape of
`default_value` and shape of `values` don't meet the parameter description.
ValueError: If the shape of `output_shape`, shape of `indices`,
shape of `default_value` and shape of `values` don't meet the parameter description.
ValueError: If each Element of `output_shape` is not > 0.
ValueError: If the shape[0] of `indices` don't match with the element of `values`.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)

View File

@ -0,0 +1,77 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations.sparse_ops import SparseToDenseV2
from mindspore.common.api import ms_function
import mindspore.common.dtype as mstype
class SparseToDenseNet(nn.Cell):
def __init__(self):
super(SparseToDenseNet, self).__init__()
self.sparsetodense = SparseToDenseV2()
@ms_function
def construct(self, indices, output_shape, values, default_value):
return self.sparsetodense(indices, output_shape, values, default_value)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparsetodense_2d_int32():
"""
Feature: Converts a sparse representation into a dense tensor.
Description: 2D , int32
Expectation: success
"""
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
indices = Tensor(np.array([[0, 1]]).astype(np.int32))
output_shape = Tensor(np.array([2, 2]).astype(np.int32))
values = Tensor(np.array([1]).astype(np.int32))
default_value = Tensor(0, dtype=mstype.int32)
net = SparseToDenseNet()
output = net(indices, output_shape, values, default_value)
sparse_expect = np.array([[0, 1],
[0, 0]]).astype(np.int32)
assert (output.asnumpy() == sparse_expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sparsetodense_2d_double():
"""
Feature: Converts a sparse representation into a dense tensor.
Description: 2D , double
Expectation: success
"""
for mode in [context.PYNATIVE_MODE, context.GRAPH_MODE]:
context.set_context(mode=mode, device_target="GPU")
indices = Tensor(np.array([[0, 1]]).astype(np.int32))
output_shape = Tensor(np.array([2, 2]).astype(np.int32))
values = Tensor(np.array([1.0]).astype(np.double))
default_value = Tensor(0.0, dtype=mstype.double)
net = SparseToDenseNet()
output = net(indices, output_shape, values, default_value)
sparse_expect = np.array([[0.0, 1.0],
[0.0, 0.0]]).astype(np.double)
assert (output.asnumpy() == sparse_expect).all()