forked from OSSInnovation/mindspore
!3289 Move abstract function to core abstract folder.
Merge pull request !3289 from ZhangQinghua/master
This commit is contained in:
commit
a4f447af6c
|
@ -25,7 +25,7 @@
|
|||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "frontend/operator/cc_implementations.h"
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "debug/trace.h"
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "frontend/operator/cc_implementations.h"
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "./common.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/do_signature.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "utils/graph_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/profile.h"
|
||||
|
|
|
@ -434,8 +434,30 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<TypedPrimiti
|
|||
// Forward to specific subclass of FunctionWrapper.
|
||||
EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this());
|
||||
return evaluator;
|
||||
if (func->isa<PrimitiveAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<PrimitiveAbstractClosure>>());
|
||||
} else if (func->isa<FuncGraphAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<FuncGraphAbstractClosure>>());
|
||||
} else if (func->isa<MetaFuncGraphAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<MetaFuncGraphAbstractClosure>>());
|
||||
} else if (func->isa<JTransformedAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<JTransformedAbstractClosure>>());
|
||||
} else if (func->isa<VirtualAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<VirtualAbstractClosure>>());
|
||||
} else if (func->isa<PartialAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<PartialAbstractClosure>>());
|
||||
} else if (func->isa<TypedPrimitiveAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(func->cast<std::shared_ptr<TypedPrimitiveAbstractClosure>>());
|
||||
} else if (func->isa<AbstractFuncAtom>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
|
||||
} else if (func->isa<AbstractFuncUnion>()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
|
||||
} else if (func->isa<DummyAbstractClosure>()) {
|
||||
MS_LOG(EXCEPTION) << "A dummy function cannot eval";
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/primitive_py.h"
|
||||
#include "abstract/analysis_context.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -14,12 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
class Evaluator;
|
||||
|
@ -134,11 +132,6 @@ std::size_t AbstractFuncUnion::hash() const {
|
|||
return hash_sum;
|
||||
}
|
||||
|
||||
EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<PrimitiveAbstractClosure>());
|
||||
}
|
||||
|
||||
bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<PrimitiveAbstractClosure>()) {
|
||||
return false;
|
||||
|
@ -152,11 +145,6 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
|
|||
|
||||
std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
|
||||
|
||||
EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<FuncGraphAbstractClosure>());
|
||||
}
|
||||
|
||||
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<FuncGraphAbstractClosure>()) {
|
||||
return false;
|
||||
|
@ -181,11 +169,6 @@ std::string FuncGraphAbstractClosure::ToString() const {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<MetaFuncGraphAbstractClosure>());
|
||||
}
|
||||
|
||||
bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<MetaFuncGraphAbstractClosure>()) {
|
||||
return false;
|
||||
|
@ -229,11 +212,6 @@ std::size_t PartialAbstractClosure::hash() const {
|
|||
return hash_value;
|
||||
}
|
||||
|
||||
EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<PartialAbstractClosure>());
|
||||
}
|
||||
|
||||
std::string PartialAbstractClosure::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "PartialAbstractClosure(" << fn_->ToString() << "(";
|
||||
|
@ -244,11 +222,6 @@ std::string PartialAbstractClosure::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<JTransformedAbstractClosure>());
|
||||
}
|
||||
|
||||
bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<JTransformedAbstractClosure>()) {
|
||||
return false;
|
||||
|
@ -265,11 +238,6 @@ std::size_t JTransformedAbstractClosure::hash() const {
|
|||
return hash_value;
|
||||
}
|
||||
|
||||
EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
return engine->_GetEvaluatorFor(shared_from_base<VirtualAbstractClosure>());
|
||||
}
|
||||
|
||||
bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<VirtualAbstractClosure>()) {
|
||||
return false;
|
||||
|
@ -306,12 +274,6 @@ std::string VirtualAbstractClosure::ToString() const {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
|
||||
return engine->_GetEvaluatorFor(shared_from_base<TypedPrimitiveAbstractClosure>());
|
||||
}
|
||||
|
||||
bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
|
||||
if (!other.isa<TypedPrimitiveAbstractClosure>()) {
|
||||
return false;
|
|
@ -16,8 +16,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
|
||||
#ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
|
||||
#define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -35,10 +35,6 @@ class AbstractFuncAtom : public AbstractFunction {
|
|||
MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction)
|
||||
|
||||
AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); }
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom";
|
||||
}
|
||||
|
||||
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
|
||||
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
|
@ -56,9 +52,6 @@ class AbstractFuncUnion : public AbstractFunction {
|
|||
std::string ToString() const override;
|
||||
|
||||
AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; }
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override {
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion";
|
||||
}
|
||||
bool IsSuperSet(const AbstractFunctionPtr &other);
|
||||
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
|
||||
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
|
||||
|
@ -80,8 +73,6 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
|
|||
~PrimitiveAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
PrimitivePtr prim() { return prim_; }
|
||||
|
||||
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
|
||||
|
@ -114,8 +105,6 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
|
|||
~FuncGraphAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
FuncGraphPtr func_graph() { return func_graph_; }
|
||||
|
||||
AnalysisContextPtr context() const override { return context_; }
|
||||
|
@ -146,8 +135,6 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
|
|||
|
||||
AnalysisContextPtr context() const override { return kDummyAnalysisContext; }
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
ScopePtr GetScope() { return scope_; }
|
||||
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
|
||||
|
@ -172,8 +159,6 @@ class PartialAbstractClosure : public AbstractFuncAtom {
|
|||
~PartialAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
AbstractFunctionPtr fn() { return fn_; }
|
||||
AbstractBasePtrList args() { return args_spec_list_; }
|
||||
AnfNodePtr node() { return node_.lock(); }
|
||||
|
@ -199,7 +184,6 @@ class JTransformedAbstractClosure : public AbstractFuncAtom {
|
|||
explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {}
|
||||
~JTransformedAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom)
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
AbstractFuncAtomPtr fn() { return fn_; }
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); }
|
||||
|
@ -224,8 +208,6 @@ class VirtualAbstractClosure : public AbstractFuncAtom {
|
|||
~VirtualAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
AbstractBasePtrList args_spec_list() { return args_spec_list_; }
|
||||
|
||||
AbstractBasePtr output() { return output_; }
|
||||
|
@ -254,8 +236,6 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
|
|||
~TypedPrimitiveAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override;
|
||||
|
||||
PrimitivePtr prim() { return prim_; }
|
||||
AbstractBasePtrList args_spec_list() { return args_spec_list_; }
|
||||
AbstractBasePtr output() { return output_; }
|
||||
|
@ -280,8 +260,6 @@ class DummyAbstractClosure : public AbstractFuncAtom {
|
|||
~DummyAbstractClosure() override = default;
|
||||
MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
|
||||
|
||||
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
|
||||
|
||||
AbstractFunctionPtr Copy() const override { return std::make_shared<DummyAbstractClosure>(); }
|
||||
bool operator==(const AbstractFunction &other) const override;
|
||||
|
||||
|
@ -300,4 +278,4 @@ struct AbstractFunctionEqual {
|
|||
};
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_
|
|
@ -193,7 +193,6 @@ class AbstractFunction : public AbstractBase {
|
|||
|
||||
static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list);
|
||||
|
||||
virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0;
|
||||
virtual AnfNodePtr tracking_id() const { return nullptr; }
|
||||
virtual void set_tracking_id(AnfNodePtr) {}
|
||||
virtual AnalysisContextPtr context() const { return nullptr; }
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "frontend/operator/composite/composite.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
Loading…
Reference in New Issue