!16668 add gpu op cumprod and tensor_scatter_add

From: @yanglf1121
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-05-27 10:04:28 +08:00 committed by Gitee
commit b75c694032
15 changed files with 1095 additions and 4 deletions

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/tensor_scatter_add_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
TensorScatterAddGpuFwdKernel, half, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
TensorScatterAddGpuFwdKernel, float, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
TensorScatterAddGpuFwdKernel, double, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
TensorScatterAddGpuFwdKernel, char, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
TensorScatterAddGpuFwdKernel, uchar, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
TensorScatterAddGpuFwdKernel, int, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
TensorScatterAddGpuFwdKernel, double, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,208 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_add.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class TensorScatterAddGpuFwdKernel : public GpuKernel {
public:
TensorScatterAddGpuFwdKernel()
: input_size_(1),
update_size_(1),
indices_size_(1),
output_size_(1),
block_size_(1),
indices_stride_(nullptr),
work_shape_(nullptr),
indices_dim_0_(0),
indices_dim_1_(0),
memcpy_flag_(false) {}
~TensorScatterAddGpuFwdKernel() {
if (indices_stride_ != nullptr) {
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_));
}
if (work_shape_ != nullptr) {
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_));
}
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
T *input = GetDeviceAddress<T>(inputs, 0);
S *indices = GetDeviceAddress<S>(inputs, 1);
T *update = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0);
if (!memcpy_flag_) {
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
const size_t vec_work_len = sizeof(S) * vec_work_shape_.size();
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpy failed in TensorScatterAddGpuFwdKernel::Launch.");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpy failed in TensorScatterAddGpuFwdKernel::Launch.");
memcpy_flag_ = true;
}
const size_t update_size = update_size_ / sizeof(T);
const size_t output_size = output_size_ / sizeof(T);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(&output[0], &input[0], input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
TensorScatterAdd(input, indices, update, output, block_size_, update_size, output_size, indices_dim_0_,
indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
memcpy_flag_ = false;
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorScatterAdd needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorScatterAdd has 1 output.";
return false;
}
update_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
std::vector<size_t> shape_me = input_shapes_;
(void)std::transform(shape_me.begin(), shape_me.end(), std::back_inserter(vec_work_shape_),
[](const size_t &value) { return static_cast<S>(value); });
GetSize();
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len);
if (indices_stride_work == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len;
}
indices_stride_ = static_cast<S *>(indices_stride_work);
const size_t vec_work_len = sizeof(S) * vec_work_shape_.size();
void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len);
if (work_shape_work == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len;
}
work_shape_ = static_cast<S *>(work_shape_work);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
input_size_list_.push_back(indices_size_);
input_size_list_.push_back(update_size_);
output_size_list_.push_back(output_size_);
return;
}
void GetSize() {
input_size_ = sizeof(T);
for (size_t i = 0; i < input_shapes_.size(); i++) {
input_size_ *= input_shapes_[i];
}
indices_size_ = sizeof(S);
for (size_t i = 0; i < indices_shapes_.size(); i++) {
indices_size_ *= indices_shapes_[i];
}
update_size_ = sizeof(T);
for (size_t i = 0; i < update_shapes_.size(); i++) {
update_size_ *= update_shapes_[i];
}
output_size_ = sizeof(T);
for (size_t i = 0; i < output_shapes_.size(); i++) {
output_size_ *= output_shapes_[i];
}
// calculate indices dim 0/1
indices_dim_0_ = indices_shapes_[0];
indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1];
// calculate block_size
for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) {
block_size_ *= output_shapes_[i];
}
// calculate indices_stride
vec_indices_stride_.resize(indices_dim_1_, 0);
vec_indices_stride_[indices_dim_1_ - 1] = block_size_;
for (size_t i = indices_dim_1_ - 1; i > 0; --i) {
vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i];
}
}
private:
std::vector<size_t> update_shapes_;
std::vector<size_t> indices_shapes_;
std::vector<size_t> input_shapes_;
std::vector<size_t> output_shapes_;
std::vector<S> vec_indices_stride_;
std::vector<S> vec_work_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t update_size_;
size_t indices_size_;
size_t output_size_;
size_t block_size_;
S *indices_stride_;
S *work_shape_;
size_t indices_dim_0_;
size_t indices_dim_1_;
bool memcpy_flag_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H

