!36136 [assistant][ops] Add New GPU operator ScatterDiv/ScatterMul

Merge pull request !36136 from 彭念/scatterdiv/scattermul
This commit is contained in:
i-robot 2022-09-17 00:40:45 +00:00 committed by Gitee
commit f8d8fbb47d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 990 additions and 42 deletions

View File

@ -0,0 +1,255 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h"
#include <utility>
#include <memory>
#include <algorithm>
#include "abstract/utils.h"
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateScatterKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::ScatterHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using ScatterPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, ScatterPtrCreatorFunc>> kernel_attr = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateScatterKernelPtr<float, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CreateScatterKernelPtr<float, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
CreateScatterKernelPtr<half, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
CreateScatterKernelPtr<half, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
CreateScatterKernelPtr<double, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
CreateScatterKernelPtr<double, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
CreateScatterKernelPtr<int8_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
CreateScatterKernelPtr<int8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
CreateScatterKernelPtr<uint8_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
CreateScatterKernelPtr<uint8_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
CreateScatterKernelPtr<int16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
CreateScatterKernelPtr<int16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
CreateScatterKernelPtr<uint16_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
CreateScatterKernelPtr<uint16_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateScatterKernelPtr<int, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
CreateScatterKernelPtr<int, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
CreateScatterKernelPtr<uint32_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
CreateScatterKernelPtr<uint32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateScatterKernelPtr<int64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
CreateScatterKernelPtr<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
CreateScatterKernelPtr<uint64_t, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
CreateScatterKernelPtr<uint64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
CreateScatterKernelPtr<Complex<float>, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
CreateScatterKernelPtr<Complex<float>, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
CreateScatterKernelPtr<Complex<double>, int>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
CreateScatterKernelPtr<Complex<double>, int64_t>}};
} // namespace
bool ScatterGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool ScatterGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->GetPrim()->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
return true;
}
int ScatterGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> inp_shape0 = inputs[kIndex0]->GetShapeVector();
std::vector<int64_t> inp_shape1 = inputs[kIndex1]->GetShapeVector();
std::vector<int64_t> inp_shape2 = inputs[kIndex2]->GetShapeVector();
std::vector<int64_t> out_shape = outputs[kIndex0]->GetShapeVector();
input_shapes.emplace_back(inp_shape0);
input_shapes.emplace_back(inp_shape1);
input_shapes.emplace_back(inp_shape2);
output_shapes.emplace_back(out_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> ScatterGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ScatterPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterDiv, ScatterGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterMul, ScatterGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <map>
#include <memory>
#include "mindspore/core/ops/scatter_div.h"
#include "mindspore/core/ops/scatter_mul.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
class ScatterGpuKernelMod : public NativeGpuKernelMod {
public:
ScatterGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::ScatterAttr>(); }
~ScatterGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
void *stream_ptr_;
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::ScatterAttr> attr_ptr_{nullptr};
BaseOperatorPtr base_operator_ = nullptr;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_FUNCTOR_GPU_KERNEL_H_

View File

@ -0,0 +1,156 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include <map>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh"
namespace mindspore {
namespace cukernel {
static const std::map<std::string, ScatterType> kScatterTypeMap = {
{"ScatterMul", SCATTER_MUL},
{"ScatterDiv", SCATTER_DIV},
};
class ScatterAttr : public GpuKernelAttrBase {
public:
ScatterAttr() = default;
~ScatterAttr() override = default;
};
template <typename T, typename S>
class ScatterHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit ScatterHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
is_null_input_ = false;
}
virtual ~ScatterHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM = 3;
constexpr size_t OUTPUT_NUM = 1;
ResetResource();
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
input_shape_ = input_shapes[0];
indices_shape_ = input_shapes[1];
first_dim_size_ = input_shape_[0];
input_size_ = 1;
inner_size_ = 1;
for (int64_t i = 1; i < static_cast<int64_t>(input_shape_.size()); i++) {
inner_size_ *= input_shape_[i];
}
input_size_ = input_shape_[0] * inner_size_;
indices_size_ = 1;
for (int64_t i = 0; i < static_cast<int64_t>(indices_shape_.size()); i++) {
indices_size_ *= indices_shape_[i];
}
updates_size_ = 1;
updates_size_ = indices_size_ * inner_size_;
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag == -1) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
return 0;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
auto iter = kScatterTypeMap.find(kernel_name_);
if (iter == kScatterTypeMap.end()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "Only support these scatter functors: ScatterMul, ScatterDiv "
<< " currently, but got " << kernel_name_;
} else {
scatter_type_ = iter->second;
}
T *input = nullptr;
S *indices = nullptr;
T *updates = nullptr;
T *output = nullptr;
S size_limit = static_cast<S>(first_dim_size_);
int flag = GetDeviceAddress<T>(input_ptrs, kIndex0, kernel_name_, &input);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<S>(input_ptrs, kIndex1, kernel_name_, &indices);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, kIndex2, kernel_name_, &updates);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, kIndex0, kernel_name_, &output);
if (flag != 0) {
return flag;
}
// call cuda kernel
Scatter(scatter_type_, size_limit, inner_size_, indices_size_, indices, updates, input, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
cudaError_t status = (cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream)));
if (status != cudaSuccess) {
MS_LOG(ERROR) << "CUDA Error: "
<< "cudaMemcpyAsync output failed"
<< " | Error Number: " << status << " " << cudaGetErrorString(status);
}
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<ScatterAttr>(kernel_attr);
}
void ResetResource() noexcept override {
input_size_ = 0;
inner_size_ = 0;
indices_size_ = 0;
updates_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
work_size_list_.clear();
}
private:
ScatterType scatter_type_;
std::shared_ptr<ScatterAttr> attr_ptr_;
std::vector<int64_t> input_shape_;
std::vector<int64_t> indices_shape_;
size_t first_dim_size_;
size_t input_size_;
size_t inner_size_;
size_t indices_size_;
size_t updates_size_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ARGMAX_HELPER_H_

View File

@ -0,0 +1,356 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
// Specializations of atomic div for complex types
__device__ inline Complex<float> ScatterDivComplex(Complex<float>* address, Complex<float> val) {
auto ptr_addr = reinterpret_cast<float*>(address);
float addr_real = (*address).real();
float addr_imag = (*address).imag();
float temp = (pow(val.real(), 2) + pow(val.imag(), 2));
MsAtomicMul(ptr_addr, val.real());
MsAtomicAdd(ptr_addr, addr_imag * val.imag());
MsAtomicMul(ptr_addr + 1, val.real());
MsAtomicSub(ptr_addr + 1, addr_real * val.imag());
return Complex<float>(MsAtomicDiv(ptr_addr, temp),
MsAtomicDiv(ptr_addr + 1, temp));
}
__device__ inline Complex<double> ScatterDivComplex(Complex<double>* address, Complex<double> val) {
auto ptr_addr = reinterpret_cast<double*>(address);
double addr_real = (*address).real();
double addr_imag = (*address).imag();
double temp = (pow(val.real(), 2) + pow(val.imag(), 2));
MsAtomicMul(ptr_addr, val.real());
MsAtomicAdd(ptr_addr, addr_imag * val.imag());
MsAtomicMul(ptr_addr + 1, val.real());
MsAtomicSub(ptr_addr + 1, addr_real * val.imag());
return Complex<double>(MsAtomicDiv(ptr_addr, temp),
MsAtomicDiv(ptr_addr + 1, temp));
}
// Specializations of atomic mul for complex types
__device__ inline Complex<float> ScatterMulComplex(Complex<float>* address, Complex<float> val) {
auto ptr_addr = reinterpret_cast<float*>(address);
float addr_real = (*address).real();
float addr_imag = (*address).imag();
MsAtomicMul(ptr_addr, val.real());
MsAtomicMul(ptr_addr + 1, val.real());
return Complex<float>(MsAtomicSub(ptr_addr, addr_imag * val.imag()),
MsAtomicAdd(ptr_addr + 1, addr_real * val.imag()));
}
__device__ inline Complex<double> ScatterMulComplex(Complex<double>* address, Complex<double> val) {
auto ptr_addr = reinterpret_cast<double*>(address);
double addr_real = (*address).real();
double addr_imag = (*address).imag();
MsAtomicMul(ptr_addr, val.real());
MsAtomicMul(ptr_addr + 1, val.real());
return Complex<double>(MsAtomicSub(ptr_addr, addr_imag * val.imag()),
MsAtomicAdd(ptr_addr + 1, addr_real * val.imag()));
}
template <typename T, typename S>
__global__ void ScatterDivKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices,
const T *updates, T *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicDiv(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
const Complex<float> *updates, Complex<float> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterDivComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
const int64_t *indices, const Complex<float> *updates, Complex<float> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterDivComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
const Complex<double> *updates, Complex<double> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterDivComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
const int64_t *indices, const Complex<double> *updates, Complex<double> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterDivComplex(&input[current_pos], updates[pos]);
}
}
template <typename T, typename S>
__global__ void ScatterMulKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices,
const T *updates, T *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
MsAtomicMul(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
const Complex<float> *updates, Complex<float> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterMulComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
const int64_t *indices, const Complex<float> *updates, Complex<float> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterMulComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
const Complex<double> *updates, Complex<double> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterMulComplex(&input[current_pos], updates[pos]);
}
}
__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
const int64_t *indices, const Complex<double> *updates, Complex<double> *input) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
if (indices[index] < 0 || indices[index] >= size_limit) {
continue;
}
const size_t current_pos = indices[index] * inner_size + offset;
ScatterMulComplex(&input[current_pos], updates[pos]);
}
}
template <typename T, typename S>
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
const S *indices, const T *updates, T *input, const uint32_t &device_id, cudaStream_t cuda_stream) {
const size_t updates_size = inner_size * indices_size;
switch (func_type) {
case SCATTER_DIV:
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
case SCATTER_MUL:
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
default:
break;
}
}
template <typename S>
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
const S *indices, const Complex<float> *updates, Complex<float> *input, const uint32_t &device_id,
cudaStream_t cuda_stream) {
const size_t updates_size = inner_size * indices_size;
switch (func_type) {
case SCATTER_DIV:
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
case SCATTER_MUL:
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
default:
break;
}
}
template <typename S>
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
const S *indices, const Complex<double> *updates, Complex<double> *input, const uint32_t &device_id,
cudaStream_t cuda_stream) {
const size_t updates_size = inner_size * indices_size;
switch (func_type) {
case SCATTER_DIV:
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
case SCATTER_MUL:
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
size_limit, inner_size, updates_size, indices, updates, input);
default:
break;
}
}
template CUDA_LIB_EXPORT void Scatter<float, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const float *updates, float *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<float, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const float *updates, float *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<half, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const half *updates, half *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<half, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const half *updates, half *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<double, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const double *updates, double *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<double, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const double *updates, double *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int8_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int8_t *updates, int8_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int8_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int8_t *updates, int8_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<unsigned char, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const unsigned char *updates,
unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<unsigned char, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const unsigned char *updates,
unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int16_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int16_t *updates, int16_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int16_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int16_t *updates, int16_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint16_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const uint16_t *updates, uint16_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint16_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const uint16_t *updates, uint16_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int *updates, int *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int *updates, int *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint32_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const uint32_t *updates, uint32_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint32_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const uint32_t *updates, uint32_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int64_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const int64_t *updates, int64_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<int64_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const int64_t *updates, int64_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint64_t, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const uint64_t *updates, uint64_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<uint64_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const uint64_t *updates, uint64_t *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<Complex<float>, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const Complex<float> *updates, Complex<float> *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<Complex<float>, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const Complex<float> *updates, Complex<float> *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<Complex<double>, int>(enum ScatterType func_type, int size_limit,
const size_t &inner_size, const size_t &indices_size,
const int *indices, const Complex<double> *updates, Complex<double> *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Scatter<Complex<double>, int64_t>(enum ScatterType func_type, int64_t size_limit,
const size_t &inner_size, const size_t &indices_size,
const int64_t *indices, const Complex<double> *updates, Complex<double> *input,
const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
enum ScatterType {
SCATTER_MUL = 0,
SCATTER_DIV,
SCATTER_INVALID_TYPE = 255
};
template <typename T, typename S>
CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size,
const size_t &indices_size, const S *indices, const T *updates, T *input,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif

View File

@ -30,6 +30,9 @@ class MIND_API ScatterDiv : public BaseOperator {
/// \brief Constructor.
ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
};
abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore

View File

@ -4677,34 +4677,36 @@ class ScatterMul(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
- **indices** (Tensor) - The index to do multiply operation whose data type must be mstype.int32 or
mstype.int64.
- **updates** (Tensor) - The tensor doing the multiply operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
Outputs:
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
Raises:
TypeError: If `use_locking` is not a bool.
TypeError: If `indices` is not an int32.
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
TypeError: If `indices` is not an int32 or an int64.
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
is required when data type conversion of Parameter is not supported.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_mul = ops.ScatterMul()
>>> from mindspore.ops import operations as op
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mstype.int32)
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32)
>>> scatter_mul = op.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
>>> print(output)
[[2. 2. 2.]
[4. 4. 4.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
>>> # for indices = [[0, 1], [1, 1]]
>>> # step 1: [0, 1]
>>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
@ -4712,16 +4714,16 @@ class ScatterMul(_ScatterOpDynamic):
>>> # step 2: [1, 1]
>>> # input_x[1] = [6.0, 6.0, 6.0] * [7.0, 7.0, 7.0] = [42.0, 42.0, 42.0]
>>> # input_x[1] = [42.0, 42.0, 42.0] * [9.0, 9.0, 9.0] = [378.0, 378.0, 378.0]
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
>>> scatter_mul = ops.ScatterMul()
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
>>> scatter_mul = op.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
>>> print(output)
[[ 1. 1. 1.]
[378. 378. 378.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
>>> # for indices = [[1, 0], [1, 1]]
>>> # step 1: [1, 0]
>>> # input_x[0] = [1.0, 1.0, 1.0] * [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
@ -4729,16 +4731,16 @@ class ScatterMul(_ScatterOpDynamic):
>>> # step 2: [1, 1]
>>> # input_x[1] = [2.0, 2.0, 2.0] * [7.0, 7.0, 7.0] = [14.0, 14.0, 14.0]
>>> # input_x[1] = [14.0, 14.0, 14.0] * [9.0, 9.0, 9.0] = [126.0, 126.0, 126.0]
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
>>> scatter_mul = ops.ScatterMul()
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
>>> scatter_mul = op.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
>>> print(output)
[[ 3. 3. 3.]
[126. 126. 126.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
>>> # for indices = [[0, 1], [0, 1]]
>>> # step 1: [0, 1]
>>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
@ -4746,10 +4748,10 @@ class ScatterMul(_ScatterOpDynamic):
>>> # step 2: [0, 1]
>>> # input_x[0] = [1.0, 1.0, 1.0] * [7.0, 7.0, 7.0] = [7.0, 7.0, 7.0]
>>> # input_x[1] = [6.0, 6.0, 6.0] * [9.0, 9.0, 9.0] = [54.0, 54.0, 54.0]
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
>>> scatter_mul = ops.ScatterMul()
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
>>> scatter_mul = op.ScatterMul()
>>> output = scatter_mul(input_x, indices, updates)
>>> print(output)
[[ 7. 7. 7.]
@ -4764,7 +4766,7 @@ class ScatterDiv(_ScatterOpDynamic):
Using given values to update tensor value through the div operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
for each :math:`i, ..., j` in `indices.shape`:
for each `i, ..., j` in `indices.shape`:
.. math::
@ -4780,8 +4782,8 @@ class ScatterDiv(_ScatterOpDynamic):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **indices** (Tensor) - The index to do divide operation whose data type must be mindspore.int32 or
mindspore.int64.
- **indices** (Tensor) - The index to do divide operation whose data type must be mstype.int32 or
mstype.int64.
- **updates** (Tensor) - The tensor doing the divide operation with `input_x`,
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
@ -4798,20 +4800,21 @@ class ScatterDiv(_ScatterOpDynamic):
and `updates` is greater than 8 dimensions.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_div = ops.ScatterDiv()
>>> from mindspore.ops import operations as op
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
>>> indices = Tensor(np.array([0, 1]), mstype.int32)
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32)
>>> scatter_div = op.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
>>> print(output)
[[3. 3. 3.]
[1. 1. 1.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
>>> # for indices = [[0, 1], [1, 1]]
>>> # step 1: [0, 1]
>>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
@ -4819,17 +4822,17 @@ class ScatterDiv(_ScatterOpDynamic):
>>> # step 2: [1, 1]
>>> # input_x[1] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
>>> # input_x[1] = [21.0, 21.0, 21.0] / [7.0, 7.0, 7.0] = [3.0, 3.0, 3.0]
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
>>> scatter_div = ops.ScatterDiv()
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
>>> scatter_div = op.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
>>> print(output)
[[105. 105. 105.]
[ 3. 3. 3.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
>>> # for indices = [[1, 0], [1, 1]]
>>> # step 1: [1, 0]
>>> # input_x[0] = [105.0, 105.0, 105.0] / [3.0, 3.0, 3.0] = [35.0, 35.0, 35.0]
@ -4837,17 +4840,17 @@ class ScatterDiv(_ScatterOpDynamic):
>>> # step 2: [1, 1]
>>> # input_x[1] = [315.0, 315.0, 315.0] / [5.0, 5.0, 5.0] = [63.0 63.0 63.0]
>>> # input_x[1] = [63.0 63.0 63.0] / [7.0, 7.0, 7.0] = [9.0, 9.0, 9.0]
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
>>> scatter_div = ops.ScatterDiv()
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
>>> scatter_div = op.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
>>> print(output)
[[35. 35. 35.]
[ 9. 9. 9.]]
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
>>> # for indices = [[0, 1], [0, 1]]
>>> # step 1: [0, 1]
>>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
@ -4855,10 +4858,10 @@ class ScatterDiv(_ScatterOpDynamic):
>>> # step 2: [0, 1]
>>> # input_x[0] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
>>> # input_x[1] = [105.0, 105.0, 105.0] / [7.0, 7.0, 7.0] = [15.0, 15.0, 15.0]
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32)
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
>>> scatter_div = ops.ScatterDiv()
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
>>> scatter_div = op.ScatterDiv()
>>> output = scatter_div(input_x, indices, updates)
>>> print(output)
[[21. 21. 21.]

View File

@ -0,0 +1,84 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.ops.operations as P
from mindspore.nn import Cell
from mindspore import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.dtype import pytype_to_dtype
class ScatterDiv(Cell):
def __init__(self, input_shape, input_dtype, use_locking):
super().__init__()
self.op = P.ScatterDiv(use_locking)
self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input")
def construct(self, indices, update):
self.op(self.inputdata, indices, update)
return self.inputdata
class ScatterMul(Cell):
def __init__(self, input_shape, input_dtype, use_locking):
super().__init__()
self.op = P.ScatterMul(use_locking)
self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input")
def construct(self, indices, update):
self.op(self.inputdata, indices, update)
return self.inputdata
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter_func_small_float32():
"""
Feature: ScatterDiv/ScatterMul gpu TEST.
Description: test case for ScatterDiv/ScatterMul
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_shape = (2, 3)
input_dtype = np.float32
update_np = np.array(
[
[[23, 11, 15], [14, 36, 215]],
[[330, 9, 65], [10, 7, 39]]
]
).astype(np.float32)
indices_np = np.array([[0, 1], [0, 1]]).astype(np.int32)
# div
indices_me = Tensor(indices_np)
update_me = Tensor(update_np)
net = ScatterDiv(input_shape, pytype_to_dtype(input_dtype), use_locking=True)
out = net(indices_me, update_me)
expect = np.array([[0.00013175, 0.01010101, 0.00102564], [0.00714286, 0.00396825, 0.00011926]])
assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001)
# mul
indices_me = Tensor(indices_np)
update_me = Tensor(update_np)
net = ScatterMul(input_shape, pytype_to_dtype(input_dtype), use_locking=True)
out = net(indices_me, update_me)
expect = np.array([[7590.0, 99.0, 975.0], [140.0, 252.0, 8385.0]])
assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001)