fix gpu cast fusion bug

This commit is contained in:
VectorSL 2020-07-18 17:20:17 +08:00
parent e984f3ecce
commit 80ed8e0e5c
2 changed files with 60 additions and 50 deletions

View File

@ -30,8 +30,7 @@ const BaseRef ReplaceBNCastFusion::DefinePattern() const {
VectorRef in_cast = VectorRef({prim::kPrimCast, x_});
VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_});
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
return out_cast;
return tupleget;
}
const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
@ -40,19 +39,9 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
if (item_idx != 0) {
return nullptr;
}
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 1);
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 3);
@ -65,14 +54,32 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(bias);
MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto outlist = GetRealNodeUsedList(graph, fbn2);
for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
return nullptr;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
}
}
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
outputs_type.clear();
outputs_shape.clear();
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2);
for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i));
@ -80,13 +87,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
}
outputs_type[0] = kNumberTypeFloat16;
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get());
outputs_type.clear();
outputs_shape.clear();
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
return tuple;
return node;
}
} // namespace opt
} // namespace mindspore

View File

@ -30,8 +30,7 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_});
VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_});
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_});
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
return out_cast;
return tupleget;
}
const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
@ -40,21 +39,16 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx != 0) {
return nullptr;
}
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0);
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0);
// if x_type is fp32, the cast is nessery.
if (x_type == kNumberTypeFloat32) {
return nullptr;
}
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2);
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3);
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4);
@ -66,13 +60,32 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(x_);
MS_EXCEPTION_IF_NULL(mean);
MS_EXCEPTION_IF_NULL(var);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
std::vector<TypeId> outputs_type;
std::vector<std::vector<size_t>> outputs_shape;
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto outlist = GetRealNodeUsedList(graph, fbn2g);
for (size_t i = 0; i < outlist->size(); i++) {
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
return nullptr;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
}
}
outputs_type.clear();
outputs_shape.clear();
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g);
for (size_t i = 0; i < output_num; i++) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i));
@ -80,12 +93,8 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
}
outputs_type[0] = kNumberTypeFloat16;
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get());
outputs_type.clear();
outputs_shape.clear();
outputs_type.push_back(kNumberTypeFloat16);
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
return tuple;
return node;
}
} // namespace opt
} // namespace mindspore