View File

@ -0,0 +1,155 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cumprod_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void Copy(T *input, T *output, size_t size) {
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) {
input[write_index] = output[write_index];
}
}
template <typename T>
__global__ void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (size_t j = 0; j < dim1; ++j) {
size_t read_index = j * stride2 + offset;
if (j == 0) {
output[read_index] = 0;
} else {
size_t read_index2 = (j - 1) * stride2 + offset;
output[read_index] = input[read_index2];
}
}
}
}
template <typename T>
__global__ void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (int j = dim1 - 1; j >= 0; --j) {
size_t read_index = j * stride2 + offset;
if (j == dim1 - 1) {
output[read_index] = 0;
} else {
size_t read_index2 = (j + 1) * stride2 + offset;
output[read_index] = input[read_index2];
}
}
}
}
template <typename T>
__global__ void CumProdKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (int j = dim1 - 1; j >= 0; --j) {
size_t read_index = j * stride2 + offset;
if (j == dim1 - 1) {
output[read_index] = input[read_index];
} else {
size_t read_index2 = (j + 1) * stride2 + offset;
output[read_index] = output[read_index2] * input[read_index];
}
}
}
}
template <typename T>
__global__ void CumProdKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2) {
size_t num = dim0 * dim2;
size_t i, k, offset;
size_t step = blockDim.x * gridDim.x;
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
i = write_index / dim2 % dim0;
k = write_index % dim2;
offset = i * stride + k;
for (size_t j = 0; j < dim1; ++j) {
size_t read_index = j * stride2 + offset;
if (j == 0) {
output[read_index] = input[read_index];
} else {
size_t read_index2 = (j - 1) * stride2 + offset;
output[read_index] = output[read_index2] * input[read_index];
}
}
}
}
template <typename T>
void CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) {
int size = dim0 * dim2;
if (exclusive_) {
if (reverse_) {
RightMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
CumProdKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride,
stride2);
} else {
LeftMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
CumProdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride, stride2);
}
} else {
if (reverse_) {
CumProdKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride,
stride2);
} else {
CumProdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
}
}
return;
}
template void CumProd<uint8_t>(const uint8_t *input, uint8_t *output, uint8_t *workspace, size_t dim0, size_t dim1,
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
cudaStream_t stream);
template void CumProd<int8_t>(const int8_t *input, int8_t *output, int8_t *workspace, size_t dim0, size_t dim1,
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
cudaStream_t stream);
template void CumProd<int32_t>(const int32_t *input, int32_t *output, int32_t *workspace, size_t dim0, size_t dim1,
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
cudaStream_t stream);
template void CumProd<double>(const double *input, double *output, double *workspace, size_t dim0, size_t dim1,
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
cudaStream_t stream);
template void CumProd<float>(const float *input, float *output, float *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
template void CumProd<half>(const half *input, half *output, half *workspace, size_t dim0, size_t dim1, size_t dim2,
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);

View File

@ -0,0 +1,22 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
template <typename T>
void CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_add.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
__global__ void TensorScatterAddKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
int i, j;
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
read_index += blockDim.x * gridDim.x) {
size_t write_index = 0;
bool out_bound = false;
i = read_index / block_size;
j = read_index % block_size;
for (size_t k = 0; k < indices_dim_1; k++) {
S indices_i = indices[i * indices_dim_1 + k];
out_bound |= indices_i >= work_shape[k];
write_index += indices_i * indices_stride[k];
}
write_index += j;
out_bound |= write_index >= output_size;
if (!out_bound) {
MsAtomicAdd(&output[write_index], update[read_index]);
}
}
}
template <typename T, typename S>
void TensorScatterAdd(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, cudaStream_t stream) {
TensorScatterAddKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride,
work_shape);
return;
}
template void TensorScatterAdd<half, int>(half *input, int *indices, half *update, half *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, cudaStream_t stream);
template void TensorScatterAdd<float, int>(float *input, int *indices, float *update, float *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterAdd<double, int>(double *input, int *indices, double *update, double *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterAdd<char, int>(char *input, int *indices, char *update, char *output,
const size_t &block_size, const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, cudaStream_t stream);
template void TensorScatterAdd<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
unsigned char *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1,
int *indices_stride, int *work_shape, cudaStream_t stream);
template void TensorScatterAdd<int, int>(int *input, int *indices, int *update, int *output, const size_t &block_size,
const size_t &input_size, const size_t &output_size,
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
int *work_shape, cudaStream_t stream);
template void TensorScatterAdd<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, cudaStream_t stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
void TensorScatterAdd(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
S *indices_stride, S *work_shape, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH

View File

@ -0,0 +1,34 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/cumprod_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
CumProdGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
CumProdGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CumProdGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CumProdGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CumProdGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CumProdGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,108 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/cumprod_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class CumProdGpuKernel : public GpuKernel {
public:
CumProdGpuKernel() : exclusive_(false), reverse_(false), axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
~CumProdGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
T *ws_addr = GetDeviceAddress<T>(workspace, 0);
CumProd(input_addr, output_addr, ws_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, exclusive_, reverse_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumProdGpuKernel needs 1.";
return false;
}
input_size_0_ = sizeof(T);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
exclusive_ = GetAttr<bool>(kernel_node, "exclusive");
reverse_ = GetAttr<bool>(kernel_node, "reverse");
int input_dim_length = SizeToInt(shape_.size());
if (axis_ >= input_dim_length) {
MS_LOG(EXCEPTION) << "Axis out of bounds.";
}
while (axis_ < 0) {
axis_ += input_dim_length;
}
for (size_t i = 0; i < shape_.size(); i++) {
input_size_0_ *= shape_[i];
}
Reshape();
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_0_);
output_size_list_.push_back(input_size_0_);
workspace_size_list_.push_back(input_size_0_);
}
private:
void Reshape() {
dims_[0] = 1;
dims_[1] = shape_[IntToSize(axis_)];
dims_[2] = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
dims_[0] *= shape_[i];
}
for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) {
dims_[2] *= shape_[i];
}
stride_ = dims_[1] * dims_[2];
stride2_ = dims_[2];
return;
}
bool exclusive_;
bool reverse_;
int axis_;
size_t input_size_0_;
size_t stride_;
size_t stride2_;
size_t dims_[3] = {};
std::vector<size_t> shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_

