update InstanceNorm GPU kernel

This commit is contained in:
zhujingxuan 2022-06-08 17:16:59 +08:00
parent 764656a010
commit 2bf0a95a25
8 changed files with 383 additions and 212 deletions

View File

@ -110,6 +110,9 @@ int ScatterNdFunctorGPUKernelMod::Resize(const BaseOperatorPtr &base_operator,
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") ||
CHECK_SHAPE_NULL(indices_shape, kernel_name_, "indices") ||
CHECK_SHAPE_NULL(updates_shape, kernel_name_, "updates");
if (is_null_input_) {
return KRET_OK;
}
if (indices_shape.size() < kMinIndiceRank) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the dimension of 'indices' must be at least 2, but got "

View File

@ -15,30 +15,154 @@
*/
#include "plugin/device/gpu/kernel/nn/instance_norm_gpu_kernel.h"
#include <map>
#include <utility>
#include "mindspore/core/ops/instance_norm.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
InstanceNormGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(InstanceNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
InstanceNormGpuKernelMod, half)
namespace {
using KernelRunFunc = InstanceNormGpuKernelMod::KernelRunFunc;
constexpr auto kNCDims = 2;
} // namespace
bool InstanceNormGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
"Create para desc failed");
auto kernel_ptr = std::dynamic_pointer_cast<ops::InstanceNorm>(base_operator);
epsilon_ = kernel_ptr->get_epsilon();
exp_avg_factor_ = kernel_ptr->get_momentum();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs.at(kIndex0)->GetDtype()));
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int InstanceNormGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input_x");
if (is_null_input_) {
return KRET_OK;
}
batch_ = input_shape[kIndex0];
channel_ = input_shape[kIndex1];
CheckTensorSize({input_shape});
const int batch = 1;
const int channel = SizeToInt(batch_) * SizeToInt(channel_);
const int height = 1;
const int width = std::accumulate(input_shape.begin() + kNCDims, input_shape.end(), 1, std::multiplies{});
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch, channel, height, width),
"Set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
size_t para_size = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, &para_size),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_,
scale_bias_mean_var_desc_, nullptr, &workspace_size_),
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
workspace_size_list_.clear();
workspace_size_list_ = {
para_size, // ws gamma
para_size, // ws beta
para_size, // ws mean
para_size, // ws variance
workspace_size_,
};
return KRET_OK;
}
template <typename T>
bool InstanceNormGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto x_addr = GetDeviceAddress<T>(inputs, kIndex0);
auto gamma_addr = GetDeviceAddress<float>(inputs, kIndex1);
auto beta_addr = GetDeviceAddress<float>(inputs, kIndex2);
auto runing_mean_addr = GetDeviceAddress<float>(inputs, kIndex3);
auto runnig_variance_addr = GetDeviceAddress<float>(inputs, kIndex4);
T *z = nullptr;
auto y_addr = GetDeviceAddress<T>(outputs, kIndex0);
auto save_mean_addr = GetDeviceAddress<float>(outputs, kIndex1);
auto save_variance_addr = GetDeviceAddress<float>(outputs, kIndex2);
float *ws_gamma = GetDeviceAddress<float>(workspace, kIndex0);
float *ws_beta = GetDeviceAddress<float>(workspace, kIndex1);
float *ws_mean = GetDeviceAddress<float>(workspace, kIndex2);
float *ws_var = GetDeviceAddress<float>(workspace, kIndex3);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, kIndex4);
CopyMemDevice2Device(batch_, channel_, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, ws_gamma,
ws_beta, ws_mean, ws_var, stream_ptr_);
const float alpha = 1;
const float beta = 0;
float *reserve_addr = nullptr;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnBatchNormalizationForwardTrainingEx(
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_,
ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr,
workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
return true;
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &InstanceNormGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&InstanceNormGpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&InstanceNormGpuKernelMod::LaunchKernel<half>},
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InstanceNorm, InstanceNormGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_INSTANCE_NORM_GPU_KERNEL_H_
#include <map>
#include <utility>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
@ -27,193 +29,60 @@
namespace mindspore {
namespace kernel {
constexpr size_t kInputXDimSize = 4;
template <typename T>
class InstanceNormGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class InstanceNormGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<InstanceNormGpuKernelMod> {
public:
InstanceNormGpuKernelMod()
: input_x_size_(0),
input_z_size_(0),
para_size_(0),
output_size_(0),
workspace_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
exp_avg_factor_(0.1),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
z_desc_(nullptr),
scale_bias_mean_var_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~InstanceNormGpuKernelMod() override { DestroyResource(); }
InstanceNormGpuKernelMod() = default;
~InstanceNormGpuKernelMod() override {
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR_NOTRACE(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
"Destroy para desc failed");
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto x_addr = GetDeviceAddress<T>(inputs, 0);
auto gamma_addr = GetDeviceAddress<float>(inputs, 1);
auto beta_addr = GetDeviceAddress<float>(inputs, 2);
auto runing_mean_addr = GetDeviceAddress<float>(inputs, 3);
auto runnig_variance_addr = GetDeviceAddress<float>(inputs, 4);
T *z = nullptr;
auto y_addr = GetDeviceAddress<T>(outputs, 0);
auto save_mean_addr = GetDeviceAddress<float>(outputs, 1);
auto save_variance_addr = GetDeviceAddress<float>(outputs, 2);
float *ws_gamma = GetDeviceAddress<float>(workspace, 0);
float *ws_beta = GetDeviceAddress<float>(workspace, 1);
float *ws_mean = GetDeviceAddress<float>(workspace, 2);
float *ws_var = GetDeviceAddress<float>(workspace, 3);
T *workspace_addr = GetPossiblyNullDeviceAddress<T>(workspace, 4);
size_t N = input_shape_[0];
size_t C = input_shape_[1];
CopyMemDevice2Device(N, C, gamma_addr, beta_addr, runing_mean_addr, runnig_variance_addr, ws_gamma, ws_beta,
ws_mean, ws_var, reinterpret_cast<cudaStream_t>(stream_ptr));
const float alpha = 1;
const float beta = 0;
float *reserve_addr = nullptr;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationForwardTrainingEx(
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_,
ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr,
workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
return true;
stream_ptr_ = reinterpret_cast<cudaStream_t>(stream_ptr);
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
InitResource();
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 5) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 5, but got " << input_num;
}
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (input_shape_.size() != kInputXDimSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of input_x must be 4, but got "
<< input_shape_.size();
}
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input_x");
if (is_null_input_) {
InitSizeLists();
return true;
}
CheckTensorSize({input_shape_});
SetTensorDescriptor();
InitSizeLists();
return true;
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
"Destroy para desc failed");
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
"Create para desc failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_x_size_),
"Get input x size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, &para_size_),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size_),
"Get output size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(handle_, mode_, bn_ops_, x_desc_, z_desc_, y_desc_,
scale_bias_mean_var_desc_, nullptr, &workspace_size_),
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize failed");
}
input_size_list_.push_back(input_x_size_); // input x
input_size_list_.push_back(input_shape_[1]); // gamma
input_size_list_.push_back(input_shape_[1]); // beta
input_size_list_.push_back(input_shape_[1]); // mean
input_size_list_.push_back(input_shape_[1]); // variance
output_size_list_.push_back(output_size_); // output
output_size_list_.push_back(para_size_); // save mean
output_size_list_.push_back(para_size_); // save variance
workspace_size_list_.push_back(para_size_); // ws gamma
workspace_size_list_.push_back(para_size_); // ws beta
workspace_size_list_.push_back(para_size_); // ws mean
workspace_size_list_.push_back(para_size_); // ws variance
workspace_size_list_.push_back(workspace_size_);
}
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
void SetTensorDescriptor() {
cudnnTensorFormat_t cudnn_format;
int batch = 1;
int channel = SizeToInt(input_shape_[0]) * SizeToInt(input_shape_[1]);
int height = SizeToInt(input_shape_[2]);
int width = SizeToInt(input_shape_[3]);
cudnn_format = CUDNN_TENSOR_NCHW;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
static constexpr cudnnBatchNormOps_t bn_ops_{CUDNN_BATCHNORM_OPS_BN};
static constexpr cudnnBatchNormMode_t mode_{CUDNN_BATCHNORM_SPATIAL_PERSISTENT};
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set y desc failed");
size_t batch_{0};
size_t channel_{0};
size_t workspace_size_{0};
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
}
double epsilon_{10e-5};
double exp_avg_factor_{0.1};
bool is_null_input_{false};
cudnnTensorDescriptor_t x_desc_{nullptr};
cudnnTensorDescriptor_t y_desc_{nullptr};
cudnnTensorDescriptor_t z_desc_{nullptr};
cudnnTensorDescriptor_t scale_bias_mean_var_desc_{nullptr};
size_t input_x_size_;
size_t input_z_size_;
size_t para_size_;
size_t output_size_;
size_t workspace_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
double epsilon_;
double exp_avg_factor_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t z_desc_;
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
cudnnHandle_t handle_{nullptr};
cudnnDataType_t cudnn_data_type_{CUDNN_DATA_FLOAT};
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_shape_;
cudaStream_t stream_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -538,6 +538,7 @@ GVAR_DEF(PrimitivePtr, kPrimRoll, std::make_shared<Primitive>(kRoll));
GVAR_DEF(PrimitivePtr, kPrimGroupConv2DGradInput, std::make_shared<Primitive>("GroupConv2DGradInput"));
GVAR_DEF(PrimitivePtr, kPrimBatchNorm, std::make_shared<Primitive>("BatchNorm"));
GVAR_DEF(PrimitivePtr, kPrimBatchNormGrad, std::make_shared<Primitive>("BatchNormGrad"));
GVAR_DEF(PrimitivePtr, kPrimInstanceNorm, std::make_shared<Primitive>("InstanceNorm"));
GVAR_DEF(PrimitivePtr, kPrimSyncBatchNorm, std::make_shared<Primitive>("SyncBatchNorm"));
GVAR_DEF(PrimitivePtr, kPrimSyncBatchNormGrad, std::make_shared<Primitive>("SyncBatchNormGrad"));
GVAR_DEF(PrimitivePtr, kPrimBNTrainingReduce, std::make_shared<Primitive>("BNTrainingReduce"));

