!35758 update InstanceNormGrad GPU kernel

Merge pull request !35758 from zhujingxuan/InstanceNormGrad
This commit is contained in:
i-robot 2022-06-11 08:05:44 +00:00 committed by Gitee
commit bf41f2b534
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 526 additions and 259 deletions

View File

@ -31,10 +31,12 @@ bool InstanceNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_),
"For 'InstanceNormGpuKernelMod', it create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&y_desc_),
"For 'InstanceNormGpuKernelMod', it create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
"Create para desc failed");
"For 'InstanceNormGpuKernelMod', it create para desc failed");
auto kernel_ptr = std::dynamic_pointer_cast<ops::InstanceNorm>(base_operator);
epsilon_ = kernel_ptr->get_epsilon();
@ -72,24 +74,24 @@ int InstanceNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
"For 'InstanceNormGpuKernelMod', it set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"Set y desc failed");
"For 'InstanceNormGpuKernelMod', it set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
"For 'InstanceNormGpuKernelMod', it set para desc failed");
size_t para_size = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, &para_size),
"Get para size failed");
"For 'InstanceNormGpuKernelMod', it get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_,
scale_bias_mean_var_desc_, nullptr, &workspace_size_),
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
"For 'InstanceNormGpuKernelMod', it launch cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
workspace_size_list_.clear();
workspace_size_list_ = {
@ -134,7 +136,7 @@ bool InstanceNormGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &input
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_,
ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr,
workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
"For 'InstanceNormGpuKernelMod', it launch cudnnBatchNormalizationForwardTrainingEx failed");
return true;
}

View File

@ -33,10 +33,12 @@ class InstanceNormGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHe
public:
InstanceNormGpuKernelMod() = default;
~InstanceNormGpuKernelMod() override {
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(x_desc_),
"For 'InstanceNormGpuKernelMod', it destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(y_desc_),
"For 'InstanceNormGpuKernelMod', it destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
"Destroy para desc failed");
"For 'InstanceNormGpuKernelMod', it destroy para desc failed");
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

View File

