!3289 Move abstract function to core abstract folder.

Merge pull request !3289 from ZhangQinghua/master
This commit is contained in:
mindspore-ci-bot 2020-07-23 09:39:37 +08:00 committed by Gitee
commit a4f447af6c
11 changed files with 35 additions and 74 deletions

View File

@ -25,7 +25,7 @@
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse.h"
#include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/parse_base.h"

View File

@ -25,7 +25,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h" #include "frontend/operator/cc_implementations.h"

View File

@ -23,7 +23,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "debug/trace.h" #include "debug/trace.h"

View File

@ -25,7 +25,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "abstract/param_validator.h" #include "abstract/param_validator.h"
#include "frontend/operator/cc_implementations.h" #include "frontend/operator/cc_implementations.h"

View File

@ -23,7 +23,7 @@
#include "./common.h" #include "./common.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/operator/composite/do_signature.h" #include "frontend/operator/composite/do_signature.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "utils/graph_utils.h" #include "utils/graph_utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/profile.h" #include "utils/profile.h"

View File

@ -434,8 +434,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimiti
// Forward to specific subclass of FunctionWrapper. // Forward to specific subclass of FunctionWrapper.
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_EXCEPTION_IF_NULL(func); MS_EXCEPTION_IF_NULL(func);
EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); if (func->isa<PrimitiveAbstractClosure>()) {
return evaluator; return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>());
} else if (func->isa<FuncGraphAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>());
} else if (func->isa<MetaFuncGraphAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
} else if (func->isa<JTransformedAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
} else if (func->isa<VirtualAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
} else if (func->isa<PartialAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>());
} else if (func->isa<TypedPrimitiveAbstractClosure>()) {
return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>());
} else if (func->isa<AbstractFuncAtom>()) {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
} else if (func->isa<AbstractFuncUnion>()) {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
} else if (func->isa<DummyAbstractClosure>()) {
MS_LOG(EXCEPTION) << "A dummy function cannot eval";
} else {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
}
return nullptr;
} }
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {

View File

@ -35,7 +35,7 @@
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/primitive_py.h" #include "ir/primitive_py.h"
#include "abstract/analysis_context.h" #include "abstract/analysis_context.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/parse.h"
namespace mindspore { namespace mindspore {

View File

@ -14,12 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include <vector> #include <vector>
#include "pipeline/jit/static_analysis/static_analysis.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
class Evaluator; class Evaluator;
@ -134,11 +132,6 @@ std::size_t AbstractFuncUnion::hash() const {
return hash_sum; return hash_sum;
} }
EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>());
}
bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<PrimitiveAbstractClosure>()) { if (!other.isa<PrimitiveAbstractClosure>()) {
return false; return false;
@ -152,11 +145,6 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>());
}
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<FuncGraphAbstractClosure>()) { if (!other.isa<FuncGraphAbstractClosure>()) {
return false; return false;
@ -181,11 +169,6 @@ std::string FuncGraphAbstractClosure::ToString() const {
return ss.str(); return ss.str();
} }
EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>());
}
bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<MetaFuncGraphAbstractClosure>()) { if (!other.isa<MetaFuncGraphAbstractClosure>()) {
return false; return false;
@ -229,11 +212,6 @@ std::size_t PartialAbstractClosure::hash() const {
return hash_value; return hash_value;
} }
EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>());
}
std::string PartialAbstractClosure::ToString() const { std::string PartialAbstractClosure::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
@ -244,11 +222,6 @@ std::string PartialAbstractClosure::ToString() const {
return buffer.str(); return buffer.str();
} }
EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>());
}
bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<JTransformedAbstractClosure>()) { if (!other.isa<JTransformedAbstractClosure>()) {
return false; return false;
@ -265,11 +238,6 @@ std::size_t JTransformedAbstractClosure::hash() const {
return hash_value; return hash_value;
} }
EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>());
}
bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<VirtualAbstractClosure>()) { if (!other.isa<VirtualAbstractClosure>()) {
return false; return false;
@ -306,12 +274,6 @@ std::string VirtualAbstractClosure::ToString() const {
return buffer.str(); return buffer.str();
} }
EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
MS_EXCEPTION_IF_NULL(engine);
return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>());
}
bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<TypedPrimitiveAbstractClosure>()) { if (!other.isa<TypedPrimitiveAbstractClosure>()) {
return false; return false;

View File

@ -16,8 +16,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
#include <memory> #include <memory>
#include <string> #include <string>
@ -35,10 +35,6 @@ class AbstractFuncAtom : public AbstractFunction {
MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); } AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
}
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
bool operator==(const AbstractFunction &other) const override; bool operator==(const AbstractFunction &other) const override;
@ -56,9 +52,6 @@ class AbstractFuncUnion : public AbstractFunction {
std::string ToString() const override; std::string ToString() const override;
AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
}
bool IsSuperSet(const AbstractFunctionPtr &other); bool IsSuperSet(const AbstractFunctionPtr &other);
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final; void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
@ -80,8 +73,6 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
~PrimitiveAbstractClosure() override = default; ~PrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
PrimitivePtr prim() { return prim_; } PrimitivePtr prim() { return prim_; }
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
@ -114,8 +105,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
~FuncGraphAbstractClosure() override = default; ~FuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
FuncGraphPtr func_graph() { return func_graph_; } FuncGraphPtr func_graph() { return func_graph_; }
AnalysisContextPtr context() const override { return context_; } AnalysisContextPtr context() const override { return context_; }
@ -146,8 +135,6 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
AnalysisContextPtr context() const override { return kDummyAnalysisContext; } AnalysisContextPtr context() const override { return kDummyAnalysisContext; }
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
ScopePtr GetScope() { return scope_; } ScopePtr GetScope() { return scope_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); } AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
@ -172,8 +159,6 @@ class PartialAbstractClosure : public AbstractFuncAtom {
~PartialAbstractClosure() override = default; ~PartialAbstractClosure() override = default;
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
AbstractFunctionPtr fn() { return fn_; } AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; } AbstractBasePtrList args() { return args_spec_list_; }
AnfNodePtr node() { return node_.lock(); } AnfNodePtr node() { return node_.lock(); }
@ -199,7 +184,6 @@ class JTransformedAbstractClosure : public AbstractFuncAtom {
explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
~JTransformedAbstractClosure() override = default; ~JTransformedAbstractClosure() override = default;
MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
AbstractFuncAtomPtr fn() { return fn_; } AbstractFuncAtomPtr fn() { return fn_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); } AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
@ -224,8 +208,6 @@ class VirtualAbstractClosure : public AbstractFuncAtom {
~VirtualAbstractClosure() override = default; ~VirtualAbstractClosure() override = default;
MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
AbstractBasePtrList args_spec_list() { return args_spec_list_; } AbstractBasePtrList args_spec_list() { return args_spec_list_; }
AbstractBasePtr output() { return output_; } AbstractBasePtr output() { return output_; }
@ -254,8 +236,6 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
~TypedPrimitiveAbstractClosure() override = default; ~TypedPrimitiveAbstractClosure() override = default;
MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
PrimitivePtr prim() { return prim_; } PrimitivePtr prim() { return prim_; }
AbstractBasePtrList args_spec_list() { return args_spec_list_; } AbstractBasePtrList args_spec_list() { return args_spec_list_; }
AbstractBasePtr output() { return output_; } AbstractBasePtr output() { return output_; }
@ -280,8 +260,6 @@ class DummyAbstractClosure : public AbstractFuncAtom {
~DummyAbstractClosure() override = default; ~DummyAbstractClosure() override = default;
MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); } AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
bool operator==(const AbstractFunction &other) const override; bool operator==(const AbstractFunction &other) const override;
@ -300,4 +278,4 @@ struct AbstractFunctionEqual {
}; };
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_

View File

@ -193,7 +193,6 @@ class AbstractFunction : public AbstractBase {
static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list);
virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0;
virtual AnfNodePtr tracking_id() const { return nullptr; } virtual AnfNodePtr tracking_id() const { return nullptr; }
virtual void set_tracking_id(AnfNodePtr) {} virtual void set_tracking_id(AnfNodePtr) {}
virtual AnalysisContextPtr context() const { return nullptr; } virtual AnalysisContextPtr context() const { return nullptr; }

View File

@ -21,7 +21,7 @@
#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/composite.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/jit/static_analysis/prim.h"
#include "pipeline/jit/static_analysis/abstract_function.h" #include "abstract/abstract_function.h"
#include "debug/trace.h" #include "debug/trace.h"
namespace mindspore { namespace mindspore {