!9488 Fix DropoutGenMask output shape err of Transformer in Pynative mode

From: @zuochuanyong
Reviewed-by: @chujinjin,@kisnwang
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2020-12-04 14:23:36 +08:00 committed by Gitee
commit 89535c4612
1 changed files with 13 additions and 0 deletions

View File

@ -280,6 +280,19 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
const auto &attr_map = op_prim->evaluate_added_attrs(); const auto &attr_map = op_prim->evaluate_added_attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(), (void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
// Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
// caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
// the same but values are different
auto abstr = op_exec_info->abstract;
MS_EXCEPTION_IF_NULL(abstr);
auto build_shape = abstr->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
auto build_type = abstr->BuildType();
MS_EXCEPTION_IF_NULL(build_type);
(void)graph_info.append(std::to_string(build_type->type_id()) + "_");
return graph_info; return graph_info;
} }