forked from mindspore-Ecosystem/mindspore
commit
85461bcdb3
|
@ -276,6 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SyncBnSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SyncBnGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/ascend/ir_fission/bn_split.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
@ -104,6 +105,36 @@ CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode
|
|||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr SyncBNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> bn_update_grad_outputs;
|
||||
CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs);
|
||||
if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) {
|
||||
MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"
|
||||
<< " trace: " << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> allreduce_mul_outputs;
|
||||
for (size_t i = 0; i < bn_update_grad_outputs.size(); ++i) {
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_update_grad_outputs[i], cnode);
|
||||
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> bn_reduce_grad_outputs;
|
||||
CreateOutputsOfReduceGrad(func_graph, cnode, allreduce_mul_outputs, &bn_reduce_grad_outputs);
|
||||
if (bn_reduce_grad_outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"
|
||||
<< " trace: " << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0],
|
||||
allreduce_mul_outputs[0], allreduce_mul_outputs[1]};
|
||||
auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
return make_tuple;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef BnGradSplit::DefinePattern() const {
|
||||
|
@ -120,5 +151,17 @@ const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfN
|
|||
}
|
||||
return BNGradSplitForTBE(func_graph, cnode);
|
||||
}
|
||||
|
||||
const BaseRef SyncBnGradSplit::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimSyncBatchNormGrad, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr SyncBnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
return SyncBNGradSplitForTBE(func_graph, cnode);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,14 @@ class BnGradSplit : public PatternProcessPass {
|
|||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class SyncBnGradSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit SyncBnGradSplit(bool multigraph = true) : PatternProcessPass("sync_bn_grad_split", multigraph) {}
|
||||
~SyncBnGradSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -28,6 +30,9 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kReduceOpSum = "sum";
|
||||
constexpr auto kDeviceNum = "device_num";
|
||||
|
||||
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
||||
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -117,8 +122,105 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr
|
|||
// Create BNTrainingUpdate node
|
||||
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs);
|
||||
}
|
||||
|
||||
AnfNodePtr SyncBNSplitForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) {
|
||||
MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs.";
|
||||
return nullptr;
|
||||
}
|
||||
// Create BNTrainingReduce node and get outputs of BNTrainingReduce
|
||||
std::vector<AnfNodePtr> bn_training_reduce_outputs;
|
||||
if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
|
||||
MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
|
||||
return nullptr;
|
||||
}
|
||||
if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
|
||||
MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"
|
||||
<< " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> allreduce_mul_outputs;
|
||||
for (size_t i = 0; i < bn_training_reduce_outputs.size(); ++i) {
|
||||
auto allreduce_mul_output = CreateAllReduceAndMul(func_graph, bn_training_reduce_outputs[i], cnode);
|
||||
allreduce_mul_outputs.emplace_back(allreduce_mul_output);
|
||||
}
|
||||
|
||||
// Create BNTrainingUpdate node
|
||||
return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, allreduce_mul_outputs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
|
||||
if (!AnfAlgo::HasNodeAttr(kDeviceNum, sync_bn_cnode)) {
|
||||
MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] does not have attr device_num.";
|
||||
}
|
||||
auto device_num = AnfAlgo::GetNodeAttr<int64_t>(sync_bn_cnode, kDeviceNum);
|
||||
MS_LOG(INFO) << "device_num value: " << device_num;
|
||||
float device_num_reciprocal = 1.0 / device_num;
|
||||
|
||||
std::vector<int64_t> device_num_shape = {};
|
||||
auto device_num_reciprocal_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, device_num_shape);
|
||||
MS_EXCEPTION_IF_NULL(device_num_reciprocal_tensor);
|
||||
auto data_ptr = device_num_reciprocal_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
auto *val = reinterpret_cast<float *>(data_ptr);
|
||||
*val = device_num_reciprocal;
|
||||
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, device_num_shape);
|
||||
auto device_num_reciprocal_value = kernel_graph->NewValueNode(abstract, device_num_reciprocal_tensor);
|
||||
MS_EXCEPTION_IF_NULL(device_num_reciprocal_value);
|
||||
kernel_graph->AddValueNodeToGraph(device_num_reciprocal_value);
|
||||
return device_num_reciprocal_value;
|
||||
}
|
||||
|
||||
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
|
||||
const CNodePtr &sync_bn_cnode) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(allreduce_input);
|
||||
MS_EXCEPTION_IF_NULL(sync_bn_cnode);
|
||||
|
||||
// create AllReduce
|
||||
std::vector<AnfNodePtr> allreduce_inputs = {NewValueNode(std::make_shared<Primitive>(kAllReduceOpName)),
|
||||
allreduce_input};
|
||||
auto allreduce = graph->NewCNode(allreduce_inputs);
|
||||
MS_EXCEPTION_IF_NULL(allreduce);
|
||||
allreduce->set_abstract(allreduce_input->abstract());
|
||||
allreduce->set_scope(allreduce_input->scope());
|
||||
AnfAlgo::SetNodeAttr(kAttrOp, MakeValue(kReduceOpSum), allreduce);
|
||||
AnfAlgo::CopyNodeAttr(kAttrGroup, sync_bn_cnode, allreduce);
|
||||
// use SyncBatchNorm's opid as AllReduce's fusion attr
|
||||
auto sync_bn_opname = sync_bn_cnode->fullname_with_scope();
|
||||
auto opid_pos = sync_bn_opname.rfind("-op");
|
||||
if (opid_pos == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "op[" << sync_bn_cnode->DebugString() << "] has no opid.";
|
||||
}
|
||||
int64_t opid = std::stol(sync_bn_opname.substr(opid_pos + 3));
|
||||
// user defined fusion should be greater than 1
|
||||
if (opid < 2) {
|
||||
opid = opid - 2 + std::numeric_limits<int64_t>::max();
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrFusion, MakeValue(opid), allreduce);
|
||||
|
||||
// create Mul
|
||||
auto device_num_reciprocal_vnode = CreateValueNodeOfDeviceNumReciprocal(graph, sync_bn_cnode);
|
||||
std::vector<AnfNodePtr> mul_inputs = {NewValueNode(std::make_shared<Primitive>(kMulOpName)), allreduce,
|
||||
device_num_reciprocal_vnode};
|
||||
auto mul = graph->NewCNode(mul_inputs);
|
||||
MS_EXCEPTION_IF_NULL(mul);
|
||||
mul->set_abstract(allreduce_input->abstract());
|
||||
mul->set_scope(allreduce_input->scope());
|
||||
return mul;
|
||||
}
|
||||
|
||||
const BaseRef BnSplit::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
|
@ -132,5 +234,14 @@ const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
}
|
||||
return SplitBatchNormForTBE(func_graph, node);
|
||||
}
|
||||
|
||||
const BaseRef SyncBnSplit::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimSyncBatchNorm, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr SyncBnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
return SyncBNSplitForTBE(func_graph, node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,19 @@ class BnSplit : public PatternProcessPass {
|
|||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class SyncBnSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit SyncBnSplit(bool multigraph = true) : PatternProcessPass("sync_bn_split", multigraph) {}
|
||||
~SyncBnSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
AnfNodePtr CreateValueNodeOfDeviceNumReciprocal(const FuncGraphPtr &graph, const CNodePtr &sync_bn_cnode);
|
||||
|
||||
AnfNodePtr CreateAllReduceAndMul(const FuncGraphPtr &graph, const AnfNodePtr &allreduce_input,
|
||||
const CNodePtr &sync_bn_cnode);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_BN_SPLIT_H_
|
||||
|
|
|
@ -228,6 +228,8 @@ inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared<Primitive>(
|
|||
inline const PrimitivePtr kPrimFusedBatchNormGradEx = std::make_shared<Primitive>("FusedBatchNormGradEx");
|
||||
inline const PrimitivePtr kPrimBatchNorm = std::make_shared<Primitive>("BatchNorm");
|
||||
inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared<Primitive>("BatchNormGrad");
|
||||
inline const PrimitivePtr kPrimSyncBatchNorm = std::make_shared<Primitive>("SyncBatchNorm");
|
||||
inline const PrimitivePtr kPrimSyncBatchNormGrad = std::make_shared<Primitive>("SyncBatchNormGrad");
|
||||
inline const PrimitivePtr kPrimReluGrad = std::make_shared<Primitive>("ReluGrad");
|
||||
inline const PrimitivePtr kPrimReluGradV2 = std::make_shared<Primitive>("ReluGradV2");
|
||||
inline const PrimitivePtr kPrimRelu6Grad = std::make_shared<Primitive>("ReLU6Grad");
|
||||
|
|
|
@ -13,12 +13,17 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""normalization"""
|
||||
import itertools
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common._decorator import deprecated
|
||||
from mindspore.ops.primitive import constexpr
|
||||
import mindspore.context as context
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
|
@ -26,8 +31,9 @@ from mindspore.communication import management
|
|||
from mindspore.ops import _selected_ops
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'InstanceNorm2d']
|
||||
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'GroupNorm', 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d']
|
||||
|
||||
SYNC_BN_GROUP_NAME = ""
|
||||
|
||||
class _BatchNorm(Cell):
|
||||
"""Batch Normalization base class."""
|
||||
|
@ -44,6 +50,7 @@ class _BatchNorm(Cell):
|
|||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=1,
|
||||
process_groups=0,
|
||||
input_dims='2d',
|
||||
data_format='NCHW'):
|
||||
super(_BatchNorm, self).__init__()
|
||||
|
@ -68,19 +75,47 @@ class _BatchNorm(Cell):
|
|||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, num_features), name="beta", requires_grad=affine)
|
||||
self.group = validator.check_positive_int(device_num_each_group)
|
||||
self.group_device_num = validator.check_positive_int(device_num_each_group)
|
||||
self.process_groups = process_groups
|
||||
self.is_global = False
|
||||
if self.group != 1:
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
global SYNC_BN_GROUP_NAME
|
||||
# for GlobalBatchNorm
|
||||
if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
self.rank_list = self.list_group(self.device_list, self.group)
|
||||
self.rank_list = self.list_group(self.device_list, self.group_device_num)
|
||||
self.rank_list_idx = len(self.rank_list)
|
||||
for i in range(self.rank_list_idx):
|
||||
if self.rank_id in self.rank_list[i] and self.group != 1:
|
||||
if self.rank_id in self.rank_list[i]:
|
||||
self.is_global = True
|
||||
management.create_group('group' + str(i), self.rank_list[i])
|
||||
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group"+ str(i)
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
|
||||
# for SyncBatchNorm
|
||||
if self.process_groups != 0 and self.parallel_mode != context.ParallelMode.STAND_ALONE:
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
if self.process_groups is not None:
|
||||
validator.check_isinstance("process_groups", self.process_groups, list)
|
||||
self._check_rank_ids(self.process_groups, self.rank_size)
|
||||
for i in range(len(self.process_groups)):
|
||||
validator.check_isinstance("process_groups[" + str(i) +"]", self.process_groups[i], list)
|
||||
self.group_device_num = len(self.process_groups[i])
|
||||
if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
|
||||
self.is_global = True
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
|
||||
elif self.rank_size > 1:
|
||||
self.is_global = True
|
||||
self.group_device_num = self.rank_size
|
||||
self.device_list = [i for i in range(0, self.rank_size)]
|
||||
if SYNC_BN_GROUP_NAME == "":
|
||||
SYNC_BN_GROUP_NAME = "sync_bn_group0"
|
||||
management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||
self.square = P.Square()
|
||||
|
@ -109,9 +144,12 @@ class _BatchNorm(Cell):
|
|||
self.bn_train = P.FusedBatchNorm(mode=1,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum)
|
||||
if self.is_global:
|
||||
self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
|
||||
momentum=self.momentum,
|
||||
group=SYNC_BN_GROUP_NAME,
|
||||
device_num=self.group_device_num)
|
||||
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
||||
self.enable_global_sync = self.is_global and (self.is_ge_backend or\
|
||||
(self.is_graph_mode and self._target == "Ascend"))
|
||||
|
||||
data_parallel_strategy = ((1,), (1,))
|
||||
data_parallel_strategy_one = ((1,), ())
|
||||
|
@ -135,26 +173,13 @@ class _BatchNorm(Cell):
|
|||
group_list = [list(i) for i in world_rank_list]
|
||||
return group_list
|
||||
|
||||
def _global_sync(self, x, axes, re_shape):
|
||||
"""calculate global batch normalization output"""
|
||||
x_mean = self.reduce_mean(x, axes)
|
||||
x_mean_square = self.reduce_mean(self.square(x), axes)
|
||||
global_batch_mean = self.all_reduce(x_mean) / self.group
|
||||
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
||||
global_mean = global_batch_mean
|
||||
global_var = global_batch_mean_square - self.square(global_mean)
|
||||
var_sqrt = self.sqrt(global_var + self.eps)
|
||||
mean_first = (x - global_mean) / var_sqrt
|
||||
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
|
||||
|
||||
mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
|
||||
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
|
||||
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
|
||||
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
|
||||
y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean))))
|
||||
y = F.depend(y, self.assign_sub_var(self.moving_variance,
|
||||
self.reshape(tmp_variance, self.shape(self.moving_variance))))
|
||||
return y
|
||||
def _check_rank_ids(self, process_groups, rank_size):
|
||||
seen = set()
|
||||
for rid in itertools.chain(*process_groups):
|
||||
validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups")
|
||||
if rid in seen:
|
||||
raise ValueError("rank id in process_groups should not be duplicated.")
|
||||
seen.add(rid)
|
||||
|
||||
def construct(self, x):
|
||||
_shape_check_bn(self.shape(x), self.input_dims)
|
||||
|
@ -164,10 +189,6 @@ class _BatchNorm(Cell):
|
|||
flag = self.use_batch_statistics
|
||||
|
||||
if flag:
|
||||
if self.enable_global_sync:
|
||||
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
||||
return self._global_sync(x, axes, re_shape)
|
||||
|
||||
return self.bn_train(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
|
@ -597,6 +618,7 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
[ 20.9999895 241.9988 ]]]]
|
||||
"""
|
||||
|
||||
@deprecated("1.2", "SyncBatchNorm", True)
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
|
@ -619,8 +641,8 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
use_batch_statistics,
|
||||
device_num_each_group,
|
||||
input_dims='both')
|
||||
self.group = validator.check_positive_int(device_num_each_group)
|
||||
if self.group <= 1:
|
||||
self.group_device_num = validator.check_positive_int(device_num_each_group)
|
||||
if self.group_device_num <= 1:
|
||||
raise ValueError("the number of group must be greater than 1.")
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
|
@ -628,6 +650,121 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
pass
|
||||
|
||||
|
||||
class SyncBatchNorm(_BatchNorm):
|
||||
r"""
|
||||
Sync Batch normalization layer over a N-dimension input.
|
||||
|
||||
Sync Batch Normalization is cross device synchronized batch normalization. The implementation of Batch
|
||||
Normalization only normalizes the data within each device. Sync Batch normalization will normalize the input
|
||||
within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
|
||||
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Note:
|
||||
Currently, SyncBatchNorm only supports 2D and 4D inputs.
|
||||
|
||||
Args:
|
||||
num_features (int): `C` from an expected input of size (N, C, H, W).
|
||||
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
||||
momentum (float): A floating hyperparameter of the momentum for the
|
||||
running_mean and running_var computation. Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, training process will use the mean and
|
||||
variance of current batch data and track the running mean and variance, eval process will use the running
|
||||
mean and variance. Default: None.
|
||||
process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists.
|
||||
Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
|
||||
group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating
|
||||
synchronization across all devices.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
TypeError: If `process_groups` is not a list.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: If `device_num_each_group` is less than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> # This example should be run with multiple processes.
|
||||
>>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.communication import init
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.context import ParallelMode
|
||||
>>> from mindspore import nn, Tensor
|
||||
>>> from mindspore.common import dtype as mstype
|
||||
>>>
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> init()
|
||||
>>> context.reset_auto_parallel_context()
|
||||
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
>>> np.random.seed(0)
|
||||
>>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
|
||||
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 2, 2]), mstype.float32)
|
||||
>>> output = sync_bn_op(input)
|
||||
>>> print(output)
|
||||
[[[[171.99915 46.999763]
|
||||
[116.99941 191.99904 ]]
|
||||
[[ 66.999664 250.99875 ]
|
||||
[194.99902 102.99948 ]]
|
||||
[[ 8.999955 210.99895 ]
|
||||
[ 20.9999895 241.9988 ]]]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.9,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=None,
|
||||
process_groups=None):
|
||||
super(SyncBatchNorm, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics,
|
||||
process_groups=process_groups,
|
||||
input_dims='both')
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
if x.dim == 0:
|
||||
pass
|
||||
|
||||
|
||||
class LayerNorm(Cell):
|
||||
r"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
from .. import operations as P
|
||||
from .. import composite as C
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations import _inner_ops as inner
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import bprop_getters
|
||||
|
||||
|
@ -64,5 +66,20 @@ def bprop_pqc(self):
|
|||
dx = t(dx, (1, 0))
|
||||
dy = C.tensor_dot(dout[0], out[2], ((0, 1), (0, 1)))
|
||||
return dx, dy
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.SyncBatchNorm)
|
||||
def get_bprop_sync_batch_norm(self):
|
||||
"""Grad definition for `SyncBatchNorm` operation."""
|
||||
input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num)
|
||||
|
||||
def bprop(x, scale, b, mean, variance, out, dout):
|
||||
saved_mean = out[3]
|
||||
saved_variance = out[4]
|
||||
out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
|
||||
dx = out[0]
|
||||
dscale = out[1]
|
||||
dbias = out[2]
|
||||
return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
|
||||
return bprop
|
||||
|
|
|
@ -204,6 +204,24 @@ class BatchNormGrad(PrimitiveWithInfer):
|
|||
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
|
||||
|
||||
|
||||
class SyncBatchNormGrad(PrimitiveWithInfer):
|
||||
"""Performs grad of SyncBatchNorm operation."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-5, group="group0", device_num=2):
|
||||
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
if not isinstance(group, str):
|
||||
raise TypeError("The group attr of SyncBatchNormGrad should be str.")
|
||||
validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
|
||||
|
||||
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape):
|
||||
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
|
||||
return (x_shape, scale_shape, scale_shape)
|
||||
|
||||
def infer_dtype(self, y_backprop_type, x_type, scale_type, save_mean_shape, save_variance_shape):
|
||||
return (x_type, scale_type, scale_type)
|
||||
|
||||
|
||||
class BiasAddGrad(PrimitiveWithInfer):
|
||||
"""Computes gradients of BiasAdd."""
|
||||
|
||||
|
|
|
@ -630,6 +630,7 @@ class GpuConvertToDynamicShape(PrimitiveWithCheck):
|
|||
def check_dtype(self, input_dtype):
|
||||
validator.check_subclass("input_dtype", input_dtype, mstype.tensor, self.name)
|
||||
|
||||
|
||||
class ErrorOnDynamicShapeInput(PrimitiveWithInfer):
|
||||
"""
|
||||
This op is used for dynamic shape testing. The only purpose of this operator is
|
||||
|
@ -724,3 +725,93 @@ class SequenceMask(PrimitiveWithCheck):
|
|||
def check_dtype(self, lengths_dtype, maxlen_dtype):
|
||||
validator.check_subclass("lengths_dtype", lengths_dtype, mstype.tensor, self.name)
|
||||
validator.check_subclass("maxlen", maxlen_dtype, mstype.number, self.name)
|
||||
|
||||
|
||||
class SyncBatchNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Sync Batch Normalization for input data and updated parameters.
|
||||
|
||||
Sync Batch Normalization is cross device synchronized batch normalization. Batch Normalization is
|
||||
widely used in convolutional neural networks. This operation applies Batch Normalization over input
|
||||
to avoid internal covariate shift as described in the paper `Batch Normalization: Accelerating
|
||||
Deep Network Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
||||
It rescales and recenters the features using a mini-batch of data and the learned parameters which
|
||||
can be described in the following formula,
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
Args:
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`).
|
||||
Momentum value must be [0, 1]. Default: 0.1.
|
||||
group (str): The communication group to work on. Default: "sync_bn_group0".
|
||||
device_num (int): The number of devices in each group. Default: 2.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
|
||||
- **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
||||
- **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`.
|
||||
- **mean** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type.
|
||||
- **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `mean`.
|
||||
|
||||
Outputs:
|
||||
Tuple of 5 Tensor, the normalized inputs and the updated parameters.
|
||||
|
||||
- **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`.
|
||||
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> # This example should be run with multiple processes.
|
||||
>>> # Please refer to nn.SyncBatchNorm for direct use.
|
||||
>>> input_x = Tensor(np.ones([2, 2]), mindspore.float32)
|
||||
>>> scale = Tensor(np.ones([2]), mindspore.float32)
|
||||
>>> bias = Tensor(np.ones([2]), mindspore.float32)
|
||||
>>> mean = Tensor(np.ones([2]), mindspore.float32)
|
||||
>>> variance = Tensor(np.ones([2]), mindspore.float32)
|
||||
>>> sync_batch_norm = ops._inner_ops.SyncBatchNorm()
|
||||
>>> output = sync_batch_norm(input_x, scale, bias, mean, variance)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[2, 2], dtype=Float32, value=
|
||||
[[ 1.00000000e+00, 1.00000000e+00],
|
||||
[ 1.00000000e+00, 1.00000000e+00]]), Tensor(shape=[2], dtype=Float32, value=
|
||||
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
||||
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
||||
[ 1.00000000e+00, 1.00000000e+00]), Tensor(shape=[2], dtype=Float32, value=
|
||||
[ 1.00000000e+00, 1.00000000e+00]))
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-5, momentum=0.1, group="sync_bn_group0", device_num=2):
|
||||
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
validator.check_isinstance("group", group, str)
|
||||
validator.check_int(device_num, 2, Rel.GE, "device_num", self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
|
||||
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
|
||||
|
||||
def infer_shape(self, input_x, scale, bias, mean, variance):
|
||||
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
|
||||
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
|
||||
validator.check("scale shape[0]", scale[0], "input_x channel", input_x[1], Rel.EQ, self.name)
|
||||
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
|
||||
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
|
||||
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
|
||||
return (input_x, scale, scale, scale, scale)
|
||||
|
||||
def infer_dtype(self, input_x, scale, bias, mean, variance):
|
||||
validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name)
|
||||
args = {"scale": scale, "bias": bias}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name)
|
||||
args_moving = {"mean": mean, "variance": variance}
|
||||
validator.check_tensors_dtypes_same_and_valid(args_moving, [mstype.float16, mstype.float32], self.name)
|
||||
return (input_x, scale, bias, input_x, input_x)
|
||||
|
|
|
@ -100,5 +100,67 @@ TEST_F(TestHWBnGradSplit, test_bn_grad_split_tbe) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_grad_split", "after2");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBnGradSplit, test_sync_bn_grad_split_tbe) {
|
||||
get_py_fun_.SetDoResolve(true);
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{1, 64, 112, 112};
|
||||
std::vector<int64_t> shp_b{64};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, b_abstract, b_abstract, b_abstract};
|
||||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
|
||||
// get SyncBNGrad
|
||||
CNodePtr ret = kernel_graph->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
EXPECT_NE(ret->input(1), nullptr);
|
||||
EXPECT_TRUE(ret->input(1)->isa<CNode>());
|
||||
auto make_tuple1 = ret->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(make_tuple1->input(1), nullptr);
|
||||
EXPECT_TRUE(make_tuple1->input(1)->isa<CNode>());
|
||||
auto make_tuple2 = make_tuple1->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(make_tuple2->input(1), nullptr);
|
||||
EXPECT_TRUE(make_tuple2->input(1)->isa<CNode>());
|
||||
auto tuple_getitem = make_tuple2->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(tuple_getitem->input(1), nullptr);
|
||||
EXPECT_TRUE(tuple_getitem->input(1)->isa<CNode>());
|
||||
auto bn_grad = tuple_getitem->input(1)->cast<CNodePtr>();
|
||||
|
||||
// get param1
|
||||
EXPECT_NE(bn_grad->input(1), nullptr);
|
||||
auto param1 = bn_grad->input(1);
|
||||
|
||||
// set kernel for param1
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder2;
|
||||
builder2.SetOutputsFormat({kOpFormat_NC1HWC0});
|
||||
builder2.SetOutputsDeviceType({kNumberTypeFloat32});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder2.Build(), param1.get());
|
||||
|
||||
// set kernel for SyncBNGrad
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
builder1.SetInputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder1.SetOutputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder1.SetInputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
builder1.SetOutputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
builder1.SetKernelType(TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), bn_grad.get());
|
||||
// do sync_bn_grad_split pass
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::SyncBnGradSplit>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(kernel_graph);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_grad_split", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -86,7 +86,7 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) {
|
|||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get());
|
||||
|
||||
// do bn_grad_split_pass
|
||||
// do bn_split_pass
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::BnSplit>();
|
||||
|
@ -97,5 +97,54 @@ TEST_F(TestHWBnSplit, test_bn_split_tbe) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_bn_split_tbe", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWBnSplit, test_sync_bn_split_tbe) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "before");
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
std::vector<int64_t> shp_x{1, 64, 112, 112};
|
||||
std::vector<int64_t> shp_b{64};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
auto b_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_b);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, b_abstract, b_abstract, b_abstract, b_abstract};
|
||||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
// get kernel
|
||||
auto ret = kernel_graph->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
EXPECT_TRUE(ret->inputs().size() == 2);
|
||||
auto make_tuple = ret->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(make_tuple, nullptr);
|
||||
EXPECT_TRUE(make_tuple->inputs().size() == 2);
|
||||
auto item0 = make_tuple->input(1)->cast<CNodePtr>();
|
||||
EXPECT_NE(item0, nullptr);
|
||||
EXPECT_TRUE(item0->inputs().size() == 3);
|
||||
auto bn = item0->input(1);
|
||||
EXPECT_NE(bn, nullptr);
|
||||
EXPECT_TRUE(bn->isa<CNode>());
|
||||
|
||||
// set kernel for SyncBN
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder.SetOutputsFormat(
|
||||
{kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
|
||||
builder.SetInputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
builder.SetOutputsDeviceType(
|
||||
{kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), bn.get());
|
||||
|
||||
// do sync_bn_split_pass
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::SyncBnSplit>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
auto new_graph = optimizer->Optimize(kernel_graph);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_sync_bn_split_tbe", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,15 +16,21 @@
|
|||
from mindspore.ops import Primitive
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import _constants as Constants
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||
bn_grad = G.BatchNormGrad(is_training=True)
|
||||
sync_bn_grad = G.SyncBatchNormGrad()
|
||||
bn_grad1 = Primitive('BNGrad1')
|
||||
bn_grad2 = Primitive('BNGrad2')
|
||||
bn_grad3 = Primitive('BNGrad3')
|
||||
bn_training_update_grad = Primitive('BNTrainingUpdateGrad')
|
||||
bn_training_reduce_grad = Primitive('BNTrainingReduceGrad')
|
||||
allreduce = Primitive('AllReduce')
|
||||
mul = Primitive('Mul')
|
||||
mul_value = Tensor(0.5, mstype.float32)
|
||||
|
||||
|
||||
class FnDict:
|
||||
|
@ -85,3 +91,36 @@ def test_bn_grad_split(tag):
|
|||
return make_tuple(output)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_sync_bn_grad_split(tag):
|
||||
""" test_sync_bn_grad_split """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(i0, i1, i2, i3, i4):
|
||||
bn_grad_output = sync_bn_grad(i0, i1, i2, i3, i4)
|
||||
item0 = tuple_getitem(bn_grad_output, 0)
|
||||
item1 = tuple_getitem(bn_grad_output, 1)
|
||||
item2 = tuple_getitem(bn_grad_output, 2)
|
||||
output = make_tuple(item0, item1, item2)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(i0, i1, i2, i3, i4):
|
||||
bn_update_grad_output = bn_training_update_grad(i0, i1, i3, i4)
|
||||
update_output0 = tuple_getitem(bn_update_grad_output, 0)
|
||||
update_output1 = tuple_getitem(bn_update_grad_output, 1)
|
||||
allreduce_output0 = allreduce(update_output0)
|
||||
allreduce_output1 = allreduce(update_output1)
|
||||
update_item0 = mul(allreduce_output0, mul_value)
|
||||
update_item1 = mul(allreduce_output1, mul_value)
|
||||
bn_reduce_grad_output = bn_training_reduce_grad(i0, i1, update_item0, update_item1, i2, i3, i4)
|
||||
output = make_tuple(bn_reduce_grad_output, update_item0, update_item1)
|
||||
item0 = tuple_getitem(output, 0)
|
||||
item1 = tuple_getitem(output, 1)
|
||||
item2 = tuple_getitem(output, 2)
|
||||
output = make_tuple(item0, item1, item2)
|
||||
return make_tuple(output)
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -15,16 +15,23 @@
|
|||
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.ops import _constants as Constants
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive(Constants.kTupleGetItem)
|
||||
bn = P.BatchNorm(is_training=True)
|
||||
sync_bn = inner.SyncBatchNorm()
|
||||
fused_bn1 = Primitive('FusedBN1')
|
||||
fused_bn2 = Primitive('FusedBN2')
|
||||
fused_bn3 = Primitive('FusedBN3')
|
||||
bn_training_reduce = Primitive('BNTrainingReduce')
|
||||
bn_training_update = Primitive('BNTrainingUpdate')
|
||||
allreduce = Primitive('AllReduce')
|
||||
mul = Primitive('Mul')
|
||||
mul_value = Tensor(0.5, mstype.float32)
|
||||
|
||||
|
||||
class FnDict:
|
||||
|
@ -89,3 +96,30 @@ def test_bn_split_tbe(tag):
|
|||
return make_tuple(output)
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_sync_bn_split_tbe(tag):
|
||||
""" test_sync_split_bn_fusion """
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(x, scale, b, mean, variance):
|
||||
bn_output = sync_bn(x, scale, b, mean, variance)
|
||||
output = tuple_getitem(bn_output, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(x, scale, b, mean, variance):
|
||||
bn_training_reduce_output = bn_training_reduce(x)
|
||||
bn_training_reduce_output0 = tuple_getitem(bn_training_reduce_output, 0)
|
||||
bn_training_reduce_output1 = tuple_getitem(bn_training_reduce_output, 1)
|
||||
allreduce_output0 = allreduce(bn_training_reduce_output0)
|
||||
allreduce_output1 = allreduce(bn_training_reduce_output1)
|
||||
bn_training_update_input1 = mul(allreduce_output0, mul_value)
|
||||
bn_training_update_input2 = mul(allreduce_output1, mul_value)
|
||||
bn_training_update_output = bn_training_update(x, bn_training_update_input1, bn_training_update_input2,
|
||||
scale, b, mean, variance)
|
||||
output = tuple_getitem(bn_training_update_output, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -1755,6 +1755,16 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('SyncBatchNorm', {
|
||||
'block': inner.SyncBatchNorm(),
|
||||
'desc_inputs': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'skip': []}),
|
||||
('SyncBatchNormGrad', {
|
||||
'block': G.SyncBatchNormGrad(),
|
||||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('TopK', {
|
||||
'block': P.TopK(),
|
||||
'desc_const': [5],
|
||||
|
|
Loading…
Reference in New Issue