!16547 Change keyvalue for prim_abs_list_, avoid mismatch

From: @zhangzhaoju
Reviewed-by: @zh_qh,@ginfung,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-24 08:53:11 +08:00 committed by Gitee
commit a76342f2cd
4 changed files with 11 additions and 11 deletions

View File

@ -36,14 +36,6 @@
namespace mindspore {
namespace ad {
struct PrimitiveTotalEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim;
extern KPrim g_k_prims;

View File

@ -878,7 +878,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
auto prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(prim);
auto temp = prim_abs_list_.find(prim->id());
auto temp = prim_abs_list_.find(prim);
if (temp != prim_abs_list_.end()) {
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
auto iter = temp->second.find(args_spec_list);
@ -924,7 +924,7 @@ py::object ForwardExecutor::GetOpOutputObject(const OpExecInfoPtr &op_exec_info,
// Add output abstract info into cache, the const value needs to infer evert step
if (!out_abstract_existed && !op_exec_info->is_dynamic_shape) {
auto &out = prim_abs_list_[prim->id()];
auto &out = prim_abs_list_[prim];
out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].attrs = prim->evaluate_added_attrs();
}

View File

@ -331,7 +331,7 @@ class ForwardExecutor {
private:
GradExecutorWeakPtr grad_executor_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
std::unordered_map<PrimitivePtr, AbstractListMap, PrimitiveHasher, PrimitiveTotalEqual> prim_abs_list_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
// Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;

View File

@ -156,5 +156,13 @@ struct PrimitiveHasher {
return prim->Hash();
}
};
struct PrimitiveTotalEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return *t1 == *t2;
}
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PRIMITIVE_H_