forked from mindspore-Ecosystem/mindspore
!11138 Opitimize pynative dynamic grad graph
From: @zjun3021 Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
4ed3491952
File diff suppressed because it is too large
Load Diff
|
@ -69,32 +69,47 @@ struct GraphInfo {
|
|||
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
|
||||
};
|
||||
|
||||
struct CellInfo {
|
||||
bool is_grad{false}; // Derivative is calculated
|
||||
bool is_custom_bprop{false}; // Custom bprop
|
||||
FuncGraphPtr fg; // Forward graph
|
||||
std::string cell_id;
|
||||
std::string bprop_cell_id;
|
||||
class CellInfo {
|
||||
public:
|
||||
CellInfo() = default;
|
||||
CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
|
||||
: is_grad(isgrad),
|
||||
is_custom_bprop(custom_bprop),
|
||||
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
|
||||
: is_custom_bprop(custom_bprop),
|
||||
is_dynamic(has_dynamic),
|
||||
fg(std::move(foward_graph)),
|
||||
cell_id(std::move(cellid)),
|
||||
bprop_cell_id(std::move(bprop_id)) {}
|
||||
|
||||
bool is_grad{false}; // Derivative is calculated
|
||||
bool is_custom_bprop{false}; // Custom bprop
|
||||
bool is_dynamic{false}; // Set by has_dynamic_cell
|
||||
bool is_real_dynamic{false}; // Set by ops order
|
||||
size_t call_times{0};
|
||||
FuncGraphPtr fg{nullptr}; // Forward graph
|
||||
std::string cell_id;
|
||||
std::string bprop_cell_id;
|
||||
std::vector<std::string> cell_ops_info; // All ops info
|
||||
};
|
||||
|
||||
struct TopCellInfo {
|
||||
ResourcePtr resource;
|
||||
FuncGraphPtr df_builder;
|
||||
FuncGraphPtr bg; // Backward graph
|
||||
std::string cell_id;
|
||||
bool is_dynamic_cell{false};
|
||||
class TopCellInfo {
|
||||
public:
|
||||
TopCellInfo() = default;
|
||||
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
|
||||
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
|
||||
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
|
||||
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {}
|
||||
|
||||
bool is_topest{false};
|
||||
bool do_vm_compiled{false};
|
||||
ResourcePtr resource{nullptr};
|
||||
FuncGraphPtr df_builder{nullptr};
|
||||
FuncGraphPtr bg{nullptr}; // Backward graph
|
||||
std::string cell_id;
|
||||
std::string sens_id;
|
||||
std::string weights_id;
|
||||
};
|
||||
|
||||
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
|
||||
using CellInfoPtr = std::shared_ptr<CellInfo>;
|
||||
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
|
||||
|
||||
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||
public:
|
||||
static std::shared_ptr<PynativeExecutor> GetInstance() {
|
||||
|
@ -119,11 +134,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
|
||||
// Get info
|
||||
bool GetIsDynamicCell() const { return dynamic_cell_; }
|
||||
bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); }
|
||||
// Call by python
|
||||
void Clear(const std::string &flag = "");
|
||||
void Clean();
|
||||
|
@ -149,7 +165,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
template <typename T>
|
||||
void VectorClear(T *vec, const std::string &cell_id) {
|
||||
for (auto it = vec->begin(); it != vec->end();) {
|
||||
if (it->cell_id.find(cell_id) != std::string::npos) {
|
||||
if ((*it)->cell_id.find(cell_id) != std::string::npos) {
|
||||
it = vec->erase(it);
|
||||
} else {
|
||||
it++;
|
||||
|
@ -201,29 +217,39 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
|
||||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
|
||||
void CleanPreMemoryInValueNode(const std::string &cell_id);
|
||||
void CleanPreMemoryInValueNode();
|
||||
|
||||
// Construct grad graph
|
||||
void PushCurrentGraphToStack();
|
||||
void PopGraphStack();
|
||||
void PushCurrentCellOpInfoToStack();
|
||||
void PopCurrentCellOpInfoFromStack();
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
void AddNestedGradOrder() { ++grad_order_; }
|
||||
void SubNestedGradOrder();
|
||||
bool IsNotNestedGrad() const;
|
||||
bool IsNestedGrad() const;
|
||||
bool IsTopGraph(const std::string &cell_id);
|
||||
bool IsTopestGraph(const std::string &cell_id);
|
||||
bool IsBpropGraph(const std::string &cell_id);
|
||||
bool IsFirstGradStep(const std::string &cell_id);
|
||||
bool grad_running() const { return grad_is_running_; }
|
||||
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
|
||||
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
|
||||
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
|
||||
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
|
||||
bool CheckDynamicCell(const std::string &cell_id);
|
||||
bool CheckRealDynamicCell(const std::string &cell_id);
|
||||
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
|
||||
bool need_cloned = false, bool is_grad = false);
|
||||
void ClearCnodeRes(const AnfNodePtr &node);
|
||||
void UpdateCellDynamic(const std::string &cell_id);
|
||||
bool CheckCellChanged(const std::string &cell_id);
|
||||
void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled);
|
||||
void ClearResidualRes(const std::string &cell_id);
|
||||
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
|
||||
void NewGraphInner(const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args);
|
||||
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
|
||||
const std::string &out_id, const py::args &args);
|
||||
|
@ -232,38 +258,44 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
const std::string &cell_id, const py::args &args);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
|
||||
py::object *sens = nullptr);
|
||||
void ClearDynamicTopRes(const std::string &cell_id);
|
||||
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
std::string GetTensorCellId(const std::string &cell_id);
|
||||
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
|
||||
void UpdateGraphInfoMap(const std::string &cell_id);
|
||||
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
|
||||
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
|
||||
const std::string &cell_id);
|
||||
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
|
||||
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
|
||||
const py::object &out, bool has_sens);
|
||||
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
|
||||
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
|
||||
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
|
||||
|
||||
// Hold graph(forward and grad) info
|
||||
std::string GetCellOpInfo();
|
||||
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
|
||||
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
|
||||
graph_info_map_[g].objects.push_back(obj);
|
||||
graph_info_map_[g]->objects.push_back(obj);
|
||||
}
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
graph_info_map_[g].params[id] = param;
|
||||
graph_info_map_[g]->params[id] = param;
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
int64_t index = -1) {
|
||||
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) {
|
||||
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
|
||||
graph_info_map_[g]->node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
|
@ -274,7 +306,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
size_t grad_order_{0};
|
||||
std::string top_cell_id_;
|
||||
bool grad_flag_{false};
|
||||
bool dynamic_cell_{false};
|
||||
bool has_dynamic_cell_{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_replace_forward_{true};
|
||||
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
|
||||
|
@ -288,16 +320,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
FuncGraphPtr curr_g_{nullptr};
|
||||
// Records forwrad graph, the bottom is top graph
|
||||
std::stack<FuncGraphPtr> graph_stack_;
|
||||
// Records op info of every cell, the bottom is op info of top cell
|
||||
std::stack<std::string> cell_op_info_stack_;
|
||||
|
||||
// Use vector for keep order
|
||||
std::vector<CellInfo> cell_graph_list_;
|
||||
std::vector<TopCellInfo> top_cell_list_;
|
||||
std::vector<CellInfoPtr> cell_graph_list_;
|
||||
std::vector<TopCellInfoPtr> top_cell_list_;
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
std::unordered_map<std::string, bool> cell_dynamic_map_;
|
||||
// Record all info for all cells
|
||||
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
|
||||
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
|
||||
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
|
||||
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
|
||||
|
||||
// Used for runop and replace forward result of grad graph
|
||||
|
|
|
@ -314,6 +314,9 @@ class _PynativeExecutor:
|
|||
def check_graph(self, obj, *args, **kwargs):
|
||||
return self._executor.check_graph(obj, *args, *(kwargs.values()))
|
||||
|
||||
def check_run(self, obj, *args, **kwargs):
|
||||
return self._executor.check_run(obj, *args, *(kwargs.values()))
|
||||
|
||||
def grad(self, grad, obj, weights, *args, **kwargs):
|
||||
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
|
||||
|
||||
|
|
|
@ -162,6 +162,14 @@ class OrderedMap {
|
|||
return pos == map_data_.end() ? sequential_data_.end() : (pos->second);
|
||||
}
|
||||
|
||||
ValueT at(const key_t &key) {
|
||||
auto pos = map_data_.find(key);
|
||||
if (pos == map_data_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Have no key " << key;
|
||||
}
|
||||
return pos->second->second;
|
||||
}
|
||||
|
||||
// Remove the last element from the sequential_data_.
|
||||
void pop_back() {
|
||||
typename map_type::iterator pos = map_data_.find(sequential_data_.back().first);
|
||||
|
@ -192,6 +200,24 @@ class OrderedMap {
|
|||
return 1;
|
||||
}
|
||||
|
||||
void update(const key_t &old_key, const key_t &new_key) {
|
||||
auto old_it = find(old_key);
|
||||
if (old_it == end()) {
|
||||
return;
|
||||
}
|
||||
auto new_it = find(new_key);
|
||||
if (new_it == end()) {
|
||||
old_it->first = new_key;
|
||||
auto nh = map_data_.extract(old_key);
|
||||
nh.key() = new_key;
|
||||
map_data_.insert(std::move(nh));
|
||||
return;
|
||||
}
|
||||
*old_it = *new_it;
|
||||
(void)erase(old_key);
|
||||
(void)erase(new_key);
|
||||
}
|
||||
|
||||
private:
|
||||
map_type map_data_;
|
||||
sequential_type sequential_data_;
|
||||
|
|
|
@ -68,7 +68,7 @@ class Cell(Cell_):
|
|||
"""
|
||||
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
|
||||
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
|
||||
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
|
||||
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
|
||||
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
|
||||
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
|
||||
'_auto_parallel_compile_and_run', 'cell_type']
|
||||
|
@ -105,15 +105,10 @@ class Cell(Cell_):
|
|||
self._backward_hook = None
|
||||
self.enable_hook = False
|
||||
self._bprop_debug = False
|
||||
self._already_run = False
|
||||
self.cell_type = None
|
||||
self._auto_parallel_compile_and_run = False
|
||||
self._support_non_tensor_inputs = False
|
||||
|
||||
@property
|
||||
def already_run(self):
|
||||
return self._already_run
|
||||
|
||||
def __getstate__(self):
|
||||
base = Cell_.__getstate__(self)
|
||||
return base, self.__dict__
|
||||
|
@ -150,10 +145,6 @@ class Cell(Cell_):
|
|||
# `<class 'xxxxxxx'>` to `xxxxxxx`
|
||||
return str(self.__class__)[8:-2]
|
||||
|
||||
@already_run.setter
|
||||
def already_run(self, value):
|
||||
self._already_run = value
|
||||
|
||||
@property
|
||||
def create_time(self):
|
||||
return self._create_time
|
||||
|
@ -334,12 +325,10 @@ class Cell(Cell_):
|
|||
for item in inputs:
|
||||
if isinstance(item, numpy.ndarray):
|
||||
raise TypeError("cell inputs should not be numpy array.")
|
||||
origin_grad = []
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(self, *inputs, **kwargs)
|
||||
for cell in self.cells():
|
||||
origin_grad.append(cell.requires_grad)
|
||||
cell.set_grad(True)
|
||||
else:
|
||||
_pynative_exec.set_grad_flag(False)
|
||||
|
@ -363,9 +352,6 @@ class Cell(Cell_):
|
|||
output = output.data
|
||||
if self.requires_grad is True:
|
||||
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
|
||||
for i, cell in enumerate(self.cells()):
|
||||
cell.set_grad(origin_grad[i])
|
||||
self._already_run = True
|
||||
return output
|
||||
|
||||
def _add_attr(self, name, value):
|
||||
|
|
|
@ -319,36 +319,30 @@ class GradOperation(GradOperation_):
|
|||
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param)
|
||||
self.grad_fn = None
|
||||
self.fn = None
|
||||
self.need_forward = False
|
||||
|
||||
def _pynative_forward_run(self, args, kwargs, fn):
|
||||
""" Pynative forward run to build grad graph. """
|
||||
new_kwargs = {}
|
||||
new_kwargs = kwargs
|
||||
if self.sens_param:
|
||||
if not 'sens' in kwargs.keys():
|
||||
args = args[:-1]
|
||||
new_kwargs = kwargs
|
||||
else:
|
||||
for key, value in kwargs.items():
|
||||
if key != 'sens':
|
||||
new_kwargs[key] = value
|
||||
new_kwargs = kwargs.copy()
|
||||
new_kwargs.pop('sens')
|
||||
for arg in args:
|
||||
if not isinstance(arg, Tensor):
|
||||
raise TypeError("grad inputs should be tensor in pynative mode")
|
||||
if isinstance(fn, FunctionType):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(fn, *args, **new_kwargs)
|
||||
output = fn(*args, **new_kwargs)
|
||||
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
|
||||
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
|
||||
_pynative_exec.set_grad_flag(True)
|
||||
_pynative_exec.new_graph(fn, *args, **new_kwargs)
|
||||
output = fn(*args, **new_kwargs)
|
||||
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
|
||||
else:
|
||||
if fn.already_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
if not fn.already_run:
|
||||
self.need_forward = True
|
||||
if self.need_forward:
|
||||
# Check if fn have run already
|
||||
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
|
||||
fn.set_grad()
|
||||
fn(*args, **new_kwargs)
|
||||
fn.already_run = False
|
||||
|
||||
def __call__(self, fn, weights=None):
|
||||
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
|
||||
|
@ -367,7 +361,6 @@ class GradOperation(GradOperation_):
|
|||
def after_grad(*args, **kwargs):
|
||||
if _pynative_exec.check_graph(fn, *args, **kwargs):
|
||||
print("Another grad step is running")
|
||||
fn.already_run = False
|
||||
self._pynative_forward_run(args, kwargs, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
|
||||
out = _pynative_exec(fn, *args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue