forked from mindspore-Ecosystem/mindspore
!16547 Change keyvalue for prim_abs_list_, avoid mismatch
From: @zhangzhaoju Reviewed-by: @zh_qh,@ginfung,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
a76342f2cd
|
@ -36,14 +36,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
struct PrimitiveTotalEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return *t1 == *t2;
|
||||
}
|
||||
};
|
||||
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
class KPrim;
|
||||
extern KPrim g_k_prims;
|
||||
|
|
|
@ -878,7 +878,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
|
|||
auto prim = op_exec_info->py_primitive;
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
auto temp = prim_abs_list_.find(prim->id());
|
||||
auto temp = prim_abs_list_.find(prim);
|
||||
if (temp != prim_abs_list_.end()) {
|
||||
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
|
||||
auto iter = temp->second.find(args_spec_list);
|
||||
|
@ -924,7 +924,7 @@ py::object ForwardExecutor::GetOpOutputObject(const OpExecInfoPtr &op_exec_info,
|
|||
|
||||
// Add output abstract info into cache, the const value needs to infer evert step
|
||||
if (!out_abstract_existed && !op_exec_info->is_dynamic_shape) {
|
||||
auto &out = prim_abs_list_[prim->id()];
|
||||
auto &out = prim_abs_list_[prim];
|
||||
out[args_spec_list].abs = op_exec_info->abstract;
|
||||
out[args_spec_list].attrs = prim->evaluate_added_attrs();
|
||||
}
|
||||
|
|
|
@ -331,7 +331,7 @@ class ForwardExecutor {
|
|||
|
||||
private:
|
||||
GradExecutorWeakPtr grad_executor_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
std::unordered_map<PrimitivePtr, AbstractListMap, PrimitiveHasher, PrimitiveTotalEqual> prim_abs_list_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
// Used to cache cast struct
|
||||
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
|
||||
|
|
|
@ -156,5 +156,13 @@ struct PrimitiveHasher {
|
|||
return prim->Hash();
|
||||
}
|
||||
};
|
||||
|
||||
struct PrimitiveTotalEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
MS_EXCEPTION_IF_NULL(t1);
|
||||
MS_EXCEPTION_IF_NULL(t2);
|
||||
return *t1 == *t2;
|
||||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_PRIMITIVE_H_
|
||||
|
|
Loading…
Reference in New Issue