add scatter_nd_update op for gpu

This commit is contained in:
xcnick 2021-05-26 15:32:49 +08:00
parent ac7178b8b9
commit 86bf1fd890
6 changed files with 578 additions and 1 deletions

View File

@ -0,0 +1,134 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_update_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterNdUpdateKernel, double, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterNdUpdateKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateKernel, float, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterNdUpdateKernel, half, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ScatterNdUpdateKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterNdUpdateKernel, int, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
ScatterNdUpdateKernel, int, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
ScatterNdUpdateKernel, int16_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
ScatterNdUpdateKernel, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterNdUpdateKernel, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
ScatterNdUpdateKernel, uint8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterNdUpdateKernel, int8_t, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
ScatterNdUpdateKernel, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
ScatterNdUpdateKernel, bool, int)
MS_REG_GPU_KERNEL_TWO(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
ScatterNdUpdateKernel, bool, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,170 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_ND_UPDATE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_ND_UPDATE_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd_update_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class ScatterNdUpdateKernel : public GpuKernel {
public:
ScatterNdUpdateKernel() { ResetResource(); }
~ScatterNdUpdateKernel() {
if (indices_stride_ != nullptr) {
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_));
}
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
S *indices = GetDeviceAddress<S>(inputs, 1);
T *updates = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0);
const size_t indices_len = sizeof(S) * out_strides_.size();
void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len);
if (indices_stride_work == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len;
}
indices_stride_ = static_cast<S *>(indices_stride_work);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(indices_stride_, &out_strides_[0], indices_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync failed in ScatterNdUpdateGpuFwdKernel::Launch.");
CalScatterNdUpdate(unit_size_, num_units_, index_depth_, indices_stride_, indices, updates, input,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ScatterNdUpdate needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but ScatterNdUpdate has 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto updates_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
auto index_depth = indices_shape.back();
if (index_depth > input_shape.size()) {
MS_LOG(EXCEPTION) << "Value of last dimension of indices is greater than shape rank";
}
if (indices_shape.size() < 2) {
MS_LOG(EXCEPTION) << "Indices dimension less than 2";
}
if (updates_shape.size() != indices_shape.size() - 1 + input_shape.size() - index_depth) {
MS_LOG(EXCEPTION) << "Update, shape rank and indices rank inconsistent";
}
for (size_t i = 0; i < indices_shape.size() - 1; ++i) {
if (updates_shape[i] != indices_shape[i]) {
MS_LOG(EXCEPTION) << "Value of " << i << "th dimension of indices is not equal to that update";
}
}
indices_size_ = 1;
for (size_t i = 0; i < indices_shape.size(); i++) {
indices_size_ *= indices_shape[i];
}
input_size_ = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
updates_size_ = 1;
for (size_t i = 0; i < updates_shape.size(); i++) {
updates_size_ *= updates_shape[i];
}
index_depth_ = SizeToInt(index_depth);
unit_size_ = 1;
for (size_t i = indices_shape.size() - 1; i < updates_shape.size(); ++i) {
unit_size_ *= SizeToInt(updates_shape[i]);
}
num_units_ = 1;
num_units_ *= updates_shape[indices_shape.size() - 2];
for (int i = SizeToInt(indices_shape.size()) - 3; i >= 0; i--) {
num_units_ *= updates_shape[i];
}
int out_stride = 1;
out_strides_.push_back(out_stride);
for (int i = index_depth_ - 2; i >= 0; i--) {
out_stride *= input_shape[i + 1];
out_strides_.push_back(out_stride);
}
reverse(out_strides_.begin(), out_strides_.end());
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
indices_size_ = 0;
updates_size_ = 0;
unit_size_ = 0;
num_units_ = 0;
index_depth_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
input_size_list_.push_back(indices_size_ * sizeof(S));
input_size_list_.push_back(updates_size_ * sizeof(T));
output_size_list_.push_back(input_size_ * sizeof(T));
}
private:
size_t input_size_;
size_t indices_size_;
size_t updates_size_;
size_t unit_size_;
size_t num_units_;
size_t index_depth_;
std::vector<S> out_strides_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
S *indices_stride_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SCATTER_ND_UPDATE_GPU_KERNEL_H_

View File

@ -0,0 +1,116 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd_update_impl.cuh"
template <typename T, typename S>
__global__ void ScatterNdUpdate(const size_t unit_size, const size_t index_depth, const size_t updates_size,
const S *out_strides, const S *indices, const T *updates, T *input) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < (updates_size);
read_index += blockDim.x * gridDim.x) {
size_t write_index = 0;
bool out_bound = false;
i = read_index / unit_size;
j = read_index % unit_size;
for (size_t k = 0; k < index_depth; k++) {
S indices_i = indices[i * index_depth + k];
out_bound |= indices_i < 0;
write_index += indices_i * out_strides[k] * unit_size;
}
write_index += j;
if (!out_bound) {
input[write_index] = updates[read_index];
}
}
}
template <typename T, typename S>
void CalScatterNdUpdate(const size_t &unit_size, const size_t &num_units, const size_t &index_depth,
const S *out_strides, const S *indices, const T *updates, T *input, cudaStream_t cuda_stream) {
const size_t updates_size = unit_size * num_units;
ScatterNdUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(unit_size, index_depth, updates_size,
out_strides, indices, updates, input);
}
template void CalScatterNdUpdate<double, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const double *updates, double *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<double, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const double *updates, double *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<float, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const float *updates, float *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<float, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const float *updates, float *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<half, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const half *updates, half *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<half, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const half *updates, half *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int32_t, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const int32_t *updates, int32_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int32_t, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const int32_t *updates, int32_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int16_t, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const int16_t *updates, int16_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int16_t, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const int16_t *updates, int16_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<uint8_t, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const uint8_t *updates, uint8_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<uint8_t, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const uint8_t *updates, uint8_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int8_t, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const int8_t *updates, int8_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<int8_t, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const int8_t *updates, int8_t *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<bool, int64_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int64_t *out_strides,
const int64_t *indices, const bool *updates, bool *input,
cudaStream_t cuda_stream);
template void CalScatterNdUpdate<bool, int32_t>(const size_t &unit_size, const size_t &num_units,
const size_t &index_depth, const int32_t *out_strides,
const int32_t *indices, const bool *updates, bool *input,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ND_UPDATE_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ND_UPDATE_IMPL_CUH_
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
void CalScatterNdUpdate(const size_t &unit_size, const size_t &num_units, const size_t &index_depth,
const S *out_strides, const S *indices, const T *updates, T *input, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ND_UPDATE_IMPL_CUH_

View File

@ -3634,7 +3634,7 @@ class ScatterNdUpdate(_ScatterNdOp):
TypeError: If `use_locking` is not a bool.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])

View File

@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
import mindspore.common.dtype as mstype
import mindspore.ops as ops
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.env_onecard
def test_op1():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = ops.ScatterNdUpdate()
self.x = Parameter(
Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x"
)
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.env_onecard
def test_op2():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = ops.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.env_onecard
def test_op3():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = ops.ScatterNdUpdate()
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0], [2]]), mstype.int32)
update = Tensor(
np.array(
[
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
]
),
mstype.float32,
)
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
expect = [
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_traning
@pytest.mark.env_onecard
def test_op4():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = ops.ScatterNdUpdate()
self.x = Parameter(
Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x"
)
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 1]]), mstype.int32)
update = Tensor(np.array([1.0]), mstype.float32)
scatter_nd_update = ScatterNdUpdate()
out = scatter_nd_update(indices, update)
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, np.float))