add new op instancenorm2d

This commit is contained in:
zhouyuanshen 2020-12-31 14:35:45 +08:00
parent e805d06499
commit 26f6daa850
13 changed files with 1010 additions and 4 deletions

View File

@ -0,0 +1,90 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
__global__ void CopyMemKernel(const size_t thread_num, const size_t N, const size_t C,
float *gamma_addr, float *beta_addr,
float *runing_mean_addr, float *runnig_variance_addr,
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
size_t cur_addr = pos / (N * C);
size_t cur_local_index = pos % (N * C);
size_t local_index = 0;
switch (cur_addr) {
case 0:
if (!(gamma_addr && ws_gamma)) break;
local_index = cur_local_index % C;
ws_gamma[cur_local_index] = gamma_addr[local_index];
break;
case 1:
if (!(beta_addr && ws_beta)) break;
local_index = cur_local_index % C;
ws_beta[cur_local_index] = beta_addr[local_index];
break;
case 2:
if (!(runing_mean_addr && ws_mean)) break;
local_index = cur_local_index % C;
ws_mean[cur_local_index] = runing_mean_addr[local_index];
break;
default:
if (!(runnig_variance_addr && ws_var)) break;
local_index = cur_local_index % C;
ws_var[cur_local_index] = runnig_variance_addr[local_index];
}
}
return;
}
void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, float *beta_addr,
float *runing_mean_addr, float *runnig_variance_addr,
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
cudaStream_t cuda_stream) {
size_t thread_num = N * C * 4;
CopyMemKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
thread_num, N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr,
ws_gamma, ws_beta, ws_mean, ws_var);
}
__global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C,
float *save_mean_addr, float *save_var_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) {
size_t cur_addr = pos / C;
size_t cur_local_index = pos % C;
float tmp = 0;
if (cur_addr) {
for (size_t i = 0; i < N; i++) {
tmp += save_var_addr[i * C + cur_local_index];
}
save_var_addr[cur_local_index] = tmp / N;
} else {
for (size_t i = 0; i < N; i++) {
tmp += save_mean_addr[i * C + cur_local_index];
}
save_mean_addr[cur_local_index] = tmp / N;
}
}
return;
}
void ComputeMean(const size_t N, const size_t C,
float *save_mean_addr, float *save_var_addr,
cudaStream_t cuda_stream) {
size_t thread_num = C * 2;
ComputeMeanKernel<<<GET_BLOCKS(thread_num), GET_THREADS, 0, cuda_stream>>>(
thread_num, N, C, save_mean_addr, save_var_addr);
}

View File

