!19286 add_PReLUGrad_for_GPU

Merge pull request !19286 from zhangbuxue/add_PReLuGrad_for_GPU
This commit is contained in:
i-robot 2021-07-08 02:01:45 +00:00 committed by Gitee
commit 4a75354555
18 changed files with 589 additions and 436 deletions

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void CalPReLUGradKernel(size_t size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t index = 0;
if (weight_size != 1) {
index = (pos / per_channel_size) % weight_size;
}
T threshold = static_cast<T>(0);
dx[pos] = pos[x] <= threshold ? w[index] * dy[pos] : dy[pos];
if (pos[x] < threshold) {
MsAtomicAdd(dw + index, x[pos] * dy[pos]);
}
}
}
template <typename T>
__global__ void InitDwData(size_t weight_size, T *dw) {
T init_value = static_cast<T>(0);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < weight_size; i += blockDim.x * gridDim.x) {
dw[i] = init_value;
}
}
template <typename T>
void CalPReLUGrad(size_t size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream) {
InitDwData<<<GET_BLOCKS(weight_size), GET_THREADS, 0, cuda_stream>>>(weight_size, dw);
CalPReLUGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, weight_size, per_channel_size,
dy, x, w, dx, dw);
return;
}
template void CalPReLUGrad(size_t, size_t, size_t, const float *, const float *, const float *, float *, float *,
cudaStream_t);
template void CalPReLUGrad(size_t, size_t, size_t, const half *, const half *, const half *, half *, half *,
cudaStream_t);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalPReLUGrad(size_t input_size, size_t weight_size, size_t per_channel_size,
const T *dy, const T *x, const T *w, T *dx, T *dw, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_GRAD_H_

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh"
template <typename T>
__global__ void CalPReLUKernel(size_t size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t index = 0;
if (weight_size != 1) {
index = (pos / per_channel_size) % weight_size;
}
T threshold = static_cast<T>(0);
output_addr[pos] = input_addr[pos] < threshold ? weight_addr[index] * input_addr[pos] : input_addr[pos];
}
}
template <typename T>
void CalPReLU(size_t size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream) {
CalPReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, weight_size, per_channel_size,
input_addr, weight_addr, output_addr);
return;
}
template void CalPReLU(size_t, size_t, size_t, const float *, const float *, float *, cudaStream_t);
template void CalPReLU(size_t, size_t, size_t, const half *, const half *, half *, cudaStream_t);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalPReLU(size_t input_size, size_t weight_size, size_t per_channel_size,
const T *input_addr, const T *weight_addr, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_PRELU_H_

View File

@ -30,26 +30,6 @@ void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) {
return;
}
template <typename T>
__global__ void PReluChannelSharedGradKernel(size_t size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr) {
T zero = static_cast<T>(0);
T w = w_addr[0];
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
T dy = dy_addr[pos];
T x = x_addr[pos];
dx_addr[pos] = x > zero ? dy : w * dy;
dwc_addr[pos] = x > zero ? zero : x * dy;
}
}
template <typename T>
void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr,
cudaStream_t cuda_stream) {
PReluChannelSharedGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input_size, dy_addr, x_addr,
w_addr, dx_addr, dwc_addr);
return;
}
template void CalReLUGrad(int size, double *dy, double *y, double *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream);
@ -58,7 +38,3 @@ template void CalReLUGrad(int size, int16_t *dy, int16_t *y, int16_t *dx, cudaSt
template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream);
template void CalReLUGrad(int size, uint8_t *dy, uint8_t *y, uint8_t *dx, cudaStream_t cuda_stream);
template void PReluChannelSharedGrad(size_t input_size, float *dy_addr, float *x_addr, float *w_addr, float *dx_addr,
float *dwc_addr, cudaStream_t cuda_stream);
template void PReluChannelSharedGrad(size_t input_size, half *dy_addr, half *x_addr, half *w_addr, half *dx_addr,
half *dwc_addr, cudaStream_t cuda_stream);

