forked from mindspore-Ecosystem/mindspore
fix gpu cast fusion bug
This commit is contained in:
parent
e984f3ecce
commit
80ed8e0e5c
|
@ -30,8 +30,7 @@ const BaseRef ReplaceBNCastFusion::DefinePattern() const {
|
|||
VectorRef in_cast = VectorRef({prim::kPrimCast, x_});
|
||||
VectorRef fbn2 = VectorRef({prim::kPrimFusedBatchNorm, in_cast, scale_, bias_, mean_, var_});
|
||||
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2, index_});
|
||||
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
|
||||
return out_cast;
|
||||
return tupleget;
|
||||
}
|
||||
|
||||
const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
|
@ -40,19 +39,9 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int item_idx = GetValue<int>(value_node->value());
|
||||
|
||||
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
|
||||
auto fbn2 = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
auto x_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 0);
|
||||
auto x_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(x_after), 0);
|
||||
if (item_idx != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 1);
|
||||
auto bias = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2), 3);
|
||||
|
@ -65,14 +54,32 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
|||
MS_EXCEPTION_IF_NULL(bias);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
|
||||
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
|
||||
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
auto outlist = GetRealNodeUsedList(graph, fbn2);
|
||||
for (size_t i = 0; i < outlist->size(); i++) {
|
||||
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int item_idx = GetValue<int>(value_node->value());
|
||||
if (item_idx == 0) {
|
||||
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
|
||||
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
|
||||
return nullptr;
|
||||
}
|
||||
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
|
||||
outputs_type.push_back(kNumberTypeFloat16);
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
|
||||
}
|
||||
}
|
||||
|
||||
manager->Replace(utils::cast<CNodePtr>(x_after), utils::cast<CNodePtr>(x_before));
|
||||
outputs_type.clear();
|
||||
outputs_shape.clear();
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2, i));
|
||||
|
@ -80,13 +87,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
|||
}
|
||||
outputs_type[0] = kNumberTypeFloat16;
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2.get());
|
||||
|
||||
outputs_type.clear();
|
||||
outputs_shape.clear();
|
||||
outputs_type.push_back(kNumberTypeFloat16);
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
|
||||
return tuple;
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,8 +30,7 @@ const BaseRef ReplaceBNGradCastFusion::DefinePattern() const {
|
|||
VectorRef dy_cast = VectorRef({prim::kPrimCast, dy_});
|
||||
VectorRef fbn2g = VectorRef({prim::kPrimFusedBatchNormGrad, dy_cast, x_, scale_, mean_, var_});
|
||||
VectorRef tupleget = VectorRef({prim::kPrimTupleGetItem, fbn2g, index_});
|
||||
VectorRef out_cast = VectorRef({prim::kPrimCast, tupleget});
|
||||
return out_cast;
|
||||
return tupleget;
|
||||
}
|
||||
|
||||
const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
|
@ -40,21 +39,16 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
auto tuple = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 1);
|
||||
MS_EXCEPTION_IF_NULL(index_node);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int item_idx = GetValue<int>(value_node->value());
|
||||
if (item_idx != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple), 0);
|
||||
auto fbn2g = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
|
||||
|
||||
auto dy_after = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 0);
|
||||
auto dy_before = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(dy_after), 0);
|
||||
auto x_ = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 1);
|
||||
|
||||
auto x_type = AnfAlgo::GetOutputInferDataType(x_, 0);
|
||||
// if x_type is fp32, the cast is nessery.
|
||||
if (x_type == kNumberTypeFloat32) {
|
||||
return nullptr;
|
||||
}
|
||||
auto scale = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 2);
|
||||
auto mean = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 3);
|
||||
auto var = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(fbn2g), 4);
|
||||
|
@ -66,13 +60,32 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
|
|||
MS_EXCEPTION_IF_NULL(x_);
|
||||
MS_EXCEPTION_IF_NULL(mean);
|
||||
MS_EXCEPTION_IF_NULL(var);
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
|
||||
manager->Replace(utils::cast<CNodePtr>(node), utils::cast<CNodePtr>(tuple));
|
||||
std::vector<TypeId> outputs_type;
|
||||
std::vector<std::vector<size_t>> outputs_shape;
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
auto outlist = GetRealNodeUsedList(graph, fbn2g);
|
||||
for (size_t i = 0; i < outlist->size(); i++) {
|
||||
auto index_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(outlist->at(i).first), 1);
|
||||
auto value_node = index_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
int item_idx = GetValue<int>(value_node->value());
|
||||
if (item_idx == 0) {
|
||||
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
|
||||
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
|
||||
return nullptr;
|
||||
}
|
||||
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
|
||||
outputs_type.push_back(kNumberTypeFloat16);
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(outlist->at(i).first, 0));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, outlist->at(i).first.get());
|
||||
}
|
||||
}
|
||||
outputs_type.clear();
|
||||
outputs_shape.clear();
|
||||
manager->Replace(utils::cast<CNodePtr>(dy_after), utils::cast<CNodePtr>(dy_before));
|
||||
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(fbn2g);
|
||||
for (size_t i = 0; i < output_num; i++) {
|
||||
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(fbn2g, i));
|
||||
|
@ -80,12 +93,8 @@ const AnfNodePtr ReplaceBNGradCastFusion::Process(const FuncGraphPtr &graph, con
|
|||
}
|
||||
outputs_type[0] = kNumberTypeFloat16;
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fbn2g.get());
|
||||
outputs_type.clear();
|
||||
outputs_shape.clear();
|
||||
outputs_type.push_back(kNumberTypeFloat16);
|
||||
outputs_shape.push_back(AnfAlgo::GetOutputInferShape(tuple, 0));
|
||||
AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, tuple.get());
|
||||
return tuple;
|
||||
|
||||
return node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue