!36136 [assistant][ops] Add New GPU operator ScatterDiv/ScatterMul
Merge pull request !36136 from 彭念/scatterdiv/scattermul
This commit is contained in:
commit
f8d8fbb47d
|
@ -0,0 +1,255 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/scatter_gpu_kernel.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateScatterKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::ScatterHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
using ScatterPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, ScatterPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateScatterKernelPtr<float, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateScatterKernelPtr<float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateScatterKernelPtr<half, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateScatterKernelPtr<half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateScatterKernelPtr<double, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateScatterKernelPtr<double, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateScatterKernelPtr<int8_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
CreateScatterKernelPtr<int8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateScatterKernelPtr<uint8_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateScatterKernelPtr<uint8_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateScatterKernelPtr<int16_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
CreateScatterKernelPtr<int16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateScatterKernelPtr<uint16_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateScatterKernelPtr<uint16_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateScatterKernelPtr<int, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateScatterKernelPtr<int, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateScatterKernelPtr<uint32_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateScatterKernelPtr<uint32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateScatterKernelPtr<int64_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CreateScatterKernelPtr<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateScatterKernelPtr<uint64_t, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateScatterKernelPtr<uint64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateScatterKernelPtr<Complex<float>, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
CreateScatterKernelPtr<Complex<float>, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateScatterKernelPtr<Complex<double>, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CreateScatterKernelPtr<Complex<double>, int64_t>}};
|
||||
} // namespace
|
||||
|
||||
bool ScatterGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ScatterGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
return true;
|
||||
}
|
||||
|
||||
int ScatterGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape0 = inputs[kIndex0]->GetShapeVector();
|
||||
std::vector<int64_t> inp_shape1 = inputs[kIndex1]->GetShapeVector();
|
||||
std::vector<int64_t> inp_shape2 = inputs[kIndex2]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[kIndex0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape0);
|
||||
input_shapes.emplace_back(inp_shape1);
|
||||
input_shapes.emplace_back(inp_shape2);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ScatterGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ScatterPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterDiv, ScatterGpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ScatterMul, ScatterGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ops/scatter_div.h"
|
||||
#include "mindspore/core/ops/scatter_mul.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/scatter_helper.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
class ScatterGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
ScatterGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::ScatterAttr>(); }
|
||||
~ScatterGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void *stream_ptr_;
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::ScatterAttr> attr_ptr_{nullptr};
|
||||
BaseOperatorPtr base_operator_ = nullptr;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_FUNCTOR_GPU_KERNEL_H_
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SCATTER_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
static const std::map<std::string, ScatterType> kScatterTypeMap = {
|
||||
{"ScatterMul", SCATTER_MUL},
|
||||
{"ScatterDiv", SCATTER_DIV},
|
||||
};
|
||||
class ScatterAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
ScatterAttr() = default;
|
||||
~ScatterAttr() override = default;
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
class ScatterHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit ScatterHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~ScatterHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 3;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
input_shape_ = input_shapes[0];
|
||||
indices_shape_ = input_shapes[1];
|
||||
first_dim_size_ = input_shape_[0];
|
||||
input_size_ = 1;
|
||||
inner_size_ = 1;
|
||||
for (int64_t i = 1; i < static_cast<int64_t>(input_shape_.size()); i++) {
|
||||
inner_size_ *= input_shape_[i];
|
||||
}
|
||||
input_size_ = input_shape_[0] * inner_size_;
|
||||
indices_size_ = 1;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(indices_shape_.size()); i++) {
|
||||
indices_size_ *= indices_shape_[i];
|
||||
}
|
||||
updates_size_ = 1;
|
||||
updates_size_ = indices_size_ * inner_size_;
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
auto iter = kScatterTypeMap.find(kernel_name_);
|
||||
if (iter == kScatterTypeMap.end()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "Only support these scatter functors: ScatterMul, ScatterDiv "
|
||||
<< " currently, but got " << kernel_name_;
|
||||
} else {
|
||||
scatter_type_ = iter->second;
|
||||
}
|
||||
T *input = nullptr;
|
||||
S *indices = nullptr;
|
||||
T *updates = nullptr;
|
||||
T *output = nullptr;
|
||||
S size_limit = static_cast<S>(first_dim_size_);
|
||||
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, kIndex0, kernel_name_, &input);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<S>(input_ptrs, kIndex1, kernel_name_, &indices);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(input_ptrs, kIndex2, kernel_name_, &updates);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(output_ptrs, kIndex0, kernel_name_, &output);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
// call cuda kernel
|
||||
Scatter(scatter_type_, size_limit, inner_size_, indices_size_, indices, updates, input, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
|
||||
cudaError_t status = (cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream)));
|
||||
if (status != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "CUDA Error: "
|
||||
<< "cudaMemcpyAsync output failed"
|
||||
<< " | Error Number: " << status << " " << cudaGetErrorString(status);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<ScatterAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
inner_size_ = 0;
|
||||
indices_size_ = 0;
|
||||
updates_size_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
work_size_list_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
ScatterType scatter_type_;
|
||||
std::shared_ptr<ScatterAttr> attr_ptr_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> indices_shape_;
|
||||
size_t first_dim_size_;
|
||||
size_t input_size_;
|
||||
size_t inner_size_;
|
||||
size_t indices_size_;
|
||||
size_t updates_size_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_ARGMAX_HELPER_H_
|
|
@ -0,0 +1,356 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/scatter_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/util.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
|
||||
|
||||
// Specializations of atomic div for complex types
|
||||
__device__ inline Complex<float> ScatterDivComplex(Complex<float>* address, Complex<float> val) {
|
||||
auto ptr_addr = reinterpret_cast<float*>(address);
|
||||
float addr_real = (*address).real();
|
||||
float addr_imag = (*address).imag();
|
||||
float temp = (pow(val.real(), 2) + pow(val.imag(), 2));
|
||||
|
||||
MsAtomicMul(ptr_addr, val.real());
|
||||
MsAtomicAdd(ptr_addr, addr_imag * val.imag());
|
||||
MsAtomicMul(ptr_addr + 1, val.real());
|
||||
MsAtomicSub(ptr_addr + 1, addr_real * val.imag());
|
||||
return Complex<float>(MsAtomicDiv(ptr_addr, temp),
|
||||
MsAtomicDiv(ptr_addr + 1, temp));
|
||||
}
|
||||
|
||||
__device__ inline Complex<double> ScatterDivComplex(Complex<double>* address, Complex<double> val) {
|
||||
auto ptr_addr = reinterpret_cast<double*>(address);
|
||||
double addr_real = (*address).real();
|
||||
double addr_imag = (*address).imag();
|
||||
double temp = (pow(val.real(), 2) + pow(val.imag(), 2));
|
||||
|
||||
MsAtomicMul(ptr_addr, val.real());
|
||||
MsAtomicAdd(ptr_addr, addr_imag * val.imag());
|
||||
MsAtomicMul(ptr_addr + 1, val.real());
|
||||
MsAtomicSub(ptr_addr + 1, addr_real * val.imag());
|
||||
return Complex<double>(MsAtomicDiv(ptr_addr, temp),
|
||||
MsAtomicDiv(ptr_addr + 1, temp));
|
||||
}
|
||||
|
||||
// Specializations of atomic mul for complex types
|
||||
__device__ inline Complex<float> ScatterMulComplex(Complex<float>* address, Complex<float> val) {
|
||||
auto ptr_addr = reinterpret_cast<float*>(address);
|
||||
float addr_real = (*address).real();
|
||||
float addr_imag = (*address).imag();
|
||||
MsAtomicMul(ptr_addr, val.real());
|
||||
MsAtomicMul(ptr_addr + 1, val.real());
|
||||
return Complex<float>(MsAtomicSub(ptr_addr, addr_imag * val.imag()),
|
||||
MsAtomicAdd(ptr_addr + 1, addr_real * val.imag()));
|
||||
}
|
||||
|
||||
__device__ inline Complex<double> ScatterMulComplex(Complex<double>* address, Complex<double> val) {
|
||||
auto ptr_addr = reinterpret_cast<double*>(address);
|
||||
double addr_real = (*address).real();
|
||||
double addr_imag = (*address).imag();
|
||||
MsAtomicMul(ptr_addr, val.real());
|
||||
MsAtomicMul(ptr_addr + 1, val.real());
|
||||
return Complex<double>(MsAtomicSub(ptr_addr, addr_imag * val.imag()),
|
||||
MsAtomicAdd(ptr_addr + 1, addr_real * val.imag()));
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ScatterDivKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices,
|
||||
const T *updates, T *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
MsAtomicDiv(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
|
||||
const Complex<float> *updates, Complex<float> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterDivComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
|
||||
const int64_t *indices, const Complex<float> *updates, Complex<float> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterDivComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterDivKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
|
||||
const Complex<double> *updates, Complex<double> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterDivComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterDivKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
|
||||
const int64_t *indices, const Complex<double> *updates, Complex<double> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterDivComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void ScatterMulKernel(S size_limit, const size_t inner_size, const size_t updates_size, const S *indices,
|
||||
const T *updates, T *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
MsAtomicMul(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
|
||||
const Complex<float> *updates, Complex<float> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterMulComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
|
||||
const int64_t *indices, const Complex<float> *updates, Complex<float> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterMulComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterMulKernel(int size_limit, const size_t inner_size, const size_t updates_size, const int *indices,
|
||||
const Complex<double> *updates, Complex<double> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterMulComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void ScatterMulKernel(int64_t size_limit, const size_t inner_size, const size_t updates_size,
|
||||
const int64_t *indices, const Complex<double> *updates, Complex<double> *input) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
||||
const size_t index = pos / inner_size;
|
||||
const size_t offset = pos % inner_size;
|
||||
if (indices[index] < 0 || indices[index] >= size_limit) {
|
||||
continue;
|
||||
}
|
||||
const size_t current_pos = indices[index] * inner_size + offset;
|
||||
ScatterMulComplex(&input[current_pos], updates[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
|
||||
const S *indices, const T *updates, T *input, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
const size_t updates_size = inner_size * indices_size;
|
||||
switch (func_type) {
|
||||
case SCATTER_DIV:
|
||||
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
case SCATTER_MUL:
|
||||
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
|
||||
const S *indices, const Complex<float> *updates, Complex<float> *input, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
const size_t updates_size = inner_size * indices_size;
|
||||
switch (func_type) {
|
||||
case SCATTER_DIV:
|
||||
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
case SCATTER_MUL:
|
||||
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size, const size_t &indices_size,
|
||||
const S *indices, const Complex<double> *updates, Complex<double> *input, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
const size_t updates_size = inner_size * indices_size;
|
||||
switch (func_type) {
|
||||
case SCATTER_DIV:
|
||||
return ScatterDivKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
case SCATTER_MUL:
|
||||
return ScatterMulKernel<<<CUDA_BLOCKS(device_id, updates_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
size_limit, inner_size, updates_size, indices, updates, input);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void Scatter<float, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const float *updates, float *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<float, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const float *updates, float *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<half, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const half *updates, half *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<half, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const half *updates, half *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<double, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const double *updates, double *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<double, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const double *updates, double *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int8_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int8_t *updates, int8_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int8_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int8_t *updates, int8_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<unsigned char, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const unsigned char *updates,
|
||||
unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<unsigned char, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const unsigned char *updates,
|
||||
unsigned char *input, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int16_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int16_t *updates, int16_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int16_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int16_t *updates, int16_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint16_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const uint16_t *updates, uint16_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint16_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const uint16_t *updates, uint16_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int *updates, int *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int *updates, int *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint32_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const uint32_t *updates, uint32_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint32_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const uint32_t *updates, uint32_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int64_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const int64_t *updates, int64_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<int64_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const int64_t *updates, int64_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint64_t, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const uint64_t *updates, uint64_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<uint64_t, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const uint64_t *updates, uint64_t *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<Complex<float>, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const Complex<float> *updates, Complex<float> *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<Complex<float>, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const Complex<float> *updates, Complex<float> *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<Complex<double>, int>(enum ScatterType func_type, int size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int *indices, const Complex<double> *updates, Complex<double> *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Scatter<Complex<double>, int64_t>(enum ScatterType func_type, int64_t size_limit,
|
||||
const size_t &inner_size, const size_t &indices_size,
|
||||
const int64_t *indices, const Complex<double> *updates, Complex<double> *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SCATTER_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
enum ScatterType {
|
||||
SCATTER_MUL = 0,
|
||||
SCATTER_DIV,
|
||||
SCATTER_INVALID_TYPE = 255
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
CUDA_LIB_EXPORT void Scatter(enum ScatterType func_type, S size_limit, const size_t &inner_size,
|
||||
const size_t &indices_size, const S *indices, const T *updates, T *input,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif
|
|
@ -30,6 +30,9 @@ class MIND_API ScatterDiv : public BaseOperator {
|
|||
/// \brief Constructor.
|
||||
ScatterDiv() : BaseOperator(kNameScatterDiv) { InitIOName({"input_x", "indices", "updates"}, {"output"}); }
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr ScatterDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -4677,34 +4677,36 @@ class ScatterMul(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do min operation whose data type must be mindspore.int32.
|
||||
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + x.shape[1:]`.
|
||||
- **indices** (Tensor) - The index to do multiply operation whose data type must be mstype.int32 or
|
||||
mstype.int64.
|
||||
- **updates** (Tensor) - The tensor doing the multiply operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the updated `input_x`, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `use_locking` is not a bool.
|
||||
TypeError: If `indices` is not an int32.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + x.shape[1:]`.
|
||||
TypeError: If `indices` is not an int32 or an int64.
|
||||
ValueError: If the shape of `updates` is not equal to `indices.shape + input_x.shape[1:]`.
|
||||
RuntimeError: If the data type of `input_x` and `updates` conversion of Parameter
|
||||
is required when data type conversion of Parameter is not supported.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
||||
>>> scatter_mul = ops.ScatterMul()
|
||||
>>> from mindspore.ops import operations as op
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32)
|
||||
>>> scatter_mul = op.ScatterMul()
|
||||
>>> output = scatter_mul(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[2. 2. 2.]
|
||||
[4. 4. 4.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[0, 1], [1, 1]]
|
||||
>>> # step 1: [0, 1]
|
||||
>>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
|
||||
|
@ -4712,16 +4714,16 @@ class ScatterMul(_ScatterOpDynamic):
|
|||
>>> # step 2: [1, 1]
|
||||
>>> # input_x[1] = [6.0, 6.0, 6.0] * [7.0, 7.0, 7.0] = [42.0, 42.0, 42.0]
|
||||
>>> # input_x[1] = [42.0, 42.0, 42.0] * [9.0, 9.0, 9.0] = [378.0, 378.0, 378.0]
|
||||
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
|
||||
>>> scatter_mul = ops.ScatterMul()
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
|
||||
>>> scatter_mul = op.ScatterMul()
|
||||
>>> output = scatter_mul(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 1. 1. 1.]
|
||||
[378. 378. 378.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[1, 0], [1, 1]]
|
||||
>>> # step 1: [1, 0]
|
||||
>>> # input_x[0] = [1.0, 1.0, 1.0] * [3.0, 3.0, 3.0] = [3.0, 3.0, 3.0]
|
||||
|
@ -4729,16 +4731,16 @@ class ScatterMul(_ScatterOpDynamic):
|
|||
>>> # step 2: [1, 1]
|
||||
>>> # input_x[1] = [2.0, 2.0, 2.0] * [7.0, 7.0, 7.0] = [14.0, 14.0, 14.0]
|
||||
>>> # input_x[1] = [14.0, 14.0, 14.0] * [9.0, 9.0, 9.0] = [126.0, 126.0, 126.0]
|
||||
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
|
||||
>>> scatter_mul = ops.ScatterMul()
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
|
||||
>>> scatter_mul = op.ScatterMul()
|
||||
>>> output = scatter_mul(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 3. 3. 3.]
|
||||
[126. 126. 126.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> input_x = Parameter(Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[0, 1], [0, 1]]
|
||||
>>> # step 1: [0, 1]
|
||||
>>> # input_x[0] = [1.0, 1.0, 1.0] * [1.0, 1.0, 1.0] = [1.0, 1.0, 1.0]
|
||||
|
@ -4746,10 +4748,10 @@ class ScatterMul(_ScatterOpDynamic):
|
|||
>>> # step 2: [0, 1]
|
||||
>>> # input_x[0] = [1.0, 1.0, 1.0] * [7.0, 7.0, 7.0] = [7.0, 7.0, 7.0]
|
||||
>>> # input_x[1] = [6.0, 6.0, 6.0] * [9.0, 9.0, 9.0] = [54.0, 54.0, 54.0]
|
||||
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mindspore.float32)
|
||||
>>> scatter_mul = ops.ScatterMul()
|
||||
... [[7.0, 7.0, 7.0], [9.0, 9.0, 9.0]]]), mstype.float32)
|
||||
>>> scatter_mul = op.ScatterMul()
|
||||
>>> output = scatter_mul(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[ 7. 7. 7.]
|
||||
|
@ -4764,7 +4766,7 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
Using given values to update tensor value through the div operation, along with the input indices.
|
||||
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
|
||||
|
||||
for each :math:`i, ..., j` in `indices.shape`:
|
||||
for each `i, ..., j` in `indices.shape`:
|
||||
|
||||
.. math::
|
||||
|
||||
|
@ -4780,8 +4782,8 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
Inputs:
|
||||
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **indices** (Tensor) - The index to do divide operation whose data type must be mindspore.int32 or
|
||||
mindspore.int64.
|
||||
- **indices** (Tensor) - The index to do divide operation whose data type must be mstype.int32 or
|
||||
mstype.int64.
|
||||
- **updates** (Tensor) - The tensor doing the divide operation with `input_x`,
|
||||
the data type is same as `input_x`, the shape is `indices.shape + input_x.shape[1:]`.
|
||||
|
||||
|
@ -4798,20 +4800,21 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
and `updates` is greater than 8 dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mindspore.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mindspore.int32)
|
||||
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mindspore.float32)
|
||||
>>> scatter_div = ops.ScatterDiv()
|
||||
>>> from mindspore.ops import operations as op
|
||||
>>> input_x = Parameter(Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mstype.float32), name="x")
|
||||
>>> indices = Tensor(np.array([0, 1]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), mstype.float32)
|
||||
>>> scatter_div = op.ScatterDiv()
|
||||
>>> output = scatter_div(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[3. 3. 3.]
|
||||
[1. 1. 1.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
|
||||
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
|
||||
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[0, 1], [1, 1]]
|
||||
>>> # step 1: [0, 1]
|
||||
>>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
|
||||
|
@ -4819,17 +4822,17 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
>>> # step 2: [1, 1]
|
||||
>>> # input_x[1] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
|
||||
>>> # input_x[1] = [21.0, 21.0, 21.0] / [7.0, 7.0, 7.0] = [3.0, 3.0, 3.0]
|
||||
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
|
||||
>>> scatter_div = ops.ScatterDiv()
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
|
||||
>>> scatter_div = op.ScatterDiv()
|
||||
>>> output = scatter_div(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[105. 105. 105.]
|
||||
[ 3. 3. 3.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
|
||||
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
|
||||
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[1, 0], [1, 1]]
|
||||
>>> # step 1: [1, 0]
|
||||
>>> # input_x[0] = [105.0, 105.0, 105.0] / [3.0, 3.0, 3.0] = [35.0, 35.0, 35.0]
|
||||
|
@ -4837,17 +4840,17 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
>>> # step 2: [1, 1]
|
||||
>>> # input_x[1] = [315.0, 315.0, 315.0] / [5.0, 5.0, 5.0] = [63.0 63.0 63.0]
|
||||
>>> # input_x[1] = [63.0 63.0 63.0] / [7.0, 7.0, 7.0] = [9.0, 9.0, 9.0]
|
||||
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[1, 0], [1, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
|
||||
>>> scatter_div = ops.ScatterDiv()
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
|
||||
>>> scatter_div = op.ScatterDiv()
|
||||
>>> output = scatter_div(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[35. 35. 35.]
|
||||
[ 9. 9. 9.]]
|
||||
>>> # for input_x will be updated after the operation is completed. input_x need to be re-initialized.
|
||||
>>> input_x = Parameter(Tensor(np.array([[105.0, 105.0, 105.0],
|
||||
... [315.0, 315.0, 315.0]]), mindspore.float32), name="x")
|
||||
... [315.0, 315.0, 315.0]]), mstype.float32), name="x")
|
||||
>>> # for indices = [[0, 1], [0, 1]]
|
||||
>>> # step 1: [0, 1]
|
||||
>>> # input_x[0] = [105.0, 105.0, 105.0] / [1.0, 1.0, 1.0] = [105.0, 105.0, 105.0]
|
||||
|
@ -4855,10 +4858,10 @@ class ScatterDiv(_ScatterOpDynamic):
|
|||
>>> # step 2: [0, 1]
|
||||
>>> # input_x[0] = [105.0, 105.0, 105.0] / [5.0, 5.0, 5.0] = [21.0, 21.0, 21.0]
|
||||
>>> # input_x[1] = [105.0, 105.0, 105.0] / [7.0, 7.0, 7.0] = [15.0, 15.0, 15.0]
|
||||
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mindspore.int32)
|
||||
>>> indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int32)
|
||||
>>> updates = Tensor(np.array([[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]],
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mindspore.float32)
|
||||
>>> scatter_div = ops.ScatterDiv()
|
||||
... [[5.0, 5.0, 5.0], [7.0, 7.0, 7.0]]]), mstype.float32)
|
||||
>>> scatter_div = op.ScatterDiv()
|
||||
>>> output = scatter_div(input_x, indices, updates)
|
||||
>>> print(output)
|
||||
[[21. 21. 21.]
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
|
||||
|
||||
class ScatterDiv(Cell):
|
||||
def __init__(self, input_shape, input_dtype, use_locking):
|
||||
super().__init__()
|
||||
self.op = P.ScatterDiv(use_locking)
|
||||
self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input")
|
||||
|
||||
def construct(self, indices, update):
|
||||
self.op(self.inputdata, indices, update)
|
||||
return self.inputdata
|
||||
|
||||
|
||||
class ScatterMul(Cell):
|
||||
def __init__(self, input_shape, input_dtype, use_locking):
|
||||
super().__init__()
|
||||
self.op = P.ScatterMul(use_locking)
|
||||
self.inputdata = Parameter(initializer(1, input_shape, input_dtype), name="input")
|
||||
|
||||
def construct(self, indices, update):
|
||||
self.op(self.inputdata, indices, update)
|
||||
return self.inputdata
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter_func_small_float32():
|
||||
"""
|
||||
Feature: ScatterDiv/ScatterMul gpu TEST.
|
||||
Description: test case for ScatterDiv/ScatterMul
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
input_shape = (2, 3)
|
||||
input_dtype = np.float32
|
||||
update_np = np.array(
|
||||
[
|
||||
[[23, 11, 15], [14, 36, 215]],
|
||||
[[330, 9, 65], [10, 7, 39]]
|
||||
]
|
||||
).astype(np.float32)
|
||||
indices_np = np.array([[0, 1], [0, 1]]).astype(np.int32)
|
||||
|
||||
# div
|
||||
indices_me = Tensor(indices_np)
|
||||
update_me = Tensor(update_np)
|
||||
net = ScatterDiv(input_shape, pytype_to_dtype(input_dtype), use_locking=True)
|
||||
out = net(indices_me, update_me)
|
||||
expect = np.array([[0.00013175, 0.01010101, 0.00102564], [0.00714286, 0.00396825, 0.00011926]])
|
||||
assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001)
|
||||
|
||||
# mul
|
||||
indices_me = Tensor(indices_np)
|
||||
update_me = Tensor(update_np)
|
||||
net = ScatterMul(input_shape, pytype_to_dtype(input_dtype), use_locking=True)
|
||||
out = net(indices_me, update_me)
|
||||
expect = np.array([[7590.0, 99.0, 975.0], [140.0, 252.0, 8385.0]])
|
||||
assert np.allclose(out.asnumpy(), expect.astype(np.float32), 0.0001, 0.0001)
|
Loading…
Reference in New Issue