View File

@ -20,8 +20,4 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream);
template <typename T>
void PReluChannelSharedGrad(size_t input_size, T *dy_addr, T *x_addr, T *w_addr, T *dx_addr, T *dwc_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_

View File

@ -96,18 +96,3 @@ template void ReluGradV2(const size_t num, const int64_t *dy, const uint32_t *ma
template void ReluGradV2(const size_t num, const uint8_t *dy, const uint32_t *mask, uint8_t *dx,
cudaStream_t cuda_stream);
template <typename T>
__global__ void CalPReLUKernel(int size, T *input_addr, T *weight_addr, T *output_addr) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = input_addr[pos] > static_cast<T>(0) ? input_addr[pos] : *weight_addr * input_addr[pos];
}
}
template <typename T>
void CalPReLU(int size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream) {
CalPReLUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_addr, weight_addr, output_addr);
return;
}
template void CalPReLU(int size, float *input_addr, float *weight_addr, float *output_addr, cudaStream_t cuda_stream);
template void CalPReLU(int size, half *input_addr, half *weight_addr, half *output_addr, cudaStream_t cuda_stream);

View File

@ -25,7 +25,4 @@ template <typename T>
void ReluV2(const size_t num, const T *x, T *y, uint32_t *mask, cudaStream_t cuda_stream);
template <typename T>
void ReluGradV2(const size_t num, const T *dy, const uint32_t *mask, T *dx, cudaStream_t cuda_stream);
template <typename T>
void CalPReLU(int input_size, T *input_addr, T *weight_addr, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_H_

View File

@ -19,93 +19,97 @@
#include <vector>
#include <map>
#include <string>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/prelu_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class PReLUGpuKernel : public GpuKernel {
public:
PReLUGpuKernel() { ResetResource(); }
~PReLUGpuKernel() override {}
PReLUGpuKernel() = default;
~PReLUGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *weight = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
auto *input = GetDeviceAddress<T>(inputs, 0);
auto *weight = GetDeviceAddress<T>(inputs, 1);
auto *output = GetDeviceAddress<T>(outputs, 0);
const int size = input_size_ / sizeof(T);
CalPReLU(size, input, weight, output, reinterpret_cast<cudaStream_t>(stream_ptr));
CalPReLU(input_length_, weight_length_, per_channel_length_, input, weight, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
ResetResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but ReLUGpuFwdKernel needs 2.";
MS_LOG(ERROR) << "PReLU needs 2 inputs, but got " << input_num;
return false;
}
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(ERROR) << "PReLUGpuFwdKernel input is null.";
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "ReLU should have 1 output, but got " << input_num;
return false;
}
size_t size = 1;
for (size_t i = 0; i < input_shape.size(); i++) {
size *= input_shape[i];
}
input_size_ = size * sizeof(T);
auto weight_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(weight_shape);
if (is_null_input_) {
MS_LOG(ERROR) << "PReLUGpuFwdKernel weight is null.";
return false;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_length_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<>());
size_t input_rank = input_shape.size();
size_t channel_num;
if (input_rank == 0) {
channel_num = 1;
per_channel_length_ = 1;
} else if (input_rank == 1) {
channel_num = 1;
per_channel_length_ = input_shape[0];
} else {
channel_num = input_shape[1];
per_channel_length_ = std::accumulate(input_shape.begin() + 2, input_shape.end(), size_t(1), std::multiplies<>());
}
size = 1;
for (size_t i = 0; i < weight_shape.size(); i++) {
size *= weight_shape[i];
}
weight_size_ = size * sizeof(T);
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (weight_shape.size() != 1 && weight_shape[0] != 1 && weight_shape[0] != channel_num) {
MS_LOG(EXCEPTION) << "PReLU requires the rank of weight should be 1, and the elements number should be "
"1 or channels number "
<< channel_num << ", but got weight shape " << weight_shape;
}
weight_length_ = weight_shape[0];
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
is_null_input_ = false;
input_length_ = 0;
weight_length_ = 0;
per_channel_length_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
input_size_ = 0;
workspace_size_ = 0;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(input_size_);
workspace_size_list_.push_back(workspace_size_);
size_t data_size = sizeof(T);
input_size_list_.push_back(input_length_ * data_size);
input_size_list_.push_back(weight_length_ * data_size);
output_size_list_.push_back(input_length_ * data_size);
}
private:
bool is_null_input_;
size_t input_length_{0};
size_t weight_length_{0};
size_t per_channel_length_{0};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t weight_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/prelu_grad_kernel.h"
#include "backend/kernel_compiler/gpu/nn/prelu_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
@ -25,7 +25,7 @@ MS_REG_GPU_KERNEL_ONE(PReLUGrad,
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PReLUGpuGradKernel, float)
PReLUGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(PReLUGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
@ -33,6 +33,6 @@ MS_REG_GPU_KERNEL_ONE(PReLUGrad,
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
PReLUGpuGradKernel, half)
PReLUGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,121 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_
#include <vector>
#include <map>
#include <functional>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/prelu_grad_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class PReLUGradGpuKernel : public GpuKernel {
public:
PReLUGradGpuKernel() = default;
~PReLUGradGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto *dy = GetDeviceAddress<T>(inputs, 0);
auto *x = GetDeviceAddress<T>(inputs, 1);
auto *w = GetDeviceAddress<T>(inputs, 2);
auto *dx = GetDeviceAddress<T>(outputs, 0);
auto *dw = GetDeviceAddress<T>(outputs, 1);
CalPReLUGrad(input_length_, weight_length_, per_channel_length_, dy, x, w, dx, dw,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
ResetResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "ReLUGrad needs 3 inputs, but got " << input_num;
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(ERROR) << "ReLUGrad should have 2 outputs, but got " << input_num;
return false;
}
auto x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
input_length_ = std::accumulate(x_shape.begin(), x_shape.end(), size_t(1), std::multiplies<>());
size_t x_rank = x_shape.size();
size_t channel_num;
if (x_rank == 0) {
channel_num = 1;
per_channel_length_ = 1;
} else if (x_rank == 1) {
channel_num = 1;
per_channel_length_ = x_shape[0];
} else {
channel_num = x_shape[1];
per_channel_length_ = std::accumulate(x_shape.begin() + 2, x_shape.end(), size_t(1), std::multiplies<>());
}
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
if (weight_shape.size() != 1 && weight_shape[0] != 1 && weight_shape[0] != channel_num) {
MS_LOG(EXCEPTION) << "PReLUGrad requires the rank of weight should be 1, and the elements number should be "
"1 or channels number "
<< channel_num << ", but got weight shape " << weight_shape;
}
weight_length_ = weight_shape[0];
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_length_ = 0;
weight_length_ = 0;
per_channel_length_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
size_t data_size = sizeof(T);
input_size_list_.push_back(input_length_ * data_size);
input_size_list_.push_back(input_length_ * data_size);
input_size_list_.push_back(weight_length_ * data_size);
output_size_list_.push_back(input_length_ * data_size);
output_size_list_.push_back(weight_length_ * data_size);
}
private:
size_t input_length_{0};
size_t weight_length_{0};
size_t per_channel_length_{0};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_GPU_KERNEL_H_