View File

@ -28,6 +28,70 @@
namespace mindspore {
namespace ops {
namespace {
abstract::TupleShapePtr InstanceNormInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const auto input_x_shape_ptr = input_args[kInputIndex0]->BuildShape();
const auto gamma_shape_ptr = input_args[kInputIndex1]->BuildShape();
const auto beta_shape_ptr = input_args[kInputIndex2]->BuildShape();
const auto mean_shape_ptr = input_args[kInputIndex3]->BuildShape();
const auto variance_shape_ptr = input_args[kInputIndex4]->BuildShape();
if (input_x_shape_ptr->IsDynamic() || gamma_shape_ptr->IsDynamic() || beta_shape_ptr->IsDynamic() ||
mean_shape_ptr->IsDynamic() || variance_shape_ptr->IsDynamic()) {
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{input_x_shape_ptr, mean_shape_ptr, mean_shape_ptr});
}
auto input_x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_x_shape_ptr)[kShape];
auto gamma_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(gamma_shape_ptr)[kShape];
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(beta_shape_ptr)[kShape];
auto mean_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(mean_shape_ptr)[kShape];
auto variance_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(variance_shape_ptr)[kShape];
constexpr size_t minimum_input_x_rank = 3;
(void)CheckAndConvertUtils::CheckValue<size_t>("input_x rank", input_x_shape.size(), kGreaterEqual,
minimum_input_x_rank, prim_name);
const size_t batch = input_x_shape[kInputIndex0];
const size_t channel = input_x_shape[kInputIndex1];
(void)CheckAndConvertUtils::CheckValue<size_t>("gamma rank", gamma_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("beta rank", beta_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("mean rank", mean_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("variance rank", variance_shape.size(), kEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("gamma shape", gamma_shape[0], kEqual, "(C, )", channel, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("beta shape", beta_shape[0], kEqual, "(C, )", channel, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("mean shape", mean_shape[0], kEqual, "(C, )", channel, prim_name);
(void)CheckAndConvertUtils::CheckValue<size_t>("variance shape", variance_shape[0], kEqual, "(C, )", channel,
prim_name);
const int64_t batch_channel = SizeToLong(batch * channel);
abstract::ShapePtr save_mean_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{batch_channel});
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{input_x_shape_ptr, save_mean_shape, save_mean_shape});
}
TuplePtr InstanceNormInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const auto prim_name = primitive->name();
const auto input_x = input_args[kInputIndex0]->BuildType();
const auto gamma = input_args[kInputIndex1]->BuildType();
const auto beta = input_args[kInputIndex2]->BuildType();
const auto mean = input_args[kInputIndex3]->BuildType();
const auto variance = input_args[kInputIndex4]->BuildType();
(void)CheckAndConvertUtils::CheckTypeValid("input_x", input_x, {kFloat16, kFloat32}, prim_name);
const std::map<std::string, TypePtr> types = {
{"gamma", gamma},
{"beta", beta},
{"mean", mean},
{"variance", variance},
};
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, {kFloat32}, prim_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{input_x, gamma, gamma});
}
} // namespace
MIND_API_OPERATOR_IMPL(InstanceNorm, BaseOperator);
void InstanceNorm::Init(const float epsilon) { this->set_epsilon(epsilon); }
@ -36,7 +100,21 @@ float InstanceNorm::get_epsilon() const {
auto value_ptr = GetAttr(kEpsilon);
return GetValue<float>(value_ptr);
}
void InstanceNorm::set_momentum(const float momentum) { (void)this->AddAttr(kMomentum, api::MakeValue(momentum)); }
float InstanceNorm::get_momentum() const {
auto value_ptr = GetAttr(kMomentum);
return GetValue<float>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameInstanceNorm, InstanceNorm);
AbstractBasePtr InstanceNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr int64_t kInputNum = 5;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputNum, primitive->name());
auto type = InstanceNormInferType(primitive, input_args);
auto shape = InstanceNormInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(InstanceNorm, prim::kPrimInstanceNorm, InstanceNormInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -47,6 +47,16 @@ class MIND_API InstanceNorm : public BaseOperator {
///
/// \return a value.
float get_epsilon() const;
/// \brief Method to set momentum attribute.
///
/// \param[in] momentum Define a value added to the denominator for numerical stability.
void set_momentum(const float momentum);
/// \brief Method to get momentum attribute.
///
/// \return a value.
float get_momentum() const;
};
} // namespace ops
} // namespace mindspore

