From a52cd685fca1ae3464781bea82d65bf2eccc70ae Mon Sep 17 00:00:00 2001 From: tom__chen Date: Mon, 9 Nov 2020 16:27:34 -0500 Subject: [PATCH] change repeat_element op to a composite op --- .../gpu/arrays/repeat_elements_gpu_kernel.cc | 28 -- .../gpu/arrays/repeat_elements_gpu_kernel.h | 161 --------- .../arrays/repeat_elements_grad_gpu_kernel.cc | 29 -- .../arrays/repeat_elements_grad_gpu_kernel.h | 119 ------- .../cuda_impl/repeat_elements_grad_impl.cu | 48 --- .../cuda_impl/repeat_elements_grad_impl.cuh | 26 -- .../gpu/cuda_impl/repeat_elements_impl.cu | 318 ----------------- .../gpu/cuda_impl/repeat_elements_impl.cuh | 52 --- mindspore/ops/_grad/grad_array_ops.py | 10 - mindspore/ops/composite/__init__.py | 4 +- mindspore/ops/composite/array_ops.py | 100 ++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 21 -- mindspore/ops/operations/array_ops.py | 50 --- .../ops/gpu/test_repeat_elements_grad_op.py | 321 ------------------ tests/st/ops/gpu/test_repeat_elements_op.py | 7 +- .../python/parallel/test_repeat_elements.py | 86 ----- 17 files changed, 108 insertions(+), 1275 deletions(-) delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh create mode 100644 mindspore/ops/composite/array_ops.py delete mode 100644 tests/st/ops/gpu/test_repeat_elements_grad_op.py delete mode 100644 tests/ut/python/parallel/test_repeat_elements.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc deleted file mode 100644 index 1d0ab8a36fa..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - RepeatElementsGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE(RepeatElements, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - RepeatElementsGpuKernel, int32_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h deleted file mode 100644 index e8ca831b0ed..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_gpu_kernel.h +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ - -#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh" - -#include - -#include -#include - -#include "backend/kernel_compiler/gpu/gpu_kernel.h" -#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class RepeatElementsGpuKernel : public GpuKernel { - public: - RepeatElementsGpuKernel() : rep_(1), axis_(0), input_size_(1), output_size_(0) {} - ~RepeatElementsGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *input_device_address = GetDeviceAddress(inputs, 0); - T *output_device_address = GetDeviceAddress(outputs, 0); - - switch (input_dim_) { - case 1: - CalRepeatElements1d(input_device_address, rep_, axis_, output_device_address, output_size_, - reinterpret_cast(stream_ptr)); - break; - case 2: - CalRepeatElements2d(input_device_address, input_shape_[1], rep_, axis_, output_device_address, output_shape_[1], - output_size_, reinterpret_cast(stream_ptr)); - break; - case 3: - CalRepeatElements3d(input_device_address, input_shape_[1], input_shape_[2], rep_, axis_, output_device_address, - output_shape_[1], output_shape_[2], output_size_, - reinterpret_cast(stream_ptr)); - break; - case 4: - CalRepeatElements4d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], rep_, axis_, - output_device_address, output_shape_[1], output_shape_[2], output_shape_[3], output_size_, - reinterpret_cast(stream_ptr)); - break; - case 5: - CalRepeatElements5d(input_device_address, input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4], - rep_, axis_, output_device_address, output_shape_[1], output_shape_[2], output_shape_[3], - output_shape_[4], output_size_, reinterpret_cast(stream_ptr)); - break; - default: - int *input_shape_device_address = GetDeviceAddress(workspace, 0); - int *output_shape_device_address = GetDeviceAddress(workspace, 1); - int *input_shape_cumulative_product_device_address = GetDeviceAddress(workspace, 2); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(input_shape_device_address, input_shape_.data(), workspace_size_list_[0], - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(output_shape_device_address, output_shape_.data(), workspace_size_list_[1], - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync output_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(input_shape_cumulative_product_device_address, input_shape_cumulative_product_.data(), - workspace_size_list_[2], cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape_cumulative_product_device_address failed"); - - CalRepeatElements(input_device_address, input_dim_, input_shape_device_address, - input_shape_cumulative_product_device_address, rep_, axis_, output_device_address, - output_shape_device_address, output_size_, reinterpret_cast(stream_ptr)); - break; - } - - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_count != 1) { - MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementsGpuKernel expects 1."; - } - - std::vector temp_input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_dim_ = temp_input_shape.size(); - for (size_t e : temp_input_shape) { - input_size_ *= e; - input_shape_.push_back(e); - } - - int cumulative_product = 1; - for (size_t i = input_dim_ - 1; i > 0; i--) { - cumulative_product *= input_shape_[i]; - input_shape_cumulative_product_.push_back(cumulative_product); - } - std::reverse(input_shape_cumulative_product_.begin(), input_shape_cumulative_product_.end()); - - axis_ = static_cast(GetAttr(kernel_node, "axis")); - if (axis_ < 0) { - axis_ += input_dim_; - } - - rep_ = static_cast(GetAttr(kernel_node, "rep")); - output_size_ = input_size_ * rep_; - output_shape_ = input_shape_; - output_shape_[axis_] *= rep_; - - InitSizeLists(); - - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_ * sizeof(T)); - output_size_list_.push_back(output_size_ * sizeof(T)); - - // workspaces for input shape, output shape and cumulative sum - workspace_size_list_.push_back(input_dim_ * sizeof(int)); - workspace_size_list_.push_back(input_dim_ * sizeof(int)); - workspace_size_list_.push_back((input_dim_ - 1) * sizeof(int)); - } - - private: - int rep_; - int axis_; - int input_dim_; - std::vector input_shape_; - std::vector input_shape_cumulative_product_; - std::vector output_shape_; - - size_t input_size_; - size_t output_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc deleted file mode 100644 index c3f364ad6f5..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - RepeatElementsGradGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE(RepeatElementsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - RepeatElementsGradGpuKernel, int32_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h deleted file mode 100644 index 27f5b2cceb4..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/repeat_elements_grad_gpu_kernel.h +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ - -#include "backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh" - -#include - -#include -#include - -#include "backend/kernel_compiler/gpu/gpu_kernel.h" -#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class RepeatElementsGradGpuKernel : public GpuKernel { - public: - RepeatElementsGradGpuKernel() - : rep_(1), axis_(0), input_size_(1), output_size_(0), outer_size_(1), repeat_dim_size_(1), inner_size_(1) {} - ~RepeatElementsGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *dy = GetDeviceAddress(inputs, 0); - T *dx = GetDeviceAddress(outputs, 0); - - CalRepeatElementsGrad(dy, rep_, dx, outer_size_, repeat_dim_size_, inner_size_, - reinterpret_cast(stream_ptr)); - - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_count != 1) { - MS_LOG(EXCEPTION) << input_count << " arguments were provided, but RepeatElementGradGpuKernel expects 1."; - } - - std::vector dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - int dy_dim = dy_shape.size(); - - axis_ = static_cast(GetAttr(kernel_node, "axis")); - if (axis_ < 0) { - axis_ += dy_dim; - } - rep_ = static_cast(GetAttr(kernel_node, "rep")); - if (axis_ >= dy_dim) { - axis_ = dy_dim - 1; - rep_ = 1; - } - - for (int i = 0; i < dy_dim; i++) { - auto e = dy_shape[i]; - input_size_ *= e; - input_shape_.push_back(e); - if (i < axis_) { - outer_size_ *= e; - } else if (i > axis_) { - inner_size_ *= e; - } else { - repeat_dim_size_ = e / rep_; - } - } - - output_size_ = input_size_ / rep_; - output_shape_ = input_shape_; - output_shape_[axis_] /= rep_; - - InitSizeLists(); - - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_ * sizeof(T)); - output_size_list_.push_back(output_size_ * sizeof(T)); - } - - private: - int rep_; - int axis_; - size_t input_size_; - size_t output_size_; - int outer_size_; - int repeat_dim_size_; - int inner_size_; - std::vector input_shape_; - std::vector output_shape_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_REPEAT_ELEMENTS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu deleted file mode 100644 index 4c125e6ed7d..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cu +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "repeat_elements_grad_impl.cuh" -#include "runtime/device/gpu/cuda_common.h" - -template -__global__ void RepeatElementsGrad(const int dx_size, const T *dy, const int rep, T *dx, const int outer_size, - const int repeat_dim_size, const int inner_size) { - for (size_t t_id = blockIdx.x * blockDim.x + threadIdx.x; t_id < dx_size; t_id += gridDim.x * blockDim.x) { - int inner_id = t_id % inner_size; - int repeat_dim_id = t_id / inner_size % repeat_dim_size; - int outer_id = t_id / inner_size / repeat_dim_size; - T dx_i = static_cast(0); - for (int i = 0; i < rep; i++) { - dx_i += dy[(outer_id * rep * repeat_dim_size * inner_size) + (repeat_dim_id * rep * inner_size) + - (i * inner_size) + inner_id]; - } - dx[t_id] = dx_i; - } -} - -template -void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size, - const int inner_size, cudaStream_t cuda_stream) { - const int dx_size = outer_size * repeat_dim_size * inner_size; - RepeatElementsGrad<<>>(dx_size, dy, rep, dx, outer_size, - repeat_dim_size, inner_size); -} - -template void CalRepeatElementsGrad(const int *dy, const int rep, int *dx, const int outer_size, - const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream); -template void CalRepeatElementsGrad(const half *dy, const int rep, half *dx, const int outer_size, - const int repeat_dim_size, const int inner_size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh deleted file mode 100644 index 0cb46f1bf6b..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_grad_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ - -#include - -template -void CalRepeatElementsGrad(const T *dy, const int rep, T *dx, const int outer_size, const int repeat_dim_size, - const int inner_size, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu deleted file mode 100644 index c95f2b6e70f..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cu +++ /dev/null @@ -1,318 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "repeat_elements_impl.cuh" -#include "runtime/device/gpu/cuda_common.h" - -template -__global__ void RepeatElements1d(const T *input, const int rep, const int axis, T *output, - const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int copied_value_index = gt_id / rep; - output[gt_id] = input[copied_value_index]; - } -} - -template -__global__ void RepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, - const int output_d1, const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int global_array_index = gt_id; - - int index_d1 = global_array_index % output_d1; - global_array_index -= index_d1; - global_array_index /= output_d1; - - int index_d0 = global_array_index; - - switch (axis) { - case 0: - index_d0 /= rep; - break; - case 1: - index_d1 /= rep; - break; - } - - const int term0 = index_d0 * input_d1; - const int copied_value_index = term0 + index_d1; - output[gt_id] = input[copied_value_index]; - } -} - -template -__global__ void RepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, - T *output, const int output_d1, const int output_d2, const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int global_array_index = gt_id; - - int index_d2 = global_array_index % output_d2; - global_array_index -= index_d2; - global_array_index /= output_d2; - - int index_d1 = global_array_index % output_d1; - global_array_index -= index_d1; - global_array_index /= output_d1; - - int index_d0 = global_array_index; - - switch (axis) { - case 0: - index_d0 /= rep; - break; - case 1: - index_d1 /= rep; - break; - case 2: - index_d2 /= rep; - break; - default: - asm("trap;"); - } - - const int term0 = index_d0 * input_d1 * input_d2; - const int term1 = index_d1 * input_d2; - const int copied_value_index = term0 + term1 + index_d2; - output[gt_id] = input[copied_value_index]; - } -} - -template -__global__ void RepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, - const int rep, const int axis, T *output, const int output_d1, const int output_d2, - const int output_d3, const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int global_array_index = gt_id; - - int index_d3 = global_array_index % output_d3; - global_array_index -= index_d3; - global_array_index /= output_d3; - - int index_d2 = global_array_index % output_d2; - global_array_index -= index_d2; - global_array_index /= output_d2; - - int index_d1 = global_array_index % output_d1; - global_array_index -= index_d1; - global_array_index /= output_d1; - - int index_d0 = global_array_index; - - switch (axis) { - case 0: - index_d0 /= rep; - break; - case 1: - index_d1 /= rep; - break; - case 2: - index_d2 /= rep; - break; - case 3: - index_d3 /= rep; - break; - } - - const int term0 = index_d0 * input_d1 * input_d2 * input_d3; - const int term1 = index_d1 * input_d2 * input_d3; - const int term2 = index_d2 * input_d3; - const int copied_value_index = term0 + term1 + term2 + index_d3; - output[gt_id] = input[copied_value_index]; - } -} - -template -__global__ void RepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, - const int input_d4, const int rep, const int axis, T *output, const int output_d1, - const int output_d2, const int output_d3, const int output_d4, const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int global_array_index = gt_id; - - int index_d4 = global_array_index % output_d4; - global_array_index -= index_d4; - global_array_index /= output_d4; - - int index_d3 = global_array_index % output_d3; - global_array_index -= index_d3; - global_array_index /= output_d3; - - int index_d2 = global_array_index % output_d2; - global_array_index -= index_d2; - global_array_index /= output_d2; - - int index_d1 = global_array_index % output_d1; - global_array_index -= index_d1; - global_array_index /= output_d1; - - int index_d0 = global_array_index; - - switch (axis) { - case 0: - index_d0 /= rep; - break; - case 1: - index_d1 /= rep; - break; - case 2: - index_d2 /= rep; - break; - case 3: - index_d3 /= rep; - break; - case 4: - index_d4 /= rep; - break; - } - - const int term0 = index_d0 * input_d1 * input_d2 * input_d3 * input_d4; - const int term1 = index_d1 * input_d2 * input_d3 * input_d4; - const int term2 = index_d2 * input_d3 * input_d4; - const int term3 = index_d3 * input_d4; - const int copied_value_index = term0 + term1 + term2 + term3 + index_d4; - output[gt_id] = input[copied_value_index]; - } -} - -template -__global__ void RepeatElements(const T *input, const int input_dim, const int* const input_shape, - const int* const coefficients, const int rep, const int axis, T *output, - const int* const output_shape, const int output_size) { - for (size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x; gt_id < output_size; gt_id += gridDim.x * blockDim.x) { - int index_tuple[REPEAT_ELEMENTS_MAX_INPUT_DIM]; - - int global_array_index = gt_id; - for (size_t i = input_dim - 1; i > 0; i--) { - int coordinate = global_array_index % output_shape[i]; - index_tuple[i] = coordinate; - global_array_index -= coordinate; - global_array_index /= output_shape[i]; - } - index_tuple[0] = global_array_index; - - index_tuple[axis] /= rep; - - int copied_value_index = 0; - for (size_t i = 0; i < input_dim - 1; i++) { - copied_value_index += index_tuple[i] * coefficients[i]; - } - copied_value_index += index_tuple[input_dim - 1]; - - output[gt_id] = input[copied_value_index]; - } -} - -template -void CalRepeatElements1d( - const T *input, const int rep, const int axis, T *output, const int output_size, cudaStream_t cuda_stream) { - RepeatElements1d<<>>(input, rep, axis, output, output_size); -} - -template -void CalRepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, - const int output_d1, const int output_size, cudaStream_t cuda_stream) { - RepeatElements2d<<>>(input, input_d1, rep, axis, output, - output_d1, output_size); -} - -template -void CalRepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, - T *output, const int output_d1, const int output_d2, const int output_size, - cudaStream_t cuda_stream) { - RepeatElements3d<<>>(input, input_d1, input_d2, rep, axis, - output, output_d1, output_d2, output_size); -} - -template -void CalRepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int rep, - const int axis, T *output, const int output_d1, const int output_d2, const int output_d3, - const int output_size, cudaStream_t cuda_stream) { - RepeatElements4d<<>>(input, input_d1, input_d2, input_d3, rep, - axis, output, output_d1, output_d2, - output_d3, output_size); -} - -template -void CalRepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int input_d4, - const int rep, const int axis, T *output, const int output_d1, const int output_d2, - const int output_d3, const int output_d4, const int output_size, cudaStream_t cuda_stream) { - RepeatElements5d<<>>(input, input_d1, input_d2, input_d3, - input_d4, rep, axis, output, output_d1, - output_d2, output_d3, output_d4, - output_size); -} - -template -void CalRepeatElements(const T *input, const int input_dim, const int* const input_shape, - const int* const input_shape_cumulative_product, const int rep, const int axis, T *output, - const int* const output_shape, const int output_size, cudaStream_t cuda_stream) { - RepeatElements<<>>(input, input_dim, input_shape, - input_shape_cumulative_product, rep, axis, - output, output_shape, output_size); -} - -// int32 -template void CalRepeatElements1d( - const int *input, const int rep, const int axis, int *output, const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements2d(const int *input, const int input_d1, const int rep, const int axis, int *output, - const int output_d1, const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements3d(const int *input, const int input_d1, const int input_d2, const int rep, - const int axis, int *output, const int output_d1, const int output_d2, - const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements4d(const int *input, const int input_d1, const int input_d2, const int input_d3, - const int rep, const int axis, int *output, const int output_d1, - const int output_d2, const int output_d3, const int output_size, - cudaStream_t cuda_stream); - -template void CalRepeatElements5d(const int *input, const int input_d1, const int input_d2, const int input_d3, - const int input_d4, const int rep, const int axis, int *output, - const int output_d1, const int output_d2, const int output_d3, - const int output_d4, const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements(const int *input, const int input_dim, const int* const input_shape, - const int* const input_shape_cumulative_product, const int rep, const int axis, - int *output, const int* const output_shape, const int output_size, - cudaStream_t cuda_stream); - -// float16 -template void CalRepeatElements1d( - const half *input, const int rep, const int axis, half *output, const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements2d(const half *input, const int input_d1, const int rep, const int axis, - half *output, const int output_d1, const int output_size, - cudaStream_t cuda_stream); - -template void CalRepeatElements3d(const half *input, const int input_d1, const int input_d2, const int rep, - const int axis, half *output, const int output_d1, const int output_d2, - const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements4d(const half *input, const int input_d1, const int input_d2, const int input_d3, - const int rep, const int axis, half *output, const int output_d1, - const int output_d2, const int output_d3, const int output_size, - cudaStream_t cuda_stream); - -template void CalRepeatElements5d(const half *input, const int input_d1, const int input_d2, const int input_d3, - const int input_d4, const int rep, const int axis, half *output, - const int output_d1, const int output_d2, const int output_d3, - const int output_d4, const int output_size, cudaStream_t cuda_stream); - -template void CalRepeatElements(const half *input, const int input_dim, const int* const input_shape, - const int* const input_shape_cumulative_product, const int rep, const int axis, - half *output, const int* const output_shape, const int output_size, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh deleted file mode 100644 index 34221c38454..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/repeat_elements_impl.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ - -#include - -#define REPEAT_ELEMENTS_MAX_INPUT_DIM 100 - -template -void CalRepeatElements1d( - const T *input, const int rep, const int axis, T *output, const int output_size, cudaStream_t cuda_stream); - -template -void CalRepeatElements2d(const T *input, const int input_d1, const int rep, const int axis, T *output, - const int output_d1, const int output_size, cudaStream_t cuda_stream); - -template -void CalRepeatElements3d(const T *input, const int input_d1, const int input_d2, const int rep, const int axis, - T *output, const int output_d1, const int output_d2, const int output_size, - cudaStream_t cuda_stream); - -template -void CalRepeatElements4d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int rep, - const int axis, T *output, const int output_d1, const int output_d2, const int output_d3, - const int output_size, cudaStream_t cuda_stream); - -template -void CalRepeatElements5d(const T *input, const int input_d1, const int input_d2, const int input_d3, const int input_d4, - const int rep, const int axis, T *output, const int output_d1, const int output_d2, - const int output_d3, const int output_d4, const int output_size, cudaStream_t cuda_stream); - -template -void CalRepeatElements(const T *input, const int input_dim, const int* const input_shape, - const int* const input_shape_cumulative_product, const int rep, const int axis, T *output, - const int* const output_shape, const int output_size, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_REPEAT_ELEMENTS_H_ diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index c08813aa7c4..631ae14672a 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -882,13 +882,3 @@ def get_bprop_unique(self): dx = op(dout, out) return (dx,) return bprop - - -@bprop_getters.register(P.RepeatElements) -def get_bprop_repeat_elements(self): - """Generate bprop for RepeatElements""" - op = G.RepeatElementsGrad(self.rep, self.axis) - def bprop(x, y, dy): - dx = op(dy) - return (dx,) - return bprop diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index d29e9c075bd..fa37434b572 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -28,6 +28,7 @@ from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.zeros_like_impl import zeros_like from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial from .math_ops import count_nonzero, TensorDot +from .array_ops import repeat_elements __all__ = [ @@ -51,4 +52,5 @@ __all__ = [ 'clip_by_value', 'clip_by_global_norm', 'count_nonzero', - 'TensorDot'] + 'TensorDot', + 'repeat_elements'] diff --git a/mindspore/ops/composite/array_ops.py b/mindspore/ops/composite/array_ops.py new file mode 100644 index 00000000000..e68878d6fc4 --- /dev/null +++ b/mindspore/ops/composite/array_ops.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""math Operations.""" +from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils +from mindspore.common import dtype as mstype +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F +from .. import operations as P + + +@constexpr +def _check_is_int(arg_value, arg_name, op_name): + arg_value = validator.check_is_int(arg_value, arg_name, op_name) + return arg_value + + +@constexpr +def _check_positive_int(arg_value, arg_name, op_name): + arg_value = validator.check_positive_int(arg_value, arg_name, op_name) + return arg_value + + +@constexpr +def _check_axis_range(arg_value, limit, arg_name, op_name): + arg_value = validator.check_int_range(arg_value, -limit, limit, Rel.INC_LEFT, arg_name, op_name) + return arg_value + + +@constexpr +def _cal_repeat_dims(x_rank, rep, expand_axis): + rep_dims = [1] * (x_rank + 1) + rep_dims[expand_axis] = rep + return tuple(rep_dims) + + +@constexpr +def _cal_reshape(x_shape, rep, axis): + x_reshape = list(x_shape) + x_reshape[axis] *= rep + return tuple(x_reshape) + + +def repeat_elements(x, rep, axis=0): + """ + Repeat elements of a tensor along an axis, like np.repeat. + + Args: + - **x** (Tensor) - The tensor to repeat values for. + - **rep** (int) - The number of times to repeat, must be positive, required. + - **axis** (int) - The axis along which to repeat, default 0. + + Outputs: + One tensor with values repeated along the specified axis. If x has shape + (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ..., si * rep, ..., sn) + + Examples: + >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) + >>> output = C.RepeatElements(x, rep = 2, axis = 0) + >>> print(output) + [[0, 1, 2], + [0, 1, 2], + [3, 4, 5], + [3, 4, 5]], + """ + const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x') + rep = _check_positive_int(rep, "rep", "repeat_elements") + axis = _check_is_int(axis, "axis", "repeat_elements") + + shape_op = P.Shape() + rank_op = P.Rank() + tile_op = P.Tile() + expand_dims_op = P.ExpandDims() + reshape_op = P.Reshape() + + x_rank = rank_op(x) + axis = _check_axis_range(axis, x_rank, "axis", "repeat_elements") + + expand_axis = axis + 1 + x_expand = expand_dims_op(x, expand_axis) + rep_dims = _cal_repeat_dims(x_rank, rep, expand_axis) + x_expand = tile_op(x_expand, rep_dims) + x_shape = shape_op(x) + x_reshape = _cal_reshape(x_shape, rep, axis) + x_rep = reshape_op(x_expand, x_reshape) + + return x_rep diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index dab748e1976..ccf5e766117 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, - Unique, GatherD, Identity, RepeatElements) + Unique, GatherD, Identity) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, Send, Receive, @@ -388,7 +388,6 @@ __all__ = [ "Pull", "ReLUV2", "SparseToDense", - "RepeatElements", ] __all__.sort() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index d60d4ad38b5..49dd10aa761 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1912,24 +1912,3 @@ class LRNGrad(PrimitiveWithInfer): def infer_shape(self, grads, x, y): return x - - -class RepeatElementsGrad(PrimitiveWithInfer): - """Gradients of RepeatElements operation.""" - - @prim_attr_register - def __init__(self, rep, axis=0): - self.init_prim_io_names(inputs=['dy'], outputs=['dx']) - validator.check_value_type("rep", rep, [int], self.name) - validator.check_value_type("axis", axis, [int], self.name) - self.rep = rep - self.axis = axis - - def infer_dtype(self, dy_type): - validator.check_type_name("dy_type", dy_type, [mstype.float16, mstype.float32, mstype.int32], self.name) - return dy_type - - def infer_shape(self, dy_shape): - dx_shape = dy_shape - dx_shape[self.axis] = dy_shape[self.axis] // self.rep - return dx_shape diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 4a871f6fc7f..4cdee350868 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -4361,53 +4361,3 @@ class Identity(PrimitiveWithInfer): 'dtype': x['dtype'], 'value': None} return out - - -class RepeatElements(PrimitiveWithInfer): - """ - Repeat elements of a tensor along an axis, like np.repeat. - - Args: - rep (int): The number of times to repeat, must be positive, required. - axis (int): The axis along which to repeat, default 0. - - Inputs: - - **x** (Tensor) - The tensor to repeat values for. Must be of type int32 or float16. - - Outputs: - One tensor with values repeated along the specified axis. If x has shape - (s1, s2, ..., sn) and axis is i, the output will have shape (s1, s2, ..., si * rep, ..., sn) - - - Examples: - >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) - >>> repeat_elements = P.RepeatElements(rep = 2, axis = 0) - >>> output = repeat_elements(x) - >>> print(output) - [[0 1 2] - [0 1 2] - [3 4 5] - [3 4 5]] - """ - - @prim_attr_register - def __init__(self, rep, axis=0): - self.init_prim_io_names(inputs=["x"], outputs=["output"]) - - validator.check_value_type("rep", rep, [int], self.name) - self.rep = rep - - validator.check_value_type("axis", axis, [int], self.name) - self.axis = axis - - def infer_shape(self, x_shape): - validator.check("rep", self.rep, "", 0, Rel.GT, self.name) - validator.check("axis", self.axis, "dimension of x", len(x_shape), Rel.LT, self.name) - validator.check("axis", self.axis, "negative dimension of x", -len(x_shape), Rel.GE, self.name) - - x_shape[self.axis] *= self.rep - return x_shape - - def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name) - return x_dtype diff --git a/tests/st/ops/gpu/test_repeat_elements_grad_op.py b/tests/st/ops/gpu/test_repeat_elements_grad_op.py deleted file mode 100644 index 038ee115ec5..00000000000 --- a/tests/st/ops/gpu/test_repeat_elements_grad_op.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -import mindspore.context as context - - -class RepeatElementsNet(nn.Cell): - def __init__(self, rep, axis): - super(RepeatElementsNet, self).__init__() - self.repeat_elements = P.RepeatElements(rep, axis) - - def construct(self, x): - return self.repeat_elements(x) - - -class RepeatElementsGradNet(nn.Cell): - def __init__(self, rep, axis): - super(RepeatElementsGradNet, self).__init__() - self.repeat_elements_grad = G.RepeatElementsGrad(rep, axis) - - def construct(self, dy): - return self.repeat_elements_grad(dy) - - -def repeat_elements(x, rep, axis): - repeat_elements_net = RepeatElementsNet(rep, axis) - return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy() - - -def repeat_elements_grad(dy, rep, axis): - repeat_elements_grad_net = RepeatElementsGradNet(rep, axis) - return repeat_elements_grad_net(Tensor(dy.astype(np.int32))).asnumpy() - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_1d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_1d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1, 2) - - y = repeat_elements(a, 5, 0) - print(y) - ms_out = repeat_elements_grad(y, 5, 0) - print(ms_out) - np.testing.assert_array_equal(a*5, ms_out) - - y = repeat_elements(a, 513, 0) - ms_out = repeat_elements_grad(y, 513, 0) - np.testing.assert_array_equal(a*513, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_1d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_1d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(4) - - y = repeat_elements(a, 3, 0) - ms_out = repeat_elements_grad(y, 3, 0) - np.testing.assert_array_equal(a*3, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_2d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_2d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1) - - y = repeat_elements(a, 13, 0) - ms_out = repeat_elements_grad(y, 13, 0) - np.testing.assert_array_equal(a*13, ms_out) - - y = repeat_elements(a, 13, 1) - ms_out = repeat_elements_grad(y, 13, 1) - np.testing.assert_array_equal(a*13, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_2d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(12, 2) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_2d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(24).reshape(8, 3) - - y = repeat_elements(a, 23, 0) - ms_out = repeat_elements_grad(y, 23, 0) - np.testing.assert_array_equal(a*23, ms_out) - - y = repeat_elements(a, 23, 1) - ms_out = repeat_elements_grad(y, 23, 1) - np.testing.assert_array_equal(a*23, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_5d_one_element_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_5d_one_element_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1).reshape(1, 1, 1, 1, 1) - - y = repeat_elements(a, 19, 0) - ms_out = repeat_elements_grad(y, 19, 0) - np.testing.assert_array_equal(a, ms_out) - - y = repeat_elements(a, 19, 1) - ms_out = repeat_elements_grad(y, 19, 1) - np.testing.assert_array_equal(a, ms_out) - - y = repeat_elements(a, 19, 2) - ms_out = repeat_elements_grad(y, 19, 2) - np.testing.assert_array_equal(a, ms_out) - - y = repeat_elements(a, 19, 3) - ms_out = repeat_elements_grad(y, 19, 3) - np.testing.assert_array_equal(a, ms_out) - - y = repeat_elements(a, 19, 4) - ms_out = repeat_elements_grad(y, 19, 4) - np.testing.assert_array_equal(a, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_5d_rep_1(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(224).reshape(8, 2, 1, 7, 2) - - ms_out = repeat_elements_grad(a, 1, 0) - np_out = a.repeat(1, 0) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 1) - np_out = a.repeat(1, 1) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 2) - np_out = a.repeat(1, 2) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 3) - np_out = a.repeat(1, 3) - np.testing.assert_array_equal(np_out, ms_out) - - ms_out = repeat_elements_grad(a, 1, 4) - np_out = a.repeat(1, 4) - np.testing.assert_array_equal(np_out, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_5d_rep_many(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(224).reshape(1, 7, 4, 4, 2) - - y = repeat_elements(a, 7, 0) - ms_out = repeat_elements_grad(y, 7, 0) - np.testing.assert_array_equal(a*7, ms_out) - - y = repeat_elements(a, 7, 1) - ms_out = repeat_elements_grad(y, 7, 1) - np.testing.assert_array_equal(a*7, ms_out) - - y = repeat_elements(a, 7, 2) - ms_out = repeat_elements_grad(y, 7, 2) - np.testing.assert_array_equal(a*7, ms_out) - - y = repeat_elements(a, 7, 3) - ms_out = repeat_elements_grad(y, 7, 3) - np.testing.assert_array_equal(a*7, ms_out) - - y = repeat_elements(a, 7, 4) - ms_out = repeat_elements_grad(y, 7, 4) - np.testing.assert_array_equal(a*7, ms_out) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_repeat_elements_grad_half(): - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3) - - y = repeat_elements(a, 4, 0) - ms_out = repeat_elements_grad(y, 4, 0) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 1) - ms_out = repeat_elements_grad(y, 4, 1) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 2) - ms_out = repeat_elements_grad(y, 4, 2) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 3) - ms_out = repeat_elements_grad(y, 4, 3) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 4) - ms_out = repeat_elements_grad(y, 4, 4) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 5) - ms_out = repeat_elements_grad(y, 4, 5) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 6) - ms_out = repeat_elements_grad(y, 4, 6) - np.testing.assert_array_equal(a*4, ms_out) - - y = repeat_elements(a, 4, 7) - ms_out = repeat_elements_grad(y, 4, 7) - np.testing.assert_array_equal(a*4, ms_out) diff --git a/tests/st/ops/gpu/test_repeat_elements_op.py b/tests/st/ops/gpu/test_repeat_elements_op.py index 74941a604d3..e373f5c4ead 100644 --- a/tests/st/ops/gpu/test_repeat_elements_op.py +++ b/tests/st/ops/gpu/test_repeat_elements_op.py @@ -17,17 +17,18 @@ import numpy as np import pytest from mindspore import Tensor -from mindspore.ops import operations as P +from mindspore.ops import composite as C import mindspore.nn as nn import mindspore.context as context class RepeatElementsNet(nn.Cell): def __init__(self, rep, axis): super(RepeatElementsNet, self).__init__() - self.repeat_elements = P.RepeatElements(rep, axis) + self.rep = rep + self.axis = axis def construct(self, x): - return self.repeat_elements(x) + return C.repeat_elements(x, self.rep, self.axis) def repeat_elements(x, rep, axis): diff --git a/tests/ut/python/parallel/test_repeat_elements.py b/tests/ut/python/parallel/test_repeat_elements.py deleted file mode 100644 index aff2fb3f56c..00000000000 --- a/tests/ut/python/parallel/test_repeat_elements.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -import mindspore as ms -from mindspore import context, Tensor, Parameter -from mindspore.common.api import _executor -from mindspore.nn import Cell, TrainOneStepCell, Momentum -from mindspore.ops import operations as P - - -class Net(Cell): - def __init__(self, mul_weight, strategy1=None, strategy2=None): - super().__init__() - self.mul = P.Mul().shard(strategy1) - self.repeat = P.RepeatElements(rep=2, axis=1).shard(strategy2) - self.mul_weight = Parameter(mul_weight, "w1") - - def construct(self, x, b): - out = self.mul(x, self.mul_weight) - out = self.repeat(out) - return out - - -_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) -_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) -_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) - - -def compile_net(net): - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) - train_net = TrainOneStepCell(net, optimizer) - train_net.set_auto_parallel() - train_net.set_train() - _executor.compile(train_net, _x, _b) - context.reset_auto_parallel_context() - - -def test_repeat_elements_data_parallel(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((16, 1, 1), (16, 1, 1)) - strategy2 = ((16, 1, 1),) - net = Net(_w1, strategy1, strategy2) - compile_net(net) - - -def test_repeat_elements_model_parallel(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((1, 1, 16), (1, 1, 16)) - strategy2 = ((1, 1, 16),) - net = Net(_w1, strategy1, strategy2) - compile_net(net) - - -def test_repeat_elements_hybrid_parallel(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((2, 2, 4), (2, 2, 4)) - strategy2 = ((2, 2, 4),) - net = Net(_w1, strategy1, strategy2) - compile_net(net) - - -def test_repeat_elements_auto_parallel(): - context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) - net = Net(_w1) - compile_net(net) - - -def test_repeat_elements_repeat_calc(): - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) - strategy1 = ((2, 2, 4), (2, 2, 4)) - strategy2 = ((1, 2, 2),) - net = Net(_w1, strategy1, strategy2) - compile_net(net)