forked from mindspore-Ecosystem/mindspore
IR operators of GPU and CPU are unified as batchnorm
This commit is contained in:
parent
172a28fe14
commit
87e41aaeee
|
@ -16,6 +16,10 @@ Previously MakeRefKey is an external interface that is not used, now make it an
|
|||
|
||||
Previously the number of outputs of these operator is different on different backends. To unify their definition we change their output on Ascend backend from multiple to a single.
|
||||
|
||||
##### `P.FusedBatchNorm`, `P.FusedBatchNormEx` deleted ([!12115](https://gitee.com/mindspore/mindspore/pulls/12115))
|
||||
|
||||
The FusedBatchNorm and FusedBatchNormEx interface has been deleted. Please use the batchnorm operator to replace it.
|
||||
|
||||
# MindSpore 1.1.1 Release Notes
|
||||
|
||||
## MindSpore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,14 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
void BatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t type_size = sizeof(float);
|
||||
|
@ -30,16 +30,13 @@ void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|||
workspace_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
|
||||
void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "FusedBatchNorm") {
|
||||
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
|
||||
is_train = true;
|
||||
}
|
||||
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
|
||||
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
||||
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
|
||||
}
|
||||
batch_size = x_shape[0];
|
||||
channel = x_shape[1];
|
||||
|
@ -66,9 +63,9 @@ void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_DST, x_desc);
|
||||
}
|
||||
|
||||
bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool BatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() < 5 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Error input output size!";
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,18 +13,18 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class FusedBatchNormCPUKernel : public MKLCPUKernel {
|
||||
class BatchNormCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
FusedBatchNormCPUKernel() = default;
|
||||
~FusedBatchNormCPUKernel() override = default;
|
||||
BatchNormCPUKernel() = default;
|
||||
~BatchNormCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
|
@ -43,20 +43,6 @@ class FusedBatchNormCPUKernel : public MKLCPUKernel {
|
|||
size_t nhw_size{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(FusedBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -69,7 +55,7 @@ MS_REG_CPU_KERNEL(BatchNorm,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormCPUKernel)
|
||||
BatchNormCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,7 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_gard_cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/batch_norm_gard_cpu_kernel.h"
|
||||
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
|
@ -22,19 +22,20 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void FusedBatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t type_size = sizeof(float);
|
||||
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t tensor_size = shape[1] * 2 * type_size;
|
||||
input_size_list_.pop_back();
|
||||
// [2, c] to store scale and bias
|
||||
workspace_size_list_.emplace_back(tensor_size);
|
||||
// [2, c] to store diff_scale and diff_bias
|
||||
workspace_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
|
||||
void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() != 4) {
|
||||
|
@ -72,25 +73,25 @@ void FusedBatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
AddArgument(DNNL_ARG_DIFF_SCALE_SHIFT, scale_bias_desc);
|
||||
}
|
||||
|
||||
bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool BatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() < 5 || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Error input output size!";
|
||||
}
|
||||
auto wksp_in = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
auto scale_ret = memcpy_s(wksp_in, workspace[0]->size, inputs[2]->addr, inputs[2]->size);
|
||||
auto max_size = workspace[0]->size - inputs[2]->size;
|
||||
auto bias_ret = memcpy_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, inputs[3]->addr, inputs[3]->size);
|
||||
if (scale_ret != 0 || bias_ret != 0) {
|
||||
auto bias_ret = memset_s(wksp_in + (inputs[2]->size / sizeof(float)), max_size, 0., max_size);
|
||||
if (scale_ret != 0 && bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error.";
|
||||
return false;
|
||||
}
|
||||
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_MEAN, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[5]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_MEAN, inputs[3]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_VARIANCE, inputs[4]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DIFF_SCALE_SHIFT, workspace[1]->addr);
|
||||
|
@ -99,7 +100,7 @@ bool FusedBatchNormGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &
|
|||
auto wksp_out = reinterpret_cast<float *>(workspace[1]->addr);
|
||||
auto diff_scale_ret = memcpy_s(outputs[1]->addr, outputs[1]->size, wksp_out, inputs[2]->size);
|
||||
auto diff_bias_ret =
|
||||
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), inputs[3]->size);
|
||||
memcpy_s(outputs[2]->addr, outputs[2]->size, wksp_out + (outputs[1]->size / sizeof(float)), outputs[2]->size);
|
||||
if (diff_scale_ret != 0 || diff_bias_ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "Memcpy_s error.";
|
||||
return false;
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,18 +13,18 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_GRAD_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BATCH_NORM_GRAD_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class FusedBatchNormGradCPUKernel : public MKLCPUKernel {
|
||||
class BatchNormGradCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
FusedBatchNormGradCPUKernel() = default;
|
||||
~FusedBatchNormGradCPUKernel() override = default;
|
||||
BatchNormGradCPUKernel() = default;
|
||||
~BatchNormGradCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
|
@ -42,7 +42,7 @@ class FusedBatchNormGradCPUKernel : public MKLCPUKernel {
|
|||
size_t nhw_size{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(FusedBatchNormGradCPU,
|
||||
MS_REG_CPU_KERNEL(BatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -53,7 +53,7 @@ MS_REG_CPU_KERNEL(FusedBatchNormGradCPU,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGradCPUKernel)
|
||||
BatchNormGradCPUKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/nn/batch_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -29,10 +29,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
|
||||
BatchNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -43,11 +42,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormEx,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, half)
|
||||
BatchNormGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -58,10 +56,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
|
||||
BatchNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormWithActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -72,11 +69,10 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, half)
|
||||
BatchNormGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -88,10 +84,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
|
||||
BatchNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormWithAddAndActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
|
@ -103,8 +98,7 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormExWithAddAndActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormExGpuKernel, half)
|
||||
BatchNormGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -27,10 +27,10 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedBatchNormExGpuKernel : public GpuKernel {
|
||||
class BatchNormGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedBatchNormExGpuKernel() { ResetResource(); }
|
||||
~FusedBatchNormExGpuKernel() override { DestroyResource(); }
|
||||
BatchNormGpuKernel() { ResetResource(); }
|
||||
~BatchNormGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -46,30 +46,38 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 1);
|
||||
auto bias = GetDeviceAddress<float>(inputs, 2);
|
||||
auto runing_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto runnig_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto running_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto running_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
T *z = nullptr;
|
||||
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
|
||||
z = GetDeviceAddress<T>(inputs, 5);
|
||||
}
|
||||
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
auto save_mean = GetDeviceAddress<float>(outputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(outputs, 4);
|
||||
auto reserve_addr = GetDeviceAddress<float>(outputs, 5);
|
||||
auto reserve_addr = GetDeviceAddress<float>(outputs, 2);
|
||||
T *workspace_addr = nullptr;
|
||||
if (workspace_size_ != 0) {
|
||||
workspace_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
}
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationForwardTrainingEx(handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_,
|
||||
y, scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean,
|
||||
runnig_variance, epsilon_, save_mean, save_variance, activation_desc_,
|
||||
workspace_addr, workspace_size_, reserve_addr, reserve_size_),
|
||||
"Kernel launch failed");
|
||||
if (is_train_) {
|
||||
auto save_mean = GetDeviceAddress<float>(outputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(outputs, 4);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationForwardTrainingEx(
|
||||
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x, z_desc_, z, y_desc_, y, scale_bias_mean_var_desc_, scale,
|
||||
bias, exp_avg_factor_, running_mean, running_variance, epsilon_, save_mean, save_variance, activation_desc_,
|
||||
workspace_addr, workspace_size_, reserve_addr, reserve_size_),
|
||||
"Kernel launch failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnBatchNormalizationForwardInference(
|
||||
handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, scale_bias_mean_var_desc_,
|
||||
scale, bias, running_mean, running_variance, epsilon_),
|
||||
"Kernel launch failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -77,18 +85,22 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
kernel_node_ = kernel_node;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == kFusedBatchNormEx) {
|
||||
if (kernel_name == kBatchNorm) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||
} else if (kernel_name == kFusedBatchNormExWithActivation) {
|
||||
} else if (kernel_name == kBatchNormWithActivation) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
|
||||
} else if (kernel_name == kFusedBatchNormExWithAddAndActivation) {
|
||||
} else if (kernel_name == kBatchNormWithAddAndActivation) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name;
|
||||
}
|
||||
|
||||
InitResource();
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
if (is_train_) {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
} else {
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
}
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
|
||||
|
||||
|
@ -106,11 +118,11 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormExGpuKernel should be 4";
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGpuKernel should be 4";
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "FusedBatchNormExGpuKernel input is null";
|
||||
MS_LOG(WARNING) << "BatchNormGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -121,6 +133,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
}
|
||||
SetTensorDescriptor(format, shape);
|
||||
InitSizeLists();
|
||||
is_train_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -135,6 +148,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||
epsilon_ = 10e-5;
|
||||
exp_avg_factor_ = 0.1;
|
||||
is_train_ = false;
|
||||
is_null_input_ = false;
|
||||
x_desc_ = nullptr;
|
||||
y_desc_ = nullptr;
|
||||
|
@ -215,11 +229,10 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
output_size_list_.push_back(output_size_); // output
|
||||
output_size_list_.push_back(reserve_size_); // reserve space
|
||||
output_size_list_.push_back(para_size_); // save scale
|
||||
output_size_list_.push_back(para_size_); // save bias
|
||||
output_size_list_.push_back(para_size_); // save mean
|
||||
output_size_list_.push_back(para_size_); // save variance
|
||||
output_size_list_.push_back(reserve_size_); // reserve space
|
||||
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
}
|
||||
|
@ -280,6 +293,7 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
|
|||
cudnnBatchNormOps_t bn_ops_;
|
||||
double epsilon_;
|
||||
double exp_avg_factor_;
|
||||
bool is_train_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,11 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/nn/batch_norm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // dy
|
||||
.AddInputAttr(kNumberTypeFloat32) // x
|
||||
|
@ -29,8 +29,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
|
|||
.AddOutputAttr(kNumberTypeFloat32) // dx
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||
FusedBatchNormGradExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
|
||||
BatchNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16) // dy
|
||||
.AddInputAttr(kNumberTypeFloat16) // x
|
||||
|
@ -41,9 +41,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
|
|||
.AddOutputAttr(kNumberTypeFloat16) // dx
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||
FusedBatchNormGradExGpuKernel, half)
|
||||
BatchNormGradGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // dy
|
||||
.AddInputAttr(kNumberTypeFloat32) // x
|
||||
|
@ -56,8 +56,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32) // dx
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||
FusedBatchNormGradExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
|
||||
BatchNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGradWithActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16) // dy
|
||||
.AddInputAttr(kNumberTypeFloat16) // x
|
||||
|
@ -70,9 +70,9 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat16) // dx
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32), // dbias
|
||||
FusedBatchNormGradExGpuKernel, half)
|
||||
BatchNormGradGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32) // dy
|
||||
.AddInputAttr(kNumberTypeFloat32) // x
|
||||
|
@ -86,8 +86,8 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dbias
|
||||
.AddOutputAttr(kNumberTypeFloat32), // dz
|
||||
FusedBatchNormGradExGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
|
||||
BatchNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGradWithAddAndActivation,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16) // dy
|
||||
.AddInputAttr(kNumberTypeFloat16) // x
|
||||
|
@ -101,6 +101,6 @@ MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
|
|||
.AddOutputAttr(kNumberTypeFloat32) // dscale
|
||||
.AddOutputAttr(kNumberTypeFloat32) // dbias
|
||||
.AddOutputAttr(kNumberTypeFloat16), // dz
|
||||
FusedBatchNormGradExGpuKernel, half)
|
||||
BatchNormGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCH_NORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -24,13 +24,14 @@
|
|||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
||||
class BatchNormGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedBatchNormGradExGpuKernel()
|
||||
BatchNormGradGpuKernel()
|
||||
: x_size_(0),
|
||||
para_size_(0),
|
||||
workspace_size_(0),
|
||||
|
@ -38,6 +39,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
|
||||
epsilon_(10e-5),
|
||||
is_train_(false),
|
||||
is_null_input_(false),
|
||||
x_desc_(nullptr),
|
||||
y_desc_(nullptr),
|
||||
|
@ -49,7 +51,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
handle_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
beta_data_diff_(0) {}
|
||||
~FusedBatchNormGradExGpuKernel() override { DestroyResource(); }
|
||||
~BatchNormGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -88,17 +90,22 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
if (workspace_size_ != 0) {
|
||||
workspace_addr = GetDeviceAddress<T>(workspace, 0);
|
||||
}
|
||||
|
||||
const float alpha_data_diff = 1;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnBatchNormalizationBackwardEx(
|
||||
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
|
||||
scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance,
|
||||
activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_),
|
||||
"Kernel launch failed");
|
||||
if (is_train_) {
|
||||
const float alpha_data_diff = 1;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackwardEx(handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff_,
|
||||
&alpha_param_diff, &beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy,
|
||||
dz_desc_, dz, dx_desc_, dx, scale_bias_diff_desc_, scale, bias, dscale, dbias,
|
||||
epsilon_, save_mean, save_variance, activation_desc_, workspace_addr,
|
||||
workspace_size_, reserve_addr, reserve_size_),
|
||||
"Kernel launch failed");
|
||||
} else {
|
||||
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, dscale, dbias, epsilon_, batch_, channel_, height_,
|
||||
width_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -106,11 +113,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
kernel_node_ = kernel_node;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == kFusedBatchNormGradEx) {
|
||||
if (kernel_name == kBatchNormGradOpName) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||
} else if (kernel_name == kFusedBatchNormGradExWithActivation) {
|
||||
} else if (kernel_name == kBatchNormGradWithActivation) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
|
||||
} else if (kernel_name == kFusedBatchNormGradExWithAddAndActivation) {
|
||||
} else if (kernel_name == kBatchNormGradWithAddAndActivation) {
|
||||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name;
|
||||
|
@ -134,11 +141,11 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradExGpuKernel should be 4";
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "FusedBatchNormGradExGpuKernel input is null";
|
||||
MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -150,6 +157,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
|
||||
SetTensorDescriptor(format, shape);
|
||||
InitSizeLists();
|
||||
is_train_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -225,50 +233,52 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
private:
|
||||
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
|
||||
cudnnTensorFormat_t cudnn_format;
|
||||
int batch, channel, height, width;
|
||||
if (format == kOpFormat_NHWC) {
|
||||
batch = SizeToInt(shape[0]);
|
||||
height = SizeToInt(shape[1]);
|
||||
width = SizeToInt(shape[2]);
|
||||
channel = SizeToInt(shape[3]);
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
height_ = SizeToInt(shape[1]);
|
||||
width_ = SizeToInt(shape[2]);
|
||||
channel_ = SizeToInt(shape[3]);
|
||||
cudnn_format = CUDNN_TENSOR_NHWC;
|
||||
} else {
|
||||
batch = SizeToInt(shape[0]);
|
||||
channel = SizeToInt(shape[1]);
|
||||
height = SizeToInt(shape[2]);
|
||||
width = SizeToInt(shape[3]);
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
channel_ = SizeToInt(shape[1]);
|
||||
height_ = SizeToInt(shape[2]);
|
||||
width_ = SizeToInt(shape[3]);
|
||||
cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set x desc failed");
|
||||
|
||||
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set z desc failed");
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set dy desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
"Set dx desc failed");
|
||||
|
||||
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
|
||||
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set z desc failed");
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dy desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dx desc failed");
|
||||
|
||||
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set z desc failed");
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
|
||||
"Set para desc failed");
|
||||
|
||||
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
|
||||
|
@ -278,7 +288,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
"cudnnSetActivationDescriptor failed");
|
||||
}
|
||||
}
|
||||
|
||||
int batch_;
|
||||
int channel_;
|
||||
int height_;
|
||||
int width_;
|
||||
size_t x_size_;
|
||||
size_t para_size_;
|
||||
size_t workspace_size_;
|
||||
|
@ -286,6 +299,7 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
|
|||
cudnnBatchNormMode_t mode_;
|
||||
cudnnBatchNormOps_t bn_ops_;
|
||||
double epsilon_;
|
||||
bool is_train_;
|
||||
bool is_null_input_;
|
||||
|
||||
cudnnTensorDescriptor_t x_desc_;
|
|
@ -1,48 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/batchnorm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BatchNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BatchNormGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,202 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_grad_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class BatchNormGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
BatchNormGradGpuKernel()
|
||||
: batch_(0),
|
||||
channel_(0),
|
||||
height_(0),
|
||||
width_(0),
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||
epsilon_(10e-5),
|
||||
is_null_input_(false),
|
||||
x_desc_(nullptr),
|
||||
dy_desc_(nullptr),
|
||||
dx_desc_(nullptr),
|
||||
scale_bias_desc_(nullptr),
|
||||
handle_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
|
||||
~BatchNormGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 2);
|
||||
auto save_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
|
||||
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
|
||||
auto reserve_1 = GetDeviceAddress<T>(outputs, 3);
|
||||
auto reserve_2 = GetDeviceAddress<T>(outputs, 4);
|
||||
// For CI only, reserved vars can not be unused.
|
||||
MS_LOG(DEBUG) << reinterpret_cast<size_t>(reserve_1) << reinterpret_cast<size_t>(reserve_2); // NOLINT
|
||||
|
||||
if (is_training_) {
|
||||
const float alpha_data_diff = 1;
|
||||
const float beta_data_diff = 0;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_,
|
||||
scale, bn_scale, bn_bias, epsilon_, save_mean, save_variance),
|
||||
"Kernel Launch Failed.");
|
||||
} else {
|
||||
CalBatchNormGrad(x, dy, scale, save_mean, save_variance, dx, bn_scale, bn_bias, epsilon_, batch_, channel_,
|
||||
height_, width_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", BatchNormGradGpuKernel should be 5";
|
||||
}
|
||||
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", BatchNormGradGpuKernel should be 4";
|
||||
return false;
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "BatchNormGradGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
channel_ = SizeToInt(shape[1]);
|
||||
height_ = SizeToInt(shape[2]);
|
||||
width_ = SizeToInt(shape[3]);
|
||||
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
is_training_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
|
||||
"Set para desc failed");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_),
|
||||
"Destroy para desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_),
|
||||
"Create para desc failed");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = 0;
|
||||
size_t para_size = 0;
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
|
||||
"Get input size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size),
|
||||
"Get input size failed");
|
||||
}
|
||||
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
|
||||
output_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(para_size);
|
||||
output_size_list_.push_back(para_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_;
|
||||
int channel_;
|
||||
int height_;
|
||||
int width_;
|
||||
|
||||
cudnnBatchNormMode_t mode_;
|
||||
bool is_training_;
|
||||
double epsilon_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t dy_desc_;
|
||||
cudnnTensorDescriptor_t dx_desc_;
|
||||
cudnnTensorDescriptor_t scale_bias_desc_;
|
||||
|
||||
cudnnHandle_t handle_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BATCHNORM_GRAD_GPU_KERNEL_H_
|
|
@ -1,74 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(BatchNorm,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,204 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedBatchNormGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedBatchNormGpuKernel()
|
||||
: batch_(0),
|
||||
channel_(0),
|
||||
height_(0),
|
||||
width_(0),
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||
epsilon_(10e-5),
|
||||
exp_avg_factor_(0.1),
|
||||
is_train_(false),
|
||||
is_null_input_(false),
|
||||
x_desc_(nullptr),
|
||||
y_desc_(nullptr),
|
||||
scale_bias_mean_var_desc_(nullptr),
|
||||
handle_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
|
||||
~FusedBatchNormGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto x = GetDeviceAddress<T>(inputs, 0);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 1);
|
||||
auto bias = GetDeviceAddress<float>(inputs, 2);
|
||||
auto runing_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto runnig_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto y = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
if (is_train_) {
|
||||
auto save_mean = GetDeviceAddress<float>(outputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(outputs, 4);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y,
|
||||
scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean,
|
||||
runnig_variance, epsilon_, save_mean, save_variance),
|
||||
"Kernel launch failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x,
|
||||
y_desc_, y, scale_bias_mean_var_desc_, scale,
|
||||
bias, runing_mean, runnig_variance, epsilon_),
|
||||
"Kernel launch failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5";
|
||||
}
|
||||
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4";
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
|
||||
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto format_attr = GetAttr<std::string>(kernel_node, "format");
|
||||
if (format_attr == kOpFormat_NHWC) {
|
||||
format = kOpFormat_NHWC;
|
||||
cudnn_format = CUDNN_TENSOR_NHWC;
|
||||
}
|
||||
SetNCHW(shape, &batch_, &channel_, &height_, &width_, format);
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (node_name == "FusedBatchNorm") {
|
||||
is_train_ = true;
|
||||
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
|
||||
}
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set x desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set y desc failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, cudnn_format, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
|
||||
"Set para desc failed");
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_),
|
||||
"Destroy para desc failed");
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_),
|
||||
"Create para desc failed");
|
||||
}
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = 0;
|
||||
size_t para_size = 0;
|
||||
size_t output_size = 0;
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
|
||||
"Get input size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size),
|
||||
"Get para size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_desc_, &output_size),
|
||||
"Get para size failed");
|
||||
}
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(para_size); // scale
|
||||
input_size_list_.push_back(para_size); // bias
|
||||
input_size_list_.push_back(para_size); // mean
|
||||
input_size_list_.push_back(para_size); // variance
|
||||
|
||||
output_size_list_.push_back(output_size);
|
||||
output_size_list_.push_back(para_size); // running mean
|
||||
output_size_list_.push_back(para_size); // running variance
|
||||
output_size_list_.push_back(para_size); // save mean
|
||||
output_size_list_.push_back(para_size); // save variance
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_;
|
||||
int channel_;
|
||||
int height_;
|
||||
int width_;
|
||||
cudnnBatchNormMode_t mode_;
|
||||
double epsilon_;
|
||||
double exp_avg_factor_;
|
||||
bool is_train_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t y_desc_;
|
||||
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
|
||||
cudnnHandle_t handle_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGradGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
FusedBatchNormGradGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,188 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FusedBatchNormGradGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FusedBatchNormGradGpuKernel()
|
||||
: batch_(0),
|
||||
channel_(0),
|
||||
height_(0),
|
||||
width_(0),
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||
epsilon_(10e-5),
|
||||
is_null_input_(false),
|
||||
x_desc_(nullptr),
|
||||
dy_desc_(nullptr),
|
||||
dx_desc_(nullptr),
|
||||
scale_bias_desc_(nullptr),
|
||||
handle_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
|
||||
~FusedBatchNormGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
VARIABLE_NOT_USED(stream_ptr);
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto dy = GetDeviceAddress<T>(inputs, 0);
|
||||
auto x = GetDeviceAddress<T>(inputs, 1);
|
||||
auto scale = GetDeviceAddress<float>(inputs, 2);
|
||||
auto save_mean = GetDeviceAddress<float>(inputs, 3);
|
||||
auto save_variance = GetDeviceAddress<float>(inputs, 4);
|
||||
auto dx = GetDeviceAddress<T>(outputs, 0);
|
||||
auto bn_scale = GetDeviceAddress<float>(outputs, 1);
|
||||
auto bn_bias = GetDeviceAddress<float>(outputs, 2);
|
||||
|
||||
const float alpha_data_diff = 1;
|
||||
const float beta_data_diff = 0;
|
||||
const float alpha_param_diff = 1;
|
||||
const float beta_param_diff = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
|
||||
&beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale,
|
||||
bn_scale, bn_bias, epsilon_, save_mean, save_variance),
|
||||
"Kernel Launch Failed.");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
InitResource();
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 5) {
|
||||
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5";
|
||||
}
|
||||
|
||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4";
|
||||
return false;
|
||||
}
|
||||
is_null_input_ = CHECK_NULL_INPUT(shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
batch_ = SizeToInt(shape[0]);
|
||||
channel_ = SizeToInt(shape[1]);
|
||||
height_ = SizeToInt(shape[2]);
|
||||
width_ = SizeToInt(shape[3]);
|
||||
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
|
||||
"Set dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
|
||||
"Set para desc failed");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void DestroyResource() noexcept override {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(scale_bias_desc_),
|
||||
"Destroy para desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&scale_bias_desc_),
|
||||
"Create para desc failed");
|
||||
}
|
||||
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = 0;
|
||||
size_t para_size = 0;
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &input_size),
|
||||
"Get input size failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size),
|
||||
"Get input size failed");
|
||||
}
|
||||
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
input_size_list_.push_back(para_size);
|
||||
|
||||
output_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(para_size);
|
||||
output_size_list_.push_back(para_size);
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_;
|
||||
int channel_;
|
||||
int height_;
|
||||
int width_;
|
||||
|
||||
cudnnBatchNormMode_t mode_;
|
||||
double epsilon_;
|
||||
bool is_null_input_;
|
||||
cudnnTensorDescriptor_t x_desc_;
|
||||
cudnnTensorDescriptor_t dy_desc_;
|
||||
cudnnTensorDescriptor_t dx_desc_;
|
||||
cudnnTensorDescriptor_t scale_bias_desc_;
|
||||
|
||||
cudnnHandle_t handle_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_
|
|
@ -34,7 +34,12 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) {
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
} else {
|
||||
CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> bn_update_grad_inputs = {
|
||||
NewValueNode(std::make_shared<Primitive>(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2],
|
||||
bn_grad_inputs[4], bn_grad_inputs[5]};
|
||||
|
@ -57,7 +62,12 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
auto bn_grad_inputs = bn_grad_node->inputs();
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
if (AnfAlgo::CheckPrimitiveType(bn_grad_node, prim::kPrimBatchNormGrad)) {
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum);
|
||||
} else {
|
||||
CheckCNodeInputSize(bn_grad_node, kSyncBNGradInputTensorNum);
|
||||
}
|
||||
|
||||
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size";
|
||||
}
|
||||
|
@ -110,6 +120,7 @@ CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &c
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> bn_update_grad_outputs;
|
||||
|
||||
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
|
||||
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,7 +38,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(bn_cnode);
|
||||
if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) {
|
||||
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
|
||||
MS_LOG(INFO) << "BatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString();
|
||||
return false;
|
||||
}
|
||||
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
||||
|
@ -51,7 +51,7 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|||
bn_training_reduce->set_kernel_info(kernel_info);
|
||||
std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
|
||||
if (bn_shape_i0.size() < kShape2dDims) {
|
||||
MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims;
|
||||
MS_LOG(INFO) << "The BatchNorm's first input's shape dims less than " << kShape2dDims;
|
||||
return false;
|
||||
}
|
||||
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};
|
||||
|
|
|
@ -33,7 +33,7 @@ CNodePtr CreateBatchNorm3DGrad(const FuncGraphPtr &graph, const CNodePtr &batchn
|
|||
MS_EXCEPTION_IF_NULL(batchnorm_grad);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNorm3DGradOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
for (size_t i = 1; i < batchnorm_grad->size(); ++i) {
|
||||
for (size_t i = 1; i < batchnorm_grad->size() - 1; ++i) {
|
||||
inputs.push_back(batchnorm_grad->input(i));
|
||||
}
|
||||
auto new_node = graph->NewCNode(inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -56,7 +56,8 @@ constexpr size_t kBN1OutputNum = 2;
|
|||
constexpr size_t kBN2OutputNum = 3;
|
||||
constexpr size_t kBN3OutputNum = 1;
|
||||
|
||||
constexpr size_t kBNGradInputTensorNum = 5;
|
||||
constexpr size_t kBNGradInputTensorNum = 6;
|
||||
constexpr size_t kSyncBNGradInputTensorNum = 5;
|
||||
constexpr size_t kBNGradOutputNum = 3;
|
||||
|
||||
constexpr size_t kBNGrad1OutputNum = 3;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,8 +28,8 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef BatchNormAddReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, tuple_get_item, z_});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||
return relu;
|
||||
|
@ -44,24 +44,24 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format");
|
||||
auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
return nullptr;
|
||||
}
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
|
||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4);
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
|
||||
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -71,7 +71,7 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(z);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
|
||||
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
||||
|
@ -79,17 +79,17 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
|
|||
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i));
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -85,14 +85,14 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
|
|||
std::vector<AnfNodePtr> bn_add_relu_grad_output;
|
||||
CreateMultipleOutputsOfAnfNode(graph, bn_add_relu_grad, kBNAddReluGradOutputNum, &bn_add_relu_grad_output);
|
||||
if (bn_add_relu_grad_output.size() != kBNAddReluGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "The output size of node " << kFusedBatchNormGradExWithAddAndActivation << " must be "
|
||||
MS_LOG(EXCEPTION) << "The output size of node " << kBatchNormGradWithAddAndActivation << " must be "
|
||||
<< kBNAddReluGradOutputNum << ", but it is " << bn_add_relu_grad_output.size();
|
||||
}
|
||||
|
||||
// Get bn outputs
|
||||
std::vector<AnfNodePtr> bn_outputs;
|
||||
if (!GetBatchNormOutputs(graph, bn_grad, &bn_outputs)) {
|
||||
MS_LOG(INFO) << "The " << prim::kPrimFusedBatchNormGradEx
|
||||
MS_LOG(INFO) << "The " << prim::kPrimBatchNormGrad
|
||||
<< " node should only have output 0, 1 and 2. The node should not be changed";
|
||||
return;
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
|
||||
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
|
||||
if (AnfAlgo::GetCNodeName(forward_node) != kBatchNormWithAddAndActivation) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -150,7 +150,7 @@ bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
|
||||
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
|
||||
VectorRef batch_norm_grad =
|
||||
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
||||
VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
||||
return batch_norm_grad;
|
||||
}
|
||||
|
||||
|
@ -184,7 +184,7 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
|
|||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithAddAndActivation);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormGradWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
|
||||
auto fused_batch_norm_add_relu_grad = graph->NewCNode(inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,8 +28,8 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef BatchNormReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tuple_get});
|
||||
return relu;
|
||||
}
|
||||
|
@ -41,24 +41,24 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
|
||||
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format");
|
||||
auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
return nullptr;
|
||||
}
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
|
||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4);
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(scale);
|
||||
|
@ -66,7 +66,7 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithActivation);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var};
|
||||
auto fused_batch_norm_with_relu = graph->NewCNode(inputs);
|
||||
|
@ -74,17 +74,17 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
|
|||
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i));
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_relu.get());
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_relu);
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_relu);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_relu);
|
||||
manager->Replace(batch_norm, fused_batch_norm_with_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,7 +31,7 @@ namespace opt {
|
|||
const BaseRef BatchNormReluGradFusion::DefinePattern() const {
|
||||
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
|
||||
VectorRef batch_norm_grad =
|
||||
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
||||
VectorRef({prim::kPrimBatchNormGrad, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
|
||||
return batch_norm_grad;
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
|
|||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
MS_EXCEPTION_IF_NULL(bias);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormGradExWithActivation);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormGradWithActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), dy, x, scale, save_mean, save_var, reserve, bias, y};
|
||||
auto fused_batch_norm_grad_with_relu = graph->NewCNode(inputs);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,8 +42,7 @@ struct AnfNodeIndex {
|
|||
};
|
||||
|
||||
// opname, output idx
|
||||
std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0},
|
||||
{kFusedBatchNormGradExWithAddAndActivation, 3}};
|
||||
std::map<string, uint32_t> kInplaceOpNames = {{kConv2DBackpropInputOpName, 0}, {kBatchNormGradWithAddAndActivation, 3}};
|
||||
|
||||
std::set<string> kSkipOpNames = {
|
||||
kTensorAddOpName,
|
||||
|
@ -51,7 +50,7 @@ std::set<string> kSkipOpNames = {
|
|||
|
||||
// opname, input idx
|
||||
std::map<string, uint32_t> kAggregatesOpNames = {
|
||||
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}};
|
||||
{kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kBatchNormGradWithAddAndActivation, 0}};
|
||||
|
||||
constexpr size_t inplace_node_size = 2;
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef PostBatchNormAddReluFusion::DefinePattern() const {
|
||||
VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_});
|
||||
VectorRef batch_norm = VectorRef({prim::kPrimBatchNorm, x_, scale_, bias_, mean_, var_});
|
||||
VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm, index_});
|
||||
VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item});
|
||||
VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add});
|
||||
return relu;
|
||||
|
@ -44,24 +44,24 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
|||
MS_EXCEPTION_IF_NULL(tensor_add);
|
||||
auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm_ex);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format");
|
||||
auto batch_norm = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
|
||||
MS_EXCEPTION_IF_NULL(batch_norm);
|
||||
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm)->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(format_attr);
|
||||
auto format = GetValue<std::string>(format_attr);
|
||||
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
if (AnfAlgo::GetInputFormat(batch_norm, 0) != kOpFormat_NHWC && format != "NHWC") {
|
||||
return nullptr;
|
||||
}
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0);
|
||||
auto shape = AnfAlgo::GetInputDeviceShape(batch_norm, 0);
|
||||
if (shape.back() % kBNChannelMultipleFactor != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm_ex), 4);
|
||||
auto x = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 0);
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(batch_norm), 4);
|
||||
auto z = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tensor_add), 0);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -71,7 +71,7 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
|||
MS_EXCEPTION_IF_NULL(var);
|
||||
MS_EXCEPTION_IF_NULL(z);
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedBatchNormExWithAddAndActivation);
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormWithAddAndActivation);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x, scale, bias, mean, var, z};
|
||||
auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs);
|
||||
|
@ -79,17 +79,17 @@ const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph,
|
|||
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex);
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i));
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm, i));
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm, i));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get());
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
AnfAlgo::CopyNodeAttrs(batch_norm, fused_batch_norm_with_add_relu);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu);
|
||||
manager->Replace(batch_norm, fused_batch_norm_with_add_relu);
|
||||
device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu);
|
||||
return tuple_get_item;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -445,9 +445,10 @@ size_t BestFitMemReuse::GetAllocatedSize() {
|
|||
bool BestFitMemReuse::IsRelease() {
|
||||
// unable_used_node include the node type that output tensor cannot be released,
|
||||
// even if its refcount is equal to zero.
|
||||
std::unordered_set<std::string> unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(),
|
||||
prim::kPrimFusedBatchNorm->name(),
|
||||
prim::kPrimFusedBatchNormGrad->name()};
|
||||
std::unordered_set<std::string> unable_used_node = {
|
||||
prim::kPrimBatchNorm->name(),
|
||||
prim::kPrimBatchNormGrad->name(),
|
||||
};
|
||||
return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end();
|
||||
}
|
||||
|
||||
|
@ -494,7 +495,7 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
|
|||
#endif
|
||||
for (const auto &op_def_ptr : op_ptr_list_) {
|
||||
current_kernel_ = op_def_ptr;
|
||||
// releas pre_op_def
|
||||
// release pre_op_def
|
||||
if (pre_op != nullptr) {
|
||||
ReleasePreNodeWorkspace(pre_op.get());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -450,12 +450,7 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
|||
// Experimental support for 3D data (input_size == 3).
|
||||
if (input_size >= 1 && input_size <= 4) {
|
||||
if (dim == 0) {
|
||||
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors.
|
||||
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
|
||||
s.push_back(1);
|
||||
} else {
|
||||
s.push_back(std::min(max_device_num, target_tensor_batch));
|
||||
}
|
||||
s.push_back(std::min(max_device_num, target_tensor_batch));
|
||||
} else {
|
||||
s.push_back(1);
|
||||
}
|
||||
|
@ -533,8 +528,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|||
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
||||
} else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
|
||||
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
|
||||
(type == "FusedBatchNormEx") || (type == "Dropout") || (type == BATCH_MATMUL)) {
|
||||
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || (type == "Dropout") ||
|
||||
(type == BATCH_MATMUL)) {
|
||||
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
||||
} else {
|
||||
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,7 +46,6 @@ const std::map<std::string, OperatorType> DictOpType{
|
|||
{RESHAPE, OperatorType::kRecReshape},
|
||||
{BIAS_ADD, OperatorType::kRecBiasAdd},
|
||||
{BATCH_NORM, OperatorType::kRecBatchNorm},
|
||||
{FUSE_BATCH_NORM, OperatorType::kRecBatchNorm},
|
||||
{LAYER_NORM, OperatorType::kRecBatchNorm},
|
||||
{SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
|
||||
{ONEHOT, OperatorType::kRecOneHot},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -53,9 +53,7 @@ enum MatchCountPriority : int {
|
|||
MATCH_COUNT_PRIORITY_END
|
||||
};
|
||||
const std::map<std::string, std::vector<std::string>> kNextOpFormatList = {
|
||||
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}},
|
||||
{prim::kPrimFusedBatchNorm->name(),
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}}};
|
||||
{prim::kPrimConv2D->name(), {kOpFormat_NC1HWC0, kOpFormat_FRAC_Z}}};
|
||||
|
||||
bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -233,6 +231,24 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|||
return result;
|
||||
}
|
||||
|
||||
bool CheckHitTargetDtype(const std::map<TypeId, TypeId> &type_map, const TypeId &in_dtype, const TypeId &device_dtype,
|
||||
bool *flag) {
|
||||
auto iter = type_map.find(in_dtype);
|
||||
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
||||
// or infer dtype not equal kernel info dtype, return false
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
||||
*flag = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
// filte kernel info that unsupported raise or reduce datatype
|
||||
|
@ -245,19 +261,9 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|||
if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
|
||||
return false;
|
||||
}
|
||||
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
||||
// or infer dtype not equal kernel info dtype, return false
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
|
@ -266,19 +272,10 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|||
if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
|
||||
if (!CheckHitTargetDtype(type_map, in_dtype, device_dtype, &flag)) {
|
||||
return false;
|
||||
}
|
||||
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
||||
// or infer dtype not equal kernel info dtype, return false
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
if (flag) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
|
|
|
@ -101,9 +101,9 @@ namespace {
|
|||
std::vector<int> CheckRealOutput(const std::string &node_name, const size_t &output_size) {
|
||||
// define a vector containing real output number
|
||||
std::vector<int> real_outputs;
|
||||
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
|
||||
// P.BatchNorm is used for training and inference
|
||||
// can add the filter list for more operators here....
|
||||
if (node_name == "FusedBatchNorm" || node_name == "BatchNorm") {
|
||||
if (node_name == "BatchNorm") {
|
||||
MS_LOG(INFO) << "loading node named " << node_name;
|
||||
real_outputs.insert(real_outputs.end(), {0, 3, 4});
|
||||
} else {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -374,7 +374,7 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s
|
|||
if (kernel_name == prim::kPrimConv2D->name()) {
|
||||
conv_cnt++;
|
||||
}
|
||||
if (kernel_name == prim::kPrimFusedBatchNormEx->name()) {
|
||||
if (kernel_name == prim::kPrimBatchNorm->name()) {
|
||||
bn_cnt++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,12 +46,12 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
|
|||
{prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}},
|
||||
{kAvgPoolOpName, {{0}, {0}}},
|
||||
{kAvgPoolGradOpName, {{0, 1, 2}, {0}}},
|
||||
{kFusedBatchNormEx, {{0}, {0}}},
|
||||
{kFusedBatchNormExWithActivation, {{0}, {0}}},
|
||||
{kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}},
|
||||
{kFusedBatchNormGradEx, {{0, 1}, {0}}},
|
||||
{kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}},
|
||||
{kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
|
||||
{kBatchNorm, {{0}, {0}}},
|
||||
{kBatchNormWithActivation, {{0}, {0}}},
|
||||
{kBatchNormWithAddAndActivation, {{0, 5}, {0}}},
|
||||
{kBatchNormGradOpName, {{0, 1}, {0}}},
|
||||
{kBatchNormGradWithActivation, {{0, 1, 7}, {0}}},
|
||||
{kBatchNormGradWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
|
||||
{kBiasAddOpName, {{0}, {0}}},
|
||||
{prim::kPrimBiasAddGrad->name(), {{0}, {}}},
|
||||
// Format insensitive.
|
||||
|
|
|
@ -50,13 +50,12 @@ constexpr auto kFusedBN3OpName = "FusedBN3";
|
|||
constexpr auto kBNGrad1OpName = "BNGrad1";
|
||||
constexpr auto kBNGrad2OpName = "BNGrad2";
|
||||
constexpr auto kBNGrad3OpName = "BNGrad3";
|
||||
constexpr auto kFusedBatchNormEx = "FusedBatchNormEx";
|
||||
constexpr auto kBatchNorm = "BatchNorm";
|
||||
constexpr auto kInstanceNorm = "InstanceNorm";
|
||||
constexpr auto kFusedBatchNormExWithActivation = "FusedBatchNormExWithActivation";
|
||||
constexpr auto kFusedBatchNormExWithAddAndActivation = "FusedBatchNormExWithAddAndActivation";
|
||||
constexpr auto kFusedBatchNormGradEx = "FusedBatchNormGradEx";
|
||||
constexpr auto kFusedBatchNormGradExWithActivation = "FusedBatchNormGradExWithActivation";
|
||||
constexpr auto kFusedBatchNormGradExWithAddAndActivation = "FusedBatchNormGradExWithAddAndActivation";
|
||||
constexpr auto kBatchNormWithActivation = "BatchNormWithActivation";
|
||||
constexpr auto kBatchNormWithAddAndActivation = "BatchNormWithAddAndActivation";
|
||||
constexpr auto kBatchNormGradWithActivation = "BatchNormGradWithActivation";
|
||||
constexpr auto kBatchNormGradWithAddAndActivation = "BatchNormGradWithAddAndActivation";
|
||||
constexpr auto kClearZeroOpName = "ClearZero";
|
||||
constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
|
||||
constexpr auto kGetNextOpName = "GetNext";
|
||||
|
|
|
@ -45,14 +45,10 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -140,77 +140,6 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
|
|||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: five tensors(x, gamma, beta, mean, variance).
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 5);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString()
|
||||
<< ", arg1:" << args_spec_list[1]->ToString();
|
||||
FusedBatchNormCheckDim(primitive, args_spec_list);
|
||||
|
||||
auto input = args_spec_list[0];
|
||||
auto input_shape = dyn_cast<Shape>(input->GetShapeTrack());
|
||||
MS_EXCEPTION_IF_NULL(input_shape);
|
||||
const auto &input_shape_list = input_shape->shape();
|
||||
if (input_shape_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Input shape size should >= 2.";
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < args_spec_list.size(); ++i) {
|
||||
auto arg_shape = dyn_cast<Shape>(args_spec_list[i]->GetShapeTrack());
|
||||
MS_EXCEPTION_IF_NULL(arg_shape);
|
||||
const auto &arg_shape_list = arg_shape->shape();
|
||||
if (arg_shape_list.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Arg shape size should >= 1.";
|
||||
}
|
||||
if (arg_shape_list[0] != input_shape_list[1]) {
|
||||
MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0]
|
||||
<< ") should match the second dimension of tensor"
|
||||
" param[0](which is "
|
||||
<< input_shape_list[1] << ").";
|
||||
}
|
||||
}
|
||||
auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s");
|
||||
|
||||
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
|
||||
for (size_t i = 1; i < args_spec_list.size(); ++i) {
|
||||
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
|
||||
tensorPtrList.push_back(param);
|
||||
}
|
||||
(void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s");
|
||||
|
||||
// check validity;
|
||||
auto epsilon_value = primitive->GetAttr("epsilon");
|
||||
auto momentum_value = primitive->GetAttr("momentum");
|
||||
MS_EXCEPTION_IF_NULL(epsilon_value);
|
||||
MS_EXCEPTION_IF_NULL(momentum_value);
|
||||
if (!epsilon_value->isa<FP32Imm>() || !momentum_value->isa<FP32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString()
|
||||
<< ", momentum: " << momentum_value->ToString();
|
||||
}
|
||||
|
||||
auto epsilon = epsilon_value->cast<FP32ImmPtr>()->value();
|
||||
auto momentum = momentum_value->cast<FP32ImmPtr>()->value();
|
||||
|
||||
if (epsilon > 1.0f || epsilon <= 0.0f) {
|
||||
MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon;
|
||||
}
|
||||
if (momentum > 1.0f || momentum < 0.0f) {
|
||||
MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum;
|
||||
}
|
||||
|
||||
// Outputs: y, running_mean, running_variance, save_mean, save_inv_variance.
|
||||
AbstractBasePtr y = input->Broaden();
|
||||
AbstractBasePtr other = args_spec_list[1]->Broaden();
|
||||
MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString();
|
||||
|
||||
AbstractBasePtrList elements = {y, other, other, other, other};
|
||||
return std::make_shared<AbstractTuple>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -228,24 +157,8 @@ AbstractBasePtr InferImplSparseSoftmaxCrossEntropyWithLogits(const AnalysisEngin
|
|||
return std::make_shared<abstract::AbstractTensor>(type_tensor->element(), shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance).
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[3]);
|
||||
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 5);
|
||||
auto dx = args_spec_list[1]->Broaden();
|
||||
auto dscale = args_spec_list[2]->Broaden();
|
||||
auto dbias = args_spec_list[3]->Broaden();
|
||||
|
||||
AbstractBasePtrList rets = {dx, dscale, dbias};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: five tensors(x, gamma, beta, mean, variance).
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 5);
|
||||
|
@ -256,12 +169,21 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi
|
|||
ShapeVector x_min_shape = input_x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = input_x->shape()->max_shape();
|
||||
CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape);
|
||||
if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Input rank should 4.";
|
||||
|
||||
auto input_tensor = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
(void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param x of BatchNorm should be");
|
||||
AbstractTensorPtrList tensorPtrList = std::vector<AbstractTensorPtr>();
|
||||
for (size_t i = 1; i < args_spec_list.size(); ++i) {
|
||||
auto param = CheckArg<AbstractTensor>(op_name, args_spec_list, i);
|
||||
tensorPtrList.push_back(param);
|
||||
}
|
||||
(void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32},
|
||||
"param gamma, beta, mean, variance of Batchnorm should be");
|
||||
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format = GetAndCheckFormat(data_format_ptr);
|
||||
|
||||
int64_t c_axis = 1;
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = 3;
|
||||
|
@ -275,8 +197,8 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi
|
|||
MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size();
|
||||
}
|
||||
if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) {
|
||||
MS_LOG(EXCEPTION) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis << "]=" << x_shape[c_axis]
|
||||
<< ", but got " << arg_shape[0];
|
||||
MS_EXCEPTION(ValueError) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis
|
||||
<< "]=" << x_shape[c_axis] << ", but got " << arg_shape[0];
|
||||
}
|
||||
}
|
||||
AbstractTensorPtr input_gamma = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
|
@ -288,7 +210,7 @@ AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const Primi
|
|||
AbstractTensorPtr output = std::make_shared<AbstractTensor>(input_x->element(), output_shape_ptr);
|
||||
ShapePtr gamma_shape_ptr = std::make_shared<Shape>(gamma_shape, gamma_min_shape, gamma_max_shape);
|
||||
AbstractTensorPtr output_gamma = std::make_shared<AbstractTensor>(input_gamma->element(), gamma_shape_ptr);
|
||||
AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma, output_gamma};
|
||||
AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
|
||||
|
|
|
@ -117,9 +117,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
// NN
|
||||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
|
||||
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
|
||||
{prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}},
|
||||
{prim::kPrimBatchNorm, {InferImplBatchNorm, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
{prim::kPrimConv2D, {InferImplConv2D, true}},
|
||||
{prim::kPrimBiasAdd, {InferImplBiasAdd, true}},
|
||||
|
|
|
@ -219,13 +219,10 @@ inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoo
|
|||
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
|
||||
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");
|
||||
inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormEx = std::make_shared<Primitive>("FusedBatchNormEx");
|
||||
inline const PrimitivePtr kPrimConv2D = std::make_shared<Primitive>("Conv2D");
|
||||
inline const PrimitivePtr kPrimFullConnection = std::make_shared<Primitive>("FullConnection");
|
||||
inline const PrimitivePtr kPrimConv2DTranspose = std::make_shared<Primitive>("Conv2DTranspose");
|
||||
inline const PrimitivePtr kPrimGroupConv2DGradInput = std::make_shared<Primitive>("GroupConv2DGradInput");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>("FusedBatchNormGrad");
|
||||
inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared<Primitive>("SyncBatchNorm");
|
||||
|
|
|
@ -130,8 +130,6 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
|
|||
{"MaxPoolGradGradWithArgmax", FormatAndPadUpperAttrMap},
|
||||
{"BatchNorm", DataFormatMap},
|
||||
{"BatchNormGrad", DataFormatMap},
|
||||
{"FusedBatchNormEx", DataFormatMap},
|
||||
{"FusedBatchNormGradEx", DataFormatMap},
|
||||
{"BiasAdd", DataFormatMap},
|
||||
{"BiasAddGrad", DataFormatMap},
|
||||
{"BinaryCrossEntropy", ReductionMap},
|
||||
|
|
|
@ -139,25 +139,16 @@ class _BatchNorm(Cell):
|
|||
else:
|
||||
self.is_ge_backend = False
|
||||
|
||||
if self._target == "Ascend":
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
data_format=self.format)
|
||||
if self._target == "GPU":
|
||||
self.bn_train = P.FusedBatchNormEx(mode=1,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
data_format=self.format)
|
||||
if self._target == "CPU":
|
||||
self.bn_train = P.FusedBatchNorm(mode=1,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum)
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
data_format=self.format)
|
||||
if self.is_global:
|
||||
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
group=SYNC_BN_GROUP_NAME,
|
||||
device_num=self.group_device_num)
|
||||
|
||||
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
||||
|
||||
data_parallel_strategy = ((1,), (1,))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -541,19 +541,9 @@ class Conv2dBnFoldQuantOneConv(Cell):
|
|||
channel_axis=channel_axis,
|
||||
num_channels=out_channels,
|
||||
quant_dtype=quant_dtype)
|
||||
if self._target == "Ascend":
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum)
|
||||
if self._target == "GPU":
|
||||
self.bn_train = P.FusedBatchNormEx(mode=1,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
data_format=self.format)
|
||||
if self._target == "CPU":
|
||||
self.bn_train = P.FusedBatchNorm(mode=1,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum)
|
||||
self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps,
|
||||
momentum=self.momentum, data_format=self.format)
|
||||
|
||||
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
||||
data_parallel_strategy = ((1,), (1,))
|
||||
data_parallel_strategy_one = ((1,), ())
|
||||
|
|
|
@ -647,49 +647,6 @@ def get_bprop_fast_gelu_2(self):
|
|||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.FusedBatchNorm)
|
||||
def get_bprop_fused_batch_norm(self):
|
||||
"""Grad definition for `FusedBatchNorm` operation."""
|
||||
input_grad = G.FusedBatchNormGrad(self.epsilon, self.momentum)
|
||||
target_cpu = False
|
||||
if self.target == "CPU":
|
||||
input_grad = G.FusedBatchNormGradCPU(self.epsilon, self.momentum)
|
||||
target_cpu = True
|
||||
|
||||
def bprop(x, scale, b, mean, variance, out, dout):
|
||||
saved_mean = out[3]
|
||||
saved_variance = out[4]
|
||||
if target_cpu:
|
||||
out = input_grad(dout[0], x, scale, b, saved_mean, saved_variance)
|
||||
else:
|
||||
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
|
||||
dx = out[0]
|
||||
dscale = out[1]
|
||||
dbias = out[2]
|
||||
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.FusedBatchNormEx)
|
||||
def get_bprop_fused_batch_norm_ex(self):
|
||||
"""Grad definition for `FusedBatchNormEx` operation."""
|
||||
input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format)
|
||||
|
||||
def bprop(x, scale, b, mean, variance, out, dout):
|
||||
saved_mean = out[3]
|
||||
saved_variance = out[4]
|
||||
reserve = out[5]
|
||||
out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
|
||||
dx = out[0]
|
||||
dscale = out[1]
|
||||
dbias = out[2]
|
||||
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.InstanceNorm)
|
||||
def get_bprop_instance_norm(self):
|
||||
"""Grad definition for `InstanceNorm` operation."""
|
||||
|
@ -715,12 +672,14 @@ def get_bprop_batch_norm(self):
|
|||
|
||||
def bprop(x, scale, b, mean, variance, out, dout):
|
||||
if is_training:
|
||||
saved_reserve_1 = out[3]
|
||||
saved_reserve_2 = out[4]
|
||||
saved_mean = out[3]
|
||||
saved_variance = out[4]
|
||||
reserve = out[2]
|
||||
else:
|
||||
saved_reserve_1 = mean
|
||||
saved_reserve_2 = variance
|
||||
out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2)
|
||||
saved_mean = mean
|
||||
saved_variance = variance
|
||||
reserve = out[2]
|
||||
out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
|
||||
dx = out[0]
|
||||
dscale = out[1]
|
||||
dbias = out[2]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -48,12 +48,6 @@ class BiasAdd:
|
|||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class FusedBatchNorm:
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@op_selector
|
||||
class ApplyMomentum:
|
||||
def __call__(self, *args):
|
||||
|
|
|
@ -65,9 +65,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
BiasAdd, Conv2D,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
|
||||
FusedBatchNorm, FusedBatchNormEx, InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
GeLU, Gelu, FastGeLU, FastGelu, Elu,
|
||||
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
|
||||
LogSoftmax, MaxPool3D,
|
||||
MaxPool, DataFormatDimMap,
|
||||
|
@ -142,8 +141,6 @@ __all__ = [
|
|||
'Conv2D',
|
||||
'Flatten',
|
||||
'MaxPoolWithArgmax',
|
||||
'FusedBatchNorm',
|
||||
'FusedBatchNormEx',
|
||||
'BNTrainingReduce',
|
||||
'BNTrainingUpdate',
|
||||
'BatchNorm',
|
||||
|
|
|
@ -197,12 +197,12 @@ class BatchNormGrad(PrimitiveWithInfer):
|
|||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.data_format = validator.check_string(data_format, ['NCHW', 'NHWC', "NCDHW"], 'format', self.name)
|
||||
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape):
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve):
|
||||
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
|
||||
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
|
||||
return (x_shape, scale_shape, scale_shape)
|
||||
|
||||
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type):
|
||||
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
|
||||
def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape, reserve):
|
||||
return (x_type, scale_type, scale_type)
|
||||
|
||||
|
||||
class SyncBatchNormGrad(PrimitiveWithInfer):
|
||||
|
@ -708,53 +708,6 @@ class FlattenGrad(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class FusedBatchNormGrad(Primitive):
|
||||
"""Gradients of FusedBatchNorm operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=0.0, momentum=0.1):
|
||||
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'],
|
||||
outputs=['dx', 'bn_scale', 'bn_bias'])
|
||||
|
||||
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FusedBatchNormGradCPU(PrimitiveWithInfer):
|
||||
"""Gradients of FusedBatchNorm operation for CPU."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=0.0, momentum=0.1):
|
||||
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'bias', 'save_mean', 'save_inv_variance'],
|
||||
outputs=['dx', 'bn_scale', 'bn_bias'])
|
||||
self.add_prim_attr('data_format', "NCHW")
|
||||
|
||||
def infer_shape(self, dy_shape, x_shape, scale_shape, bias_shape, save_mean_shape, save_inv_variance_shape):
|
||||
return (x_shape, scale_shape, bias_shape)
|
||||
|
||||
def infer_dtype(self, dy_type, x_type, scale_type, bias_type, save_mean_type, save_inv_variance_type):
|
||||
return (x_type, scale_type, bias_type)
|
||||
|
||||
|
||||
class FusedBatchNormGradEx(PrimitiveWithInfer):
|
||||
"""Gradients of FusedBatchNormEx operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=0.0, momentum=0.1, data_format="NCHW"):
|
||||
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'],
|
||||
outputs=['dx', 'bn_scale', 'bn_bias'])
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError("NHWC format only support in GPU target.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape):
|
||||
return (x_shape, scale_shape, scale_shape)
|
||||
|
||||
def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_type, save_variance_type, reserve_type):
|
||||
return (x_type, scale_type, scale_type)
|
||||
|
||||
|
||||
class InstanceNormGrad(PrimitiveWithInfer):
|
||||
"""Gradients of InstanceNorm operation."""
|
||||
|
||||
|
|
|
@ -816,221 +816,20 @@ class Tanh(PrimitiveWithInfer):
|
|||
|
||||
class FusedBatchNorm(Primitive):
|
||||
r"""
|
||||
FusedBatchNorm is a BatchNorm. Moving mean and moving variance will be computed instead of being loaded.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This operation applies
|
||||
Batch Normalization over input to avoid internal covariate shift as described in the
|
||||
paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||
Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||
feature using a mini-batch of data and the learned parameters which can be described
|
||||
in the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
Args:
|
||||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **scale** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **bias** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **mean** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Parameter) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 5 Tensor, the normalized input and the updated parameters.
|
||||
|
||||
- **output_x** (Tensor) - The same type and shape as the `input_x`.
|
||||
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `mode` is not an int.
|
||||
TypeError: If `epsilon` or `momentum` is not a float.
|
||||
TypeError: If `output_x`, `updated_scale`, `updated_bias`, `updated_moving_mean` or `updated_moving_variance` is
|
||||
a Tensor.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as ops
|
||||
>>> class FusedBatchNormNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(FusedBatchNormNet, self).__init__()
|
||||
>>> self.fused_batch_norm = ops.FusedBatchNorm()
|
||||
>>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale")
|
||||
>>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias")
|
||||
>>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
|
||||
>>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
|
||||
>>>
|
||||
>>> def construct(self, input_x):
|
||||
>>> out = self.fused_batch_norm(input_x, self.scale, self.bias, self.mean, self.variance)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> net = FusedBatchNormNet()
|
||||
>>> output = net(input_x)
|
||||
>>> result = output[0].shape
|
||||
>>> print(result)
|
||||
(128, 64, 32, 64)
|
||||
The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
|
||||
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
||||
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self._update_parameter = True
|
||||
self.target = context.get_context("device_target")
|
||||
raise TypeError("The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.")
|
||||
|
||||
|
||||
class FusedBatchNormEx(PrimitiveWithCheck):
|
||||
r"""
|
||||
FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve)
|
||||
than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that
|
||||
moving mean and moving variance will be computed instead of being loaded. FusedBatchNormEx currently only
|
||||
supports 4D inputs.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This operation applies
|
||||
Batch Normalization over input to avoid internal covariate shift as described in the
|
||||
paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||
Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||
feature using a mini-batch of data and the learned parameters which can be described
|
||||
in the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
Args:
|
||||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
||||
Default: "NCHW".
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
|
||||
data type: float16 or float32.
|
||||
- **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **bias** (Parameter) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **mean** (Parameter) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **variance** (Parameter) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 6 Tensors, the normalized input, the updated parameters and reserve.
|
||||
|
||||
- **output_x** (Tensor) - The output of FusedBatchNormEx, same type and shape as the `input_x`.
|
||||
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `mode` is not an int.
|
||||
TypeError: If neither `epsilon` nor `momentum` is a float.
|
||||
TypeError: If `data_format` is not a str.
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
TypeError: If dtype of `scale`, `bias`, `mean` or `variance` is not float32.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as ops
|
||||
>>> class FusedBatchNormExNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(FusedBatchNormExNet, self).__init__()
|
||||
>>> self.fused_batch_norm_ex = ops.FusedBatchNormEx()
|
||||
>>> self.scale = Parameter(Tensor(np.ones([64]), mindspore.float32), name="scale")
|
||||
>>> self.bias = Parameter(Tensor(np.ones([64]), mindspore.float32), name="bias")
|
||||
>>> self.mean = Parameter(Tensor(np.ones([64]), mindspore.float32), name="mean")
|
||||
>>> self.variance = Parameter(Tensor(np.ones([64]), mindspore.float32), name="variance")
|
||||
>>>
|
||||
>>> def construct(self, input_x):
|
||||
>>> out = self.fused_batch_norm_ex(input_x, self.scale, self.bias, self.mean, self.variance)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
>>> net = FusedBatchNormExNet()
|
||||
>>> output = net(input_x)
|
||||
>>> result = output[0].shape
|
||||
>>> print(result)
|
||||
(128, 64, 32, 64)
|
||||
The FusedBatchNormEx interface is deprecated, please use the BatchNorm interface.
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
|
||||
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
|
||||
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self._update_parameter = True
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError("NHWC format only support in GPU target.")
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
def check_shape(self, input_x, scale, bias, mean, variance):
|
||||
input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2])
|
||||
validator.check_equal_int(len(input_shape_norm), 4, "x rank", self.name)
|
||||
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
|
||||
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
|
||||
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
|
||||
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, input_x, scale, bias, mean, variance):
|
||||
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
|
||||
args = {"scale": scale, "bias": bias}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name)
|
||||
args_moving = {"mean": mean, "variance": variance}
|
||||
valid_dtypes = [mstype.tensor_type(mstype.float32)]
|
||||
validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name)
|
||||
raise TypeError("FusedBatchnormEx interface is deprecated, please use BatchNorm interface.")
|
||||
|
||||
|
||||
class InstanceNorm(PrimitiveWithInfer):
|
||||
|
@ -1419,7 +1218,7 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
else:
|
||||
args_moving = {"mean": mean, "variance": variance}
|
||||
validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
|
||||
return (input_x, scale, bias, input_x, input_x)
|
||||
return (input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32)
|
||||
|
||||
|
||||
class Conv2D(PrimitiveWithCheck):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -41,7 +41,7 @@ class Grad(nn.Cell):
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.bn = P.FusedBatchNorm()
|
||||
self.bn = P.BatchNorm()
|
||||
self.scale = Parameter(initializer('ones', [64]), name='scale')
|
||||
self.b = Parameter(initializer('zeros', [64]), name='b')
|
||||
self.mean = Parameter(initializer('ones', [64]), name='mean')
|
|
@ -1,128 +0,0 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetFusedBatchNormEx(Cell):
|
||||
def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None):
|
||||
super(NetFusedBatchNormEx, self).__init__()
|
||||
self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1)
|
||||
self.moving_mean = Parameter(initializer(
|
||||
mean_init, num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer(
|
||||
var_init, num_features), name="variance", requires_grad=False)
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, num_features), name="gamma", requires_grad=True)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=True)
|
||||
self.dynshape = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fused_bn_ex():
|
||||
x = np.array([[
|
||||
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
|
||||
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
|
||||
expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
|
||||
[-0.1471, 0.7706, 1.6882, 2.6059],
|
||||
[0.3118, 1.6882, 2.1471, 2.1471],
|
||||
[0.7706, 0.3118, 2.6059, -0.1471]],
|
||||
|
||||
[[0.9119, 1.8518, 1.3819, -0.0281],
|
||||
[-0.0281, 0.9119, 1.3819, 1.8518],
|
||||
[2.7918, 0.4419, -0.4981, 0.9119],
|
||||
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
|
||||
|
||||
weight = np.ones(2).astype(np.float32)
|
||||
bias = np.ones(2).astype(np.float32)
|
||||
moving_mean = np.ones(2).astype(np.float32)
|
||||
moving_var = np.ones(2).astype(np.float32)
|
||||
error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
bn_net = NetFusedBatchNormEx(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var))
|
||||
output_list = bn_net(Tensor(x))
|
||||
output = output_list[0]
|
||||
diff = output.asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
class NetFusedBatchNormExDynamic(Cell):
|
||||
def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None):
|
||||
super(NetFusedBatchNormExDynamic, self).__init__()
|
||||
self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1)
|
||||
self.moving_mean = Parameter(initializer(
|
||||
mean_init, num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer(
|
||||
var_init, num_features), name="variance", requires_grad=False)
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, num_features), name="gamma", requires_grad=True)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=True)
|
||||
self.dynshape = inner.GpuConvertToDynamicShape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dynshape(x)
|
||||
x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fused_bn_ex_dynamic():
|
||||
x = np.array([[
|
||||
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
|
||||
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
|
||||
expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
|
||||
[-0.1471, 0.7706, 1.6882, 2.6059],
|
||||
[0.3118, 1.6882, 2.1471, 2.1471],
|
||||
[0.7706, 0.3118, 2.6059, -0.1471]],
|
||||
|
||||
[[0.9119, 1.8518, 1.3819, -0.0281],
|
||||
[-0.0281, 0.9119, 1.3819, 1.8518],
|
||||
[2.7918, 0.4419, -0.4981, 0.9119],
|
||||
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
|
||||
|
||||
weight = np.ones(2).astype(np.float32)
|
||||
bias = np.ones(2).astype(np.float32)
|
||||
moving_mean = np.ones(2).astype(np.float32)
|
||||
moving_var = np.ones(2).astype(np.float32)
|
||||
error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
bn_net = NetFusedBatchNormExDynamic(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var))
|
||||
output_list = bn_net(Tensor(x))
|
||||
output = output_list[0]
|
||||
diff = output.asnumpy() - expect_output
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
|
@ -414,25 +414,6 @@ TEST_F(TestOps, ReluTest) {
|
|||
ASSERT_EQ(prim->name(), kPrimRelu->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, FusedBatchNormTest) {
|
||||
auto prim = std::make_shared<Primitive>("FusedBatchNorm");
|
||||
ASSERT_EQ(prim->name(), kPrimFusedBatchNorm->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, FusedBatchNormAttrTest) {
|
||||
Primitive prim("FusedBatchNorm");
|
||||
prim.SetAttrs({
|
||||
{"epsilon", MakeValue(0.001f)},
|
||||
{"momentum", MakeValue(0.1f)},
|
||||
});
|
||||
ASSERT_EQ(prim.name(), kPrimFusedBatchNorm->name());
|
||||
|
||||
FP32Imm epsilon(0.001f);
|
||||
FP32Imm momentum(0.1f);
|
||||
ASSERT_EQ(*prim.GetAttr("epsilon"), epsilon);
|
||||
ASSERT_EQ(*prim.GetAttr("momentum"), momentum);
|
||||
}
|
||||
|
||||
TEST_F(TestOps, PoolingTest) {
|
||||
auto prim = std::make_shared<Primitive>("Pooling");
|
||||
ASSERT_EQ(prim->name(), kPrimPooling->name());
|
||||
|
|
|
@ -612,65 +612,6 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) {
|
|||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
TEST_F(TestPrim, test_fused_batch_norm) {
|
||||
PrimitivePtr fused_batch_norm = prim::kPrimFusedBatchNorm;
|
||||
fused_batch_norm->AddAttr("epsilon", MakeValue(0.001f));
|
||||
fused_batch_norm->AddAttr("momentum", MakeValue(0.1f));
|
||||
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(fused_batch_norm, 5);
|
||||
|
||||
// NCHW
|
||||
std::vector<int64_t> inputs_dims = {128, 64, 32, 64};
|
||||
std::vector<int64_t> scale_dims = {64};
|
||||
std::vector<int64_t> offset_dims = {64};
|
||||
std::vector<int64_t> mean_dims = {64};
|
||||
std::vector<int64_t> variance_dims = {64};
|
||||
|
||||
tensor::TensorPtr inputs = std::make_shared<tensor::Tensor>();
|
||||
inputs->set_data_type(kNumberTypeFloat32);
|
||||
inputs->set_shape(inputs_dims);
|
||||
|
||||
tensor::TensorPtr scale = std::make_shared<tensor::Tensor>();
|
||||
scale->set_data_type(kNumberTypeFloat32);
|
||||
scale->set_shape(scale_dims);
|
||||
|
||||
tensor::TensorPtr offset = std::make_shared<tensor::Tensor>();
|
||||
offset->set_data_type(kNumberTypeFloat32);
|
||||
offset->set_shape(offset_dims);
|
||||
|
||||
tensor::TensorPtr mean = std::make_shared<tensor::Tensor>();
|
||||
mean->set_data_type(kNumberTypeFloat32);
|
||||
mean->set_shape(mean_dims);
|
||||
|
||||
tensor::TensorPtr variance = std::make_shared<tensor::Tensor>();
|
||||
variance->set_data_type(kNumberTypeFloat32);
|
||||
variance->set_shape(variance_dims);
|
||||
|
||||
AbstractBasePtr abstract_inputs = FromValue(inputs, true);
|
||||
AbstractBasePtr abstract_scale = FromValue(scale, true);
|
||||
AbstractBasePtr abstract_offset = FromValue(offset, true);
|
||||
AbstractBasePtr abstract_mean = FromValue(mean, true);
|
||||
AbstractBasePtr abstract_variance = FromValue(variance, true);
|
||||
AbstractBasePtrList args_spec_list = {abstract_inputs, abstract_scale, abstract_offset, abstract_mean,
|
||||
abstract_variance};
|
||||
|
||||
AbstractBasePtr expected0 = abstract_inputs->Clone();
|
||||
AbstractBasePtr expected1 = abstract_scale->Clone();
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
||||
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
||||
|
||||
std::shared_ptr<AbstractTuple> abs_tuple = dyn_cast<AbstractTuple>(res);
|
||||
ASSERT_TRUE(abs_tuple != nullptr);
|
||||
ASSERT_TRUE(*abs_tuple->elements()[0] == *expected0);
|
||||
ASSERT_TRUE(*abs_tuple->elements()[1] == *expected1);
|
||||
ASSERT_TRUE(*abs_tuple->elements()[2] == *expected1);
|
||||
ASSERT_TRUE(*abs_tuple->elements()[3] == *expected1);
|
||||
ASSERT_TRUE(*abs_tuple->elements()[4] == *expected1);
|
||||
}
|
||||
|
||||
TEST_F(TestPrim, test_pooling) {
|
||||
PrimitivePtr pooling = prim::kPrimPooling;
|
||||
pooling->AddAttr("mode", MakeValue(std::string("avg")));
|
||||
|
|
|
@ -35,7 +35,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_fission) {
|
|||
std::vector<int64_t> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
for (size_t i = 0; i < 6; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
@ -56,7 +56,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission1)
|
|||
std::vector<int64_t> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
for (size_t i = 0; i < 6; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
@ -75,7 +75,7 @@ TEST_F(TestHWBatchNormGradInferFission, test_batch_norm_grad_infer_no_fission2)
|
|||
std::vector<int64_t> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
for (size_t i = 0; i < 6; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
|
|
@ -47,7 +47,7 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) {
|
|||
std::vector<int64_t> shp_b{64};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract};
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
|
||||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
|
||||
|
@ -80,13 +80,17 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) {
|
|||
// set kernel for BNGrad
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
builder1.SetInputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_NC1HWC0});
|
||||
builder1.SetOutputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0,
|
||||
kOpFormat_NC1HWC0});
|
||||
builder1.SetInputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32,
|
||||
kNumberTypeFloat32});
|
||||
builder1.SetOutputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32,
|
||||
kNumberTypeFloat32});
|
||||
builder1.SetKernelType(TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get());
|
||||
// do bn_grad_split pass
|
||||
|
|
|
@ -37,7 +37,7 @@ TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_fusion) {
|
|||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{64};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract};
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract, y_abstract};
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
|
@ -57,7 +57,7 @@ TEST_F(TestHWOptimizeBatchNormGrad2BNInferGrad, test_no_fusion) {
|
|||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int64_t> shp_y{64};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract};
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, y_abstract, y_abstract, y_abstract, y_abstract};
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
|
|
|
@ -39,28 +39,28 @@ def test_batch_norm_grad_infer_fission(tag):
|
|||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
|
||||
def before(input0, input1, input2, input3, input4, input5):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4, input5)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_is_training(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4)
|
||||
def before_is_training(input0, input1, input2, input3, input4, input5):
|
||||
batch_norm = BatchNormGradTraining(input0, input1, input2, input3, input4, input5)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_output3_not_null(input0, input1, input2, input3, input4):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 3))
|
||||
def before_output3_not_null(input0, input1, input2, input3, input4, input5):
|
||||
batch_norm = BatchNormGradInfer(input0, input1, input2, input3, input4, input5)
|
||||
outputs = make_tuple(tuple_getitem(batch_norm, 0), tuple_getitem(batch_norm, 1), tuple_getitem(batch_norm, 2))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4):
|
||||
def after(input0, input1, input2, input3, input4, input5):
|
||||
bn_infer_grad = BNInferGrad(input0, input2, input4)
|
||||
bn_training_update_grad = BNTrainingUpdateGrad(input0, input1, input3, input4)
|
||||
outputs = make_tuple(bn_infer_grad, tuple_getitem(bn_training_update_grad, 0),
|
||||
|
|
|
@ -38,19 +38,19 @@ def test_batchnormgrad_to_bninfergrad(tag):
|
|||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4):
|
||||
res = batch_norm_grad(input0, input1, input2, input3, input4)
|
||||
def before(input0, input1, input2, input3, input4, input5):
|
||||
res = batch_norm_grad(input0, input1, input2, input3, input4, input5)
|
||||
res = tuple_getitem(res, 0)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4):
|
||||
def after(input0, input1, input2, input3, input4, input5):
|
||||
res = bn_infer_grad(input0, input2, input4)
|
||||
return make_tuple(res)
|
||||
|
||||
@fns
|
||||
def no_fusion(input0, input1, input2, input3, input4):
|
||||
res = batch_norm_grad(input0, input1, input2, input3, input4)
|
||||
def no_fusion(input0, input1, input2, input3, input4, input5):
|
||||
res = batch_norm_grad(input0, input1, input2, input3, input4, input5)
|
||||
item0 = tuple_getitem(res, 0)
|
||||
item1 = tuple_getitem(res, 1)
|
||||
item2 = tuple_getitem(res, 2)
|
||||
|
|
|
@ -49,8 +49,8 @@ def test_bn_grad_split(tag):
|
|||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(i0, i1, i2, i3, i4):
|
||||
bn_grad_output = bn_grad(i0, i1, i2, i3, i4)
|
||||
def before(i0, i1, i2, i3, i4, i5):
|
||||
bn_grad_output = bn_grad(i0, i1, i2, i3, i4, i5)
|
||||
item0 = tuple_getitem(bn_grad_output, 0)
|
||||
item1 = tuple_getitem(bn_grad_output, 1)
|
||||
item2 = tuple_getitem(bn_grad_output, 2)
|
||||
|
@ -58,7 +58,7 @@ def test_bn_grad_split(tag):
|
|||
return output
|
||||
|
||||
@fns
|
||||
def after1(i0, i1, i2, i3, i4):
|
||||
def after1(i0, i1, i2, i3, i4, i5):
|
||||
bn_grad1_output = bn_grad1(i0, i1, i3)
|
||||
bn_grad1_item0 = tuple_getitem(bn_grad1_output, 0)
|
||||
bn_grad1_item1 = tuple_getitem(bn_grad1_output, 1)
|
||||
|
@ -78,7 +78,7 @@ def test_bn_grad_split(tag):
|
|||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def after2(i0, i1, i2, i3, i4):
|
||||
def after2(i0, i1, i2, i3, i4, i5):
|
||||
bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4)
|
||||
update_item0 = tuple_getitem(bn_update_grad_output, 0)
|
||||
update_item1 = tuple_getitem(bn_update_grad_output, 1)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -27,8 +27,6 @@ make_tuple = Primitive('MakeTuple')
|
|||
four2five = Primitive('Four2Five')
|
||||
five2four = Primitive('Five2Four')
|
||||
cast = Primitive('Cast')
|
||||
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
|
||||
bn = P.FusedBatchNorm()
|
||||
relu = P.ReLU()
|
||||
|
||||
|
||||
|
@ -140,25 +138,6 @@ def test_eliminate_depend_input2(tag):
|
|||
return fns[tag]
|
||||
|
||||
|
||||
def test_opt_match(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def graph1(x, y):
|
||||
sum_add = add(x, y)
|
||||
output = make_tuple(sum_add)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def graph2(x, w, scale, b, mean, variance):
|
||||
conv_output = conv(x, w)
|
||||
bn_output = bn(conv_output, scale, b, mean, variance)
|
||||
res = tuple_getitem(bn_output, 0)
|
||||
return res
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_func_graph_cse(tag):
|
||||
""" test_func_graph_cse """
|
||||
fns = FnDict()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
# Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import _constants as Constants
|
||||
|
||||
# pylint: disable=unused-variable
|
||||
|
@ -25,7 +24,6 @@ allreduce = P.AllReduce()
|
|||
allreduce.add_prim_attr('fusion', 1)
|
||||
make_tuple = Primitive("MakeTuple")
|
||||
conv = P.Conv2D(out_channel=64, kernel_size=7, mode=1, pad_mode="valid", pad=0, stride=1, dilation=1, group=1)
|
||||
bn = P.FusedBatchNorm()
|
||||
relu = P.ReLU()
|
||||
conv_bn1 = Primitive('ConvBN1')
|
||||
bn2_add_relu = Primitive('BN2AddRelu')
|
||||
|
@ -33,7 +31,6 @@ bn2_relu = Primitive('BN2Relu')
|
|||
fused_bn1 = Primitive('FusedBN1')
|
||||
fused_bn2 = Primitive('FusedBN2')
|
||||
fused_bn3 = Primitive('FusedBN3')
|
||||
bn_grad = G.FusedBatchNormGrad()
|
||||
bn_grad1 = Primitive('BNGrad1')
|
||||
bn_grad2 = Primitive('BNGrad2')
|
||||
bn_grad3 = Primitive('BNGrad3')
|
||||
|
@ -50,73 +47,6 @@ class FnDict:
|
|||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_bn_split(tag):
|
||||
""" test_split_bn_fusion """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x, scale, b, mean, variance):
|
||||
bn_output = bn(x, scale, b, mean, variance)
|
||||
item0 = tuple_getitem(bn_output, 0)
|
||||
return item0
|
||||
|
||||
@fns
|
||||
def after(x, scale, b, mean, variance):
|
||||
fused_bn1_output = fused_bn1(x)
|
||||
fused_bn2_input0 = tuple_getitem(fused_bn1_output, 0)
|
||||
fused_bn2_input1 = tuple_getitem(fused_bn1_output, 1)
|
||||
fused_bn2_output = fused_bn2(fused_bn2_input0, fused_bn2_input1, mean, variance)
|
||||
fused_bn3_input1 = tuple_getitem(fused_bn2_output, 0)
|
||||
fused_bn3_input2 = tuple_getitem(fused_bn2_output, 1)
|
||||
fused_bn3_output = fused_bn3(x, fused_bn3_input1, fused_bn3_input2, scale, b)
|
||||
output1 = tuple_getitem(fused_bn2_output, 2)
|
||||
output2 = tuple_getitem(fused_bn2_output, 3)
|
||||
output3 = tuple_getitem(fused_bn2_output, 0)
|
||||
output4 = tuple_getitem(fused_bn2_output, 1)
|
||||
output = make_tuple(fused_bn3_output, output1, output2, output3, output4)
|
||||
item0 = tuple_getitem(output, 0)
|
||||
return make_tuple(item0)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_bn_grad_split(tag):
|
||||
""" test_bn_grad_split """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(dy, x, scale, save_mean, save_inv_variance):
|
||||
bn_grad_output = bn_grad(dy, x, scale, save_mean, save_inv_variance)
|
||||
item0 = tuple_getitem(bn_grad_output, 0)
|
||||
item1 = tuple_getitem(bn_grad_output, 1)
|
||||
item2 = tuple_getitem(bn_grad_output, 2)
|
||||
output = make_tuple(item0, item1, item2)
|
||||
res = tuple_getitem(output, 0)
|
||||
return res
|
||||
|
||||
@fns
|
||||
def after(i0, i1, i2, i3, i4):
|
||||
bn_grad1_output = bn_grad1(i0, i1, i3)
|
||||
bn_grad1_item0 = tuple_getitem(bn_grad1_output, 0)
|
||||
bn_grad1_item1 = tuple_getitem(bn_grad1_output, 1)
|
||||
bn_grad1_item2 = tuple_getitem(bn_grad1_output, 2)
|
||||
bn_grad2_output = bn_grad2(bn_grad1_item0, bn_grad1_item1, i4, i2)
|
||||
bn_grad2_item0 = tuple_getitem(bn_grad2_output, 0)
|
||||
bn_grad2_item1 = tuple_getitem(bn_grad2_output, 1)
|
||||
bn_grad2_item2 = tuple_getitem(bn_grad2_output, 2)
|
||||
bn_grad2_item3 = tuple_getitem(bn_grad2_output, 3)
|
||||
bn_grad2_item4 = tuple_getitem(bn_grad2_output, 4)
|
||||
bn_grad3_output = bn_grad3(i0, bn_grad2_item2, bn_grad2_item3, bn_grad2_item4, bn_grad1_item2)
|
||||
bn_grad_make_tuple = make_tuple(bn_grad3_output, bn_grad2_item0, bn_grad2_item1)
|
||||
item0 = tuple_getitem(bn_grad_make_tuple, 0)
|
||||
item1 = tuple_getitem(bn_grad_make_tuple, 1)
|
||||
item2 = tuple_getitem(bn_grad_make_tuple, 2)
|
||||
output = make_tuple(item0, item1, item2)
|
||||
return make_tuple(tuple_getitem(output, 0))
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_all_reduce_fusion_all(tag):
|
||||
""" test_all_reduce_fusion_all """
|
||||
fns = FnDict()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -221,7 +221,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) {
|
|||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
// test fused batch norm as input
|
||||
inputs.push_back(NewValueNode(prim::kPrimFusedBatchNorm));
|
||||
inputs.push_back(NewValueNode(prim::kPrimBatchNorm));
|
||||
auto bn = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
std::vector<int64_t> shp{2, 32, 224, 224};
|
||||
|
@ -417,7 +417,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
|
|||
TEST_F(AnfRuntimeAlgorithmTest, GetOutputInferDataTypeTest) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(prim::kPrimFusedBatchNorm));
|
||||
inputs.push_back(NewValueNode(prim::kPrimBatchNorm));
|
||||
auto bn = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(bn);
|
||||
std::vector<int64_t> shp{2, 32, 224, 224};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -690,26 +690,6 @@ test_cases_for_verify_exception = [
|
|||
'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('FusedBatchNorm_ValueError_1', {
|
||||
'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('FusedBatchNorm_ValueError_2', {
|
||||
'block': (lambda _: P.FusedBatchNorm(mode=2, epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('FusedBatchNorm_ValueError_3', {
|
||||
'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=-1e-5, momentum=0.1), {'exception': ValueError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('FusedBatchNorm_ValueError_4', {
|
||||
'block': (lambda _: P.FusedBatchNorm(mode=0, epsilon=1e-5, momentum=-0.1), {'exception': ValueError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('FusedBatchNorm_ValueError_5', {
|
||||
'block': (lambda _: P.FusedBatchNorm(mode=1, epsilon=-0.001, momentum=0.0), {'exception': ValueError}),
|
||||
'desc_inputs': [0],
|
||||
}),
|
||||
('Softmax_ValueError_1', {
|
||||
'block': (lambda _: P.Softmax("1"), {'exception': TypeError}),
|
||||
'desc_inputs': [0],
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1749,11 +1749,6 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
||||
'desc_bprop': [[2, 16], [16], [16]],
|
||||
'skip': ['backward']}),
|
||||
('FusedBatchNormGrad', {
|
||||
'block': G.FusedBatchNormGrad(),
|
||||
'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('BatchNorm', {
|
||||
'block': P.BatchNorm(),
|
||||
'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
|
@ -1761,8 +1756,8 @@ test_case_nn_ops = [
|
|||
'skip': []}),
|
||||
('BatchNormGrad', {
|
||||
'block': G.BatchNormGrad(),
|
||||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('SyncBatchNorm', {
|
||||
'block': inner.SyncBatchNorm(),
|
||||
|
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
||||
|
||||
grad_all = C.GradOperation(get_all=True)
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, b):
|
||||
predict = self.network(x, y, b)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y, b):
|
||||
return grad_all(self.network)(x, y, b)
|
||||
|
||||
|
||||
# model_parallel test
|
||||
def test_two_matmul_batchnorm_ex():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.matmul1 = P.BatchMatMul().shard(strategy1)
|
||||
self.norm = P.FusedBatchNormEx()
|
||||
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma")
|
||||
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta")
|
||||
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean")
|
||||
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var")
|
||||
self.matmul2 = P.BatchMatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x, y, b):
|
||||
out = self.matmul1(x, y)
|
||||
out = self.norm(out, self.gamma, self.beta, self.mean, self.var)[0]
|
||||
out = self.matmul2(out, b)
|
||||
return out
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8)
|
||||
strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1))
|
||||
strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1))
|
||||
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
|
||||
net.set_auto_parallel()
|
||||
x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32)
|
||||
b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_executor.compile(net, x, y, b)
|
|
@ -1,260 +0,0 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as DT
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn import WithLossCell
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class FusedBatchNorm(nn.Cell):
|
||||
"""Batch Normalization base class."""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones'):
|
||||
super(FusedBatchNorm, self).__init__()
|
||||
if num_features < 1:
|
||||
raise ValueError("num_features must be at least 1")
|
||||
|
||||
if momentum < 0 or momentum > 1:
|
||||
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
|
||||
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = Tensor(1.0 - momentum, DT.float32)
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=affine)
|
||||
self.moving_mean = Parameter(initializer(
|
||||
moving_mean_init, num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer(
|
||||
moving_var_init, num_features), name="variance", requires_grad=False)
|
||||
|
||||
self.bn_train = P.BatchNorm(is_training=True,
|
||||
epsilon=self.eps)
|
||||
self.bn_infer = P.BatchNorm(is_training=False,
|
||||
epsilon=self.eps)
|
||||
self.sub_mean = P.Sub().shard(((1), (1)))
|
||||
self.sub_var = P.Sub().shard(((1), (1)))
|
||||
self.mul_mean = P.Mul().shard(((1,), ()))
|
||||
self.mul_var = P.Mul().shard(((1,), ()))
|
||||
self.assign_sub_mean = P.AssignSub().shard(((1,), (1,)))
|
||||
self.assign_sub_var = P.AssignSub().shard(((1), (1)))
|
||||
self.sub_mean2 = P.Sub().shard(((1), (1)))
|
||||
self.sub_var2 = P.Sub().shard(((1), (1)))
|
||||
|
||||
def shard(self, strategy):
|
||||
self.bn_train.shard(strategy)
|
||||
self.bn_infer.shard(strategy)
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def construct(self, x):
|
||||
if self.training:
|
||||
y, batch_mean, batch_var, _, _ = \
|
||||
self.bn_train(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
self.moving_mean,
|
||||
self.moving_variance)
|
||||
|
||||
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
||||
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
||||
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
||||
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
||||
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
||||
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
||||
|
||||
else:
|
||||
y = self.bn_infer(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
self.moving_mean,
|
||||
self.moving_variance)[0]
|
||||
return y
|
||||
|
||||
def extend_repr(self):
|
||||
return 'num_features={}, eps={}, momentum={}, ' \
|
||||
'beta={}, gamma={}, ' \
|
||||
'moving_mean={}, moving_variance={} ' \
|
||||
.format(self.num_features,
|
||||
self.eps,
|
||||
self.momentum,
|
||||
self.beta,
|
||||
self.gamma,
|
||||
self.moving_mean,
|
||||
self.moving_variance)
|
||||
|
||||
|
||||
class PReLU(nn.Cell):
|
||||
"""
|
||||
PReLU activation function.
|
||||
|
||||
Computes prelu value of a 4-dim tensor(NCHW).
|
||||
PReLU: out = max(0, A) + min(0, wA)
|
||||
|
||||
Args:
|
||||
channel: Integer. The dimensionality of w. Default: 1.
|
||||
w: Float. The initial value of w. Default: 0.25.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same type as features.
|
||||
|
||||
Examples:
|
||||
prelu = nn.PReLU(1, [np.float32(0.25)]) # or prelu = nn.PReLU(33, Tensor(np.random.rand(33), ms.float32)])
|
||||
input_data = Tensor(np.random.rand(1, 33, 4, 4), ms.float32)
|
||||
output = prelu.construct(input_data)
|
||||
"""
|
||||
|
||||
def __init__(self, channel=1, w=0.25):
|
||||
super(PReLU, self).__init__()
|
||||
if isinstance(w, (np.float32, float)):
|
||||
tmp = np.empty((channel,), dtype=np.float32)
|
||||
tmp.fill(w)
|
||||
w = tmp
|
||||
elif isinstance(w, (int, bool, complex, str)):
|
||||
raise TypeError("w only support input type float32 and float")
|
||||
|
||||
if not isinstance(w, Tensor):
|
||||
w = Tensor(w)
|
||||
self.w = Parameter(initializer(w, [channel,]), name='a')
|
||||
self.prelu = P.PReLU()
|
||||
self.relu = P.ReLU().shard(((1)))
|
||||
|
||||
def construct(self, x):
|
||||
self.w = self.relu(self.w)
|
||||
return self.prelu(x, self.w)
|
||||
|
||||
|
||||
class BNNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(BNNet, self).__init__()
|
||||
self.bn = FusedBatchNorm(512)
|
||||
self.prelu = PReLU(512)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.bn(x)
|
||||
x = self.prelu(x)
|
||||
return x
|
||||
|
||||
|
||||
def bn_net():
|
||||
return BNNet()
|
||||
|
||||
|
||||
def bn_common(parallel_mode, train_flag, strategy_loss=None):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
rank_size = 8
|
||||
|
||||
predict = Tensor(np.ones([32, 512]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([32]), dtype=ms.int32)
|
||||
dataset = Dataset(predict, label, 2)
|
||||
net = bn_net()
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss.softmax_cross_entropy.shard(strategy_loss)
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum, 0.0001, 1024 * rank_size)
|
||||
|
||||
if not train_flag:
|
||||
net = WithLossCell(net, loss)
|
||||
net.set_train()
|
||||
|
||||
if parallel_mode == ParallelMode.DATA_PARALLEL:
|
||||
context.set_auto_parallel_context(parameter_broadcast=True)
|
||||
model = Model(net, loss, opt)
|
||||
if train_flag:
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
else:
|
||||
model._predict(predict, label)
|
||||
|
||||
|
||||
def test_data_parallel():
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
train_flag = True
|
||||
bn_common(parallel_mode, train_flag)
|
||||
|
||||
|
||||
def auto_parallel():
|
||||
train_flag = True
|
||||
parallel_mode = ParallelMode.AUTO_PARALLEL
|
||||
bn_common(parallel_mode, train_flag)
|
||||
|
||||
|
||||
def Xtest_data_parallel_predict():
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
train_flag = False
|
||||
bn_common(parallel_mode, train_flag)
|
||||
|
||||
|
||||
def Xtest_semi_auto_parallel_predict():
|
||||
train_flag = False
|
||||
parallel_mode = ParallelMode.SEMI_AUTO_PARALLEL
|
||||
bn_common(parallel_mode, train_flag)
|
||||
|
||||
|
||||
def Xtest_auto_parallel_predict():
|
||||
train_flag = False
|
||||
parallel_mode = ParallelMode.AUTO_PARALLEL
|
||||
bn_common(parallel_mode, train_flag)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
auto_parallel()
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -92,27 +92,6 @@ def vm_impl_tanh(self):
|
|||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.FusedBatchNorm)
|
||||
def vm_impl_fused_batch_norm(self):
|
||||
"""Generate vm_impl function for FusedBatchNorm"""
|
||||
|
||||
def vm_impl(x, scale, b, mean, variance):
|
||||
# pylint: disable=unused-argument
|
||||
x = x.asnumpy()
|
||||
scale = scale.asnumpy()
|
||||
b = b.asnumpy()
|
||||
mean = mean.asnumpy()
|
||||
variance = variance.asnumpy()
|
||||
out, x_mean, x_var, running_mean, running_var = vm.batch_norm(x, scale, b, mean, \
|
||||
variance, \
|
||||
eps=self.epsilon, \
|
||||
momentum=self.momentum)
|
||||
return Tensor(out), Tensor(x_mean), Tensor(x_var), \
|
||||
Tensor(running_mean), Tensor(running_var)
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
@vm_impl_getters.register(P.BatchNorm)
|
||||
def vm_impl_batch_norm(self):
|
||||
"""Generate vm_impl function for BatchNorm"""
|
||||
|
@ -223,23 +202,6 @@ def vm_impl_avg_pool_grad(self):
|
|||
return vm_impl
|
||||
|
||||
|
||||
# pylint: disable=function-redefined
|
||||
@vm_impl_getters.register(G.FusedBatchNormGrad)
|
||||
def vm_impl_fused_batch_norm_grad(self):
|
||||
"""Generate vm_impl function for FusedBatchNormGrad"""
|
||||
|
||||
def vm_impl(dy, x, scale, save_mean, save_inv_variance):
|
||||
dy = dy.asnumpy()
|
||||
x = x.asnumpy()
|
||||
scale = scale.asnumpy()
|
||||
save_mean = save_mean.asnumpy()
|
||||
save_inv_variance = save_inv_variance.asnumpy()
|
||||
dx, dscale, dshift = vm.batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
|
||||
return (Tensor(dx), Tensor(dscale), Tensor(dshift))
|
||||
|
||||
return vm_impl
|
||||
|
||||
|
||||
# pylint: disable=function-redefined
|
||||
@vm_impl_getters.register(G.BatchNormGrad)
|
||||
def vm_impl_fused_batch_norm_grad(self):
|
||||
|
|
Loading…
Reference in New Issue