!1383 keep different attributes for cnode evaluation

Merge pull request !1383 from amongo/KeepPrimAttrInCNode
This commit is contained in:
mindspore-ci-bot 2020-05-27 13:02:34 +08:00 committed by Gitee
commit 5b9c145ff8
24 changed files with 800 additions and 324 deletions

View File

@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto ctx = node_cfg_->context(); auto ctx = node_cfg_->context();
auto engine = node_cfg_->engine(); auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx); auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg); auto eval_result = engine->cache().GetValue(cfg);
if (abs == nullptr) { if (eval_result == nullptr || eval_result->abstract() == nullptr) {
return "Undefined"; return "Undefined";
} }
auto abs = eval_result->abstract();
auto dtype = abs->BuildType(); auto dtype = abs->BuildType();
auto shape = abs->BuildShape(); auto shape = abs->BuildShape();
std::ostringstream oss; std::ostringstream oss;

View File

@ -42,7 +42,11 @@ enum PrimType {
class Primitive : public Named { class Primitive : public Named {
public: public:
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} : Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false) {}
Primitive(const Primitive &prim) Primitive(const Primitive &prim)
: Named(prim), : Named(prim),
@ -50,14 +54,23 @@ class Primitive : public Named {
instance_name_(prim.instance_name_), instance_name_(prim.instance_name_),
is_base_(prim.is_base_), is_base_(prim.is_base_),
has_signature_(prim.has_signature_), has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_) {} prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false) {}
MS_DECLARE_PARENT(Primitive, Named); MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); } std::string ToString() const override { return name(); }
void BeginRecordAddAttr() {
evaluate_added_attrs_.clear();
record_evaluate_add_attr_ = true;
}
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr; attrs_[name] = attr;
if (record_evaluate_add_attr_) {
evaluate_added_attrs_[name] = attr;
}
return *this; return *this;
} }
@ -80,6 +93,7 @@ class Primitive : public Named {
py::function hook() const { return hook_; } py::function hook() const { return hook_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); } bool HasAttr() const { return !attrs_.empty(); }
@ -106,6 +120,7 @@ class Primitive : public Named {
protected: protected:
std::unordered_map<std::string, ValuePtr> attrs_; std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
private: private:
std::string instance_name_; std::string instance_name_;
@ -113,6 +128,7 @@ class Primitive : public Named {
bool is_base_; bool is_base_;
bool has_signature_; bool has_signature_;
PrimType prim_type_; PrimType prim_type_;
bool record_evaluate_add_attr_;
}; };
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {

View File

@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
} }
subargs.push_back(AbstractJoin(l_ptr->elements())); subargs.push_back(AbstractJoin(l_ptr->elements()));
} }
AbstractBasePtr engin_exc = engine->Execute(fn, subargs); EvalResultPtr engin_exc = engine->Execute(fn, subargs);
AbstractBasePtrList result; AbstractBasePtrList result;
for (std::size_t i = 1; i < args_spec_list.size(); i++) { for (std::size_t i = 1; i < args_spec_list.size(); i++) {
result.push_back(engin_exc); result.push_back(engin_exc->abstract());
} }
return std::make_shared<AbstractList>(result); return std::make_shared<AbstractList>(result);
} }
@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
AbstractBasePtr list_type = AbstractJoin(lst->elements()); AbstractBasePtr list_type = AbstractJoin(lst->elements());
auto result1 = engine->Execute(fn, lst->elements()); auto result1 = engine->Execute(fn, lst->elements());
auto result2 = engine->Execute(fn, {dflt, list_type}); auto result2 = engine->Execute(fn, {dflt, list_type});
MS_EXCEPTION_IF_NULL(result1); MS_EXCEPTION_IF_NULL(result1->abstract());
return result1->Join(result2); MS_EXCEPTION_IF_NULL(result2->abstract());
return result1->abstract()->Join(result2->abstract());
} }
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
return sorted_nodes; return sorted_nodes;
} }
AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size(); std::size_t nargs = fg->parameters().size();
@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto &arg = args_spec_list[i]; const auto &arg = args_spec_list[i];
const auto &node = parameters[i]; const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
engine->cache().set_value(conf, arg); engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
} }
const AnfNodePtr &func_node = fg->get_return(); const AnfNodePtr &func_node = fg->get_return();
@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto &node = *it; const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
ret_base = engine->GetEvaluatedValue(node_conf); ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << ret_base->ToString(); << ", abstract: " << ret_base->ToString();
} }
MS_EXCEPTION_IF_NULL(ret_base); MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString(); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
return ret_base; return std::make_shared<EvalResult>(ret_base, nullptr);
} }
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
return cloned_func_graph; return cloned_func_graph;
} }
AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
AnfNodeConfigPtr out_conf) {
const std::string &evaluator_name = ToString(); const std::string &evaluator_name = ToString();
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
args_spec_list = NormalizeArgs(args_spec_list); args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list); args_spec_list = BroadenUndeterminedArgs(args_spec_list);
@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
auto iter = cache_->find(args_spec_list); auto iter = cache_->find(args_spec_list);
if (iter == cache_->end()) { if (iter == cache_->end()) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
AbstractBasePtr ret = Eval(engine, args_spec_list); EvalResultPtr ret = Eval(engine, args_spec_list);
if (ret == nullptr) { if (ret->abstract() == nullptr) {
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
} }
MS_EXCEPTION_IF_NULL(ret); MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
(*cache_)[args_spec_list] = ret; (*cache_)[args_spec_list] = ret;
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return ret; return ret;
} else { } else {
MS_EXCEPTION_IF_NULL(iter->second); MS_EXCEPTION_IF_NULL(iter->second);
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; MS_EXCEPTION_IF_NULL(iter->second->abstract());
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>()); trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return iter->second; return iter->second;
} }
} }
AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr) { AnfNodeConfigPtr) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
AbstractBasePtr ret = EvalPrim(engine, args_spec_list); EvalResultPtr ret = EvalPrim(engine, args_spec_list);
return ret; return ret;
} }
AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
if (args_conf_list.size() == 0) { if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Size should greater than 0"; MS_LOG(EXCEPTION) << "Size should greater than 0";
} }
AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
// No need to cache. // No need to cache.
return ret; return ret;
} }
AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
AbstractBasePtr ret = EvalPrim(args_conf_list); EvalResultPtr ret = EvalPrim(args_conf_list);
return ret; return ret;
} }
AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
// Don't lookup from cache, as different out_conf with same node but different context // Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive. // may add different entry to anfnode_config_map_, like getattr primitive.
(*cache_)[args_spec_list] = ret; (*cache_)[args_spec_list] = ret;
return ret; return ret;
} }
AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(cache_);
auto iter = cache_->find(args_spec_list); auto iter = cache_->find(args_spec_list);
@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
(*cache_)[args_spec_list] = ret; (*cache_)[args_spec_list] = ret;
return ret; return ret;
} }
AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(cache_);
auto iter = cache_->find(args_spec_list); auto iter = cache_->find(args_spec_list);
@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
} }
// Call the original evaluator, get the result: y = f(x) // Call the original evaluator, get the result: y = f(x)
AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr); EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams; AbstractBasePtrList bparams;
@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final); AbstractFunctionPtr bprop =
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
// J(f)(J(x)) return a tuple (y, bprop_f) // J(f)(J(x)) return a tuple (y, bprop_f)
AbstractBasePtrList jargs = {result, bprop}; AbstractBasePtrList jargs = {result->abstract(), bprop};
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs); AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
(*cache_)[args_spec_list] = jtuple; auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
return jtuple; (*cache_)[args_spec_list] = infer_reuslt;
return infer_reuslt;
} }
AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() != args_spec_list_.size()) { if (args_spec_list.size() != args_spec_list_.size()) {
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
<< ", arguments no: " << args_spec_list.size(); << ", arguments no: " << args_spec_list.size();
@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
MS_EXCEPTION_IF_NULL(args_spec_list[i]); MS_EXCEPTION_IF_NULL(args_spec_list[i]);
(void)args_spec_list[i]->Join(args_spec_list_[i]); (void)args_spec_list[i]->Join(args_spec_list_[i]);
} }
return output_; return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -29,21 +29,28 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
using EvaluatorCacheMap = using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>; std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>; using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
using EvaluatorAttrMap =
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
class Evaluator : public Base { class Evaluator : public Base {
public: public:
explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {} explicit Evaluator(const std::string &id)
: cache_(std::make_shared<EvaluatorCacheMap>()),
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
identifier_(id) {}
~Evaluator() override = default; ~Evaluator() override = default;
MS_DECLARE_PARENT(Evaluator, Base); MS_DECLARE_PARENT(Evaluator, Base);
// difference between Run() and Eval(): // difference between Run() and Eval():
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
// Run() will modify cache_ member, so it cannot marked as const; // Run() will modify cache_ member, so it cannot marked as const;
virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
@ -58,9 +65,10 @@ class Evaluator : public Base {
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
EvaluatorCacheMapPtr &cache() { return cache_; } EvaluatorCacheMapPtr &cache() { return cache_; }
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
EvaluatorCacheMapPtr cache_; EvaluatorCacheMapPtr cache_;
EvaluatorAttrMapPtr attr_cache_;
std::string identifier_; std::string identifier_;
AnfNodeWeakPtr bound_node_; AnfNodeWeakPtr bound_node_;
@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator {
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
~PrimEvaluator() override = default; ~PrimEvaluator() override = default;
MS_DECLARE_PARENT(PrimEvaluator, Evaluator); MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
}; };
@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TrivialPrimEvaluator() override = default; ~TrivialPrimEvaluator() override = default;
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
}; };
class TransitionPrimEvaluator : public PrimEvaluator { class TransitionPrimEvaluator : public PrimEvaluator {
@ -90,9 +98,9 @@ class TransitionPrimEvaluator : public PrimEvaluator {
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TransitionPrimEvaluator() override = default; ~TransitionPrimEvaluator() override = default;
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
// Parameter in_conf0 : the first element in args_conf_list; // Parameter in_conf0 : the first element in args_conf_list;
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
}; };
@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~SymbolicPrimEvaluator() override = default; ~SymbolicPrimEvaluator() override = default;
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
}; };
// Evaluator will be stored in AnalysisEngine.constructors_ // Evaluator will be stored in AnalysisEngine.constructors_
@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator {
DummyEvaluator() : Evaluator("dummy") {} DummyEvaluator() : Evaluator("dummy") {}
~DummyEvaluator() override = default; ~DummyEvaluator() override = default;
MS_DECLARE_PARENT(DummyEvaluator, Evaluator); MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
}; };
// Wrap another evaluator to track a subset of uses. // Wrap another evaluator to track a subset of uses.
@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator {
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
private: private:
@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
~BaseFuncGraphEvaluator() override = default; ~BaseFuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator {
} }
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
} }
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private: private:
@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator {
~VirtualEvaluator() override = default; ~VirtualEvaluator() override = default;
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
std::string ToString() const override { return identifier_; } std::string ToString() const override { return identifier_; }
private: private:
@ -285,11 +292,11 @@ class JEvaluator : public Evaluator {
} }
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
} }
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private: private:

View File

@ -135,12 +135,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using mindspore::parse::PyObjectWrapper; using mindspore::parse::PyObjectWrapper;
AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
return abs_base; prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return infer_result;
} }
AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
if (!prim_->isa<prim::DoSignaturePrimitive>()) { if (!prim_->isa<prim::DoSignaturePrimitive>()) {
@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope; ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) { if (out_conf != nullptr) {
@ -212,7 +216,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
return graph_specialize_args; return graph_specialize_args;
} }
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
// get the forward graph // get the forward graph
MS_EXCEPTION_IF_NULL(args_spec_list[0]); MS_EXCEPTION_IF_NULL(args_spec_list[0]);
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>(); AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
} }
} // end anonymous namespace } // end anonymous namespace
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
const auto &iter = cache_->find(args); const auto &iter = cache_->find(args);
@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
} }
auto infer_fuc = pyobj.attr("__infer__"); auto infer_fuc = pyobj.attr("__infer__");
prim_py_->BeginRecordAddAttr();
py::dict output = infer_fuc(*py_args); py::dict output = infer_fuc(*py_args);
prim_py_->EndRecordAddAttr();
auto added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
auto res_spec = PyInferRes2Abstract(prim_py_, output); auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
(*cache_)[args] = res_spec; auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
return res_spec; (*cache_)[args] = infer_result;
return infer_result;
} }
AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if (nargs_ != args.size()) { if (nargs_ != args.size()) {
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
} }
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type); AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
return abs_base; return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
} }
ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
@ -553,7 +560,7 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
manager->AddFuncGraph(func_graph); manager->AddFuncGraph(func_graph);
} }
AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &old_conf) { const AnfNodeConfigPtr &old_conf) {
MS_EXCEPTION_IF_NULL(old_conf); MS_EXCEPTION_IF_NULL(old_conf);
@ -585,7 +592,7 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
return eng->ForwardConfig(old_conf, fn_conf); return eng->ForwardConfig(old_conf, fn_conf);
} }
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list, const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
// args_spec_list: same as StaticGetter // args_spec_list: same as StaticGetter
@ -627,7 +634,7 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
return eng->ForwardConfig(out_conf, fn_conf); return eng->ForwardConfig(out_conf, fn_conf);
} }
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
if (args_spec_list.empty()) { if (args_spec_list.empty()) {
@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
AbstractBasePtr attr = cls->GetAttribute(item_name); AbstractBasePtr attr = cls->GetAttribute(item_name);
if (attr != nullptr) { if (attr != nullptr) {
return attr; return std::make_shared<EvalResult>(attr, nullptr);
} }
ValuePtr method = cls->GetMethod(item_name); ValuePtr method = cls->GetMethod(item_name);
@ -660,7 +667,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
return StaticGetterInferred(converted_v, data_conf, out_conf); return StaticGetterInferred(converted_v, data_conf, out_conf);
} }
AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
const TypePtr &data_type, const ConfigPtr &data_conf, const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_v); MS_EXCEPTION_IF_NULL(item_v);
@ -689,7 +696,7 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e
return StaticGetterInferred(converted_v, data_conf, out_conf); return StaticGetterInferred(converted_v, data_conf, out_conf);
} }
AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function // Inputs: namespace and its static function; or class and its member function
CheckArgsSize("StaticGetter", args_spec_list, 2); CheckArgsSize("StaticGetter", args_spec_list, 2);
@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
~EmbedEvaluator() override = default; ~EmbedEvaluator() override = default;
MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
// arg: free variable to be embedded // arg: free variable to be embedded
if (args_conf_list.size() != 1) { if (args_conf_list.size() != 1) {
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]); AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(node_conf); MS_EXCEPTION_IF_NULL(node_conf);
AbstractBasePtr x = node_conf->GetEvaluatedValue(); AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
x = SensitivityTransform(x); x = SensitivityTransform(x);
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x); SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>()); AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
return abs_scalar; return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
} }
}; };
@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
~RefToEmbedEvaluator() override = default; ~RefToEmbedEvaluator() override = default;
MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
if (args_conf_list.size() != 1) { if (args_conf_list.size() != 1) {
MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
return nullptr; return nullptr;
@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
return nullptr; return nullptr;
} }
AbstractBasePtr abs = node_conf->GetEvaluatedValue(); AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>(); AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) { if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
} }
auto refkey = key_value->cast<RefKeyPtr>(); auto refkey = key_value->cast<RefKeyPtr>();
if (refkey == nullptr) { if (refkey == nullptr) {
return std::make_shared<AbstractScalar>(type); return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
} }
std::string name = refkey->tag(); std::string name = refkey->tag();
@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x); x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
return abs_scalar; return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
} }
}; };
@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
~GetAttrEvaluator() override = default; ~GetAttrEvaluator() override = default;
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
// Inputs: data, item // Inputs: data, item
if (args_spec_list.size() != 2) { if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
} }
AbstractBasePtr ret = nullptr; EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) { if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
~ResolveEvaluator() override = default; ~ResolveEvaluator() override = default;
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
// Inputs: namespace, symbol // Inputs: namespace, symbol
if (args_spec_list.size() != 2) { if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
} }
AbstractBasePtr ret = nullptr; EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) { if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
~CreateInstanceEvaluator() override = default; ~CreateInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) { if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
} }
@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
} }
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
(*cache_)[args_spec_list] = ret; auto infer_result = std::make_shared<EvalResult>(ret, nullptr);
return ret; (*cache_)[args_spec_list] = infer_result;
return infer_result;
} }
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator {
public: public:
PartialEvaluator() : Evaluator("PartialEvaluator") {} PartialEvaluator() : Evaluator("PartialEvaluator") {}
~PartialEvaluator() override = default; ~PartialEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf = nullptr) override { AnfNodeConfigPtr out_conf = nullptr) override {
if (args_conf_list.size() == 0) { if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Args size should be greater than 0"; MS_LOG(EXCEPTION) << "Args size should be greater than 0";
} }
MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node()); MS_EXCEPTION_IF_NULL(out_conf->node());
auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
AbstractBasePtrList args_spec_list{arg0_value}; AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) { if (arg0_value->isa<AbstractError>()) {
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node()); auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString(); << " as func is: " << arg0_value->ToString();
(*cache_)[args_spec_list] = ret; auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
return ret; (*cache_)[args_spec_list] = eval_result;
return eval_result;
} }
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0); auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
// Sometimes, node[0] in out_conf becomes phi0; // Sometimes, node[0] in out_conf becomes phi0;
@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator {
} }
} }
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); });
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
auto cnode = out_conf->node()->cast<CNodePtr>(); auto cnode = out_conf->node()->cast<CNodePtr>();
@ -989,15 +999,16 @@ class PartialEvaluator : public Evaluator {
func->Visit(build_partial); func->Visit(build_partial);
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
(*cache_)[args_spec_list] = ret; auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
return ret; (*cache_)[args_spec_list] = infer_result;
return infer_result;
} }
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
const AnfNodeConfigPtr &out_conf = nullptr) const { const AnfNodeConfigPtr &out_conf = nullptr) const {
MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node()); MS_EXCEPTION_IF_NULL(out_conf->node());

View File

@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
~StandardPrimEvaluator() override = default; ~StandardPrimEvaluator() override = default;
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
PrimitivePtr prim() { return prim_; } PrimitivePtr prim() { return prim_; }
std::string ToString() const override { return identifier_ + prim_->name(); } std::string ToString() const override { return identifier_ + prim_->name(); }
@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator {
: TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
~PythonPrimEvaluator() override = default; ~PythonPrimEvaluator() override = default;
MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); } PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
std::string ToString() const override { return identifier_ + prim_py_->name(); } std::string ToString() const override { return identifier_ + prim_py_->name(); }
@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator {
public: public:
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
~DoSignatureEvaluator() override = default; ~DoSignatureEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override; AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator {
public: public:
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
~UnpackGraphEvaluator() override = default; ~UnpackGraphEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override; AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator {
~UniformPrimEvaluator() override = default; ~UniformPrimEvaluator() override = default;
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
ValuePtr RunImpl(const ValuePtrList &args) const; ValuePtr RunImpl(const ValuePtrList &args) const;
// If eval_value_ is False, return broadened arguments. // If eval_value_ is False, return broadened arguments.

View File

@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
if (conf->node()->intermediate_abstract()) { if (conf->node()->intermediate_abstract()) {
return conf->node()->intermediate_abstract(); return conf->node()->intermediate_abstract();
} }
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
} }
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() {
// Specialize CNode in func graphs // Specialize CNode in func graphs
void FuncGraphSpecializer::SecondPass() { void FuncGraphSpecializer::SecondPass() {
for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) { for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) {
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
ProcessCNode(node->cast<CNodePtr>()); ProcessCNode(node->cast<CNodePtr>());
} }
@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AnfNodeConfigPtr conf = MakeConfig(node); AnfNodeConfigPtr conf = MakeConfig(node);
AnfNodePtr new_node = GetReplicatedNode(node); AnfNodePtr new_node = GetReplicatedNode(node);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
if (new_node->func_graph() != specialized_func_graph_) { if (new_node->func_graph() != specialized_func_graph_) {
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
<< ", new_node: " << new_node->DebugString() << ", new_node: " << new_node->DebugString()
@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto attrs = conf->GetEvaluatedValue()->attribute();
auto c_old = node->cast<CNodePtr>(); auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>(); auto c_new = new_node->cast<CNodePtr>();
auto new_inputs = c_new->inputs(); auto new_inputs = c_new->inputs();
@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival); AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
if (replace_node == nullptr) { if (replace_node == nullptr) {
replace_node = BuildReplacedNode(iconf); replace_node = BuildReplacedNode(iconf);
MS_EXCEPTION_IF_NULL(replace_node); MS_EXCEPTION_IF_NULL(replace_node);
@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
} }
auto attrs = std::make_shared<AttrValueMap>();
for (size_t i = 0; i < partial_closure->args().size(); i++) { for (size_t i = 0; i < partial_closure->args().size(); i++) {
auto old_node = cnode->input(i + 2); auto old_node = cnode->input(i + 2);
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
if (possibile_value_node != nullptr) { if (possibile_value_node != nullptr) {
partial_node_list.push_back(possibile_value_node); partial_node_list.push_back(possibile_value_node);
} else { } else {
@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
const EvaluatorPtr &eval) { const EvaluatorPtr &eval) {
MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(eval);
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices; std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
AbstractBasePtr ret = nullptr; EvalResultPtr ret = nullptr;
AbstractBasePtrList broaded_argvals; AbstractBasePtrList broaded_argvals;
for (auto &argvals_map : *evalcaches_[eval]) { for (auto &argvals_map : *evalcaches_[eval]) {
auto argvals = argvals_map.first; auto argvals = argvals_map.first;
@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
(*real)[broaded_argvals] = ret; (*real)[broaded_argvals] = ret;
evalcaches_[eval] = real; evalcaches_[eval] = real;
return std::make_pair(broaded_argvals, ret); return std::make_pair(broaded_argvals, ret->abstract());
} else { } else {
MS_LOG(DEBUG) << "Choices.size: " << choices.size(); MS_LOG(DEBUG) << "Choices.size: " << choices.size();
return std::make_pair(AbstractBasePtrList(), nullptr); return std::make_pair(AbstractBasePtrList(), nullptr);
@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
return; return;
} }
specializer_->AddSeen(new_node); specializer_->AddSeen(new_node);
auto new_inputs = new_node->inputs(); auto new_inputs = new_node->inputs();
if (new_inputs.empty()) { if (new_inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
@ -530,8 +530,14 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
} }
if (CanSpecializeNode(func)) { if (CanSpecializeNode(func)) {
// for primitive node , we build the primitive node with infered attributes in the first pass
// so we do not build replaced node again here in second pass
if (IsValueNode<Primitive>(func)) {
new_inputs[0] = func;
} else {
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
} }
}
for (size_t i = 0; i < argvals.size();) { for (size_t i = 0; i < argvals.size();) {
size_t next = i + 1; size_t next = i + 1;
@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
} }
i = next; i = next;
} }
new_node->set_inputs(new_inputs); new_node->set_inputs(new_inputs);
} }
@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
EvaluatorCacheMap evaluator_cache_map = *eval->cache(); EvaluatorCacheMap evaluator_cache_map = *eval->cache();
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
*result = std::make_pair(argvals, evaluator_cache_map[argvals]); *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
return kSpecializeSuccess; return kSpecializeSuccess;
} }
DumpEvaluatorCache(evaluator_cache_map, argvals); DumpEvaluatorCache(evaluator_cache_map, argvals);
@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
MS_EXCEPTION_IF_NULL(choices); MS_EXCEPTION_IF_NULL(choices);
if (choices->count(argvals)) { if (choices->count(argvals)) {
*result = std::make_pair(argvals, (*choices)[argvals]); *result = std::make_pair(argvals, (*choices)[argvals]->abstract());
return kSpecializeSuccess; return kSpecializeSuccess;
} else if (choices->size() == 1) { } else if (choices->size() == 1) {
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
*result = std::make_pair(choices->begin()->first, choices->begin()->second); *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
return kSpecializeSuccess; return kSpecializeSuccess;
} else if (choices->empty()) { } else if (choices->empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
return kSpecializeFindUniqueArgvalPoly; return kSpecializeFindUniqueArgvalPoly;
} }
} }
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
auto &prim_attrs = prim->attrs();
bool is_attr_same = true;
for (auto &item : *attrs) {
auto itr = prim_attrs.find(item.first);
if (itr != prim_attrs.end()) {
if (!(*(itr->second) == *(item.second))) {
is_attr_same = false;
break;
}
} else {
is_attr_same = false;
break;
}
}
if (!is_attr_same) {
if (prim->isa<PrimitivePy>()) {
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
auto clone_fn = prim_py->GetPyObj().attr("_clone");
py::object new_obj = clone_fn();
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}
return cloned_prim;
}
auto cloned_prim = std::make_shared<Primitive>(*prim);
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}
return cloned_prim;
}
return prim;
}
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) { AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs) {
MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(ival); MS_EXCEPTION_IF_NULL(ival);
@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
ValuePtr value = nullptr; ValuePtr value = nullptr;
if (abs->isa<PrimitiveAbstractClosure>()) { if (abs->isa<PrimitiveAbstractClosure>()) {
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs); auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
if (attrs != nullptr) {
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
} else {
value = real_fn->prim(); value = real_fn->prim();
}
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) { } else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs); auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
value = real_fn->meta_func_graph(); value = real_fn->meta_func_graph();

View File

@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
// Build a value node if ival is constant and not any-value // Build a value node if ival is constant and not any-value
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival); AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs);
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
// replicated node. // replicated node.
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);

View File

@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
return nullptr; return nullptr;
} }
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
<< ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString() << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
<< ", Pointer: " << arg.get(); << ", Pointer: " << result->abstract().get();
cache_[conf] = arg; cache_[conf] = result;
// Set intermediate abstract value. // Set intermediate abstract value.
if (IsIntermediateAbstract(arg)) { if (IsIntermediateAbstract(result->abstract())) {
if (conf->node()->intermediate_abstract() == nullptr) { if (conf->node()->intermediate_abstract() == nullptr) {
conf->node()->set_intermediate_abstract(arg); conf->node()->set_intermediate_abstract(result->abstract());
MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString(); MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
} else { } else {
auto old_spec = conf->node()->intermediate_abstract(); auto old_spec = conf->node()->intermediate_abstract();
auto joined_spec = IntermediateJoin(arg, old_spec); auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
conf->node()->set_intermediate_abstract(joined_spec); conf->node()->set_intermediate_abstract(joined_spec);
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
<< arg->ToString() << "\njoined_spec:\t" << result->abstract()->ToString() << "\njoined_spec:\t"
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
} }
} }
} }
AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
auto value = cache_.find(conf); auto value = cache_.find(conf);
if (value == cache_.end()) { if (value == cache_.end()) {
return nullptr; return nullptr;
@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return eval->graph_context(); return eval->graph_context();
} }
AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
auto value = cache_.GetValue(conf); auto value = cache_.GetValue(conf);
if (value != nullptr) { if (value != nullptr) {
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", " MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
<< value->ToString(); << ", " << value->abstract()->ToString();
return value; return value;
} }
@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
return value; return value;
} }
AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
AnfNodePtr node = conf->node(); AnfNodePtr node = conf->node();
AbstractBasePtr ret_abstract = nullptr; EvalResultPtr eval_result = nullptr;
#ifdef DEBUG #ifdef DEBUG
compute_conf_stack_.push_back(node); compute_conf_stack_.push_back(node);
std::ostringstream buffer; std::ostringstream buffer;
@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->abstract() != nullptr) { if (node->abstract() != nullptr) {
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
ret_abstract = node->abstract(); eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
} else if (node->isa<ValueNode>()) { } else if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>(); auto value_node = node->cast<ValueNodePtr>();
ret_abstract = EvalValueNode(value_node, conf); eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
} else if (node->isa<CNode>()) { } else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
trace::TraceEvalCNodeEnter(conf); trace::TraceEvalCNodeEnter(conf);
ret_abstract = EvalCNode(cnode, conf); eval_result = EvalCNode(cnode, conf);
trace::TraceEvalCNodeLeave(); trace::TraceEvalCNodeLeave();
} else { } else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
#ifdef DEBUG #ifdef DEBUG
compute_conf_stack_.pop_back(); compute_conf_stack_.pop_back();
if (ret_abstract == nullptr) { if (eval_result == nullptr) {
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} }
#endif #endif
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString(); MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
return ret_abstract; return eval_result;
} }
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return ToAbstract(value_node->value(), conf->context(), conf); return ToAbstract(value_node->value(), conf->context(), conf);
} }
AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
MS_EXCEPTION_IF_NULL(func_conf); MS_EXCEPTION_IF_NULL(func_conf);
// Keep it in a local variable, otherwise smart pointer will free it. // Keep it in a local variable, otherwise smart pointer will free it.
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue(); AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
if (maybe_func == nullptr) { if (maybe_func == nullptr) {
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
return ExecuteEvaluators(infs, conf, args_conf_list); return ExecuteEvaluators(infs, conf, args_conf_list);
} }
AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
ConfigPtrList args_conf_list; ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); }); [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
return tracked_eval; return tracked_eval;
} }
AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
const ConfigPtrList &args_conf_list) {
if (evaluators.size() == 1) { if (evaluators.size() == 1) {
EvaluatorPtr eval = evaluators[0]; EvaluatorPtr eval = evaluators[0];
MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(eval);
@ -465,7 +464,7 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
} }
AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) { const ConfigPtrList &args_conf_list) {
AbstractBasePtrList out_specs; AbstractBasePtrList out_specs;
@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr { [](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue(); return conf->GetEvaluatedValue()->abstract();
}); });
for (auto eval : evaluators) { for (auto eval : evaluators) {
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
eval_trace_.push_back(current_inf); eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval); MS_EXCEPTION_IF_NULL(eval);
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec); MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
out_specs.push_back(out_spec); out_specs.push_back(eval_result->abstract());
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
eval_trace_.pop_back(); eval_trace_.pop_back();
if (eval_trace_.empty()) { if (eval_trace_.empty()) {
multi_poss_.clear(); multi_poss_.clear();
@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
// Try to travel the latest undetermined. // Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) { if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec); MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
return out_spec; << " return out_spec: " << eval_result->abstract()->ToString();
return eval_result;
} }
} }
} }
@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
if (out_specs.size() == 1) { if (out_specs.size() == 1) {
MS_EXCEPTION_IF_NULL(out_specs[0]); MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation. // If only one result derived, then broaden it to avoid wrong constant propagation.
return out_specs[0]->Broaden(); return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
} }
auto joined_spec = AbstractJoin(out_specs); auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec); MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return joined_spec; return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
} }
AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() { EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>(); AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
return engine_.lock()->GetEvaluatedValue(self); return engine_.lock()->GetEvaluatedValue(self);
} }
@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
return a; return a;
} }
AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
auto evaluator = GetPrimEvaluator(primitive, nullptr); auto evaluator = GetPrimEvaluator(primitive, nullptr);
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
if (!evaluator->isa<TrivialPrimEvaluator>()) { if (!evaluator->isa<TrivialPrimEvaluator>()) {
@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
<< evaluator->ToString(); << evaluator->ToString();
} }
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator); auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs); auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
return res_spec; return eval_result;
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -40,13 +40,33 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
// define attribute value map
using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
// the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base {
public:
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
~EvalResult() override = default;
MS_DECLARE_PARENT(EvalResult, Base);
AbstractBasePtr abstract() { return abstract_; }
AttrValueMapPtr attribute() { return attribute_; }
private:
AbstractBasePtr abstract_;
AttrValueMapPtr attribute_;
};
using EvalResultPtr = std::shared_ptr<EvalResult>;
// Superclass for AnfNodeConfig and VirtualConfig. // Superclass for AnfNodeConfig and VirtualConfig.
class Config : public Base { class Config : public Base {
public: public:
Config() = default; Config() = default;
~Config() override = default; ~Config() override = default;
MS_DECLARE_PARENT(Config, Base); MS_DECLARE_PARENT(Config, Base);
virtual AbstractBasePtr GetEvaluatedValue() = 0; virtual EvalResultPtr GetEvaluatedValue() = 0;
}; };
// Config will be stored in AnalysisCache // Config will be stored in AnalysisCache
@ -74,7 +94,7 @@ class AnfNodeConfig : public Config {
~AnfNodeConfig() override = default; ~AnfNodeConfig() override = default;
MS_DECLARE_PARENT(AnfNodeConfig, Config); MS_DECLARE_PARENT(AnfNodeConfig, Config);
AbstractBasePtr GetEvaluatedValue() override; EvalResultPtr GetEvaluatedValue() override;
AnalysisContextPtr context() const { return context_; } AnalysisContextPtr context() const { return context_; }
@ -123,7 +143,9 @@ class VirtualConfig : public Config {
~VirtualConfig() override = default; ~VirtualConfig() override = default;
MS_DECLARE_PARENT(VirtualConfig, Config); MS_DECLARE_PARENT(VirtualConfig, Config);
AbstractBasePtr GetEvaluatedValue() override { return abstract_; } EvalResultPtr GetEvaluatedValue() override {
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
}
private: private:
AbstractBasePtr abstract_; AbstractBasePtr abstract_;
@ -135,11 +157,11 @@ class AnalysisCache {
AnalysisCache() = default; AnalysisCache() = default;
~AnalysisCache() = default; ~AnalysisCache() = default;
void Clear() { cache_.clear(); } void Clear() { cache_.clear(); }
void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg); void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf); EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
private: private:
std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_; std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
}; };
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>; using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
@ -147,7 +169,7 @@ using AnfNodeConfigMap =
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>; std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
struct AnalysisResult { struct AnalysisResult {
AbstractBasePtr inferred; EvalResultPtr inferred;
AnalysisContextPtr context; AnalysisContextPtr context;
}; };
@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze. // func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function. // Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args). // Infer the result of fn(args).
AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear(); void Clear();
void ClearEvaluatorCache(); void ClearEvaluatorCache();
AnalysisCache &cache() { return cache_; } AnalysisCache &cache() { return cache_; }
@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// Set the analysis result for orig to the result for new. // Set the analysis result for orig to the result for new.
// This sets an entry in anfnode_config_map from orig to new. // This sets an entry in anfnode_config_map from orig to new.
AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
(void)anfnode_config_map_.emplace(orig_conf, new_conf); (void)anfnode_config_map_.emplace(orig_conf, new_conf);
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list); const ConfigPtrList &args_conf_list);
AbstractBasePtr Eval(const AnfNodeConfigPtr &conf); EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list); const ConfigPtrList &args_conf_list);
AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list);
#ifdef DEBUG #ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_; std::vector<AnfNodePtr> compute_conf_stack_;
@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden); return FromValueInside(MakeValue(value), broaden);
} }
AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
} }
} }
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
op_exec_info->abstract = infer_res; op_exec_info->abstract = infer_res;
} }

View File

@ -26,6 +26,8 @@
#include <list> #include <list>
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <queue>
#include <set>
#include "ir/visitor.h" #include "ir/visitor.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
return res; return res;
} }
// search the cnodes inside this graph only
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
std::queue<CNodePtr> todo;
todo.push(ret);
std::vector<CNodePtr> sorted_nodes;
auto seen = NewSeenGeneration();
while (!todo.empty()) {
CNodePtr top = todo.front();
todo.pop();
sorted_nodes.push_back(top);
auto inputs = top->inputs();
for (auto &item : inputs) {
if (item->seen_ == seen) {
continue;
}
if (item->isa<CNode>()) {
todo.push(item->cast<CNodePtr>());
}
item->seen_ = seen;
}
}
return sorted_nodes;
}
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) { std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
std::vector<AnfNodePtr> vecs; std::vector<AnfNodePtr> vecs;
if (node == nullptr) { if (node == nullptr) {

View File

@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
const IncludeFunc &include = AlwaysInclude); const IncludeFunc &include = AlwaysInclude);
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
class FuncGraphIndex { class FuncGraphIndex {
public: public:
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,

View File

@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init ExpandDims""" """init ExpandDims"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
def __infer__(self, x, axis): def __infer__(self, x, axis):
@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer):
# if primitive need setattr in __infer__ need add this flag # if primitive need setattr in __infer__ need add this flag
"""init Cast""" """init Cast"""
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
self.__setattr_flag__ = True
def __infer__(self, x, t): def __infer__(self, x, t):
src_type = x['dtype'] src_type = x['dtype']
@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer):
def __init__(self): def __init__(self):
"""init Reshape""" """init Reshape"""
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
self.__setattr_flag__ = True
def __infer__(self, x, shape): def __infer__(self, x, shape):
shape_v = shape['value'] shape_v = shape['value']
@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init Transpose""" """init Transpose"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
def __infer__(self, x, perm): def __infer__(self, x, perm):
@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init index_select""" """init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis): def __infer__(self, params, indices, axis):
@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Tile""" """init Tile"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x): def __infer__(self, input_x):
@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Pack""" """init Pack"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis self.axis = axis
@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Unpack""" """init Unpack"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis self.axis = axis
@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init""" """init"""
self.__setattr_flag__ = True
def infer_shape(self, cond_shape, x_shape, y_shape): def infer_shape(self, cond_shape, x_shape, y_shape):
if cond_shape != x_shape or x_shape != y_shape: if cond_shape != x_shape or x_shape != y_shape:

View File

@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.name cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
@ -596,7 +595,6 @@ class BatchMatMul(MatMul):
@prim_attr_register @prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.name cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs): def infer_shape(self, inputs):

View File

@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer):
"""init Conv2D""" """init Conv2D"""
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation) self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name) validator.check_value_type('pad', pad, (int,), self.name)
@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer):
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel out_channel = self.out_channel
out_shape = [x_shape[0], out_channel, h_out, w_out] out_shape = [x_shape[0], out_channel, h_out, w_out]
return out_shape return out_shape

View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import functools
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.primitive import constexpr
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_cast_op_attr():
class CastNet(nn.Cell):
def __init__(self):
super(CastNet, self).__init__()
self.cast = P.Cast()
def construct(self, x, t):
return self.cast(x, t)
class CastTypeTest(nn.Cell):
def __init__(self, net):
super(CastTypeTest, self).__init__()
self.net = net
self.cast = P.Cast()
def construct(self, x, y, z):
cast_op = self.cast
t1 = cast_op(x, mstype.float32)
t2 = cast_op(y, mstype.int32)
cast_net = self.net
t3 = cast_net(x, mstype.float16)
t4 = cast_net(y, mstype.int32)
t5 = cast_net(z, mstype.float16)
return (t1, t2, t3, t4, t5)
net = CastTypeTest(CastNet())
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
out = net(t1, t2, t3)
assert out[0].asnumpy().dtype == np.float32
assert out[1].asnumpy().dtype == np.int32
assert out[2].asnumpy().dtype == np.float16
assert out[3].asnumpy().dtype == np.int32
assert out[4].asnumpy().dtype == np.float16

View File

@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) {
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step); AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tensor, slice}; AbstractBasePtrList args_spec_list = {tensor, slice};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) {
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2); AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
AbstractBasePtrList args_spec_list = {tensor, start_index}; AbstractBasePtrList args_spec_list = {tensor, start_index};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles); AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed."; FAIL() << "Cast ret to abstract array failed.";
} }
@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map); abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }
@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) {
auto tuple = std::make_shared<AbstractTuple>(eles); auto tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tuple}; AbstractBasePtrList args_spec_list = {tuple};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred); AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
if (ret == nullptr) { if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed."; FAIL() << "Cast ret to abstract tuple failed.";
} }

View File

@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
AbstractBasePtr abstract_v2 = FromValue(2, false); AbstractBasePtr abstract_v2 = FromValue(2, false);
AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2}; AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2};
AbstractBasePtr abstract_val = FromValue(10, false); AbstractBasePtr abstract_val = FromValue(10, false);
cache[args_spec_list] = abstract_val; cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>());
auto iter = cache.find(args_spec_list); auto iter = cache.find(args_spec_list);
ASSERT_TRUE(iter != cache.end()); ASSERT_TRUE(iter != cache.end());
ASSERT_TRUE(iter->second == abstract_val); ASSERT_TRUE(iter->second->abstract() == abstract_val);
AbstractBasePtr abstract_v1_variant1 = FromValue(1, false); AbstractBasePtr abstract_v1_variant1 = FromValue(1, false);
AbstractBasePtr abstract_v2_variant1 = FromValue(2, false); AbstractBasePtr abstract_v2_variant1 = FromValue(2, false);
@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
iter = cache.find(args_spec_list_variant1); iter = cache.find(args_spec_list_variant1);
ASSERT_TRUE(iter != cache.end()); ASSERT_TRUE(iter != cache.end());
ASSERT_TRUE(iter->second == abstract_val); ASSERT_TRUE(iter->second->abstract() == abstract_val);
AbstractBasePtr abstract_v1_variant2 = FromValue(1, false); AbstractBasePtr abstract_v1_variant2 = FromValue(1, false);
AbstractBasePtr abstract_v2_variant2 = FromValue(3, false); AbstractBasePtr abstract_v2_variant2 = FromValue(3, false);
@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
std::vector<int> shape = {2, 2, 6, 6}; std::vector<int> shape = {2, 2, 6, 6};
expected->set_shape(std::make_shared<Shape>(shape)); expected->set_shape(std::make_shared<Shape>(shape));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString();
@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
AbstractBasePtr abstract_x = FromValue(x, false); AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
} }
@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
AbstractBasePtr abstract_x = FromValue(x, false); AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
} }
@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y); args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
} }
@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y); args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
} }
@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y); args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
} }
@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y); args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
} }

View File

@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) {
auto prim_typeof = std::make_shared<Primitive>("typeof"); auto prim_typeof = std::make_shared<Primitive>("typeof");
FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
res->dump(); res->dump();
TypePtr res_value = res->GetValueTrack()->cast<TypePtr>(); TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
res_value->dump(); res_value->dump();
@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) {
auto prim_list_map = std::make_shared<Primitive>("list_map"); auto prim_list_map = std::make_shared<Primitive>("list_map");
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)})); auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)}));
res->dump(); res->dump();
MS_LOG(INFO) << "result res: " << res->ToString(); MS_LOG(INFO) << "result res: " << res->ToString();
@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) {
auto prim_list_reduce = std::make_shared<Primitive>("list_reduce"); auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
res->dump(); res->dump();
TypePtr res_type = res->GetTypeTrack(); TypePtr res_type = res->GetTypeTrack();
res_type->dump(); res_type->dump();
@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) {
auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array"); auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
res->dump(); res->dump();
TypePtr res_type = res->BuildType(); TypePtr res_type = res->BuildType();
res_type->dump(); res_type->dump();
@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) {
auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar"); auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
res->dump(); res->dump();
TypePtr res_type = res->BuildType(); TypePtr res_type = res->BuildType();
res_type->dump(); res_type->dump();
@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) {
auto prim_J = std::make_shared<Primitive>("J"); auto prim_J = std::make_shared<Primitive>("J");
FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res); AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
ASSERT_TRUE(res_J != nullptr); ASSERT_TRUE(res_J != nullptr);
ASSERT_TRUE(*(res_J->element()) == *abstract_v1); ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) {
int v1 = 1; int v1 = 1;
AbstractBasePtr abstract_v1 = FromValue(v1, false); AbstractBasePtr abstract_v1 = FromValue(v1, false);
AbstractBasePtrList args_spec_list = {abstract_v1}; AbstractBasePtrList args_spec_list = {abstract_v1};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
res->dump(); res->dump();
AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res); AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
ASSERT_TRUE(res_J != nullptr); ASSERT_TRUE(res_J != nullptr);
@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) {
AbstractBasePtrList args_spec_list = {a1, a2}; AbstractBasePtrList args_spec_list = {a1, a2};
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred); AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack()))); ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
} }
@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) {
AbstractBasePtr arg2 = FromValue(2, false); AbstractBasePtr arg2 = FromValue(2, false);
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *arg1); ASSERT_TRUE(*res == *arg1);
} }
@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) {
AbstractBasePtr arg2 = FromValue(2, false); AbstractBasePtr arg2 = FromValue(2, false);
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "make result res: " << res->ToString(); MS_LOG(INFO) << "make result res: " << res->ToString();
MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
ASSERT_TRUE(*res == *arg2); ASSERT_TRUE(*res == *arg2);
@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) {
AbstractBasePtr abstract_v1 = FromValue(1, false); AbstractBasePtr abstract_v1 = FromValue(1, false);
AbstractBasePtrList args_spec_list = {abstract_v1}; AbstractBasePtrList args_spec_list = {abstract_v1};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *abstract_v1); ASSERT_TRUE(*res == *abstract_v1);
} }
@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) {
AbstractBasePtrList args_spec_list = {a, b}; AbstractBasePtrList args_spec_list = {a, b};
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) {
AbstractBasePtr abstract_v2 = FromValue(1, false); AbstractBasePtr abstract_v2 = FromValue(1, false);
AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
auto expected = std::make_shared<PartialAbstractClosure>( auto expected = std::make_shared<PartialAbstractClosure>(
std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list); std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtr abstract_x = FromValue(1, false);
AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) {
AbstractBasePtr abstract_y = FromValue(2, false); AbstractBasePtr abstract_y = FromValue(2, false);
args_spec_list = {abstract_env, embed_x, abstract_y}; args_spec_list = {abstract_env, embed_x, abstract_y};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp); ASSERT_TRUE(*res == *exp);
} }
@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtr abstract_x = FromValue(1, false);
AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) {
AbstractBasePtr abstract_y = FromValue(2, false); AbstractBasePtr abstract_y = FromValue(2, false);
args_spec_list = {abstract_env, embed_x, abstract_y}; args_spec_list = {abstract_env, embed_x, abstract_y};
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*res == *exp); ASSERT_TRUE(*res == *exp);
@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) {
AbstractBasePtr abstract_z = FromValue(3, false); AbstractBasePtr abstract_z = FromValue(3, false);
args_spec_list = {res, embed_x, abstract_z}; args_spec_list = {res, embed_x, abstract_z};
res = engine_->Run(graph_getitem, args_spec_list).inferred; res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *abstract_x); ASSERT_TRUE(*res == *abstract_x);
} }
@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) {
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtr abstract_x = FromValue(1, false);
AbstractBasePtrList args_spec_list = {abstract_x}; AbstractBasePtrList args_spec_list = {abstract_x};
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) {
AbstractBasePtr abstract_y = FromValue(2, false); AbstractBasePtr abstract_y = FromValue(2, false);
args_spec_list = {abstract_env, embed_x, abstract_y}; args_spec_list = {abstract_env, embed_x, abstract_y};
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred; AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>()); AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
ASSERT_TRUE(*abstract_e1 == *exp); ASSERT_TRUE(*abstract_e1 == *exp);
AbstractBasePtr abstract_z = FromValue(3, false); AbstractBasePtr abstract_z = FromValue(3, false);
args_spec_list = {abstract_env, embed_x, abstract_z}; args_spec_list = {abstract_env, embed_x, abstract_z};
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred; AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
ASSERT_TRUE(*abstract_e2 == *exp); ASSERT_TRUE(*abstract_e2 == *exp);
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
args_spec_list = {abstract_e1, abstract_e2}; args_spec_list = {abstract_e1, abstract_e2};
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *exp); ASSERT_TRUE(*res == *exp);
} }
@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) {
AbstractBasePtrList args_spec_list = {a}; AbstractBasePtrList args_spec_list = {a};
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred); AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value(); auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)}; std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) {
AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
AbstractBasePtrList args_spec_list = {expected}; AbstractBasePtrList args_spec_list = {expected};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) {
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
AbstractBasePtrList args_spec_list = {arr}; AbstractBasePtrList args_spec_list = {arr};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTensor>(ret); auto res = dyn_cast<AbstractTensor>(ret);
ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
} }
@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) {
std::vector<int> shape = {2, 64, 14, 14}; std::vector<int> shape = {2, 64, 14, 14};
expected->set_shape(std::make_shared<Shape>(shape)); expected->set_shape(std::make_shared<Shape>(shape));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString();
@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) {
auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
AbstractBasePtrList args_spec_list = {input, weight}; AbstractBasePtrList args_spec_list = {input, weight};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTensor>(ret); auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) {
auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
AbstractBasePtrList args_spec_list = {input, weight}; AbstractBasePtrList args_spec_list = {input, weight};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTensor>(ret); auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) {
auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
AbstractBasePtrList args_spec_list = {value, bias}; AbstractBasePtrList args_spec_list = {value, bias};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTensor>(ret); auto res = dyn_cast<AbstractTensor>(ret);
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
AbstractBasePtrList args_spec_list = {logits, labels}; AbstractBasePtrList args_spec_list = {logits, labels};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_NE(ret, nullptr); ASSERT_NE(ret, nullptr);
auto res = dyn_cast<AbstractTuple>(ret); auto res = dyn_cast<AbstractTuple>(ret);
auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) {
auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
AbstractBasePtrList args_spec_list = {logits, labels}; AbstractBasePtrList args_spec_list = {logits, labels};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractScalar>(ret); auto res = dyn_cast<AbstractScalar>(ret);
AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64); AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
expected->set_type(UTPrimUtils::kF64); expected->set_type(UTPrimUtils::kF64);
@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) {
AbstractBasePtr expected0 = abstract_inputs->Clone(); AbstractBasePtr expected0 = abstract_inputs->Clone();
AbstractBasePtr expected1 = abstract_scale->Clone(); AbstractBasePtr expected1 = abstract_scale->Clone();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected0: " << expected0->ToString(); MS_LOG(INFO) << "expected0: " << expected0->ToString();
MS_LOG(INFO) << "expected1: " << expected1->ToString(); MS_LOG(INFO) << "expected1: " << expected1->ToString();
@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) {
inputs->set_shape(inputs_dims); inputs->set_shape(inputs_dims);
AbstractBasePtr abstract_input = FromValue(inputs, false); AbstractBasePtr abstract_input = FromValue(inputs, false);
AbstractBasePtrList args_spec_list = {abstract_input}; AbstractBasePtrList args_spec_list = {abstract_input};
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractBasePtr expected = abstract_input->Clone()->Broaden(); AbstractBasePtr expected = abstract_input->Clone()->Broaden();
std::vector<int> expected_dims = {8, 64, 2, 2}; std::vector<int> expected_dims = {8, 64, 2, 2};
@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) {
auto prim = std::make_shared<Primitive>("hastype"); auto prim = std::make_shared<Primitive>("hastype");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) {
auto prim = std::make_shared<Primitive>("array_len"); auto prim = std::make_shared<Primitive>("array_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) {
auto prim = std::make_shared<Primitive>("list_len"); auto prim = std::make_shared<Primitive>("list_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) {
auto prim = std::make_shared<Primitive>("tuple_len"); auto prim = std::make_shared<Primitive>("tuple_len");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) {
auto prim = std::make_shared<Primitive>("tuple_reversed"); auto prim = std::make_shared<Primitive>("tuple_reversed");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "expect=" << expected->ToString(); MS_LOG(INFO) << "expect=" << expected->ToString();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) {
auto prim = std::make_shared<Primitive>("list_getitem"); auto prim = std::make_shared<Primitive>("list_getitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *elem); ASSERT_TRUE(*res == *elem);
} }
@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) {
auto prim = std::make_shared<Primitive>("list_setitem"); auto prim = std::make_shared<Primitive>("list_setitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
AbstractBasePtrList elems_exp = {elem1, elem2}; AbstractBasePtrList elems_exp = {elem1, elem2};
auto expected = std::make_shared<AbstractList>(elems_exp); auto expected = std::make_shared<AbstractList>(elems_exp);
@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) {
auto prim = std::make_shared<Primitive>("list_append"); auto prim = std::make_shared<Primitive>("list_append");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2})); auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
MS_LOG(INFO) << "expected: " << expected->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString();
@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) {
auto prim = std::make_shared<Primitive>("tuple_setitem"); auto prim = std::make_shared<Primitive>("tuple_setitem");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
AbstractBasePtrList elems_exp = {elem1, elem2}; AbstractBasePtrList elems_exp = {elem1, elem2};
auto expected = std::make_shared<AbstractTuple>(elems_exp); auto expected = std::make_shared<AbstractTuple>(elems_exp);
@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) {
auto prim = std::make_shared<Primitive>("make_list"); auto prim = std::make_shared<Primitive>("make_list");
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
} }
@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) {
AbstractBasePtrList elem_list({ele1, ele2, ele3}); AbstractBasePtrList elem_list({ele1, ele2, ele3});
AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list); AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "res=" << res->ToString(); MS_LOG(INFO) << "res=" << res->ToString();
MS_LOG(INFO) << "expected=" << expected->ToString(); MS_LOG(INFO) << "expected=" << expected->ToString();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) {
AbstractBasePtr expected1 = abstract_mean_var->Clone(); AbstractBasePtr expected1 = abstract_mean_var->Clone();
AbstractBasePtr expected2 = abstract_mean_var->Clone(); AbstractBasePtr expected2 = abstract_mean_var->Clone();
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected0: " << expected0->ToString(); MS_LOG(INFO) << "expected0: " << expected0->ToString();
MS_LOG(INFO) << "expected1: " << expected1->ToString(); MS_LOG(INFO) << "expected1: " << expected1->ToString();
@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) {
AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8), AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int>{79})); std::make_shared<Shape>(std::vector<int>{79}));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "res=" << res->ToString(); MS_LOG(INFO) << "res=" << res->ToString();
MS_LOG(INFO) << "expected=" << expected->ToString(); MS_LOG(INFO) << "expected=" << expected->ToString();
ASSERT_TRUE(*res == *expected); ASSERT_TRUE(*res == *expected);
@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) {
std::vector<int> shape = {2, 20, 32, 32}; std::vector<int> shape = {2, 20, 32, 32};
expected->set_shape(std::make_shared<Shape>(shape)); expected->set_shape(std::make_shared<Shape>(shape));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString();
@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
auto x_input = std::make_shared<AbstractTuple>(x_arg_list); auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
auto y_input = std::make_shared<AbstractTuple>(y_arg_list); auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
AbstractBasePtrList args_spec_list = {x_input, y_input}; AbstractBasePtrList args_spec_list = {x_input, y_input};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTuple>(ret); auto res = dyn_cast<AbstractTuple>(ret);
AbstractBasePtrList x_idx_list; AbstractBasePtrList x_idx_list;
auto r_x = std::make_shared<AbstractTuple>(x_idx_list); auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
auto x_input = std::make_shared<AbstractTuple>(x_arg_list); auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
auto y_input = std::make_shared<AbstractTuple>(y_arg_list); auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
AbstractBasePtrList args_spec_list = {x_input, y_input}; AbstractBasePtrList args_spec_list = {x_input, y_input};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
auto res = dyn_cast<AbstractTuple>(ret); auto res = dyn_cast<AbstractTuple>(ret);
AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
auto r_x = std::make_shared<AbstractTuple>(x_idx_list); auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) {
AbstractBasePtr key = abstract::FromValue("x"); AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key}; AbstractBasePtrList args_spec_list = {array_dict, key};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second)); AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
AbstractBasePtr key = abstract::FromValue("x"); AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key}; AbstractBasePtrList args_spec_list = {array_dict, key};
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret); AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x); AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);

View File

@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
} }
@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) {
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
AbstractBasePtr abstract_v1 = FromValue(1, false); AbstractBasePtr abstract_v1 = FromValue(1, false);
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
// now this test case failed randomly, have to debug. // now this test case failed randomly, have to debug.
@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) {
args_spec_list.clear(); args_spec_list.clear();
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2); args_spec_list.push_back(abstract_v2);
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred; abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
} }
@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
AbstractBasePtr abstract_v2 = FromValue(v1, false); AbstractBasePtr abstract_v2 = FromValue(v1, false);
args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2); args_spec_list.push_back(abstract_v2);
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
} }
@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add"); auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred; AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32);
} }
@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
AbstractBasePtr base1 = FromValue(x1, false); AbstractBasePtr base1 = FromValue(x1, false);
AbstractBasePtr base2 = FromValue(x2, false); AbstractBasePtr base2 = FromValue(x2, false);
AbstractBasePtrList base_list = {base1, base2}; AbstractBasePtrList base_list = {base1, base2};
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract();
MS_LOG(INFO) << "result spec: " << res->ToString(); MS_LOG(INFO) << "result spec: " << res->ToString();
AbstractBasePtr exp = FromValue(x3, false); AbstractBasePtr exp = FromValue(x3, false);
MS_LOG(INFO) << "result exp: " << exp->ToString(); MS_LOG(INFO) << "result exp: " << exp->ToString();
@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
TEST_F(TestGraphInfer, test_graph_infer_defaults) { TEST_F(TestGraphInfer, test_graph_infer_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(50), false); AbstractBasePtr expect = FromValue(MakeValue(50), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(1), false); AbstractBasePtr expect = FromValue(MakeValue(1), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
TEST_F(TestGraphInfer, test_graph_infer_vararg) { TEST_F(TestGraphInfer, test_graph_infer_vararg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(9), false); AbstractBasePtr expect = FromValue(MakeValue(9), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(48), false); AbstractBasePtr expect = FromValue(MakeValue(48), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
TEST_F(TestGraphInfer, test_graph_infer_kwarg) { TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(7), false); AbstractBasePtr expect = FromValue(MakeValue(7), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(46), false); AbstractBasePtr expect = FromValue(MakeValue(46), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }
@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
AbstractBasePtrList args_spec_list = {}; AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(57), false); AbstractBasePtr expect = FromValue(MakeValue(57), false);
ASSERT_EQ(*res, *expect); ASSERT_EQ(*res, *expect);
} }

View File

@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \ from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config import pipeline_for_verify_exception_for_case_by_case_config
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def conv3x3(in_channels, out_channels, stride=1, padding=1): def conv3x3(in_channels, out_channels, stride=1, padding=1):
"""3x3 convolution """ """3x3 convolution """
@ -377,6 +378,21 @@ class StateNet(nn.Cell):
return x return x
def test_conv2d_same_primitive():
class Conv2DSameNet(nn.Cell):
def __init__(self):
super(Conv2DSameNet, self).__init__()
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
r1 = self.conv1(x)
r2 = self.conv2(y)
return (r1, r2)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = Conv2DSameNet()
out = net(t1, t2)
class ComparisonNet(nn.Cell): class ComparisonNet(nn.Cell):
def __init__(self): def __init__(self):
""" ComparisonNet definition """ """ ComparisonNet definition """

View File

@ -0,0 +1,276 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import functools
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.primitive import constexpr
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class FakeOp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
""""""
def infer_shape(self, x, y):
self.second_shape = y
self.add_prim_attr("second_shape", y)
return x
def infer_dtype(self, x, y):
return x
# test the normal case that should generate independent primitive because of different
# generated attributes after inference
def test_conv2d_same_primitive():
class Conv2DSameNet(nn.Cell):
def __init__(self):
super(Conv2DSameNet, self).__init__()
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
r1 = self.conv1(x)
r2 = self.conv2(y)
return (r1, r2)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = Conv2DSameNet()
out = net(t1, t2)
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_conv2d_op_with_arg():
class Conv2dNet(nn.Cell):
def __init__(self):
super(Conv2dNet, self).__init__()
def construct(self, op, x):
return op(x)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
conv_op = self.conv2
a = self.opnet(conv_op, x)
b = self.opnet(conv_op, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(Conv2dNet())
out = net(t1, t2)
def test_conv2d_op_with_arg():
class FackOpNet(nn.Cell):
def __init__(self):
super(FackOpNet, self).__init__()
self.op = FakeOp()
def construct(self, x, y):
return self.op(x, y)
class OpNet(nn.Cell):
def __init__(self):
super(OpNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FackOpNet()
def construct(self, x, y):
op = self.op
a = self.opnet(op, x, y)
b = self.opnet(op, y, x)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(OpNet())
out = net(t1, t2)
def test_conv2d_op_with_arg_same_input():
class FackOpNet(nn.Cell):
def __init__(self):
super(FackOpNet, self).__init__()
self.op = FakeOp()
def construct(self, x, y):
return self.op(x, y)
class OpNet(nn.Cell):
def __init__(self):
super(OpNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FackOpNet()
def construct(self, x, y):
op = self.op
a = self.opnet(op, x, x)
b = self.opnet(op, y, x)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(OpNet())
out = net(t1, t2)
# test op with partial
def test_op_as_partial():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = partial_op(y)
b = partial_op(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpAsPartial()
out = net(t1, t2, t3)
# test op with partial
def test_op_as_partial_inside():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = partial_op(y)
b = partial_op(z)
return a, b
class OuterNet(nn.Cell):
def __init__(self):
super(OuterNet, self).__init__()
self.net = OpAsPartial()
def construct(self, x, y, z):
a,b = self.net(x, y, z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OuterNet()
out = net(t1, t2, t3)
# test op with partial case 2
def test_op_as_partial_independent():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op1 = F.partial(self.op, x)
a = partial_op1(y)
partial_op2 = F.partial(self.op, x)
b = partial_op2(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpAsPartial()
out = net(t1, t2, t3)
def test_nest_partial():
class NestPartial(nn.Cell):
def __init__(self):
super(NestPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op1 = F.partial(self.op)
partial_op2 = F.partial(partial_op1, x)
a = partial_op2(y)
partial_op3 = F.partial(self.op)
partial_op4 = F.partial(partial_op3, x)
b = partial_op4(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = NestPartial()
out = net(t1, t2, t3)
# high order argument
# op and op args as network arguments
def test_op_with_arg_as_input():
class WithOpArgNet(nn.Cell):
def __init__(self):
super(WithOpArgNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FakeOp()
def construct(self, x, y, z):
op = self.op
a = self.opnet(op, x, z)
b = self.opnet(op, x, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpsNet(WithOpArgNet())
out = net(t1, t2, t3)
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_partial_as_arg():
class PartialArgNet(nn.Cell):
def __init__(self):
super(PartialArgNet, self).__init__()
def construct(self, partial_op, y):
return partial_op(y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.partial_net = net
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = self.partial_net(partial_op, z)
b = self.partial_net(partial_op, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpsNet(PartialArgNet())
out = net(t1, t2, t3)