diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index 00abc2b9ad4..5d0ff86fbb7 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { auto ctx = node_cfg_->context(); auto engine = node_cfg_->engine(); auto cfg = engine->MakeConfig(node, ctx); - auto abs = engine->cache().GetValue(cfg); - if (abs == nullptr) { + auto eval_result = engine->cache().GetValue(cfg); + if (eval_result == nullptr || eval_result->abstract() == nullptr) { return "Undefined"; } - + auto abs = eval_result->abstract(); auto dtype = abs->BuildType(); auto shape = abs->BuildShape(); std::ostringstream oss; diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h index 1b347a30739..78623f8542d 100644 --- a/mindspore/ccsrc/ir/primitive_base.h +++ b/mindspore/ccsrc/ir/primitive_base.h @@ -42,7 +42,11 @@ enum PrimType { class Primitive : public Named { public: explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {} + : Named(name), + is_base_(is_base), + has_signature_(false), + prim_type_(prim_type), + record_evaluate_add_attr_(false) {} Primitive(const Primitive &prim) : Named(prim), @@ -50,14 +54,23 @@ class Primitive : public Named { instance_name_(prim.instance_name_), is_base_(prim.is_base_), has_signature_(prim.has_signature_), - prim_type_(prim.prim_type_) {} + prim_type_(prim.prim_type_), + record_evaluate_add_attr_(false) {} MS_DECLARE_PARENT(Primitive, Named); abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } + void BeginRecordAddAttr() { + evaluate_added_attrs_.clear(); + record_evaluate_add_attr_ = true; + } + void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; + if (record_evaluate_add_attr_) { + evaluate_added_attrs_[name] = attr; + } return *this; } @@ -80,6 +93,7 @@ class Primitive : public Named { py::function hook() const { return hook_; } const std::unordered_map &attrs() const { return attrs_; } + std::unordered_map &evaluate_added_attrs() { return evaluate_added_attrs_; } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. bool HasAttr() const { return !attrs_.empty(); } @@ -106,6 +120,7 @@ class Primitive : public Named { protected: std::unordered_map attrs_; + std::unordered_map evaluate_added_attrs_; private: std::string instance_name_; @@ -113,6 +128,7 @@ class Primitive : public Named { bool is_base_; bool has_signature_; PrimType prim_type_; + bool record_evaluate_add_attr_; }; inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { diff --git a/mindspore/ccsrc/operator/prim_structures.cc b/mindspore/ccsrc/operator/prim_structures.cc index 31d2bff43d8..7b0bba98a59 100644 --- a/mindspore/ccsrc/operator/prim_structures.cc +++ b/mindspore/ccsrc/operator/prim_structures.cc @@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv } subargs.push_back(AbstractJoin(l_ptr->elements())); } - AbstractBasePtr engin_exc = engine->Execute(fn, subargs); + EvalResultPtr engin_exc = engine->Execute(fn, subargs); AbstractBasePtrList result; for (std::size_t i = 1; i < args_spec_list.size(); i++) { - result.push_back(engin_exc); + result.push_back(engin_exc->abstract()); } return std::make_shared(result); } @@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi AbstractBasePtr list_type = AbstractJoin(lst->elements()); auto result1 = engine->Execute(fn, lst->elements()); auto result2 = engine->Execute(fn, {dflt, list_type}); - MS_EXCEPTION_IF_NULL(result1); - return result1->Join(result2); + MS_EXCEPTION_IF_NULL(result1->abstract()); + MS_EXCEPTION_IF_NULL(result2->abstract()); + return result1->abstract()->Join(result2->abstract()); } AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index fde8558d3dd..542abc62058 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -89,7 +89,7 @@ static std::vector FastShadowSort(const AnfNodePtr &ret_node) { return sorted_nodes; } -AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { +EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); MS_EXCEPTION_IF_NULL(fg); std::size_t nargs = fg->parameters().size(); @@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs const auto &arg = args_spec_list[i]; const auto &node = parameters[i]; AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); - engine->cache().set_value(conf, arg); + engine->cache().set_value(conf, std::make_shared(arg, nullptr)); } const AnfNodePtr &func_node = fg->get_return(); @@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs const auto &node = *it; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); - ret_base = engine->GetEvaluatedValue(node_conf); + ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); } MS_EXCEPTION_IF_NULL(ret_base); - MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString(); - return ret_base; + MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); + return std::make_shared(ret_base, nullptr); } AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { @@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons return cloned_func_graph; } -AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { const std::string &evaluator_name = ToString(); AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); args_spec_list = NormalizeArgs(args_spec_list); args_spec_list = BroadenUndeterminedArgs(args_spec_list); @@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar auto iter = cache_->find(args_spec_list); if (iter == cache_->end()) { MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; - AbstractBasePtr ret = Eval(engine, args_spec_list); - if (ret == nullptr) { + EvalResultPtr ret = Eval(engine, args_spec_list); + if (ret->abstract() == nullptr) { EvalFailLogging(shared_from_base(), args_spec_list, out_conf); MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; } - MS_EXCEPTION_IF_NULL(ret); - MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; + MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; (*cache_)[args_spec_list] = ret; trace::TraceGraphEvalLeave(shared_from_base()); return ret; } else { MS_EXCEPTION_IF_NULL(iter->second); - MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; + MS_EXCEPTION_IF_NULL(iter->second->abstract()); + MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; trace::TraceGraphEvalLeave(shared_from_base()); return iter->second; } } -AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr) { +EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); - AbstractBasePtr ret = EvalPrim(engine, args_spec_list); + EvalResultPtr ret = EvalPrim(engine, args_spec_list); return ret; } -AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Size should greater than 0"; } - AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); + EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); // No need to cache. return ret; } -AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - AbstractBasePtr ret = EvalPrim(args_conf_list); +EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { + EvalResultPtr ret = EvalPrim(args_conf_list); return ret; } -AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); - AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); + EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); // Don't lookup from cache, as different out_conf with same node but different context // may add different entry to anfnode_config_map_, like getattr primitive. (*cache_)[args_spec_list] = ret; return ret; } -AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); MS_EXCEPTION_IF_NULL(cache_); auto iter = cache_->find(args_spec_list); @@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); + EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); + (*cache_)[args_spec_list] = ret; return ret; } -AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { +EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), [](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); MS_EXCEPTION_IF_NULL(cache_); auto iter = cache_->find(args_spec_list); @@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a } // Call the original evaluator, get the result: y = f(x) - AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr); + EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) AbstractBasePtrList bparams; @@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); AbstractBasePtr bparams_final = std::make_shared(bparams); - AbstractFunctionPtr bprop = std::make_shared(SensitivityTransform(result), bparams_final); + AbstractFunctionPtr bprop = + std::make_shared(SensitivityTransform(result->abstract()), bparams_final); // J(f)(J(x)) return a tuple (y, bprop_f) - AbstractBasePtrList jargs = {result, bprop}; + AbstractBasePtrList jargs = {result->abstract(), bprop}; AbstractBasePtr jtuple = std::make_shared(jargs); - (*cache_)[args_spec_list] = jtuple; - return jtuple; + auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); + (*cache_)[args_spec_list] = infer_reuslt; + return infer_reuslt; } -AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { +EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() != args_spec_list_.size()) { MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() << ", arguments no: " << args_spec_list.size(); @@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL MS_EXCEPTION_IF_NULL(args_spec_list[i]); (void)args_spec_list[i]->Join(args_spec_list_[i]); } - return output_; + return std::make_shared(output_, std::make_shared()); } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h index 0f7e1fe0520..c7a004ac44e 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h @@ -29,21 +29,28 @@ namespace mindspore { namespace abstract { using EvaluatorCacheMap = - std::unordered_map; + std::unordered_map; using EvaluatorCacheMapPtr = std::shared_ptr; +using EvaluatorAttrMap = + std::unordered_map; +using EvaluatorAttrMapPtr = std::shared_ptr; + class Evaluator : public Base { public: - explicit Evaluator(const std::string &id) : cache_(std::make_shared()), identifier_(id) {} + explicit Evaluator(const std::string &id) + : cache_(std::make_shared()), + attr_cache_(std::make_shared()), + identifier_(id) {} ~Evaluator() override = default; MS_DECLARE_PARENT(Evaluator, Base); // difference between Run() and Eval(): // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. // Run() will modify cache_ member, so it cannot marked as const; - virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); + virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); - virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } @@ -58,9 +65,10 @@ class Evaluator : public Base { virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } EvaluatorCacheMapPtr &cache() { return cache_; } + EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } EvaluatorCacheMapPtr cache_; - + EvaluatorAttrMapPtr attr_cache_; std::string identifier_; AnfNodeWeakPtr bound_node_; @@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator { explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} ~PrimEvaluator() override = default; MS_DECLARE_PARENT(PrimEvaluator, Evaluator); - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } }; @@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator { explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~TrivialPrimEvaluator() override = default; MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; }; class TransitionPrimEvaluator : public PrimEvaluator { @@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator { explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~TransitionPrimEvaluator() override = default; MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; // Parameter in_conf0 : the first element in args_conf_list; - virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; }; class SymbolicPrimEvaluator : public PrimEvaluator { @@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator { explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} ~SymbolicPrimEvaluator() override = default; MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; }; // Evaluator will be stored in AnalysisEngine.constructors_ @@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator { DummyEvaluator() : Evaluator("dummy") {} ~DummyEvaluator() override = default; MS_DECLARE_PARENT(DummyEvaluator, Evaluator); - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } }; // Wrap another evaluator to track a subset of uses. @@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator { bound_node_ = AnfNodeWeakPtr(node); } - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } private: @@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator { ~BaseFuncGraphEvaluator() override = default; MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); - AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; @@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator { } bound_node_ = AnfNodeWeakPtr(node); } - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; } - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } private: @@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator { ~VirtualEvaluator() override = default; MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); - AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; std::string ToString() const override { return identifier_; } private: @@ -285,11 +292,11 @@ class JEvaluator : public Evaluator { } bound_node_ = AnfNodeWeakPtr(node); } - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; } - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } private: diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index dbd08c5f58a..f8d53c71496 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { using mindspore::parse::PyObjectWrapper; -AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { +EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + prim_->BeginRecordAddAttr(); AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); - return abs_base; + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); + return infer_result; } -AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { AbstractBasePtrList args_spec_list; if (!prim_->isa()) { MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString(); @@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); ScopePtr scope = kDefaultScope; if (out_conf != nullptr) { @@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s return graph_specialize_args; } -AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { +EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { if (out_conf->node() == nullptr || !out_conf->node()->isa()) { MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; } @@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; AbstractBasePtrList args_spec_list; (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); // get the forward graph MS_EXCEPTION_IF_NULL(args_spec_list[0]); AbstractFunctionPtr fn = args_spec_list[0]->cast(); @@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } } // end anonymous namespace -AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { +EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); const auto &iter = cache_->find(args); @@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; } auto infer_fuc = pyobj.attr("__infer__"); - + prim_py_->BeginRecordAddAttr(); py::dict output = infer_fuc(*py_args); + prim_py_->EndRecordAddAttr(); + auto added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); auto res_spec = PyInferRes2Abstract(prim_py_, output); MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; - (*cache_)[args] = res_spec; - return res_spec; + auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); + (*cache_)[args] = infer_result; + return infer_result; } -AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { +EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. if (nargs_ != args.size()) { MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; @@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const } AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); - return abs_base; + return std::make_shared(abs_base, std::make_shared()); } ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { @@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun manager->AddFuncGraph(func_graph); } -AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &old_conf) { +EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &old_conf) { MS_EXCEPTION_IF_NULL(old_conf); AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); @@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat return eng->ForwardConfig(old_conf, fn_conf); } -AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, - const AnfNodeConfigPtr &out_conf) { +EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, + const AnfNodeConfigPtr &out_conf) { // args_spec_list: same as StaticGetter if (args_spec_list.size() < 2) { MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; @@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng return eng->ForwardConfig(out_conf, fn_conf); } -AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { +EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "args_spec_list is empty"; } @@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e AbstractBasePtr attr = cls->GetAttribute(item_name); if (attr != nullptr) { - return attr; + return std::make_shared(attr, nullptr); } ValuePtr method = cls->GetMethod(item_name); @@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e return StaticGetterInferred(converted_v, data_conf, out_conf); } -AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, - const TypePtr &data_type, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &out_conf) { +EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, + const TypePtr &data_type, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { MS_EXCEPTION_IF_NULL(item_v); MS_EXCEPTION_IF_NULL(data_type); // The method maybe a Primitive or Composite @@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e return StaticGetterInferred(converted_v, data_conf, out_conf); } -AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { +EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { // Inputs: namespace and its static function; or class and its member function CheckArgsSize("StaticGetter", args_spec_list, 2); @@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} ~EmbedEvaluator() override = default; MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); - AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { // arg: free variable to be embedded if (args_conf_list.size() != 1) { MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); @@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator { AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); MS_EXCEPTION_IF_NULL(node_conf); - AbstractBasePtr x = node_conf->GetEvaluatedValue(); + AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); x = SensitivityTransform(x); SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); - return abs_scalar; + return std::make_shared(abs_scalar, std::make_shared()); } }; @@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} ~RefToEmbedEvaluator() override = default; MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); - AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override { + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { if (args_conf_list.size() != 1) { MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); return nullptr; @@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; return nullptr; } - AbstractBasePtr abs = node_conf->GetEvaluatedValue(); + AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); AbstractRefPtr ref_abs = abs->cast(); if (ref_abs == nullptr) { MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref."; @@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { } auto refkey = key_value->cast(); if (refkey == nullptr) { - return std::make_shared(type); + return std::make_shared(std::make_shared(type), std::make_shared()); } std::string name = refkey->tag(); @@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { x = SensitivityTransform(x); std::shared_ptr key = std::make_shared(node, x); std::shared_ptr abs_scalar = std::make_shared(key, type); - return abs_scalar; + return std::make_shared(abs_scalar, std::make_shared()); } }; @@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} ~GetAttrEvaluator() override = default; MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { // Inputs: data, item if (args_spec_list.size() != 2) { MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); } - AbstractBasePtr ret = nullptr; + EvalResultPtr ret = nullptr; if (bound_node() != nullptr) { TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); @@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator { ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} ~ResolveEvaluator() override = default; MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { // Inputs: namespace, symbol if (args_spec_list.size() != 2) { MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); } - AbstractBasePtr ret = nullptr; + EvalResultPtr ret = nullptr; if (bound_node() != nullptr) { TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); @@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} ~CreateInstanceEvaluator() override = default; MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override { + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, + const AnfNodeConfigPtr &out_conf) override { if (args_spec_list.empty()) { MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; } @@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { } AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); - (*cache_)[args_spec_list] = ret; - return ret; + auto infer_result = std::make_shared(ret, nullptr); + (*cache_)[args_spec_list] = infer_result; + return infer_result; } pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { @@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator { public: PartialEvaluator() : Evaluator("PartialEvaluator") {} ~PartialEvaluator() override = default; - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf = nullptr) override { + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf = nullptr) override { if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Args size should be greater than 0"; } + MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); - - auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); + auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); AbstractBasePtrList args_spec_list{arg0_value}; // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. if (arg0_value->isa()) { auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() << " as func is: " << arg0_value->ToString(); - (*cache_)[args_spec_list] = ret; - return ret; + auto eval_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = eval_result; + return eval_result; } auto func = CheckArg("partial", args_spec_list, 0); // Sometimes, node[0] in out_conf becomes phi0; @@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator { } } - (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); + (void)std::transform( + args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); auto cnode = out_conf->node()->cast(); @@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator { func->Visit(build_partial); auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); - (*cache_)[args_spec_list] = ret; - return ret; + auto infer_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = infer_result; + return infer_result; } - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } - AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, - const AnfNodeConfigPtr &out_conf = nullptr) const { + EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, + const AnfNodeConfigPtr &out_conf = nullptr) const { MS_EXCEPTION_IF_NULL(out_conf); MS_EXCEPTION_IF_NULL(out_conf->node()); auto cnode = out_conf->node()->cast(); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 91b47ddb7f6..22418180f76 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator { : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} ~StandardPrimEvaluator() override = default; MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; PrimitivePtr prim() { return prim_; } std::string ToString() const override { return identifier_ + prim_->name(); } @@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator { : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} ~PythonPrimEvaluator() override = default; MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; PrimitivePtr prim() { return dyn_cast(prim_py_); } std::string ToString() const override { return identifier_ + prim_py_->name(); } @@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator { public: explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} ~DoSignatureEvaluator() override = default; - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } @@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator { public: explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} ~UnpackGraphEvaluator() override = default; - AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; - AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; } @@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator { ~UniformPrimEvaluator() override = default; MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); - AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; ValuePtr RunImpl(const ValuePtrList &args) const; // If eval_value_ is False, return broadened arguments. diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc index 6f67e7348c5..2a03eb6d5c1 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc @@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { if (conf->node()->intermediate_abstract()) { return conf->node()->intermediate_abstract(); } - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); } AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { @@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() { // Specialize CNode in func graphs void FuncGraphSpecializer::SecondPass() { - for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) { + for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { if (node->isa()) { ProcessCNode(node->cast()); } @@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { AnfNodeConfigPtr conf = MakeConfig(node); AnfNodePtr new_node = GetReplicatedNode(node); MS_EXCEPTION_IF_NULL(new_node); - if (new_node->func_graph() != specialized_func_graph_) { MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() << ", new_node: " << new_node->DebugString() @@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); if (node->isa()) { + auto attrs = conf->GetEvaluatedValue()->attribute(); auto c_old = node->cast(); auto c_new = new_node->cast(); auto new_inputs = c_new->inputs(); @@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. - AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival); + AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); if (replace_node == nullptr) { replace_node = BuildReplacedNode(iconf); MS_EXCEPTION_IF_NULL(replace_node); @@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); } + auto attrs = std::make_shared(); for (size_t i = 0; i < partial_closure->args().size(); i++) { auto old_node = cnode->input(i + 2); - auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); + auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); if (possibile_value_node != nullptr) { partial_node_list.push_back(possibile_value_node); } else { @@ -455,7 +456,7 @@ std::pair FuncGraphSpecializer::BuildFromB const EvaluatorPtr &eval) { MS_EXCEPTION_IF_NULL(eval); std::unordered_set choices; - AbstractBasePtr ret = nullptr; + EvalResultPtr ret = nullptr; AbstractBasePtrList broaded_argvals; for (auto &argvals_map : *evalcaches_[eval]) { auto argvals = argvals_map.first; @@ -478,7 +479,7 @@ std::pair FuncGraphSpecializer::BuildFromB (*real)[broaded_argvals] = ret; evalcaches_[eval] = real; - return std::make_pair(broaded_argvals, ret); + return std::make_pair(broaded_argvals, ret->abstract()); } else { MS_LOG(DEBUG) << "Choices.size: " << choices.size(); return std::make_pair(AbstractBasePtrList(), nullptr); @@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { return; } specializer_->AddSeen(new_node); - auto new_inputs = new_node->inputs(); if (new_inputs.empty()) { MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; @@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { } if (CanSpecializeNode(func)) { - new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); + // for primitive node , we build the primitive node with infered attributes in the first pass + // so we do not build replaced node again here in second pass + if (IsValueNode(func)) { + new_inputs[0] = func; + } else { + new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); + } } for (size_t i = 0; i < argvals.size();) { @@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { } i = next; } - new_node->set_inputs(new_inputs); } @@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct EvaluatorCacheMap evaluator_cache_map = *eval->cache(); if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { - *result = std::make_pair(argvals, evaluator_cache_map[argvals]); + *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); return kSpecializeSuccess; } DumpEvaluatorCache(evaluator_cache_map, argvals); @@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct MS_EXCEPTION_IF_NULL(choices); if (choices->count(argvals)) { - *result = std::make_pair(argvals, (*choices)[argvals]); + *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); return kSpecializeSuccess; } else if (choices->size() == 1) { MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; - *result = std::make_pair(choices->begin()->first, choices->begin()->second); + *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); return kSpecializeSuccess; } else if (choices->empty()) { MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; @@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct return kSpecializeFindUniqueArgvalPoly; } } +static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { + auto &prim_attrs = prim->attrs(); + bool is_attr_same = true; + for (auto &item : *attrs) { + auto itr = prim_attrs.find(item.first); + if (itr != prim_attrs.end()) { + if (!(*(itr->second) == *(item.second))) { + is_attr_same = false; + break; + } + } else { + is_attr_same = false; + break; + } + } + if (!is_attr_same) { + if (prim->isa()) { + PrimitivePyPtr prim_py = prim->cast(); + auto clone_fn = prim_py->GetPyObj().attr("_clone"); + py::object new_obj = clone_fn(); + auto cloned_prim = new_obj.cast(); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + auto cloned_prim = std::make_shared(*prim); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + return prim; +} -AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) { +AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, + const AttrValueMapPtr &attrs) { MS_EXCEPTION_IF_NULL(origin_node); MS_EXCEPTION_IF_NULL(ival); @@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin ValuePtr value = nullptr; if (abs->isa()) { auto real_fn = dyn_cast(abs); - value = real_fn->prim(); + // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one + if (attrs != nullptr) { + value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); + } else { + value = real_fn->prim(); + } } else if (abs->isa()) { auto real_fn = dyn_cast(abs); value = real_fn->meta_func_graph(); diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h index e3c2027d418..b04978586d9 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h @@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_thisnode; it may be a replicated forwared CNode in static analysis or just a // replicated node. AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index 464f8313e92..e024d11b474 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase return nullptr; } -void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) { +void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() - << ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString() - << ", Pointer: " << arg.get(); - cache_[conf] = arg; + << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() + << ", Pointer: " << result->abstract().get(); + cache_[conf] = result; // Set intermediate abstract value. - if (IsIntermediateAbstract(arg)) { + if (IsIntermediateAbstract(result->abstract())) { if (conf->node()->intermediate_abstract() == nullptr) { - conf->node()->set_intermediate_abstract(arg); - MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString(); + conf->node()->set_intermediate_abstract(result->abstract()); + MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); } else { auto old_spec = conf->node()->intermediate_abstract(); - auto joined_spec = IntermediateJoin(arg, old_spec); + auto joined_spec = IntermediateJoin(result->abstract(), old_spec); conf->node()->set_intermediate_abstract(joined_spec); MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" - << arg->ToString() << "\njoined_spec:\t" + << result->abstract()->ToString() << "\njoined_spec:\t" << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); } } } -AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { +EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { auto value = cache_.find(conf); if (value == cache_.end()) { return nullptr; @@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana return eval->graph_context(); } -AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { +EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); auto value = cache_.GetValue(conf); if (value != nullptr) { - MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", " - << value->ToString(); + MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() + << ", " << value->abstract()->ToString(); return value; } @@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) return value; } -AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { +EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); AnfNodePtr node = conf->node(); - AbstractBasePtr ret_abstract = nullptr; + EvalResultPtr eval_result = nullptr; #ifdef DEBUG compute_conf_stack_.push_back(node); std::ostringstream buffer; @@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(node); if (node->abstract() != nullptr) { MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); - ret_abstract = node->abstract(); + eval_result = std::make_shared(node->abstract(), std::make_shared()); } else if (node->isa()) { auto value_node = node->cast(); - ret_abstract = EvalValueNode(value_node, conf); + eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); } else if (node->isa()) { auto cnode = node->cast(); trace::TraceEvalCNodeEnter(conf); - ret_abstract = EvalCNode(cnode, conf); + eval_result = EvalCNode(cnode, conf); trace::TraceEvalCNodeLeave(); } else { MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() @@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { #ifdef DEBUG compute_conf_stack_.pop_back(); - if (ret_abstract == nullptr) { + if (eval_result == nullptr) { MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } #endif - MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString(); - return ret_abstract; + MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); + return eval_result; } AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { @@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co return ToAbstract(value_node->value(), conf->context(), conf); } -AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { +EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(cnode); auto &inputs = cnode->inputs(); @@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); MS_EXCEPTION_IF_NULL(func_conf); // Keep it in a local variable, otherwise smart pointer will free it. - AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue(); + AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); if (maybe_func == nullptr) { MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); @@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo return ExecuteEvaluators(infs, conf, args_conf_list); } -AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { +EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { ConfigPtrList args_conf_list; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); @@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { return tracked_eval; } -AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list) { +EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { if (evaluators.size() == 1) { EvaluatorPtr eval = evaluators[0]; MS_EXCEPTION_IF_NULL(eval); @@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list) { +EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list) { AbstractBasePtrList out_specs; if (!multi_poss_.count(evaluators[0])) { multi_poss_[evaluators[0]] = evaluators[1]; @@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue(); + return conf->GetEvaluatedValue()->abstract(); }); for (auto eval : evaluators) { auto fg_eval = eval->cast(); @@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorToString() << " ptr: " << eval.get(); MS_EXCEPTION_IF_NULL(eval); - auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(out_spec); - MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); - out_specs.push_back(out_spec); - MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); + auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); + out_specs.push_back(eval_result->abstract()); eval_trace_.pop_back(); if (eval_trace_.empty()) { multi_poss_.clear(); @@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorfirst) { MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); - auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(out_spec); - MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); - return out_spec; + auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() + << " return out_spec: " << eval_result->abstract()->ToString(); + return eval_result; } } } @@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorBroaden(); + return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); } auto joined_spec = AbstractJoin(out_specs); MS_EXCEPTION_IF_NULL(joined_spec); MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); - return joined_spec; + return std::make_shared(joined_spec, std::make_shared()); } -AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() { +EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { AnfNodeConfigPtr self = shared_from_base(); return engine_.lock()->GetEvaluatedValue(self); } @@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { return a; } -AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { +EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { auto evaluator = GetPrimEvaluator(primitive, nullptr); MS_EXCEPTION_IF_NULL(evaluator); if (!evaluator->isa()) { @@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr << evaluator->ToString(); } auto trivial_evaluator = dyn_cast(evaluator); - auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs); - return res_spec; + auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); + return eval_result; } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h index e8cccd3d824..7c547266028 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h @@ -40,13 +40,33 @@ namespace mindspore { namespace abstract { + +// define attribute value map +using AttrValueMap = std::unordered_map; +using AttrValueMapPtr = std::shared_ptr; + +// the class to save evaluated result: abstract value and modified attribute +class EvalResult : public Base { + public: + EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} + ~EvalResult() override = default; + MS_DECLARE_PARENT(EvalResult, Base); + AbstractBasePtr abstract() { return abstract_; } + AttrValueMapPtr attribute() { return attribute_; } + + private: + AbstractBasePtr abstract_; + AttrValueMapPtr attribute_; +}; + +using EvalResultPtr = std::shared_ptr; // Superclass for AnfNodeConfig and VirtualConfig. class Config : public Base { public: Config() = default; ~Config() override = default; MS_DECLARE_PARENT(Config, Base); - virtual AbstractBasePtr GetEvaluatedValue() = 0; + virtual EvalResultPtr GetEvaluatedValue() = 0; }; // Config will be stored in AnalysisCache @@ -74,7 +94,7 @@ class AnfNodeConfig : public Config { ~AnfNodeConfig() override = default; MS_DECLARE_PARENT(AnfNodeConfig, Config); - AbstractBasePtr GetEvaluatedValue() override; + EvalResultPtr GetEvaluatedValue() override; AnalysisContextPtr context() const { return context_; } @@ -123,7 +143,9 @@ class VirtualConfig : public Config { ~VirtualConfig() override = default; MS_DECLARE_PARENT(VirtualConfig, Config); - AbstractBasePtr GetEvaluatedValue() override { return abstract_; } + EvalResultPtr GetEvaluatedValue() override { + return std::make_shared(abstract_, std::make_shared()); + } private: AbstractBasePtr abstract_; @@ -135,11 +157,11 @@ class AnalysisCache { AnalysisCache() = default; ~AnalysisCache() = default; void Clear() { cache_.clear(); } - void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg); - AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf); + void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); + EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); private: - std::unordered_map cache_; + std::unordered_map cache_; }; using PrimEvaluatorMap = std::unordered_map; @@ -147,7 +169,7 @@ using AnfNodeConfigMap = std::unordered_map; struct AnalysisResult { - AbstractBasePtr inferred; + EvalResultPtr inferred; AnalysisContextPtr context; }; @@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this { // func_graph: The func_graph to analyze. // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); + EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); // Return the Evaluator for the given function. EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); - AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); + EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); // Infer the result of fn(args). - AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); + EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); void Clear(); void ClearEvaluatorCache(); AnalysisCache &cache() { return cache_; } @@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this { // Set the analysis result for orig to the result for new. // This sets an entry in anfnode_config_map from orig to new. - AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { + EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. (void)anfnode_config_map_.emplace(orig_conf, new_conf); MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() @@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this { AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, const ConfigPtrList &args_conf_list); - AbstractBasePtr Eval(const AnfNodeConfigPtr &conf); + EvalResultPtr Eval(const AnfNodeConfigPtr &conf); EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); - AbstractBasePtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list); - AbstractBasePtr ExecuteMultipleEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list); + EvalResultPtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); + EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); #ifdef DEBUG std::vector compute_conf_stack_; @@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { return FromValueInside(MakeValue(value), broaden); } -AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); +EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 451b7f4b3e0..6f0a4e57905 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); } } - AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list); + AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); op_exec_info->abstract = infer_res; } diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 6a4ee58e307..60733642e78 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include "ir/visitor.h" #include "utils/log_adapter.h" @@ -223,6 +225,31 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c return res; } +// search the cnodes inside this graph only +std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { + std::queue todo; + todo.push(ret); + std::vector sorted_nodes; + auto seen = NewSeenGeneration(); + while (!todo.empty()) { + CNodePtr top = todo.front(); + todo.pop(); + sorted_nodes.push_back(top); + auto inputs = top->inputs(); + for (auto &item : inputs) { + if (item->seen_ == seen) { + continue; + } + + if (item->isa()) { + todo.push(item->cast()); + } + item->seen_ = seen; + } + } + return sorted_nodes; +} + std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index d01335af829..7e8b36fb8e1 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -57,6 +57,7 @@ std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, const IncludeFunc &include = AlwaysInclude); +std::vector BroadFirstSearchGraphCNodes(CNodePtr ret); class FuncGraphIndex { public: explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 757335e2931..7fd2314e2b4 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init ExpandDims""" - self.__setattr_flag__ = True self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) def __infer__(self, x, axis): @@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer): # if primitive need setattr in __infer__ need add this flag """init Cast""" self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) - self.__setattr_flag__ = True def __infer__(self, x, t): src_type = x['dtype'] @@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer): def __init__(self): """init Reshape""" self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output']) - self.__setattr_flag__ = True def __infer__(self, x, shape): shape_v = shape['value'] @@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init Transpose""" - self.__setattr_flag__ = True self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output']) def __infer__(self, x, perm): @@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init index_select""" - self.__setattr_flag__ = True self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) def __infer__(self, params, indices, axis): @@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0): """init Tile""" - self.__setattr_flag__ = True validator.check_value_type("axis", axis, [int], self.name) def __infer__(self, input_x): @@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0): """init Pack""" - self.__setattr_flag__ = True validator.check_value_type("axis", axis, [int], self.name) self.axis = axis @@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0): """init Unpack""" - self.__setattr_flag__ = True validator.check_value_type("axis", axis, [int], self.name) self.axis = axis @@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" - self.__setattr_flag__ = True def infer_shape(self, cond_shape, x_shape, y_shape): if cond_shape != x_shape or x_shape != y_shape: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fdea0e3672f..40d01bab09d 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer): @prim_attr_register def __init__(self, transpose_a=False, transpose_b=False): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) - self.__setattr_flag__ = True cls_name = self.name validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) @@ -596,7 +595,6 @@ class BatchMatMul(MatMul): @prim_attr_register def __init__(self, transpose_a=False, transpose_b=False): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) - self.__setattr_flag__ = True cls_name = self.name validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) @@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer): @prim_attr_register def __init__(self): - self.__setattr_flag__ = True self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) def infer_shape(self, inputs): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a4e998589e1..6552d9f47cd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name) - self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) validator.check_value_type('pad', pad, (int,), self.name) @@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer): self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) - out_channel = self.out_channel out_shape = [x_shape[0], out_channel, h_out, w_out] return out_shape diff --git a/tests/st/ops/ascend/test_ops_infer.py b/tests/st/ops/ascend/test_ops_infer.py new file mode 100644 index 00000000000..43d56c865c7 --- /dev/null +++ b/tests/st/ops/ascend/test_ops_infer.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test nn ops """ +import functools +import numpy as np +import mindspore.nn as nn +import mindspore.context as context +import mindspore.common.dtype as mstype + +from mindspore import Tensor, Parameter +from mindspore.common.initializer import initializer +from mindspore.ops import Primitive +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import prim_attr_register, PrimitiveWithInfer +from mindspore.ops.primitive import constexpr +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + + +def test_cast_op_attr(): + class CastNet(nn.Cell): + def __init__(self): + super(CastNet, self).__init__() + self.cast = P.Cast() + def construct(self, x, t): + return self.cast(x, t) + + class CastTypeTest(nn.Cell): + def __init__(self, net): + super(CastTypeTest, self).__init__() + self.net = net + self.cast = P.Cast() + def construct(self, x, y, z): + cast_op = self.cast + t1 = cast_op(x, mstype.float32) + t2 = cast_op(y, mstype.int32) + cast_net = self.net + t3 = cast_net(x, mstype.float16) + t4 = cast_net(y, mstype.int32) + t5 = cast_net(z, mstype.float16) + return (t1, t2, t3, t4, t5) + net = CastTypeTest(CastNet()) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32)) + out = net(t1, t2, t3) + assert out[0].asnumpy().dtype == np.float32 + assert out[1].asnumpy().dtype == np.int32 + assert out[2].asnumpy().dtype == np.float16 + assert out[3].asnumpy().dtype == np.int32 + assert out[4].asnumpy().dtype == np.float16 diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 2c4b9b71468..84e9fda9d2e 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) { auto slice = std::make_shared(start_index, stop_index, step); AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) { auto slice = std::make_shared(start_index, stop_index, step); AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) { auto slice = std::make_shared(start_index, stop_index, step); AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) { auto slice = std::make_shared(start_index, stop_index, step); AbstractBasePtrList args_spec_list = {tuple_tensor, slice}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) { AbstractSlicePtr slice = std::make_shared(start_index, stop_index, step); AbstractBasePtrList args_spec_list = {tensor, slice}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) { AbstractTuplePtr slice_tuple = std::make_shared(eles); AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) { AbstractTuplePtr slice_tuple = std::make_shared(eles); AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) { AbstractScalarPtr start_index = std::make_shared(2); AbstractBasePtrList args_spec_list = {tensor, start_index}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) { AbstractTuplePtr slice_tuple = std::make_shared(eles); AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) { AbstractTuplePtr slice_tuple = std::make_shared(eles); AbstractBasePtrList args_spec_list = {tensor, slice_tuple}; - AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred); + AbstractTensorPtr ret = dyn_cast(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract array failed."; } @@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) { abstract::AbstractDictionaryPtr tensor_dict = std::make_shared(tensor_map); AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) { abstract::AbstractDictionaryPtr tensor_dict = std::make_shared(tensor_map); AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } @@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) { auto tuple = std::make_shared(eles); AbstractBasePtrList args_spec_list = {tuple}; - AbstractTuplePtr ret = dyn_cast(engine_->Run(zip_op_graph, args_spec_list).inferred); + AbstractTuplePtr ret = dyn_cast(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract()); if (ret == nullptr) { FAIL() << "Cast ret to abstract tuple failed."; } diff --git a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc index 80acbe6ad57..eebe6c252ba 100644 --- a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc @@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { AbstractBasePtr abstract_v2 = FromValue(2, false); AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2}; AbstractBasePtr abstract_val = FromValue(10, false); - cache[args_spec_list] = abstract_val; + cache[args_spec_list] = std::make_shared(abstract_val, std::make_shared()); auto iter = cache.find(args_spec_list); ASSERT_TRUE(iter != cache.end()); - ASSERT_TRUE(iter->second == abstract_val); + ASSERT_TRUE(iter->second->abstract() == abstract_val); AbstractBasePtr abstract_v1_variant1 = FromValue(1, false); AbstractBasePtr abstract_v2_variant1 = FromValue(2, false); @@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) { iter = cache.find(args_spec_list_variant1); ASSERT_TRUE(iter != cache.end()); - ASSERT_TRUE(iter->second == abstract_val); + ASSERT_TRUE(iter->second->abstract() == abstract_val); AbstractBasePtr abstract_v1_variant2 = FromValue(1, false); AbstractBasePtr abstract_v2_variant2 = FromValue(3, false); @@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) { std::vector shape = {2, 2, 6, 6}; expected->set_shape(std::make_shared(shape)); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString(); @@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) { AbstractBasePtr abstract_x = FromValue(x, false); args_spec_list.push_back(abstract_x); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); } @@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) { AbstractBasePtr abstract_x = FromValue(x, false); args_spec_list.push_back(abstract_x); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32); } @@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) { args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_y); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); } @@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) { args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_y); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); } @@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) { args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_y); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); } @@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) { args_spec_list.push_back(abstract_x); args_spec_list.push_back(abstract_y); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64); } diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index f54961af94f..532666b7a94 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -139,7 +139,7 @@ TEST_F(TestPrim, test_typeof) { auto prim_typeof = std::make_shared("typeof"); FuncGraphPtr func_graph = MakeFuncGraph(prim_typeof, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); res->dump(); TypePtr res_value = res->GetValueTrack()->cast(); res_value->dump(); @@ -164,7 +164,7 @@ TEST_F(TestPrim, test_list_map) { auto prim_list_map = std::make_shared("list_map"); FuncGraphPtr func_graph = MakeFuncGraph(prim_list_map, 3); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto expected = std::make_shared(AbstractBasePtrList({FromValue(3, false), FromValue(3, false)})); res->dump(); MS_LOG(INFO) << "result res: " << res->ToString(); @@ -188,7 +188,7 @@ TEST_F(TestPrim, test_list_reduce) { auto prim_list_reduce = std::make_shared("list_reduce"); FuncGraphPtr func_graph = MakeFuncGraph(prim_list_reduce, 3); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); res->dump(); TypePtr res_type = res->GetTypeTrack(); res_type->dump(); @@ -205,7 +205,7 @@ TEST_F(TestPrim, test_scalar_to_array) { auto prim_scalar_to_array = std::make_shared("scalar_to_array"); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_to_array, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); res->dump(); TypePtr res_type = res->BuildType(); res_type->dump(); @@ -223,7 +223,7 @@ TEST_F(TestPrim, test_array_to_scalar) { auto prim_array_to_scalar = std::make_shared("array_to_scalar"); FuncGraphPtr func_graph = MakeFuncGraph(prim_array_to_scalar, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); res->dump(); TypePtr res_type = res->BuildType(); res_type->dump(); @@ -239,7 +239,7 @@ TEST_F(TestPrim, test_J_1) { auto prim_J = std::make_shared("J"); FuncGraphPtr func_graph = MakeFuncGraph(prim_J, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractJTaggedPtr res_J = dyn_cast(res); ASSERT_TRUE(res_J != nullptr); ASSERT_TRUE(*(res_J->element()) == *abstract_v1); @@ -280,7 +280,7 @@ TEST_F(TestPrim, test_J_2) { int v1 = 1; AbstractBasePtr abstract_v1 = FromValue(v1, false); AbstractBasePtrList args_spec_list = {abstract_v1}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); res->dump(); AbstractTuplePtr res_J = dyn_cast(res); ASSERT_TRUE(res_J != nullptr); @@ -302,7 +302,7 @@ TEST_F(TestPrim, test_dot) { AbstractBasePtrList args_spec_list = {a1, a2}; - AbstractTensorPtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred); + AbstractTensorPtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred->abstract()); ASSERT_TRUE(*(dyn_cast(res->GetShapeTrack())) == *(dyn_cast(expected->GetShapeTrack()))); } @@ -317,7 +317,7 @@ TEST_F(TestPrim, test_switch1) { AbstractBasePtr arg2 = FromValue(2, false); AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *arg1); } @@ -330,7 +330,7 @@ TEST_F(TestPrim, test_switch2) { AbstractBasePtr arg2 = FromValue(2, false); AbstractBasePtrList args_spec_list = {arg0, arg1, arg2}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "make result res: " << res->ToString(); MS_LOG(INFO) << "make result arg2: " << arg2->ToString(); ASSERT_TRUE(*res == *arg2); @@ -343,7 +343,7 @@ TEST_F(TestPrim, test_identity) { AbstractBasePtr abstract_v1 = FromValue(1, false); AbstractBasePtrList args_spec_list = {abstract_v1}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *abstract_v1); } @@ -357,7 +357,7 @@ TEST_F(TestPrim, test_broadcast_shape) { AbstractBasePtrList args_spec_list = {a, b}; - AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred); + AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred->abstract()); auto ret = res->BuildValue()->cast()->value(); std::vector element_list = {MakeValue(Shape::SHP_ANY), MakeValue(Shape::SHP_ANY)}; @@ -377,7 +377,7 @@ TEST_F(TestPrim, test_partial) { AbstractBasePtr abstract_v2 = FromValue(1, false); AbstractBasePtrList args_spec_list = {abstract_add, abstract_v1, abstract_v2}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractBasePtrList fn_args_list = {abstract_v1, abstract_v2}; auto expected = std::make_shared( std::make_shared(prim::kPrimScalarAdd), fn_args_list); @@ -392,7 +392,7 @@ TEST_F(TestPrim, test_env_setitem) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtrList args_spec_list = {abstract_x}; - AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; + AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); FuncGraphPtr func_graph = MakeFuncGraph(prim::kPrimEnvSetItem, 3); @@ -400,7 +400,7 @@ TEST_F(TestPrim, test_env_setitem) { AbstractBasePtr abstract_y = FromValue(2, false); args_spec_list = {abstract_env, embed_x, abstract_y}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*res == *exp); } @@ -412,7 +412,7 @@ TEST_F(TestPrim, test_env_getitem) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtrList args_spec_list = {abstract_x}; - AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; + AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); @@ -420,7 +420,7 @@ TEST_F(TestPrim, test_env_getitem) { AbstractBasePtr abstract_y = FromValue(2, false); args_spec_list = {abstract_env, embed_x, abstract_y}; - AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*res == *exp); @@ -429,7 +429,7 @@ TEST_F(TestPrim, test_env_getitem) { AbstractBasePtr abstract_z = FromValue(3, false); args_spec_list = {res, embed_x, abstract_z}; - res = engine_->Run(graph_getitem, args_spec_list).inferred; + res = engine_->Run(graph_getitem, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *abstract_x); } @@ -442,7 +442,7 @@ TEST_F(TestPrim, test_env_add) { FuncGraphPtr graph_embed = MakeFuncGraph(prim::kPrimEmbed, 1); AbstractBasePtr abstract_x = FromValue(1, false); AbstractBasePtrList args_spec_list = {abstract_x}; - AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred; + AbstractBasePtr embed_x = engine_->Run(graph_embed, args_spec_list).inferred->abstract(); FuncGraphPtr graph_setitem = MakeFuncGraph(prim::kPrimEnvSetItem, 3); @@ -450,19 +450,19 @@ TEST_F(TestPrim, test_env_add) { AbstractBasePtr abstract_y = FromValue(2, false); args_spec_list = {abstract_env, embed_x, abstract_y}; - AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred; + AbstractBasePtr abstract_e1 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); AbstractBasePtr exp = std::make_shared(kAnyValue, std::make_shared()); ASSERT_TRUE(*abstract_e1 == *exp); AbstractBasePtr abstract_z = FromValue(3, false); args_spec_list = {abstract_env, embed_x, abstract_z}; - AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred; + AbstractBasePtr abstract_e2 = engine_->Run(graph_setitem, args_spec_list).inferred->abstract(); ASSERT_TRUE(*abstract_e2 == *exp); FuncGraphPtr graph_add = MakeFuncGraph(prim::kPrimEnvAdd, 2); args_spec_list = {abstract_e1, abstract_e2}; - AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph_add, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *exp); } @@ -475,7 +475,7 @@ TEST_F(TestPrim, test_shape) { AbstractBasePtrList args_spec_list = {a}; - AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred); + AbstractTuplePtr res = dyn_cast(engine_->Run(func_graph, args_spec_list).inferred->abstract()); auto ret = res->BuildValue()->cast()->value(); std::vector element_list = {MakeValue(2), MakeValue(3)}; @@ -493,7 +493,7 @@ TEST_F(TestPrim, test_relu) { AbstractBasePtr expected = UTPrimUtils::ArrayFloat64Of({2, 2, 2, 3}); // NCHW AbstractBasePtrList args_spec_list = {expected}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -507,7 +507,7 @@ TEST_F(TestPrim, test_relu2) { auto expected = ArrayOfTensor(UTPrimUtils::kF32, {3, 4, 5}); AbstractBasePtrList args_spec_list = {arr}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); ASSERT_TRUE(*(res->GetShapeTrack()) == *(expected->GetShapeTrack())); } @@ -540,7 +540,7 @@ TEST_F(TestPrim, test_conv2d1) { std::vector shape = {2, 64, 14, 14}; expected->set_shape(std::make_shared(shape)); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString(); @@ -558,7 +558,7 @@ TEST_F(TestPrim, test_conv2d) { auto weight = ArrayOfTensor(UTPrimUtils::kF32, {64, 32, 3, 3}); AbstractBasePtrList args_spec_list = {input, weight}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 64, 16, 16}); MS_LOG(INFO) << "result: " << res->ToString(); @@ -574,7 +574,7 @@ TEST_F(TestPrim, test_conv2d_native) { auto weight = ArrayOfTensor(UTPrimUtils::kF64, {3, 32, 3, 3}); AbstractBasePtrList args_spec_list = {input, weight}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); auto expected = ArrayOfTensor(UTPrimUtils::kF64, {10, 96, 16, 16}); MS_LOG(INFO) << "result: " << res->ToString(); @@ -590,7 +590,7 @@ TEST_F(TestPrim, test_biasAdd) { auto bias = ArrayOfTensor(UTPrimUtils::kF32, {32}); AbstractBasePtrList args_spec_list = {value, bias}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); auto expected = ArrayOfTensor(UTPrimUtils::kF32, {10, 32, 32, 32}); MS_LOG(INFO) << "result: " << res->ToString(); @@ -606,7 +606,7 @@ TEST_F(TestPrim, test_softmax_cross_entropy_with_logits) { auto labels = ArrayOfTensor(UTPrimUtils::kF32, {64, 10}); AbstractBasePtrList args_spec_list = {logits, labels}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_NE(ret, nullptr); auto res = dyn_cast(ret); auto loss = ArrayOfTensor(UTPrimUtils::kF32, {64}); @@ -636,7 +636,7 @@ TEST_F(TestPrim, test_tensor_to_scalar_prim) { auto labels = ArrayOfTensor(UTPrimUtils::kF64, {64, 10}); AbstractBasePtrList args_spec_list = {logits, labels}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); AbstractScalarPtr expected = std::make_shared(kAnyValue, kFloat64); expected->set_type(UTPrimUtils::kF64); @@ -690,7 +690,7 @@ TEST_F(TestPrim, test_fused_batch_norm) { AbstractBasePtr expected0 = abstract_inputs->Clone(); AbstractBasePtr expected1 = abstract_scale->Clone(); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "expected0: " << expected0->ToString(); MS_LOG(INFO) << "expected1: " << expected1->ToString(); @@ -722,7 +722,7 @@ TEST_F(TestPrim, test_pooling) { inputs->set_shape(inputs_dims); AbstractBasePtr abstract_input = FromValue(inputs, false); AbstractBasePtrList args_spec_list = {abstract_input}; - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractBasePtr expected = abstract_input->Clone()->Broaden(); std::vector expected_dims = {8, 64, 2, 2}; @@ -747,7 +747,7 @@ TEST_F(TestPrim, test_hastype) { auto prim = std::make_shared("hastype"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -761,7 +761,7 @@ TEST_F(TestPrim, test_array_len) { auto prim = std::make_shared("array_len"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -775,7 +775,7 @@ TEST_F(TestPrim, test_list_len) { auto prim = std::make_shared("list_len"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -789,7 +789,7 @@ TEST_F(TestPrim, test_tuple_len) { auto prim = std::make_shared("tuple_len"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -803,7 +803,7 @@ TEST_F(TestPrim, test_tuple_reversed) { auto prim = std::make_shared("tuple_reversed"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 1); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "expect=" << expected->ToString(); ASSERT_TRUE(*res == *expected); } @@ -825,7 +825,7 @@ TEST_F(TestPrim, test_list_getitem) { auto prim = std::make_shared("list_getitem"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *elem); } @@ -844,7 +844,7 @@ TEST_F(TestPrim, test_list_setitem) { auto prim = std::make_shared("list_setitem"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); AbstractBasePtrList elems_exp = {elem1, elem2}; auto expected = std::make_shared(elems_exp); @@ -866,7 +866,7 @@ TEST_F(TestPrim, test_list_append) { auto prim = std::make_shared("list_append"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); auto expected = std::make_shared(AbstractBasePtrList({elem1, elem2})); MS_LOG(INFO) << "expected: " << expected->ToString(); @@ -890,7 +890,7 @@ TEST_F(TestPrim, test_tuple_setitem) { auto prim = std::make_shared("tuple_setitem"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 3); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); AbstractBasePtrList elems_exp = {elem1, elem2}; auto expected = std::make_shared(elems_exp); @@ -916,7 +916,7 @@ TEST_F(TestPrim, test_make_list) { auto prim = std::make_shared("make_list"); FuncGraphPtr func_graph = MakeFuncGraph(prim, 2); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(*res == *expected); } @@ -939,7 +939,7 @@ TEST_F(TestPrim, test_make_range) { AbstractBasePtrList elem_list({ele1, ele2, ele3}); AbstractBasePtr expected = std::make_shared(elem_list); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "res=" << res->ToString(); MS_LOG(INFO) << "expected=" << expected->ToString(); ASSERT_TRUE(*res == *expected); @@ -982,7 +982,7 @@ TEST_F(TestPrim, test_layernorm) { AbstractBasePtr expected1 = abstract_mean_var->Clone(); AbstractBasePtr expected2 = abstract_mean_var->Clone(); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "expected0: " << expected0->ToString(); MS_LOG(INFO) << "expected1: " << expected1->ToString(); @@ -1028,7 +1028,7 @@ TEST_F(TestPrim, test_DropoutGenMask) { AbstractBasePtr expected = std::make_shared(std::make_shared(kAnyValue, kUInt8), std::make_shared(std::vector{79})); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "res=" << res->ToString(); MS_LOG(INFO) << "expected=" << expected->ToString(); ASSERT_TRUE(*res == *expected); @@ -1058,7 +1058,7 @@ TEST_F(TestPrim, test_dropout) { std::vector shape = {2, 20, 32, 32}; expected->set_shape(std::make_shared(shape)); - AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract(); MS_LOG(INFO) << "result: " << res->ToString(); MS_LOG(INFO) << "expected: " << expected->ToString(); @@ -1079,7 +1079,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_01_dim) { auto x_input = std::make_shared(x_arg_list); auto y_input = std::make_shared(y_arg_list); AbstractBasePtrList args_spec_list = {x_input, y_input}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); AbstractBasePtrList x_idx_list; auto r_x = std::make_shared(x_idx_list); @@ -1103,7 +1103,7 @@ TEST_F(TestPrim, test_BroadcastGradientArgs_1_dim) { auto x_input = std::make_shared(x_arg_list); auto y_input = std::make_shared(y_arg_list); AbstractBasePtrList args_spec_list = {x_input, y_input}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); auto res = dyn_cast(ret); AbstractBasePtrList x_idx_list({abstract::FromValue(1)}); auto r_x = std::make_shared(x_idx_list); @@ -1128,7 +1128,7 @@ TEST_F(TestPrim, test_DictGetItem) { AbstractBasePtr key = abstract::FromValue("x"); AbstractBasePtrList args_spec_list = {array_dict, key}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractTensorPtr tensor_ret = dyn_cast(ret); AbstractTensorPtr expect = dyn_cast(FromValue(tensor_map[0].second)); @@ -1147,7 +1147,7 @@ TEST_F(TestPrim, test_DictGetItem2) { AbstractBasePtr key = abstract::FromValue("x"); AbstractBasePtrList args_spec_list = {array_dict, key}; - AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr ret = engine_->Run(func_graph, args_spec_list).inferred->abstract(); AbstractTensorPtr tensor_ret = dyn_cast(ret); AbstractTensorPtr expect = dyn_cast(arr_x); diff --git a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc index c60ac75f5b2..510651c5f03 100644 --- a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc @@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) { auto prim_scalar_add = std::make_shared("scalar_add"); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract(); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); } @@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) { MS_LOG(INFO) << "" << graph_f_->get_return()->ToString(); AbstractBasePtr abstract_v1 = FromValue(1, false); args_spec_list.push_back(abstract_v1); - AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract(); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); // now this test case failed randomly, have to debug. @@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) { args_spec_list.clear(); args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v2); - abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred; + abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract(); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); } @@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) { AbstractBasePtr abstract_v2 = FromValue(v1, false); args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v2); - AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract(); ASSERT_TRUE(abs_base_got.get() == abstract_v1.get()); } @@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { auto prim_scalar_add = std::make_shared("scalar_add"); FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add); - AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred; + AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract(); ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack())); ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32); } @@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) { AbstractBasePtr base1 = FromValue(x1, false); AbstractBasePtr base2 = FromValue(x2, false); AbstractBasePtrList base_list = {base1, base2}; - auto res = EvalOnePrim(std::make_shared("scalar_add"), base_list); + auto res = EvalOnePrim(std::make_shared("scalar_add"), base_list)->abstract(); MS_LOG(INFO) << "result spec: " << res->ToString(); AbstractBasePtr exp = FromValue(x3, false); MS_LOG(INFO) << "result exp: " << exp->ToString(); @@ -446,7 +446,7 @@ void TestGraphEval::TearDown() { TEST_F(TestGraphInfer, test_graph_infer_defaults) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(50), false); ASSERT_EQ(*res, *expect); } @@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) { TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(1), false); ASSERT_EQ(*res, *expect); } @@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) { TEST_F(TestGraphInfer, test_graph_infer_vararg) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(9), false); ASSERT_EQ(*res, *expect); } @@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(48), false); ASSERT_EQ(*res, *expect); } @@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) { TEST_F(TestGraphInfer, test_graph_infer_kwarg) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(7), false); ASSERT_EQ(*res, *expect); } @@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(46), false); ASSERT_EQ(*res, *expect); } @@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) { TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) { FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults"); AbstractBasePtrList args_spec_list = {}; - AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred; + AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract(); AbstractBasePtr expect = FromValue(MakeValue(57), false); ASSERT_EQ(*res, *expect); } diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 1216944fafc..3e258dba306 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config from ....mindspore_test_framework.pipeline.forward.verify_exception \ import pipeline_for_verify_exception_for_case_by_case_config - +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) def conv3x3(in_channels, out_channels, stride=1, padding=1): """3x3 convolution """ @@ -377,6 +378,21 @@ class StateNet(nn.Cell): return x +def test_conv2d_same_primitive(): + class Conv2DSameNet(nn.Cell): + def __init__(self): + super(Conv2DSameNet, self).__init__() + self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) + self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) + def construct(self, x, y): + r1 = self.conv1(x) + r2 = self.conv2(y) + return (r1, r2) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + net = Conv2DSameNet() + out = net(t1, t2) + class ComparisonNet(nn.Cell): def __init__(self): """ ComparisonNet definition """ diff --git a/tests/ut/python/ops/test_ops_attr_infer.py b/tests/ut/python/ops/test_ops_attr_infer.py new file mode 100644 index 00000000000..95b66e8ff0b --- /dev/null +++ b/tests/ut/python/ops/test_ops_attr_infer.py @@ -0,0 +1,276 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test nn ops """ +import functools +import numpy as np +import mindspore + +import mindspore.nn as nn +import mindspore.context as context +import mindspore.common.dtype as mstype + +from mindspore import Tensor, Parameter +from mindspore.common.initializer import initializer +from mindspore.ops import Primitive +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import prim_attr_register, PrimitiveWithInfer +from mindspore.ops.primitive import constexpr + +from ..ut_filter import non_graph_engine +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward \ + import pipeline_for_compile_forward_ge_graph_for_case_by_case_config +from ....mindspore_test_framework.pipeline.forward.verify_exception \ + import pipeline_for_verify_exception_for_case_by_case_config +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + +class FakeOp(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + """""" + def infer_shape(self, x, y): + self.second_shape = y + self.add_prim_attr("second_shape", y) + return x + + def infer_dtype(self, x, y): + return x + +# test the normal case that should generate independent primitive because of different +# generated attributes after inference +def test_conv2d_same_primitive(): + class Conv2DSameNet(nn.Cell): + def __init__(self): + super(Conv2DSameNet, self).__init__() + self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) + self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True) + def construct(self, x, y): + r1 = self.conv1(x) + r2 = self.conv2(y) + return (r1, r2) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + net = Conv2DSameNet() + out = net(t1, t2) + +# test cell as high order argument +# The graph with free variables used as argument is not supported yet +# because of the limit of inference specialize system +def Xtest_conv2d_op_with_arg(): + class Conv2dNet(nn.Cell): + def __init__(self): + super(Conv2dNet, self).__init__() + def construct(self, op, x): + return op(x) + class OpsNet(nn.Cell): + def __init__(self, net): + super(OpsNet, self).__init__() + self.opnet = net + self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True) + def construct(self, x, y): + conv_op = self.conv2 + a = self.opnet(conv_op, x) + b = self.opnet(conv_op, y) + return (a, b) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + net = OpsNet(Conv2dNet()) + out = net(t1, t2) + + +def test_conv2d_op_with_arg(): + class FackOpNet(nn.Cell): + def __init__(self): + super(FackOpNet, self).__init__() + self.op = FakeOp() + def construct(self, x, y): + return self.op(x, y) + class OpNet(nn.Cell): + def __init__(self): + super(OpNet, self).__init__() + def construct(self, op, x, y): + return op(x, y) + class OpsNet(nn.Cell): + def __init__(self, net): + super(OpsNet, self).__init__() + self.opnet = net + self.op = FackOpNet() + def construct(self, x, y): + op = self.op + a = self.opnet(op, x, y) + b = self.opnet(op, y, x) + return (a, b) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + net = OpsNet(OpNet()) + out = net(t1, t2) + + + +def test_conv2d_op_with_arg_same_input(): + class FackOpNet(nn.Cell): + def __init__(self): + super(FackOpNet, self).__init__() + self.op = FakeOp() + def construct(self, x, y): + return self.op(x, y) + class OpNet(nn.Cell): + def __init__(self): + super(OpNet, self).__init__() + def construct(self, op, x, y): + return op(x, y) + class OpsNet(nn.Cell): + def __init__(self, net): + super(OpsNet, self).__init__() + self.opnet = net + self.op = FackOpNet() + def construct(self, x, y): + op = self.op + a = self.opnet(op, x, x) + b = self.opnet(op, y, x) + return (a, b) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + net = OpsNet(OpNet()) + out = net(t1, t2) + +# test op with partial +def test_op_as_partial(): + class OpAsPartial(nn.Cell): + def __init__(self): + super(OpAsPartial, self).__init__() + self.op = FakeOp() + def construct(self, x, y, z): + partial_op = F.partial(self.op, x) + a = partial_op(y) + b = partial_op(z) + return a, b + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = OpAsPartial() + out = net(t1, t2, t3) + +# test op with partial +def test_op_as_partial_inside(): + class OpAsPartial(nn.Cell): + def __init__(self): + super(OpAsPartial, self).__init__() + self.op = FakeOp() + def construct(self, x, y, z): + partial_op = F.partial(self.op, x) + a = partial_op(y) + b = partial_op(z) + return a, b + class OuterNet(nn.Cell): + def __init__(self): + super(OuterNet, self).__init__() + self.net = OpAsPartial() + def construct(self, x, y, z): + a,b = self.net(x, y, z) + return a, b + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = OuterNet() + out = net(t1, t2, t3) + +# test op with partial case 2 +def test_op_as_partial_independent(): + class OpAsPartial(nn.Cell): + def __init__(self): + super(OpAsPartial, self).__init__() + self.op = FakeOp() + def construct(self, x, y, z): + partial_op1 = F.partial(self.op, x) + a = partial_op1(y) + partial_op2 = F.partial(self.op, x) + b = partial_op2(z) + return a, b + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = OpAsPartial() + out = net(t1, t2, t3) + +def test_nest_partial(): + class NestPartial(nn.Cell): + def __init__(self): + super(NestPartial, self).__init__() + self.op = FakeOp() + def construct(self, x, y, z): + partial_op1 = F.partial(self.op) + partial_op2 = F.partial(partial_op1, x) + a = partial_op2(y) + partial_op3 = F.partial(self.op) + partial_op4 = F.partial(partial_op3, x) + b = partial_op4(z) + return a, b + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = NestPartial() + out = net(t1, t2, t3) + +# high order argument +# op and op args as network arguments +def test_op_with_arg_as_input(): + class WithOpArgNet(nn.Cell): + def __init__(self): + super(WithOpArgNet, self).__init__() + def construct(self, op, x, y): + return op(x, y) + class OpsNet(nn.Cell): + def __init__(self, net): + super(OpsNet, self).__init__() + self.opnet = net + self.op = FakeOp() + def construct(self, x, y, z): + op = self.op + a = self.opnet(op, x, z) + b = self.opnet(op, x, y) + return (a, b) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = OpsNet(WithOpArgNet()) + out = net(t1, t2, t3) + +# The partial application used as argument is not supported yet +# because of the limit of inference specialize system +def Xtest_partial_as_arg(): + class PartialArgNet(nn.Cell): + def __init__(self): + super(PartialArgNet, self).__init__() + def construct(self, partial_op, y): + return partial_op(y) + class OpsNet(nn.Cell): + def __init__(self, net): + super(OpsNet, self).__init__() + self.partial_net = net + self.op = FakeOp() + def construct(self, x, y, z): + partial_op = F.partial(self.op, x) + a = self.partial_net(partial_op, z) + b = self.partial_net(partial_op, y) + return (a, b) + t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32)) + t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32)) + t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32)) + net = OpsNet(PartialArgNet()) + out = net(t1, t2, t3)