@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
void CopyMemDevice2Device(const size_t N, const size_t C,
float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr,
float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var,
cudaStream_t cuda_stream);
void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
InstanceNormGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
InstanceNormGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,240 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/utils.h"
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class InstanceNormGpuKernel : public GpuKernel {
public:
InstanceNormGpuKernel()
: input_x_size_(0),
input_z_size_(0),
para_size_(0),
output_size_(0),
workspace_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
is_training_(true),
epsilon_(10e-5),
exp_avg_factor_(0.1),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
z_desc_(nullptr),
scale_bias_mean_var_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~InstanceNormGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto x_addr = GetDeviceAddress<T>(inputs, 0);
auto gamma_addr = GetDeviceAddress<float>(inputs, 1);
auto beta_addr = GetDeviceAddress<float>(inputs, 2);
auto runing_mean_addr = GetDeviceAddress<float>(inputs, 3);
auto runnig_variance_addr = GetDeviceAddress<float>(inputs, 4);
T *z = nullptr;
auto y_addr = GetDeviceAddress<T>(outputs, 0);
auto save_mean_addr = GetDeviceAddress<float>(outputs, 1);
auto save_variance_addr = GetDeviceAddress<float>(outputs, 2);
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
float *ws_beta = GetDeviceAddress<float>(workspace, 1);
float *ws_mean = GetDeviceAddress<float>(workspace, 2);
float *ws_var = GetDeviceAddress<float>(workspace, 3);
T *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 4);
}
size_t N = input_shape_[0];
size_t C = input_shape_[1];
CopyMemDevice2Device(N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, ws_gamma, ws_beta,
ws_mean, ws_var, reinterpret_cast<cudaStream_t>(stream_ptr));
const float alpha = 1;
const float beta = 0;
float *reserve_addr = nullptr;
if (is_training_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationForwardTrainingEx(
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr,
scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr,
save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationForwardInference(
handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr,
scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_),
"Kernel launch failed");
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
InitResource();
is_training_ = GetAttr<bool>(kernel_node, "is_training");
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5";
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape_.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGpuKernel should be 4";
}
is_null_input_ = CHECK_NULL_INPUT(input_shape_);
if (is_null_input_) {
MS_LOG(WARNING) << "InstanceNormGpuKernel input is null";
InitSizeLists();
return true;
}
SetTensorDescriptor();
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
"Destroy para desc failed");
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_x_size_),
"Get input x size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, &para_size_),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size_),
"Get output size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_,
scale_bias_mean_var_desc_, nullptr, &workspace_size_),
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
}
input_size_list_.push_back(input_x_size_); // input x
input_size_list_.push_back(input_shape_[1]); // gamma
input_size_list_.push_back(input_shape_[1]); // beta
input_size_list_.push_back(input_shape_[1]); // mean
input_size_list_.push_back(input_shape_[1]); // variance
output_size_list_.push_back(output_size_); // output
output_size_list_.push_back(para_size_); // save mean
output_size_list_.push_back(para_size_); // save variance
workspace_size_list_.push_back(para_size_); // ws gamma
workspace_size_list_.push_back(para_size_); // ws beta
workspace_size_list_.push_back(para_size_); // ws mean
workspace_size_list_.push_back(para_size_); // ws variance
workspace_size_list_.push_back(workspace_size_);
}
private:
void SetTensorDescriptor() {
cudnnTensorFormat_t cudnn_format;
int batch, channel, height, width;
batch = 1;
channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
height = SizeToInt(input_shape_[2]);
width = SizeToInt(input_shape_[3]);
cudnn_format = CUDNN_TENSOR_NCHW;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
}
size_t input_x_size_;
size_t input_z_size_;
size_t para_size_;
size_t output_size_;
size_t workspace_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
bool is_training_;
double epsilon_;
double exp_avg_factor_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t z_desc_;
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
InstanceNormGradGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(InstanceNormGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
InstanceNormGradGpuKernel, half)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,238 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_
#include <string>
#include <vector>
#include "utils/utils.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class InstanceNormGradGpuKernel : public GpuKernel {
public:
InstanceNormGradGpuKernel()
: x_size_(0),
para_size_(0),
workspace_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
is_training_(true),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
dy_desc_(nullptr),
dx_desc_(nullptr),
dz_desc_(nullptr),
scale_bias_diff_desc_(nullptr),
activation_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
beta_data_diff_(0) {}
~InstanceNormGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto gamma = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
void *beta = nullptr;
T *y = nullptr;
auto dx = GetDeviceAddress<T>(outputs, 0);
auto dgamma = GetDeviceAddress<float>(outputs, 1);
auto dbeta = GetDeviceAddress<float>(outputs, 2);
T *dz = nullptr;
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 3);
}
size_t N = input_shape_[0];
size_t C = input_shape_[1];
CopyMemDevice2Device(N, C, gamma, nullptr, nullptr, nullptr, ws_gamma, nullptr, nullptr, nullptr,
reinterpret_cast<cudaStream_t>(stream_ptr));
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaStreamSynchronized failed");
const float alpha_data_diff = 1;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
float *reserve_addr = nullptr;
if (is_training_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x,
y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta,
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
ComputeMean(N, C, dgamma, dbeta, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
MS_LOG(EXCEPTION) << "The backward of InstanceNorm operator in evaluation mode is not implemented yet.";
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
InitResource();
is_training_ = GetAttr<bool>(kernel_node, "is_training");
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5";
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape_.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGradGpuKernel should be 4";
}
is_null_input_ = CHECK_NULL_INPUT(input_shape_);
if (is_null_input_) {
MS_LOG(WARNING) << "InstanceNormGradGpuKernel input is null";
InitSizeLists();
return true;
}
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor();
InitSizeLists();
return true;
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_diff_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size_), "Get x size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, &para_size_),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_, dx_desc_,
scale_bias_diff_desc_, activation_desc_, &workspace_size_),
"cudnnGetBatchNormalizationBackwardExWorkspaceSize failed");
}
input_size_list_.push_back(x_size_);
input_size_list_.push_back(x_size_);
input_size_list_.push_back(input_shape_[1]);
input_size_list_.push_back(para_size_);
input_size_list_.push_back(para_size_);
output_size_list_.push_back(x_size_);
output_size_list_.push_back(para_size_);
output_size_list_.push_back(para_size_);
workspace_size_list_.push_back(para_size_); // ws gamma
workspace_size_list_.push_back(workspace_size_);
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_diff_desc_),
"Destroy para desc failed");
}
private:
void SetTensorDescriptor() {
int batch, channel, height, width;
batch = 1;
channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
height = SizeToInt(input_shape_[2]);
width = SizeToInt(input_shape_[3]);
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dx desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
}
size_t x_size_;
size_t para_size_;
size_t workspace_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
double epsilon_;
bool is_training_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t dz_desc_;
cudnnTensorDescriptor_t scale_bias_diff_desc_;
cudnnActivationDescriptor_t activation_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
float beta_data_diff_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_

