!11138 Opitimize pynative dynamic grad graph

From: @zjun3021
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-01-11 20:43:29 +08:00 committed by Gitee
commit 4ed3491952
6 changed files with 578 additions and 301 deletions

File diff suppressed because it is too large Load Diff

View File

@ -69,32 +69,47 @@ struct GraphInfo {
explicit GraphInfo(std::string id) : cell_id(std::move((id))) {}
};
struct CellInfo {
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
FuncGraphPtr fg; // Forward graph
std::string cell_id;
std::string bprop_cell_id;
class CellInfo {
public:
CellInfo() = default;
CellInfo(bool isgrad, bool custom_bprop, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_grad(isgrad),
is_custom_bprop(custom_bprop),
CellInfo(bool custom_bprop, bool has_dynamic, FuncGraphPtr foward_graph, std::string cellid, std::string bprop_id)
: is_custom_bprop(custom_bprop),
is_dynamic(has_dynamic),
fg(std::move(foward_graph)),
cell_id(std::move(cellid)),
bprop_cell_id(std::move(bprop_id)) {}
bool is_grad{false}; // Derivative is calculated
bool is_custom_bprop{false}; // Custom bprop
bool is_dynamic{false}; // Set by has_dynamic_cell
bool is_real_dynamic{false}; // Set by ops order
size_t call_times{0};
FuncGraphPtr fg{nullptr}; // Forward graph
std::string cell_id;
std::string bprop_cell_id;
std::vector<std::string> cell_ops_info; // All ops info
};
struct TopCellInfo {
ResourcePtr resource;
FuncGraphPtr df_builder;
FuncGraphPtr bg; // Backward graph
std::string cell_id;
bool is_dynamic_cell{false};
class TopCellInfo {
public:
TopCellInfo() = default;
TopCellInfo(ResourcePtr r, FuncGraphPtr df, FuncGraphPtr backward_graph, std::string cellid)
: resource(std::move(r)), df_builder(std::move(df)), bg(std::move(backward_graph)), cell_id(std::move(cellid)) {}
TopCellInfo(bool topest, ResourcePtr r, FuncGraphPtr df, std::string cellid)
: is_topest(topest), resource(std::move(r)), df_builder(std::move(df)), cell_id(std::move(cellid)) {}
bool is_topest{false};
bool do_vm_compiled{false};
ResourcePtr resource{nullptr};
FuncGraphPtr df_builder{nullptr};
FuncGraphPtr bg{nullptr}; // Backward graph
std::string cell_id;
std::string sens_id;
std::string weights_id;
};
using GraphInfoPtr = std::shared_ptr<GraphInfo>;
using CellInfoPtr = std::shared_ptr<CellInfo>;
using TopCellInfoPtr = std::shared_ptr<TopCellInfo>;
class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
public:
static std::shared_ptr<PynativeExecutor> GetInstance() {
@ -119,11 +134,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void NewGraph(const py::object &cell, const py::args &args);
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
py::object CheckGraph(const py::object &cell, const py::args &args);
py::object CheckAlreadyRun(const py::object &cell, const py::args &args);
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
// Get info
bool GetIsDynamicCell() const { return dynamic_cell_; }
bool GetIsDynamicCell() { return CheckRealDynamicCell(top_cell_id_); }
// Call by python
void Clear(const std::string &flag = "");
void Clean();
@ -149,7 +165,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
template <typename T>
void VectorClear(T *vec, const std::string &cell_id) {
for (auto it = vec->begin(); it != vec->end();) {
if (it->cell_id.find(cell_id) != std::string::npos) {
if ((*it)->cell_id.find(cell_id) != std::string::npos) {
it = vec->erase(it);
} else {
it++;
@ -201,29 +217,39 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveTensorsInValueNode(const ResourcePtr &resource);
void SaveAllValueNodeTensors(const FuncGraphPtr &graph);
void CleanPreMemoryInValueNode(const std::string &cell_id);
void CleanPreMemoryInValueNode();
// Construct grad graph
void PushCurrentGraphToStack();
void PopGraphStack();
void PushCurrentCellOpInfoToStack();
void PopCurrentCellOpInfoFromStack();
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
ResourcePtr GetResource(const std::string &cell_id = "");
void AddNestedGradOrder() { ++grad_order_; }
void SubNestedGradOrder();
bool IsNotNestedGrad() const;
bool IsNestedGrad() const;
bool IsTopGraph(const std::string &cell_id);
bool IsTopestGraph(const std::string &cell_id);
bool IsBpropGraph(const std::string &cell_id);
bool IsFirstGradStep(const std::string &cell_id);
bool grad_running() const { return grad_is_running_; }
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearCnodeRes(const AnfNodePtr &node);
void UpdateCellDynamic(const std::string &cell_id);
bool CheckCellChanged(const std::string &cell_id);
void UpdateTopCellCompileInfo(const std::string &cell_id, bool vm_compiled);
void ClearResidualRes(const std::string &cell_id);
void DumpGraphIR(const std::string &filename, const FuncGraphPtr &graph);
void NewGraphInner(const py::object &cell, const py::args &args);
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
void MakeNewTopGraph(const string &cell_id, const py::args &args);
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
const std::string &out_id, const py::args &args);
@ -232,38 +258,44 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
const std::string &cell_id, const py::args &args);
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
py::object *sens = nullptr);
void ClearDynamicTopRes(const std::string &cell_id);
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
const py::args &args);
std::string GetCellId(const py::object &obj, const py::args &args);
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
std::string GetTensorCellId(const std::string &cell_id);
bool CheckGradParamsChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
size_t arg_size, const std::string &cell_id);
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
void UpdateGraphInfoMap(const std::string &cell_id);
void ClearUselessRes(const FuncGraphPtr &df_builder, const py::object &cell, const std::string &cell_id);
void ReplaceGraphParams(const FuncGraphPtr &df_builder, const FuncGraphPtr &forward_graph,
const std::string &cell_id);
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
void RecoverGraphParams(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
// Hold graph(forward and grad) info
std::string GetCellOpInfo();
void ReplaceCellOpInfoByCellId(const std::string &cell_id);
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
graph_info_map_[g].objects.push_back(obj);
graph_info_map_[g]->objects.push_back(obj);
}
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
bool is_param = false);
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr &param) {
graph_info_map_[g].params[id] = param;
graph_info_map_[g]->params[id] = param;
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
int64_t index = -1) {
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
graph_info_map_[g]->node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
}
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
const std::vector<int64_t> &index) {
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
graph_info_map_[g]->node_map[id] = std::make_pair(node, index);
}
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
const std::vector<int64_t> &index_sequence, bool is_param = false);
@ -274,7 +306,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t grad_order_{0};
std::string top_cell_id_;
bool grad_flag_{false};
bool dynamic_cell_{false};
bool has_dynamic_cell_{false};
bool grad_is_running_{false};
bool need_replace_forward_{true};
// The pointer of top python Cell object, which is always the network(inherit class Cell) ran in python test script,
@ -288,16 +320,15 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr curr_g_{nullptr};
// Records forwrad graph, the bottom is top graph
std::stack<FuncGraphPtr> graph_stack_;
// Records op info of every cell, the bottom is op info of top cell
std::stack<std::string> cell_op_info_stack_;
// Use vector for keep order
std::vector<CellInfo> cell_graph_list_;
std::vector<TopCellInfo> top_cell_list_;
std::vector<CellInfoPtr> cell_graph_list_;
std::vector<TopCellInfoPtr> top_cell_list_;
std::unordered_set<std::string> cell_input_args_;
std::unordered_map<std::string, bool> cell_dynamic_map_;
// Record all info for all cells
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
OrderedMap<FuncGraphPtr, GraphInfoPtr> graph_info_map_;
std::unordered_map<FuncGraphPtr, std::vector<std::pair<ParameterPtr, ParameterPtr>>> replace_weights_map_;
// Used for runop and replace forward result of grad graph

