From 864044788f982351e3226f3a9b78032a39b55113 Mon Sep 17 00:00:00 2001 From: p-nian <1026078943@qq.com> Date: Wed, 24 Aug 2022 16:25:18 +0800 Subject: [PATCH] add scatterdiv/scattermul gpu --- .../gpu/kernel/arrays/scatter_gpu_kernel.cc | 255 +++++++++++++ .../gpu/kernel/arrays/scatter_gpu_kernel.h | 59 +++ .../cuda_impl/cuda_class/scatter_helper.h | 156 ++++++++ .../kernel/cuda_impl/cuda_ops/scatter_impl.cu | 356 ++++++++++++++++++ .../cuda_impl/cuda_ops/scatter_impl.cuh | 32 ++ mindspore/core/ops/scatter_div.h | 3 + .../mindspore/ops/operations/array_ops.py | 87 ++--- tests/st/ops/gpu/test_scatter_op.py | 84 +++++ 8 files changed, 990 insertions(+), 42 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh create mode 100644 tests/st/ops/gpu/test_scatter_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.cc new file mode 100644 index 00000000000..3fe64097347 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h" +#include +#include +#include +#include "abstract/utils.h" + +namespace mindspore { +namespace kernel { +namespace { +template +std::unique_ptr CreateScatterKernelPtr(const std::string &kernel_name, + const uint32_t &device_id) { + return std::make_unique>(kernel_name, device_id); +} +using ScatterPtrCreatorFunc = + std::function(const std::string &, const uint32_t &)>; + +const std::vector> kernel_attr = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt16) + .AddOutputAttr(kNumberTypeInt16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt16) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt16) + .AddOutputAttr(kNumberTypeUInt16), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt32) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt32) + .AddOutputAttr(kNumberTypeUInt32), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeUInt64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeUInt64) + .AddOutputAttr(kNumberTypeUInt64), + CreateScatterKernelPtr}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + CreateScatterKernelPtr, int>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeComplex64) + .AddOutputAttr(kNumberTypeComplex64), + CreateScatterKernelPtr, int64_t>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + CreateScatterKernelPtr, int>}, + {KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeComplex128) + .AddOutputAttr(kNumberTypeComplex128), + CreateScatterKernelPtr, int64_t>}}; +} // namespace + +bool ScatterGpuKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector input_ptrs = ConvertPtrs(inputs); + std::vector work_ptrs = ConvertPtrs(workspace); + std::vector output_ptrs = ConvertPtrs(outputs); + if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) { + return false; + } + return true; +} + +bool ScatterGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + kernel_name_ = base_operator->GetPrim()->name(); + auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport()); + if (!is_match) { + return false; + } + helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_)); + helper_ptr_->SetKernelParam(attr_ptr_); + return true; +} + +int ScatterGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + std::vector> input_shapes; + std::vector> output_shapes; + std::vector inp_shape0 = inputs[kIndex0]->GetShapeVector(); + std::vector inp_shape1 = inputs[kIndex1]->GetShapeVector(); + std::vector inp_shape2 = inputs[kIndex2]->GetShapeVector(); + std::vector out_shape = outputs[kIndex0]->GetShapeVector(); + input_shapes.emplace_back(inp_shape0); + input_shapes.emplace_back(inp_shape1); + input_shapes.emplace_back(inp_shape2); + output_shapes.emplace_back(out_shape); + if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) { + return KRET_RESIZE_FAILED; + } + input_size_list_ = helper_ptr_->GetInputSizeList(); + output_size_list_ = helper_ptr_->GetOutputSizeList(); + workspace_size_list_ = helper_ptr_->GetWorkSizeList(); + return KRET_OK; +} + +std::vector ScatterGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterDiv, ScatterGpuKernelMod); +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterMul, ScatterGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h new file mode 100644 index 00000000000..1809880a506 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "mindspore/core/ops/scatter_div.h" +#include "mindspore/core/ops/scatter_mul.h" +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +namespace mindspore { +namespace kernel { +template +using Complex = mindspore::utils::Complex; +class ScatterGpuKernelMod : public NativeGpuKernelMod { + public: + ScatterGpuKernelMod() { attr_ptr_ = std::make_shared(); } + ~ScatterGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; + std::vector GetOpSupport() override; + + private: + void *stream_ptr_; + std::unique_ptr helper_ptr_{nullptr}; + std::shared_ptr attr_ptr_{nullptr}; + BaseOperatorPtr base_operator_ = nullptr; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_FUNCTOR_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h new file mode 100644 index 00000000000..93b3b040d46 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h @@ -0,0 +1,156 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_ +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh" + +namespace mindspore { +namespace cukernel { +static const std::map kScatterTypeMap = { + {"ScatterMul", SCATTER_MUL}, + {"ScatterDiv", SCATTER_DIV}, +}; +class ScatterAttr : public GpuKernelAttrBase { + public: + ScatterAttr() = default; + ~ScatterAttr() override = default; +}; + +template +class ScatterHelperGpuKernel : public GpuKernelHelperBase { + public: + explicit ScatterHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id) + : GpuKernelHelperBase(kernel_name, device_id) { + is_null_input_ = false; + } + + virtual ~ScatterHelperGpuKernel() = default; + int CalMemSize(const std::vector> &input_shapes, + const std::vector> &output_shapes) override { + constexpr size_t INPUT_NUM = 3; + constexpr size_t OUTPUT_NUM = 1; + ResetResource(); + int inp_flag = CalShapesSizeInBytes(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_); + if (inp_flag == -1) { + return inp_flag; + } + input_shape_ = input_shapes[0]; + indices_shape_ = input_shapes[1]; + first_dim_size_ = input_shape_[0]; + input_size_ = 1; + inner_size_ = 1; + for (int64_t i = 1; i < static_cast(input_shape_.size()); i++) { + inner_size_ *= input_shape_[i]; + } + input_size_ = input_shape_[0] * inner_size_; + indices_size_ = 1; + for (int64_t i = 0; i < static_cast(indices_shape_.size()); i++) { + indices_size_ *= indices_shape_[i]; + } + updates_size_ = 1; + updates_size_ = indices_size_ * inner_size_; + int out_flag = + CalShapesSizeInBytes(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_); + if (out_flag == -1) { + return out_flag; + } + is_null_input_ = (inp_flag == 1 || out_flag == 1); + return 0; + } + + int Process(const std::vector &input_ptrs, const std::vector &output_ptrs, + const std::vector &work_ptrs, void *cuda_stream) override { + if (is_null_input_) { + return 0; + } + auto iter = kScatterTypeMap.find(kernel_name_); + if (iter == kScatterTypeMap.end()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "Only support these scatter functors: ScatterMul, ScatterDiv " + << " currently, but got " << kernel_name_; + } else { + scatter_type_ = iter->second; + } + T *input = nullptr; + S *indices = nullptr; + T *updates = nullptr; + T *output = nullptr; + S size_limit = static_cast(first_dim_size_); + + int flag = GetDeviceAddress(input_ptrs, kIndex0, kernel_name_, &input); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kIndex1, kernel_name_, &indices); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(input_ptrs, kIndex2, kernel_name_, &updates); + if (flag != 0) { + return flag; + } + flag = GetDeviceAddress(output_ptrs, kIndex0, kernel_name_, &output); + if (flag != 0) { + return flag; + } + // call cuda kernel + Scatter(scatter_type_, size_limit, inner_size_, indices_size_, indices, updates, input, device_id_, + reinterpret_cast(cuda_stream)); + + cudaError_t status = (cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, + reinterpret_cast(cuda_stream))); + if (status != cudaSuccess) { + MS_LOG(ERROR) << "CUDA Error: " + << "cudaMemcpyAsync output failed" + << " | Error Number: " << status << " " << cudaGetErrorString(status); + } + return 0; + } + + void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override { + attr_ptr_ = std::dynamic_pointer_cast(kernel_attr); + } + + void ResetResource() noexcept override { + input_size_ = 0; + inner_size_ = 0; + indices_size_ = 0; + updates_size_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + work_size_list_.clear(); + } + + private: + ScatterType scatter_type_; + std::shared_ptr attr_ptr_; + std::vector input_shape_; + std::vector indices_shape_; + size_t first_dim_size_; + size_t input_size_; + size_t inner_size_; + size_t indices_size_; + size_t updates_size_; + bool is_null_input_; +}; +} // namespace cukernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ARGMAX_HELPER_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cu new file mode 100644 index 00000000000..fd34f3d99eb --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cu @@ -0,0 +1,356 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" + +// Specializations of atomic div for complex types +__device__ inline Complex ScatterDivComplex(Complex* address, Complex val) { + auto ptr_addr = reinterpret_cast(address); + float addr_real = (*address).real(); + float addr_imag = (*address).imag(); + float temp = (pow(val.real(), 2) + pow(val.imag(), 2)); + + MsAtomicMul(ptr_addr, val.real()); + MsAtomicAdd(ptr_addr, addr_imag * val.imag()); + MsAtomicMul(ptr_addr + 1, val.real()); + MsAtomicSub(ptr_addr + 1, addr_real * val.imag()); + return Complex(MsAtomicDiv(ptr_addr, temp), + MsAtomicDiv(ptr_addr + 1, temp)); +} + +__device__ inline Complex ScatterDivComplex(Complex* address, Complex val) { + auto ptr_addr = reinterpret_cast(address); + double addr_real = (*address).real(); + double addr_imag = (*address).imag(); + double temp = (pow(val.real(), 2) + pow(val.imag(), 2)); + + MsAtomicMul(ptr_addr, val.real()); + MsAtomicAdd(ptr_addr, addr_imag * val.imag()); + MsAtomicMul(ptr_addr + 1, val.real()); + MsAtomicSub(ptr_addr + 1, addr_real * val.imag()); + return Complex(MsAtomicDiv(ptr_addr, temp), + MsAtomicDiv(ptr_addr + 1, temp)); +} + +// Specializations of atomic mul for complex types +__device__ inline Complex ScatterMulComplex(Complex* address, Complex val) { + auto ptr_addr = reinterpret_cast(address); + float addr_real = (*address).real(); + float addr_imag = (*address).imag(); + MsAtomicMul(ptr_addr, val.real()); + MsAtomicMul(ptr_addr + 1, val.real()); + return Complex(MsAtomicSub(ptr_addr, addr_imag * val.imag()), + MsAtomicAdd(ptr_addr + 1, addr_real * val.imag())); +} + +__device__ inline Complex ScatterMulComplex(Complex* address, Complex val) { + auto ptr_addr = reinterpret_cast(address); + double addr_real = (*address).real(); + double addr_imag = (*address).imag(); + MsAtomicMul(ptr_addr, val.real()); + MsAtomicMul(ptr_addr + 1, val.real()); + return Complex(MsAtomicSub(ptr_addr, addr_imag * val.imag()), + MsAtomicAdd(ptr_addr + 1, addr_real * val.imag())); +} + +template + __global__ void ScatterDivKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices, + const T *updates, T *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + MsAtomicDiv(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices, + const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterDivComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size, + const int64_t *indices, const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterDivComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices, + const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterDivComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size, + const int64_t *indices, const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterDivComplex(&input[current_pos], updates[pos]); + } +} + +template + __global__ void ScatterMulKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices, + const T *updates, T *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + MsAtomicMul(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices, + const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterMulComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size, + const int64_t *indices, const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterMulComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices, + const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterMulComplex(&input[current_pos], updates[pos]); + } +} + +__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size, + const int64_t *indices, const Complex *updates, Complex *input) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + if (indices[index] < 0 || indices[index] >= size_limit) { + continue; + } + const size_t current_pos = indices[index] * inner_size + offset; + ScatterMulComplex(&input[current_pos], updates[pos]); + } +} + +template +void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size, + const S *indices, const T *updates, T *input, const uint32_t &device_id, cudaStream_t cuda_stream) { + const size_t updates_size = inner_size * indices_size; + switch (func_type) { + case SCATTER_DIV: + return ScatterDivKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + case SCATTER_MUL: + return ScatterMulKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + default: + break; + } +} + +template +void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size, + const S *indices, const Complex *updates, Complex *input, const uint32_t &device_id, + cudaStream_t cuda_stream) { + const size_t updates_size = inner_size * indices_size; + switch (func_type) { + case SCATTER_DIV: + return ScatterDivKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + case SCATTER_MUL: + return ScatterMulKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + default: + break; + } +} + +template +void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size, + const S *indices, const Complex *updates, Complex *input, const uint32_t &device_id, + cudaStream_t cuda_stream) { + const size_t updates_size = inner_size * indices_size; + switch (func_type) { + case SCATTER_DIV: + return ScatterDivKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + case SCATTER_MUL: + return ScatterMulKernel<<>>( + size_limit, inner_size, updates_size, indices, updates, input); + default: + break; + } +} + +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const float *updates, float *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const float *updates, float *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const half *updates, half *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const half *updates, half *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const double *updates, double *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const double *updates, double *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const int8_t *updates, int8_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const int8_t *updates, int8_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const unsigned char *updates, + unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const unsigned char *updates, + unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const int16_t *updates, int16_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const int16_t *updates, int16_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const uint16_t *updates, uint16_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const uint16_t *updates, uint16_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const int *updates, int *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const int *updates, int *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const uint32_t *updates, uint32_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const uint32_t *updates, uint32_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const int64_t *updates, int64_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const int64_t *updates, int64_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const uint64_t *updates, uint64_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const uint64_t *updates, uint64_t *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter, int>(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const Complex *updates, Complex *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter, int64_t>(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const Complex *updates, Complex *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter, int>(enum ScatterType func_type, int size_limit, + const size_t &inner_size, const size_t &indices_size, + const int *indices, const Complex *updates, Complex *input, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void Scatter, int64_t>(enum ScatterType func_type, int64_t size_limit, + const size_t &inner_size, const size_t &indices_size, + const int64_t *indices, const Complex *updates, Complex *input, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh new file mode 100644 index 00000000000..dfac7bb7683 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_ +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +enum ScatterType { + SCATTER_MUL = 0, + SCATTER_DIV, + SCATTER_INVALID_TYPE = 255 +}; + +template +CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, + const size_t &indices_size, const S *indices, const T *updates, T *input, + const uint32_t &device_id, cudaStream_t cuda_stream); + +#endif diff --git a/mindspore/core/ops/scatter_div.h b/mindspore/core/ops/scatter_div.h index bf26f175c34..42e97d7aa36 100644 --- a/mindspore/core/ops/scatter_div.h +++ b/mindspore/core/ops/scatter_div.h @@ -30,6 +30,9 @@ class MIND_API ScatterDiv : public BaseOperator { /// \brief Constructor. ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); } }; + +abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index c78ae5ee7ce..60a14e21363 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -4631,34 +4631,36 @@ class ScatterMul(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32. - - **updates** (Tensor) - The tensor doing the min operation with `input_x`, - the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`. + - **indices** (Tensor) - The index to do multiply operation whose data type must be mstype.int32 or + mstype.int64. + - **updates** (Tensor) - The tensor doing the multiply operation with `input_x`, + the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`. Outputs: Tensor, the updated `input_x`, has the same shape and type as `input_x`. Raises: TypeError: If `use_locking` is not a bool. - TypeError: If `indices` is not an int32. - ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`. + TypeError: If `indices` is not an int32 or an int64. + ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`. RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter is required when data type conversion of Parameter is not supported. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") - >>> indices = Tensor(np.array([0, 1]), mindspore.int32) - >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) - >>> scatter_mul = ops.ScatterMul() + >>> from mindspore.ops import operations as op + >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x") + >>> indices = Tensor(np.array([0, 1]), mstype.int32) + >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32) + >>> scatter_mul = op.ScatterMul() >>> output = scatter_mul(input_x, indices, updates) >>> print(output) [[2. 2. 2.] [4. 4. 4.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. - >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") + >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x") >>> # for indices = [[0, 1], [1, 1]] >>> # step 1: [0, 1] >>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0] @@ -4666,16 +4668,16 @@ class ScatterMul(_ScatterOpDynamic): >>> # step 2: [1, 1] >>> # input_x[1] = [6.0, 6.0, 6.0] * [7.0, 7.0, 7.0] = [42.0, 42.0, 42.0] >>> # input_x[1] = [42.0, 42.0, 42.0] * [9.0, 9.0, 9.0] = [378.0, 378.0, 378.0] - >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32) - >>> scatter_mul = ops.ScatterMul() + ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32) + >>> scatter_mul = op.ScatterMul() >>> output = scatter_mul(input_x, indices, updates) >>> print(output) [[ 1. 1. 1.] [378. 378. 378.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. - >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") + >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x") >>> # for indices = [[1, 0], [1, 1]] >>> # step 1: [1, 0] >>> # input_x[0] = [1.0, 1.0, 1.0] * [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0] @@ -4683,16 +4685,16 @@ class ScatterMul(_ScatterOpDynamic): >>> # step 2: [1, 1] >>> # input_x[1] = [2.0, 2.0, 2.0] * [7.0, 7.0, 7.0] = [14.0, 14.0, 14.0] >>> # input_x[1] = [14.0, 14.0, 14.0] * [9.0, 9.0, 9.0] = [126.0, 126.0, 126.0] - >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32) - >>> scatter_mul = ops.ScatterMul() + ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32) + >>> scatter_mul = op.ScatterMul() >>> output = scatter_mul(input_x, indices, updates) >>> print(output) [[ 3. 3. 3.] [126. 126. 126.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. - >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") + >>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x") >>> # for indices = [[0, 1], [0, 1]] >>> # step 1: [0, 1] >>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0] @@ -4700,10 +4702,10 @@ class ScatterMul(_ScatterOpDynamic): >>> # step 2: [0, 1] >>> # input_x[0] = [1.0, 1.0, 1.0] * [7.0, 7.0, 7.0] = [7.0, 7.0, 7.0] >>> # input_x[1] = [6.0, 6.0, 6.0] * [9.0, 9.0, 9.0] = [54.0, 54.0, 54.0] - >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32) - >>> scatter_mul = ops.ScatterMul() + ... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32) + >>> scatter_mul = op.ScatterMul() >>> output = scatter_mul(input_x, indices, updates) >>> print(output) [[ 7. 7. 7.] @@ -4718,7 +4720,7 @@ class ScatterDiv(_ScatterOpDynamic): Using given values to update tensor value through the div operation, along with the input indices. This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. - for each :math:`i, ..., j` in `indices.shape`: + for each `i, ..., j` in `indices.shape`: .. math:: @@ -4734,8 +4736,8 @@ class ScatterDiv(_ScatterOpDynamic): Inputs: - **input_x** (Parameter) - The target tensor, with data type of Parameter. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. - - **indices** (Tensor) - The index to do divide operation whose data type must be mindspore.int32 or - mindspore.int64. + - **indices** (Tensor) - The index to do divide operation whose data type must be mstype.int32 or + mstype.int64. - **updates** (Tensor) - The tensor doing the divide operation with `input_x`, the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`. @@ -4752,20 +4754,21 @@ class ScatterDiv(_ScatterOpDynamic): and `updates` is greater than 8 dimensions. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x") - >>> indices = Tensor(np.array([0, 1]), mindspore.int32) - >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32) - >>> scatter_div = ops.ScatterDiv() + >>> from mindspore.ops import operations as op + >>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x") + >>> indices = Tensor(np.array([0, 1]), mstype.int32) + >>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32) + >>> scatter_div = op.ScatterDiv() >>> output = scatter_div(input_x, indices, updates) >>> print(output) [[3. 3. 3.] [1. 1. 1.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0], - ... [315.0, 315.0, 315.0]]), mindspore.float32), name="x") + ... [315.0, 315.0, 315.0]]), mstype.float32), name="x") >>> # for indices = [[0, 1], [1, 1]] >>> # step 1: [0, 1] >>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0] @@ -4773,17 +4776,17 @@ class ScatterDiv(_ScatterOpDynamic): >>> # step 2: [1, 1] >>> # input_x[1] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0] >>> # input_x[1] = [21.0, 21.0, 21.0] / [7.0, 7.0, 7.0] = [3.0, 3.0, 3.0] - >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32) - >>> scatter_div = ops.ScatterDiv() + ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32) + >>> scatter_div = op.ScatterDiv() >>> output = scatter_div(input_x, indices, updates) >>> print(output) [[105. 105. 105.] [ 3. 3. 3.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0], - ... [315.0, 315.0, 315.0]]), mindspore.float32), name="x") + ... [315.0, 315.0, 315.0]]), mstype.float32), name="x") >>> # for indices = [[1, 0], [1, 1]] >>> # step 1: [1, 0] >>> # input_x[0] = [105.0, 105.0, 105.0] / [3.0, 3.0, 3.0] = [35.0, 35.0, 35.0] @@ -4791,17 +4794,17 @@ class ScatterDiv(_ScatterOpDynamic): >>> # step 2: [1, 1] >>> # input_x[1] = [315.0, 315.0, 315.0] / [5.0, 5.0, 5.0] = [63.0 63.0 63.0] >>> # input_x[1] = [63.0 63.0 63.0] / [7.0, 7.0, 7.0] = [9.0, 9.0, 9.0] - >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32) - >>> scatter_div = ops.ScatterDiv() + ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32) + >>> scatter_div = op.ScatterDiv() >>> output = scatter_div(input_x, indices, updates) >>> print(output) [[35. 35. 35.] [ 9. 9. 9.]] >>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized. >>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0], - ... [315.0, 315.0, 315.0]]), mindspore.float32), name="x") + ... [315.0, 315.0, 315.0]]), mstype.float32), name="x") >>> # for indices = [[0, 1], [0, 1]] >>> # step 1: [0, 1] >>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0] @@ -4809,10 +4812,10 @@ class ScatterDiv(_ScatterOpDynamic): >>> # step 2: [0, 1] >>> # input_x[0] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0] >>> # input_x[1] = [105.0, 105.0, 105.0] / [7.0, 7.0, 7.0] = [15.0, 15.0, 15.0] - >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32) + >>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32) >>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]], - ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32) - >>> scatter_div = ops.ScatterDiv() + ... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32) + >>> scatter_div = op.ScatterDiv() >>> output = scatter_div(input_x, indices, updates) >>> print(output) [[21. 21. 21.] diff --git a/tests/st/ops/gpu/test_scatter_op.py b/tests/st/ops/gpu/test_scatter_op.py new file mode 100644 index 00000000000..4cff415c891 --- /dev/null +++ b/tests/st/ops/gpu/test_scatter_op.py @@ -0,0 +1,84 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.context as context +import mindspore.ops.operations as P +from mindspore.nn import Cell +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.common.dtype import pytype_to_dtype + + +class ScatterDiv(Cell): + def __init__(self, input_shape, input_dtype, use_locking): + super().__init__() + self.op = P.ScatterDiv(use_locking) + self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input") + + def construct(self, indices, update): + self.op(self.inputdata, indices, update) + return self.inputdata + + +class ScatterMul(Cell): + def __init__(self, input_shape, input_dtype, use_locking): + super().__init__() + self.op = P.ScatterMul(use_locking) + self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input") + + def construct(self, indices, update): + self.op(self.inputdata, indices, update) + return self.inputdata + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_func_small_float32(): + """ + Feature: ScatterDiv/ScatterMul gpu TEST. + Description: test case for ScatterDiv/ScatterMul + Expectation: The value and shape of output are the expected values. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + input_shape = (2, 3) + input_dtype = np.float32 + update_np = np.array( + [ + [[23, 11, 15], [14, 36, 215]], + [[330, 9, 65], [10, 7, 39]] + ] + ).astype(np.float32) + indices_np = np.array([[0, 1], [0, 1]]).astype(np.int32) + + # div + indices_me = Tensor(indices_np) + update_me = Tensor(update_np) + net = ScatterDiv(input_shape, pytype_to_dtype(input_dtype), use_locking=True) + out = net(indices_me, update_me) + expect = np.array([[0.00013175, 0.01010101, 0.00102564], [0.00714286, 0.00396825, 0.00011926]]) + assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001) + + # mul + indices_me = Tensor(indices_np) + update_me = Tensor(update_np) + net = ScatterMul(input_shape, pytype_to_dtype(input_dtype), use_locking=True) + out = net(indices_me, update_me) + expect = np.array([[7590.0, 99.0, 975.0], [140.0, 252.0, 8385.0]]) + assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001)