@ -15,30 +15,163 @@
*/
#include "plugin/device/gpu/kernel/nn/instance_norm_grad_gpu_kernel.h"
#include <map>
#include <utility>
#include "mindspore/core/ops/instance_norm_grad.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
InstanceNormGradGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
InstanceNormGradGpuKernelMod, half)
namespace {
using KernelRunFunc = InstanceNormGradGpuKernelMod::KernelRunFunc;
constexpr auto kNCDims = 2;
} // namespace
bool InstanceNormGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_),
"For 'InstanceNormGradGpuKernelMod', it create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&dy_desc_),
"For 'InstanceNormGradGpuKernelMod', it create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&dx_desc_),
"For 'InstanceNormGradGpuKernelMod', it create dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&scale_bias_diff_desc_),
"For 'InstanceNormGradGpuKernelMod', it create para desc failed");
auto kernel_ptr = std::dynamic_pointer_cast<ops::InstanceNormGrad>(base_operator);
epsilon_ = kernel_ptr->get_epsilon();
beta_data_diff_ = kernel_ptr->get_inplace_algo() == "cover" ? 0 : 1;
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs.at(kIndex0)->GetDtype()));
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int InstanceNormGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input_x");
if (is_null_input_) {
return KRET_OK;
}
batch_ = input_shape[kIndex0];
channel_ = input_shape[kIndex1];
CheckTensorSize({input_shape});
int batch = 1;
int channel = SizeToInt(batch_) * SizeToInt(channel_);
int height = 1;
int width = std::accumulate(input_shape.begin() + kNCDims, input_shape.end(), 1, std::multiplies{});
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"For 'InstanceNormGradGpuKernelMod', it set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"For 'InstanceNormGradGpuKernelMod', it set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"For 'InstanceNormGradGpuKernelMod', it set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"For 'InstanceNormGradGpuKernelMod', it set para desc failed");
size_t para_size;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, &para_size),
"For 'InstanceNormGradGpuKernelMod', it get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnGetBatchNormalizationBackwardExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_,
dx_desc_, scale_bias_diff_desc_, activation_desc_,
&workspace_size_),
"For 'InstanceNormGradGpuKernelMod', it launch cudnnGetBatchNormalizationBackwardExWorkspaceSize failed");
workspace_size_list_.clear();
workspace_size_list_ = {
para_size, // ws gamma
para_size, // ws dgamma
para_size, // ws dbeta
workspace_size_,
};
return KRET_OK;
}
template <typename T>
bool InstanceNormGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto dy = GetDeviceAddress<T>(inputs, kIndex0);
auto x = GetDeviceAddress<T>(inputs, kIndex1);
auto gamma = GetDeviceAddress<float>(inputs, kIndex2);
auto save_mean = GetDeviceAddress<float>(inputs, kIndex3);
auto save_variance = GetDeviceAddress<float>(inputs, kIndex4);
void *beta = nullptr;
T *y = nullptr;
auto dx = GetDeviceAddress<T>(outputs, kIndex0);
auto dgamma = GetDeviceAddress<float>(outputs, kIndex1);
auto dbeta = GetDeviceAddress<float>(outputs, kIndex2);
T *dz = nullptr;
float *ws_gamma = GetDeviceAddress<float>(workspace, kIndex0);
float *ws_dgamma = GetDeviceAddress<float>(workspace, kIndex1);
float *ws_dbeta = GetDeviceAddress<float>(workspace, kIndex2);
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, kIndex3);
CopyMemDevice2Device(batch_, channel_, gamma, nullptr, nullptr, nullptr, ws_gamma, nullptr, nullptr, nullptr,
stream_ptr_);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream_ptr_),
"For 'InstanceNormGradGpuKernelMod', it launch cudaStreamSynchronized failed");
const float alpha_data_diff = 1;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
float *reserve_addr = nullptr;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta,
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"For 'InstanceNormGradGpuKernelMod', it launch cudnnBatchNormalizationBackwardEx failed");
ComputeMean(batch_, channel_, dgamma, dbeta, ws_dgamma, ws_dbeta, stream_ptr_);
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &InstanceNormGradGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
&InstanceNormGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
&InstanceNormGradGpuKernelMod::LaunchKernel<half>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InstanceNormGrad, InstanceNormGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
#include <map>
#include <utility>
#include <string>
#include <vector>
#include "include/common/utils/utils.h"
@ -28,195 +30,67 @@
namespace mindspore {
namespace kernel {
template <typename T>
class InstanceNormGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class InstanceNormGradGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<InstanceNormGradGpuKernelMod> {
public:
InstanceNormGradGpuKernelMod()
: x_size_(0),
para_size_(0),
workspace_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
dy_desc_(nullptr),
dx_desc_(nullptr),
dz_desc_(nullptr),
scale_bias_diff_desc_(nullptr),
activation_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
beta_data_diff_(0) {}
~InstanceNormGradGpuKernelMod() override { DestroyResource(); }
InstanceNormGradGpuKernelMod() = default;
~InstanceNormGradGpuKernelMod() override {
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(x_desc_),
"For 'InstanceNormGradGpuKernelMod', it destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(dy_desc_),
"For 'InstanceNormGradGpuKernelMod', it destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(dx_desc_),
"For 'InstanceNormGradGpuKernelMod', it destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(scale_bias_diff_desc_),
"For 'InstanceNormGradGpuKernelMod', it destroy para desc failed");
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto gamma = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
void *beta = nullptr;
T *y = nullptr;
auto dx = GetDeviceAddress<T>(outputs, 0);
auto dgamma = GetDeviceAddress<float>(outputs, 1);
auto dbeta = GetDeviceAddress<float>(outputs, 2);
T *dz = nullptr;
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
float *ws_dgamma = GetDeviceAddress<float>(workspace, 1);
float *ws_dbeta = GetDeviceAddress<float>(workspace, 2);
void *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 3);
size_t N = input_shape_[0];
size_t C = input_shape_[1];
CopyMemDevice2Device(N, C, gamma, nullptr, nullptr, nullptr, ws_gamma, nullptr, nullptr, nullptr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaStreamSynchronized failed");
const float alpha_data_diff = 1;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
float *reserve_addr = nullptr;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff,
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
scale_bias_diff_desc_, ws_gamma, beta, ws_dgamma, ws_dbeta, epsilon_, save_mean,
save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, ws_dgamma, ws_dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
stream_ptr_ = reinterpret_cast<cudaStream_t>(stream_ptr);
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
InitResource();
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 5, but got " << input_num;
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape_.size() != 4) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input must be 4, but got "
<< input_shape_.size();
}
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
CheckTensorSize({input_shape_});
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor();
InitSizeLists();
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_diff_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size_), "Get x size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, &para_size_),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_, dx_desc_,
scale_bias_diff_desc_, activation_desc_, &workspace_size_),
"cudnnGetBatchNormalizationBackwardExWorkspaceSize failed");
}
input_size_list_.push_back(x_size_);
input_size_list_.push_back(x_size_);
input_size_list_.push_back(input_shape_[1]);
input_size_list_.push_back(para_size_);
input_size_list_.push_back(para_size_);
output_size_list_.push_back(x_size_);
output_size_list_.push_back(x_size_);
output_size_list_.push_back(x_size_);
workspace_size_list_.push_back(para_size_); // ws gamma
workspace_size_list_.push_back(para_size_); // ws dgamma
workspace_size_list_.push_back(para_size_); // ws dbeta
workspace_size_list_.push_back(workspace_size_);
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_diff_desc_),
"Destroy para desc failed");
}
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void SetTensorDescriptor() {
int batch = 1;
int channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
int height = SizeToInt(input_shape_[2]);
int width = SizeToInt(input_shape_[3]);
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
}
static constexpr cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
static constexpr cudnnBatchNormOps_t bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
size_t x_size_;
size_t para_size_;
size_t workspace_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
double epsilon_;
bool is_null_input_;
size_t batch_{0};
size_t channel_{0};
size_t workspace_size_{0};
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t dz_desc_;
cudnnTensorDescriptor_t scale_bias_diff_desc_;
cudnnActivationDescriptor_t activation_desc_;
double epsilon_{10e-5};
float beta_data_diff_{0};
bool is_null_input_{false};
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
float beta_data_diff_;
std::vector<size_t> input_shape_;
cudnnTensorDescriptor_t x_desc_{nullptr};
cudnnTensorDescriptor_t y_desc_{nullptr};
cudnnTensorDescriptor_t dy_desc_{nullptr};
cudnnTensorDescriptor_t dx_desc_{nullptr};
cudnnTensorDescriptor_t dz_desc_{nullptr};
cudnnTensorDescriptor_t scale_bias_diff_desc_{nullptr};
cudnnActivationDescriptor_t activation_desc_{nullptr};
cudnnHandle_t handle_{nullptr};
cudnnDataType_t cudnn_data_type_{CUDNN_DATA_FLOAT};
cudaStream_t stream_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -540,6 +540,7 @@ GVAR_DEF(PrimitivePtr, kPrimGroupConv2DGradInput, std::make_shared<Primitive>("G
GVAR_DEF(PrimitivePtr, kPrimBatchNorm, std::make_shared<Primitive>("BatchNorm"));
GVAR_DEF(PrimitivePtr, kPrimBatchNormGrad, std::make_shared<Primitive>("BatchNormGrad"));
GVAR_DEF(PrimitivePtr, kPrimInstanceNorm, std::make_shared<Primitive>("InstanceNorm"));
GVAR_DEF(PrimitivePtr, kPrimInstanceNormGrad, std::make_shared<Primitive>("InstanceNormGrad"));
GVAR_DEF(PrimitivePtr, kPrimSyncBatchNorm, std::make_shared<Primitive>("SyncBatchNorm"));
GVAR_DEF(PrimitivePtr, kPrimSyncBatchNormGrad, std::make_shared<Primitive>("SyncBatchNormGrad"));
GVAR_DEF(PrimitivePtr, kPrimBNTrainingReduce, std::make_shared<Primitive>("BNTrainingReduce"));

View File

@ -0,0 +1,133 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/instance_norm_grad.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/primitive_c.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InstanceNormGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](auto arg) { return arg->BuildShape()->IsDynamic(); })) {
const auto x_shape_ptr = input_args[kInputIndex1]->BuildShape();
const auto gamma_shape_ptr = input_args[kInputIndex2]->BuildShape();
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{x_shape_ptr, gamma_shape_ptr, gamma_shape_ptr});
}
const auto prim_name = primitive->name();
const auto y_backprop_shape_ptr = input_args[kInputIndex0]->BuildShape();
const auto x_shape_ptr = input_args[kInputIndex1]->BuildShape();
const auto gamma_shape_ptr = input_args[kInputIndex2]->BuildShape();
const auto save_mean_shape_ptr = input_args[kInputIndex3]->BuildShape();
const auto save_variance_shape_ptr = input_args[kInputIndex4]->BuildShape();
auto y_backprop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(y_backprop_shape_ptr)[kShape];
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(x_shape_ptr)[kShape];
auto gamma_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(gamma_shape_ptr)[kShape];
auto save_mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(save_mean_shape_ptr)[kShape];
auto save_variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(save_variance_shape_ptr)[kShape];
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kEqual, "y_backprop rank",
y_backprop_shape.size(), prim_name);
constexpr size_t minimum_input_x_rank = 3;
(void)CheckAndConvertUtils::CheckValue<size_t>("x rank", x_shape.size(), kGreaterEqual, minimum_input_x_rank,
prim_name);
(void)CheckAndConvertUtils::Check("x shape", x_shape, kEqual, y_backprop_shape, prim_name);
const size_t batch = x_shape[kInputIndex0];
const size_t channel = x_shape[kInputIndex1];
const size_t batch_channel = batch * channel;
(void)CheckAndConvertUtils::CheckValue<size_t>("gamma rank", gamma_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("save_mean rank", save_mean_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("save_variance rank", save_variance_shape.size(), kEqual, 1,
prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("gamma shape", gamma_shape[0], kEqual, "(C, )", channel, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("save_mean shape", save_mean_shape[0], kEqual, "(B*C, )",
batch_channel, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("save_variance shape", save_variance_shape[0], kEqual, "(B*C, )",
batch_channel, prim_name);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{x_shape_ptr, gamma_shape_ptr, gamma_shape_ptr});
}
TuplePtr InstanceNormGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const auto prim_name = primitive->name();
const auto y_backprop_type = input_args[kInputIndex0]->BuildType();
const auto x_type = input_args[kInputIndex1]->BuildType();
const auto gamma_type = input_args[kInputIndex2]->BuildType();
const auto save_mean_type = input_args[kInputIndex3]->BuildType();
const auto save_variance_type = input_args[kInputIndex4]->BuildType();
const std::map<std::string, TypePtr> types = {
{"y_backprop", y_backprop_type},
{"x", x_type},
};
const auto type = CheckAndConvertUtils::CheckTensorTypeSame(types, {kFloat16, kFloat32}, prim_name);
const std::map<std::string, TypePtr> grad_types = {
{"gamma", gamma_type},
{"save_mean", save_mean_type},
{"save_variance", save_variance_type},
};
const auto grad_type = CheckAndConvertUtils::CheckTensorTypeSame(grad_types, {kFloat32}, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{type, grad_type, grad_type});
}
} // namespace
MIND_API_OPERATOR_IMPL(InstanceNormGrad, BaseOperator);
void InstanceNormGrad::Init(const float epsilon) { this->set_epsilon(epsilon); }
void InstanceNormGrad::set_epsilon(const float epsilon) { (void)this->AddAttr(kEpsilon, api::MakeValue(epsilon)); }
float InstanceNormGrad::get_epsilon() const {
auto value_ptr = GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
void InstanceNormGrad::set_inplace_algo(const std::string inplace_algo = "cover") {
(void)this->AddAttr(kInplaceAlgo, api::MakeValue(inplace_algo));
}
std::string InstanceNormGrad::get_inplace_algo() const {
auto value_ptr = GetAttr(kInplaceAlgo);
if (value_ptr == nullptr) {
return "cover";
}
return GetValue<std::string>(value_ptr);
}
AbstractBasePtr InstanceNormGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr int64_t kInputNum = 5;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
auto type = InstanceNormGradInferType(primitive, input_args);
auto shape = InstanceNormGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(InstanceNormGrad, prim::kPrimInstanceNormGrad, InstanceNormGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,64 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_INSTANCE_NORM_GRAD_H_
#define MINDSPORE_CORE_OPS_INSTANCE_NORM_GRAD_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameInstanceNormGrad = "InstanceNormGrad";
/// \brief InstanceNormGrad defined the InstanceNormGrad operator prototype.
class MIND_API InstanceNormGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(InstanceNormGrad);
/// \brief Constructor.
InstanceNormGrad() : BaseOperator(kNameInstanceNormGrad) {}
/// \brief Method to init the op's attributes
///
/// \param[in] epsilon Define a value added to the denominator for numerical stability.
void Init(const float epsilon = 0.00001);
/// \brief Method to set epsilon attribute.
///
/// \param[in] epsilon Define a value added to the denominator for numerical stability.
void set_epsilon(const float epsilon);
/// \brief Method to get epsilon attribute.
///
/// \return a value.
float get_epsilon() const;
/// \brief Method to set inplace_algo attribute.
///
/// \param[in] inplace_algo Define a value added to the denominator for numerical stability.
void set_inplace_algo(const std::string inplace_algo);
/// \brief Method to get inplace_algo attribute.
///
/// \return a value.
std::string get_inplace_algo() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_INSTANCE_NORM_GRAD_H_

View File

@ -286,6 +286,7 @@ constexpr auto kNumGroups = "num_groups";
constexpr auto kIndexing = "indexing";
constexpr auto kModulated = "modulated";
constexpr auto kAdjoint = "adjoint";
constexpr auto kInplaceAlgo = "inplace_algo";
enum Index : size_t {
kInputIndex0 = 0,

View File

@ -613,12 +613,6 @@ class InstanceNormGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
outputs=['dx', 'bn_gamma', 'bn_beta'])
def infer_shape(self, y_backprop_shape, x_shape, gamma_shape, save_mean_shape, save_variance_shape):
return (x_shape, gamma_shape, gamma_shape)
def infer_dtype(self, y_backprop_type, x_type, gamma_type, save_mean_type, save_variance_type):
return (x_type, gamma_type, gamma_type)
class EinsumGrad(PrimitiveWithInfer):
"""Gradients of Einsum."""

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,96 +13,159 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
from mindspore import Tensor, Parameter, nn
import numpy as np
from mindspore import Tensor, nn, context, Parameter, ms_function
from mindspore.ops.composite import GradOperation
from mindspore.ops import functional as F
from mindspore.ops.operations.nn_ops import InstanceNorm
from mindspore.common.initializer import initializer
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class InstanceNormNet(nn.Cell):
def __init__(self, channel, epsilon=1e-5):
super(InstanceNormNet, self).__init__()
self.instance_norm = InstanceNorm(epsilon=epsilon)
self.gamma = Parameter(Tensor(np.ones([channel]), mindspore.float32), name="gamma")
self.beta = Parameter(Tensor(np.zeros([channel]), mindspore.float32), name="beta")
self.mean = Parameter(Tensor(np.zeros([channel]), mindspore.float32), name="mean")
self.variance = Parameter(Tensor(np.ones([channel]), mindspore.float32), name="variance")
class InstanceNormNd(nn.Cell):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, gamma_init='ones', beta_init='zeros'):
super(InstanceNormNd, self).__init__()
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(beta_init, num_features), name="beta", requires_grad=affine)
self.instance_bn = InstanceNorm(epsilon=eps, momentum=momentum)
def construct(self, input_x):
out = self.instance_norm(input_x, self.gamma, self.beta, self.mean, self.variance)
return out[0]
def construct(self, x):
return self.instance_bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)[0]
def instance_norm_np(x, eps=1e-5):
shape = x.shape
b = shape[0]
c = shape[1]
x = x.reshape((b, c, -1))
mu = np.expand_dims(np.mean(x, axis=-1), axis=-1)
std = np.expand_dims(np.std(x, axis=-1), axis=-1)
result = (x - mu) / (std + eps)
return result.reshape(shape)
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
@ms_function
def construct(self, input_x, grad):
return self.grad(self.network)(input_x, grad)
class Expected1d(nn.Cell):
def __init__(self, n, gamma_init=0.5, beta_init=0.5):
super(Expected1d, self).__init__()
self.ops = nn.BatchNorm2d(n, use_batch_statistics=True, gamma_init=gamma_init, beta_init=beta_init)
def construct(self, x):
shape = F.shape(x)
return F.reshape(self.ops(F.reshape(x, (1, -1, 1, shape[2]))), shape)
class Expected2d(nn.Cell):
def __init__(self, n, gamma_init=0.5, beta_init=0.5):
super(Expected2d, self).__init__()
self.ops = nn.BatchNorm2d(n, use_batch_statistics=True, gamma_init=gamma_init, beta_init=beta_init)
def construct(self, x):
shape = F.shape(x)
return F.reshape(self.ops(F.reshape(x, (1, -1, shape[2], shape[3]))), shape)
class Expected3d(nn.Cell):
def __init__(self, n, gamma_init=0.5, beta_init=0.5):
super(Expected3d, self).__init__()
self.ops = nn.BatchNorm3d(n, use_batch_statistics=True, gamma_init=gamma_init, beta_init=beta_init)
def construct(self, x):
shape = F.shape(x)
return F.reshape(self.ops(F.reshape(x, (1, -1, shape[2], shape[3], shape[4]))), shape)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 5)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_1d(shape, data_type, err):
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
def test_instancenorm_1d(shape, data_type):
"""
Feature: InstanceNorm 1D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)
x_np = Tensor(np.random.randn(*shape).astype(data_type))
grad = Tensor(np.random.randn(*shape).astype(data_type))
instance_op = InstanceNormNd(shape[1], gamma_init=0.5, beta_init=0.5)
expected_net = Expected1d(shape[0] * shape[1], gamma_init=0.5, beta_init=0.5)
result = instance_op(Tensor(x_np))
expected = expected_net(Tensor(x_np))
assert np.allclose(result.asnumpy(), expected.asnumpy())
instance_backward_net = Grad(instance_op)
expected_backward_net = Grad(expected_net)
result = instance_backward_net(Tensor(x_np), Tensor(grad))
expected = expected_backward_net(Tensor(x_np), Tensor(grad))
assert np.allclose(result[0].asnumpy(), expected[0].asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 3, 4)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_2d(shape, data_type, err):
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
def test_instancenorm_2d(shape, data_type):
"""
Feature: InstanceNorm 2D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)
x_np = Tensor(np.random.randn(*shape).astype(data_type))
grad = Tensor(np.random.randn(*shape).astype(data_type))
instance_op = InstanceNormNd(shape[1], gamma_init=0.5, beta_init=0.5)
expected_net = Expected2d(shape[0] * shape[1], gamma_init=0.5, beta_init=0.5)
result = instance_op(Tensor(x_np))
expected = expected_net(Tensor(x_np))
assert np.allclose(result.asnumpy(), expected.asnumpy())
instance_backward_net = Grad(instance_op)
expected_backward_net = Grad(expected_net)
result = instance_backward_net(Tensor(x_np), Tensor(grad))
expected = expected_backward_net(Tensor(x_np), Tensor(grad))
assert np.allclose(result[0].asnumpy(), expected[0].asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 3, 4, 7)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_3d(shape, data_type, err):
@pytest.mark.parametrize("data_type", [np.float16, np.float32])
def test_instancenorm_3d(shape, data_type):
"""
Feature: InstanceNorm 3D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)
x_np = Tensor(np.random.randn(*shape).astype(data_type))
grad = Tensor(np.random.randn(*shape).astype(data_type))
instance_op = InstanceNormNd(shape[1], gamma_init=0.5, beta_init=0.5)
expected_net = Expected3d(shape[0] * shape[1], gamma_init=0.5, beta_init=0.5)
result = instance_op(Tensor(x_np))
expected = expected_net(Tensor(x_np))
assert np.allclose(result.asnumpy(), expected.asnumpy())
instance_backward_net = Grad(instance_op)
expected_backward_net = Grad(expected_net)
result = instance_backward_net(Tensor(x_np), Tensor(grad))
expected = expected_backward_net(Tensor(x_np), Tensor(grad))
assert np.allclose(result[0].asnumpy(), expected[0].asnumpy())