forked from mindspore-Ecosystem/mindspore
!1383 keep different attributes for cnode evaluation
Merge pull request !1383 from amongo/KeepPrimAttrInCNode
This commit is contained in:
commit
5b9c145ff8
|
@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
|
|||
auto ctx = node_cfg_->context();
|
||||
auto engine = node_cfg_->engine();
|
||||
auto cfg = engine->MakeConfig(node, ctx);
|
||||
auto abs = engine->cache().GetValue(cfg);
|
||||
if (abs == nullptr) {
|
||||
auto eval_result = engine->cache().GetValue(cfg);
|
||||
if (eval_result == nullptr || eval_result->abstract() == nullptr) {
|
||||
return "Undefined";
|
||||
}
|
||||
|
||||
auto abs = eval_result->abstract();
|
||||
auto dtype = abs->BuildType();
|
||||
auto shape = abs->BuildShape();
|
||||
std::ostringstream oss;
|
||||
|
|
|
@ -42,7 +42,11 @@ enum PrimType {
|
|||
class Primitive : public Named {
|
||||
public:
|
||||
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
|
||||
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {}
|
||||
: Named(name),
|
||||
is_base_(is_base),
|
||||
has_signature_(false),
|
||||
prim_type_(prim_type),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
Primitive(const Primitive &prim)
|
||||
: Named(prim),
|
||||
|
@ -50,14 +54,23 @@ class Primitive : public Named {
|
|||
instance_name_(prim.instance_name_),
|
||||
is_base_(prim.is_base_),
|
||||
has_signature_(prim.has_signature_),
|
||||
prim_type_(prim.prim_type_) {}
|
||||
prim_type_(prim.prim_type_),
|
||||
record_evaluate_add_attr_(false) {}
|
||||
|
||||
MS_DECLARE_PARENT(Primitive, Named);
|
||||
|
||||
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
|
||||
std::string ToString() const override { return name(); }
|
||||
void BeginRecordAddAttr() {
|
||||
evaluate_added_attrs_.clear();
|
||||
record_evaluate_add_attr_ = true;
|
||||
}
|
||||
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
|
||||
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
|
||||
attrs_[name] = attr;
|
||||
if (record_evaluate_add_attr_) {
|
||||
evaluate_added_attrs_[name] = attr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -80,6 +93,7 @@ class Primitive : public Named {
|
|||
py::function hook() const { return hook_; }
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
|
||||
|
||||
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
|
||||
bool HasAttr() const { return !attrs_.empty(); }
|
||||
|
@ -106,6 +120,7 @@ class Primitive : public Named {
|
|||
|
||||
protected:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
|
||||
|
||||
private:
|
||||
std::string instance_name_;
|
||||
|
@ -113,6 +128,7 @@ class Primitive : public Named {
|
|||
bool is_base_;
|
||||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
bool record_evaluate_add_attr_;
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
|
||||
|
|
|
@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
|
|||
}
|
||||
subargs.push_back(AbstractJoin(l_ptr->elements()));
|
||||
}
|
||||
AbstractBasePtr engin_exc = engine->Execute(fn, subargs);
|
||||
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
|
||||
AbstractBasePtrList result;
|
||||
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||
result.push_back(engin_exc);
|
||||
result.push_back(engin_exc->abstract());
|
||||
}
|
||||
return std::make_shared<AbstractList>(result);
|
||||
}
|
||||
|
@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
|
|||
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
||||
auto result1 = engine->Execute(fn, lst->elements());
|
||||
auto result2 = engine->Execute(fn, {dflt, list_type});
|
||||
MS_EXCEPTION_IF_NULL(result1);
|
||||
return result1->Join(result2);
|
||||
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||
return result1->abstract()->Join(result2->abstract());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
|
|||
return sorted_nodes;
|
||||
}
|
||||
|
||||
AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
|
||||
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::size_t nargs = fg->parameters().size();
|
||||
|
@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
|
|||
const auto &arg = args_spec_list[i];
|
||||
const auto &node = parameters[i];
|
||||
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
|
||||
engine->cache().set_value(conf, arg);
|
||||
engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
|
||||
}
|
||||
const AnfNodePtr &func_node = fg->get_return();
|
||||
|
||||
|
@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
|
|||
const auto &node = *it;
|
||||
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
|
||||
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
|
||||
ret_base = engine->GetEvaluatedValue(node_conf);
|
||||
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
|
||||
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
|
||||
<< ", abstract: " << ret_base->ToString();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(ret_base);
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString();
|
||||
return ret_base;
|
||||
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
|
||||
return std::make_shared<EvalResult>(ret_base, nullptr);
|
||||
}
|
||||
|
||||
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
|
|||
return cloned_func_graph;
|
||||
}
|
||||
|
||||
AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
|
||||
const std::string &evaluator_name = ToString();
|
||||
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
args_spec_list = NormalizeArgs(args_spec_list);
|
||||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||
|
@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
|
|||
auto iter = cache_->find(args_spec_list);
|
||||
if (iter == cache_->end()) {
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
|
||||
AbstractBasePtr ret = Eval(engine, args_spec_list);
|
||||
if (ret == nullptr) {
|
||||
EvalResultPtr ret = Eval(engine, args_spec_list);
|
||||
if (ret->abstract() == nullptr) {
|
||||
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
|
||||
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
|
||||
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return ret;
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
|
||||
MS_EXCEPTION_IF_NULL(iter->second->abstract());
|
||||
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
|
||||
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
|
||||
return iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr) {
|
||||
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
|
||||
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
if (args_conf_list.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Size should greater than 0";
|
||||
}
|
||||
AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||
EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
|
||||
// No need to cache.
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
AbstractBasePtr ret = EvalPrim(args_conf_list);
|
||||
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
EvalResultPtr ret = EvalPrim(args_conf_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
||||
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
|
||||
// Don't lookup from cache, as different out_conf with same node but different context
|
||||
// may add different entry to anfnode_config_map_, like getattr primitive.
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(cache_);
|
||||
auto iter = cache_->find(args_spec_list);
|
||||
|
@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
|
|||
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
||||
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
|
||||
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
MS_EXCEPTION_IF_NULL(cache_);
|
||||
auto iter = cache_->find(args_spec_list);
|
||||
|
@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
|
|||
}
|
||||
|
||||
// Call the original evaluator, get the result: y = f(x)
|
||||
AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
|
||||
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
|
||||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||
AbstractBasePtrList bparams;
|
||||
|
@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
|
|||
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
|
||||
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
||||
AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final);
|
||||
AbstractFunctionPtr bprop =
|
||||
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
|
||||
|
||||
// J(f)(J(x)) return a tuple (y, bprop_f)
|
||||
AbstractBasePtrList jargs = {result, bprop};
|
||||
AbstractBasePtrList jargs = {result->abstract(), bprop};
|
||||
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
|
||||
(*cache_)[args_spec_list] = jtuple;
|
||||
return jtuple;
|
||||
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
|
||||
(*cache_)[args_spec_list] = infer_reuslt;
|
||||
return infer_reuslt;
|
||||
}
|
||||
|
||||
AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
||||
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.size() != args_spec_list_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
|
||||
<< ", arguments no: " << args_spec_list.size();
|
||||
|
@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
|
|||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
(void)args_spec_list[i]->Join(args_spec_list_[i]);
|
||||
}
|
||||
return output_;
|
||||
return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,21 +29,28 @@
|
|||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using EvaluatorCacheMap =
|
||||
std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
|
||||
|
||||
using EvaluatorAttrMap =
|
||||
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
|
||||
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
|
||||
|
||||
class Evaluator : public Base {
|
||||
public:
|
||||
explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {}
|
||||
explicit Evaluator(const std::string &id)
|
||||
: cache_(std::make_shared<EvaluatorCacheMap>()),
|
||||
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
|
||||
identifier_(id) {}
|
||||
~Evaluator() override = default;
|
||||
MS_DECLARE_PARENT(Evaluator, Base);
|
||||
|
||||
// difference between Run() and Eval():
|
||||
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
|
||||
// Run() will modify cache_ member, so it cannot marked as const;
|
||||
virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
|
||||
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
|
||||
|
||||
virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
|
||||
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
|
||||
|
||||
|
@ -58,9 +65,10 @@ class Evaluator : public Base {
|
|||
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
|
||||
|
||||
EvaluatorCacheMapPtr &cache() { return cache_; }
|
||||
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
|
||||
|
||||
EvaluatorCacheMapPtr cache_;
|
||||
|
||||
EvaluatorAttrMapPtr attr_cache_;
|
||||
std::string identifier_;
|
||||
|
||||
AnfNodeWeakPtr bound_node_;
|
||||
|
@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator {
|
|||
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
|
||||
~PrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
};
|
||||
|
@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
|
|||
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~TrivialPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
};
|
||||
|
||||
class TransitionPrimEvaluator : public PrimEvaluator {
|
||||
|
@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator {
|
|||
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~TransitionPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
// Parameter in_conf0 : the first element in args_conf_list;
|
||||
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
|
||||
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
|
||||
};
|
||||
|
||||
class SymbolicPrimEvaluator : public PrimEvaluator {
|
||||
|
@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
|
|||
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
|
||||
~SymbolicPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
|
||||
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
|
||||
};
|
||||
|
||||
// Evaluator will be stored in AnalysisEngine.constructors_
|
||||
|
@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator {
|
|||
DummyEvaluator() : Evaluator("dummy") {}
|
||||
~DummyEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
|
||||
};
|
||||
|
||||
// Wrap another evaluator to track a subset of uses.
|
||||
|
@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator {
|
|||
bound_node_ = AnfNodeWeakPtr(node);
|
||||
}
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
~BaseFuncGraphEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||
|
||||
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
|
||||
|
||||
|
@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator {
|
|||
}
|
||||
bound_node_ = AnfNodeWeakPtr(node);
|
||||
}
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator {
|
|||
~VirtualEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
|
||||
std::string ToString() const override { return identifier_; }
|
||||
|
||||
private:
|
||||
|
@ -285,11 +292,11 @@ class JEvaluator : public Evaluator {
|
|||
}
|
||||
bound_node_ = AnfNodeWeakPtr(node);
|
||||
}
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
|
||||
}
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
|
||||
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
|
||||
|
||||
private:
|
||||
|
|
|
@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
|
||||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
|
||||
prim_->BeginRecordAddAttr();
|
||||
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
|
||||
return abs_base;
|
||||
prim_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_->evaluate_added_attrs();
|
||||
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
|
||||
MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString();
|
||||
|
@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
|
|||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
|
||||
ScopePtr scope = kDefaultScope;
|
||||
if (out_conf != nullptr) {
|
||||
|
@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
|||
return graph_specialize_args;
|
||||
}
|
||||
|
||||
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf) {
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
}
|
||||
|
@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config
|
|||
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
|
||||
// get the forward graph
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
|
||||
|
@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
}
|
||||
} // end anonymous namespace
|
||||
|
||||
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||
|
||||
const auto &iter = cache_->find(args);
|
||||
|
@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
|
|||
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
|
||||
}
|
||||
auto infer_fuc = pyobj.attr("__infer__");
|
||||
|
||||
prim_py_->BeginRecordAddAttr();
|
||||
py::dict output = infer_fuc(*py_args);
|
||||
prim_py_->EndRecordAddAttr();
|
||||
auto added_attrs = prim_py_->evaluate_added_attrs();
|
||||
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
|
||||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||
|
||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||
(*cache_)[args] = res_spec;
|
||||
return res_spec;
|
||||
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
|
||||
(*cache_)[args] = infer_result;
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
|
||||
if (nargs_ != args.size()) {
|
||||
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
|
||||
|
@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
|
|||
}
|
||||
|
||||
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
|
||||
return abs_base;
|
||||
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
|
||||
|
@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
|
|||
manager->AddFuncGraph(func_graph);
|
||||
}
|
||||
|
||||
AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &old_conf) {
|
||||
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &old_conf) {
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
|
||||
AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
||||
|
@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
|
|||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
// args_spec_list: same as StaticGetter
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
|
||||
|
@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
|
|||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list is empty";
|
||||
}
|
||||
|
@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
|
|||
|
||||
AbstractBasePtr attr = cls->GetAttribute(item_name);
|
||||
if (attr != nullptr) {
|
||||
return attr;
|
||||
return std::make_shared<EvalResult>(attr, nullptr);
|
||||
}
|
||||
|
||||
ValuePtr method = cls->GetMethod(item_name);
|
||||
|
@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
|
|||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
// The method maybe a Primitive or Composite
|
||||
|
@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e
|
|||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
// Inputs: namespace and its static function; or class and its member function
|
||||
CheckArgsSize("StaticGetter", args_spec_list, 2);
|
||||
|
||||
|
@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
|
||||
~EmbedEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||
// arg: free variable to be embedded
|
||||
if (args_conf_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
|
||||
|
@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
|
||||
MS_EXCEPTION_IF_NULL(node_conf);
|
||||
|
||||
AbstractBasePtr x = node_conf->GetEvaluatedValue();
|
||||
AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
|
||||
x = SensitivityTransform(x);
|
||||
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
|
||||
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
|
||||
return abs_scalar;
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
|
||||
~RefToEmbedEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
|
||||
if (args_conf_list.size() != 1) {
|
||||
MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
|
||||
return nullptr;
|
||||
|
@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
|
||||
return nullptr;
|
||||
}
|
||||
AbstractBasePtr abs = node_conf->GetEvaluatedValue();
|
||||
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
|
||||
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
|
||||
if (ref_abs == nullptr) {
|
||||
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
|
||||
|
@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
}
|
||||
auto refkey = key_value->cast<RefKeyPtr>();
|
||||
if (refkey == nullptr) {
|
||||
return std::make_shared<AbstractScalar>(type);
|
||||
return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
std::string name = refkey->tag();
|
||||
|
@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
|
|||
x = SensitivityTransform(x);
|
||||
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
|
||||
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
|
||||
return abs_scalar;
|
||||
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
|
||||
~GetAttrEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
// Inputs: data, item
|
||||
if (args_spec_list.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||
}
|
||||
AbstractBasePtr ret = nullptr;
|
||||
EvalResultPtr ret = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
|
@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
|
|||
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
|
||||
~ResolveEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
// Inputs: namespace, symbol
|
||||
if (args_spec_list.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||
}
|
||||
AbstractBasePtr ret = nullptr;
|
||||
EvalResultPtr ret = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
|
@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
|
||||
~CreateInstanceEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override {
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||
}
|
||||
|
@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
|
||||
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
auto infer_result = std::make_shared<EvalResult>(ret, nullptr);
|
||||
(*cache_)[args_spec_list] = infer_result;
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||
|
@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator {
|
|||
public:
|
||||
PartialEvaluator() : Evaluator("PartialEvaluator") {}
|
||||
~PartialEvaluator() override = default;
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf = nullptr) override {
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
AnfNodeConfigPtr out_conf = nullptr) override {
|
||||
if (args_conf_list.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
|
||||
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
|
||||
auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
|
||||
AbstractBasePtrList args_spec_list{arg0_value};
|
||||
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
||||
if (arg0_value->isa<AbstractError>()) {
|
||||
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
|
||||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||
<< " as func is: " << arg0_value->ToString();
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
(*cache_)[args_spec_list] = eval_result;
|
||||
return eval_result;
|
||||
}
|
||||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||
// Sometimes, node[0] in out_conf becomes phi0;
|
||||
|
@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator {
|
|||
}
|
||||
}
|
||||
|
||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); });
|
||||
(void)std::transform(
|
||||
args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); });
|
||||
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
|
||||
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
|
@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator {
|
|||
func->Visit(build_partial);
|
||||
|
||||
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
(*cache_)[args_spec_list] = infer_result;
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
|
||||
const AnfNodeConfigPtr &out_conf = nullptr) const {
|
||||
EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
|
||||
const AnfNodeConfigPtr &out_conf = nullptr) const {
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
|
|
|
@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
|||
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
|
||||
~StandardPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
PrimitivePtr prim() { return prim_; }
|
||||
|
||||
std::string ToString() const override { return identifier_ + prim_->name(); }
|
||||
|
@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator {
|
|||
: TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
|
||||
~PythonPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
|
||||
|
||||
std::string ToString() const override { return identifier_ + prim_py_->name(); }
|
||||
|
@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator {
|
|||
public:
|
||||
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
|
||||
~DoSignatureEvaluator() override = default;
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
|
@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator {
|
|||
public:
|
||||
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
|
||||
~UnpackGraphEvaluator() override = default;
|
||||
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
|
||||
AnfNodeConfigPtr out_config = nullptr) override;
|
||||
|
||||
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
|
||||
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
|
||||
}
|
||||
|
||||
|
@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator {
|
|||
~UniformPrimEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
|
||||
|
||||
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
|
||||
ValuePtr RunImpl(const ValuePtrList &args) const;
|
||||
|
||||
// If eval_value_ is False, return broadened arguments.
|
||||
|
|
|
@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
|
|||
if (conf->node()->intermediate_abstract()) {
|
||||
return conf->node()->intermediate_abstract();
|
||||
}
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
}
|
||||
|
||||
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
|
||||
|
@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() {
|
|||
|
||||
// Specialize CNode in func graphs
|
||||
void FuncGraphSpecializer::SecondPass() {
|
||||
for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
|
||||
for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) {
|
||||
if (node->isa<CNode>()) {
|
||||
ProcessCNode(node->cast<CNodePtr>());
|
||||
}
|
||||
|
@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
AnfNodeConfigPtr conf = MakeConfig(node);
|
||||
AnfNodePtr new_node = GetReplicatedNode(node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
|
||||
if (new_node->func_graph() != specialized_func_graph_) {
|
||||
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
|
||||
<< ", new_node: " << new_node->DebugString()
|
||||
|
@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
auto attrs = conf->GetEvaluatedValue()->attribute();
|
||||
auto c_old = node->cast<CNodePtr>();
|
||||
auto c_new = new_node->cast<CNodePtr>();
|
||||
auto new_inputs = c_new->inputs();
|
||||
|
@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
|
|||
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
|
||||
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
|
||||
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
|
||||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
|
||||
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
|
||||
if (replace_node == nullptr) {
|
||||
replace_node = BuildReplacedNode(iconf);
|
||||
MS_EXCEPTION_IF_NULL(replace_node);
|
||||
|
@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
|
|||
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
|
||||
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
|
||||
}
|
||||
auto attrs = std::make_shared<AttrValueMap>();
|
||||
for (size_t i = 0; i < partial_closure->args().size(); i++) {
|
||||
auto old_node = cnode->input(i + 2);
|
||||
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
|
||||
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
|
||||
if (possibile_value_node != nullptr) {
|
||||
partial_node_list.push_back(possibile_value_node);
|
||||
} else {
|
||||
|
@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
const EvaluatorPtr &eval) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
|
||||
AbstractBasePtr ret = nullptr;
|
||||
EvalResultPtr ret = nullptr;
|
||||
AbstractBasePtrList broaded_argvals;
|
||||
for (auto &argvals_map : *evalcaches_[eval]) {
|
||||
auto argvals = argvals_map.first;
|
||||
|
@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
|
|||
|
||||
(*real)[broaded_argvals] = ret;
|
||||
evalcaches_[eval] = real;
|
||||
return std::make_pair(broaded_argvals, ret);
|
||||
return std::make_pair(broaded_argvals, ret->abstract());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
|
||||
return std::make_pair(AbstractBasePtrList(), nullptr);
|
||||
|
@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
return;
|
||||
}
|
||||
specializer_->AddSeen(new_node);
|
||||
|
||||
auto new_inputs = new_node->inputs();
|
||||
if (new_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
|
||||
|
@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
}
|
||||
|
||||
if (CanSpecializeNode(func)) {
|
||||
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
||||
// for primitive node , we build the primitive node with infered attributes in the first pass
|
||||
// so we do not build replaced node again here in second pass
|
||||
if (IsValueNode<Primitive>(func)) {
|
||||
new_inputs[0] = func;
|
||||
} else {
|
||||
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < argvals.size();) {
|
||||
|
@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
|
|||
}
|
||||
i = next;
|
||||
}
|
||||
|
||||
new_node->set_inputs(new_inputs);
|
||||
}
|
||||
|
||||
|
@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
|
||||
EvaluatorCacheMap evaluator_cache_map = *eval->cache();
|
||||
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
|
||||
*result = std::make_pair(argvals, evaluator_cache_map[argvals]);
|
||||
*result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
|
||||
return kSpecializeSuccess;
|
||||
}
|
||||
DumpEvaluatorCache(evaluator_cache_map, argvals);
|
||||
|
@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
MS_EXCEPTION_IF_NULL(choices);
|
||||
|
||||
if (choices->count(argvals)) {
|
||||
*result = std::make_pair(argvals, (*choices)[argvals]);
|
||||
*result = std::make_pair(argvals, (*choices)[argvals]->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices->size() == 1) {
|
||||
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
|
||||
*result = std::make_pair(choices->begin()->first, choices->begin()->second);
|
||||
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
|
||||
return kSpecializeSuccess;
|
||||
} else if (choices->empty()) {
|
||||
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
|
||||
|
@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
|
|||
return kSpecializeFindUniqueArgvalPoly;
|
||||
}
|
||||
}
|
||||
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
|
||||
auto &prim_attrs = prim->attrs();
|
||||
bool is_attr_same = true;
|
||||
for (auto &item : *attrs) {
|
||||
auto itr = prim_attrs.find(item.first);
|
||||
if (itr != prim_attrs.end()) {
|
||||
if (!(*(itr->second) == *(item.second))) {
|
||||
is_attr_same = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
is_attr_same = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_attr_same) {
|
||||
if (prim->isa<PrimitivePy>()) {
|
||||
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
|
||||
auto clone_fn = prim_py->GetPyObj().attr("_clone");
|
||||
py::object new_obj = clone_fn();
|
||||
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
|
||||
for (auto &item : *attrs) {
|
||||
cloned_prim->AddAttr(item.first, item.second);
|
||||
}
|
||||
return cloned_prim;
|
||||
}
|
||||
auto cloned_prim = std::make_shared<Primitive>(*prim);
|
||||
for (auto &item : *attrs) {
|
||||
cloned_prim->AddAttr(item.first, item.second);
|
||||
}
|
||||
return cloned_prim;
|
||||
}
|
||||
return prim;
|
||||
}
|
||||
|
||||
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
|
||||
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||
const AttrValueMapPtr &attrs) {
|
||||
MS_EXCEPTION_IF_NULL(origin_node);
|
||||
MS_EXCEPTION_IF_NULL(ival);
|
||||
|
||||
|
@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
|
|||
ValuePtr value = nullptr;
|
||||
if (abs->isa<PrimitiveAbstractClosure>()) {
|
||||
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
|
||||
value = real_fn->prim();
|
||||
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
|
||||
if (attrs != nullptr) {
|
||||
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
|
||||
} else {
|
||||
value = real_fn->prim();
|
||||
}
|
||||
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
|
||||
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
|
||||
value = real_fn->meta_func_graph();
|
||||
|
|
|
@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
|
|||
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
|
||||
|
||||
// Build a value node if ival is constant and not any-value
|
||||
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
|
||||
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
|
||||
const AttrValueMapPtr &attrs);
|
||||
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
|
||||
// replicated node.
|
||||
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
|
||||
|
|
|
@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
||||
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
|
||||
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
|
||||
<< ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
|
||||
<< ", Pointer: " << arg.get();
|
||||
cache_[conf] = arg;
|
||||
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
|
||||
<< ", Pointer: " << result->abstract().get();
|
||||
cache_[conf] = result;
|
||||
|
||||
// Set intermediate abstract value.
|
||||
if (IsIntermediateAbstract(arg)) {
|
||||
if (IsIntermediateAbstract(result->abstract())) {
|
||||
if (conf->node()->intermediate_abstract() == nullptr) {
|
||||
conf->node()->set_intermediate_abstract(arg);
|
||||
MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
|
||||
conf->node()->set_intermediate_abstract(result->abstract());
|
||||
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
|
||||
} else {
|
||||
auto old_spec = conf->node()->intermediate_abstract();
|
||||
auto joined_spec = IntermediateJoin(arg, old_spec);
|
||||
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
|
||||
conf->node()->set_intermediate_abstract(joined_spec);
|
||||
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
|
||||
<< arg->ToString() << "\njoined_spec:\t"
|
||||
<< result->abstract()->ToString() << "\njoined_spec:\t"
|
||||
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
||||
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
|
||||
auto value = cache_.find(conf);
|
||||
if (value == cache_.end()) {
|
||||
return nullptr;
|
||||
|
@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
|
|||
return eval->graph_context();
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
||||
EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
auto value = cache_.GetValue(conf);
|
||||
if (value != nullptr) {
|
||||
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
|
||||
<< value->ToString();
|
||||
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
|
||||
<< ", " << value->abstract()->ToString();
|
||||
return value;
|
||||
}
|
||||
|
||||
|
@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
|
|||
return value;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
||||
EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
AnfNodePtr node = conf->node();
|
||||
AbstractBasePtr ret_abstract = nullptr;
|
||||
EvalResultPtr eval_result = nullptr;
|
||||
#ifdef DEBUG
|
||||
compute_conf_stack_.push_back(node);
|
||||
std::ostringstream buffer;
|
||||
|
@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->abstract() != nullptr) {
|
||||
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
|
||||
ret_abstract = node->abstract();
|
||||
eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
ret_abstract = EvalValueNode(value_node, conf);
|
||||
eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
trace::TraceEvalCNodeEnter(conf);
|
||||
ret_abstract = EvalCNode(cnode, conf);
|
||||
eval_result = EvalCNode(cnode, conf);
|
||||
trace::TraceEvalCNodeLeave();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
|
||||
|
@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
|
|||
|
||||
#ifdef DEBUG
|
||||
compute_conf_stack_.pop_back();
|
||||
if (ret_abstract == nullptr) {
|
||||
if (eval_result == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
#endif
|
||||
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
|
||||
return ret_abstract;
|
||||
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
|
||||
|
@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
|||
return ToAbstract(value_node->value(), conf->context(), conf);
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto &inputs = cnode->inputs();
|
||||
|
@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
|||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
|
||||
MS_EXCEPTION_IF_NULL(func_conf);
|
||||
// Keep it in a local variable, otherwise smart pointer will free it.
|
||||
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
|
||||
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
|
||||
if (maybe_func == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
|
||||
|
@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
|
|||
return ExecuteEvaluators(infs, conf, args_conf_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
||||
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
|
||||
ConfigPtrList args_conf_list;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
|
||||
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
|
||||
|
@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
return tracked_eval;
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
|
||||
if (evaluators.size() == 1) {
|
||||
EvaluatorPtr eval = evaluators[0];
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
|
@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
|
|||
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
AbstractBasePtrList out_specs;
|
||||
if (!multi_poss_.count(evaluators[0])) {
|
||||
multi_poss_[evaluators[0]] = evaluators[1];
|
||||
|
@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->GetEvaluatedValue();
|
||||
return conf->GetEvaluatedValue()->abstract();
|
||||
});
|
||||
for (auto eval : evaluators) {
|
||||
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
|
||||
|
@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
eval_trace_.push_back(current_inf);
|
||||
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_spec);
|
||||
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
|
||||
out_specs.push_back(out_spec);
|
||||
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
|
||||
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
|
||||
out_specs.push_back(eval_result->abstract());
|
||||
eval_trace_.pop_back();
|
||||
if (eval_trace_.empty()) {
|
||||
multi_poss_.clear();
|
||||
|
@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
// Try to travel the latest undetermined.
|
||||
if (latest_entry != eval_trace_.rbegin()->first) {
|
||||
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
|
||||
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_spec);
|
||||
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
|
||||
return out_spec;
|
||||
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
|
||||
MS_EXCEPTION_IF_NULL(eval_result->abstract());
|
||||
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
|
||||
<< " return out_spec: " << eval_result->abstract()->ToString();
|
||||
return eval_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
|
|||
if (out_specs.size() == 1) {
|
||||
MS_EXCEPTION_IF_NULL(out_specs[0]);
|
||||
// If only one result derived, then broaden it to avoid wrong constant propagation.
|
||||
return out_specs[0]->Broaden();
|
||||
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
auto joined_spec = AbstractJoin(out_specs);
|
||||
MS_EXCEPTION_IF_NULL(joined_spec);
|
||||
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
|
||||
return joined_spec;
|
||||
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
|
||||
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
|
||||
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
|
||||
return engine_.lock()->GetEvaluatedValue(self);
|
||||
}
|
||||
|
@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
|
|||
return a;
|
||||
}
|
||||
|
||||
AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
|
||||
auto evaluator = GetPrimEvaluator(primitive, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
if (!evaluator->isa<TrivialPrimEvaluator>()) {
|
||||
|
@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
|
|||
<< evaluator->ToString();
|
||||
}
|
||||
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
|
||||
auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||
return res_spec;
|
||||
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||
return eval_result;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,13 +40,33 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
||||
// define attribute value map
|
||||
using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
|
||||
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
|
||||
|
||||
// the class to save evaluated result: abstract value and modified attribute
|
||||
class EvalResult : public Base {
|
||||
public:
|
||||
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
|
||||
~EvalResult() override = default;
|
||||
MS_DECLARE_PARENT(EvalResult, Base);
|
||||
AbstractBasePtr abstract() { return abstract_; }
|
||||
AttrValueMapPtr attribute() { return attribute_; }
|
||||
|
||||
private:
|
||||
AbstractBasePtr abstract_;
|
||||
AttrValueMapPtr attribute_;
|
||||
};
|
||||
|
||||
using EvalResultPtr = std::shared_ptr<EvalResult>;
|
||||
// Superclass for AnfNodeConfig and VirtualConfig.
|
||||
class Config : public Base {
|
||||
public:
|
||||
Config() = default;
|
||||
~Config() override = default;
|
||||
MS_DECLARE_PARENT(Config, Base);
|
||||
virtual AbstractBasePtr GetEvaluatedValue() = 0;
|
||||
virtual EvalResultPtr GetEvaluatedValue() = 0;
|
||||
};
|
||||
|
||||
// Config will be stored in AnalysisCache
|
||||
|
@ -74,7 +94,7 @@ class AnfNodeConfig : public Config {
|
|||
~AnfNodeConfig() override = default;
|
||||
MS_DECLARE_PARENT(AnfNodeConfig, Config);
|
||||
|
||||
AbstractBasePtr GetEvaluatedValue() override;
|
||||
EvalResultPtr GetEvaluatedValue() override;
|
||||
|
||||
AnalysisContextPtr context() const { return context_; }
|
||||
|
||||
|
@ -123,7 +143,9 @@ class VirtualConfig : public Config {
|
|||
|
||||
~VirtualConfig() override = default;
|
||||
MS_DECLARE_PARENT(VirtualConfig, Config);
|
||||
AbstractBasePtr GetEvaluatedValue() override { return abstract_; }
|
||||
EvalResultPtr GetEvaluatedValue() override {
|
||||
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
private:
|
||||
AbstractBasePtr abstract_;
|
||||
|
@ -135,11 +157,11 @@ class AnalysisCache {
|
|||
AnalysisCache() = default;
|
||||
~AnalysisCache() = default;
|
||||
void Clear() { cache_.clear(); }
|
||||
void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
|
||||
AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf);
|
||||
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
|
||||
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
|
||||
|
||||
private:
|
||||
std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
|
||||
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
|
||||
};
|
||||
|
||||
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
|
||||
|
@ -147,7 +169,7 @@ using AnfNodeConfigMap =
|
|||
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
|
||||
|
||||
struct AnalysisResult {
|
||||
AbstractBasePtr inferred;
|
||||
EvalResultPtr inferred;
|
||||
AnalysisContextPtr context;
|
||||
};
|
||||
|
||||
|
@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
// func_graph: The func_graph to analyze.
|
||||
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
|
||||
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
|
||||
// Return the Evaluator for the given function.
|
||||
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||
|
||||
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
|
||||
AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
|
||||
// Infer the result of fn(args).
|
||||
AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
||||
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
|
||||
void Clear();
|
||||
void ClearEvaluatorCache();
|
||||
AnalysisCache &cache() { return cache_; }
|
||||
|
@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
// Set the analysis result for orig to the result for new.
|
||||
// This sets an entry in anfnode_config_map from orig to new.
|
||||
AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
||||
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
|
||||
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
|
||||
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
|
||||
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
|
||||
|
@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
|
||||
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
AbstractBasePtr Eval(const AnfNodeConfigPtr &conf);
|
||||
EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
|
||||
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
|
||||
AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list);
|
||||
|
||||
#ifdef DEBUG
|
||||
std::vector<AnfNodePtr> compute_conf_stack_;
|
||||
|
@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
|||
return FromValueInside(MakeValue(value), broaden);
|
||||
}
|
||||
|
||||
AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
|
|||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
||||
}
|
||||
}
|
||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list);
|
||||
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
|
||||
op_exec_info->abstract = infer_res;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include <list>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
|
||||
#include "ir/visitor.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
|||
return res;
|
||||
}
|
||||
|
||||
// search the cnodes inside this graph only
|
||||
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
|
||||
std::queue<CNodePtr> todo;
|
||||
todo.push(ret);
|
||||
std::vector<CNodePtr> sorted_nodes;
|
||||
auto seen = NewSeenGeneration();
|
||||
while (!todo.empty()) {
|
||||
CNodePtr top = todo.front();
|
||||
todo.pop();
|
||||
sorted_nodes.push_back(top);
|
||||
auto inputs = top->inputs();
|
||||
for (auto &item : inputs) {
|
||||
if (item->seen_ == seen) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item->isa<CNode>()) {
|
||||
todo.push(item->cast<CNodePtr>());
|
||||
}
|
||||
item->seen_ = seen;
|
||||
}
|
||||
}
|
||||
return sorted_nodes;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
|
|
|
@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
|
|||
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
|
||||
const IncludeFunc &include = AlwaysInclude);
|
||||
|
||||
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
|
||||
class FuncGraphIndex {
|
||||
public:
|
||||
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
|
||||
|
|
|
@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init ExpandDims"""
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x, axis):
|
||||
|
@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer):
|
|||
# if primitive need setattr in __infer__ need add this flag
|
||||
"""init Cast"""
|
||||
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def __infer__(self, x, t):
|
||||
src_type = x['dtype']
|
||||
|
@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""init Reshape"""
|
||||
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def __infer__(self, x, shape):
|
||||
shape_v = shape['value']
|
||||
|
@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init Transpose"""
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
|
||||
|
||||
def __infer__(self, x, perm):
|
||||
|
@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init index_select"""
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
||||
|
||||
def __infer__(self, params, indices, axis):
|
||||
|
@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
"""init Tile"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
|
||||
def __infer__(self, input_x):
|
||||
|
@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
"""init Pack"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.axis = axis
|
||||
|
||||
|
@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, axis=0):
|
||||
"""init Unpack"""
|
||||
self.__setattr_flag__ = True
|
||||
validator.check_value_type("axis", axis, [int], self.name)
|
||||
self.axis = axis
|
||||
|
||||
|
@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init"""
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def infer_shape(self, cond_shape, x_shape, y_shape):
|
||||
if cond_shape != x_shape or x_shape != y_shape:
|
||||
|
|
|
@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
@ -596,7 +595,6 @@ class BatchMatMul(MatMul):
|
|||
@prim_attr_register
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
|
||||
self.__setattr_flag__ = True
|
||||
cls_name = self.name
|
||||
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
|
||||
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
||||
|
@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.__setattr_flag__ = True
|
||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
|
|
|
@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer):
|
|||
"""init Conv2D"""
|
||||
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
|
||||
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
|
||||
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
|
||||
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('stride', self.stride)
|
||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||
self.add_prim_attr('dilation', self.dilation)
|
||||
validator.check_value_type('pad', pad, (int,), self.name)
|
||||
|
@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer):
|
|||
|
||||
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
|
||||
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
|
||||
|
||||
out_channel = self.out_channel
|
||||
out_shape = [x_shape[0], out_channel, h_out, w_out]
|
||||
return out_shape
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test nn ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
|
||||
def test_cast_op_attr():
|
||||
class CastNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CastNet, self).__init__()
|
||||
self.cast = P.Cast()
|
||||
def construct(self, x, t):
|
||||
return self.cast(x, t)
|
||||
|
||||
class CastTypeTest(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(CastTypeTest, self).__init__()
|
||||
self.net = net
|
||||
self.cast = P.Cast()
|
||||
def construct(self, x, y, z):
|
||||
cast_op = self.cast
|
||||
t1 = cast_op(x, mstype.float32)
|
||||
t2 = cast_op(y, mstype.int32)
|
||||
cast_net = self.net
|
||||
t3 = cast_net(x, mstype.float16)
|
||||
t4 = cast_net(y, mstype.int32)
|
||||
t5 = cast_net(z, mstype.float16)
|
||||
return (t1, t2, t3, t4, t5)
|
||||
net = CastTypeTest(CastNet())
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
|
||||
out = net(t1, t2, t3)
|
||||
assert out[0].asnumpy().dtype == np.float32
|
||||
assert out[1].asnumpy().dtype == np.int32
|
||||
assert out[2].asnumpy().dtype == np.float16
|
||||
assert out[3].asnumpy().dtype == np.int32
|
||||
assert out[4].asnumpy().dtype == np.float16
|
|
@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
|
|||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
|
|||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
|
|||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
|||
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
|
||||
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) {
|
|||
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
|
|||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
|
|||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) {
|
|||
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
|
||||
AbstractBasePtrList args_spec_list = {tensor, start_index};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
|
|||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
|
|||
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
|
||||
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
|
||||
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract array failed.";
|
||||
}
|
||||
|
@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
|
|||
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
||||
|
||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
|
|||
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
|
||||
|
||||
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) {
|
|||
auto tuple = std::make_shared<AbstractTuple>(eles);
|
||||
AbstractBasePtrList args_spec_list = {tuple};
|
||||
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred);
|
||||
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
|
||||
if (ret == nullptr) {
|
||||
FAIL() << "Cast ret to abstract tuple failed.";
|
||||
}
|
||||
|
|
|
@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
|
|||
AbstractBasePtr abstract_v2 = FromValue(2, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2};
|
||||
AbstractBasePtr abstract_val = FromValue(10, false);
|
||||
cache[args_spec_list] = abstract_val;
|
||||
cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>());
|
||||
|
||||
auto iter = cache.find(args_spec_list);
|
||||
ASSERT_TRUE(iter != cache.end());
|
||||
ASSERT_TRUE(iter->second == abstract_val);
|
||||
ASSERT_TRUE(iter->second->abstract() == abstract_val);
|
||||
|
||||
AbstractBasePtr abstract_v1_variant1 = FromValue(1, false);
|
||||
AbstractBasePtr abstract_v2_variant1 = FromValue(2, false);
|
||||
|
@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
|
|||
|
||||
iter = cache.find(args_spec_list_variant1);
|
||||
ASSERT_TRUE(iter != cache.end());
|
||||
ASSERT_TRUE(iter->second == abstract_val);
|
||||
ASSERT_TRUE(iter->second->abstract() == abstract_val);
|
||||
|
||||
AbstractBasePtr abstract_v1_variant2 = FromValue(1, false);
|
||||
AbstractBasePtr abstract_v2_variant2 = FromValue(3, false);
|
||||
|
@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
|
|||
std::vector<int> shape = {2, 2, 6, 6};
|
||||
expected->set_shape(std::make_shared<Shape>(shape));
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||
|
||||
|
@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
|
|||
AbstractBasePtr abstract_x = FromValue(x, false);
|
||||
args_spec_list.push_back(abstract_x);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
|
|||
AbstractBasePtr abstract_x = FromValue(x, false);
|
||||
args_spec_list.push_back(abstract_x);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
|
|||
args_spec_list.push_back(abstract_x);
|
||||
args_spec_list.push_back(abstract_y);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
|
|||
args_spec_list.push_back(abstract_x);
|
||||
args_spec_list.push_back(abstract_y);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||
}
|
||||
|
@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
|
|||
args_spec_list.push_back(abstract_x);
|
||||
args_spec_list.push_back(abstract_y);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
|
|||
args_spec_list.push_back(abstract_x);
|
||||
args_spec_list.push_back(abstract_y);
|
||||
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) {
|
|||
|
||||
auto prim_typeof = std::make_shared<Primitive>("typeof");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
res->dump();
|
||||
TypePtr res_value = res->GetValueTrack()->cast<TypePtr>();
|
||||
res_value->dump();
|
||||
|
@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) {
|
|||
|
||||
auto prim_list_map = std::make_shared<Primitive>("list_map");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)}));
|
||||
res->dump();
|
||||
MS_LOG(INFO) << "result res: " << res->ToString();
|
||||
|
@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) {
|
|||
|
||||
auto prim_list_reduce = std::make_shared<Primitive>("list_reduce");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
res->dump();
|
||||
TypePtr res_type = res->GetTypeTrack();
|
||||
res_type->dump();
|
||||
|
@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) {
|
|||
|
||||
auto prim_scalar_to_array = std::make_shared<Primitive>("scalar_to_array");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
res->dump();
|
||||
TypePtr res_type = res->BuildType();
|
||||
res_type->dump();
|
||||
|
@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) {
|
|||
|
||||
auto prim_array_to_scalar = std::make_shared<Primitive>("array_to_scalar");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
res->dump();
|
||||
TypePtr res_type = res->BuildType();
|
||||
res_type->dump();
|
||||
|
@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) {
|
|||
|
||||
auto prim_J = std::make_shared<Primitive>("J");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1);
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractJTaggedPtr res_J = dyn_cast<AbstractJTagged>(res);
|
||||
ASSERT_TRUE(res_J != nullptr);
|
||||
ASSERT_TRUE(*(res_J->element()) == *abstract_v1);
|
||||
|
@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) {
|
|||
int v1 = 1;
|
||||
AbstractBasePtr abstract_v1 = FromValue(v1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_v1};
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
res->dump();
|
||||
AbstractTuplePtr res_J = dyn_cast<AbstractTuple>(res);
|
||||
ASSERT_TRUE(res_J != nullptr);
|
||||
|
@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) {
|
|||
|
||||
AbstractBasePtrList args_spec_list = {a1, a2};
|
||||
|
||||
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred);
|
||||
AbstractTensorPtr res = dyn_cast<AbstractTensor>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||
|
||||
ASSERT_TRUE(*(dyn_cast<Shape>(res->GetShapeTrack())) == *(dyn_cast<Shape>(expected->GetShapeTrack())));
|
||||
}
|
||||
|
@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) {
|
|||
AbstractBasePtr arg2 = FromValue(2, false);
|
||||
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *arg1);
|
||||
}
|
||||
|
||||
|
@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) {
|
|||
AbstractBasePtr arg2 = FromValue(2, false);
|
||||
AbstractBasePtrList args_spec_list = {arg0, arg1, arg2};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "make result res: " << res->ToString();
|
||||
MS_LOG(INFO) << "make result arg2: " << arg2->ToString();
|
||||
ASSERT_TRUE(*res == *arg2);
|
||||
|
@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) {
|
|||
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_v1};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *abstract_v1);
|
||||
}
|
||||
|
||||
|
@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) {
|
|||
|
||||
AbstractBasePtrList args_spec_list = {a, b};
|
||||
|
||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred);
|
||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||
|
||||
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
||||
std::vector<ValuePtr> element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)};
|
||||
|
@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) {
|
|||
AbstractBasePtr abstract_v2 = FromValue(1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2};
|
||||
auto expected = std::make_shared<PartialAbstractClosure>(
|
||||
std::make_shared<PrimitiveAbstractClosure>(prim::kPrimScalarAdd), fn_args_list);
|
||||
|
@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) {
|
|||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
|
||||
|
@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) {
|
|||
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*res == *exp);
|
||||
}
|
||||
|
@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
|||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
|
||||
|
@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
|||
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*res == *exp);
|
||||
|
||||
|
@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) {
|
|||
AbstractBasePtr abstract_z = FromValue(3, false);
|
||||
args_spec_list = {res, embed_x, abstract_z};
|
||||
|
||||
res = engine_->Run(graph_getitem, args_spec_list).inferred;
|
||||
res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract();
|
||||
|
||||
ASSERT_TRUE(*res == *abstract_x);
|
||||
}
|
||||
|
@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) {
|
|||
FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1);
|
||||
AbstractBasePtr abstract_x = FromValue(1, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_x};
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred;
|
||||
AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract();
|
||||
|
||||
FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3);
|
||||
|
||||
|
@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) {
|
|||
AbstractBasePtr abstract_y = FromValue(2, false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_y};
|
||||
|
||||
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred;
|
||||
AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr exp = std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
|
||||
ASSERT_TRUE(*abstract_e1 == *exp);
|
||||
|
||||
AbstractBasePtr abstract_z = FromValue(3, false);
|
||||
args_spec_list = {abstract_env, embed_x, abstract_z};
|
||||
|
||||
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred;
|
||||
AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*abstract_e2 == *exp);
|
||||
|
||||
FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2);
|
||||
args_spec_list = {abstract_e1, abstract_e2};
|
||||
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract();
|
||||
|
||||
ASSERT_TRUE(*res == *exp);
|
||||
}
|
||||
|
@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) {
|
|||
|
||||
AbstractBasePtrList args_spec_list = {a};
|
||||
|
||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred);
|
||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
||||
|
||||
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
|
||||
|
@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) {
|
|||
AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW
|
||||
AbstractBasePtrList args_spec_list = {expected};
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) {
|
|||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {arr};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTensor>(ret);
|
||||
ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack()));
|
||||
}
|
||||
|
@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) {
|
|||
std::vector<int> shape = {2, 64, 14, 14};
|
||||
expected->set_shape(std::make_shared<Shape>(shape));
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||
|
||||
|
@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) {
|
|||
auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {input, weight};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTensor>(ret);
|
||||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16});
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
|
@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) {
|
|||
auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {input, weight};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTensor>(ret);
|
||||
auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16});
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
|
@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) {
|
|||
auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {value, bias};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTensor>(ret);
|
||||
auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32});
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
|
@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) {
|
|||
auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {logits, labels};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_NE(ret, nullptr);
|
||||
auto res = dyn_cast<AbstractTuple>(ret);
|
||||
auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64});
|
||||
|
@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) {
|
|||
auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {logits, labels};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractScalar>(ret);
|
||||
AbstractScalarPtr expected = std::make_shared<AbstractScalar>(kAnyValue, kFloat64);
|
||||
expected->set_type(UTPrimUtils::kF64);
|
||||
|
@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) {
|
|||
AbstractBasePtr expected0 = abstract_inputs->Clone();
|
||||
AbstractBasePtr expected1 = abstract_scale->Clone();
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
||||
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
||||
|
@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) {
|
|||
inputs->set_shape(inputs_dims);
|
||||
AbstractBasePtr abstract_input = FromValue(inputs, false);
|
||||
AbstractBasePtrList args_spec_list = {abstract_input};
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
|
||||
AbstractBasePtr expected = abstract_input->Clone()->Broaden();
|
||||
std::vector<int> expected_dims = {8, 64, 2, 2};
|
||||
|
@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) {
|
|||
auto prim = std::make_shared<Primitive>("hastype");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) {
|
|||
auto prim = std::make_shared<Primitive>("array_len");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) {
|
|||
auto prim = std::make_shared<Primitive>("list_len");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) {
|
|||
auto prim = std::make_shared<Primitive>("tuple_len");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) {
|
|||
auto prim = std::make_shared<Primitive>("tuple_reversed");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 1);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "expect=" << expected->ToString();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) {
|
|||
auto prim = std::make_shared<Primitive>("list_getitem");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *elem);
|
||||
}
|
||||
|
||||
|
@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) {
|
|||
auto prim = std::make_shared<Primitive>("list_setitem");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
AbstractBasePtrList elems_exp = {elem1, elem2};
|
||||
auto expected = std::make_shared<AbstractList>(elems_exp);
|
||||
|
@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) {
|
|||
auto prim = std::make_shared<Primitive>("list_append");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
auto expected = std::make_shared<AbstractList>(AbstractBasePtrList({elem1, elem2}));
|
||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||
|
@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) {
|
|||
auto prim = std::make_shared<Primitive>("tuple_setitem");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 3);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
AbstractBasePtrList elems_exp = {elem1, elem2};
|
||||
auto expected = std::make_shared<AbstractTuple>(elems_exp);
|
||||
|
@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) {
|
|||
auto prim = std::make_shared<Primitive>("make_list");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim, 2);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
}
|
||||
|
||||
|
@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) {
|
|||
AbstractBasePtrList elem_list({ele1, ele2, ele3});
|
||||
AbstractBasePtr expected = std::make_shared<AbstractTuple>(elem_list);
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "res=" << res->ToString();
|
||||
MS_LOG(INFO) << "expected=" << expected->ToString();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
|
@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) {
|
|||
AbstractBasePtr expected1 = abstract_mean_var->Clone();
|
||||
AbstractBasePtr expected2 = abstract_mean_var->Clone();
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected0: " << expected0->ToString();
|
||||
MS_LOG(INFO) << "expected1: " << expected1->ToString();
|
||||
|
@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) {
|
|||
AbstractBasePtr expected = std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
||||
std::make_shared<Shape>(std::vector<int>{79}));
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "res=" << res->ToString();
|
||||
MS_LOG(INFO) << "expected=" << expected->ToString();
|
||||
ASSERT_TRUE(*res == *expected);
|
||||
|
@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) {
|
|||
std::vector<int> shape = {2, 20, 32, 32};
|
||||
expected->set_shape(std::make_shared<Shape>(shape));
|
||||
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
MS_LOG(INFO) << "result: " << res->ToString();
|
||||
MS_LOG(INFO) << "expected: " << expected->ToString();
|
||||
|
||||
|
@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) {
|
|||
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
||||
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
||||
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTuple>(ret);
|
||||
AbstractBasePtrList x_idx_list;
|
||||
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
||||
|
@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) {
|
|||
auto x_input = std::make_shared<AbstractTuple>(x_arg_list);
|
||||
auto y_input = std::make_shared<AbstractTuple>(y_arg_list);
|
||||
AbstractBasePtrList args_spec_list = {x_input, y_input};
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
auto res = dyn_cast<AbstractTuple>(ret);
|
||||
AbstractBasePtrList x_idx_list({abstract::FromValue(1)});
|
||||
auto r_x = std::make_shared<AbstractTuple>(x_idx_list);
|
||||
|
@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) {
|
|||
AbstractBasePtr key = abstract::FromValue("x");
|
||||
AbstractBasePtrList args_spec_list = {array_dict, key};
|
||||
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
||||
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(FromValue(tensor_map[0].second));
|
||||
|
||||
|
@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
|
|||
AbstractBasePtr key = abstract::FromValue("x");
|
||||
AbstractBasePtrList args_spec_list = {array_dict, key};
|
||||
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
AbstractTensorPtr tensor_ret = dyn_cast<AbstractTensor>(ret);
|
||||
AbstractTensorPtr expect = dyn_cast<AbstractTensor>(arr_x);
|
||||
|
||||
|
|
|
@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
|
|||
|
||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||
}
|
||||
|
||||
|
@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) {
|
|||
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
|
||||
AbstractBasePtr abstract_v1 = FromValue(1, false);
|
||||
args_spec_list.push_back(abstract_v1);
|
||||
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||
|
||||
// now this test case failed randomly, have to debug.
|
||||
|
@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) {
|
|||
args_spec_list.clear();
|
||||
args_spec_list.push_back(abstract_v1);
|
||||
args_spec_list.push_back(abstract_v2);
|
||||
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred;
|
||||
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||
}
|
||||
|
||||
|
@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
|
|||
AbstractBasePtr abstract_v2 = FromValue(v1, false);
|
||||
args_spec_list.push_back(abstract_v1);
|
||||
args_spec_list.push_back(abstract_v2);
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract();
|
||||
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
|
||||
}
|
||||
|
||||
|
@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
|
|||
|
||||
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred;
|
||||
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
|
||||
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
|
||||
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32);
|
||||
}
|
||||
|
@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
|
|||
AbstractBasePtr base1 = FromValue(x1, false);
|
||||
AbstractBasePtr base2 = FromValue(x2, false);
|
||||
AbstractBasePtrList base_list = {base1, base2};
|
||||
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list);
|
||||
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract();
|
||||
MS_LOG(INFO) << "result spec: " << res->ToString();
|
||||
AbstractBasePtr exp = FromValue(x3, false);
|
||||
MS_LOG(INFO) << "result exp: " << exp->ToString();
|
||||
|
@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(50), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(1), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(9), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(48), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(7), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(46), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
|
|||
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
|
||||
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
|
||||
AbstractBasePtrList args_spec_list = {};
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
|
||||
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
|
||||
AbstractBasePtr expect = FromValue(MakeValue(57), false);
|
||||
ASSERT_EQ(*res, *expect);
|
||||
}
|
||||
|
|
|
@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
def conv3x3(in_channels, out_channels, stride=1, padding=1):
|
||||
"""3x3 convolution """
|
||||
|
@ -377,6 +378,21 @@ class StateNet(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
def test_conv2d_same_primitive():
|
||||
class Conv2DSameNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Conv2DSameNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||
def construct(self, x, y):
|
||||
r1 = self.conv1(x)
|
||||
r2 = self.conv2(y)
|
||||
return (r1, r2)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
net = Conv2DSameNet()
|
||||
out = net(t1, t2)
|
||||
|
||||
class ComparisonNet(nn.Cell):
|
||||
def __init__(self):
|
||||
""" ComparisonNet definition """
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test nn ops """
|
||||
import functools
|
||||
import numpy as np
|
||||
import mindspore
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
from ..ut_filter import non_graph_engine
|
||||
from ....mindspore_test_framework.mindspore_test import mindspore_test
|
||||
from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
||||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
from mindspore import context
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
|
||||
class FakeOp(PrimitiveWithInfer):
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
""""""
|
||||
def infer_shape(self, x, y):
|
||||
self.second_shape = y
|
||||
self.add_prim_attr("second_shape", y)
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, y):
|
||||
return x
|
||||
|
||||
# test the normal case that should generate independent primitive because of different
|
||||
# generated attributes after inference
|
||||
def test_conv2d_same_primitive():
|
||||
class Conv2DSameNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Conv2DSameNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
|
||||
def construct(self, x, y):
|
||||
r1 = self.conv1(x)
|
||||
r2 = self.conv2(y)
|
||||
return (r1, r2)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
net = Conv2DSameNet()
|
||||
out = net(t1, t2)
|
||||
|
||||
# test cell as high order argument
|
||||
# The graph with free variables used as argument is not supported yet
|
||||
# because of the limit of inference specialize system
|
||||
def Xtest_conv2d_op_with_arg():
|
||||
class Conv2dNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Conv2dNet, self).__init__()
|
||||
def construct(self, op, x):
|
||||
return op(x)
|
||||
class OpsNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(OpsNet, self).__init__()
|
||||
self.opnet = net
|
||||
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
|
||||
def construct(self, x, y):
|
||||
conv_op = self.conv2
|
||||
a = self.opnet(conv_op, x)
|
||||
b = self.opnet(conv_op, y)
|
||||
return (a, b)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
net = OpsNet(Conv2dNet())
|
||||
out = net(t1, t2)
|
||||
|
||||
|
||||
def test_conv2d_op_with_arg():
|
||||
class FackOpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(FackOpNet, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
class OpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OpNet, self).__init__()
|
||||
def construct(self, op, x, y):
|
||||
return op(x, y)
|
||||
class OpsNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(OpsNet, self).__init__()
|
||||
self.opnet = net
|
||||
self.op = FackOpNet()
|
||||
def construct(self, x, y):
|
||||
op = self.op
|
||||
a = self.opnet(op, x, y)
|
||||
b = self.opnet(op, y, x)
|
||||
return (a, b)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
net = OpsNet(OpNet())
|
||||
out = net(t1, t2)
|
||||
|
||||
|
||||
|
||||
def test_conv2d_op_with_arg_same_input():
|
||||
class FackOpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(FackOpNet, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
class OpNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OpNet, self).__init__()
|
||||
def construct(self, op, x, y):
|
||||
return op(x, y)
|
||||
class OpsNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(OpsNet, self).__init__()
|
||||
self.opnet = net
|
||||
self.op = FackOpNet()
|
||||
def construct(self, x, y):
|
||||
op = self.op
|
||||
a = self.opnet(op, x, x)
|
||||
b = self.opnet(op, y, x)
|
||||
return (a, b)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
net = OpsNet(OpNet())
|
||||
out = net(t1, t2)
|
||||
|
||||
# test op with partial
|
||||
def test_op_as_partial():
|
||||
class OpAsPartial(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OpAsPartial, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
partial_op = F.partial(self.op, x)
|
||||
a = partial_op(y)
|
||||
b = partial_op(z)
|
||||
return a, b
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = OpAsPartial()
|
||||
out = net(t1, t2, t3)
|
||||
|
||||
# test op with partial
|
||||
def test_op_as_partial_inside():
|
||||
class OpAsPartial(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OpAsPartial, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
partial_op = F.partial(self.op, x)
|
||||
a = partial_op(y)
|
||||
b = partial_op(z)
|
||||
return a, b
|
||||
class OuterNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OuterNet, self).__init__()
|
||||
self.net = OpAsPartial()
|
||||
def construct(self, x, y, z):
|
||||
a,b = self.net(x, y, z)
|
||||
return a, b
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = OuterNet()
|
||||
out = net(t1, t2, t3)
|
||||
|
||||
# test op with partial case 2
|
||||
def test_op_as_partial_independent():
|
||||
class OpAsPartial(nn.Cell):
|
||||
def __init__(self):
|
||||
super(OpAsPartial, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
partial_op1 = F.partial(self.op, x)
|
||||
a = partial_op1(y)
|
||||
partial_op2 = F.partial(self.op, x)
|
||||
b = partial_op2(z)
|
||||
return a, b
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = OpAsPartial()
|
||||
out = net(t1, t2, t3)
|
||||
|
||||
def test_nest_partial():
|
||||
class NestPartial(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NestPartial, self).__init__()
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
partial_op1 = F.partial(self.op)
|
||||
partial_op2 = F.partial(partial_op1, x)
|
||||
a = partial_op2(y)
|
||||
partial_op3 = F.partial(self.op)
|
||||
partial_op4 = F.partial(partial_op3, x)
|
||||
b = partial_op4(z)
|
||||
return a, b
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = NestPartial()
|
||||
out = net(t1, t2, t3)
|
||||
|
||||
# high order argument
|
||||
# op and op args as network arguments
|
||||
def test_op_with_arg_as_input():
|
||||
class WithOpArgNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(WithOpArgNet, self).__init__()
|
||||
def construct(self, op, x, y):
|
||||
return op(x, y)
|
||||
class OpsNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(OpsNet, self).__init__()
|
||||
self.opnet = net
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
op = self.op
|
||||
a = self.opnet(op, x, z)
|
||||
b = self.opnet(op, x, y)
|
||||
return (a, b)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = OpsNet(WithOpArgNet())
|
||||
out = net(t1, t2, t3)
|
||||
|
||||
# The partial application used as argument is not supported yet
|
||||
# because of the limit of inference specialize system
|
||||
def Xtest_partial_as_arg():
|
||||
class PartialArgNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(PartialArgNet, self).__init__()
|
||||
def construct(self, partial_op, y):
|
||||
return partial_op(y)
|
||||
class OpsNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(OpsNet, self).__init__()
|
||||
self.partial_net = net
|
||||
self.op = FakeOp()
|
||||
def construct(self, x, y, z):
|
||||
partial_op = F.partial(self.op, x)
|
||||
a = self.partial_net(partial_op, z)
|
||||
b = self.partial_net(partial_op, y)
|
||||
return (a, b)
|
||||
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
|
||||
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
|
||||
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
|
||||
net = OpsNet(PartialArgNet())
|
||||
out = net(t1, t2, t3)
|
Loading…
Reference in New Issue