Extract api for FuncGraph & FuncGraphManager

This commit is contained in:
He Wei 2021-07-26 09:28:59 +08:00
parent ba3aa00e92
commit d035d24bb4
7 changed files with 148 additions and 30 deletions

View File

@ -314,6 +314,8 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/mindspore/core/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/abstract/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/base/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/base
@ -429,6 +431,8 @@ else()
PATTERN "train*" EXCLUDE PATTERN "delegate.h" EXCLUDE PATTERN "lite_session.h" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/mindspore/core/api/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/abstract/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/abstract
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/base/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/base

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_API_FUNC_GRAPH_H_
#define MINDSPORE_API_FUNC_GRAPH_H_
#include <vector>
#include <memory>
#include <string>
#include "api/ir/func_graph_manager.h"
namespace mindspore::api {
class FuncGraph {
public:
FuncGraph() = default;
virtual ~FuncGraph() = default;
virtual const std::vector<AnfNodePtr> get_inputs() const = 0;
virtual const std::vector<AnfNodePtr> &parameters() const = 0;
virtual void add_parameter(const ParameterPtr &p) = 0;
virtual AnfNodePtr output() const = 0;
virtual CNodePtr get_return() const = 0;
virtual void set_output(const AnfNodePtr &value, bool force_new_ret = false) = 0;
virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) = 0;
virtual bool has_attr(const std::string &key) const = 0;
virtual ValuePtr get_attr(const std::string &key) const = 0;
virtual void set_attr(const std::string &key, const ValuePtr &value) = 0;
virtual FuncGraphManagerPtr get_manager() const = 0;
};
} // namespace mindspore::api
#endif // MINDSPORE_API_FUNC_GRAPH_H_

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_API_FUNC_GRAPH_MANAGER_H_
#define MINDSPORE_API_FUNC_GRAPH_MANAGER_H_
#include <memory>
#include <utility>
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
#include "ir/anf.h"
namespace mindspore::api {
class FuncGraph;
using FuncGraphPtr = std::shared_ptr<FuncGraph>;
class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
struct AnfNodeIndexPairHasher {
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
return std::hash<const AnfNode *>{}(p1.first.get());
}
};
struct AnfNodeIndexPairEqual {
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
return lhs == rhs;
}
};
using AnfNodeIndexSet = OrderedSet<std::pair<AnfNodePtr, int>, AnfNodeIndexPairHasher, AnfNodeIndexPairEqual>;
using NodeUsersMap = OrderedMap<AnfNodePtr, AnfNodeIndexSet>;
class FuncGraphManager {
public:
FuncGraphManager() = default;
virtual ~FuncGraphManager() = default;
virtual bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) = 0;
virtual void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) = 0;
virtual void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) = 0;
virtual const NodeUsersMap &node_users() const = 0;
static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
};
} // namespace mindspore::api
#endif // MINDSPORE_API_FUNC_GRAPH_MANAGER_H_

View File