View File

@ -1,196 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class PReLUGpuGradKernel : public GpuKernel {
public:
PReLUGpuGradKernel()
: data_format_(kOpFormat_NCDHW),
input_size_(0),
weight_size_(0),
reduce_workspace_size_(0),
spatial_count_(1),
is_null_input_(false),
channel_shared_(false),
channel_last_(false) {}
~PReLUGpuGradKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *dy_addr = GetDeviceAddress<T>(inputs, 0);
T *x_addr = GetDeviceAddress<T>(inputs, 1);
T *w_addr = GetDeviceAddress<T>(inputs, 2);
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
T *dw_addr = GetDeviceAddress<T>(outputs, 1);
T *dw_collector_addr = GetDeviceAddress<T>(workspace, 0);
T *reduce_workspace_addr = GetDeviceAddress<T>(workspace, 1);
PReluChannelSharedGrad(input_size_ / sizeof(T), dy_addr, x_addr, w_addr, dx_addr, dw_collector_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (data_type_ == CUDNN_DATA_DOUBLE) {
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
reduce_workspace_size_, &alpha, grad_weight_collector_descriptor_, dw_collector_addr, &beta,
grad_weight_descriptor_, dw_addr),
"cudnnReduceTensor failed.");
} else {
const float alphaf = static_cast<float>(1.0f);
const float betaf = static_cast<float>(0.0f);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, reduce_workspace_addr,
reduce_workspace_size_, &alphaf, grad_weight_collector_descriptor_, dw_collector_addr, &betaf,
grad_weight_descriptor_, dw_addr),
"cudnnReduceTensor failed.");
}
return true;
}
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_),
"cudnnCreateReduceTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_collector_descriptor_),
"cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&grad_weight_descriptor_),
"cudnnCreateTensorDescriptor failed.");
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_),
"cudnnDestroyReduceTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_collector_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(grad_weight_descriptor_),
"cudnnDestroyTensorDescriptor failed.");
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
input_size_ = sizeof(T);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null.";
}
for (size_t i = 0; i < input_shape.size(); ++i) {
input_size_ *= input_shape[i];
}
weight_size_ = sizeof(T);
auto weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
is_null_input_ = CHECK_NULL_INPUT(weight_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "PReLUGpuBwdKernel input is null.";
}
for (auto dim : weight_shape) {
weight_size_ *= dim;
}
channel_shared_ = (weight_shape[0] == 1);
if (!channel_shared_) {
MS_LOG(WARNING)
<< "PReLUGpuBwdKernel shares weight for all channels, but the given weight tensor has more than one element.";
}
spatial_count_ = 1;
if (channel_last_) {
for (size_t i = 1; i < input_shape.size() - 1; ++i) {
spatial_count_ *= input_shape[i];
}
} else {
for (size_t i = 2; i < input_shape.size(); ++i) {
spatial_count_ *= input_shape[i];
}
}
data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
int input_dim_length = input_shape.size();
std::vector<size_t> reduce_out_shape(input_dim_length, 1);
if (channel_last_) {
reduce_out_shape[input_dim_length - 1] = weight_shape[0];
} else {
reduce_out_shape[1] = weight_shape[0];
}
InitResource();
CudnnSetTensorNdDescriptor(reduce_out_shape, grad_weight_descriptor_, data_type_, kernel_node_);
CudnnSetTensorNdDescriptor(input_shape, grad_weight_collector_descriptor_, data_type_, kernel_node_);
cudnnDataType_t comp_type = (data_type_ == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, CUDNN_REDUCE_TENSOR_ADD, comp_type,
CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES),
"cudnnSetReduceTensorDescriptor failed");
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
input_size_list_.push_back(input_size_);
input_size_list_.push_back(weight_size_);
output_size_list_.push_back(input_size_);
output_size_list_.push_back(weight_size_);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, grad_weight_collector_descriptor_,
grad_weight_descriptor_, &reduce_workspace_size_),
"cudnnGetReductionWorkspaceSize failed.");
workspace_size_list_.push_back(input_size_);
workspace_size_list_.push_back(reduce_workspace_size_);
}
private:
cudnnHandle_t cudnn_handle_;
cudnnDataType_t data_type_;
cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_;
cudnnTensorDescriptor_t grad_weight_collector_descriptor_;
cudnnTensorDescriptor_t grad_weight_descriptor_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::string data_format_ = kOpFormat_NCDHW;
size_t input_size_;
size_t weight_size_;
size_t reduce_workspace_size_;
size_t spatial_count_;
bool is_null_input_ = false;
bool channel_shared_ = false;
bool channel_last_ = false;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_PRELU_GRAD_KERNEL_H_