View File

@ -1056,28 +1056,6 @@ class InstanceNorm(PrimitiveWithInfer):
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
def infer_shape(self, input_x, gamma, beta, mean, variance):
input_shape_norm = input_x
validator.check_equal_int(len(gamma), 1, "gamma rank", self.name)
validator.check("gamma shape", gamma, "beta shape", beta, Rel.EQ, self.name)
validator.check("gamma shape[0]", gamma[0], "input channel", input_shape_norm[1], Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "gamma shape", gamma, Rel.EQ, self.name)
save_mean_shape = gamma
save_mean_shape[0] = save_mean_shape[0] * input_shape_norm[0]
return input_x, save_mean_shape, save_mean_shape
def infer_dtype(self, input_x, gamma, beta, mean, variance):
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
args = {"gamma": gamma, "beta": beta}
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
args_moving = {"mean": mean, "variance": variance}
valid_dtypes = [mstype.tensor_type(mstype.float32)]
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
return input_x, gamma, gamma
class BNTrainingReduce(Primitive):
"""

View File

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
from mindspore import Tensor, Parameter, nn
from mindspore.ops.operations.nn_ops import InstanceNorm
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class InstanceNormNet(nn.Cell):
def __init__(self, channel, epsilon=1e-5):
super(InstanceNormNet, self).__init__()
self.instance_norm = InstanceNorm(epsilon=epsilon)
self.gamma = Parameter(Tensor(np.ones([channel]), mindspore.float32), name="gamma")
self.beta = Parameter(Tensor(np.zeros([channel]), mindspore.float32), name="beta")
self.mean = Parameter(Tensor(np.zeros([channel]), mindspore.float32), name="mean")
self.variance = Parameter(Tensor(np.ones([channel]), mindspore.float32), name="variance")
def construct(self, input_x):
out = self.instance_norm(input_x, self.gamma, self.beta, self.mean, self.variance)
return out[0]
def instance_norm_np(x, eps=1e-5):
shape = x.shape
b = shape[0]
c = shape[1]
x = x.reshape((b, c, -1))
mu = np.expand_dims(np.mean(x, axis=-1), axis=-1)
std = np.expand_dims(np.std(x, axis=-1), axis=-1)
result = (x - mu) / (std + eps)
return result.reshape(shape)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 5)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_1d(shape, data_type, err):
"""
Feature: InstanceNorm 1D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 3, 4)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_2d(shape, data_type, err):
"""
Feature: InstanceNorm 2D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@pytest.mark.parametrize("shape", [(8, 4, 3, 4, 7)])
@pytest.mark.parametrize("data_type, err", [(np.float16, 1e-3), (np.float32, 1e-4)])
def test_instancenorm_3d(shape, data_type, err):
"""
Feature: InstanceNorm 3D operator.
Description: Compatible with instance_norm_np.
Expectation: The result matches numpy implementation.
"""
np.random.seed(0)
input_x_np = np.random.randn(np.prod(shape)).reshape(shape).astype(data_type)
input_x = Tensor(input_x_np)
net = InstanceNormNet(shape[1])
output = net(input_x)
expected = instance_norm_np(input_x_np)
assert np.allclose(output.asnumpy(), expected, atol=err, rtol=err)