View File

@ -757,6 +757,18 @@ def get_bprop_tensor_scatter_update(self):
return bprop
@bprop_getters.register(P.TensorScatterAdd)
def get_bprop_tensor_scatter_add(self):
"""Generate bprop for TensorScatterAdd"""
gather_nd = P.GatherNd()
def bprop(x, indices, update, out, dout):
update_grad = gather_nd(dout, indices)
return dout, zeros_like(indices), update_grad
return bprop
@bprop_getters.register(P.ScatterMax)
def get_bprop_scatter_max(self):
"""Generate bprop for ScatterMax"""

View File

@ -188,6 +188,9 @@ bool_eq = Primitive("bool_eq")
logical_and = P.LogicalAnd()
logical_or = P.LogicalOr()
logical_not = P.LogicalNot()
cumsum = P.CumSum()
cumprod = P.CumProd()
tensor_scatter_add = P.TensorScatterAdd()
array_to_scalar = Primitive('array_to_scalar')
is_ = Primitive("is_")
is_not = Primitive("is_not")

View File

@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad,
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
Squeeze, StridedSlice, Tile, TensorScatterUpdate, TensorScatterAdd, EditDistance, Sort,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
@ -308,6 +308,7 @@ __all__ = [
'MirrorPad',
'GatherNd',
'TensorScatterUpdate',
'TensorScatterAdd',
'ScatterUpdate',
'ScatterNdUpdate',
'Floor',

View File

@ -3469,8 +3469,6 @@ class TensorScatterUpdate(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE)
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
return x_shape
@ -3482,6 +3480,66 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_dtype
class TensorScatterAdd(PrimitiveWithInfer):
"""
Creates a new tensor by adding the values from the positions in `input_x` indicicated by
`indices`, with values from `update`. When multiple values are given for the same
index, the updated result will be the sum of all values. This operation is almost
equivalent to using ScatterNdAdd, except that the updates are applied on `Tensor`
instead of `Parameter`.
The last axis of `indices` is the depth of each index vectors. For each index vector,
there must be a corresponding value in `update`. The shape of `update` should be
equal to the shape of `input_x[indices]`.
Inputs:
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
The rank must be at least 2.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
Outputs:
Tensor, has the same shape and type as `input_x`.
Raises:
TypeError: If dtype of `indices` is neither int32 nor int64.
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
Supported Platforms:
``GPU``
Examples:
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = ops.TensorScatterAdd()
>>> output = op(input_x, indices, update)
>>> print(output)
[[ 3.1 0.3 3.6]
[ 0.4 0.5 -3.2]]
"""
@prim_attr_register
def __init__(self):
"""Initialize TensorScatterAdd"""
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
def infer_shape(self, x_shape, indices_shape, value_shape):
if len(indices_shape) < 2:
raise ValueError("For 'TensorScatterAdd', the rank of the indices must > 2.")
update_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:]
if update_shape != value_shape:
raise ValueError("For 'TensorScatterAdd', input value are not match with input indices.")
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterUpdate(_ScatterOp_Dynamic):
r"""
Updates tensor values by using input indices and value.

View File

@ -719,7 +719,7 @@ class CumProd(PrimitiveWithInfer):
ValueError: If `axis` is None.
Supported Platforms:
``Ascend``
``Ascend`` ``GPU``
Examples:
>>> a, b, c, = 1, 2, 3

View File

@ -0,0 +1,157 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import operations as P
def cum_prod(nptype):
context.set_context(device_target='GPU')
x0 = np.random.rand(2, 3, 4, 4).astype(nptype)
axis0 = 3
x1 = np.random.rand(2, 3, 4, 4).astype(nptype)
axis1 = 3
x2 = np.random.rand(2, 3, 1, 4).astype(nptype)
axis2 = 2
x3 = np.random.rand(2, 3, 1, 4).astype(nptype)
axis3 = 2
x4 = np.random.rand(2, 3, 4, 4).astype(nptype)
axis4 = 1
x5 = np.random.rand(2, 3).astype(nptype)
axis5 = 1
x6 = np.random.rand(1, 1, 1, 1).astype(nptype)
axis6 = 0
class CumProd(nn.Cell):
def __init__(self, nptype):
super(CumProd, self).__init__()
self.x0 = Tensor(x0)
self.axis0 = axis0
self.x1 = Tensor(x1)
self.axis1 = axis1
self.x2 = Tensor(x2)
self.axis2 = axis2
self.x3 = Tensor(x3)
self.axis3 = axis3
self.x4 = Tensor(x4)
self.axis4 = axis4
self.x5 = Tensor(x5)
self.axis5 = axis5
self.x6 = Tensor(x6)
self.axis6 = axis6
@ms_function
def construct(self):
return (P.CumProd()(self.x0, self.axis0),
P.CumProd()(self.x1, self.axis1),
P.CumProd()(self.x2, self.axis2),
P.CumProd()(self.x3, self.axis3),
P.CumProd()(self.x4, self.axis4),
P.CumProd()(self.x5, self.axis5),
P.CumProd()(self.x6, self.axis6))
cumprod = CumProd(nptype)
output = cumprod()
expect0 = np.cumprod(x0, axis=axis0)
diff0 = abs(output[0].asnumpy() - expect0)
error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0)
assert output[0].shape == expect0.shape
expect1 = np.cumprod(x1, axis=axis1)
diff1 = abs(output[1].asnumpy() - expect1)
error1 = np.ones(shape=expect1.shape) * 1.0e-5
assert np.all(diff1 < error1)
assert output[1].shape == expect1.shape
expect2 = np.cumprod(x2, axis=axis2)
diff2 = abs(output[2].asnumpy() - expect2)
error2 = np.ones(shape=expect2.shape) * 1.0e-5
assert np.all(diff2 < error2)
assert output[2].shape == expect2.shape
expect3 = np.cumprod(x3, axis=axis3)
diff3 = abs(output[3].asnumpy() - expect3)
error3 = np.ones(shape=expect3.shape) * 1.0e-5
assert np.all(diff3 < error3)
assert output[3].shape == expect3.shape
expect4 = np.cumprod(x4, axis=axis4)
diff4 = abs(output[4].asnumpy() - expect4)
error4 = np.ones(shape=expect4.shape) * 1.0e-5
assert np.all(diff4 < error4)
assert output[4].shape == expect4.shape
expect5 = np.cumprod(x5, axis=axis5)
diff5 = abs(output[5].asnumpy() - expect5)
error5 = np.ones(shape=expect5.shape) * 1.0e-5
assert np.all(diff5 < error5)
assert output[5].shape == expect5.shape
expect6 = np.cumprod(x6, axis=axis6)
diff6 = abs(output[6].asnumpy() - expect6)
error6 = np.ones(shape=expect6.shape) * 1.0e-5
assert np.all(diff6 < error6)
assert output[6].shape == expect6.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cum_prod_uint8():
cum_prod(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cum_prod_int8():
cum_prod(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cum_prod_int32():
cum_prod(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cum_prod_float16():
cum_prod(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cum_prod_float32():
cum_prod(np.float32)

View File

@ -0,0 +1,146 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.scatter_add = P.TensorScatterAdd()
def construct(self, x, indices, update):
return self.scatter_add(x, indices, update)
def scatter_net(x, indices, update):
scatter_add = Net()
return scatter_add(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
def numpy_scatter_add(x, indices, update):
indices = np.expand_dims(indices, -1) if indices.ndim == 1 else indices
for idx, up in zip(indices, update):
idx = tuple(idx.tolist())
x[idx] += up
return x
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_scatter():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# indices 2-d, each index points to single value
arr_input = np.arange(21).reshape(3, 7).astype(np.float32)
arr_indices = np.array([[0, 1], [1, 1], [0, 2], [0, 2], [2, 1]]).astype(np.int32)
arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
np.testing.assert_allclose(out, expected, rtol=1e-6)
# indices 2-d, each index points to single value
arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32)
arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32)
arr_update = np.array([-1, -2, -3, -4]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
np.testing.assert_allclose(out, expected, rtol=1e-6)
# indices 2-d, each index points to a slice, and each value points to a single element in the slice
arr_input = np.zeros((3, 3)).astype(np.float32)
arr_indices = np.array([[0], [2], [1]]).astype(np.int32)
arr_update = np.array([[-1, 4, 3], [-2, 0, 1], [-3, 1, 2]]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(21).reshape(3, 7).astype(np.float32)
arr_indices = np.array([[0, 1], [1, 1], [0, 5], [0, 2], [2, 1]]).astype(np.int32)
arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32)
arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32)
arr_update = np.array([-1, -2, -3, -4]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
np.testing.assert_allclose(out, expected, rtol=1e-6)
# Below are from test_tensor_scatter_update.py
arr_input = np.arange(25).reshape(5, 5).astype(np.float32)
arr_indices = np.array([[[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4]],
[[0, 4],
[1, 3],
[2, 2],
[3, 1],
[4, 0]]]).astype(np.int32)
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = np.array([[11, 1, 2, 3, 70],
[5, 28, 7, 85, 9],
[10, 11, 78, 13, 14],
[15, 115, 17, 62, 19],
[120, 21, 22, 23, 79]]).astype(np.float32)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(25).reshape(5, 5).astype(np.float64)
arr_indices = np.array([[[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4]],
[[0, 4],
[1, 3],
[2, 2],
[3, 1],
[4, 0]]]).astype(np.int64)
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = np.array([[11, 1, 2, 3, 70],
[5, 28, 7, 85, 9],
[10, 11, 78, 13, 14],
[15, 115, 17, 62, 19],
[120, 21, 22, 23, 79]]).astype(np.float64)
np.testing.assert_allclose(out, expected, rtol=1e-6)
arr_input = np.arange(25).reshape(5, 5).astype(np.int32)
arr_indices = np.array([[[0, 0],
[1, 1],
[2, 2],
[3, 3],
[4, 4]],
[[0, 4],
[1, 3],
[2, 2],
[3, 1],
[4, 0]]]).astype(np.int32)
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32)
out = scatter_net(arr_input, arr_indices, arr_update)
expected = np.array([[11, 1, 2, 3, 70],
[5, 28, 7, 85, 9],
[10, 11, 78, 13, 14],
[15, 115, 17, 62, 19],
[120, 21, 22, 23, 79]]).astype(np.int32)
np.testing.assert_allclose(out, expected, rtol=1e-6)