forked from mindspore-Ecosystem/mindspore
!1383 keep different attributes for cnode evaluation
Merge pull request !1383 from amongo/KeepPrimAttrInCNode
This commit is contained in:
commit
5b9c145ff8
|
@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
|
||||||
auto ctx = node_cfg_->context();
|
auto ctx = node_cfg_->context();
|
||||||
auto engine = node_cfg_->engine();
|
auto engine = node_cfg_->engine();
|
||||||
auto cfg = engine->MakeConfig(node, ctx);
|
auto cfg = engine->MakeConfig(node, ctx);
|
||||||
auto abs = engine->cache().GetValue(cfg);
|
auto eval_result = engine->cache().GetValue(cfg);
|
||||||
if (abs == nullptr) {
|
if (eval_result == nullptr || eval_result->abstract() == nullptr) {
|
||||||
return "Undefined";
|
return "Undefined";
|
||||||
}
|
}
|
||||||
|
auto abs = eval_result->abstract();
|
||||||
auto dtype = abs->BuildType();
|
auto dtype = abs->BuildType();
|
||||||
auto shape = abs->BuildShape();
|
auto shape = abs->BuildShape();
|
||||||
std::ostringstream oss;
|
std::ostringstream oss;
|
||||||
|
|
|
@ -42,7 +42,11 @@ enum PrimType {
|
||||||
class Primitive : public Named {
|
class Primitive : public Named {
|
||||||
public:
|
public:
|
||||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||||
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {}
|
: Named(name),
|
||||||
|
is_base_(is_base),
|
||||||
|
has_signature_(false),
|
||||||
|
prim_type_(prim_type),
|
||||||
|
record_evaluate_add_attr_(false) {}
|
||||||
|
|
||||||
Primitive(const Primitive &prim)
|
Primitive(const Primitive &prim)
|
||||||
: Named(prim),
|
: Named(prim),
|
||||||
|
@ -50,14 +54,23 @@ class Primitive : public Named {
|
||||||
instance_name_(prim.instance_name_),
|
instance_name_(prim.instance_name_),
|
||||||
is_base_(prim.is_base_),
|
is_base_(prim.is_base_),
|
||||||
has_signature_(prim.has_signature_),
|
has_signature_(prim.has_signature_),
|
||||||
prim_type_(prim.prim_type_) {}
|
prim_type_(prim.prim_type_),
|
||||||
|
record_evaluate_add_attr_(false) {}
|
||||||
|
|
||||||
MS_DECLARE_PARENT(Primitive, Named);
|
MS_DECLARE_PARENT(Primitive, Named);
|
||||||
|
|
||||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||||
std::string ToString() const override { return name(); }
|
std::string ToString() const override { return name(); }
|
||||||
|
void BeginRecordAddAttr() {
|
||||||
|
evaluate_added_attrs_.clear();
|
||||||
|
record_evaluate_add_attr_ = true;
|
||||||
|
}
|
||||||
|
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
|
||||||
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||||
attrs_[name] = attr;
|
attrs_[name] = attr;
|
||||||
|
if (record_evaluate_add_attr_) {
|
||||||
|
evaluate_added_attrs_[name] = attr;
|
||||||
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +93,7 @@ class Primitive : public Named {
|
||||||
py::function hook() const { return hook_; }
|
py::function hook() const { return hook_; }
|
||||||
|
|
||||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||||
|
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
|
||||||
|
|
||||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||||
bool HasAttr() const { return !attrs_.empty(); }
|
bool HasAttr() const { return !attrs_.empty(); }
|
||||||
|
@ -106,6 +120,7 @@ class Primitive : public Named {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||||
|
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string instance_name_;
|
std::string instance_name_;
|
||||||
|
@ -113,6 +128,7 @@ class Primitive : public Named {
|
||||||
bool is_base_;
|
bool is_base_;
|
||||||
bool has_signature_;
|
bool has_signature_;
|
||||||
PrimType prim_type_;
|
PrimType prim_type_;
|
||||||
|
bool record_evaluate_add_attr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||||
|
|
|
@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
|
||||||
}
|
}
|
||||||
subargs.push_back(AbstractJoin(l_ptr->elements()));
|
subargs.push_back(AbstractJoin(l_ptr->elements()));
|
||||||
}
|
}
|
||||||
AbstractBasePtr engin_exc = engine->Execute(fn, subargs);
|
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
|
||||||
AbstractBasePtrList result;
|
AbstractBasePtrList result;
|
||||||
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||||
result.push_back(engin_exc);
|
result.push_back(engin_exc->abstract());
|
||||||
}
|
}
|
||||||
return std::make_shared<AbstractList>(result);
|
return std::make_shared<AbstractList>(result);
|
||||||
}
|
}
|
||||||
|
@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
||||||
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
||||||
auto result1 = engine->Execute(fn, lst->elements());
|
auto result1 = engine->Execute(fn, lst->elements());
|
||||||
auto result2 = engine->Execute(fn, {dflt, list_type});
|
auto result2 = engine->Execute(fn, {dflt, list_type});
|
||||||
MS_EXCEPTION_IF_NULL(result1);
|
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||||
return result1->Join(result2);
|
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||||
|
return result1->abstract()->Join(result2->abstract());
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
|
||||||
return sorted_nodes;
|
return sorted_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||||
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
std::size_t nargs = fg->parameters().size();
|
std::size_t nargs = fg->parameters().size();
|
||||||
|
@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
|
||||||
const auto &arg = args_spec_list[i];
|
const auto &arg = args_spec_list[i];
|
||||||
const auto &node = parameters[i];
|
const auto &node = parameters[i];
|
||||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
|
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
|
||||||
engine->cache().set_value(conf, arg);
|
engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||||
}
|
}
|
||||||
const AnfNodePtr &func_node = fg->get_return();
|
const AnfNodePtr &func_node = fg->get_return();
|
||||||
|
|
||||||
|
@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
|
||||||
const auto &node = *it;
|
const auto &node = *it;
|
||||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
||||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
|
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
|
||||||
ret_base = engine->GetEvaluatedValue(node_conf);
|
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
|
||||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
|
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
|
||||||
<< ", abstract: " << ret_base->ToString();
|
<< ", abstract: " << ret_base->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(ret_base);
|
MS_EXCEPTION_IF_NULL(ret_base);
|
||||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString();
|
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
|
||||||
return ret_base;
|
return std::make_shared<EvalResult>(ret_base, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||||
|
@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
||||||
return cloned_func_graph;
|
return cloned_func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
|
||||||
AnfNodeConfigPtr out_conf) {
|
|
||||||
const std::string &evaluator_name = ToString();
|
const std::string &evaluator_name = ToString();
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
args_spec_list = NormalizeArgs(args_spec_list);
|
args_spec_list = NormalizeArgs(args_spec_list);
|
||||||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||||
|
@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
|
||||||
auto iter = cache_->find(args_spec_list);
|
auto iter = cache_->find(args_spec_list);
|
||||||
if (iter == cache_->end()) {
|
if (iter == cache_->end()) {
|
||||||
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
||||||
AbstractBasePtr ret = Eval(engine, args_spec_list);
|
EvalResultPtr ret = Eval(engine, args_spec_list);
|
||||||
if (ret == nullptr) {
|
if (ret->abstract() == nullptr) {
|
||||||
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||||
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(ret);
|
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
|
||||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
|
|
||||||
(*cache_)[args_spec_list] = ret;
|
(*cache_)[args_spec_list] = ret;
|
||||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||||
return ret;
|
return ret;
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION_IF_NULL(iter->second);
|
MS_EXCEPTION_IF_NULL(iter->second);
|
||||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
|
MS_EXCEPTION_IF_NULL(iter->second->abstract());
|
||||||
|
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
|
||||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr) {
|
AnfNodeConfigPtr) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
|
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
if (args_conf_list.size() == 0) {
|
if (args_conf_list.size() == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
||||||
}
|
}
|
||||||
AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||||
// No need to cache.
|
// No need to cache.
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||||
AbstractBasePtr ret = EvalPrim(args_conf_list);
|
EvalResultPtr ret = EvalPrim(args_conf_list);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
||||||
// Don't lookup from cache, as different out_conf with same node but different context
|
// Don't lookup from cache, as different out_conf with same node but different context
|
||||||
// may add different entry to anfnode_config_map_, like getattr primitive.
|
// may add different entry to anfnode_config_map_, like getattr primitive.
|
||||||
(*cache_)[args_spec_list] = ret;
|
(*cache_)[args_spec_list] = ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
MS_EXCEPTION_IF_NULL(cache_);
|
MS_EXCEPTION_IF_NULL(cache_);
|
||||||
auto iter = cache_->find(args_spec_list);
|
auto iter = cache_->find(args_spec_list);
|
||||||
|
@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
|
||||||
|
|
||||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
|
||||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||||
AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
||||||
|
|
||||||
(*cache_)[args_spec_list] = ret;
|
(*cache_)[args_spec_list] = ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
MS_EXCEPTION_IF_NULL(cache_);
|
MS_EXCEPTION_IF_NULL(cache_);
|
||||||
auto iter = cache_->find(args_spec_list);
|
auto iter = cache_->find(args_spec_list);
|
||||||
|
@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the original evaluator, get the result: y = f(x)
|
// Call the original evaluator, get the result: y = f(x)
|
||||||
AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||||
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
|
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
|
||||||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||||
AbstractBasePtrList bparams;
|
AbstractBasePtrList bparams;
|
||||||
|
@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
|
||||||
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||||
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
|
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
|
||||||
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
||||||
AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final);
|
AbstractFunctionPtr bprop =
|
||||||
|
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
|
||||||
|
|
||||||
// J(f)(J(x)) return a tuple (y, bprop_f)
|
// J(f)(J(x)) return a tuple (y, bprop_f)
|
||||||
AbstractBasePtrList jargs = {result, bprop};
|
AbstractBasePtrList jargs = {result->abstract(), bprop};
|
||||||
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
|
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
|
||||||
(*cache_)[args_spec_list] = jtuple;
|
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
|
||||||
return jtuple;
|
(*cache_)[args_spec_list] = infer_reuslt;
|
||||||
|
return infer_reuslt;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
||||||
if (args_spec_list.size() != args_spec_list_.size()) {
|
if (args_spec_list.size() != args_spec_list_.size()) {
|
||||||
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
|
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
|
||||||
<< ", arguments no: " << args_spec_list.size();
|
<< ", arguments no: " << args_spec_list.size();
|
||||||
|
@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||||
(void)args_spec_list[i]->Join(args_spec_list_[i]);
|
(void)args_spec_list[i]->Join(args_spec_list_[i]);
|
||||||
}
|
}
|
||||||
return output_;
|
return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,21 +29,28 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
using EvaluatorCacheMap =
|
using EvaluatorCacheMap =
|
||||||
std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||||
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
|
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
|
||||||
|
|
||||||
|
using EvaluatorAttrMap =
|
||||||
|
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||||
|
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
|
||||||
|
|
||||||
class Evaluator : public Base {
|
class Evaluator : public Base {
|
||||||
public:
|
public:
|
||||||
explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {}
|
explicit Evaluator(const std::string &id)
|
||||||
|
: cache_(std::make_shared<EvaluatorCacheMap>()),
|
||||||
|
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
|
||||||
|
identifier_(id) {}
|
||||||
~Evaluator() override = default;
|
~Evaluator() override = default;
|
||||||
MS_DECLARE_PARENT(Evaluator, Base);
|
MS_DECLARE_PARENT(Evaluator, Base);
|
||||||
|
|
||||||
// difference between Run() and Eval():
|
// difference between Run() and Eval():
|
||||||
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
|
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
|
||||||
// Run() will modify cache_ member, so it cannot marked as const;
|
// Run() will modify cache_ member, so it cannot marked as const;
|
||||||
virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
|
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
|
||||||
|
|
||||||
virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||||
|
|
||||||
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
|
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
|
||||||
|
|
||||||
|
@ -58,9 +65,10 @@ class Evaluator : public Base {
|
||||||
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
|
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
|
||||||
|
|
||||||
EvaluatorCacheMapPtr &cache() { return cache_; }
|
EvaluatorCacheMapPtr &cache() { return cache_; }
|
||||||
|
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
|
||||||
|
|
||||||
EvaluatorCacheMapPtr cache_;
|
EvaluatorCacheMapPtr cache_;
|
||||||
|
EvaluatorAttrMapPtr attr_cache_;
|
||||||
std::string identifier_;
|
std::string identifier_;
|
||||||
|
|
||||||
AnfNodeWeakPtr bound_node_;
|
AnfNodeWeakPtr bound_node_;
|
||||||
|
@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator {
|
||||||
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
|
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
|
||||||
~PrimEvaluator() override = default;
|
~PrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
|
MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
|
||||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
|
||||||
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||||
~TrivialPrimEvaluator() override = default;
|
~TrivialPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||||
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
|
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TransitionPrimEvaluator : public PrimEvaluator {
|
class TransitionPrimEvaluator : public PrimEvaluator {
|
||||||
|
@ -90,9 +98,9 @@ class TransitionPrimEvaluator : public PrimEvaluator {
|
||||||
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||||
~TransitionPrimEvaluator() override = default;
|
~TransitionPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
|
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||||
// Parameter in_conf0 : the first element in args_conf_list;
|
// Parameter in_conf0 : the first element in args_conf_list;
|
||||||
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
|
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
||||||
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||||
~SymbolicPrimEvaluator() override = default;
|
~SymbolicPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||||
virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Evaluator will be stored in AnalysisEngine.constructors_
|
// Evaluator will be stored in AnalysisEngine.constructors_
|
||||||
|
@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator {
|
||||||
DummyEvaluator() : Evaluator("dummy") {}
|
DummyEvaluator() : Evaluator("dummy") {}
|
||||||
~DummyEvaluator() override = default;
|
~DummyEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
|
MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrap another evaluator to track a subset of uses.
|
// Wrap another evaluator to track a subset of uses.
|
||||||
|
@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator {
|
||||||
bound_node_ = AnfNodeWeakPtr(node);
|
bound_node_ = AnfNodeWeakPtr(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||||
AnfNodeConfigPtr out_conf) override;
|
|
||||||
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
|
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
||||||
~BaseFuncGraphEvaluator() override = default;
|
~BaseFuncGraphEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
|
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||||
|
|
||||||
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||||
|
|
||||||
|
@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator {
|
||||||
}
|
}
|
||||||
bound_node_ = AnfNodeWeakPtr(node);
|
bound_node_ = AnfNodeWeakPtr(node);
|
||||||
}
|
}
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
|
||||||
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||||
AnfNodeConfigPtr out_conf) override;
|
|
||||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator {
|
||||||
~VirtualEvaluator() override = default;
|
~VirtualEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
|
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||||
std::string ToString() const override { return identifier_; }
|
std::string ToString() const override { return identifier_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -285,11 +292,11 @@ class JEvaluator : public Evaluator {
|
||||||
}
|
}
|
||||||
bound_node_ = AnfNodeWeakPtr(node);
|
bound_node_ = AnfNodeWeakPtr(node);
|
||||||
}
|
}
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
|
||||||
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||||
AnfNodeConfigPtr out_conf) override;
|
|
||||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -135,12 +135,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
|
|
||||||
using mindspore::parse::PyObjectWrapper;
|
using mindspore::parse::PyObjectWrapper;
|
||||||
|
|
||||||
AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||||
|
prim_->BeginRecordAddAttr();
|
||||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||||
return abs_base;
|
prim_->EndRecordAddAttr();
|
||||||
|
auto added_attrs = prim_->evaluate_added_attrs();
|
||||||
|
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
|
return infer_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
|
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
|
||||||
|
@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
|
||||||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||||
|
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||||
|
|
||||||
ScopePtr scope = kDefaultScope;
|
ScopePtr scope = kDefaultScope;
|
||||||
if (out_conf != nullptr) {
|
if (out_conf != nullptr) {
|
||||||
|
@ -212,7 +216,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
||||||
return graph_specialize_args;
|
return graph_specialize_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf) {
|
AnfNodeConfigPtr out_conf) {
|
||||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||||
|
@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config
|
||||||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||||
AbstractBasePtrList args_spec_list;
|
AbstractBasePtrList args_spec_list;
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||||
// get the forward graph
|
// get the forward graph
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
||||||
|
@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
||||||
}
|
}
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||||
|
|
||||||
const auto &iter = cache_->find(args);
|
const auto &iter = cache_->find(args);
|
||||||
|
@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
|
||||||
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
|
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
|
||||||
}
|
}
|
||||||
auto infer_fuc = pyobj.attr("__infer__");
|
auto infer_fuc = pyobj.attr("__infer__");
|
||||||
|
prim_py_->BeginRecordAddAttr();
|
||||||
py::dict output = infer_fuc(*py_args);
|
py::dict output = infer_fuc(*py_args);
|
||||||
|
prim_py_->EndRecordAddAttr();
|
||||||
|
auto added_attrs = prim_py_->evaluate_added_attrs();
|
||||||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||||
(*cache_)[args] = res_spec;
|
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||||
return res_spec;
|
(*cache_)[args] = infer_result;
|
||||||
|
return infer_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||||
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
||||||
if (nargs_ != args.size()) {
|
if (nargs_ != args.size()) {
|
||||||
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
||||||
|
@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
|
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
|
||||||
return abs_base;
|
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
|
ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
|
||||||
|
@ -553,7 +560,7 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
|
||||||
manager->AddFuncGraph(func_graph);
|
manager->AddFuncGraph(func_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
|
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
|
||||||
const AnfNodeConfigPtr &old_conf) {
|
const AnfNodeConfigPtr &old_conf) {
|
||||||
MS_EXCEPTION_IF_NULL(old_conf);
|
MS_EXCEPTION_IF_NULL(old_conf);
|
||||||
|
|
||||||
|
@ -585,7 +592,7 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
|
||||||
return eng->ForwardConfig(old_conf, fn_conf);
|
return eng->ForwardConfig(old_conf, fn_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
|
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
|
||||||
const AbstractBasePtrList &args_spec_list,
|
const AbstractBasePtrList &args_spec_list,
|
||||||
const AnfNodeConfigPtr &out_conf) {
|
const AnfNodeConfigPtr &out_conf) {
|
||||||
// args_spec_list: same as StaticGetter
|
// args_spec_list: same as StaticGetter
|
||||||
|
@ -627,7 +634,7 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
|
||||||
return eng->ForwardConfig(out_conf, fn_conf);
|
return eng->ForwardConfig(out_conf, fn_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||||
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
|
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
|
||||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||||
if (args_spec_list.empty()) {
|
if (args_spec_list.empty()) {
|
||||||
|
@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
|
||||||
|
|
||||||
AbstractBasePtr attr = cls->GetAttribute(item_name);
|
AbstractBasePtr attr = cls->GetAttribute(item_name);
|
||||||
if (attr != nullptr) {
|
if (attr != nullptr) {
|
||||||
return attr;
|
return std::make_shared<EvalResult>(attr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
ValuePtr method = cls->GetMethod(item_name);
|
ValuePtr method = cls->GetMethod(item_name);
|
||||||
|
@ -660,7 +667,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
|
||||||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||||
const AnfNodeConfigPtr &out_conf) {
|
const AnfNodeConfigPtr &out_conf) {
|
||||||
MS_EXCEPTION_IF_NULL(item_v);
|
MS_EXCEPTION_IF_NULL(item_v);
|
||||||
|
@ -689,7 +696,7 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e
|
||||||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||||
// Inputs: namespace and its static function; or class and its member function
|
// Inputs: namespace and its static function; or class and its member function
|
||||||
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
||||||
|
@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
|
EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
|
||||||
~EmbedEvaluator() override = default;
|
~EmbedEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
|
MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||||
// arg: free variable to be embedded
|
// arg: free variable to be embedded
|
||||||
if (args_conf_list.size() != 1) {
|
if (args_conf_list.size() != 1) {
|
||||||
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
|
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
|
||||||
|
@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
||||||
MS_EXCEPTION_IF_NULL(node_conf);
|
MS_EXCEPTION_IF_NULL(node_conf);
|
||||||
|
|
||||||
AbstractBasePtr x = node_conf->GetEvaluatedValue();
|
AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
|
||||||
x = SensitivityTransform(x);
|
x = SensitivityTransform(x);
|
||||||
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
|
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
|
||||||
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
|
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
|
||||||
return abs_scalar;
|
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
|
RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
|
||||||
~RefToEmbedEvaluator() override = default;
|
~RefToEmbedEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
|
MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||||
if (args_conf_list.size() != 1) {
|
if (args_conf_list.size() != 1) {
|
||||||
MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
|
MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
|
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
AbstractBasePtr abs = node_conf->GetEvaluatedValue();
|
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
|
||||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||||
if (ref_abs == nullptr) {
|
if (ref_abs == nullptr) {
|
||||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
|
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
|
||||||
|
@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
}
|
}
|
||||||
auto refkey = key_value->cast<RefKeyPtr>();
|
auto refkey = key_value->cast<RefKeyPtr>();
|
||||||
if (refkey == nullptr) {
|
if (refkey == nullptr) {
|
||||||
return std::make_shared<AbstractScalar>(type);
|
return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string name = refkey->tag();
|
std::string name = refkey->tag();
|
||||||
|
@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
||||||
x = SensitivityTransform(x);
|
x = SensitivityTransform(x);
|
||||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||||
return abs_scalar;
|
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
||||||
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
|
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
|
||||||
~GetAttrEvaluator() override = default;
|
~GetAttrEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||||
// Inputs: data, item
|
// Inputs: data, item
|
||||||
if (args_spec_list.size() != 2) {
|
if (args_spec_list.size() != 2) {
|
||||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||||
}
|
}
|
||||||
AbstractBasePtr ret = nullptr;
|
EvalResultPtr ret = nullptr;
|
||||||
if (bound_node() != nullptr) {
|
if (bound_node() != nullptr) {
|
||||||
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||||
|
@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
|
||||||
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
|
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
|
||||||
~ResolveEvaluator() override = default;
|
~ResolveEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
|
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||||
// Inputs: namespace, symbol
|
// Inputs: namespace, symbol
|
||||||
if (args_spec_list.size() != 2) {
|
if (args_spec_list.size() != 2) {
|
||||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||||
}
|
}
|
||||||
AbstractBasePtr ret = nullptr;
|
EvalResultPtr ret = nullptr;
|
||||||
if (bound_node() != nullptr) {
|
if (bound_node() != nullptr) {
|
||||||
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||||
|
@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
|
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
|
||||||
~CreateInstanceEvaluator() override = default;
|
~CreateInstanceEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override {
|
const AnfNodeConfigPtr &out_conf) override {
|
||||||
if (args_spec_list.empty()) {
|
if (args_spec_list.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||||
}
|
}
|
||||||
|
@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
||||||
(*cache_)[args_spec_list] = ret;
|
auto infer_result = std::make_shared<EvalResult>(ret, nullptr);
|
||||||
return ret;
|
(*cache_)[args_spec_list] = infer_result;
|
||||||
|
return infer_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||||
|
@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator {
|
||||||
public:
|
public:
|
||||||
PartialEvaluator() : Evaluator("PartialEvaluator") {}
|
PartialEvaluator() : Evaluator("PartialEvaluator") {}
|
||||||
~PartialEvaluator() override = default;
|
~PartialEvaluator() override = default;
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||||
AnfNodeConfigPtr out_conf = nullptr) override {
|
AnfNodeConfigPtr out_conf = nullptr) override {
|
||||||
if (args_conf_list.size() == 0) {
|
if (args_conf_list.size() == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
|
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(out_conf);
|
MS_EXCEPTION_IF_NULL(out_conf);
|
||||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||||
|
auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
|
||||||
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
|
|
||||||
AbstractBasePtrList args_spec_list{arg0_value};
|
AbstractBasePtrList args_spec_list{arg0_value};
|
||||||
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
||||||
if (arg0_value->isa<AbstractError>()) {
|
if (arg0_value->isa<AbstractError>()) {
|
||||||
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
|
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
|
||||||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||||
<< " as func is: " << arg0_value->ToString();
|
<< " as func is: " << arg0_value->ToString();
|
||||||
(*cache_)[args_spec_list] = ret;
|
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||||
return ret;
|
(*cache_)[args_spec_list] = eval_result;
|
||||||
|
return eval_result;
|
||||||
}
|
}
|
||||||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||||
// Sometimes, node[0] in out_conf becomes phi0;
|
// Sometimes, node[0] in out_conf becomes phi0;
|
||||||
|
@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(
|
||||||
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); });
|
args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
|
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); });
|
||||||
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
|
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
|
||||||
|
|
||||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||||
|
@ -989,15 +999,16 @@ class PartialEvaluator : public Evaluator {
|
||||||
func->Visit(build_partial);
|
func->Visit(build_partial);
|
||||||
|
|
||||||
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
||||||
(*cache_)[args_spec_list] = ret;
|
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||||
return ret;
|
(*cache_)[args_spec_list] = infer_result;
|
||||||
|
return infer_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
|
EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
|
||||||
const AnfNodeConfigPtr &out_conf = nullptr) const {
|
const AnfNodeConfigPtr &out_conf = nullptr) const {
|
||||||
MS_EXCEPTION_IF_NULL(out_conf);
|
MS_EXCEPTION_IF_NULL(out_conf);
|
||||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||||
|
|
|
@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
||||||
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
|
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
|
||||||
~StandardPrimEvaluator() override = default;
|
~StandardPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
|
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||||
PrimitivePtr prim() { return prim_; }
|
PrimitivePtr prim() { return prim_; }
|
||||||
|
|
||||||
std::string ToString() const override { return identifier_ + prim_->name(); }
|
std::string ToString() const override { return identifier_ + prim_->name(); }
|
||||||
|
@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator {
|
||||||
: TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
|
: TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
|
||||||
~PythonPrimEvaluator() override = default;
|
~PythonPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
|
MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||||
PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
|
PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
|
||||||
|
|
||||||
std::string ToString() const override { return identifier_ + prim_py_->name(); }
|
std::string ToString() const override { return identifier_ + prim_py_->name(); }
|
||||||
|
@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator {
|
||||||
public:
|
public:
|
||||||
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
|
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
|
||||||
~DoSignatureEvaluator() override = default;
|
~DoSignatureEvaluator() override = default;
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||||
AnfNodeConfigPtr out_config = nullptr) override;
|
AnfNodeConfigPtr out_config = nullptr) override;
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator {
|
||||||
public:
|
public:
|
||||||
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
|
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
|
||||||
~UnpackGraphEvaluator() override = default;
|
~UnpackGraphEvaluator() override = default;
|
||||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||||
AnfNodeConfigPtr out_config = nullptr) override;
|
AnfNodeConfigPtr out_config = nullptr) override;
|
||||||
|
|
||||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator {
|
||||||
~UniformPrimEvaluator() override = default;
|
~UniformPrimEvaluator() override = default;
|
||||||
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
|
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
|
||||||
|
|
||||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||||
ValuePtr RunImpl(const ValuePtrList &args) const;
|
ValuePtr RunImpl(const ValuePtrList &args) const;
|
||||||
|
|
||||||
// If eval_value_ is False, return broadened arguments.
|
// If eval_value_ is False, return broadened arguments.
|
||||||
|
|
|
@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
|
||||||
if (conf->node()->intermediate_abstract()) {
|
if (conf->node()->intermediate_abstract()) {
|
||||||
return conf->node()->intermediate_abstract();
|
return conf->node()->intermediate_abstract();
|
||||||
}
|
}
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
|
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
|
||||||
|
@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() {
|
||||||
|
|
||||||
// Specialize CNode in func graphs
|
// Specialize CNode in func graphs
|
||||||
void FuncGraphSpecializer::SecondPass() {
|
void FuncGraphSpecializer::SecondPass() {
|
||||||
for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
|
for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) {
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
ProcessCNode(node->cast<CNodePtr>());
|
ProcessCNode(node->cast<CNodePtr>());
|
||||||
}
|
}
|
||||||
|
@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
||||||
AnfNodeConfigPtr conf = MakeConfig(node);
|
AnfNodeConfigPtr conf = MakeConfig(node);
|
||||||
AnfNodePtr new_node = GetReplicatedNode(node);
|
AnfNodePtr new_node = GetReplicatedNode(node);
|
||||||
MS_EXCEPTION_IF_NULL(new_node);
|
MS_EXCEPTION_IF_NULL(new_node);
|
||||||
|
|
||||||
if (new_node->func_graph() != specialized_func_graph_) {
|
if (new_node->func_graph() != specialized_func_graph_) {
|
||||||
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
|
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
|
||||||
<< ", new_node: " << new_node->DebugString()
|
<< ", new_node: " << new_node->DebugString()
|
||||||
|
@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
||||||
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
||||||
|
|
||||||
if (node->isa<CNode>()) {
|
if (node->isa<CNode>()) {
|
||||||
|
auto attrs = conf->GetEvaluatedValue()->attribute();
|
||||||
auto c_old = node->cast<CNodePtr>();
|
auto c_old = node->cast<CNodePtr>();
|
||||||
auto c_new = new_node->cast<CNodePtr>();
|
auto c_new = new_node->cast<CNodePtr>();
|
||||||
auto new_inputs = c_new->inputs();
|
auto new_inputs = c_new->inputs();
|
||||||
|
@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
||||||
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
|
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
|
||||||
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
|
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
|
||||||
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
|
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
|
||||||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
|
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
|
||||||
if (replace_node == nullptr) {
|
if (replace_node == nullptr) {
|
||||||
replace_node = BuildReplacedNode(iconf);
|
replace_node = BuildReplacedNode(iconf);
|
||||||
MS_EXCEPTION_IF_NULL(replace_node);
|
MS_EXCEPTION_IF_NULL(replace_node);
|
||||||
|
@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
|
||||||
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
|
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
|
||||||
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
|
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
|
||||||
}
|
}
|
||||||
|
auto attrs = std::make_shared<AttrValueMap>();
|
||||||
for (size_t i = 0; i < partial_closure->args().size(); i++) {
|
for (size_t i = 0; i < partial_closure->args().size(); i++) {
|
||||||
auto old_node = cnode->input(i + 2);
|
auto old_node = cnode->input(i + 2);
|
||||||
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
|
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
|
||||||
if (possibile_value_node != nullptr) {
|
if (possibile_value_node != nullptr) {
|
||||||
partial_node_list.push_back(possibile_value_node);
|
partial_node_list.push_back(possibile_value_node);
|
||||||
} else {
|
} else {
|
||||||
|
@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
||||||
const EvaluatorPtr &eval) {
|
const EvaluatorPtr &eval) {
|
||||||
MS_EXCEPTION_IF_NULL(eval);
|
MS_EXCEPTION_IF_NULL(eval);
|
||||||
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
||||||
AbstractBasePtr ret = nullptr;
|
EvalResultPtr ret = nullptr;
|
||||||
AbstractBasePtrList broaded_argvals;
|
AbstractBasePtrList broaded_argvals;
|
||||||
for (auto &argvals_map : *evalcaches_[eval]) {
|
for (auto &argvals_map : *evalcaches_[eval]) {
|
||||||
auto argvals = argvals_map.first;
|
auto argvals = argvals_map.first;
|
||||||
|
@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
||||||
|
|
||||||
(*real)[broaded_argvals] = ret;
|
(*real)[broaded_argvals] = ret;
|
||||||
evalcaches_[eval] = real;
|
evalcaches_[eval] = real;
|
||||||
return std::make_pair(broaded_argvals, ret);
|
return std::make_pair(broaded_argvals, ret->abstract());
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
||||||
return std::make_pair(AbstractBasePtrList(), nullptr);
|
return std::make_pair(AbstractBasePtrList(), nullptr);
|
||||||
|
@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
specializer_->AddSeen(new_node);
|
specializer_->AddSeen(new_node);
|
||||||
|
|
||||||
auto new_inputs = new_node->inputs();
|
auto new_inputs = new_node->inputs();
|
||||||
if (new_inputs.empty()) {
|
if (new_inputs.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
|
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
|
||||||
|
@ -530,8 +530,14 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CanSpecializeNode(func)) {
|
if (CanSpecializeNode(func)) {
|
||||||
|
// for primitive node , we build the primitive node with infered attributes in the first pass
|
||||||
|
// so we do not build replaced node again here in second pass
|
||||||
|
if (IsValueNode<Primitive>(func)) {
|
||||||
|
new_inputs[0] = func;
|
||||||
|
} else {
|
||||||
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < argvals.size();) {
|
for (size_t i = 0; i < argvals.size();) {
|
||||||
size_t next = i + 1;
|
size_t next = i + 1;
|
||||||
|
@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
||||||
}
|
}
|
||||||
i = next;
|
i = next;
|
||||||
}
|
}
|
||||||
|
|
||||||
new_node->set_inputs(new_inputs);
|
new_node->set_inputs(new_inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
||||||
|
|
||||||
EvaluatorCacheMap evaluator_cache_map = *eval->cache();
|
EvaluatorCacheMap evaluator_cache_map = *eval->cache();
|
||||||
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
|
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
|
||||||
*result = std::make_pair(argvals, evaluator_cache_map[argvals]);
|
*result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
|
||||||
return kSpecializeSuccess;
|
return kSpecializeSuccess;
|
||||||
}
|
}
|
||||||
DumpEvaluatorCache(evaluator_cache_map, argvals);
|
DumpEvaluatorCache(evaluator_cache_map, argvals);
|
||||||
|
@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
||||||
MS_EXCEPTION_IF_NULL(choices);
|
MS_EXCEPTION_IF_NULL(choices);
|
||||||
|
|
||||||
if (choices->count(argvals)) {
|
if (choices->count(argvals)) {
|
||||||
*result = std::make_pair(argvals, (*choices)[argvals]);
|
*result = std::make_pair(argvals, (*choices)[argvals]->abstract());
|
||||||
return kSpecializeSuccess;
|
return kSpecializeSuccess;
|
||||||
} else if (choices->size() == 1) {
|
} else if (choices->size() == 1) {
|
||||||
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
||||||
*result = std::make_pair(choices->begin()->first, choices->begin()->second);
|
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
|
||||||
return kSpecializeSuccess;
|
return kSpecializeSuccess;
|
||||||
} else if (choices->empty()) {
|
} else if (choices->empty()) {
|
||||||
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
|
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
|
||||||
|
@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
||||||
return kSpecializeFindUniqueArgvalPoly;
|
return kSpecializeFindUniqueArgvalPoly;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
|
||||||
|
auto &prim_attrs = prim->attrs();
|
||||||
|
bool is_attr_same = true;
|
||||||
|
for (auto &item : *attrs) {
|
||||||
|
auto itr = prim_attrs.find(item.first);
|
||||||
|
if (itr != prim_attrs.end()) {
|
||||||
|
if (!(*(itr->second) == *(item.second))) {
|
||||||
|
is_attr_same = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
is_attr_same = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!is_attr_same) {
|
||||||
|
if (prim->isa<PrimitivePy>()) {
|
||||||
|
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||||
|
auto clone_fn = prim_py->GetPyObj().attr("_clone");
|
||||||
|
py::object new_obj = clone_fn();
|
||||||
|
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
|
||||||
|
for (auto &item : *attrs) {
|
||||||
|
cloned_prim->AddAttr(item.first, item.second);
|
||||||
|
}
|
||||||
|
return cloned_prim;
|
||||||
|
}
|
||||||
|
auto cloned_prim = std::make_shared<Primitive>(*prim);
|
||||||
|
for (auto &item : *attrs) {
|
||||||
|
cloned_prim->AddAttr(item.first, item.second);
|
||||||
|
}
|
||||||
|
return cloned_prim;
|
||||||
|
}
|
||||||
|
return prim;
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
|
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||||
|
const AttrValueMapPtr &attrs) {
|
||||||
MS_EXCEPTION_IF_NULL(origin_node);
|
MS_EXCEPTION_IF_NULL(origin_node);
|
||||||
MS_EXCEPTION_IF_NULL(ival);
|
MS_EXCEPTION_IF_NULL(ival);
|
||||||
|
|
||||||
|
@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
|
||||||
ValuePtr value = nullptr;
|
ValuePtr value = nullptr;
|
||||||
if (abs->isa<PrimitiveAbstractClosure>()) {
|
if (abs->isa<PrimitiveAbstractClosure>()) {
|
||||||
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
||||||
|
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
|
||||||
|
if (attrs != nullptr) {
|
||||||
|
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
|
||||||
|
} else {
|
||||||
value = real_fn->prim();
|
value = real_fn->prim();
|
||||||
|
}
|
||||||
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
|
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
|
||||||
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
|
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
|
||||||
value = real_fn->meta_func_graph();
|
value = real_fn->meta_func_graph();
|
||||||
|
|
|
@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
||||||
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
|
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
|
||||||
|
|
||||||
// Build a value node if ival is constant and not any-value
|
// Build a value node if ival is constant and not any-value
|
||||||
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
|
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||||
|
const AttrValueMapPtr &attrs);
|
||||||
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
|
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
|
||||||
// replicated node.
|
// replicated node.
|
||||||
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
|
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
|
||||||
|
|
|
@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||||
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
|
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
|
||||||
<< ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
|
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
|
||||||
<< ", Pointer: " << arg.get();
|
<< ", Pointer: " << result->abstract().get();
|
||||||
cache_[conf] = arg;
|
cache_[conf] = result;
|
||||||
|
|
||||||
// Set intermediate abstract value.
|
// Set intermediate abstract value.
|
||||||
if (IsIntermediateAbstract(arg)) {
|
if (IsIntermediateAbstract(result->abstract())) {
|
||||||
if (conf->node()->intermediate_abstract() == nullptr) {
|
if (conf->node()->intermediate_abstract() == nullptr) {
|
||||||
conf->node()->set_intermediate_abstract(arg);
|
conf->node()->set_intermediate_abstract(result->abstract());
|
||||||
MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
|
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
|
||||||
} else {
|
} else {
|
||||||
auto old_spec = conf->node()->intermediate_abstract();
|
auto old_spec = conf->node()->intermediate_abstract();
|
||||||
auto joined_spec = IntermediateJoin(arg, old_spec);
|
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
|
||||||
conf->node()->set_intermediate_abstract(joined_spec);
|
conf->node()->set_intermediate_abstract(joined_spec);
|
||||||
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
||||||
<< arg->ToString() << "\njoined_spec:\t"
|
<< result->abstract()->ToString() << "\njoined_spec:\t"
|
||||||
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
||||||
auto value = cache_.find(conf);
|
auto value = cache_.find(conf);
|
||||||
if (value == cache_.end()) {
|
if (value == cache_.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
||||||
return eval->graph_context();
|
return eval->graph_context();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
auto value = cache_.GetValue(conf);
|
auto value = cache_.GetValue(conf);
|
||||||
if (value != nullptr) {
|
if (value != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
|
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
|
||||||
<< value->ToString();
|
<< ", " << value->abstract()->ToString();
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
AnfNodePtr node = conf->node();
|
AnfNodePtr node = conf->node();
|
||||||
AbstractBasePtr ret_abstract = nullptr;
|
EvalResultPtr eval_result = nullptr;
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
compute_conf_stack_.push_back(node);
|
compute_conf_stack_.push_back(node);
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
|
@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (node->abstract() != nullptr) {
|
if (node->abstract() != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
|
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
|
||||||
ret_abstract = node->abstract();
|
eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
|
||||||
} else if (node->isa<ValueNode>()) {
|
} else if (node->isa<ValueNode>()) {
|
||||||
auto value_node = node->cast<ValueNodePtr>();
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
ret_abstract = EvalValueNode(value_node, conf);
|
eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
|
||||||
} else if (node->isa<CNode>()) {
|
} else if (node->isa<CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
trace::TraceEvalCNodeEnter(conf);
|
trace::TraceEvalCNodeEnter(conf);
|
||||||
ret_abstract = EvalCNode(cnode, conf);
|
eval_result = EvalCNode(cnode, conf);
|
||||||
trace::TraceEvalCNodeLeave();
|
trace::TraceEvalCNodeLeave();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
|
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
|
||||||
|
@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
||||||
|
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
compute_conf_stack_.pop_back();
|
compute_conf_stack_.pop_back();
|
||||||
if (ret_abstract == nullptr) {
|
if (eval_result == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
|
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
|
||||||
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
|
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
|
||||||
return ret_abstract;
|
return eval_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
|
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
|
||||||
|
@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
||||||
return ToAbstract(value_node->value(), conf->context(), conf);
|
return ToAbstract(value_node->value(), conf->context(), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto &inputs = cnode->inputs();
|
auto &inputs = cnode->inputs();
|
||||||
|
@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
||||||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
||||||
MS_EXCEPTION_IF_NULL(func_conf);
|
MS_EXCEPTION_IF_NULL(func_conf);
|
||||||
// Keep it in a local variable, otherwise smart pointer will free it.
|
// Keep it in a local variable, otherwise smart pointer will free it.
|
||||||
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
|
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
|
||||||
if (maybe_func == nullptr) {
|
if (maybe_func == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
||||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||||
|
@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
||||||
return ExecuteEvaluators(infs, conf, args_conf_list);
|
return ExecuteEvaluators(infs, conf, args_conf_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
||||||
ConfigPtrList args_conf_list;
|
ConfigPtrList args_conf_list;
|
||||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
||||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||||
|
@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||||
return tracked_eval;
|
return tracked_eval;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||||
const AnfNodeConfigPtr &out_conf,
|
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
|
||||||
const ConfigPtrList &args_conf_list) {
|
|
||||||
if (evaluators.size() == 1) {
|
if (evaluators.size() == 1) {
|
||||||
EvaluatorPtr eval = evaluators[0];
|
EvaluatorPtr eval = evaluators[0];
|
||||||
MS_EXCEPTION_IF_NULL(eval);
|
MS_EXCEPTION_IF_NULL(eval);
|
||||||
|
@ -465,7 +464,7 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
|
||||||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||||
const AnfNodeConfigPtr &out_conf,
|
const AnfNodeConfigPtr &out_conf,
|
||||||
const ConfigPtrList &args_conf_list) {
|
const ConfigPtrList &args_conf_list) {
|
||||||
AbstractBasePtrList out_specs;
|
AbstractBasePtrList out_specs;
|
||||||
|
@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
||||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
return conf->GetEvaluatedValue();
|
return conf->GetEvaluatedValue()->abstract();
|
||||||
});
|
});
|
||||||
for (auto eval : evaluators) {
|
for (auto eval : evaluators) {
|
||||||
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
||||||
|
@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
||||||
eval_trace_.push_back(current_inf);
|
eval_trace_.push_back(current_inf);
|
||||||
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
||||||
MS_EXCEPTION_IF_NULL(eval);
|
MS_EXCEPTION_IF_NULL(eval);
|
||||||
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||||
MS_EXCEPTION_IF_NULL(out_spec);
|
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||||
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
|
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
|
||||||
out_specs.push_back(out_spec);
|
out_specs.push_back(eval_result->abstract());
|
||||||
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
|
|
||||||
eval_trace_.pop_back();
|
eval_trace_.pop_back();
|
||||||
if (eval_trace_.empty()) {
|
if (eval_trace_.empty()) {
|
||||||
multi_poss_.clear();
|
multi_poss_.clear();
|
||||||
|
@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
||||||
// Try to travel the latest undetermined.
|
// Try to travel the latest undetermined.
|
||||||
if (latest_entry != eval_trace_.rbegin()->first) {
|
if (latest_entry != eval_trace_.rbegin()->first) {
|
||||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
||||||
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||||
MS_EXCEPTION_IF_NULL(out_spec);
|
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||||
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
|
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
|
||||||
return out_spec;
|
<< " return out_spec: " << eval_result->abstract()->ToString();
|
||||||
|
return eval_result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
||||||
if (out_specs.size() == 1) {
|
if (out_specs.size() == 1) {
|
||||||
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
||||||
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
||||||
return out_specs[0]->Broaden();
|
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
auto joined_spec = AbstractJoin(out_specs);
|
auto joined_spec = AbstractJoin(out_specs);
|
||||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||||
return joined_spec;
|
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
|
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
|
||||||
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
|
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
|
||||||
return engine_.lock()->GetEvaluatedValue(self);
|
return engine_.lock()->GetEvaluatedValue(self);
|
||||||
}
|
}
|
||||||
|
@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
||||||
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
||||||
MS_EXCEPTION_IF_NULL(evaluator);
|
MS_EXCEPTION_IF_NULL(evaluator);
|
||||||
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
||||||
|
@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
|
||||||
<< evaluator->ToString();
|
<< evaluator->ToString();
|
||||||
}
|
}
|
||||||
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
|
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
|
||||||
auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||||
return res_spec;
|
return eval_result;
|
||||||
}
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -40,13 +40,33 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
|
||||||
|
// define attribute value map
|
||||||
|
using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
|
||||||
|
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
|
||||||
|
|
||||||
|
// the class to save evaluated result: abstract value and modified attribute
|
||||||
|
class EvalResult : public Base {
|
||||||
|
public:
|
||||||
|
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
|
||||||
|
~EvalResult() override = default;
|
||||||
|
MS_DECLARE_PARENT(EvalResult, Base);
|
||||||
|
AbstractBasePtr abstract() { return abstract_; }
|
||||||
|
AttrValueMapPtr attribute() { return attribute_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractBasePtr abstract_;
|
||||||
|
AttrValueMapPtr attribute_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using EvalResultPtr = std::shared_ptr<EvalResult>;
|
||||||
// Superclass for AnfNodeConfig and VirtualConfig.
|
// Superclass for AnfNodeConfig and VirtualConfig.
|
||||||
class Config : public Base {
|
class Config : public Base {
|
||||||
public:
|
public:
|
||||||
Config() = default;
|
Config() = default;
|
||||||
~Config() override = default;
|
~Config() override = default;
|
||||||
MS_DECLARE_PARENT(Config, Base);
|
MS_DECLARE_PARENT(Config, Base);
|
||||||
virtual AbstractBasePtr GetEvaluatedValue() = 0;
|
virtual EvalResultPtr GetEvaluatedValue() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Config will be stored in AnalysisCache
|
// Config will be stored in AnalysisCache
|
||||||
|
@ -74,7 +94,7 @@ class AnfNodeConfig : public Config {
|
||||||
~AnfNodeConfig() override = default;
|
~AnfNodeConfig() override = default;
|
||||||
MS_DECLARE_PARENT(AnfNodeConfig, Config);
|
MS_DECLARE_PARENT(AnfNodeConfig, Config);
|
||||||
|
|
||||||
AbstractBasePtr GetEvaluatedValue() override;
|
EvalResultPtr GetEvaluatedValue() override;
|
||||||
|
|
||||||
AnalysisContextPtr context() const { return context_; }
|
AnalysisContextPtr context() const { return context_; }
|
||||||
|
|
||||||
|
@ -123,7 +143,9 @@ class VirtualConfig : public Config {
|
||||||
|
|
||||||
~VirtualConfig() override = default;
|
~VirtualConfig() override = default;
|
||||||
MS_DECLARE_PARENT(VirtualConfig, Config);
|
MS_DECLARE_PARENT(VirtualConfig, Config);
|
||||||
AbstractBasePtr GetEvaluatedValue() override { return abstract_; }
|
EvalResultPtr GetEvaluatedValue() override {
|
||||||
|
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AbstractBasePtr abstract_;
|
AbstractBasePtr abstract_;
|
||||||
|
@ -135,11 +157,11 @@ class AnalysisCache {
|
||||||
AnalysisCache() = default;
|
AnalysisCache() = default;
|
||||||
~AnalysisCache() = default;
|
~AnalysisCache() = default;
|
||||||
void Clear() { cache_.clear(); }
|
void Clear() { cache_.clear(); }
|
||||||
void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
|
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
|
||||||
AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf);
|
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
|
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
|
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
|
||||||
|
@ -147,7 +169,7 @@ using AnfNodeConfigMap =
|
||||||
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||||
|
|
||||||
struct AnalysisResult {
|
struct AnalysisResult {
|
||||||
AbstractBasePtr inferred;
|
EvalResultPtr inferred;
|
||||||
AnalysisContextPtr context;
|
AnalysisContextPtr context;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||||
// func_graph: The func_graph to analyze.
|
// func_graph: The func_graph to analyze.
|
||||||
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
||||||
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
|
EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
|
||||||
// Return the Evaluator for the given function.
|
// Return the Evaluator for the given function.
|
||||||
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||||
|
|
||||||
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
|
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
|
||||||
AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
||||||
// Infer the result of fn(args).
|
// Infer the result of fn(args).
|
||||||
AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
||||||
void Clear();
|
void Clear();
|
||||||
void ClearEvaluatorCache();
|
void ClearEvaluatorCache();
|
||||||
AnalysisCache &cache() { return cache_; }
|
AnalysisCache &cache() { return cache_; }
|
||||||
|
@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||||
|
|
||||||
// Set the analysis result for orig to the result for new.
|
// Set the analysis result for orig to the result for new.
|
||||||
// This sets an entry in anfnode_config_map from orig to new.
|
// This sets an entry in anfnode_config_map from orig to new.
|
||||||
AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
||||||
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
|
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
|
||||||
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
|
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
|
||||||
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
|
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
|
||||||
|
@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
||||||
|
|
||||||
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||||
const ConfigPtrList &args_conf_list);
|
const ConfigPtrList &args_conf_list);
|
||||||
AbstractBasePtr Eval(const AnfNodeConfigPtr &conf);
|
EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
|
||||||
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||||
AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||||
|
const ConfigPtrList &args_conf_list);
|
||||||
|
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||||
const ConfigPtrList &args_conf_list);
|
const ConfigPtrList &args_conf_list);
|
||||||
AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
|
||||||
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list);
|
|
||||||
|
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
std::vector<AnfNodePtr> compute_conf_stack_;
|
std::vector<AnfNodePtr> compute_conf_stack_;
|
||||||
|
@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
||||||
return FromValueInside(MakeValue(value), broaden);
|
return FromValueInside(MakeValue(value), broaden);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
|
||||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list);
|
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||||
op_exec_info->abstract = infer_res;
|
op_exec_info->abstract = infer_res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,8 @@
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
#include "ir/visitor.h"
|
#include "ir/visitor.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// search the cnodes inside this graph only
|
||||||
|
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
|
||||||
|
std::queue<CNodePtr> todo;
|
||||||
|
todo.push(ret);
|
||||||
|
std::vector<CNodePtr> sorted_nodes;
|
||||||
|
auto seen = NewSeenGeneration();
|
||||||
|
while (!todo.empty()) {
|
||||||
|
CNodePtr top = todo.front();
|
||||||
|
todo.pop();
|
||||||
|
sorted_nodes.push_back(top);
|
||||||
|
auto inputs = top->inputs();
|
||||||
|
for (auto &item : inputs) {
|
||||||
|
if (item->seen_ == seen) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (item->isa<CNode>()) {
|
||||||
|
todo.push(item->cast<CNodePtr>());
|
||||||
|
}
|
||||||
|
item->seen_ = seen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sorted_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
||||||
std::vector<AnfNodePtr> vecs;
|
std::vector<AnfNodePtr> vecs;
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
|
|
|
@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
|
||||||
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
|
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
|
||||||
const IncludeFunc &include = AlwaysInclude);
|
const IncludeFunc &include = AlwaysInclude);
|
||||||
|
|
||||||
|
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
|
||||||
class FuncGraphIndex {
|
class FuncGraphIndex {
|
||||||
public:
|
public:
|
||||||
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
||||||
|
|
|
@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init ExpandDims"""
|
"""init ExpandDims"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, x, axis):
|
def __infer__(self, x, axis):
|
||||||
|
@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer):
|
||||||
# if primitive need setattr in __infer__ need add this flag
|
# if primitive need setattr in __infer__ need add this flag
|
||||||
"""init Cast"""
|
"""init Cast"""
|
||||||
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
||||||
self.__setattr_flag__ = True
|
|
||||||
|
|
||||||
def __infer__(self, x, t):
|
def __infer__(self, x, t):
|
||||||
src_type = x['dtype']
|
src_type = x['dtype']
|
||||||
|
@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init Reshape"""
|
"""init Reshape"""
|
||||||
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
||||||
self.__setattr_flag__ = True
|
|
||||||
|
|
||||||
def __infer__(self, x, shape):
|
def __infer__(self, x, shape):
|
||||||
shape_v = shape['value']
|
shape_v = shape['value']
|
||||||
|
@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init Transpose"""
|
"""init Transpose"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, x, perm):
|
def __infer__(self, x, perm):
|
||||||
|
@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init index_select"""
|
"""init index_select"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, params, indices, axis):
|
def __infer__(self, params, indices, axis):
|
||||||
|
@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0):
|
def __init__(self, axis=0):
|
||||||
"""init Tile"""
|
"""init Tile"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
validator.check_value_type("axis", axis, [int], self.name)
|
validator.check_value_type("axis", axis, [int], self.name)
|
||||||
|
|
||||||
def __infer__(self, input_x):
|
def __infer__(self, input_x):
|
||||||
|
@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0):
|
def __init__(self, axis=0):
|
||||||
"""init Pack"""
|
"""init Pack"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
validator.check_value_type("axis", axis, [int], self.name)
|
validator.check_value_type("axis", axis, [int], self.name)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
|
||||||
|
@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, axis=0):
|
def __init__(self, axis=0):
|
||||||
"""init Unpack"""
|
"""init Unpack"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
validator.check_value_type("axis", axis, [int], self.name)
|
validator.check_value_type("axis", axis, [int], self.name)
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
|
|
||||||
|
@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""init"""
|
"""init"""
|
||||||
self.__setattr_flag__ = True
|
|
||||||
|
|
||||||
def infer_shape(self, cond_shape, x_shape, y_shape):
|
def infer_shape(self, cond_shape, x_shape, y_shape):
|
||||||
if cond_shape != x_shape or x_shape != y_shape:
|
if cond_shape != x_shape or x_shape != y_shape:
|
||||||
|
|
|
@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, transpose_a=False, transpose_b=False):
|
def __init__(self, transpose_a=False, transpose_b=False):
|
||||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||||
self.__setattr_flag__ = True
|
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||||
|
@ -596,7 +595,6 @@ class BatchMatMul(MatMul):
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, transpose_a=False, transpose_b=False):
|
def __init__(self, transpose_a=False, transpose_b=False):
|
||||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||||
self.__setattr_flag__ = True
|
|
||||||
cls_name = self.name
|
cls_name = self.name
|
||||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||||
|
@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__setattr_flag__ = True
|
|
||||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||||
|
|
||||||
def infer_shape(self, inputs):
|
def infer_shape(self, inputs):
|
||||||
|
|
|
@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
"""init Conv2D"""
|
"""init Conv2D"""
|
||||||
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
|
||||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
|
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
||||||
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
|
self.add_prim_attr('stride', self.stride)
|
||||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||||
self.add_prim_attr('dilation', self.dilation)
|
self.add_prim_attr('dilation', self.dilation)
|
||||||
validator.check_value_type('pad', pad, (int,), self.name)
|
validator.check_value_type('pad', pad, (int,), self.name)
|
||||||
|
@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
|
|
||||||
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
|
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
|
||||||
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
|
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
|
||||||
|
|
||||||
out_channel = self.out_channel
|
out_channel = self.out_channel
|
||||||
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
||||||
return out_shape
|
return out_shape
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test nn ops """
|
||||||
|
import functools
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
from mindspore import context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_op_attr():
|
||||||
|
class CastNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CastNet, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
def construct(self, x, t):
|
||||||
|
return self.cast(x, t)
|
||||||
|
|
||||||
|
class CastTypeTest(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(CastTypeTest, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.cast = P.Cast()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
cast_op = self.cast
|
||||||
|
t1 = cast_op(x, mstype.float32)
|
||||||
|
t2 = cast_op(y, mstype.int32)
|
||||||
|
cast_net = self.net
|
||||||
|
t3 = cast_net(x, mstype.float16)
|
||||||
|
t4 = cast_net(y, mstype.int32)
|
||||||
|
t5 = cast_net(z, mstype.float16)
|
||||||
|
return (t1, t2, t3, t4, t5)
|
||||||
|
net = CastTypeTest(CastNet())
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
assert out[0].asnumpy().dtype == np.float32
|
||||||
|
assert out[1].asnumpy().dtype == np.int32
|
||||||
|
assert out[2].asnumpy().dtype == np.float16
|
||||||
|
assert out[3].asnumpy().dtype == np.int32
|
||||||
|
assert out[4].asnumpy().dtype == np.float16
|
|
@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
|
||||||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||||
|
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
|
||||||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||||
|
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
|
||||||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||||
|
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
||||||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||||
|
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) {
|
||||||
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, slice};
|
AbstractBasePtrList args_spec_list = {tensor, slice};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
|
||||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
|
||||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) {
|
||||||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
|
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, start_index};
|
AbstractBasePtrList args_spec_list = {tensor, start_index};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
|
||||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
|
||||||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||||
|
|
||||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract array failed.";
|
FAIL() << "Cast ret to abstract array failed.";
|
||||||
}
|
}
|
||||||
|
@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
|
||||||
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
|
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
|
||||||
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
|
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) {
|
||||||
auto tuple = std::make_shared<AbstractTuple>(eles);
|
auto tuple = std::make_shared<AbstractTuple>(eles);
|
||||||
AbstractBasePtrList args_spec_list = {tuple};
|
AbstractBasePtrList args_spec_list = {tuple};
|
||||||
|
|
||||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred);
|
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
|
||||||
if (ret == nullptr) {
|
if (ret == nullptr) {
|
||||||
FAIL() << "Cast ret to abstract tuple failed.";
|
FAIL() << "Cast ret to abstract tuple failed.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
|
||||||
AbstractBasePtr abstract_v2 = FromValue(2, false);
|
AbstractBasePtr abstract_v2 = FromValue(2, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2};
|
AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2};
|
||||||
AbstractBasePtr abstract_val = FromValue(10, false);
|
AbstractBasePtr abstract_val = FromValue(10, false);
|
||||||
cache[args_spec_list] = abstract_val;
|
cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>());
|
||||||
|
|
||||||
auto iter = cache.find(args_spec_list);
|
auto iter = cache.find(args_spec_list);
|
||||||
ASSERT_TRUE(iter != cache.end());
|
ASSERT_TRUE(iter != cache.end());
|
||||||
ASSERT_TRUE(iter->second == abstract_val);
|
ASSERT_TRUE(iter->second->abstract() == abstract_val);
|
||||||
|
|
||||||
AbstractBasePtr abstract_v1_variant1 = FromValue(1, false);
|
AbstractBasePtr abstract_v1_variant1 = FromValue(1, false);
|
||||||
AbstractBasePtr abstract_v2_variant1 = FromValue(2, false);
|
AbstractBasePtr abstract_v2_variant1 = FromValue(2, false);
|
||||||
|
@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
|
||||||
|
|
||||||
iter = cache.find(args_spec_list_variant1);
|
iter = cache.find(args_spec_list_variant1);
|
||||||
ASSERT_TRUE(iter != cache.end());
|
ASSERT_TRUE(iter != cache.end());
|
||||||
ASSERT_TRUE(iter->second == abstract_val);
|
ASSERT_TRUE(iter->second->abstract() == abstract_val);
|
||||||
|
|
||||||
AbstractBasePtr abstract_v1_variant2 = FromValue(1, false);
|
AbstractBasePtr abstract_v1_variant2 = FromValue(1, false);
|
||||||
AbstractBasePtr abstract_v2_variant2 = FromValue(3, false);
|
AbstractBasePtr abstract_v2_variant2 = FromValue(3, false);
|
||||||
|
@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
|
||||||
std::vector<int> shape = {2, 2, 6, 6};
|
std::vector<int> shape = {2, 2, 6, 6};
|
||||||
expected->set_shape(std::make_shared<Shape>(shape));
|
expected->set_shape(std::make_shared<Shape>(shape));
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
|
||||||
AbstractBasePtr abstract_x = FromValue(x, false);
|
AbstractBasePtr abstract_x = FromValue(x, false);
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
||||||
}
|
}
|
||||||
|
@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
|
||||||
AbstractBasePtr abstract_x = FromValue(x, false);
|
AbstractBasePtr abstract_x = FromValue(x, false);
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
args_spec_list.push_back(abstract_y);
|
args_spec_list.push_back(abstract_y);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
args_spec_list.push_back(abstract_y);
|
args_spec_list.push_back(abstract_y);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||||
}
|
}
|
||||||
|
@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
args_spec_list.push_back(abstract_y);
|
args_spec_list.push_back(abstract_y);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
|
||||||
args_spec_list.push_back(abstract_x);
|
args_spec_list.push_back(abstract_x);
|
||||||
args_spec_list.push_back(abstract_y);
|
args_spec_list.push_back(abstract_y);
|
||||||
|
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||||
}
|
}
|
||||||
|
|
|
@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) {
|
||||||
|
|
||||||
auto prim_typeof = std::make_shared<Primitive>("typeof");
|
auto prim_typeof = std::make_shared<Primitive>("typeof");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
res->dump();
|
res->dump();
|
||||||
TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
|
TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
|
||||||
res_value->dump();
|
res_value->dump();
|
||||||
|
@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) {
|
||||||
|
|
||||||
auto prim_list_map = std::make_shared<Primitive>("list_map");
|
auto prim_list_map = std::make_shared<Primitive>("list_map");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)}));
|
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)}));
|
||||||
res->dump();
|
res->dump();
|
||||||
MS_LOG(INFO) << "result res: " << res->ToString();
|
MS_LOG(INFO) << "result res: " << res->ToString();
|
||||||
|
@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) {
|
||||||
|
|
||||||
auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
|
auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
res->dump();
|
res->dump();
|
||||||
TypePtr res_type = res->GetTypeTrack();
|
TypePtr res_type = res->GetTypeTrack();
|
||||||
res_type->dump();
|
res_type->dump();
|
||||||
|
@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) {
|
||||||
|
|
||||||
auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
|
auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
res->dump();
|
res->dump();
|
||||||
TypePtr res_type = res->BuildType();
|
TypePtr res_type = res->BuildType();
|
||||||
res_type->dump();
|
res_type->dump();
|
||||||
|
@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) {
|
||||||
|
|
||||||
auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
|
auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
res->dump();
|
res->dump();
|
||||||
TypePtr res_type = res->BuildType();
|
TypePtr res_type = res->BuildType();
|
||||||
res_type->dump();
|
res_type->dump();
|
||||||
|
@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) {
|
||||||
|
|
||||||
auto prim_J = std::make_shared<Primitive>("J");
|
auto prim_J = std::make_shared<Primitive>("J");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
|
AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
|
||||||
ASSERT_TRUE(res_J != nullptr);
|
ASSERT_TRUE(res_J != nullptr);
|
||||||
ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
|
ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
|
||||||
|
@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) {
|
||||||
int v1 = 1;
|
int v1 = 1;
|
||||||
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_v1};
|
AbstractBasePtrList args_spec_list = {abstract_v1};
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
res->dump();
|
res->dump();
|
||||||
AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
|
AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
|
||||||
ASSERT_TRUE(res_J != nullptr);
|
ASSERT_TRUE(res_J != nullptr);
|
||||||
|
@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) {
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {a1, a2};
|
AbstractBasePtrList args_spec_list = {a1, a2};
|
||||||
|
|
||||||
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred);
|
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||||
|
|
||||||
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
|
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
|
||||||
}
|
}
|
||||||
|
@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) {
|
||||||
AbstractBasePtr arg2 = FromValue(2, false);
|
AbstractBasePtr arg2 = FromValue(2, false);
|
||||||
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *arg1);
|
ASSERT_TRUE(*res == *arg1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) {
|
||||||
AbstractBasePtr arg2 = FromValue(2, false);
|
AbstractBasePtr arg2 = FromValue(2, false);
|
||||||
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "make result res: " << res->ToString();
|
MS_LOG(INFO) << "make result res: " << res->ToString();
|
||||||
MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
|
MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
|
||||||
ASSERT_TRUE(*res == *arg2);
|
ASSERT_TRUE(*res == *arg2);
|
||||||
|
@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) {
|
||||||
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_v1};
|
AbstractBasePtrList args_spec_list = {abstract_v1};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *abstract_v1);
|
ASSERT_TRUE(*res == *abstract_v1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) {
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {a, b};
|
AbstractBasePtrList args_spec_list = {a, b};
|
||||||
|
|
||||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred);
|
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||||
|
|
||||||
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
||||||
std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
|
std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
|
||||||
|
@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) {
|
||||||
AbstractBasePtr abstract_v2 = FromValue(1, false);
|
AbstractBasePtr abstract_v2 = FromValue(1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
|
AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
|
AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
|
||||||
auto expected = std::make_shared<PartialAbstractClosure>(
|
auto expected = std::make_shared<PartialAbstractClosure>(
|
||||||
std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
|
std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
|
||||||
|
@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) {
|
||||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) {
|
||||||
AbstractBasePtr abstract_y = FromValue(2, false);
|
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||||
ASSERT_TRUE(*res == *exp);
|
ASSERT_TRUE(*res == *exp);
|
||||||
}
|
}
|
||||||
|
@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
||||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
||||||
AbstractBasePtr abstract_y = FromValue(2, false);
|
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||||
ASSERT_TRUE(*res == *exp);
|
ASSERT_TRUE(*res == *exp);
|
||||||
|
|
||||||
|
@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
||||||
AbstractBasePtr abstract_z = FromValue(3, false);
|
AbstractBasePtr abstract_z = FromValue(3, false);
|
||||||
args_spec_list = {res, embed_x, abstract_z};
|
args_spec_list = {res, embed_x, abstract_z};
|
||||||
|
|
||||||
res = engine_->Run(graph_getitem, args_spec_list).inferred;
|
res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
ASSERT_TRUE(*res == *abstract_x);
|
ASSERT_TRUE(*res == *abstract_x);
|
||||||
}
|
}
|
||||||
|
@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) {
|
||||||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||||
|
|
||||||
|
@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) {
|
||||||
AbstractBasePtr abstract_y = FromValue(2, false);
|
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||||
|
|
||||||
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred;
|
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||||
ASSERT_TRUE(*abstract_e1 == *exp);
|
ASSERT_TRUE(*abstract_e1 == *exp);
|
||||||
|
|
||||||
AbstractBasePtr abstract_z = FromValue(3, false);
|
AbstractBasePtr abstract_z = FromValue(3, false);
|
||||||
args_spec_list = {abstract_env, embed_x, abstract_z};
|
args_spec_list = {abstract_env, embed_x, abstract_z};
|
||||||
|
|
||||||
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred;
|
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*abstract_e2 == *exp);
|
ASSERT_TRUE(*abstract_e2 == *exp);
|
||||||
|
|
||||||
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
|
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
|
||||||
args_spec_list = {abstract_e1, abstract_e2};
|
args_spec_list = {abstract_e1, abstract_e2};
|
||||||
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
ASSERT_TRUE(*res == *exp);
|
ASSERT_TRUE(*res == *exp);
|
||||||
}
|
}
|
||||||
|
@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) {
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {a};
|
AbstractBasePtrList args_spec_list = {a};
|
||||||
|
|
||||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred);
|
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||||
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
||||||
|
|
||||||
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
|
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
|
||||||
|
@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) {
|
||||||
AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
|
AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
|
||||||
AbstractBasePtrList args_spec_list = {expected};
|
AbstractBasePtrList args_spec_list = {expected};
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) {
|
||||||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
|
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {arr};
|
AbstractBasePtrList args_spec_list = {arr};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTensor>(ret);
|
auto res = dyn_cast<AbstractTensor>(ret);
|
||||||
ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
|
ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
|
||||||
}
|
}
|
||||||
|
@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) {
|
||||||
std::vector<int> shape = {2, 64, 14, 14};
|
std::vector<int> shape = {2, 64, 14, 14};
|
||||||
expected->set_shape(std::make_shared<Shape>(shape));
|
expected->set_shape(std::make_shared<Shape>(shape));
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||||
|
|
||||||
|
@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) {
|
||||||
auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
|
auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {input, weight};
|
AbstractBasePtrList args_spec_list = {input, weight};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTensor>(ret);
|
auto res = dyn_cast<AbstractTensor>(ret);
|
||||||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
|
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
|
@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) {
|
||||||
auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
|
auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {input, weight};
|
AbstractBasePtrList args_spec_list = {input, weight};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTensor>(ret);
|
auto res = dyn_cast<AbstractTensor>(ret);
|
||||||
auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
|
auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
|
@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) {
|
||||||
auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
|
auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {value, bias};
|
AbstractBasePtrList args_spec_list = {value, bias};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTensor>(ret);
|
auto res = dyn_cast<AbstractTensor>(ret);
|
||||||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
|
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
|
@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
|
||||||
auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
|
auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {logits, labels};
|
AbstractBasePtrList args_spec_list = {logits, labels};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_NE(ret, nullptr);
|
ASSERT_NE(ret, nullptr);
|
||||||
auto res = dyn_cast<AbstractTuple>(ret);
|
auto res = dyn_cast<AbstractTuple>(ret);
|
||||||
auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
|
auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
|
||||||
|
@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) {
|
||||||
auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
|
auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
|
||||||
|
|
||||||
AbstractBasePtrList args_spec_list = {logits, labels};
|
AbstractBasePtrList args_spec_list = {logits, labels};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractScalar>(ret);
|
auto res = dyn_cast<AbstractScalar>(ret);
|
||||||
AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
|
AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
|
||||||
expected->set_type(UTPrimUtils::kF64);
|
expected->set_type(UTPrimUtils::kF64);
|
||||||
|
@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) {
|
||||||
AbstractBasePtr expected0 = abstract_inputs->Clone();
|
AbstractBasePtr expected0 = abstract_inputs->Clone();
|
||||||
AbstractBasePtr expected1 = abstract_scale->Clone();
|
AbstractBasePtr expected1 = abstract_scale->Clone();
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
||||||
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
||||||
|
@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) {
|
||||||
inputs->set_shape(inputs_dims);
|
inputs->set_shape(inputs_dims);
|
||||||
AbstractBasePtr abstract_input = FromValue(inputs, false);
|
AbstractBasePtr abstract_input = FromValue(inputs, false);
|
||||||
AbstractBasePtrList args_spec_list = {abstract_input};
|
AbstractBasePtrList args_spec_list = {abstract_input};
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
|
|
||||||
AbstractBasePtr expected = abstract_input->Clone()->Broaden();
|
AbstractBasePtr expected = abstract_input->Clone()->Broaden();
|
||||||
std::vector<int> expected_dims = {8, 64, 2, 2};
|
std::vector<int> expected_dims = {8, 64, 2, 2};
|
||||||
|
@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) {
|
||||||
auto prim = std::make_shared<Primitive>("hastype");
|
auto prim = std::make_shared<Primitive>("hastype");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) {
|
||||||
auto prim = std::make_shared<Primitive>("array_len");
|
auto prim = std::make_shared<Primitive>("array_len");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) {
|
||||||
auto prim = std::make_shared<Primitive>("list_len");
|
auto prim = std::make_shared<Primitive>("list_len");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) {
|
||||||
auto prim = std::make_shared<Primitive>("tuple_len");
|
auto prim = std::make_shared<Primitive>("tuple_len");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) {
|
||||||
auto prim = std::make_shared<Primitive>("tuple_reversed");
|
auto prim = std::make_shared<Primitive>("tuple_reversed");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "expect=" << expected->ToString();
|
MS_LOG(INFO) << "expect=" << expected->ToString();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) {
|
||||||
auto prim = std::make_shared<Primitive>("list_getitem");
|
auto prim = std::make_shared<Primitive>("list_getitem");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *elem);
|
ASSERT_TRUE(*res == *elem);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) {
|
||||||
auto prim = std::make_shared<Primitive>("list_setitem");
|
auto prim = std::make_shared<Primitive>("list_setitem");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
AbstractBasePtrList elems_exp = {elem1, elem2};
|
AbstractBasePtrList elems_exp = {elem1, elem2};
|
||||||
auto expected = std::make_shared<AbstractList>(elems_exp);
|
auto expected = std::make_shared<AbstractList>(elems_exp);
|
||||||
|
@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) {
|
||||||
auto prim = std::make_shared<Primitive>("list_append");
|
auto prim = std::make_shared<Primitive>("list_append");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
|
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
|
||||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||||
|
@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) {
|
||||||
auto prim = std::make_shared<Primitive>("tuple_setitem");
|
auto prim = std::make_shared<Primitive>("tuple_setitem");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
AbstractBasePtrList elems_exp = {elem1, elem2};
|
AbstractBasePtrList elems_exp = {elem1, elem2};
|
||||||
auto expected = std::make_shared<AbstractTuple>(elems_exp);
|
auto expected = std::make_shared<AbstractTuple>(elems_exp);
|
||||||
|
@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) {
|
||||||
auto prim = std::make_shared<Primitive>("make_list");
|
auto prim = std::make_shared<Primitive>("make_list");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) {
|
||||||
AbstractBasePtrList elem_list({ele1, ele2, ele3});
|
AbstractBasePtrList elem_list({ele1, ele2, ele3});
|
||||||
AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
|
AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "res=" << res->ToString();
|
MS_LOG(INFO) << "res=" << res->ToString();
|
||||||
MS_LOG(INFO) << "expected=" << expected->ToString();
|
MS_LOG(INFO) << "expected=" << expected->ToString();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
|
@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) {
|
||||||
AbstractBasePtr expected1 = abstract_mean_var->Clone();
|
AbstractBasePtr expected1 = abstract_mean_var->Clone();
|
||||||
AbstractBasePtr expected2 = abstract_mean_var->Clone();
|
AbstractBasePtr expected2 = abstract_mean_var->Clone();
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
||||||
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
||||||
|
@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) {
|
||||||
AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
||||||
std::make_shared<Shape>(std::vector<int>{79}));
|
std::make_shared<Shape>(std::vector<int>{79}));
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "res=" << res->ToString();
|
MS_LOG(INFO) << "res=" << res->ToString();
|
||||||
MS_LOG(INFO) << "expected=" << expected->ToString();
|
MS_LOG(INFO) << "expected=" << expected->ToString();
|
||||||
ASSERT_TRUE(*res == *expected);
|
ASSERT_TRUE(*res == *expected);
|
||||||
|
@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) {
|
||||||
std::vector<int> shape = {2, 20, 32, 32};
|
std::vector<int> shape = {2, 20, 32, 32};
|
||||||
expected->set_shape(std::make_shared<Shape>(shape));
|
expected->set_shape(std::make_shared<Shape>(shape));
|
||||||
|
|
||||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
MS_LOG(INFO) << "result: " << res->ToString();
|
MS_LOG(INFO) << "result: " << res->ToString();
|
||||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||||
|
|
||||||
|
@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
|
||||||
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
||||||
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
||||||
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTuple>(ret);
|
auto res = dyn_cast<AbstractTuple>(ret);
|
||||||
AbstractBasePtrList x_idx_list;
|
AbstractBasePtrList x_idx_list;
|
||||||
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
||||||
|
@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
|
||||||
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
||||||
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
||||||
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
auto res = dyn_cast<AbstractTuple>(ret);
|
auto res = dyn_cast<AbstractTuple>(ret);
|
||||||
AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
|
AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
|
||||||
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
||||||
|
@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) {
|
||||||
AbstractBasePtr key = abstract::FromValue("x");
|
AbstractBasePtr key = abstract::FromValue("x");
|
||||||
AbstractBasePtrList args_spec_list = {array_dict, key};
|
AbstractBasePtrList args_spec_list = {array_dict, key};
|
||||||
|
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
||||||
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
|
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
|
||||||
|
|
||||||
|
@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
|
||||||
AbstractBasePtr key = abstract::FromValue("x");
|
AbstractBasePtr key = abstract::FromValue("x");
|
||||||
AbstractBasePtrList args_spec_list = {array_dict, key};
|
AbstractBasePtrList args_spec_list = {array_dict, key};
|
||||||
|
|
||||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
||||||
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
|
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
|
||||||
|
|
||||||
|
|
|
@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
|
||||||
|
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) {
|
||||||
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
|
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
|
||||||
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||||
|
|
||||||
// now this test case failed randomly, have to debug.
|
// now this test case failed randomly, have to debug.
|
||||||
|
@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) {
|
||||||
args_spec_list.clear();
|
args_spec_list.clear();
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred;
|
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
|
||||||
AbstractBasePtr abstract_v2 = FromValue(v1, false);
|
AbstractBasePtr abstract_v2 = FromValue(v1, false);
|
||||||
args_spec_list.push_back(abstract_v1);
|
args_spec_list.push_back(abstract_v1);
|
||||||
args_spec_list.push_back(abstract_v2);
|
args_spec_list.push_back(abstract_v2);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract();
|
||||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
|
||||||
|
|
||||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
||||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred;
|
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
|
||||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
|
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
|
||||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32);
|
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32);
|
||||||
}
|
}
|
||||||
|
@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
|
||||||
AbstractBasePtr base1 = FromValue(x1, false);
|
AbstractBasePtr base1 = FromValue(x1, false);
|
||||||
AbstractBasePtr base2 = FromValue(x2, false);
|
AbstractBasePtr base2 = FromValue(x2, false);
|
||||||
AbstractBasePtrList base_list = {base1, base2};
|
AbstractBasePtrList base_list = {base1, base2};
|
||||||
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list);
|
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract();
|
||||||
MS_LOG(INFO) << "result spec: " << res->ToString();
|
MS_LOG(INFO) << "result spec: " << res->ToString();
|
||||||
AbstractBasePtr exp = FromValue(x3, false);
|
AbstractBasePtr exp = FromValue(x3, false);
|
||||||
MS_LOG(INFO) << "result exp: " << exp->ToString();
|
MS_LOG(INFO) << "result exp: " << exp->ToString();
|
||||||
|
@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
|
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(50), false);
|
AbstractBasePtr expect = FromValue(MakeValue(50), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
|
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(1), false);
|
AbstractBasePtr expect = FromValue(MakeValue(1), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
|
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(9), false);
|
AbstractBasePtr expect = FromValue(MakeValue(9), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
|
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(48), false);
|
AbstractBasePtr expect = FromValue(MakeValue(48), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
|
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(7), false);
|
AbstractBasePtr expect = FromValue(MakeValue(7), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
|
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(46), false);
|
AbstractBasePtr expect = FromValue(MakeValue(46), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
|
||||||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
|
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
|
||||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
|
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
|
||||||
AbstractBasePtrList args_spec_list = {};
|
AbstractBasePtrList args_spec_list = {};
|
||||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||||
AbstractBasePtr expect = FromValue(MakeValue(57), false);
|
AbstractBasePtr expect = FromValue(MakeValue(57), false);
|
||||||
ASSERT_EQ(*res, *expect);
|
ASSERT_EQ(*res, *expect);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||||
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||||
import pipeline_for_verify_exception_for_case_by_case_config
|
import pipeline_for_verify_exception_for_case_by_case_config
|
||||||
|
from mindspore import context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
|
||||||
def conv3x3(in_channels, out_channels, stride=1, padding=1):
|
def conv3x3(in_channels, out_channels, stride=1, padding=1):
|
||||||
"""3x3 convolution """
|
"""3x3 convolution """
|
||||||
|
@ -377,6 +378,21 @@ class StateNet(nn.Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv2d_same_primitive():
|
||||||
|
class Conv2DSameNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Conv2DSameNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||||
|
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||||
|
def construct(self, x, y):
|
||||||
|
r1 = self.conv1(x)
|
||||||
|
r2 = self.conv2(y)
|
||||||
|
return (r1, r2)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
net = Conv2DSameNet()
|
||||||
|
out = net(t1, t2)
|
||||||
|
|
||||||
class ComparisonNet(nn.Cell):
|
class ComparisonNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
""" ComparisonNet definition """
|
""" ComparisonNet definition """
|
||||||
|
|
|
@ -0,0 +1,276 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" test nn ops """
|
||||||
|
import functools
|
||||||
|
import numpy as np
|
||||||
|
import mindspore
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||||
|
from mindspore.ops.primitive import constexpr
|
||||||
|
|
||||||
|
from ..ut_filter import non_graph_engine
|
||||||
|
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||||
|
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||||
|
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||||
|
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||||
|
import pipeline_for_verify_exception_for_case_by_case_config
|
||||||
|
from mindspore import context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||||
|
|
||||||
|
class FakeOp(PrimitiveWithInfer):
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
""""""
|
||||||
|
def infer_shape(self, x, y):
|
||||||
|
self.second_shape = y
|
||||||
|
self.add_prim_attr("second_shape", y)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def infer_dtype(self, x, y):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# test the normal case that should generate independent primitive because of different
|
||||||
|
# generated attributes after inference
|
||||||
|
def test_conv2d_same_primitive():
|
||||||
|
class Conv2DSameNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Conv2DSameNet, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||||
|
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||||
|
def construct(self, x, y):
|
||||||
|
r1 = self.conv1(x)
|
||||||
|
r2 = self.conv2(y)
|
||||||
|
return (r1, r2)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
net = Conv2DSameNet()
|
||||||
|
out = net(t1, t2)
|
||||||
|
|
||||||
|
# test cell as high order argument
|
||||||
|
# The graph with free variables used as argument is not supported yet
|
||||||
|
# because of the limit of inference specialize system
|
||||||
|
def Xtest_conv2d_op_with_arg():
|
||||||
|
class Conv2dNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Conv2dNet, self).__init__()
|
||||||
|
def construct(self, op, x):
|
||||||
|
return op(x)
|
||||||
|
class OpsNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(OpsNet, self).__init__()
|
||||||
|
self.opnet = net
|
||||||
|
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
|
||||||
|
def construct(self, x, y):
|
||||||
|
conv_op = self.conv2
|
||||||
|
a = self.opnet(conv_op, x)
|
||||||
|
b = self.opnet(conv_op, y)
|
||||||
|
return (a, b)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
net = OpsNet(Conv2dNet())
|
||||||
|
out = net(t1, t2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv2d_op_with_arg():
|
||||||
|
class FackOpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(FackOpNet, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.op(x, y)
|
||||||
|
class OpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OpNet, self).__init__()
|
||||||
|
def construct(self, op, x, y):
|
||||||
|
return op(x, y)
|
||||||
|
class OpsNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(OpsNet, self).__init__()
|
||||||
|
self.opnet = net
|
||||||
|
self.op = FackOpNet()
|
||||||
|
def construct(self, x, y):
|
||||||
|
op = self.op
|
||||||
|
a = self.opnet(op, x, y)
|
||||||
|
b = self.opnet(op, y, x)
|
||||||
|
return (a, b)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
net = OpsNet(OpNet())
|
||||||
|
out = net(t1, t2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_conv2d_op_with_arg_same_input():
|
||||||
|
class FackOpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(FackOpNet, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.op(x, y)
|
||||||
|
class OpNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OpNet, self).__init__()
|
||||||
|
def construct(self, op, x, y):
|
||||||
|
return op(x, y)
|
||||||
|
class OpsNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(OpsNet, self).__init__()
|
||||||
|
self.opnet = net
|
||||||
|
self.op = FackOpNet()
|
||||||
|
def construct(self, x, y):
|
||||||
|
op = self.op
|
||||||
|
a = self.opnet(op, x, x)
|
||||||
|
b = self.opnet(op, y, x)
|
||||||
|
return (a, b)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
net = OpsNet(OpNet())
|
||||||
|
out = net(t1, t2)
|
||||||
|
|
||||||
|
# test op with partial
|
||||||
|
def test_op_as_partial():
|
||||||
|
class OpAsPartial(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OpAsPartial, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
partial_op = F.partial(self.op, x)
|
||||||
|
a = partial_op(y)
|
||||||
|
b = partial_op(z)
|
||||||
|
return a, b
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = OpAsPartial()
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
|
||||||
|
# test op with partial
|
||||||
|
def test_op_as_partial_inside():
|
||||||
|
class OpAsPartial(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OpAsPartial, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
partial_op = F.partial(self.op, x)
|
||||||
|
a = partial_op(y)
|
||||||
|
b = partial_op(z)
|
||||||
|
return a, b
|
||||||
|
class OuterNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OuterNet, self).__init__()
|
||||||
|
self.net = OpAsPartial()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
a,b = self.net(x, y, z)
|
||||||
|
return a, b
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = OuterNet()
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
|
||||||
|
# test op with partial case 2
|
||||||
|
def test_op_as_partial_independent():
|
||||||
|
class OpAsPartial(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(OpAsPartial, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
partial_op1 = F.partial(self.op, x)
|
||||||
|
a = partial_op1(y)
|
||||||
|
partial_op2 = F.partial(self.op, x)
|
||||||
|
b = partial_op2(z)
|
||||||
|
return a, b
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = OpAsPartial()
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
|
||||||
|
def test_nest_partial():
|
||||||
|
class NestPartial(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NestPartial, self).__init__()
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
partial_op1 = F.partial(self.op)
|
||||||
|
partial_op2 = F.partial(partial_op1, x)
|
||||||
|
a = partial_op2(y)
|
||||||
|
partial_op3 = F.partial(self.op)
|
||||||
|
partial_op4 = F.partial(partial_op3, x)
|
||||||
|
b = partial_op4(z)
|
||||||
|
return a, b
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = NestPartial()
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
|
||||||
|
# high order argument
|
||||||
|
# op and op args as network arguments
|
||||||
|
def test_op_with_arg_as_input():
|
||||||
|
class WithOpArgNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(WithOpArgNet, self).__init__()
|
||||||
|
def construct(self, op, x, y):
|
||||||
|
return op(x, y)
|
||||||
|
class OpsNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(OpsNet, self).__init__()
|
||||||
|
self.opnet = net
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
op = self.op
|
||||||
|
a = self.opnet(op, x, z)
|
||||||
|
b = self.opnet(op, x, y)
|
||||||
|
return (a, b)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = OpsNet(WithOpArgNet())
|
||||||
|
out = net(t1, t2, t3)
|
||||||
|
|
||||||
|
# The partial application used as argument is not supported yet
|
||||||
|
# because of the limit of inference specialize system
|
||||||
|
def Xtest_partial_as_arg():
|
||||||
|
class PartialArgNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(PartialArgNet, self).__init__()
|
||||||
|
def construct(self, partial_op, y):
|
||||||
|
return partial_op(y)
|
||||||
|
class OpsNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(OpsNet, self).__init__()
|
||||||
|
self.partial_net = net
|
||||||
|
self.op = FakeOp()
|
||||||
|
def construct(self, x, y, z):
|
||||||
|
partial_op = F.partial(self.op, x)
|
||||||
|
a = self.partial_net(partial_op, z)
|
||||||
|
b = self.partial_net(partial_op, y)
|
||||||
|
return (a, b)
|
||||||
|
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||||
|
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||||
|
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||||
|
net = OpsNet(PartialArgNet())
|
||||||
|
out = net(t1, t2, t3)
|
Loading…
Reference in New Issue