Extract api for FuncGraph & FuncGraphManager
This commit is contained in:
parent
ba3aa00e92
commit
d035d24bb4
|
@ -314,6 +314,8 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/abstract/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/base/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/base
|
||||
|
@ -429,6 +431,8 @@ else()
|
|||
PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/abstract/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/base/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/base
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_API_FUNC_GRAPH_H_
|
||||
#define MINDSPORE_API_FUNC_GRAPH_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "api/ir/func_graph_manager.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
|
||||
class FuncGraph {
|
||||
public:
|
||||
FuncGraph() = default;
|
||||
virtual ~FuncGraph() = default;
|
||||
|
||||
virtual const std::vector<AnfNodePtr> get_inputs() const = 0;
|
||||
virtual const std::vector<AnfNodePtr> ¶meters() const = 0;
|
||||
virtual void add_parameter(const ParameterPtr &p) = 0;
|
||||
|
||||
virtual AnfNodePtr output() const = 0;
|
||||
virtual CNodePtr get_return() const = 0;
|
||||
virtual void set_output(const AnfNodePtr &value, bool force_new_ret = false) = 0;
|
||||
|
||||
virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) = 0;
|
||||
|
||||
virtual bool has_attr(const std::string &key) const = 0;
|
||||
virtual ValuePtr get_attr(const std::string &key) const = 0;
|
||||
virtual void set_attr(const std::string &key, const ValuePtr &value) = 0;
|
||||
|
||||
virtual FuncGraphManagerPtr get_manager() const = 0;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_API_FUNC_GRAPH_H_
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_API_FUNC_GRAPH_MANAGER_H_
|
||||
#define MINDSPORE_API_FUNC_GRAPH_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/ordered_set.h"
|
||||
#include "utils/ordered_map.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
|
||||
class FuncGraph;
|
||||
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
|
||||
|
||||
class FuncGraphManager;
|
||||
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
||||
|
||||
struct AnfNodeIndexPairHasher {
|
||||
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
|
||||
return std::hash<const AnfNode *>{}(p1.first.get());
|
||||
}
|
||||
};
|
||||
|
||||
struct AnfNodeIndexPairEqual {
|
||||
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
using AnfNodeIndexSet = OrderedSet<std::pair<AnfNodePtr, int>, AnfNodeIndexPairHasher, AnfNodeIndexPairEqual>;
|
||||
using NodeUsersMap = OrderedMap<AnfNodePtr, AnfNodeIndexSet>;
|
||||
|
||||
class FuncGraphManager {
|
||||
public:
|
||||
FuncGraphManager() = default;
|
||||
virtual ~FuncGraphManager() = default;
|
||||
|
||||
virtual bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) = 0;
|
||||
virtual void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) = 0;
|
||||
virtual void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) = 0;
|
||||
virtual const NodeUsersMap &node_users() const = 0;
|
||||
|
||||
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
|
||||
};
|
||||
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_API_FUNC_GRAPH_MANAGER_H_
|
|
@ -140,12 +140,12 @@ bool FuncGraph::has_flag(const std::string &key) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool FuncGraph::has_attr(const std::string &key) {
|
||||
bool FuncGraph::has_attr(const std::string &key) const {
|
||||
auto iter = attrs_.find(key);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
|
||||
ValuePtr FuncGraph::get_attr(const std::string &key) {
|
||||
ValuePtr FuncGraph::get_attr(const std::string &key) const {
|
||||
auto iter = attrs_.find(key);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "base/effect_info.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "api/ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
|
||||
|
@ -151,7 +152,7 @@ class FuncGraphBase : public Value {
|
|||
MS_DECLARE_PARENT(FuncGraphBase, Value);
|
||||
};
|
||||
|
||||
class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||
class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
|
||||
public:
|
||||
FuncGraph();
|
||||
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
|
||||
|
@ -164,15 +165,15 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
// get function graph inputs, but parameters
|
||||
const std::vector<AnfNodePtr> get_inputs() const;
|
||||
const std::vector<AnfNodePtr> get_inputs() const final;
|
||||
// Return the graph's output, or nullptr if not yet deduced.
|
||||
AnfNodePtr output() const;
|
||||
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
|
||||
|
||||
const std::vector<AnfNodePtr> ¶meters() const { return parameters_; }
|
||||
const std::vector<AnfNodePtr> ¶meters() const final { return parameters_; }
|
||||
// Append
|
||||
virtual ParameterPtr add_parameter();
|
||||
void add_parameter(const ParameterPtr &p);
|
||||
void add_parameter(const ParameterPtr &p) final;
|
||||
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
|
||||
// Prepend
|
||||
virtual ParameterPtr InsertFrontParameter();
|
||||
|
@ -183,8 +184,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
ParameterPtr AddWeightParameter(const std::string &name);
|
||||
|
||||
// Create a cnode with given inputs, bound to this graph.
|
||||
virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
|
||||
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
|
||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
|
||||
CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
|
||||
|
||||
// Create a cnode with given inputs, bound to this graph and push back to order list.
|
||||
CNodePtr NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
|
||||
|
@ -240,21 +241,23 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
|
||||
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
|
||||
|
||||
bool has_attr(const std::string &key);
|
||||
ValuePtr get_attr(const std::string &key);
|
||||
void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
|
||||
bool has_attr(const std::string &key) const final;
|
||||
ValuePtr get_attr(const std::string &key) const final;
|
||||
void set_attr(const std::string &key, const ValuePtr &value) final { attrs_[key] = value; }
|
||||
|
||||
std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
|
||||
void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
|
||||
transforms_ = transforms;
|
||||
}
|
||||
|
||||
CNodePtr get_return() const { return return_; }
|
||||
CNodePtr get_return() const final { return return_; }
|
||||
void set_return(const CNodePtr &cnode) { return_ = cnode; }
|
||||
|
||||
FuncGraphManagerPtr manager() const { return manager_.lock(); }
|
||||
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
|
||||
|
||||
api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
|
||||
|
||||
std::string ToString() const override;
|
||||
GraphDebugInfoPtr debug_info();
|
||||
void set_debug_info(const GraphDebugInfoPtr &info) {
|
||||
|
|
|
@ -67,6 +67,10 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
|
|||
return Manage(func_graphs, manage);
|
||||
}
|
||||
|
||||
api::FuncGraphManagerPtr api::FuncGraphManager::Manage(const api::FuncGraphPtr &func_graph, bool manage) {
|
||||
return mindspore::Manage(std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph), manage);
|
||||
}
|
||||
|
||||
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
|
||||
: roots_(roots), is_manage_(manage) {
|
||||
Reset();
|
||||
|
|
|
@ -34,11 +34,12 @@
|
|||
#include "utils/signal.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "utils/ordered_map.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/counter.h"
|
||||
#include "utils/hashing.h"
|
||||
#include "base/base_ref.h"
|
||||
#include "ir/anf.h"
|
||||
#include "api/ir/func_graph_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
struct Change;
|
||||
|
@ -46,20 +47,9 @@ class FuncGraphTransaction;
|
|||
class FuncGraphManager;
|
||||
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
||||
|
||||
struct AnfNodeIndexPairHasher {
|
||||
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
|
||||
return std::hash<const AnfNode *>{}(p1.first.get());
|
||||
}
|
||||
};
|
||||
|
||||
struct AnfNodeIndexPairEqual {
|
||||
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
using AnfNodeIndexSet = OrderedSet<std::pair<AnfNodePtr, int>, AnfNodeIndexPairHasher, AnfNodeIndexPairEqual>;
|
||||
using AnfNodeIndexSet = api::AnfNodeIndexSet;
|
||||
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
|
||||
using NodeUsersMap = OrderedMap<AnfNodePtr, AnfNodeIndexSet>;
|
||||
using NodeUsersMap = api::NodeUsersMap;
|
||||
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
|
||||
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
|
||||
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
|
||||
|
@ -294,7 +284,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
|
|||
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
|
||||
};
|
||||
|
||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
|
||||
public:
|
||||
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
|
||||
~FuncGraphManager() {
|
||||
|
@ -314,9 +304,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
|
||||
void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter);
|
||||
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
|
||||
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
|
||||
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) final;
|
||||
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) final;
|
||||
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) final;
|
||||
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope);
|
||||
|
||||
FuncGraphTransaction Transact();
|
||||
|
@ -332,6 +322,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
|
||||
NodeUsersMap &node_users() { return node_users_; }
|
||||
|
||||
const NodeUsersMap &node_users() const final { return node_users_; }
|
||||
|
||||
FVTotalMap &free_variables_total() const;
|
||||
|
||||
FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;
|
||||
|
|
Loading…
Reference in New Issue