diff --git a/mindspore/ccsrc/common.h b/mindspore/ccsrc/common.h index 6b882a15d4..635010cea8 100644 --- a/mindspore/ccsrc/common.h +++ b/mindspore/ccsrc/common.h @@ -25,7 +25,7 @@ #include "abstract/dshape.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse_base.h" diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 5fcbe258ba..fbcb06629d 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -25,7 +25,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "abstract/param_validator.h" #include "frontend/operator/cc_implementations.h" diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc index a5f674187b..f49c19aa9c 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.cc +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -23,7 +23,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "pybind_api/api_register.h" #include "debug/trace.h" diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index 16aa6f654b..28d3119bde 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -25,7 +25,7 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "abstract/abstract_value.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "abstract/dshape.h" #include "abstract/param_validator.h" #include "frontend/operator/cc_implementations.h" diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index fe5871fe5e..7afb4037b2 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -23,7 +23,7 @@ #include "./common.h" #include "frontend/operator/ops.h" #include "frontend/operator/composite/do_signature.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "utils/graph_utils.h" #include "utils/log_adapter.h" #include "utils/profile.h" diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index b9e747a70b..b8543bad48 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -434,8 +434,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptrGetEvaluator(shared_from_this()); - return evaluator; + if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + return _GetEvaluatorFor(func->cast>()); + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; + } else if (func->isa()) { + MS_LOG(EXCEPTION) << "A dummy function cannot eval"; + } else { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction"; + } + return nullptr; } EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index f909fcbd8f..e2837a7da0 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -35,7 +35,7 @@ #include "ir/anf.h" #include "ir/primitive_py.h" #include "abstract/analysis_context.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "pipeline/jit/parse/parse.h" namespace mindspore { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc b/mindspore/core/abstract/abstract_function.cc similarity index 86% rename from mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc rename to mindspore/core/abstract/abstract_function.cc index 8bdb2a0c6c..402b9327c5 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc +++ b/mindspore/core/abstract/abstract_function.cc @@ -14,12 +14,10 @@ * limitations under the License. */ -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include -#include "pipeline/jit/static_analysis/static_analysis.h" - namespace mindspore { namespace abstract { class Evaluator; @@ -134,11 +132,6 @@ std::size_t AbstractFuncUnion::hash() const { return hash_sum; } -EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; @@ -152,11 +145,6 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } -EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; @@ -181,11 +169,6 @@ std::string FuncGraphAbstractClosure::ToString() const { return ss.str(); } -EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; @@ -229,11 +212,6 @@ std::size_t PartialAbstractClosure::hash() const { return hash_value; } -EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - std::string PartialAbstractClosure::ToString() const { std::ostringstream buffer; buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; @@ -244,11 +222,6 @@ std::string PartialAbstractClosure::ToString() const { return buffer.str(); } -EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; @@ -265,11 +238,6 @@ std::size_t JTransformedAbstractClosure::hash() const { return hash_value; } -EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; @@ -306,12 +274,6 @@ std::string VirtualAbstractClosure::ToString() const { return buffer.str(); } -EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - - return engine->_GetEvaluatorFor(shared_from_base()); -} - bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { if (!other.isa()) { return false; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h b/mindspore/core/abstract/abstract_function.h similarity index 91% rename from mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h rename to mindspore/core/abstract/abstract_function.h index 7887ac51ad..5e33384218 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -16,8 +16,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ -#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ +#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ +#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ #include #include @@ -35,10 +35,6 @@ class AbstractFuncAtom : public AbstractFunction { MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) AbstractFunctionPtr GetUnique() override { return shared_from_base(); } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; - } - AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; void Visit(std::function) const final; bool operator==(const AbstractFunction &other) const override; @@ -56,9 +52,6 @@ class AbstractFuncUnion : public AbstractFunction { std::string ToString() const override; AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; - } bool IsSuperSet(const AbstractFunctionPtr &other); AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; void Visit(std::function) const final; @@ -80,8 +73,6 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom { ~PrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - PrimitivePtr prim() { return prim_; } AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } @@ -114,8 +105,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { ~FuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - FuncGraphPtr func_graph() { return func_graph_; } AnalysisContextPtr context() const override { return context_; } @@ -146,8 +135,6 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { AnalysisContextPtr context() const override { return kDummyAnalysisContext; } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - ScopePtr GetScope() { return scope_; } AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } @@ -172,8 +159,6 @@ class PartialAbstractClosure : public AbstractFuncAtom { ~PartialAbstractClosure() override = default; MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - AbstractFunctionPtr fn() { return fn_; } AbstractBasePtrList args() { return args_spec_list_; } AnfNodePtr node() { return node_.lock(); } @@ -199,7 +184,6 @@ class JTransformedAbstractClosure : public AbstractFuncAtom { explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} ~JTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; AbstractFuncAtomPtr fn() { return fn_; } AbstractFunctionPtr Copy() const override { return std::make_shared(fn_); } @@ -224,8 +208,6 @@ class VirtualAbstractClosure : public AbstractFuncAtom { ~VirtualAbstractClosure() override = default; MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - AbstractBasePtrList args_spec_list() { return args_spec_list_; } AbstractBasePtr output() { return output_; } @@ -254,8 +236,6 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { ~TypedPrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - PrimitivePtr prim() { return prim_; } AbstractBasePtrList args_spec_list() { return args_spec_list_; } AbstractBasePtr output() { return output_; } @@ -280,8 +260,6 @@ class DummyAbstractClosure : public AbstractFuncAtom { ~DummyAbstractClosure() override = default; MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; } - AbstractFunctionPtr Copy() const override { return std::make_shared(); } bool operator==(const AbstractFunction &other) const override; @@ -300,4 +278,4 @@ struct AbstractFunctionEqual { }; } // namespace abstract } // namespace mindspore -#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ +#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 0acc516728..d74837d6d2 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -193,7 +193,6 @@ class AbstractFunction : public AbstractBase { static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); - virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; virtual AnfNodePtr tracking_id() const { return nullptr; } virtual void set_tracking_id(AnfNodePtr) {} virtual AnalysisContextPtr context() const { return nullptr; } diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index a2108998bc..9912e0c4e8 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -21,7 +21,7 @@ #include "frontend/operator/composite/composite.h" #include "frontend/operator/ops.h" #include "pipeline/jit/static_analysis/prim.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/abstract_function.h" #include "debug/trace.h" namespace mindspore {