View File

@ -314,6 +314,9 @@ class _PynativeExecutor:
def check_graph(self, obj, *args, **kwargs):
return self._executor.check_graph(obj, *args, *(kwargs.values()))
def check_run(self, obj, *args, **kwargs):
return self._executor.check_run(obj, *args, *(kwargs.values()))
def grad(self, grad, obj, weights, *args, **kwargs):
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))

View File

@ -162,6 +162,14 @@ class OrderedMap {
return pos == map_data_.end() ? sequential_data_.end() : (pos->second);
}
ValueT at(const key_t &key) {
auto pos = map_data_.find(key);
if (pos == map_data_.end()) {
MS_LOG(EXCEPTION) << "Have no key " << key;
}
return pos->second->second;
}
// Remove the last element from the sequential_data_.
void pop_back() {
typename map_type::iterator pos = map_data_.find(sequential_data_.back().first);
@ -192,6 +200,24 @@ class OrderedMap {
return 1;
}
void update(const key_t &old_key, const key_t &new_key) {
auto old_it = find(old_key);
if (old_it == end()) {
return;
}
auto new_it = find(new_key);
if (new_it == end()) {
old_it->first = new_key;
auto nh = map_data_.extract(old_key);
nh.key() = new_key;
map_data_.insert(std::move(nh));
return;
}
*old_it = *new_it;
(void)erase(old_key);
(void)erase(new_key);
}
private:
map_type map_data_;
sequential_type sequential_data_;

View File

@ -68,7 +68,7 @@ class Cell(Cell_):
"""
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_already_run', '_params_list', '_tensor_list', '_phase',
'_parameter_layout_dict', '_params_list', '_tensor_list', '_phase',
'_auto_parallel_mode', '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix',
'_attr_synced', 'enable_hook', 'pynative', 'requires_grad',
'_auto_parallel_compile_and_run', 'cell_type']
@ -105,15 +105,10 @@ class Cell(Cell_):
self._backward_hook = None
self.enable_hook = False
self._bprop_debug = False
self._already_run = False
self.cell_type = None
self._auto_parallel_compile_and_run = False
self._support_non_tensor_inputs = False
@property
def already_run(self):
return self._already_run
def __getstate__(self):
base = Cell_.__getstate__(self)
return base, self.__dict__
@ -150,10 +145,6 @@ class Cell(Cell_):
# `<class 'xxxxxxx'>` to `xxxxxxx`
return str(self.__class__)[8:-2]
@already_run.setter
def already_run(self, value):
self._already_run = value
@property
def create_time(self):
return self._create_time
@ -334,12 +325,10 @@ class Cell(Cell_):
for item in inputs:
if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.")
origin_grad = []
if self.requires_grad is True:
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(self, *inputs, **kwargs)
for cell in self.cells():
origin_grad.append(cell.requires_grad)
cell.set_grad(True)
else:
_pynative_exec.set_grad_flag(False)
@ -363,9 +352,6 @@ class Cell(Cell_):
output = output.data
if self.requires_grad is True:
_pynative_exec.end_graph(self, output, *inputs, **kwargs)
for i, cell in enumerate(self.cells()):
cell.set_grad(origin_grad[i])
self._already_run = True
return output
def _add_attr(self, name, value):

View File

@ -319,36 +319,30 @@ class GradOperation(GradOperation_):
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param)
self.grad_fn = None
self.fn = None
self.need_forward = False
def _pynative_forward_run(self, args, kwargs, fn):
""" Pynative forward run to build grad graph. """
new_kwargs = {}
new_kwargs = kwargs
if self.sens_param:
if not 'sens' in kwargs.keys():
args = args[:-1]
new_kwargs = kwargs
else:
for key, value in kwargs.items():
if key != 'sens':
new_kwargs[key] = value
new_kwargs = kwargs.copy()
new_kwargs.pop('sens')
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError("grad inputs should be tensor in pynative mode")
if isinstance(fn, FunctionType):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
_pynative_exec.set_grad_flag(True)
_pynative_exec.new_graph(fn, *args, **new_kwargs)
output = fn(*args, **new_kwargs)
_pynative_exec.end_graph(fn, output, *args, **new_kwargs)
else:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.already_run:
self.need_forward = True
if self.need_forward:
# Check if fn have run already
if not _pynative_exec.check_run(fn, *args, **new_kwargs):
fn.set_grad()
fn(*args, **new_kwargs)
fn.already_run = False
def __call__(self, fn, weights=None):
grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
@ -367,7 +361,6 @@ class GradOperation(GradOperation_):
def after_grad(*args, **kwargs):
if _pynative_exec.check_graph(fn, *args, **kwargs):
print("Another grad step is running")
fn.already_run = False
self._pynative_forward_run(args, kwargs, fn)
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
out = _pynative_exec(fn, *args, **kwargs)