@ -140,12 +140,12 @@ bool FuncGraph::has_flag(const std::string &key) {
return false;
}
bool FuncGraph::has_attr(const std::string &key) {
bool FuncGraph::has_attr(const std::string &key) const {
auto iter = attrs_.find(key);
return !(iter == attrs_.cend());
}
ValuePtr FuncGraph::get_attr(const std::string &key) {
ValuePtr FuncGraph::get_attr(const std::string &key) const {
auto iter = attrs_.find(key);
return iter == attrs_.cend() ? nullptr : iter->second;
}

View File

@ -38,6 +38,7 @@
#include "base/effect_info.h"
#include "ir/func_graph_cloner.h"
#include "abstract/abstract_value.h"
#include "api/ir/func_graph.h"
namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
@ -151,7 +152,7 @@ class FuncGraphBase : public Value {
MS_DECLARE_PARENT(FuncGraphBase, Value);
};
class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
public:
FuncGraph();
using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
@ -164,15 +165,15 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
abstract::AbstractBasePtr ToAbstract() override;
// get function graph inputs, but parameters
const std::vector<AnfNodePtr> get_inputs() const;
const std::vector<AnfNodePtr> get_inputs() const final;
// Return the graph's output, or nullptr if not yet deduced.
AnfNodePtr output() const;
void set_output(const AnfNodePtr &value, bool force_new_ret = false);
const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
const std::vector<AnfNodePtr> &parameters() const final { return parameters_; }
// Append
virtual ParameterPtr add_parameter();
void add_parameter(const ParameterPtr &p);
void add_parameter(const ParameterPtr &p) final;
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
// Prepend
virtual ParameterPtr InsertFrontParameter();
@ -183,8 +184,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
ParameterPtr AddWeightParameter(const std::string &name);
// Create a cnode with given inputs, bound to this graph.
virtual CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
// Create a cnode with given inputs, bound to this graph and push back to order list.
CNodePtr NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
@ -240,21 +241,23 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
bool has_attr(const std::string &key);
ValuePtr get_attr(const std::string &key);
void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
bool has_attr(const std::string &key) const final;
ValuePtr get_attr(const std::string &key) const final;
void set_attr(const std::string &key, const ValuePtr &value) final { attrs_[key] = value; }
std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
transforms_ = transforms;
}
CNodePtr get_return() const { return return_; }
CNodePtr get_return() const final { return return_; }
void set_return(const CNodePtr &cnode) { return_ = cnode; }
FuncGraphManagerPtr manager() const { return manager_.lock(); }
void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
std::string ToString() const override;
GraphDebugInfoPtr debug_info();
void set_debug_info(const GraphDebugInfoPtr &info) {

View File

@ -67,6 +67,10 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) {
return Manage(func_graphs, manage);
}
api::FuncGraphManagerPtr api::FuncGraphManager::Manage(const api::FuncGraphPtr &func_graph, bool manage) {
return mindspore::Manage(std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph), manage);
}
FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage)
: roots_(roots), is_manage_(manage) {
Reset();

View File

@ -34,11 +34,12 @@
#include "utils/signal.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
#include "ir/anf.h"
#include "ir/graph_utils.h"
#include "utils/counter.h"
#include "utils/hashing.h"
#include "base/base_ref.h"
#include "ir/anf.h"
#include "api/ir/func_graph_manager.h"
namespace mindspore {
struct Change;
@ -46,20 +47,9 @@ class FuncGraphTransaction;
class FuncGraphManager;
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
struct AnfNodeIndexPairHasher {
std::size_t operator()(const std::pair<AnfNodePtr, int> &p1) const {
return std::hash<const AnfNode *>{}(p1.first.get());
}
};
struct AnfNodeIndexPairEqual {
bool operator()(const std::pair<AnfNodePtr, int> &lhs, const std::pair<AnfNodePtr, int> &rhs) const {
return lhs == rhs;
}
};
using AnfNodeIndexSet = OrderedSet<std::pair<AnfNodePtr, int>, AnfNodeIndexPairHasher, AnfNodeIndexPairEqual>;
using AnfNodeIndexSet = api::AnfNodeIndexSet;
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
using NodeUsersMap = OrderedMap<AnfNodePtr, AnfNodeIndexSet>;
using NodeUsersMap = api::NodeUsersMap;
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
using EdgeTuple = std::pair<AnfNodePtr, std::pair<int, AnfNodePtr>>;
@ -294,7 +284,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
};
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager>, public api::FuncGraphManager {
public:
explicit FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool manage = true);
~FuncGraphManager() {
@ -314,9 +304,9 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) final;
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) final;
void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) final;
void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope);
FuncGraphTransaction Transact();
@ -332,6 +322,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
NodeUsersMap &node_users() { return node_users_; }
const NodeUsersMap &node_users() const final { return node_users_; }
FVTotalMap &free_variables_total() const;
FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const;