View File

@ -40,7 +40,7 @@ static bool CheckStridedSlice(const CNodePtr &cnode) {
}
}
// check reduction on the last dimension
if (AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
if (GetCNodeFuncName(cnode) == kStridedSliceOpName && AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, cnode)) {
auto shrink_axis_mask = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrShrinkAxisMask));
AnfNodePtr input = cnode->input(1);
int input_dims = 0;

View File

@ -14,15 +14,15 @@
# ============================================================================
"""activation"""
import numpy as np
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._extends import cell_attr_register
from mindspore._checkparam import Validator as validator
from ..cell import Cell
from mindspore._checkparam import Validator as validator
from mindspore._extends import cell_attr_register
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from ..cell import Cell
__all__ = ['Softmax',
'LogSoftmax',
@ -548,22 +548,24 @@ class PReLU(Cell):
Activation_function#/media/File:Activation_prelu.svg>`_.
Args:
channel (int): The dimension of input. Default: 1.
w (Union[float, list, Tensor]): The initial value of w. Default: 0.25.
channel (int): The elements number of parameter.
It could be an int, and the value is 1 or the channels number of input tensor `x`. Default: 1.
w (Union[float, list, Tensor): The initial value of parameter. It could be a float, a float list or
a tensor has the same dtype as the input tensor `x`. Default: 0.25.
Inputs:
- **x** (Tensor) - The input of PReLU with data type of float16 or float32.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
Outputs:
Tensor, with the same type and shape as the `x`.
Tensor, with the same dtype and shape as the `x`.
Raises:
TypeError: If `channel` is not an int.
TypeError: If `w` is not one of float, list, Tensor.
TypeError: If `w` is not one of a float, a float list, a float Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.
ValueError: If the `x` is a 0-D or 1-D Tensor on Ascend.
ValueError: If `channel` is less than 1.
ValueError: If length of shape of `x` is equal to 1.
Supported Platforms:
``Ascend`` ``GPU``
@ -582,24 +584,34 @@ class PReLU(Cell):
"""Initialize PReLU."""
super(PReLU, self).__init__()
validator.check_positive_int(channel, 'channel', self.cls_name)
if isinstance(w, (np.float32, float)):
if isinstance(w, (float, np.float32)):
tmp = np.empty((channel,), dtype=np.float32)
tmp.fill(w)
w = Tensor(tmp)
w = Tensor(tmp, dtype=mstype.float32)
elif isinstance(w, list):
w = Tensor(w)
if not isinstance(w, Tensor):
raise TypeError("w only support np.float32, float, list or Tensor type.")
self.w = Parameter(initializer(w, [channel]), name='a')
if len(w) != channel:
raise ValueError(f"When the 'w' is a list, the length should be equal to the channel, "
f"but got the length {len(w)}, the channel {channel}")
for i in w:
if not isinstance(i, (float, np.float32)):
raise ValueError(f"When the 'w' is a list, the all elements should be float, but got {w}")
w = Tensor(w, dtype=mstype.float32)
elif isinstance(w, Tensor):
if w.dtype not in (mstype.float16, mstype.float32):
raise ValueError(f"When the 'w' is a tensor, the dtype should be float16 or float32, but got {w.dtype}")
if len(w.shape) != 1 or w.shape[0] != channel:
raise ValueError(f"When the 'w' is a tensor, the rank should be 1, and the elements number "
f"should be equal to the channel, but got w shape {w}, the channel {channel}")
else:
raise TypeError(f"The 'w' only supported float list and tensor, but got {type(w)}")
self.w = Parameter(w, name='a')
self.prelu = P.PReLU()
self.relu = P.ReLU()
self.assign = P.Assign()
def construct(self, x):
u = self.relu(self.w)
v = self.prelu(x, u)
v = self.prelu(x, F.cast(u, x.dtype))
if self.training:
self.assign(self.w, u)
return v

View File

@ -1544,8 +1544,6 @@ class PReLUGrad(PrimitiveWithInfer):
pass
def infer_shape(self, y_backprop_shape, a_shape, w_shape):
if len(a_shape) == 1:
raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.')
return y_backprop_shape, w_shape
def infer_dtype(self, y_backprop_dtype, a_dtype, w_dtype):

View File

@ -2151,6 +2151,7 @@ class Conv2DTranspose(Conv2DBackpropInput):
>>> print(output.shape)
(10, 32, 32, 32)
"""
@prim_attr_register
def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0,
pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"):
@ -3638,16 +3639,18 @@ class PReLU(PrimitiveWithInfer):
.. math::
prelu(x_i)= \max(0, x_i) + \min(0, w * x_i),
where :math:`x_i` is an element of an channel of the input.
where :math:`x_i` is an element of an channel of the input, `w` is the weight of the channel.
Note:
1-dimensional input_x is not supported.
0-D or 1-D input_x is not supported on Ascend.
Inputs:
- **input_x** (Tensor) - Float tensor, representing the output of the preview layer.
With data type of float16 or float32.
- **weight** (Tensor) - Float Tensor, w > 0, there are only two shapes are legitimate,
1 or the number of channels of the input. With data type of float16 or float32.
- **input_x** (Tensor) - The first input tensor. The data type is float16 or float32.
Represents the output of the preview layer.
- **weight** (Tensor) - The second input tensor. The data type is float16 or float32.
There are only two shapes are legitimate, 1 or the number of channels of the `input_x`.
Channel dim is the 2nd dim of input. When input is 0-D or 1-D tensor, the number of channels is 1.
Outputs:
Tensor, with the same type as `input_x`.
@ -3656,9 +3659,9 @@ class PReLU(PrimitiveWithInfer):
Raises:
TypeError: If dtype of `input_x` or `weight` is neither float16 nor float32.
TypeError: If `input_x` or `weight` is not a Tensor.
ValueError: If length of shape of `input_x` is equal to 1.
ValueError: If length of shape of `weight` is not equal to 1.
TypeError: If the `input_x` or the `weight` is not a Tensor.
ValueError: If the `input_x` is a 0-D or 1-D Tensor on Ascned.
ValueError: If the `weight` is not a 1-D Tensor.
Supported Platforms:
``Ascend`` ``GPU``
@ -3677,12 +3680,17 @@ class PReLU(PrimitiveWithInfer):
... result = self.prelu(input_x, weight)
... return result
...
>>> input_x = Tensor(np.random.randint(-3, 3, (2, 3, 2)), mindspore.float32)
>>> input_x = Tensor(np.arange(-6, 6).reshape((2, 3, 2)), mindspore.float32)
>>> weight = Tensor(np.array([0.1, 0.6, -0.3]), mindspore.float32)
>>> net = Net()
>>> output = net(input_x, weight)
>>> print(output.shape)
(2, 3, 2)
>>> print(output)
[[[-0.60 -0.50]
[-2.40 -1.80]
[ 0.60 0.30]]
[[ 0.00 1.00]
[ 2.00 3.00]
[ 4.0 5.00]]]
"""
@prim_attr_register
@ -3691,25 +3699,29 @@ class PReLU(PrimitiveWithInfer):
def infer_shape(self, input_x_shape, weight_shape):
input_x_dim = len(input_x_shape)
if input_x_dim in (0, 1):
if context.get_context("device_target") == "Ascend":
raise ValueError(f"For '{self.name}', the 0-D or 1-D 'input_x' is not supported on Ascend.")
channel_num = 1
else:
channel_num = input_x_shape[1]
weight_dim = len(weight_shape)
if input_x_dim == 1:
raise ValueError(f'For \'{self.name}\' input_x rank 1 is not supported.')
if weight_dim != 1:
raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.')
if weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1:
raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,'
f' while channel of input_x is {input_x_shape[1]},'
f' weight_shape[0] is {weight_shape[0]}.')
raise ValueError(f"For '{self.name}', the weight dimension should be 1, while got {weight_dim}.")
if weight_shape[0] != 1 and weight_shape[0] != channel_num:
raise ValueError(f"For '{self.name}', the weight shape should be (1,) or "
f"matched with input channel ({channel_num},), but got {weight_shape}")
return input_x_shape
def infer_dtype(self, input_x_dtype, weight_dtype):
valid_dtypes = (mstype.float16, mstype.float32)
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
args = {"input_x": input_x_dtype, "weight": weight_dtype}
if context.get_context("device_target") == "GPU":
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
else:
validator.check_tensor_dtype_valid("input_x", input_x_dtype, valid_dtypes, self.name)
validator.check_tensor_dtype_valid("weight", weight_dtype, valid_dtypes, self.name)
return input_x_dtype
@ -7876,6 +7888,7 @@ class AvgPool3D(Primitive):
[[[[[ 5. 6.]]]
[[[17. 18.]]]]]
"""
@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", pad=0, ceil_mode=False,
count_include_pad=True, divisor_override=0, data_format="NCDHW"):
@ -8399,7 +8412,6 @@ class CTCLossV2Grad(Primitive):
self.add_prim_attr("zero_infinity", zero_infinity)
class Conv3DTranspose(PrimitiveWithInfer):
r"""
Computes a 3D transposed convolution, which is also known as a deconvolution

View File

@ -1,61 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class NetPReLUGrad(nn.Cell):
def __init__(self):
super(NetPReLUGrad, self).__init__()
self.prelu_grad = G.PReLUGrad()
def construct(self, dout, x, w):
return self.prelu_grad(dout, x, w)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_grad_fp32_channel_shared():
dout = Tensor(np.ones(shape=[2, 2, 2, 3]).astype(np.float32))
x = Tensor(np.arange(-5, 19).reshape(2, 2, 2, 3).astype(np.float32))
w = Tensor(np.array([-0.5]).astype(np.float32))
expect_dx = np.array([[[[-0.5000, -0.5000, -0.5000],
[-0.5000, -0.5000, -0.5000]],
[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]]],
[[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]],
[[1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000]]]]).astype(np.float32)
expect_dw = np.array([-15.]).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
prelu_grad = NetPReLUGrad()
dx, dw = prelu_grad(dout, x, w)
assert (dx.asnumpy() == expect_dx).all()
assert (dw.asnumpy() == expect_dw).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
prelu_grad = NetPReLUGrad()
dx, dw = prelu_grad(dout, x, w)
assert (dx.asnumpy() == expect_dx).all()
assert (dw.asnumpy() == expect_dw).all()

