diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu new file mode 100644 index 00000000000..7af68dc46ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cu @@ -0,0 +1,90 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" + +__global__ void CopyMemKernel(const size_t thread_num, const size_t N, const size_t C, + float *gamma_addr, float *beta_addr, + float *runing_mean_addr, float *runnig_variance_addr, + float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) { + size_t cur_addr = pos / (N * C); + size_t cur_local_index = pos % (N * C); + size_t local_index = 0; + switch (cur_addr) { + case 0: + if (!(gamma_addr && ws_gamma)) break; + local_index = cur_local_index % C; + ws_gamma[cur_local_index] = gamma_addr[local_index]; + break; + case 1: + if (!(beta_addr && ws_beta)) break; + local_index = cur_local_index % C; + ws_beta[cur_local_index] = beta_addr[local_index]; + break; + case 2: + if (!(runing_mean_addr && ws_mean)) break; + local_index = cur_local_index % C; + ws_mean[cur_local_index] = runing_mean_addr[local_index]; + break; + default: + if (!(runnig_variance_addr && ws_var)) break; + local_index = cur_local_index % C; + ws_var[cur_local_index] = runnig_variance_addr[local_index]; + } + } + return; +} + +void CopyMemDevice2Device(const size_t N, const size_t C, float *gamma_addr, float *beta_addr, + float *runing_mean_addr, float *runnig_variance_addr, + float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var, + cudaStream_t cuda_stream) { + size_t thread_num = N * C * 4; + CopyMemKernel<<>>( + thread_num, N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, + ws_gamma, ws_beta, ws_mean, ws_var); +} + +__global__ void ComputeMeanKernel(const size_t thread_num, const size_t N, const size_t C, + float *save_mean_addr, float *save_var_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < thread_num; pos += gridDim.x * blockDim.x) { + size_t cur_addr = pos / C; + size_t cur_local_index = pos % C; + float tmp = 0; + if (cur_addr) { + for (size_t i = 0; i < N; i++) { + tmp += save_var_addr[i * C + cur_local_index]; + } + save_var_addr[cur_local_index] = tmp / N; + } else { + for (size_t i = 0; i < N; i++) { + tmp += save_mean_addr[i * C + cur_local_index]; + } + save_mean_addr[cur_local_index] = tmp / N; + } + } + return; +} + +void ComputeMean(const size_t N, const size_t C, + float *save_mean_addr, float *save_var_addr, + cudaStream_t cuda_stream) { + size_t thread_num = C * 2; + ComputeMeanKernel<<>>( + thread_num, N, C, save_mean_addr, save_var_addr); +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh new file mode 100644 index 00000000000..053d529cb03 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +void CopyMemDevice2Device(const size_t N, const size_t C, + float *gamma_addr, float *beta_addr, float *runing_mean_addr, float *runnig_variance_addr, + float *ws_gamma, float *ws_beta, float *ws_mean, float *ws_var, + cudaStream_t cuda_stream); +void ComputeMean(const size_t N, const size_t C, float *save_mean_addr, float *save_var_addr, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_INSTANCE_NORM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.cc new file mode 100644 index 00000000000..1415ee709eb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(InstanceNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + InstanceNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(InstanceNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + InstanceNormGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h new file mode 100644 index 00000000000..98883ac447b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h @@ -0,0 +1,240 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class InstanceNormGpuKernel : public GpuKernel { + public: + InstanceNormGpuKernel() + : input_x_size_(0), + input_z_size_(0), + para_size_(0), + output_size_(0), + workspace_size_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + bn_ops_(CUDNN_BATCHNORM_OPS_BN), + is_training_(true), + epsilon_(10e-5), + exp_avg_factor_(0.1), + is_null_input_(false), + x_desc_(nullptr), + y_desc_(nullptr), + z_desc_(nullptr), + scale_bias_mean_var_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~InstanceNormGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto x_addr = GetDeviceAddress(inputs, 0); + auto gamma_addr = GetDeviceAddress(inputs, 1); + auto beta_addr = GetDeviceAddress(inputs, 2); + auto runing_mean_addr = GetDeviceAddress(inputs, 3); + auto runnig_variance_addr = GetDeviceAddress(inputs, 4); + T *z = nullptr; + + auto y_addr = GetDeviceAddress(outputs, 0); + auto save_mean_addr = GetDeviceAddress(outputs, 1); + auto save_variance_addr = GetDeviceAddress(outputs, 2); + + float *ws_gamma = GetDeviceAddress(workspace, 0); + float *ws_beta = GetDeviceAddress(workspace, 1); + float *ws_mean = GetDeviceAddress(workspace, 2); + float *ws_var = GetDeviceAddress(workspace, 3); + T *workspace_addr = nullptr; + if (workspace_size_ != 0) { + workspace_addr = GetDeviceAddress(workspace, 4); + } + + size_t N = input_shape_[0]; + size_t C = input_shape_[1]; + CopyMemDevice2Device(N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, ws_gamma, ws_beta, + ws_mean, ws_var, reinterpret_cast(stream_ptr)); + + const float alpha = 1; + const float beta = 0; + float *reserve_addr = nullptr; + if (is_training_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnBatchNormalizationForwardTrainingEx( + handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, + scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, + save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0), + "Kernel launch failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnBatchNormalizationForwardInference( + handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr, + scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_), + "Kernel launch failed"); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + bn_ops_ = CUDNN_BATCHNORM_OPS_BN; + + InitResource(); + is_training_ = GetAttr(kernel_node, "is_training"); + mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + exp_avg_factor_ = GetAttr(kernel_node, "momentum"); + + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5"; + } + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (input_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGpuKernel should be 4"; + } + is_null_input_ = CHECK_NULL_INPUT(input_shape_); + if (is_null_input_) { + MS_LOG(WARNING) << "InstanceNormGpuKernel input is null"; + InitSizeLists(); + return true; + } + SetTensorDescriptor(); + InitSizeLists(); + return true; + } + + void DestroyResource() noexcept override { + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), + "Destroy para desc failed"); + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), + "Create para desc failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_x_size_), + "Get input x size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size_), + "Get para size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size_), + "Get output size failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_, + scale_bias_mean_var_desc_, nullptr, &workspace_size_), + "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed"); + } + + input_size_list_.push_back(input_x_size_); // input x + input_size_list_.push_back(input_shape_[1]); // gamma + input_size_list_.push_back(input_shape_[1]); // beta + input_size_list_.push_back(input_shape_[1]); // mean + input_size_list_.push_back(input_shape_[1]); // variance + + output_size_list_.push_back(output_size_); // output + output_size_list_.push_back(para_size_); // save mean + output_size_list_.push_back(para_size_); // save variance + + workspace_size_list_.push_back(para_size_); // ws gamma + workspace_size_list_.push_back(para_size_); // ws beta + workspace_size_list_.push_back(para_size_); // ws mean + workspace_size_list_.push_back(para_size_); // ws variance + workspace_size_list_.push_back(workspace_size_); + } + + private: + void SetTensorDescriptor() { + cudnnTensorFormat_t cudnn_format; + int batch, channel, height, width; + batch = 1; + channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]); + height = SizeToInt(input_shape_[2]); + width = SizeToInt(input_shape_[3]); + cudnn_format = CUDNN_TENSOR_NCHW; + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), + "Set x desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), + "Set y desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1), + "Set para desc failed"); + } + + size_t input_x_size_; + size_t input_z_size_; + size_t para_size_; + size_t output_size_; + size_t workspace_size_; + cudnnBatchNormMode_t mode_; + cudnnBatchNormOps_t bn_ops_; + bool is_training_; + double epsilon_; + double exp_avg_factor_; + bool is_null_input_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t y_desc_; + cudnnTensorDescriptor_t z_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.cc new file mode 100644 index 00000000000..8d035f0d5a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(InstanceNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) // dy + .AddInputAttr(kNumberTypeFloat32) // x + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // save_mean + .AddInputAttr(kNumberTypeFloat32) // save_variance + .AddOutputAttr(kNumberTypeFloat32) // dx + .AddOutputAttr(kNumberTypeFloat32) // dscale + .AddOutputAttr(kNumberTypeFloat32), // dbias + InstanceNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(InstanceNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) // dy + .AddInputAttr(kNumberTypeFloat16) // x + .AddInputAttr(kNumberTypeFloat32) // scale + .AddInputAttr(kNumberTypeFloat32) // save_mean + .AddInputAttr(kNumberTypeFloat32) // save_variance + .AddOutputAttr(kNumberTypeFloat16) // dx + .AddOutputAttr(kNumberTypeFloat32) // dscale + .AddOutputAttr(kNumberTypeFloat32), // dbias + InstanceNormGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h new file mode 100644 index 00000000000..a153d321bc8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_grad_gpu_kernel.h @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_ + +#include +#include +#include "utils/utils.h" + +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/instance_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class InstanceNormGradGpuKernel : public GpuKernel { + public: + InstanceNormGradGpuKernel() + : x_size_(0), + para_size_(0), + workspace_size_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + bn_ops_(CUDNN_BATCHNORM_OPS_BN), + epsilon_(10e-5), + is_training_(true), + is_null_input_(false), + x_desc_(nullptr), + y_desc_(nullptr), + dy_desc_(nullptr), + dx_desc_(nullptr), + dz_desc_(nullptr), + scale_bias_diff_desc_(nullptr), + activation_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + beta_data_diff_(0) {} + ~InstanceNormGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto dy = GetDeviceAddress(inputs, 0); + auto x = GetDeviceAddress(inputs, 1); + auto gamma = GetDeviceAddress(inputs, 2); + auto save_mean = GetDeviceAddress(inputs, 3); + auto save_variance = GetDeviceAddress(inputs, 4); + void *beta = nullptr; + T *y = nullptr; + + auto dx = GetDeviceAddress(outputs, 0); + auto dgamma = GetDeviceAddress(outputs, 1); + auto dbeta = GetDeviceAddress(outputs, 2); + T *dz = nullptr; + + float *ws_gamma = GetDeviceAddress(workspace, 0); + void *workspace_addr = nullptr; + if (workspace_size_ != 0) { + workspace_addr = GetDeviceAddress(workspace, 3); + } + + size_t N = input_shape_[0]; + size_t C = input_shape_[1]; + CopyMemDevice2Device(N, C, gamma, nullptr, nullptr, nullptr, ws_gamma, nullptr, nullptr, nullptr, + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronized failed"); + + const float alpha_data_diff = 1; + const float alpha_param_diff = 1; + const float beta_param_diff = 0; + float *reserve_addr = nullptr; + if (is_training_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnBatchNormalizationBackwardEx( + handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff, &beta_param_diff, x_desc_, x, + y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, ws_gamma, beta, dgamma, dbeta, + epsilon_, save_mean, save_variance, activation_desc_, workspace_addr, workspace_size_, reserve_addr, 0), + "Kernel launch failed"); + ComputeMean(N, C, dgamma, dbeta, reinterpret_cast(stream_ptr)); + } else { + MS_LOG(EXCEPTION) << "The backward of InstanceNorm operator in evaluation mode is not implemented yet."; + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + kernel_node_ = kernel_node; + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + bn_ops_ = CUDNN_BATCHNORM_OPS_BN; + + InitResource(); + is_training_ = GetAttr(kernel_node, "is_training"); + mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 5"; + } + + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (input_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << input_shape_.size() << ", InstanceNormGradGpuKernel should be 4"; + } + is_null_input_ = CHECK_NULL_INPUT(input_shape_); + if (is_null_input_) { + MS_LOG(WARNING) << "InstanceNormGradGpuKernel input is null"; + InitSizeLists(); + return true; + } + beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; + SetTensorDescriptor(); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_diff_desc_), + "Create para desc failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size_), "Get x size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, ¶_size_), + "Get para size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnGetBatchNormalizationBackwardExWorkspaceSize( + handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_, dx_desc_, + scale_bias_diff_desc_, activation_desc_, &workspace_size_), + "cudnnGetBatchNormalizationBackwardExWorkspaceSize failed"); + } + + input_size_list_.push_back(x_size_); + input_size_list_.push_back(x_size_); + input_size_list_.push_back(input_shape_[1]); + input_size_list_.push_back(para_size_); + input_size_list_.push_back(para_size_); + + output_size_list_.push_back(x_size_); + output_size_list_.push_back(para_size_); + output_size_list_.push_back(para_size_); + + workspace_size_list_.push_back(para_size_); // ws gamma + workspace_size_list_.push_back(workspace_size_); + } + void DestroyResource() noexcept override { + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_diff_desc_), + "Destroy para desc failed"); + } + + private: + void SetTensorDescriptor() { + int batch, channel, height, width; + batch = 1; + channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]); + height = SizeToInt(input_shape_[2]); + width = SizeToInt(input_shape_[3]); + cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW; + + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), + "Set x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), + "Set dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width), + "Set dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1), + "Set para desc failed"); + } + + size_t x_size_; + size_t para_size_; + size_t workspace_size_; + cudnnBatchNormMode_t mode_; + cudnnBatchNormOps_t bn_ops_; + double epsilon_; + bool is_training_; + bool is_null_input_; + + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t y_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t dx_desc_; + cudnnTensorDescriptor_t dz_desc_; + cudnnTensorDescriptor_t scale_bias_diff_desc_; + cudnnActivationDescriptor_t activation_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + float beta_data_diff_; + std::vector input_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ce73041d25c..c278e7249b7 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -47,6 +47,7 @@ constexpr auto kBNGrad1OpName = "BNGrad1"; constexpr auto kBNGrad2OpName = "BNGrad2"; constexpr auto kBNGrad3OpName = "BNGrad3"; constexpr auto kFusedBatchNormEx = "FusedBatchNormEx"; +constexpr auto kInstanceNorm = "InstanceNorm"; constexpr auto kFusedBatchNormExWithActivation = "FusedBatchNormExWithActivation"; constexpr auto kFusedBatchNormExWithAddAndActivation = "FusedBatchNormExWithAddAndActivation"; constexpr auto kFusedBatchNormGradEx = "FusedBatchNormGradEx"; diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 0d1bae5a19f..83063aae022 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ from mindspore.communication import management from mindspore.ops import _selected_ops from ..cell import Cell -__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm'] +__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'InstanceNorm2d'] class _BatchNorm(Cell): @@ -705,6 +705,119 @@ class LayerNorm(Cell): self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) +class InstanceNorm2d(Cell): + r""" + Instance normalization layer over a 4D input. + + This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with + additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for + Fast Stylization `_. It rescales and recenters the feature using a mini-batch + of data and the learned parameters which can be described in the following formula. + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Note: + Note that the formula for updating the running_mean and running_var is + :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. + + Args: + num_features (int): `C` from an expected input of size (N, C, H, W). + eps (float): A value added to the denominator for numerical stability. Default: 1e-5. + momentum (float): A floating hyperparameter of the momentum for the + running_mean and running_var computation. Default: 0.1. + affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. + moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. Default: True. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32. + + Outputs: + Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and + shape as the `input_x`. + + Supported Platforms: + ``GPU`` + + Raise: + ValueError: If num_features is less than 1 or momentum not in (0, 1). + + Examples: + >>> net = nn.InstanceNorm2d(3) + >>> np.random.seed(0) + >>> input = Tensor(np.random.randint(0, 255, [2, 3, 2, 2]), mindspore.float32) + >>> output = net(input) + >>> print(output.shape) + (2, 3, 2, 2) + """ + + @cell_attr_register + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True, + input_dims='2d'): + super(InstanceNorm2d, self).__init__() + if num_features < 1: + raise ValueError("num_features must be at least 1") + + if momentum < 0 or momentum > 1: + raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) + self.use_batch_statistics = use_batch_statistics + self.num_features = num_features + self.eps = eps + self.input_dims = input_dims + self.moving_mean = Parameter(initializer( + moving_mean_init, num_features), name="mean", requires_grad=False) + self.moving_variance = Parameter(initializer( + moving_var_init, num_features), name="variance", requires_grad=False) + self.gamma = Parameter(initializer( + gamma_init, num_features), name="gamma", requires_grad=affine) + self.beta = Parameter(initializer( + beta_init, num_features), name="beta", requires_grad=affine) + + self.shape = P.Shape() + self.momentum = momentum + self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics, + epsilon=self.eps, + momentum=self.momentum) + + def _check_data_dim(self, x): + raise NotImplementedError + + def construct(self, x): + _shape_check_bn(self.shape(x), self.input_dims) + return self.instance_bn(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + def extend_repr(self): + return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( + self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) + + class GroupNorm(Cell): r""" Group Normalization over a mini-batch of inputs. diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 24a5a4d0f43..049e6d0f0d0 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -688,6 +688,24 @@ def get_bprop_fused_batch_norm_ex(self): return bprop +@bprop_getters.register(P.InstanceNorm) +def get_bprop_instance_norm(self): + """Grad definition for `InstanceNorm` operation.""" + is_training = self.is_training + input_grad = G.InstanceNormGrad(is_training, self.epsilon, self.momentum) + + def bprop(x, gamma, beta, mean, variance, out, dout): + saved_mean = out[1] + saved_variance = out[2] + out = input_grad(dout[0], x, gamma, saved_mean, saved_variance) + dx = out[0] + dgamma = out[1] + dbeta = out[2] + return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance) + + return bprop + + @bprop_getters.register(P.BatchNorm) def get_bprop_batch_norm(self): """Grad definition for `BatchNorm` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index fc545a895b7..cdc9edd6613 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -63,8 +63,8 @@ from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, U from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, - DropoutDoMask, Dropout, - DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, + DropoutDoMask, Dropout, DropoutGenMask, Flatten, + FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate, Gelu, FastGelu, Elu, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder, LogSoftmax, @@ -130,6 +130,7 @@ __all__ = [ 'MaxPoolWithArgmax', 'FusedBatchNorm', 'FusedBatchNormEx', + 'InstanceNorm', 'BNTrainingReduce', 'BNTrainingUpdate', 'BatchNorm', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 6cce3c7f24a..ef27e0bc426 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -714,6 +714,21 @@ class FusedBatchNormGradEx(PrimitiveWithInfer): return (x_type, scale_type, scale_type) +class InstanceNormGrad(PrimitiveWithInfer): + """Gradients of InstanceNorm operation.""" + + @prim_attr_register + def __init__(self, is_training=True, epsilon=0.0, momentum=0.1): + self.init_prim_io_names(inputs=['dy', 'x', 'gamma', 'save_mean', 'save_variance'], + outputs=['dx', 'bn_gamma', 'bn_beta']) + + def infer_shape(self, y_backprop_shape, x_shape, gamma_shape, save_mean_shape, save_variance_shape): + return (x_shape, gamma_shape, gamma_shape) + + def infer_dtype(self, y_backprop_type, x_type, gamma_type, save_mean_type, save_variance_type): + return (x_type, gamma_type, gamma_type) + + class UniqueGrad(Primitive): """Gradients of Unique operation.""" diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 4337eeef01c..1cfa38cad4a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -859,6 +859,119 @@ class FusedBatchNormEx(PrimitiveWithInfer): return (input_x, scale, scale, scale, scale, scale) +class InstanceNorm(PrimitiveWithInfer): + r""" + Instance normalization over a 4D input. + + This operator applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with + additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for + Fast Stylization `_. It rescales and recenters the feature using a mini-batch + of data and the learned parameters which can be described in the following formula. + + .. math:: + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta + + where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. + + Args: + is_training (bool): Is training or inference. Default: True. + epsilon (float): A small value added for numerical stability. Default: 1e-5. + momentum (float): The hyper parameter to compute moving average for running_mean and running_var + (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). + Momentum value must be [0, 1]. Default: 0.1. + data_format (str): The optional value for data format, is 'NCHW'. Default: "NCHW". + + Inputs: + - **input_x** (Tensor) - The input of InstanceNorm, Tensor of shape :math:`(N, C)`, + data type: float16 or float32. + - **gamma** (Parameter) - scale, Tensor of shape :math:`(C,)`, + data type: float32. + - **beta** (Parameter) - bias, Tensor of shape :math:`(C,)`, + data type: float32. + - **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32. + - **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32. + + Outputs: + Tuple of 3 Tensors, the normalized input, the updated parameters. + + - **output_x** (Tensor) - The output of InstanceNorm, same type and shape as the `input_x`. + - **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(NC,)`, data type: float32. + - **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(NC,)`, + data type: float32. + + Supported Platforms: + ``GPU`` + + Raise: + TypeError: If any validator check fails. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Parameter + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as ops + >>> class InstanceNormNet(nn.Cell): + >>> def __init__(self): + >>> super(InstanceNormNet, self).__init__() + >>> self.instance_norm = ops.InstanceNorm() + >>> self.gamma = Parameter(Tensor(np.ones([64]), mindspore.float32), name="gamma") + >>> self.beta = Parameter(Tensor(np.ones([64]), mindspore.float32), name="beta") + >>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean") + >>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance") + >>> + >>> def construct(self, input_x): + >>> out = self.instance_norm(input_x, self.gamma, self.beta, self.mean, self.variance) + >>> return out + >>> + >>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32) + >>> net = InstanceNormNet() + >>> output = net(input_x) + >>> result = output[0].shape + >>> print(result) + (128, 64, 32, 64) + """ + __mindspore_signature__ = ( + sig.make_sig('input_x', dtype=sig.sig_dtype.T2), + sig.make_sig('gamma', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('beta', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + ) + + @prim_attr_register + def __init__(self, is_training=True, epsilon=1e-5, momentum=0.1): + self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'], + outputs=['y', 'save_mean', 'save_variance']) + self.is_training = validator.check_bool(is_training, self.name) + self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) + self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) + self._update_parameter = True + + def infer_shape(self, input_x, gamma, beta, mean, variance): + input_shape_norm = input_x + validator.check_equal_int(len(gamma), 1, "gamma rank", self.name) + validator.check("gamma shape", gamma, "beta shape", beta, Rel.EQ, self.name) + validator.check("gamma shape[0]", gamma[0], "input channel", input_shape_norm[1], Rel.EQ, self.name) + validator.check_equal_int(len(mean), 1, "mean rank", self.name) + + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name) + save_mean_shape = gamma + save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0] + return (input_x, save_mean_shape, save_mean_shape) + + def infer_dtype(self, input_x, gamma, beta, mean, variance): + validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) + args = {"gamma": gamma, "beta": beta} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} + valid_dtypes = [mstype.tensor_type(mstype.float32)] + validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) + return (input_x, gamma, gamma) + + class BNTrainingReduce(PrimitiveWithInfer): """ For the BatchNorm operation this operator update the moving averages for training and is used in conjunction with diff --git a/tests/st/ops/gpu/test_instancenorm2d.py b/tests/st/ops/gpu/test_instancenorm2d.py new file mode 100644 index 00000000000..bdf9fa548db --- /dev/null +++ b/tests/st/ops/gpu/test_instancenorm2d.py @@ -0,0 +1,62 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import functional as F +from mindspore.ops.composite import GradOperation +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + @ms_function + def construct(self, input_x, grad): + return self.grad(self.network)(input_x, grad) + +class Net(nn.Cell): + def __init__(self, n): + super(Net, self).__init__() + self.ops = nn.BatchNorm2d(n, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5) + + def construct(self, x): + shape = F.shape(x) + return F.reshape(self.ops(F.reshape(x, (1, -1, shape[2], shape[3]))), shape) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_InstanceNorm2d_fp32(): + x_np = np.random.randn(3, 3, 2, 2).astype(np.float32) + bn_instance_comp = Net(3 * 3) + bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5) + comp_out = bn_instance_comp(Tensor(x_np)) + op_out = bn_instance_op(Tensor(x_np)) + assert np.allclose(comp_out.asnumpy(), op_out.asnumpy()) + + sens = np.random.randn(3, 3, 2, 2).astype(np.float32) + bn_comp_backward_net = Grad(bn_instance_comp) + bn_op_backward_net = Grad(bn_instance_op) + output1 = bn_comp_backward_net(Tensor(x_np), Tensor(sens)) + output2 = bn_op_backward_net(Tensor(x_np), Tensor(sens)) + assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy())