diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu new file mode 100644 index 00000000000..50300b1d6cd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cu @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size, + const T *dy, const T *x, const T *w, T *dx, T *dw) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t index = 0; + if (weight_size != 1) { + index = (pos / per_channel_size) % weight_size; + } + T threshold = static_cast(0); + dx[pos] = pos[x] <= threshold ? w[index] * dy[pos] : dy[pos]; + if (pos[x] < threshold) { + MsAtomicAdd(dw + index, x[pos] * dy[pos]); + } + } +} + +template +__global__ void InitDwData(size_t weight_size, T *dw) { + T init_value = static_cast(0); + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < weight_size; i += blockDim.x * gridDim.x) { + dw[i] = init_value; + } +} + + +template +void CalPReLUGrad(size_t size, size_t weight_size, size_t per_channel_size, + const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream) { + InitDwData<<>>(weight_size, dw); + CalPReLUGradKernel<<>>(size, weight_size, per_channel_size, + dy, x, w, dx, dw); + return; +} + +template void CalPReLUGrad(size_t, size_t, size_t, const float *, const float *, const float *, float *, float *, + cudaStream_t); +template void CalPReLUGrad(size_t, size_t, size_t, const half *, const half *, const half *, half *, half *, + cudaStream_t); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh new file mode 100644 index 00000000000..2c08332e61f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPReLUGrad(size_t input_size, size_t weight_size, size_t per_channel_size, + const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cu new file mode 100644 index 00000000000..66c975edf25 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cu @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh" + +template +__global__ void CalPReLUKernel(size_t size, size_t weight_size, size_t per_channel_size, + const T *input_addr, const T *weight_addr, T *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + size_t index = 0; + if (weight_size != 1) { + index = (pos / per_channel_size) % weight_size; + } + T threshold = static_cast(0); + output_addr[pos] = input_addr[pos] < threshold ? weight_addr[index] * input_addr[pos] : input_addr[pos]; + } +} + +template +void CalPReLU(size_t size, size_t weight_size, size_t per_channel_size, + const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream) { + CalPReLUKernel<<>>(size, weight_size, per_channel_size, + input_addr, weight_addr, output_addr); + return; +} + +template void CalPReLU(size_t, size_t, size_t, const float *, const float *, float *, cudaStream_t); +template void CalPReLU(size_t, size_t, size_t, const half *, const half *, half *, cudaStream_t); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh new file mode 100644 index 00000000000..31bb38a8a21 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPReLU(size_t input_size, size_t weight_size, size_t per_channel_size, + const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu index 4c9af52e5da..9f9dab7ae3d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu @@ -30,26 +30,6 @@ void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { return; } -template -__global__ void PReluChannelSharedGradKernel(size_t size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr) { - T zero = static_cast(0); - T w = w_addr[0]; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - T dy = dy_addr[pos]; - T x = x_addr[pos]; - dx_addr[pos] = x > zero ? dy : w * dy; - dwc_addr[pos] = x > zero ? zero : x * dy; - } -} - -template -void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr, - cudaStream_t cuda_stream) { - PReluChannelSharedGradKernel<<>>(input_size, dy_addr, x_addr, - w_addr, dx_addr, dwc_addr); - return; -} - template void CalReLUGrad(int size, double *dy, double *y, double *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); @@ -58,7 +38,3 @@ template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaSt template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream); -template void PReluChannelSharedGrad(size_t input_size, float *dy_addr, float *x_addr, float *w_addr, float *dx_addr, - float *dwc_addr, cudaStream_t cuda_stream); -template void PReluChannelSharedGrad(size_t input_size, half *dy_addr, half *x_addr, half *w_addr, half *dx_addr, - half *dwc_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh index c55ce4d823b..91840aa67fe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh @@ -20,8 +20,4 @@ #include "runtime/device/gpu/cuda_common.h" template void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); - -template -void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr, - cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu index 14f4f359083..04dbfc50599 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu @@ -96,18 +96,3 @@ template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *ma template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx, cudaStream_t cuda_stream); -template -__global__ void CalPReLUKernel(int size, T *input_addr, T *weight_addr, T *output_addr) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output_addr[pos] = input_addr[pos] > static_cast(0) ? input_addr[pos] : *weight_addr * input_addr[pos]; - } -} - -template -void CalPReLU(int size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream) { - CalPReLUKernel<<>>(size, input_addr, weight_addr, output_addr); - return; -} - -template void CalPReLU(int size, float *input_addr, float *weight_addr, float *output_addr, cudaStream_t cuda_stream); -template void CalPReLU(int size, half *input_addr, half *weight_addr, half *output_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh index 0a4b4e9ebe8..7918395f6f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh @@ -25,7 +25,4 @@ template void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream); template void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream); - -template -void CalPReLU(int input_size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h index 1e0a1eef007..15db15c1247 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_gpu_kernel.h @@ -19,93 +19,97 @@ #include #include -#include +#include + #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh" namespace mindspore { namespace kernel { template class PReLUGpuKernel : public GpuKernel { public: - PReLUGpuKernel() { ResetResource(); } - ~PReLUGpuKernel() override {} + PReLUGpuKernel() = default; + ~PReLUGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *weight = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); + auto *input = GetDeviceAddress(inputs, 0); + auto *weight = GetDeviceAddress(inputs, 1); + auto *output = GetDeviceAddress(outputs, 0); - const int size = input_size_ / sizeof(T); - CalPReLU(size, input, weight, output, reinterpret_cast(stream_ptr)); + CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output, + reinterpret_cast(stream_ptr)); return true; } + bool Init(const CNodePtr &kernel_node) override { + ResetResource(); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 2."; + MS_LOG(ERROR) << "PReLU needs 2 inputs, but got " << input_num; return false; } - auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(ERROR) << "PReLUGpuFwdKernel input is null."; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "ReLU should have 1 output, but got " << input_num; return false; } - size_t size = 1; - for (size_t i = 0; i < input_shape.size(); i++) { - size *= input_shape[i]; - } - input_size_ = size * sizeof(T); - auto weight_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(weight_shape); - if (is_null_input_) { - MS_LOG(ERROR) << "PReLUGpuFwdKernel weight is null."; - return false; + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>()); + size_t input_rank = input_shape.size(); + size_t channel_num; + if (input_rank == 0) { + channel_num = 1; + per_channel_length_ = 1; + } else if (input_rank == 1) { + channel_num = 1; + per_channel_length_ = input_shape[0]; + } else { + channel_num = input_shape[1]; + per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>()); } - size = 1; - for (size_t i = 0; i < weight_shape.size(); i++) { - size *= weight_shape[i]; - } - weight_size_ = size * sizeof(T); + auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (weight_shape.size() != 1 && weight_shape[0] != 1 && weight_shape[0] != channel_num) { + MS_LOG(EXCEPTION) << "PReLU requires the rank of weight should be 1, and the elements number should be " + "1 or channels number " + << channel_num << ", but got weight shape " << weight_shape; + } + weight_length_ = weight_shape[0]; InitSizeLists(); return true; } void ResetResource() noexcept override { - is_null_input_ = false; + input_length_ = 0; + weight_length_ = 0; + per_channel_length_ = 0; input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); - input_size_ = 0; - workspace_size_ = 0; } protected: void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - workspace_size_list_.push_back(workspace_size_); + size_t data_size = sizeof(T); + input_size_list_.push_back(input_length_ * data_size); + input_size_list_.push_back(weight_length_ * data_size); + output_size_list_.push_back(input_length_ * data_size); } private: - bool is_null_input_; + size_t input_length_{0}; + size_t weight_length_{0}; + size_t per_channel_length_{0}; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - size_t input_size_; - size_t weight_size_; - size_t workspace_size_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.cc similarity index 89% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc rename to mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.cc index b7b1bb0cf1a..24fa3714849 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h" +#include "backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.h" namespace mindspore { namespace kernel { @@ -25,7 +25,7 @@ MS_REG_GPU_KERNEL_ONE(PReLUGrad, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), - PReLUGpuGradKernel, float) + PReLUGradGpuKernel, float) MS_REG_GPU_KERNEL_ONE(PReLUGrad, KernelAttr() .AddInputAttr(kNumberTypeFloat16) @@ -33,6 +33,6 @@ MS_REG_GPU_KERNEL_ONE(PReLUGrad, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), - PReLUGpuGradKernel, half) + PReLUGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.h new file mode 100644 index 00000000000..637409672a3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_ + +#include +#include +#include + +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class PReLUGradGpuKernel : public GpuKernel { + public: + PReLUGradGpuKernel() = default; + ~PReLUGradGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto *dy = GetDeviceAddress(inputs, 0); + auto *x = GetDeviceAddress(inputs, 1); + auto *w = GetDeviceAddress(inputs, 2); + auto *dx = GetDeviceAddress(outputs, 0); + auto *dw = GetDeviceAddress(outputs, 1); + + CalPReLUGrad(input_length_, weight_length_, per_channel_length_, dy, x, w, dx, dw, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + ResetResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "ReLUGrad needs 3 inputs, but got " << input_num; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "ReLUGrad should have 2 outputs, but got " << input_num; + return false; + } + + auto x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), size_t(1), std::multiplies<>()); + size_t x_rank = x_shape.size(); + size_t channel_num; + if (x_rank == 0) { + channel_num = 1; + per_channel_length_ = 1; + } else if (x_rank == 1) { + channel_num = 1; + per_channel_length_ = x_shape[0]; + } else { + channel_num = x_shape[1]; + per_channel_length_ = std::accumulate(x_shape.begin() + 2, x_shape.end(), size_t(1), std::multiplies<>()); + } + + auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + if (weight_shape.size() != 1 && weight_shape[0] != 1 && weight_shape[0] != channel_num) { + MS_LOG(EXCEPTION) << "PReLUGrad requires the rank of weight should be 1, and the elements number should be " + "1 or channels number " + << channel_num << ", but got weight shape " << weight_shape; + } + weight_length_ = weight_shape[0]; + InitSizeLists(); + return true; + } + + void ResetResource() noexcept override { + input_length_ = 0; + weight_length_ = 0; + per_channel_length_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + + protected: + void InitSizeLists() override { + size_t data_size = sizeof(T); + input_size_list_.push_back(input_length_ * data_size); + input_size_list_.push_back(input_length_ * data_size); + input_size_list_.push_back(weight_length_ * data_size); + output_size_list_.push_back(input_length_ * data_size); + output_size_list_.push_back(weight_length_ * data_size); + } + + private: + size_t input_length_{0}; + size_t weight_length_{0}; + size_t per_channel_length_{0}; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h deleted file mode 100644 index 21f36285d76..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ - -#include -#include -#include "backend/kernel_compiler/gpu/gpu_kernel.h" -#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/kernel_constants.h" -#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class PReLUGpuGradKernel : public GpuKernel { - public: - PReLUGpuGradKernel() - : data_format_(kOpFormat_NCDHW), - input_size_(0), - weight_size_(0), - reduce_workspace_size_(0), - spatial_count_(1), - is_null_input_(false), - channel_shared_(false), - channel_last_(false) {} - ~PReLUGpuGradKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *dy_addr = GetDeviceAddress(inputs, 0); - T *x_addr = GetDeviceAddress(inputs, 1); - T *w_addr = GetDeviceAddress(inputs, 2); - T *dx_addr = GetDeviceAddress(outputs, 0); - T *dw_addr = GetDeviceAddress(outputs, 1); - T *dw_collector_addr = GetDeviceAddress(workspace, 0); - T *reduce_workspace_addr = GetDeviceAddress(workspace, 1); - - PReluChannelSharedGrad(input_size_ / sizeof(T), dy_addr, x_addr, w_addr, dx_addr, dw_collector_addr, - reinterpret_cast(stream_ptr)); - - if (data_type_ == CUDNN_DATA_DOUBLE) { - T alpha = static_cast(1.0f); - T beta = static_cast(0.0f); - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr, - reduce_workspace_size_, &alpha, grad_weight_collector_descriptor_, dw_collector_addr, &beta, - grad_weight_descriptor_, dw_addr), - "cudnnReduceTensor failed."); - } else { - const float alphaf = static_cast(1.0f); - const float betaf = static_cast(0.0f); - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr, - reduce_workspace_size_, &alphaf, grad_weight_collector_descriptor_, dw_collector_addr, &betaf, - grad_weight_descriptor_, dw_addr), - "cudnnReduceTensor failed."); - } - return true; - } - - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), - "cudnnCreateReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_collector_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - - void DestroyResource() noexcept override { - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), - "cudnnDestroyReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_collector_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - - bool Init(const CNodePtr &kernel_node) override { - kernel_node_ = kernel_node; - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null."; - } - for (size_t i = 0; i < input_shape.size(); ++i) { - input_size_ *= input_shape[i]; - } - weight_size_ = sizeof(T); - auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - is_null_input_ = CHECK_NULL_INPUT(weight_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null."; - } - for (auto dim : weight_shape) { - weight_size_ *= dim; - } - channel_shared_ = (weight_shape[0] == 1); - if (!channel_shared_) { - MS_LOG(WARNING) - << "PReLUGpuBwdKernel shares weight for all channels, but the given weight tensor has more than one element."; - } - - spatial_count_ = 1; - if (channel_last_) { - for (size_t i = 1; i < input_shape.size() - 1; ++i) { - spatial_count_ *= input_shape[i]; - } - } else { - for (size_t i = 2; i < input_shape.size(); ++i) { - spatial_count_ *= input_shape[i]; - } - } - - data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - int input_dim_length = input_shape.size(); - std::vector reduce_out_shape(input_dim_length, 1); - if (channel_last_) { - reduce_out_shape[input_dim_length - 1] = weight_shape[0]; - } else { - reduce_out_shape[1] = weight_shape[0]; - } - InitResource(); - CudnnSetTensorNdDescriptor(reduce_out_shape, grad_weight_descriptor_, data_type_, kernel_node_); - CudnnSetTensorNdDescriptor(input_shape, grad_weight_collector_descriptor_, data_type_, kernel_node_); - cudnnDataType_t comp_type = (data_type_ == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_ADD, comp_type, - CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), - "cudnnSetReduceTensorDescriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - input_size_list_.push_back(weight_size_); - output_size_list_.push_back(input_size_); - output_size_list_.push_back(weight_size_); - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, grad_weight_collector_descriptor_, - grad_weight_descriptor_, &reduce_workspace_size_), - "cudnnGetReductionWorkspaceSize failed."); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(reduce_workspace_size_); - } - - private: - cudnnHandle_t cudnn_handle_; - cudnnDataType_t data_type_; - cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; - cudnnTensorDescriptor_t grad_weight_collector_descriptor_; - cudnnTensorDescriptor_t grad_weight_descriptor_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - std::string data_format_ = kOpFormat_NCDHW; - size_t input_size_; - size_t weight_size_; - size_t reduce_workspace_size_; - size_t spatial_count_; - bool is_null_input_ = false; - bool channel_shared_ = false; - bool channel_last_ = false; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc index c835ed14ecd..c70a53f819e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.cc @@ -40,7 +40,7 @@ static bool CheckStridedSlice(const CNodePtr &cnode) { } } // check reduction on the last dimension - if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) { + if (GetCNodeFuncName(cnode) == kStridedSliceOpName && AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) { auto shrink_axis_mask = static_cast(AnfAlgo::GetNodeAttr(cnode, kAttrShrinkAxisMask)); AnfNodePtr input = cnode->input(1); int input_dims = 0; diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index c9aa72988bd..8ad278500d2 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -14,15 +14,15 @@ # ============================================================================ """activation""" import numpy as np -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.common.parameter import Parameter -from mindspore.common.initializer import initializer -from mindspore.common.tensor import Tensor -from mindspore._extends import cell_attr_register -from mindspore._checkparam import Validator as validator -from ..cell import Cell +from mindspore._checkparam import Validator as validator +from mindspore._extends import cell_attr_register +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from ..cell import Cell __all__ = ['Softmax', 'LogSoftmax', @@ -548,22 +548,24 @@ class PReLU(Cell): Activation_function#/media/File:Activation_prelu.svg>`_. Args: - channel (int): The dimension of input. Default: 1. - w (Union[float, list, Tensor]): The initial value of w. Default: 0.25. + channel (int): The elements number of parameter. + It could be an int, and the value is 1 or the channels number of input tensor `x`. Default: 1. + w (Union[float, list, Tensor): The initial value of parameter. It could be a float, a float list or + a tensor has the same dtype as the input tensor `x`. Default: 0.25. Inputs: - **x** (Tensor) - The input of PReLU with data type of float16 or float32. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions. Outputs: - Tensor, with the same type and shape as the `x`. + Tensor, with the same dtype and shape as the `x`. Raises: TypeError: If `channel` is not an int. - TypeError: If `w` is not one of float, list, Tensor. + TypeError: If `w` is not one of a float, a float list, a float Tensor. TypeError: If dtype of `x` is neither float16 nor float32. + ValueError: If the `x` is a 0-D or 1-D Tensor on Ascend. ValueError: If `channel` is less than 1. - ValueError: If length of shape of `x` is equal to 1. Supported Platforms: ``Ascend`` ``GPU`` @@ -582,24 +584,34 @@ class PReLU(Cell): """Initialize PReLU.""" super(PReLU, self).__init__() validator.check_positive_int(channel, 'channel', self.cls_name) - if isinstance(w, (np.float32, float)): + if isinstance(w, (float, np.float32)): tmp = np.empty((channel,), dtype=np.float32) tmp.fill(w) - w = Tensor(tmp) + w = Tensor(tmp, dtype=mstype.float32) elif isinstance(w, list): - w = Tensor(w) - - if not isinstance(w, Tensor): - raise TypeError("w only support np.float32, float, list or Tensor type.") - - self.w = Parameter(initializer(w, [channel]), name='a') + if len(w) != channel: + raise ValueError(f"When the 'w' is a list, the length should be equal to the channel, " + f"but got the length {len(w)}, the channel {channel}") + for i in w: + if not isinstance(i, (float, np.float32)): + raise ValueError(f"When the 'w' is a list, the all elements should be float, but got {w}") + w = Tensor(w, dtype=mstype.float32) + elif isinstance(w, Tensor): + if w.dtype not in (mstype.float16, mstype.float32): + raise ValueError(f"When the 'w' is a tensor, the dtype should be float16 or float32, but got {w.dtype}") + if len(w.shape) != 1 or w.shape[0] != channel: + raise ValueError(f"When the 'w' is a tensor, the rank should be 1, and the elements number " + f"should be equal to the channel, but got w shape {w}, the channel {channel}") + else: + raise TypeError(f"The 'w' only supported float list and tensor, but got {type(w)}") + self.w = Parameter(w, name='a') self.prelu = P.PReLU() self.relu = P.ReLU() self.assign = P.Assign() def construct(self, x): u = self.relu(self.w) - v = self.prelu(x, u) + v = self.prelu(x, F.cast(u, x.dtype)) if self.training: self.assign(self.w, u) return v diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6a7ae2ce6a6..02aaa469f77 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -1544,8 +1544,6 @@ class PReLUGrad(PrimitiveWithInfer): pass def infer_shape(self, y_backprop_shape, a_shape, w_shape): - if len(a_shape) == 1: - raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.') return y_backprop_shape, w_shape def infer_dtype(self, y_backprop_dtype, a_dtype, w_dtype): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 05d4577d8ec..c105487f46c 100755 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2151,6 +2151,7 @@ class Conv2DTranspose(Conv2DBackpropInput): >>> print(output.shape) (10, 32, 32, 32) """ + @prim_attr_register def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0, pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"): @@ -3638,16 +3639,18 @@ class PReLU(PrimitiveWithInfer): .. math:: prelu(x_i)= \max(0, x_i) + \min(0, w * x_i), - where :math:`x_i` is an element of an channel of the input. + where :math:`x_i` is an element of an channel of the input, `w` is the weight of the channel. Note: - 1-dimensional input_x is not supported. + 0-D or 1-D input_x is not supported on Ascend. Inputs: - - **input_x** (Tensor) - Float tensor, representing the output of the preview layer. - With data type of float16 or float32. - - **weight** (Tensor) - Float Tensor, w > 0, there are only two shapes are legitimate, - 1 or the number of channels of the input. With data type of float16 or float32. + - **input_x** (Tensor) - The first input tensor. The data type is float16 or float32. + Represents the output of the preview layer. + - **weight** (Tensor) - The second input tensor. The data type is float16 or float32. + There are only two shapes are legitimate, 1 or the number of channels of the `input_x`. + Channel dim is the 2nd dim of input. When input is 0-D or 1-D tensor, the number of channels is 1. + Outputs: Tensor, with the same type as `input_x`. @@ -3656,9 +3659,9 @@ class PReLU(PrimitiveWithInfer): Raises: TypeError: If dtype of `input_x` or `weight` is neither float16 nor float32. - TypeError: If `input_x` or `weight` is not a Tensor. - ValueError: If length of shape of `input_x` is equal to 1. - ValueError: If length of shape of `weight` is not equal to 1. + TypeError: If the `input_x` or the `weight` is not a Tensor. + ValueError: If the `input_x` is a 0-D or 1-D Tensor on Ascned. + ValueError: If the `weight` is not a 1-D Tensor. Supported Platforms: ``Ascend`` ``GPU`` @@ -3677,12 +3680,17 @@ class PReLU(PrimitiveWithInfer): ... result = self.prelu(input_x, weight) ... return result ... - >>> input_x = Tensor(np.random.randint(-3, 3, (2, 3, 2)), mindspore.float32) + >>> input_x = Tensor(np.arange(-6, 6).reshape((2, 3, 2)), mindspore.float32) >>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32) >>> net = Net() >>> output = net(input_x, weight) - >>> print(output.shape) - (2, 3, 2) + >>> print(output) + [[[-0.60 -0.50] + [-2.40 -1.80] + [ 0.60 0.30]] + [[ 0.00 1.00] + [ 2.00 3.00] + [ 4.0 5.00]]] """ @prim_attr_register @@ -3691,25 +3699,29 @@ class PReLU(PrimitiveWithInfer): def infer_shape(self, input_x_shape, weight_shape): input_x_dim = len(input_x_shape) + if input_x_dim in (0, 1): + if context.get_context("device_target") == "Ascend": + raise ValueError(f"For '{self.name}', the 0-D or 1-D 'input_x' is not supported on Ascend.") + channel_num = 1 + else: + channel_num = input_x_shape[1] + weight_dim = len(weight_shape) - - if input_x_dim == 1: - raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.') - if weight_dim != 1: - raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.') - - if weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1: - raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,' - f' while channel of input_x is {input_x_shape[1]},' - f' weight_shape[0] is {weight_shape[0]}.') - + raise ValueError(f"For '{self.name}', the weight dimension should be 1, while got {weight_dim}.") + if weight_shape[0] != 1 and weight_shape[0] != channel_num: + raise ValueError(f"For '{self.name}', the weight shape should be (1,) or " + f"matched with input channel ({channel_num},), but got {weight_shape}") return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): valid_dtypes = (mstype.float16, mstype.float32) - validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name) - validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name) + args = {"input_x": input_x_dtype, "weight": weight_dtype} + if context.get_context("device_target") == "GPU": + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) + else: + validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name) + validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name) return input_x_dtype @@ -7876,6 +7888,7 @@ class AvgPool3D(Primitive): [[[[[ 5. 6.]]] [[[17. 18.]]]]] """ + @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", pad=0, ceil_mode=False, count_include_pad=True, divisor_override=0, data_format="NCDHW"): @@ -8399,7 +8412,6 @@ class CTCLossV2Grad(Primitive): self.add_prim_attr("zero_infinity", zero_infinity) - class Conv3DTranspose(PrimitiveWithInfer): r""" Computes a 3D transposed convolution, which is also known as a deconvolution diff --git a/tests/st/ops/gpu/test_prelu_grad_op.py b/tests/st/ops/gpu/test_prelu_grad_op.py deleted file mode 100644 index 1442d730055..00000000000 --- a/tests/st/ops/gpu/test_prelu_grad_op.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import pytest - -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -from mindspore.ops.operations import _grad_ops as G - - -class NetPReLUGrad(nn.Cell): - def __init__(self): - super(NetPReLUGrad, self).__init__() - self.prelu_grad = G.PReLUGrad() - - def construct(self, dout, x, w): - return self.prelu_grad(dout, x, w) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_prelu_grad_fp32_channel_shared(): - dout = Tensor(np.ones(shape=[2, 2, 2, 3]).astype(np.float32)) - x = Tensor(np.arange(-5, 19).reshape(2, 2, 2, 3).astype(np.float32)) - w = Tensor(np.array([-0.5]).astype(np.float32)) - expect_dx = np.array([[[[-0.5000, -0.5000, -0.5000], - [-0.5000, -0.5000, -0.5000]], - [[1.0000, 1.0000, 1.0000], - [1.0000, 1.0000, 1.0000]]], - [[[1.0000, 1.0000, 1.0000], - [1.0000, 1.0000, 1.0000]], - [[1.0000, 1.0000, 1.0000], - [1.0000, 1.0000, 1.0000]]]]).astype(np.float32) - expect_dw = np.array([-15.]).astype(np.float32) - - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - prelu_grad = NetPReLUGrad() - dx, dw = prelu_grad(dout, x, w) - assert (dx.asnumpy() == expect_dx).all() - assert (dw.asnumpy() == expect_dw).all() - - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - prelu_grad = NetPReLUGrad() - dx, dw = prelu_grad(dout, x, w) - assert (dx.asnumpy() == expect_dx).all() - assert (dw.asnumpy() == expect_dw).all() diff --git a/tests/st/ops/gpu/test_prelu_op.py b/tests/st/ops/gpu/test_prelu_op.py index bf6dabad9de..930f9e08b36 100644 --- a/tests/st/ops/gpu/test_prelu_op.py +++ b/tests/st/ops/gpu/test_prelu_op.py @@ -20,55 +20,215 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype -class NetPReLU(nn.Cell): + +class PReLUOpNet(nn.Cell): def __init__(self): - super(NetPReLU, self).__init__() + super(PReLUOpNet, self).__init__() self.prelu = P.PReLU() def construct(self, x, weight): return self.prelu(x, weight) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_prelu_float16(): - weight = Tensor(np.array([0.25]).astype(np.float16)) - x = Tensor(np.array([[[[-1, 1, 10], - [1, -1, 1], - [10, 1, -1]]]]).astype(np.float16)) - expect = np.array([[[[-0.25, 1, 10,], - [1, -0.25, 1,], - [10, 1, -0.25]]]]).astype(np.float16) - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - prelu = NetPReLU() - output = prelu(x, weight) - assert (output.asnumpy() == expect).all() +class PReLUOpGradNet(nn.Cell): + def __init__(self, net): + super(PReLUOpGradNet, self).__init__() + self.forward = net + self.grad = C.GradOperation(get_all=True, sens_param=False) + + def construct(self, x, weight): + return self.grad(self.forward)(x, weight) + + +def judge_result_correct(result, expect): + result = result.asnumpy() + expect = expect.asnumpy() + assert result.dtype == expect.dtype + assert result.shape == expect.shape + assert np.allclose(result, expect, rtol=1.e-2) + + +def test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode): + context.set_context(mode=mode) + prelu_forward = PReLUOpNet() + prelu_backward = PReLUOpGradNet(prelu_forward) + forward_output = prelu_forward(x, weight) + judge_result_correct(forward_output, expect_forward) + + backward_output = prelu_backward(x, weight) + assert len(backward_output) == 2 + judge_result_correct(backward_output[0], expect_dx) + judge_result_correct(backward_output[1], expect_dw) - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - prelu = NetPReLU() - output = prelu(x, weight) - assert (output.asnumpy() == expect).all() @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_prelu_float32(): - weight = Tensor(np.array([0.25]).astype(np.float32)) - x = Tensor(np.array([[[[-1, 1, 10], - [1, -1, 1], - [10, 1, -1]]]]).astype(np.float32)) - expect = np.array([[[[-0.25, 1, 10,], - [1, -0.25, 1,], - [10, 1, -0.25]]]]).astype(np.float32) +def test_prelu_single_weight(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] - context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - prelu = NetPReLU() - output = prelu(x, weight) - assert (output.asnumpy() == expect).all() + x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.7 + weight = np.array([0.6]) + expect_forward = np.where(x >= 0, x, weight * x) + expect_dx = np.where(x > 0, 1, weight) + expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,)) - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - prelu = NetPReLU() - output = prelu(x, weight) - assert (output.asnumpy() == expect).all() + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_multiple_weight(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] + + x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.6 + weight = np.array([0.2, 0.3, 0.4]) + expect_forward = np.array([[[[-1.20, -1.08, -0.96], + [-0.84, -0.72, -0.60]], + [[-0.72, -0.54, -0.36], + [-0.18, 0.00, 0.60]], + [[1.20, 1.80, 2.40], + [3.00, 3.60, 4.20]]], + [[[4.80, 5.40, 6.00], + [6.60, 7.20, 7.80]], + [[8.40, 9.00, 9.60], + [10.20, 10.80, 11.40]], + [[12.00, 12.60, 13.20], + [13.80, 14.40, 15.00]]]]) + expect_dx = np.array([[[[0.2, 0.2, 0.2], + [0.2, 0.2, 0.2]], + [[0.3, 0.3, 0.3], + [0.3, 0.3, 1.0]], + [[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]]], + [[[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]], + [[1.0, 1.0, 1.0], + [1.0, 1.0, 1.0]]]]) + expect_dw = np.array([-27.0, -6.0, 0.0]) + + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_single_weight_0_D(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] + + x = np.array(-0.8) + weight = np.array([0.6]) + expect_forward = np.array(-0.48) + expect_dx = np.array(0.6) + expect_dw = np.array([-0.8]) + + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_single_weight_1_D(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] + + x = np.arange(-10, 26).reshape((36,)) * 0.7 + weight = np.array([0.6]) + expect_forward = np.where(x >= 0, x, weight * x) + expect_dx = np.where(x > 0, 1, weight) + expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,)) + + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_single_weight_2_D(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] + + x = np.arange(-10, 26).reshape((4, 9)) * 0.7 + weight = np.array([0.6]) + expect_forward = np.where(x >= 0, x, weight * x) + expect_dx = np.where(x > 0, 1, weight) + expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,)) + + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_prelu_multiple_weight_2_D(): + context.set_context(device_target="GPU") + dtypes = [mstype.float16, mstype.float32] + modes = [context.GRAPH_MODE, context.GRAPH_MODE] + + x = np.arange(-6, 6).reshape((3, 4)) * 0.6 + weight = np.array([0.2, 0.4, 0.7, 0.9]) + expect_forward = np.array([[-0.72, -1.20, -1.68, -1.62], + [-0.24, -0.24, 0.00, 0.60], + [1.20, 1.80, 2.40, 3.00]]) + expect_dx = np.array([[0.2, 0.4, 0.7, 0.9], + [0.2, 0.4, 0.7, 1.0], + [1.0, 1.0, 1.0, 1.0]]) + expect_dw = np.array([-4.8, -3.6, -2.4, -1.8]) + + for dtype in dtypes: + for mode in modes: + x = Tensor(x, dtype) + weight = Tensor(weight, dtype) + expect_forward = Tensor(expect_forward, dtype) + expect_dx = Tensor(expect_dx, dtype) + expect_dw = Tensor(expect_dw, dtype) + test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)