View File

@ -20,55 +20,215 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
class NetPReLU(nn.Cell):
class PReLUOpNet(nn.Cell):
def __init__(self):
super(NetPReLU, self).__init__()
super(PReLUOpNet, self).__init__()
self.prelu = P.PReLU()
def construct(self, x, weight):
return self.prelu(x, weight)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_float16():
weight = Tensor(np.array([0.25]).astype(np.float16))
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float16))
expect = np.array([[[[-0.25, 1, 10,],
[1, -0.25, 1,],
[10, 1, -0.25]]]]).astype(np.float16)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
class PReLUOpGradNet(nn.Cell):
def __init__(self, net):
super(PReLUOpGradNet, self).__init__()
self.forward = net
self.grad = C.GradOperation(get_all=True, sens_param=False)
def construct(self, x, weight):
return self.grad(self.forward)(x, weight)
def judge_result_correct(result, expect):
result = result.asnumpy()
expect = expect.asnumpy()
assert result.dtype == expect.dtype
assert result.shape == expect.shape
assert np.allclose(result, expect, rtol=1.e-2)
def test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode):
context.set_context(mode=mode)
prelu_forward = PReLUOpNet()
prelu_backward = PReLUOpGradNet(prelu_forward)
forward_output = prelu_forward(x, weight)
judge_result_correct(forward_output, expect_forward)
backward_output = prelu_backward(x, weight)
assert len(backward_output) == 2
judge_result_correct(backward_output[0], expect_dx)
judge_result_correct(backward_output[1], expect_dw)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_float32():
weight = Tensor(np.array([0.25]).astype(np.float32))
x = Tensor(np.array([[[[-1, 1, 10],
[1, -1, 1],
[10, 1, -1]]]]).astype(np.float32))
expect = np.array([[[[-0.25, 1, 10,],
[1, -0.25, 1,],
[10, 1, -0.25]]]]).astype(np.float32)
def test_prelu_single_weight():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
expect_dx = np.where(x > 0, 1, weight)
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
prelu = NetPReLU()
output = prelu(x, weight)
assert (output.asnumpy() == expect).all()
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_multiple_weight():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((2, 3, 2, 3)) * 0.6
weight = np.array([0.2, 0.3, 0.4])
expect_forward = np.array([[[[-1.20, -1.08, -0.96],
[-0.84, -0.72, -0.60]],
[[-0.72, -0.54, -0.36],
[-0.18, 0.00, 0.60]],
[[1.20, 1.80, 2.40],
[3.00, 3.60, 4.20]]],
[[[4.80, 5.40, 6.00],
[6.60, 7.20, 7.80]],
[[8.40, 9.00, 9.60],
[10.20, 10.80, 11.40]],
[[12.00, 12.60, 13.20],
[13.80, 14.40, 15.00]]]])
expect_dx = np.array([[[[0.2, 0.2, 0.2],
[0.2, 0.2, 0.2]],
[[0.3, 0.3, 0.3],
[0.3, 0.3, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]],
[[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]],
[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]]])
expect_dw = np.array([-27.0, -6.0, 0.0])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_0_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.array(-0.8)
weight = np.array([0.6])
expect_forward = np.array(-0.48)
expect_dx = np.array(0.6)
expect_dw = np.array([-0.8])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_1_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((36,)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
expect_dx = np.where(x > 0, 1, weight)
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_single_weight_2_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-10, 26).reshape((4, 9)) * 0.7
weight = np.array([0.6])
expect_forward = np.where(x >= 0, x, weight * x)
expect_dx = np.where(x > 0, 1, weight)
expect_dw = np.sum(np.where(x >= 0, 0, x)).reshape((1,))
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_prelu_multiple_weight_2_D():
context.set_context(device_target="GPU")
dtypes = [mstype.float16, mstype.float32]
modes = [context.GRAPH_MODE, context.GRAPH_MODE]
x = np.arange(-6, 6).reshape((3, 4)) * 0.6
weight = np.array([0.2, 0.4, 0.7, 0.9])
expect_forward = np.array([[-0.72, -1.20, -1.68, -1.62],
[-0.24, -0.24, 0.00, 0.60],
[1.20, 1.80, 2.40, 3.00]])
expect_dx = np.array([[0.2, 0.4, 0.7, 0.9],
[0.2, 0.4, 0.7, 1.0],
[1.0, 1.0, 1.0, 1.0]])
expect_dw = np.array([-4.8, -3.6, -2.4, -1.8])
for dtype in dtypes:
for mode in modes:
x = Tensor(x, dtype)
weight = Tensor(weight, dtype)
expect_forward = Tensor(expect_forward, dtype)
expect_dx = Tensor(expect_dx, dtype)
expect_dw = Tensor(expect_dw, dtype)
test_prelu(x, weight, expect_forward, expect_dx, expect_dw, mode)