forked from mindspore-Ecosystem/mindspore
Add BatchNorm fusion pattern with mix precision
This commit is contained in:
parent
94883f9b9c
commit
6e89ebe6f0
|
@ -201,6 +201,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
} else {
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>());
|
||||
}
|
||||
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
|
||||
if (context_ptr->ir_fusion_flag()) {
|
||||
|
|
|
@ -277,5 +277,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
}
|
||||
return bn_training_update_outputs[0];
|
||||
}
|
||||
|
||||
const BaseRef FusedBatchNormMixPrecisionFusion::DefinePattern() const {
|
||||
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index1 = std::make_shared<CondVar>(IsC);
|
||||
VarPtr index2 = std::make_shared<CondVar>(IsC);
|
||||
VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs});
|
||||
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0});
|
||||
VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1});
|
||||
VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2});
|
||||
VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_});
|
||||
VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_});
|
||||
VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1});
|
||||
VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2});
|
||||
VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_});
|
||||
VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_});
|
||||
VectorRef cast2 = VectorRef({prim::kPrimCast, mul0});
|
||||
VectorRef cast3 = VectorRef({prim::kPrimCast, mul1});
|
||||
VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2});
|
||||
VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3});
|
||||
VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0});
|
||||
return VectorRef({prim::kPrimDepend, depend0, assign_sub1});
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
|
@ -25,8 +26,8 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class FusedBatchNormFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit FusedBatchNormFusion(bool multigraph = true)
|
||||
: PatternProcessPass("fused_batch_norm_fusion", multigraph),
|
||||
explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph),
|
||||
data_input0_var_(std::make_shared<Var>()),
|
||||
data_input1_var_(std::make_shared<Var>()),
|
||||
data_input2_var_(std::make_shared<Var>()),
|
||||
|
@ -39,7 +40,7 @@ class FusedBatchNormFusion : public PatternProcessPass {
|
|||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
protected:
|
||||
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const;
|
||||
void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector<AnfNodePtr> &bn_training_reduce_outputs,
|
||||
|
@ -59,6 +60,15 @@ class FusedBatchNormFusion : public PatternProcessPass {
|
|||
VarPtr constant_input1_var_;
|
||||
VarPtr batch_norm_var_;
|
||||
};
|
||||
|
||||
class FusedBatchNormMixPrecisionFusion : public FusedBatchNormFusion {
|
||||
public:
|
||||
explicit FusedBatchNormMixPrecisionFusion(bool multigraph = true)
|
||||
: FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {}
|
||||
|
||||
~FusedBatchNormMixPrecisionFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_
|
||||
|
|
|
@ -50,5 +50,28 @@ TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_fusion) {
|
|||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWFusedBatchNormFusion, test_fused_batch_norm_mix_precision_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "before_mix_precision");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp_x{32, 64, 112, 112};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
std::vector<int> shp_y{64};
|
||||
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
for (size_t i = 0; i < 6; ++i) {
|
||||
args_spec_list.push_back(y_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::FusedBatchNormMixPrecisionFusion>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_fused_batch_norm_fusion", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -24,6 +24,7 @@ make_tuple = Primitive('make_tuple')
|
|||
tuple_getitem = Primitive('tuple_getitem')
|
||||
depend = Primitive('depend')
|
||||
BatchNorm = P.BatchNorm()
|
||||
Cast = P.Cast()
|
||||
BNTrainingReduce = Primitive('BNTrainingReduce')
|
||||
BNTrainingUpdate = Primitive('BNTrainingUpdate')
|
||||
constant0 = Tensor(0.1, mstype.float32)
|
||||
|
@ -59,6 +60,21 @@ def test_fused_batch_norm_fusion(tag):
|
|||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def before_mix_precision(input0, input1, input2, input3, input4, var0, var1):
|
||||
batch_norm = BatchNorm(input0, input1, input2, input3, input4)
|
||||
sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1))
|
||||
sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2))
|
||||
mul0 = Mul(sub0, constant0)
|
||||
mul1 = Mul(sub1, constant1)
|
||||
assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32))
|
||||
assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32))
|
||||
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = depend(depend0, assign_sub1)
|
||||
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, var0, var1):
|
||||
bn_training_reduce = BNTrainingReduce(input0)
|
||||
|
|
Loading…
Reference in New Issue