forked from mindspore-Ecosystem/mindspore
!16668 add gpu op cumprod and tensor_scatter_add
From: @yanglf1121 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
b75c694032
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/tensor_scatter_add_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
TensorScatterAddGpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorScatterAddGpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorScatterAddGpuFwdKernel, double, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
TensorScatterAddGpuFwdKernel, char, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
TensorScatterAddGpuFwdKernel, uchar, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
TensorScatterAddGpuFwdKernel, int, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterAdd,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorScatterAddGpuFwdKernel, double, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,208 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_add.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class TensorScatterAddGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
TensorScatterAddGpuFwdKernel()
|
||||
: input_size_(1),
|
||||
update_size_(1),
|
||||
indices_size_(1),
|
||||
output_size_(1),
|
||||
block_size_(1),
|
||||
indices_stride_(nullptr),
|
||||
work_shape_(nullptr),
|
||||
indices_dim_0_(0),
|
||||
indices_dim_1_(0),
|
||||
memcpy_flag_(false) {}
|
||||
~TensorScatterAddGpuFwdKernel() {
|
||||
if (indices_stride_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_));
|
||||
}
|
||||
if (work_shape_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices = GetDeviceAddress<S>(inputs, 1);
|
||||
T *update = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (!memcpy_flag_) {
|
||||
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
||||
const size_t vec_work_len = sizeof(S) * vec_work_shape_.size();
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpy failed in TensorScatterAddGpuFwdKernel::Launch.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpy failed in TensorScatterAddGpuFwdKernel::Launch.");
|
||||
memcpy_flag_ = true;
|
||||
}
|
||||
|
||||
const size_t update_size = update_size_ / sizeof(T);
|
||||
const size_t output_size = output_size_ / sizeof(T);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&output[0], &input[0], input_size_, cudaMemcpyDeviceToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync output failed");
|
||||
|
||||
TensorScatterAdd(input, indices, update, output, block_size_, update_size, output_size, indices_dim_0_,
|
||||
indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
memcpy_flag_ = false;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but TensorScatterAdd needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but TensorScatterAdd has 1 output.";
|
||||
return false;
|
||||
}
|
||||
|
||||
update_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
std::vector<size_t> shape_me = input_shapes_;
|
||||
(void)std::transform(shape_me.begin(), shape_me.end(), std::back_inserter(vec_work_shape_),
|
||||
[](const size_t &value) { return static_cast<S>(value); });
|
||||
|
||||
GetSize();
|
||||
|
||||
const size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
||||
void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len);
|
||||
if (indices_stride_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len;
|
||||
}
|
||||
indices_stride_ = static_cast<S *>(indices_stride_work);
|
||||
|
||||
const size_t vec_work_len = sizeof(S) * vec_work_shape_.size();
|
||||
void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len);
|
||||
if (work_shape_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len;
|
||||
}
|
||||
work_shape_ = static_cast<S *>(work_shape_work);
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(indices_size_);
|
||||
input_size_list_.push_back(update_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
void GetSize() {
|
||||
input_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < input_shapes_.size(); i++) {
|
||||
input_size_ *= input_shapes_[i];
|
||||
}
|
||||
|
||||
indices_size_ = sizeof(S);
|
||||
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
||||
indices_size_ *= indices_shapes_[i];
|
||||
}
|
||||
update_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < update_shapes_.size(); i++) {
|
||||
update_size_ *= update_shapes_[i];
|
||||
}
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < output_shapes_.size(); i++) {
|
||||
output_size_ *= output_shapes_[i];
|
||||
}
|
||||
|
||||
// calculate indices dim 0/1
|
||||
indices_dim_0_ = indices_shapes_[0];
|
||||
indices_dim_1_ = indices_shapes_[indices_shapes_.size() - 1];
|
||||
|
||||
// calculate block_size
|
||||
for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) {
|
||||
block_size_ *= output_shapes_[i];
|
||||
}
|
||||
|
||||
// calculate indices_stride
|
||||
vec_indices_stride_.resize(indices_dim_1_, 0);
|
||||
vec_indices_stride_[indices_dim_1_ - 1] = block_size_;
|
||||
|
||||
for (size_t i = indices_dim_1_ - 1; i > 0; --i) {
|
||||
vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> update_shapes_;
|
||||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
std::vector<S> vec_indices_stride_;
|
||||
std::vector<S> vec_work_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
size_t input_size_;
|
||||
size_t update_size_;
|
||||
size_t indices_size_;
|
||||
size_t output_size_;
|
||||
size_t block_size_;
|
||||
|
||||
S *indices_stride_;
|
||||
S *work_shape_;
|
||||
size_t indices_dim_0_;
|
||||
size_t indices_dim_1_;
|
||||
bool memcpy_flag_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TENSOR_SCATTER_ADD_GPU_KERNEL_H
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cumprod_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void Copy(T *input, T *output, size_t size) {
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < size; write_index += step) {
|
||||
input[write_index] = output[write_index];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LeftMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = 0;
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void RightMove(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (int j = dim1 - 1; j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == dim1 - 1) {
|
||||
output[read_index] = 0;
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = input[read_index2];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void CumProdKernelReverse(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (int j = dim1 - 1; j >= 0; --j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == dim1 - 1) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j + 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] * input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CumProdKernel(const T *input, T *output, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2) {
|
||||
size_t num = dim0 * dim2;
|
||||
size_t i, k, offset;
|
||||
size_t step = blockDim.x * gridDim.x;
|
||||
for (size_t write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; write_index += step) {
|
||||
i = write_index / dim2 % dim0;
|
||||
k = write_index % dim2;
|
||||
offset = i * stride + k;
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
size_t read_index = j * stride2 + offset;
|
||||
if (j == 0) {
|
||||
output[read_index] = input[read_index];
|
||||
} else {
|
||||
size_t read_index2 = (j - 1) * stride2 + offset;
|
||||
output[read_index] = output[read_index2] * input[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream) {
|
||||
int size = dim0 * dim2;
|
||||
if (exclusive_) {
|
||||
if (reverse_) {
|
||||
RightMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
|
||||
CumProdKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride,
|
||||
stride2);
|
||||
} else {
|
||||
LeftMove<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
Copy<<<GET_BLOCKS(size * dim1), GET_THREADS, 0, stream>>>(workspace, output, size * dim1);
|
||||
CumProdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(workspace, output, dim0, dim1, dim2, stride, stride2);
|
||||
}
|
||||
} else {
|
||||
if (reverse_) {
|
||||
CumProdKernelReverse<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride,
|
||||
stride2);
|
||||
} else {
|
||||
CumProdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, output, dim0, dim1, dim2, stride, stride2);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template void CumProd<uint8_t>(const uint8_t *input, uint8_t *output, uint8_t *workspace, size_t dim0, size_t dim1,
|
||||
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
|
||||
cudaStream_t stream);
|
||||
template void CumProd<int8_t>(const int8_t *input, int8_t *output, int8_t *workspace, size_t dim0, size_t dim1,
|
||||
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
|
||||
cudaStream_t stream);
|
||||
template void CumProd<int32_t>(const int32_t *input, int32_t *output, int32_t *workspace, size_t dim0, size_t dim1,
|
||||
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
|
||||
cudaStream_t stream);
|
||||
template void CumProd<double>(const double *input, double *output, double *workspace, size_t dim0, size_t dim1,
|
||||
size_t dim2, size_t stride, size_t stride2, bool exclusive_, bool reverse_,
|
||||
cudaStream_t stream);
|
||||
template void CumProd<float>(const float *input, float *output, float *workspace, size_t dim0, size_t dim1, size_t dim2,
|
||||
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
||||
template void CumProd<half>(const half *input, half *output, half *workspace, size_t dim0, size_t dim1, size_t dim2,
|
||||
size_t stride, size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
|
@ -0,0 +1,22 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
||||
template <typename T>
|
||||
void CumProd(const T *input, T *output, T *workspace, size_t dim0, size_t dim1, size_t dim2, size_t stride,
|
||||
size_t stride2, bool exclusive_, bool reverse_, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_IMPL_CUH_
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_add.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void TensorScatterAddKernel(T *input, S *indices, T *update, T *output, const size_t block_size,
|
||||
const size_t input_size, const size_t output_size, const size_t indices_dim_0,
|
||||
const size_t indices_dim_1, S *indices_stride, S *work_shape) {
|
||||
int i, j;
|
||||
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
|
||||
read_index += blockDim.x * gridDim.x) {
|
||||
size_t write_index = 0;
|
||||
bool out_bound = false;
|
||||
|
||||
i = read_index / block_size;
|
||||
j = read_index % block_size;
|
||||
|
||||
for (size_t k = 0; k < indices_dim_1; k++) {
|
||||
S indices_i = indices[i * indices_dim_1 + k];
|
||||
out_bound |= indices_i >= work_shape[k];
|
||||
write_index += indices_i * indices_stride[k];
|
||||
}
|
||||
|
||||
write_index += j;
|
||||
out_bound |= write_index >= output_size;
|
||||
|
||||
if (!out_bound) {
|
||||
MsAtomicAdd(&output[write_index], update[read_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void TensorScatterAdd(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
S *indices_stride, S *work_shape, cudaStream_t stream) {
|
||||
TensorScatterAddKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
|
||||
input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride,
|
||||
work_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
template void TensorScatterAdd<half, int>(half *input, int *indices, half *update, half *output,
|
||||
const size_t &block_size, const size_t &input_size, const size_t &output_size,
|
||||
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
|
||||
int *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterAdd<float, int>(float *input, int *indices, float *update, float *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterAdd<double, int>(double *input, int *indices, double *update, double *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterAdd<char, int>(char *input, int *indices, char *update, char *output,
|
||||
const size_t &block_size, const size_t &input_size, const size_t &output_size,
|
||||
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
|
||||
int *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterAdd<unsigned char, int>(unsigned char *input, int *indices, unsigned char *update,
|
||||
unsigned char *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size,
|
||||
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
int *indices_stride, int *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterAdd<int, int>(int *input, int *indices, int *update, int *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size,
|
||||
const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride,
|
||||
int *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterAdd<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int64_t *indices_stride,
|
||||
int64_t *work_shape, cudaStream_t stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void TensorScatterAdd(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1,
|
||||
S *indices_stride, S *work_shape, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_ADD_IMPL_CUH
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/cumprod_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
CumProdGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
CumProdGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CumProdGpuKernel, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CumProdGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CumProdGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(CumProd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CumProdGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,108 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/cumprod_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CumProdGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CumProdGpuKernel() : exclusive_(false), reverse_(false), axis_(0), input_size_0_(0), stride_(0), stride2_(0) {}
|
||||
~CumProdGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
T *ws_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
CumProd(input_addr, output_addr, ws_addr, dims_[0], dims_[1], dims_[2], stride_, stride2_, exclusive_, reverse_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but CumProdGpuKernel needs 1.";
|
||||
return false;
|
||||
}
|
||||
input_size_0_ = sizeof(T);
|
||||
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
||||
exclusive_ = GetAttr<bool>(kernel_node, "exclusive");
|
||||
reverse_ = GetAttr<bool>(kernel_node, "reverse");
|
||||
int input_dim_length = SizeToInt(shape_.size());
|
||||
if (axis_ >= input_dim_length) {
|
||||
MS_LOG(EXCEPTION) << "Axis out of bounds.";
|
||||
}
|
||||
while (axis_ < 0) {
|
||||
axis_ += input_dim_length;
|
||||
}
|
||||
for (size_t i = 0; i < shape_.size(); i++) {
|
||||
input_size_0_ *= shape_[i];
|
||||
}
|
||||
Reshape();
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_0_);
|
||||
output_size_list_.push_back(input_size_0_);
|
||||
workspace_size_list_.push_back(input_size_0_);
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
dims_[0] = 1;
|
||||
dims_[1] = shape_[IntToSize(axis_)];
|
||||
dims_[2] = 1;
|
||||
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
||||
dims_[0] *= shape_[i];
|
||||
}
|
||||
for (size_t i = IntToSize(axis_) + 1; i < shape_.size(); i++) {
|
||||
dims_[2] *= shape_[i];
|
||||
}
|
||||
stride_ = dims_[1] * dims_[2];
|
||||
stride2_ = dims_[2];
|
||||
return;
|
||||
}
|
||||
bool exclusive_;
|
||||
bool reverse_;
|
||||
int axis_;
|
||||
size_t input_size_0_;
|
||||
size_t stride_;
|
||||
size_t stride2_;
|
||||
size_t dims_[3] = {};
|
||||
std::vector<size_t> shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUMSUM_GPU_KERNEL_H_
|
|
@ -757,6 +757,18 @@ def get_bprop_tensor_scatter_update(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.TensorScatterAdd)
|
||||
def get_bprop_tensor_scatter_add(self):
|
||||
"""Generate bprop for TensorScatterAdd"""
|
||||
gather_nd = P.GatherNd()
|
||||
|
||||
def bprop(x, indices, update, out, dout):
|
||||
update_grad = gather_nd(dout, indices)
|
||||
return dout, zeros_like(indices), update_grad
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.ScatterMax)
|
||||
def get_bprop_scatter_max(self):
|
||||
"""Generate bprop for ScatterMax"""
|
||||
|
|
|
@ -188,6 +188,9 @@ bool_eq = Primitive("bool_eq")
|
|||
logical_and = P.LogicalAnd()
|
||||
logical_or = P.LogicalOr()
|
||||
logical_not = P.LogicalNot()
|
||||
cumsum = P.CumSum()
|
||||
cumprod = P.CumProd()
|
||||
tensor_scatter_add = P.TensorScatterAdd()
|
||||
array_to_scalar = Primitive('array_to_scalar')
|
||||
is_ = Primitive("is_")
|
||||
is_not = Primitive("is_not")
|
||||
|
|
|
@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, DynamicShape, Size, Slice, Split, TransShape, ParallelConcat, Padding, UniqueWithPad,
|
||||
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Sort,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate, TensorScatterAdd, EditDistance, Sort,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
|
||||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
|
@ -308,6 +308,7 @@ __all__ = [
|
|||
'MirrorPad',
|
||||
'GatherNd',
|
||||
'TensorScatterUpdate',
|
||||
'TensorScatterAdd',
|
||||
'ScatterUpdate',
|
||||
'ScatterNdUpdate',
|
||||
'Floor',
|
||||
|
|
|
@ -3469,8 +3469,6 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
validator.check('the dimension of x', len(x_shape),
|
||||
'the dimension of indices', indices_shape[-1], Rel.GE)
|
||||
if indices_shape[:-1] + x_shape[indices_shape[-1]:] != value_shape:
|
||||
raise ValueError("For 'TensorScatterUpdate', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
@ -3482,6 +3480,66 @@ class TensorScatterUpdate(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class TensorScatterAdd(PrimitiveWithInfer):
|
||||
"""
|
||||
Creates a new tensor by adding the values from the positions in `input_x` indicicated by
|
||||
`indices`, with values from `update`. When multiple values are given for the same
|
||||
index, the updated result will be the sum of all values. This operation is almost
|
||||
equivalent to using ScatterNdAdd, except that the updates are applied on `Tensor`
|
||||
instead of `Parameter`.
|
||||
|
||||
The last axis of `indices` is the depth of each index vectors. For each index vector,
|
||||
there must be a corresponding value in `update`. The shape of `update` should be
|
||||
equal to the shape of `input_x[indices]`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The target tensor. The dimension of input_x must be no less than indices.shape[-1].
|
||||
- **indices** (Tensor) - The index of input tensor whose data type is int32 or int64.
|
||||
The rank must be at least 2.
|
||||
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
|
||||
and update.shape should be equal to indices.shape[:-1] + input_x.shape[indices.shape[-1]:].
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape and type as `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices` is neither int32 nor int64.
|
||||
ValueError: If length of shape of `input_x` is less than the last dimension of shape of `indices`.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32)
|
||||
>>> indices = Tensor(np.array([[0, 0], [0, 0]]), mindspore.int32)
|
||||
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
|
||||
>>> op = ops.TensorScatterAdd()
|
||||
>>> output = op(input_x, indices, update)
|
||||
>>> print(output)
|
||||
[[ 3.1 0.3 3.6]
|
||||
[ 0.4 0.5 -3.2]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize TensorScatterAdd"""
|
||||
self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y'])
|
||||
|
||||
def infer_shape(self, x_shape, indices_shape, value_shape):
|
||||
if len(indices_shape) < 2:
|
||||
raise ValueError("For 'TensorScatterAdd', the rank of the indices must > 2.")
|
||||
update_shape = indices_shape[:-1] + x_shape[indices_shape[-1]:]
|
||||
if update_shape != value_shape:
|
||||
raise ValueError("For 'TensorScatterAdd', input value are not match with input indices.")
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
|
||||
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
|
||||
args = {"x": x_dtype, "value": value_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
||||
class ScatterUpdate(_ScatterOp_Dynamic):
|
||||
r"""
|
||||
Updates tensor values by using input indices and value.
|
||||
|
|
|
@ -719,7 +719,7 @@ class CumProd(PrimitiveWithInfer):
|
|||
ValueError: If `axis` is None.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> a, b, c, = 1, 2, 3
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def cum_prod(nptype):
|
||||
context.set_context(device_target='GPU')
|
||||
x0 = np.random.rand(2, 3, 4, 4).astype(nptype)
|
||||
axis0 = 3
|
||||
|
||||
x1 = np.random.rand(2, 3, 4, 4).astype(nptype)
|
||||
axis1 = 3
|
||||
|
||||
x2 = np.random.rand(2, 3, 1, 4).astype(nptype)
|
||||
axis2 = 2
|
||||
|
||||
x3 = np.random.rand(2, 3, 1, 4).astype(nptype)
|
||||
axis3 = 2
|
||||
|
||||
x4 = np.random.rand(2, 3, 4, 4).astype(nptype)
|
||||
axis4 = 1
|
||||
|
||||
x5 = np.random.rand(2, 3).astype(nptype)
|
||||
axis5 = 1
|
||||
|
||||
x6 = np.random.rand(1, 1, 1, 1).astype(nptype)
|
||||
axis6 = 0
|
||||
|
||||
class CumProd(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(CumProd, self).__init__()
|
||||
|
||||
self.x0 = Tensor(x0)
|
||||
self.axis0 = axis0
|
||||
|
||||
self.x1 = Tensor(x1)
|
||||
self.axis1 = axis1
|
||||
|
||||
self.x2 = Tensor(x2)
|
||||
self.axis2 = axis2
|
||||
|
||||
self.x3 = Tensor(x3)
|
||||
self.axis3 = axis3
|
||||
|
||||
self.x4 = Tensor(x4)
|
||||
self.axis4 = axis4
|
||||
|
||||
self.x5 = Tensor(x5)
|
||||
self.axis5 = axis5
|
||||
|
||||
self.x6 = Tensor(x6)
|
||||
self.axis6 = axis6
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return (P.CumProd()(self.x0, self.axis0),
|
||||
P.CumProd()(self.x1, self.axis1),
|
||||
P.CumProd()(self.x2, self.axis2),
|
||||
P.CumProd()(self.x3, self.axis3),
|
||||
P.CumProd()(self.x4, self.axis4),
|
||||
P.CumProd()(self.x5, self.axis5),
|
||||
P.CumProd()(self.x6, self.axis6))
|
||||
|
||||
|
||||
cumprod = CumProd(nptype)
|
||||
output = cumprod()
|
||||
|
||||
expect0 = np.cumprod(x0, axis=axis0)
|
||||
diff0 = abs(output[0].asnumpy() - expect0)
|
||||
error0 = np.ones(shape=expect0.shape) * 1.0e-5
|
||||
assert np.all(diff0 < error0)
|
||||
assert output[0].shape == expect0.shape
|
||||
|
||||
expect1 = np.cumprod(x1, axis=axis1)
|
||||
diff1 = abs(output[1].asnumpy() - expect1)
|
||||
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||
assert np.all(diff1 < error1)
|
||||
assert output[1].shape == expect1.shape
|
||||
|
||||
expect2 = np.cumprod(x2, axis=axis2)
|
||||
diff2 = abs(output[2].asnumpy() - expect2)
|
||||
error2 = np.ones(shape=expect2.shape) * 1.0e-5
|
||||
assert np.all(diff2 < error2)
|
||||
assert output[2].shape == expect2.shape
|
||||
|
||||
expect3 = np.cumprod(x3, axis=axis3)
|
||||
diff3 = abs(output[3].asnumpy() - expect3)
|
||||
error3 = np.ones(shape=expect3.shape) * 1.0e-5
|
||||
assert np.all(diff3 < error3)
|
||||
assert output[3].shape == expect3.shape
|
||||
|
||||
expect4 = np.cumprod(x4, axis=axis4)
|
||||
diff4 = abs(output[4].asnumpy() - expect4)
|
||||
error4 = np.ones(shape=expect4.shape) * 1.0e-5
|
||||
assert np.all(diff4 < error4)
|
||||
assert output[4].shape == expect4.shape
|
||||
|
||||
expect5 = np.cumprod(x5, axis=axis5)
|
||||
diff5 = abs(output[5].asnumpy() - expect5)
|
||||
error5 = np.ones(shape=expect5.shape) * 1.0e-5
|
||||
assert np.all(diff5 < error5)
|
||||
assert output[5].shape == expect5.shape
|
||||
|
||||
expect6 = np.cumprod(x6, axis=axis6)
|
||||
diff6 = abs(output[6].asnumpy() - expect6)
|
||||
error6 = np.ones(shape=expect6.shape) * 1.0e-5
|
||||
assert np.all(diff6 < error6)
|
||||
assert output[6].shape == expect6.shape
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cum_prod_uint8():
|
||||
cum_prod(np.uint8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cum_prod_int8():
|
||||
cum_prod(np.int8)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cum_prod_int32():
|
||||
cum_prod(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cum_prod_float16():
|
||||
cum_prod(np.float16)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_cum_prod_float32():
|
||||
cum_prod(np.float32)
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.scatter_add = P.TensorScatterAdd()
|
||||
|
||||
def construct(self, x, indices, update):
|
||||
return self.scatter_add(x, indices, update)
|
||||
|
||||
|
||||
def scatter_net(x, indices, update):
|
||||
scatter_add = Net()
|
||||
return scatter_add(Tensor(x), Tensor(indices), Tensor(update)).asnumpy()
|
||||
|
||||
def numpy_scatter_add(x, indices, update):
|
||||
indices = np.expand_dims(indices, -1) if indices.ndim == 1 else indices
|
||||
for idx, up in zip(indices, update):
|
||||
idx = tuple(idx.tolist())
|
||||
x[idx] += up
|
||||
return x
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_scatter():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
# indices 2-d, each index points to single value
|
||||
arr_input = np.arange(21).reshape(3, 7).astype(np.float32)
|
||||
arr_indices = np.array([[0, 1], [1, 1], [0, 2], [0, 2], [2, 1]]).astype(np.int32)
|
||||
arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
# indices 2-d, each index points to single value
|
||||
arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32)
|
||||
arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32)
|
||||
arr_update = np.array([-1, -2, -3, -4]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
# indices 2-d, each index points to a slice, and each value points to a single element in the slice
|
||||
arr_input = np.zeros((3, 3)).astype(np.float32)
|
||||
arr_indices = np.array([[0], [2], [1]]).astype(np.int32)
|
||||
arr_update = np.array([[-1, 4, 3], [-2, 0, 1], [-3, 1, 2]]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(21).reshape(3, 7).astype(np.float32)
|
||||
arr_indices = np.array([[0, 1], [1, 1], [0, 5], [0, 2], [2, 1]]).astype(np.int32)
|
||||
arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(24).reshape(4, 2, 3).astype(np.float32)
|
||||
arr_indices = np.array([[0, 0, 0], [1, 1, 1], [0, 1, 1], [3, 0, 1]]).astype(np.int32)
|
||||
arr_update = np.array([-1, -2, -3, -4]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = numpy_scatter_add(arr_input, arr_indices, arr_update)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
# Below are from test_tensor_scatter_update.py
|
||||
arr_input = np.arange(25).reshape(5, 5).astype(np.float32)
|
||||
arr_indices = np.array([[[0, 0],
|
||||
[1, 1],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
[4, 4]],
|
||||
[[0, 4],
|
||||
[1, 3],
|
||||
[2, 2],
|
||||
[3, 1],
|
||||
[4, 0]]]).astype(np.int32)
|
||||
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = np.array([[11, 1, 2, 3, 70],
|
||||
[5, 28, 7, 85, 9],
|
||||
[10, 11, 78, 13, 14],
|
||||
[15, 115, 17, 62, 19],
|
||||
[120, 21, 22, 23, 79]]).astype(np.float32)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(25).reshape(5, 5).astype(np.float64)
|
||||
arr_indices = np.array([[[0, 0],
|
||||
[1, 1],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
[4, 4]],
|
||||
[[0, 4],
|
||||
[1, 3],
|
||||
[2, 2],
|
||||
[3, 1],
|
||||
[4, 0]]]).astype(np.int64)
|
||||
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = np.array([[11, 1, 2, 3, 70],
|
||||
[5, 28, 7, 85, 9],
|
||||
[10, 11, 78, 13, 14],
|
||||
[15, 115, 17, 62, 19],
|
||||
[120, 21, 22, 23, 79]]).astype(np.float64)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
||||
|
||||
arr_input = np.arange(25).reshape(5, 5).astype(np.int32)
|
||||
arr_indices = np.array([[[0, 0],
|
||||
[1, 1],
|
||||
[2, 2],
|
||||
[3, 3],
|
||||
[4, 4]],
|
||||
[[0, 4],
|
||||
[1, 3],
|
||||
[2, 2],
|
||||
[3, 1],
|
||||
[4, 0]]]).astype(np.int32)
|
||||
arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32)
|
||||
out = scatter_net(arr_input, arr_indices, arr_update)
|
||||
expected = np.array([[11, 1, 2, 3, 70],
|
||||
[5, 28, 7, 85, 9],
|
||||
[10, 11, 78, 13, 14],
|
||||
[15, 115, 17, 62, 19],
|
||||
[120, 21, 22, 23, 79]]).astype(np.int32)
|
||||
np.testing.assert_allclose(out, expected, rtol=1e-6)
|
Loading…
Reference in New Issue