forked from mindspore-Ecosystem/mindspore
!16320 Local Response Norm GPU Op support
From: @tom__chen Reviewed-by: @robingrosman Signed-off-by: @robingrosman
This commit is contained in:
commit
2db8656048
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/local_response_norm_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ComputeScaleNHWC(const T *input, const int depth_radius, const float bias, const float alpha,
|
||||
const size_t channels, const size_t num_elements, float *scale) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < num_elements; pos += blockDim.x * gridDim.x) {
|
||||
const int posc = static_cast<int>(pos % channels);
|
||||
float sqr_sum = 0;
|
||||
for (int i = -depth_radius; i < depth_radius + 1; i++) {
|
||||
if (posc + i >= 0 && posc + i < static_cast<int>(channels)) {
|
||||
float a = static_cast<float>(input[pos + i]);
|
||||
sqr_sum += a * a;
|
||||
}
|
||||
}
|
||||
scale[pos] = bias + alpha * sqr_sum;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LocalResponseNormNHWC(const T *input, const float *scale, const float beta, const size_t num_elements,
|
||||
T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < num_elements; pos += blockDim.x * gridDim.x) {
|
||||
float z = expf(logf(scale[pos]) * -beta);
|
||||
output[pos] = input[pos] * static_cast<T>(z);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LocalResponseNormGradNHWC(const T *dy, const T *x, const T *y, const float *scale,
|
||||
const int depth_radius, const float alpha, const float beta, const float neg2_alpha_beta, const size_t channels,
|
||||
const size_t num_elements, T *dx) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < num_elements; pos += blockDim.x * gridDim.x) {
|
||||
const int posc = static_cast<int>(pos % channels);
|
||||
float ratio_sum = 0;
|
||||
for (int i = -depth_radius; i <= depth_radius; i++) {
|
||||
if (posc + i >= 0 && posc + i < static_cast<int>(channels)) {
|
||||
ratio_sum += static_cast<float>(dy[pos + i] * y[pos + i]) / scale[pos + i];
|
||||
}
|
||||
}
|
||||
float z = expf(logf(scale[pos]) * -beta);
|
||||
float ratio_2ab = ratio_sum * neg2_alpha_beta;
|
||||
dx[pos] = dy[pos] * static_cast<T>(z) + x[pos] * static_cast<T>(ratio_2ab);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalLocalResponseNormNHWC(const T *input, const int depth_radius, const float bias, const float alpha,
|
||||
const float beta, const size_t channels, const size_t num_elements, float *scale, T *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
ComputeScaleNHWC<<<GET_BLOCKS(num_elements), GET_THREADS, 0, cuda_stream>>>(input, depth_radius, bias, alpha,
|
||||
channels, num_elements, scale);
|
||||
LocalResponseNormNHWC<<<GET_BLOCKS(num_elements), GET_THREADS, 0, cuda_stream>>>(input, scale, beta, num_elements,
|
||||
output);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalLocalResponseNormGradNHWC(const T *dy, const T *x, const T *y, const int depth_radius, const float bias,
|
||||
const float alpha, const float beta, const size_t channels, const size_t num_elements, float *scale, T *dx,
|
||||
cudaStream_t cuda_stream) {
|
||||
float neg2_alpha_beta = -2.0f * alpha * beta;
|
||||
ComputeScaleNHWC<<<GET_BLOCKS(num_elements), GET_THREADS, 0, cuda_stream>>>(x, depth_radius, bias, alpha, channels,
|
||||
num_elements, scale);
|
||||
LocalResponseNormGradNHWC<<<GET_BLOCKS(num_elements), GET_THREADS, 0, cuda_stream>>>(dy, x, y, scale, depth_radius,
|
||||
alpha, beta, neg2_alpha_beta, channels, num_elements, dx);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalLocalResponseNormNHWC<float>(const float *input, const int depth_radius, const float bias,
|
||||
const float alpha, const float beta, const size_t channels, const size_t num_elements, float *scale, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalLocalResponseNormNHWC<half>(const half *input, const int depth_radius, const float bias,
|
||||
const float alpha, const float beta, const size_t channels, const size_t num_elements, float *scale, half *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalLocalResponseNormGradNHWC<float>(const float *dy, const float *x, const float *y,
|
||||
const int depth_radius, const float bias, const float alpha, const float beta, const size_t channels,
|
||||
const size_t num_elements, float *scale, float *dx, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalLocalResponseNormGradNHWC<half>(const half *dy, const half *x, const half *y,
|
||||
const int depth_radius, const float bias, const float alpha, const float beta, const size_t channels,
|
||||
const size_t num_elements, float *scale, half *dx, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOCAL_RESPONSE_NORM_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOCAL_RESPONSE_NORM_H_
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void CalLocalResponseNormNHWC(const T *input, const int depth_radius, const float bias, const float alpha,
|
||||
const float beta, const size_t channels, const size_t num_elements, float *scale, T *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalLocalResponseNormGradNHWC(const T *dy, const T *x, const T *y, const int depth_radius, const float bias,
|
||||
const float alpha, const float beta, const size_t channels, const size_t num_elements, float *scale, T *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LOCAL_RESPONSE_NORM_H_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/local_response_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LRN, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
LocalResponseNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(LRN, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
LocalResponseNormGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,271 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/local_response_norm_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LocalResponseNormGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LocalResponseNormGpuKernel() { ResetResource(); }
|
||||
~LocalResponseNormGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (use_native_) {
|
||||
std::vector<size_t> to_nhwc_axis = {0, 2, 3, 1};
|
||||
std::vector<size_t> to_nchw_axis = {0, 3, 1, 2};
|
||||
size_t shape_size = 4 * sizeof(size_t);
|
||||
size_t *ws_input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *ws_transpose_shape = GetDeviceAddress<size_t>(workspace, 1);
|
||||
size_t *ws_to_nhwc_axis = GetDeviceAddress<size_t>(workspace, 2);
|
||||
size_t *ws_to_nchw_axis = GetDeviceAddress<size_t>(workspace, 3);
|
||||
T *ws_x = GetDeviceAddress<T>(workspace, 4);
|
||||
T *ws_y = GetDeviceAddress<T>(workspace, 5);
|
||||
float *ws_scale = GetDeviceAddress<float>(workspace, 6);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_input_shape, &input_shape_[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_transpose_shape, &transpose_shape_[0], shape_size,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync transpose_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_to_nhwc_axis, &to_nhwc_axis[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nhwc_axis failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_to_nchw_axis, &to_nchw_axis[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nchw_axis failed");
|
||||
|
||||
CalNCHW2NHWCInterface(num_elements_, 4, x, &input_shape_[0], &to_nhwc_axis[0], ws_input_shape, ws_to_nhwc_axis,
|
||||
ws_x, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalLocalResponseNormNHWC(ws_x, depth_radius_, bias_, alpha_, beta_, transpose_shape_[3], num_elements_, ws_scale,
|
||||
ws_y, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalNHWC2NCHWInterface(num_elements_, 4, ws_y, &transpose_shape_[0], &to_nchw_axis[0], ws_transpose_shape,
|
||||
ws_to_nchw_axis, y, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnLRNCrossChannelForward(handle_, norm_desc_, lrn_mode_, &alpha, x_desc_, x, &beta, y_desc_, y),
|
||||
"cudnnLRNCrossChannelForward failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
depth_radius_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "depth_radius"));
|
||||
bias_ = GetAttr<float>(kernel_node, "bias");
|
||||
alpha_ = GetAttr<float>(kernel_node, "alpha");
|
||||
beta_ = GetAttr<float>(kernel_node, "beta");
|
||||
|
||||
use_native_ = false;
|
||||
unsigned int lrnN = 2 * depth_radius_ + 1;
|
||||
double lrnAlpha = lrnN * alpha_;
|
||||
if (lrnN < CUDNN_LRN_MIN_N || lrnN > CUDNN_LRN_MAX_N || bias_ < CUDNN_LRN_MIN_K || beta_ < CUDNN_LRN_MIN_BETA) {
|
||||
use_native_ = true;
|
||||
}
|
||||
InitResource();
|
||||
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "LocalResponseNormGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (input_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape.size() << ", LocalResponseNormGpuKernel should be 4D";
|
||||
}
|
||||
|
||||
if (use_native_) {
|
||||
num_elements_ = 1;
|
||||
for (auto x : input_shape) {
|
||||
input_shape_.push_back(x);
|
||||
num_elements_ *= x;
|
||||
}
|
||||
transpose_shape_.push_back(input_shape_[0]);
|
||||
transpose_shape_.push_back(input_shape_[2]);
|
||||
transpose_shape_.push_back(input_shape_[3]);
|
||||
transpose_shape_.push_back(input_shape_[1]);
|
||||
} else {
|
||||
lrn_mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
SetCUDNNDescriptors(input_shape, lrnN, lrnAlpha);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
is_null_input_ = false;
|
||||
x_desc_ = nullptr;
|
||||
y_desc_ = nullptr;
|
||||
norm_desc_ = nullptr;
|
||||
lrn_mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
|
||||
handle_ = nullptr;
|
||||
cudnn_data_type_ = CUDNN_DATA_FLOAT;
|
||||
depth_radius_ = 0;
|
||||
bias_ = 0;
|
||||
alpha_ = 0;
|
||||
beta_ = 0;
|
||||
use_native_ = false;
|
||||
num_elements_ = 0;
|
||||
input_shape_.clear();
|
||||
transpose_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
if (!use_native_) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyLRNDescriptor(norm_desc_), "Destroy LRN norm desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
if (!use_native_) {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateLRNDescriptor(&norm_desc_), "Create LRN norm desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
if (!is_null_input_) {
|
||||
if (use_native_) {
|
||||
input_size_ = num_elements_ * sizeof(T);
|
||||
output_size_ = num_elements_ * sizeof(T);
|
||||
size_t shape_size = 4 * sizeof(size_t);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(num_elements_ * sizeof(float));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size_),
|
||||
"Get input x size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size_),
|
||||
"Get output y size failed");
|
||||
}
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but LocalResponseNormGpuKernel needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but LocalResponseNormGpuKernel needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetCUDNNDescriptors(const std::vector<size_t> &shape, int lrnN, double lrnAlpha) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
int batch, channel, height, width;
|
||||
batch = SizeToInt(shape[0]);
|
||||
channel = SizeToInt(shape[1]);
|
||||
height = SizeToInt(shape[2]);
|
||||
width = SizeToInt(shape[3]);
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set x desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set y desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetLRNDescriptor(norm_desc_, lrnN, lrnAlpha, beta_, bias_),
|
||||
"cudnnSetLRNDescriptor failed");
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
||||
cudnnLRNDescriptor_t norm_desc_;
|
||||
cudnnLRNMode_t lrn_mode_;
|
||||
cudnnHandle_t handle_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
int depth_radius_;
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
bool use_native_;
|
||||
size_t num_elements_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> transpose_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GPU_KERNEL_H_
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/local_response_norm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(LRNGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
LocalResponseNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(LRNGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
LocalResponseNormGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,307 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/local_response_norm_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LocalResponseNormGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
LocalResponseNormGradGpuKernel() { ResetResource(); }
|
||||
~LocalResponseNormGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||
auto y = GetDeviceAddress<T>(inputs, 2);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if (use_native_) {
|
||||
MS_LOG(WARNING) << "TOM: grad use native";
|
||||
MS_LOG(WARNING) << "TOM: num_elements_ " << num_elements_;
|
||||
std::vector<size_t> to_nhwc_axis = {0, 2, 3, 1};
|
||||
std::vector<size_t> to_nchw_axis = {0, 3, 1, 2};
|
||||
size_t shape_size = 4 * sizeof(size_t);
|
||||
size_t *ws_input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *ws_transpose_shape = GetDeviceAddress<size_t>(workspace, 1);
|
||||
size_t *ws_to_nhwc_axis = GetDeviceAddress<size_t>(workspace, 2);
|
||||
size_t *ws_to_nchw_axis = GetDeviceAddress<size_t>(workspace, 3);
|
||||
T *ws_dy = GetDeviceAddress<T>(workspace, 4);
|
||||
T *ws_x = GetDeviceAddress<T>(workspace, 5);
|
||||
T *ws_y = GetDeviceAddress<T>(workspace, 6);
|
||||
T *ws_dx = GetDeviceAddress<T>(workspace, 7);
|
||||
float *ws_scale = GetDeviceAddress<float>(workspace, 8);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_input_shape, &input_shape_[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_transpose_shape, &transpose_shape_[0], shape_size,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync transpose_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_to_nhwc_axis, &to_nhwc_axis[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nhwc_axis failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(ws_to_nchw_axis, &to_nchw_axis[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nchw_axis failed");
|
||||
|
||||
CalNCHW2NHWCInterface(num_elements_, 4, dy, &input_shape_[0], &to_nhwc_axis[0], ws_input_shape, ws_to_nhwc_axis,
|
||||
ws_dy, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNCHW2NHWCInterface(num_elements_, 4, x, &input_shape_[0], &to_nhwc_axis[0], ws_input_shape, ws_to_nhwc_axis,
|
||||
ws_x, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNCHW2NHWCInterface(num_elements_, 4, y, &input_shape_[0], &to_nhwc_axis[0], ws_input_shape, ws_to_nhwc_axis,
|
||||
ws_y, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalLocalResponseNormGradNHWC(ws_dy, ws_x, ws_y, depth_radius_, bias_, alpha_, beta_, transpose_shape_[3],
|
||||
num_elements_, ws_scale, ws_dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CalNHWC2NCHWInterface(num_elements_, 4, ws_dx, &transpose_shape_[0], &to_nchw_axis[0], ws_transpose_shape,
|
||||
ws_to_nchw_axis, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnLRNCrossChannelBackward(handle_, norm_desc_, lrn_mode_, &alpha, y_desc_, y,
|
||||
dy_desc_, dy, x_desc_, x, &beta, dx_desc_, dx),
|
||||
"cudnnLRNCrossChannelBackward failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
depth_radius_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "depth_radius"));
|
||||
bias_ = GetAttr<float>(kernel_node, "bias");
|
||||
alpha_ = GetAttr<float>(kernel_node, "alpha");
|
||||
beta_ = GetAttr<float>(kernel_node, "beta");
|
||||
|
||||
use_native_ = false;
|
||||
unsigned int lrnN = 2 * depth_radius_ + 1;
|
||||
double lrnAlpha = lrnN * alpha_;
|
||||
if (lrnN < CUDNN_LRN_MIN_N || lrnN > CUDNN_LRN_MAX_N || bias_ < CUDNN_LRN_MIN_K || beta_ < CUDNN_LRN_MIN_BETA) {
|
||||
use_native_ = true;
|
||||
}
|
||||
InitResource();
|
||||
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "LocalResponseNormGradGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (input_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape.size() << ", LocalResponseNormGradGpuKernel should be 4D";
|
||||
}
|
||||
|
||||
if (use_native_) {
|
||||
num_elements_ = 1;
|
||||
for (auto x : input_shape) {
|
||||
input_shape_.push_back(x);
|
||||
num_elements_ *= x;
|
||||
}
|
||||
transpose_shape_.push_back(input_shape_[0]);
|
||||
transpose_shape_.push_back(input_shape_[2]);
|
||||
transpose_shape_.push_back(input_shape_[3]);
|
||||
transpose_shape_.push_back(input_shape_[1]);
|
||||
} else {
|
||||
lrn_mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
SetCUDNNDescriptors(input_shape, lrnN, lrnAlpha);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
is_null_input_ = false;
|
||||
dy_desc_ = nullptr;
|
||||
x_desc_ = nullptr;
|
||||
y_desc_ = nullptr;
|
||||
dx_desc_ = nullptr;
|
||||
norm_desc_ = nullptr;
|
||||
lrn_mode_ = CUDNN_LRN_CROSS_CHANNEL_DIM1;
|
||||
handle_ = nullptr;
|
||||
cudnn_data_type_ = CUDNN_DATA_FLOAT;
|
||||
depth_radius_ = 0;
|
||||
bias_ = 0;
|
||||
alpha_ = 0;
|
||||
beta_ = 0;
|
||||
use_native_ = false;
|
||||
num_elements_ = 0;
|
||||
input_shape_.clear();
|
||||
transpose_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
if (!use_native_) {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyLRNDescriptor(norm_desc_), "Destroy LRN norm desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
if (!use_native_) {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateLRNDescriptor(&norm_desc_), "Create LRN norm desc failed");
|
||||
}
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
if (!is_null_input_) {
|
||||
if (use_native_) {
|
||||
input_size_ = num_elements_ * sizeof(T);
|
||||
output_size_ = num_elements_ * sizeof(T);
|
||||
size_t shape_size = 4 * sizeof(size_t);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(shape_size);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(num_elements_ * sizeof(float));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_desc_, &input_size_),
|
||||
"Get input dy size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size_),
|
||||
"Get input x size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &input_size_),
|
||||
"Get input y size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dx_desc_, &output_size_),
|
||||
"Get output dx size failed");
|
||||
}
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 3) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but LocalResponseNormGradGpuKernel needs 3 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but LocalResponseNormGradGpuKernel needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetCUDNNDescriptors(const std::vector<size_t> &shape, int lrnN, double lrnAlpha) {
|
||||
int batch = SizeToInt(shape[0]);
|
||||
int channel = SizeToInt(shape[1]);
|
||||
int height = SizeToInt(shape[2]);
|
||||
int width = SizeToInt(shape[3]);
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set dy desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set x desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set y desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set dx desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetLRNDescriptor(norm_desc_, lrnN, lrnAlpha, beta_, bias_),
|
||||
"cudnnSetLRNDescriptor failed");
|
||||
}
|
||||
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t dy_desc_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
||||
cudnnTensorDescriptor_t dx_desc_;
|
||||
cudnnLRNDescriptor_t norm_desc_;
|
||||
cudnnLRNMode_t lrn_mode_;
|
||||
cudnnHandle_t handle_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
int depth_radius_;
|
||||
double bias_;
|
||||
double alpha_;
|
||||
double beta_;
|
||||
bool use_native_;
|
||||
size_t num_elements_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> transpose_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_LOCAL_RESPONSE_NORM_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
class MSLRNOpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MSLRNOpNet, self).__init__()
|
||||
self.lrn1 = P.LRN(depth_radius=2, bias=1.0, alpha=0.0001, beta=0.75)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.lrn1(x)
|
||||
return x
|
||||
|
||||
|
||||
class MSGradNet(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(MSGradNet, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
|
||||
self.network = network
|
||||
self.params = ParameterTuple(network.trainable_params())
|
||||
|
||||
def construct(self, x, dy):
|
||||
grad_op = self.grad(self.network, self.params)
|
||||
output = grad_op(x, dy)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lrn_ms():
|
||||
x = Tensor(np.array([[[[1.6243454, -0.6117564],
|
||||
[-0.5281718, -1.0729686]],
|
||||
[[0.86540765, -2.3015387],
|
||||
[1.7448118, -0.7612069]],
|
||||
[[0.3190391, -0.24937038],
|
||||
[1.4621079, -2.0601406]]]]).astype(np.float32))
|
||||
y_exp = np.array([[[[1.6239204, -0.61149347],
|
||||
[-0.5279556, -1.0724881]],
|
||||
[[0.86518127, -2.3005495],
|
||||
[1.7440975, -0.760866]],
|
||||
[[0.31895563, -0.2492632],
|
||||
[1.4615093, -2.059218]]]]).astype(np.float32)
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net = MSLRNOpNet()
|
||||
output = net(x)
|
||||
assert np.allclose(output.asnumpy(), y_exp)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = MSLRNOpNet()
|
||||
output = net(x)
|
||||
assert np.allclose(output.asnumpy(), y_exp)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lrn_grad():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
x = Tensor(np.array([[[[1.6243454, -0.6117564],
|
||||
[-0.5281718, -1.0729686]],
|
||||
[[0.86540765, -2.3015387],
|
||||
[1.7448118, -0.7612069]],
|
||||
[[0.3190391, -0.24937038],
|
||||
[1.4621079, -2.0601406]]]]).astype(np.float32))
|
||||
dy = Tensor(np.array([[[[-0.3224172, -0.38405436],
|
||||
[1.1337694, -1.0998913]],
|
||||
[[-0.1724282, -0.8778584],
|
||||
[0.04221375, 0.58281523]],
|
||||
[[-1.1006192, 1.1447237],
|
||||
[0.9015907, 0.50249434]]]]).astype(np.float32))
|
||||
dx_exp = np.array([[[[-0.3220835, -0.3837087],
|
||||
[1.133368, -1.0994467]],
|
||||
[[-0.17225023, -0.8768017],
|
||||
[0.04198911, 0.5825201]],
|
||||
[[-1.1002823, 1.1443052],
|
||||
[0.9010479, 0.50217706]]]]).astype(np.float32)
|
||||
net = MSLRNOpNet()
|
||||
grad_net = MSGradNet(net)
|
||||
grad_net.set_train(True)
|
||||
output = grad_net(x, dy)
|
||||
dx = output[0][0].asnumpy()
|
||||
assert np.allclose(dx, dx_exp, atol=1.0e-4, rtol=1.0e-4, equal_nan=True)
|
Loading…
Reference in New Issue