forked from mindspore-Ecosystem/mindspore
Reactor pynative
Signed-off-by: zjun <zhangjun0@huawei.com> Modify real dynamic Signed-off-by: zjun <zhangjun0@huawei.com>
This commit is contained in:
parent
cb90c963a8
commit
ae3a9d2883
File diff suppressed because it is too large
Load Diff
|
@ -17,8 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_EXECUTE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
@ -68,118 +68,121 @@ struct GraphInfo {
|
|||
GraphInfo() = default;
|
||||
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
|
||||
};
|
||||
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||
|
||||
class CellInfo {
|
||||
public:
|
||||
CellInfo() = default;
|
||||
~CellInfo() = default;
|
||||
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
|
||||
: is_custom_bprop(custom_bprop),
|
||||
is_dynamic(has_dynamic),
|
||||
fg(std::move(foward_graph)),
|
||||
cell_id(std::move(cellid)),
|
||||
bprop_cell_id(std::move(bprop_id)) {}
|
||||
: is_custom_bprop_(custom_bprop),
|
||||
is_dynamic_(has_dynamic),
|
||||
fg_(std::move(foward_graph)),
|
||||
cell_id_(std::move(cellid)),
|
||||
bprop_cell_id_(std::move(bprop_id)) {}
|
||||
|
||||
bool is_grad{false}; // Derivative is calculated
|
||||
bool is_custom_bprop{false}; // Custom bprop
|
||||
bool is_dynamic{false}; // Set by has_dynamic_cell
|
||||
bool is_real_dynamic{false}; // Set by ops order
|
||||
size_t call_times{0};
|
||||
FuncGraphPtr fg{nullptr}; // Forward graph
|
||||
std::string cell_id;
|
||||
std::string bprop_cell_id;
|
||||
std::vector<std::string> cell_ops_info; // All ops info
|
||||
bool is_custom_bprop() const { return is_custom_bprop_; }
|
||||
void set_is_custom_bprop(bool is_custom_bprop) { is_custom_bprop_ = is_custom_bprop; }
|
||||
bool is_dynamic() const { return is_dynamic_; }
|
||||
void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; }
|
||||
size_t call_times() const { return call_times_; }
|
||||
void set_call_times(size_t call_times) { call_times_ = call_times; }
|
||||
FuncGraphPtr fg() const { return fg_; }
|
||||
void set_fg(FuncGraphPtr fg) { fg_ = std::move(fg); }
|
||||
std::string &cell_id() { return cell_id_; }
|
||||
void set_cell_id(std::string cell_id) { cell_id_ = std::move(cell_id); }
|
||||
std::string &bprop_cell_id() { return bprop_cell_id_; }
|
||||
std::vector<std::string> &cell_ops_info() { return cell_ops_info_; }
|
||||
|
||||
private:
|
||||
bool is_custom_bprop_{false}; // Custom bprop
|
||||
bool is_dynamic_{false}; // Set by has_dynamic_cell
|
||||
size_t call_times_{0};
|
||||
FuncGraphPtr fg_{nullptr}; // Forward graph
|
||||
std::string cell_id_;
|
||||
std::string bprop_cell_id_;
|
||||
std::vector<std::string> cell_ops_info_; // All ops info
|
||||
};
|
||||
using CellInfoPtr = std::shared_ptr<CellInfo>;
|
||||
|
||||
class TopCellInfo {
|
||||
public:
|
||||
TopCellInfo() = default;
|
||||
~TopCellInfo() = default;
|
||||
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {}
|
||||
: is_topest_(topest), resource_(std::move(r)), df_builder_(std::move(df)), cell_id_(std::move(cellid)) {}
|
||||
|
||||
bool need_grad{true};
|
||||
bool is_topest{false};
|
||||
bool do_vm_compiled{false};
|
||||
bool forward_already_run{false};
|
||||
size_t top_cell_index{0};
|
||||
ResourcePtr resource{nullptr};
|
||||
FuncGraphPtr df_builder{nullptr};
|
||||
FuncGraphPtr bg{nullptr}; // Backward graph
|
||||
std::string cell_id;
|
||||
std::string sens_id;
|
||||
std::string weights_id;
|
||||
std::string input_args_id;
|
||||
};
|
||||
|
||||
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||
using CellInfoPtr = std::shared_ptr<CellInfo>;
|
||||
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
public:
|
||||
static std::shared_ptr<PynativeExecutor> GetInstance() {
|
||||
std::lock_guard<std::mutex> i_lock(instance_lock_);
|
||||
if (executor_ == nullptr) {
|
||||
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
|
||||
}
|
||||
return executor_;
|
||||
bool is_grad() const { return is_grad_; }
|
||||
void set_is_grad(bool is_grad) { is_grad_ = is_grad; }
|
||||
bool is_topest() const { return is_topest_; }
|
||||
bool vm_compiled() const { return vm_compiled_; }
|
||||
void set_vm_compiled(bool vm_compiled) { vm_compiled_ = vm_compiled; }
|
||||
bool need_grad() const { return need_grad_; }
|
||||
void set_need_grad(bool need_grad) { need_grad_ = need_grad; }
|
||||
bool has_dynamic_cell() const { return has_dynamic_cell_; }
|
||||
bool is_real_dynamic() const { return is_real_dynamic_; }
|
||||
void set_is_real_dynamic(bool is_real_dynamic) { is_real_dynamic_ = is_real_dynamic; }
|
||||
bool forward_already_run() const { return forward_already_run_; }
|
||||
void set_forward_already_run(bool set_forward_already_run) { forward_already_run_ = set_forward_already_run; }
|
||||
ResourcePtr resource() { return resource_; }
|
||||
FuncGraphPtr df_builder() { return df_builder_; }
|
||||
std::string &cell_id() { return cell_id_; }
|
||||
std::string &sens_id() { return sens_id_; }
|
||||
void set_sens_id(std::string sens_id) { sens_id_ = std::move(sens_id); }
|
||||
std::string &weights_id() { return weights_id_; }
|
||||
void set_weights_id(std::string weights_id) { weights_id_ = std::move(weights_id); }
|
||||
std::string &input_args_id() { return input_args_id_; }
|
||||
void set_input_args_id(std::string input_args_id) { input_args_id_ = std::move(input_args_id); }
|
||||
std::vector<CellInfoPtr> &cell_graph_list() { return cell_graph_list_; }
|
||||
void set_cell_graph_list(const std::vector<CellInfoPtr> &cell_graph_list) { cell_graph_list_ = cell_graph_list; }
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map() { return graph_info_map_; }
|
||||
void set_graph_info_map(const OrderedMap<FuncGraphPtr, GraphInfoPtr> &graph_info_map) {
|
||||
graph_info_map_ = graph_info_map;
|
||||
}
|
||||
void clear() {
|
||||
cell_graph_list_.clear();
|
||||
graph_info_map_.clear();
|
||||
}
|
||||
~PynativeExecutor();
|
||||
PynativeExecutor(const PynativeExecutor &) = delete;
|
||||
PynativeExecutor &operator=(const PynativeExecutor &) = delete;
|
||||
|
||||
bool need_replace_forward() const { return need_replace_forward_; }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
void EnterConstruct(const py::object &cell);
|
||||
void LeaveConstruct(const py::object &cell);
|
||||
|
||||
py::object RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
void RunInner(const py::object &cell, const py::tuple &args, const py::object &phase, py::object *ret);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
|
||||
// Get info
|
||||
bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); }
|
||||
// Call by python
|
||||
void Clear(const std::string &flag = "");
|
||||
void Clean();
|
||||
// Abnormal existed
|
||||
void ClearRes();
|
||||
// Sync stream
|
||||
void Sync();
|
||||
|
||||
private:
|
||||
PynativeExecutor() = default;
|
||||
bool is_grad_{false}; // Derivative is calculated
|
||||
bool is_topest_{false};
|
||||
bool vm_compiled_{false};
|
||||
bool need_grad_{true};
|
||||
bool has_dynamic_cell_{false};
|
||||
bool is_real_dynamic_{false};
|
||||
bool forward_already_run_{false};
|
||||
ResourcePtr resource_{nullptr};
|
||||
FuncGraphPtr df_builder_{nullptr};
|
||||
std::string cell_id_;
|
||||
std::string sens_id_;
|
||||
std::string weights_id_;
|
||||
std::string input_args_id_;
|
||||
std::vector<CellInfoPtr> cell_graph_list_;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
};
|
||||
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
||||
|
||||
template <typename T>
|
||||
void MapClear(T *map, const std::string &cell_id) {
|
||||
for (auto it = map->begin(); it != map->end();) {
|
||||
if (it->first.find(cell_id) != std::string::npos) {
|
||||
it = map->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
class DynamicAnalysis;
|
||||
using DynamicAnalysisPtr = std::shared_ptr<DynamicAnalysis>;
|
||||
|
||||
template <typename T>
|
||||
void VectorClear(T *vec, const std::string &cell_id) {
|
||||
for (auto it = vec->begin(); it != vec->end();) {
|
||||
if ((*it)->cell_id.find(cell_id) != std::string::npos) {
|
||||
it = vec->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
class ForwardExecutor;
|
||||
using ForwardExecutorPtr = std::shared_ptr<ForwardExecutor>;
|
||||
using ForwardExecutorWeakPtr = std::weak_ptr<ForwardExecutor>;
|
||||
|
||||
class GradExecutor;
|
||||
using GradExecutorPtr = std::shared_ptr<GradExecutor>;
|
||||
using GradExecutorWeakPtr = std::weak_ptr<GradExecutor>;
|
||||
|
||||
class DynamicAnalysis {
|
||||
public:
|
||||
DynamicAnalysis() = default;
|
||||
~DynamicAnalysis() = default;
|
||||
|
||||
// Check cell struct
|
||||
bool IsDynamicCell(const py::object &cell);
|
||||
|
||||
private:
|
||||
std::string GetCellInfo(const py::object &cell);
|
||||
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
bool ParseBodyContext(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node,
|
||||
|
@ -191,97 +194,75 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, size_t index);
|
||||
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, const std::string &op_name,
|
||||
size_t index);
|
||||
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
|
||||
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
||||
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
|
||||
// Run op
|
||||
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
};
|
||||
|
||||
class GradExecutor {
|
||||
public:
|
||||
GradExecutor() = default;
|
||||
~GradExecutor() = default;
|
||||
explicit GradExecutor(const ForwardExecutorPtr &forward_executor = nullptr)
|
||||
: forward_executor_(ForwardExecutorWeakPtr(forward_executor)) {}
|
||||
|
||||
std::function<void(py::object *, const py::object &, const py::args &)> InitGraph = [this](auto &&PH1, auto &&PH2,
|
||||
auto &&PH3) {
|
||||
NewGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3));
|
||||
};
|
||||
std::function<void(py::object *, const py::object &, const py::object &, const py::args &)> LinkGraph =
|
||||
[this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) {
|
||||
EndGraphInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2),
|
||||
std::forward<decltype(PH3)>(PH3), std::forward<decltype(PH4)>(PH4));
|
||||
};
|
||||
std::function<void(py::object *, const GradOperationPtr &, const py::object &, const py::object &, const py::args &)>
|
||||
GradGraph = [this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4, auto &&PH5) {
|
||||
GradNetInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
|
||||
std::forward<decltype(PH4)>(PH4), std::forward<decltype(PH5)>(PH5));
|
||||
};
|
||||
std::function<void(py::object *, const py::object &, const py::tuple &, const py::object &)> RunGraph =
|
||||
[this](auto &&PH1, auto &&PH2, auto &&PH3, auto &&PH4) {
|
||||
RunGradGraph(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2), std::forward<decltype(PH3)>(PH3),
|
||||
std::forward<decltype(PH4)>(PH4));
|
||||
};
|
||||
|
||||
FuncGraphPtr curr_g() const;
|
||||
TopCellInfoPtr top_cell() const;
|
||||
bool TopCellIsDynamic();
|
||||
void set_top_cell(TopCellInfoPtr top_cell) { top_cell_ = std::move(top_cell); }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
bool in_grad_process() const { return in_grad_process_; }
|
||||
std::string top_cell_id() { return top_cell()->cell_id(); }
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
|
||||
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
|
||||
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
||||
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *const status);
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *is_find);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false);
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
|
||||
// Replace for grad graph
|
||||
ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple);
|
||||
void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map);
|
||||
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real);
|
||||
// Update the abstract and device address info of value node and tensors in bprop graph
|
||||
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
|
||||
void CleanPreMemoryInValueNode();
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
bool need_construct_graph() const { return !graph_stack_.empty() && grad_flag_; }
|
||||
void set_dynamic_analysis(DynamicAnalysisPtr dynamic_analysis) { dynamic_analysis_ = std::move(dynamic_analysis); }
|
||||
std::stack<FuncGraphPtr> &graph_stack() { return graph_stack_; }
|
||||
std::vector<TopCellInfoPtr> &top_cell_list() { return top_cell_list_; }
|
||||
bool need_replace_forward() const { return need_replace_forward_; }
|
||||
std::stack<std::string> &cell_op_info_stack() { return cell_op_info_stack_; }
|
||||
std::unordered_map<std::string, size_t> &op_index_map() { return op_index_map_; }
|
||||
std::unordered_map<std::string, std::string> &obj_to_forward_id() { return obj_to_forward_id_; }
|
||||
void ClearGrad(const py::object &cell, const py::args &args);
|
||||
void ClearRes();
|
||||
void ClearCellRes(const std::string &cell_id = "");
|
||||
|
||||
// Construct grad graph
|
||||
void PushCurrentGraphToStack();
|
||||
void PopGraphStack();
|
||||
void PushCurrentCellOpInfoToStack();
|
||||
void PopCurrentCellOpInfoFromStack();
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
void AddNestedGradOrder() { ++grad_order_; }
|
||||
void SubNestedGradOrder();
|
||||
bool IsNestedGrad() const;
|
||||
bool IsTopGraph(const std::string &cell_id);
|
||||
bool IsTopestGraph(const std::string &cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsFirstGradStep(const std::string &cell_id);
|
||||
private:
|
||||
ForwardExecutorPtr forward() const;
|
||||
DynamicAnalysisPtr dynamic_analysis() const;
|
||||
bool grad_running() const { return grad_is_running_; }
|
||||
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
|
||||
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
|
||||
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
|
||||
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
|
||||
bool CheckDynamicCell(const std::string &cell_id);
|
||||
bool CheckRealDynamicCell(const std::string &cell_id);
|
||||
bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
|
||||
bool is_grad);
|
||||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set<AnfNodePtr> *node_set);
|
||||
void UpdateCellDynamic(const std::string &cell_id);
|
||||
bool CheckCellChanged(const std::string &cell_id);
|
||||
void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled);
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args);
|
||||
TopCellInfoPtr GetTopCell(const string &cell_id, bool find_nearest = false);
|
||||
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
|
||||
const std::string &out_id, const py::args &args);
|
||||
bool EndBpropGraph(const string &cell_id);
|
||||
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
|
||||
const std::string &cell_id, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
|
||||
py::object *sens = nullptr);
|
||||
void ClearDynamicTopRes(const std::string &cell_id, const FuncGraphPtr &df_builder);
|
||||
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
std::string GetTensorCellId(const std::string &cell_id);
|
||||
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
|
||||
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
|
||||
|
||||
// Higher derivative
|
||||
bool IsNestedGrad() const;
|
||||
void AddNestedGradOrder() { ++grad_order_; }
|
||||
void SubNestedGradOrder();
|
||||
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
|
||||
const std::string &cell_id);
|
||||
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
|
||||
|
@ -290,70 +271,222 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
|
||||
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
|
||||
|
||||
// Hold graph(forward and grad) info
|
||||
// Dynamic
|
||||
bool CheckDynamicCell(const std::string &cell_id);
|
||||
bool CheckRealDynamicCell(const std::string &cell_id);
|
||||
void ClearDynamicTopRes(const std::string &cell_id);
|
||||
|
||||
void PushCurrentGraphToStack();
|
||||
void PopGraphStack();
|
||||
void PushCurrentCellOpInfoToStack();
|
||||
void PopCurrentCellOpInfoFromStack();
|
||||
std::string GetCellOpInfo();
|
||||
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
|
||||
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
bool IsFirstGradStep();
|
||||
bool IsTopGraph(const std::string &cell_id);
|
||||
bool IsTopestGraph(const std::string &cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsGradBefore(const std::string &cell_id);
|
||||
bool CheckCellGraph(const std::string &cell_id);
|
||||
bool UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
|
||||
bool is_grad);
|
||||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
bool CheckCellChanged(const std::string &cell_id);
|
||||
void UpdateTopCellInfo(const std::string &cell_id, bool vm_compiled);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(py::object *ret, const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args);
|
||||
void EndGraphInner(py::object *ret, const py::object &cell, const py::object &out, const py::args &args);
|
||||
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
|
||||
const std::string &out_id, const py::args &args);
|
||||
bool EndBpropGraph(const string &cell_id);
|
||||
FuncGraphPtr MakeGradGraph(const py::object &cell, const FuncGraphPtr &g, const ResourcePtr &r,
|
||||
const std::string &cell_id, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
|
||||
py::object *sens = nullptr);
|
||||
void GradNetInner(py::object *ret, const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
void SetTopCellTensorId(const std::string &cell_id);
|
||||
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
void SetGradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
|
||||
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
AnfNodePtr GetObjNode(const py::object &obj, const std::string &obj_id);
|
||||
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
|
||||
|
||||
// Memory clean between steps
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void ClearCnodeRes(const AnfNodePtr &node);
|
||||
void CleanPreMemoryInValueNode();
|
||||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
|
||||
|
||||
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
|
||||
graph_info_map_[g]->objects.push_back(obj);
|
||||
top_cell()->graph_info_map()[g]->objects.push_back(obj);
|
||||
}
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
graph_info_map_[g]->params[id] = param;
|
||||
top_cell()->graph_info_map()[g]->params[id] = param;
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
int64_t index = -1) {
|
||||
graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) {
|
||||
graph_info_map_[g]->node_map[id] = std::make_pair(node, index);
|
||||
top_cell()->graph_info_map()[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
|
||||
static std::shared_ptr<PynativeExecutor> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
static int64_t graph_id_;
|
||||
private:
|
||||
size_t grad_order_{0};
|
||||
size_t top_cell_index_{0};
|
||||
std::string top_cell_id_;
|
||||
bool grad_flag_{false};
|
||||
bool in_bprop_process_{false};
|
||||
bool in_grad_process_{false};
|
||||
bool has_dynamic_cell_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_replace_forward_{true};
|
||||
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
|
||||
// such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global
|
||||
// primitives to control memory release. Global primitives are always created in top cell's '__init__' function and
|
||||
// temporary primitives are always created in other place.Temporary primitives will be released after executing top
|
||||
// cell's 'construct' function but global primitives will not.
|
||||
PyObject *top_cell_{nullptr};
|
||||
|
||||
// Used for construct grad graph
|
||||
bool grad_is_running_{false};
|
||||
FuncGraphPtr curr_g_{nullptr};
|
||||
// For clear pre top res
|
||||
TopCellInfoPtr pre_top_cell_{nullptr};
|
||||
TopCellInfoPtr top_cell_{nullptr};
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
|
||||
std::unordered_set<tensor::TensorPtr> all_value_node_tensors_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
|
||||
// Records forwrad graph, the bottom is top graph
|
||||
std::stack<FuncGraphPtr> graph_stack_;
|
||||
// Records op info of every cell, the bottom is op info of top cell
|
||||
std::stack<std::string> cell_op_info_stack_;
|
||||
|
||||
// Use vector for keep order
|
||||
std::vector<CellInfoPtr> cell_graph_list_;
|
||||
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
// Record all info for all cells
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
|
||||
ForwardExecutorWeakPtr forward_executor_;
|
||||
DynamicAnalysisPtr dynamic_analysis_;
|
||||
};
|
||||
|
||||
class ForwardExecutor {
|
||||
public:
|
||||
ForwardExecutor() = default;
|
||||
~ForwardExecutor() = default;
|
||||
|
||||
std::function<void(py::object *, const OpExecInfoPtr &)> RunOpS = [this](auto &&PH1, auto &&PH2) {
|
||||
RunOpInner(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2));
|
||||
};
|
||||
|
||||
void RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_info);
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
void set_grad_executor(const GradExecutorPtr &grad_executor) { grad_executor_ = GradExecutorWeakPtr(grad_executor); }
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> &node_abs_map() { return node_abs_map_; }
|
||||
std::unordered_map<std::string, OpIndexWithTensorId> &cell_op_index_with_tensor_id() {
|
||||
return cell_op_index_with_tensor_id_;
|
||||
}
|
||||
std::unordered_map<std::string, TensorIdWithTensor> &cell_tensor_id_with_tensor() {
|
||||
return cell_tensor_id_with_tensor_;
|
||||
}
|
||||
void ClearRes();
|
||||
|
||||
private:
|
||||
GradExecutorPtr grad() const;
|
||||
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
|
||||
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
|
||||
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
|
||||
py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info,
|
||||
PynativeStatusCode *status);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
|
||||
abstract::AbstractBasePtrList *args_spec_list);
|
||||
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
|
||||
const abstract::AbstractBasePtr &abs, const std::string &id, size_t index);
|
||||
void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list,
|
||||
bool *is_find);
|
||||
// Update the abstract and device address info of value node and tensors in bprop graph
|
||||
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
|
||||
// Mix precision
|
||||
void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info);
|
||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
|
||||
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name,
|
||||
size_t index);
|
||||
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
|
||||
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
||||
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
|
||||
|
||||
private:
|
||||
GradExecutorWeakPtr grad_executor_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
// Used for runop and replace forward result of grad graph
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
|
||||
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
std::unordered_set<tensor::TensorPtr> all_value_node_tensors_;
|
||||
};
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
public:
|
||||
static std::shared_ptr<PynativeExecutor> GetInstance() {
|
||||
std::lock_guard<std::mutex> i_lock(instance_lock_);
|
||||
if (executor_ == nullptr) {
|
||||
executor_ = std::shared_ptr<PynativeExecutor>(new (std::nothrow) PynativeExecutor());
|
||||
forward_executor_ = std::make_shared<ForwardExecutor>();
|
||||
grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
|
||||
grad_executor_->set_dynamic_analysis(std::make_shared<DynamicAnalysis>());
|
||||
forward_executor_->set_grad_executor(grad_executor_);
|
||||
}
|
||||
return executor_;
|
||||
}
|
||||
~PynativeExecutor() = default;
|
||||
PynativeExecutor(const PynativeExecutor &) = delete;
|
||||
PynativeExecutor &operator=(const PynativeExecutor &) = delete;
|
||||
|
||||
void EnterConstruct(const py::object &cell);
|
||||
void LeaveConstruct(const py::object &cell);
|
||||
GradExecutorPtr grad_executor();
|
||||
ForwardExecutorPtr forward_executor();
|
||||
|
||||
void set_grad_flag(bool flag);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
|
||||
// Used by graph clean
|
||||
bool GetIsDynamicCell();
|
||||
bool need_replace_forward() { return grad_executor()->need_replace_forward(); }
|
||||
// Cell destruct will call
|
||||
void ClearCell(const std::string &flag = "");
|
||||
void ClearGrad(const py::object &cell, const py::args &args);
|
||||
// Abnormal existed
|
||||
void ClearRes();
|
||||
// Sync stream
|
||||
void Sync();
|
||||
|
||||
private:
|
||||
PynativeExecutor() = default;
|
||||
|
||||
static std::shared_ptr<PynativeExecutor> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
static ForwardExecutorPtr forward_executor_;
|
||||
static GradExecutorPtr grad_executor_;
|
||||
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
|
||||
// such as Resnet50(Cell),LeNet(Cell).This pointer is used to distinguish temporary primitives from global
|
||||
// primitives to control memory release. Global primitives are always created in top cell's '__init__' function and
|
||||
// temporary primitives are always created in other place.Temporary primitives will be released after executing top
|
||||
// cell's 'construct' function but global primitives will not.
|
||||
PyObject *py_top_cell_{nullptr};
|
||||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
|
|
|
@ -367,8 +367,11 @@ class _PynativeExecutor:
|
|||
def grad(self, grad, obj, weights, *args, **kwargs):
|
||||
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
|
||||
|
||||
def clear(self, cell_id=""):
|
||||
self._executor.clear(cell_id)
|
||||
def del_cell(self, cell_id=""):
|
||||
self._executor.clear_cell(cell_id)
|
||||
|
||||
def clear_grad(self, obj, *args, **kwargs):
|
||||
self._executor.clear_grad(obj, *args, *(kwargs.values()))
|
||||
|
||||
def sync(self):
|
||||
self._executor.sync()
|
||||
|
|
|
@ -267,7 +267,7 @@ class Cell(Cell_):
|
|||
|
||||
def __del__(self):
|
||||
if context.get_context is not None and context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
_pynative_exec.clear(str(id(self)))
|
||||
_pynative_exec.del_cell(str(id(self)))
|
||||
if hasattr(self, "_create_time"):
|
||||
_executor.del_net_res(str(self._create_time))
|
||||
|
||||
|
|
|
@ -370,7 +370,7 @@ class GradOperation(GradOperation_):
|
|||
self._pynative_forward_run(args, kwargs, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
|
||||
out = _pynative_exec(fn, *args, **kwargs)
|
||||
_pynative_exec.clear()
|
||||
_pynative_exec.clear_grad(fn, *args, **kwargs)
|
||||
return out
|
||||
self.grad_fn = after_grad
|
||||
self.fn = fn
|
||||
|
|
|
@ -65,7 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
|
|||
py::none py_none;
|
||||
py::args args = py::make_tuple(conv_obj, op_name, op_inputs);
|
||||
py::list args_input = args[PY_INPUTS];
|
||||
return PynativeExecutor::GetInstance()->GenerateOpExecInfo(args);
|
||||
return PynativeExecutor::GetInstance()->forward_executor()->GenerateOpExecInfo(args);
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestCreateContext) {
|
||||
|
|
Loading…
Reference in New Issue