View File

@ -47,6 +47,7 @@ constexpr auto kBNGrad1OpName = "BNGrad1";
constexpr auto kBNGrad2OpName = "BNGrad2"; constexpr auto kBNGrad2OpName = "BNGrad2";
constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kBNGrad3OpName = "BNGrad3";
constexpr auto kFusedBatchNormEx = "FusedBatchNormEx"; constexpr auto kFusedBatchNormEx = "FusedBatchNormEx";
constexpr auto kInstanceNorm = "InstanceNorm";
constexpr auto kFusedBatchNormExWithActivation = "FusedBatchNormExWithActivation"; constexpr auto kFusedBatchNormExWithActivation = "FusedBatchNormExWithActivation";
constexpr auto kFusedBatchNormExWithAddAndActivation = "FusedBatchNormExWithAddAndActivation"; constexpr auto kFusedBatchNormExWithAddAndActivation = "FusedBatchNormExWithAddAndActivation";
constexpr auto kFusedBatchNormGradEx = "FusedBatchNormGradEx"; constexpr auto kFusedBatchNormGradEx = "FusedBatchNormGradEx";

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ from mindspore.communication import management
from mindspore.ops import _selected_ops from mindspore.ops import _selected_ops
from ..cell import Cell from ..cell import Cell
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm'] __all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'InstanceNorm2d']
class _BatchNorm(Cell): class _BatchNorm(Cell):
@ -705,6 +705,119 @@ class LayerNorm(Cell):
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
class InstanceNorm2d(Cell):
r"""
Instance normalization layer over a 4D input.
This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Note:
Note that the formula for updating the running_mean and running_var is
:math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
Args:
num_features (int): `C` from an expected input of size (N, C, H, W).
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.1.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. Default: True.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
Outputs:
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and
shape as the `input_x`.
Supported Platforms:
``GPU``
Raise:
ValueError: If num_features is less than 1 or momentum not in (0, 1).
Examples:
>>> net = nn.InstanceNorm2d(3)
>>> np.random.seed(0)
>>> input = Tensor(np.random.randint(0, 255, [2, 3, 2, 2]), mindspore.float32)
>>> output = net(input)
>>> print(output.shape)
(2, 3, 2, 2)
"""
@cell_attr_register
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
input_dims='2d'):
super(InstanceNorm2d, self).__init__()
if num_features < 1:
raise ValueError("num_features must be at least 1")
if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.use_batch_statistics = use_batch_statistics
self.num_features = num_features
self.eps = eps
self.input_dims = input_dims
self.moving_mean = Parameter(initializer(
moving_mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer(
moving_var_init, num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
beta_init, num_features), name="beta", requires_grad=affine)
self.shape = P.Shape()
self.momentum = momentum
self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics,
epsilon=self.eps,
momentum=self.momentum)
def _check_data_dim(self, x):
raise NotImplementedError
def construct(self, x):
_shape_check_bn(self.shape(x), self.input_dims)
return self.instance_bn(x,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
def extend_repr(self):
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
class GroupNorm(Cell): class GroupNorm(Cell):
r""" r"""
Group Normalization over a mini-batch of inputs. Group Normalization over a mini-batch of inputs.

View File

@ -688,6 +688,24 @@ def get_bprop_fused_batch_norm_ex(self):
return bprop return bprop
@bprop_getters.register(P.InstanceNorm)
def get_bprop_instance_norm(self):
"""Grad definition for `InstanceNorm` operation."""
is_training = self.is_training
input_grad = G.InstanceNormGrad(is_training, self.epsilon, self.momentum)
def bprop(x, gamma, beta, mean, variance, out, dout):
saved_mean = out[1]
saved_variance = out[2]
out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
dx = out[0]
dgamma = out[1]
dbeta = out[2]
return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
return bprop
@bprop_getters.register(P.BatchNorm) @bprop_getters.register(P.BatchNorm)
def get_bprop_batch_norm(self): def get_bprop_batch_norm(self):
"""Grad definition for `BatchNorm` operation.""" """Grad definition for `BatchNorm` operation."""

View File

@ -63,8 +63,8 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U
from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, Dropout, DropoutDoMask, Dropout, DropoutGenMask, Flatten,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
Gelu, FastGelu, Elu, Gelu, FastGelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax, LogSoftmax,
@ -130,6 +130,7 @@ __all__ = [
'MaxPoolWithArgmax', 'MaxPoolWithArgmax',
'FusedBatchNorm', 'FusedBatchNorm',
'FusedBatchNormEx', 'FusedBatchNormEx',
'InstanceNorm',
'BNTrainingReduce', 'BNTrainingReduce',
'BNTrainingUpdate', 'BNTrainingUpdate',
'BatchNorm', 'BatchNorm',

View File

@ -714,6 +714,21 @@ class FusedBatchNormGradEx(PrimitiveWithInfer):
return (x_type, scale_type, scale_type) return (x_type, scale_type, scale_type)
class InstanceNormGrad(PrimitiveWithInfer):
"""Gradients of InstanceNorm operation."""
@prim_attr_register
def __init__(self, is_training=True, epsilon=0.0, momentum=0.1):
self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'],
outputs=['dx', 'bn_gamma', 'bn_beta'])
def infer_shape(self, y_backprop_shape, x_shape, gamma_shape, save_mean_shape, save_variance_shape):
return (x_shape, gamma_shape, gamma_shape)
def infer_dtype(self, y_backprop_type, x_type, gamma_type, save_mean_type, save_variance_type):
return (x_type, gamma_type, gamma_type)
class UniqueGrad(Primitive): class UniqueGrad(Primitive):
"""Gradients of Unique operation.""" """Gradients of Unique operation."""

View File

@ -859,6 +859,119 @@ class FusedBatchNormEx(PrimitiveWithInfer):
return (input_x, scale, scale, scale, scale, scale) return (input_x, scale, scale, scale, scale, scale)
class InstanceNorm(PrimitiveWithInfer):
r"""
Instance normalization over a 4D input.
This operator applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
of data and the learned parameters which can be described in the following formula.
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
Args:
is_training (bool): Is training or inference. Default: True.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value must be [0, 1]. Default: 0.1.
data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW".
Inputs:
- **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`,
data type: float16 or float32.
- **gamma** (Parameter) - scale, Tensor of shape :math:`(C,)`,
data type: float32.
- **beta** (Parameter) - bias, Tensor of shape :math:`(C,)`,
data type: float32.
- **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
- **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
Outputs:
Tuple of 3 Tensors, the normalized input, the updated parameters.
- **output_x** (Tensor) - The output of InstanceNorm, same type and shape as the `input_x`.
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(NC,)`, data type: float32.
- **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(NC,)`,
data type: float32.
Supported Platforms:
``GPU``
Raise:
TypeError: If any validator check fails.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as ops
>>> class InstanceNormNet(nn.Cell):
>>> def __init__(self):
>>> super(InstanceNormNet, self).__init__()
>>> self.instance_norm = ops.InstanceNorm()
>>> self.gamma = Parameter(Tensor(np.ones([64]), mindspore.float32), name="gamma")
>>> self.beta = Parameter(Tensor(np.ones([64]), mindspore.float32), name="beta")
>>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
>>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
>>>
>>> def construct(self, input_x):
>>> out = self.instance_norm(input_x, self.gamma, self.beta, self.mean, self.variance)
>>> return out
>>>
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
>>> net = InstanceNormNet()
>>> output = net(input_x)
>>> result = output[0].shape
>>> print(result)
(128, 64, 32, 64)
"""
__mindspore_signature__ = (
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
sig.make_sig('gamma', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('beta', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
)
@prim_attr_register
def __init__(self, is_training=True, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
outputs=['y', 'save_mean', 'save_variance'])
self.is_training = validator.check_bool(is_training, self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
def infer_shape(self, input_x, gamma, beta, mean, variance):
input_shape_norm = input_x
validator.check_equal_int(len(gamma), 1, "gamma rank", self.name)
validator.check("gamma shape", gamma, "beta shape", beta, Rel.EQ, self.name)
validator.check("gamma shape[0]", gamma[0], "input channel", input_shape_norm[1], Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name)
save_mean_shape = gamma
save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0]
return (input_x, save_mean_shape, save_mean_shape)
def infer_dtype(self, input_x, gamma, beta, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"gamma": gamma, "beta": beta}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return (input_x, gamma, gamma)
class BNTrainingReduce(PrimitiveWithInfer): class BNTrainingReduce(PrimitiveWithInfer):
""" """
For the BatchNorm operation this operator update the moving averages for training and is used in conjunction with For the BatchNorm operation this operator update the moving averages for training and is used in conjunction with

View File

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.ops import functional as F
from mindspore.ops.composite import GradOperation
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Grad(nn.Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network
@ms_function
def construct(self, input_x, grad):
return self.grad(self.network)(input_x, grad)
class Net(nn.Cell):
def __init__(self, n):
super(Net, self).__init__()
self.ops = nn.BatchNorm2d(n, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
def construct(self, x):
shape = F.shape(x)
return F.reshape(self.ops(F.reshape(x, (1, -1, shape[2], shape[3]))), shape)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_InstanceNorm2d_fp32():
x_np = np.random.randn(3, 3, 2, 2).astype(np.float32)
bn_instance_comp = Net(3 * 3)
bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
comp_out = bn_instance_comp(Tensor(x_np))
op_out = bn_instance_op(Tensor(x_np))
assert np.allclose(comp_out.asnumpy(), op_out.asnumpy())
sens = np.random.randn(3, 3, 2, 2).astype(np.float32)
bn_comp_backward_net = Grad(bn_instance_comp)
bn_op_backward_net = Grad(bn_instance_op)
output1 = bn_comp_backward_net(Tensor(x_np), Tensor(sens))
output2 = bn_op_backward_net(Tensor(x_np), Tensor(sens))
assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy())