!8346 Pynative support dynamic graph by using ast parse

From: @joylvliang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-11-16 09:04:57 +08:00 committed by Gitee
commit 2a8557de84
4 changed files with 176 additions and 32 deletions

View File

@ -90,6 +90,24 @@ const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
// define the common name
const char NAMED_PRIMITIVE_LEN[] = "len";
const char NAMED_PRIMITIVE_BODY[] = "body";
const char NAMED_PRIMITIVE_ASSIGN[] = "Assign";
const char NAMED_PRIMITIVE_FOR[] = "For";
const char NAMED_PRIMITIVE_IF[] = "If";
const char NAMED_PRIMITIVE_WHILE[] = "While";
const char NAMED_PRIMITIVE_VALUE[] = "value";
const char NAMED_PRIMITIVE_FUNC[] = "func";
const char NAMED_PRIMITIVE_TEST[] = "test";
const char NAMED_PRIMITIVE_LEFT[] = "left";
const char NAMED_PRIMITIVE_CALL[] = "Call";
const char NAMED_PRIMITIVE_SUBSCRIPT[] = "Subscript";
const char NAMED_PRIMITIVE_ATTRIBUTE[] = "Attribute";
const char NAMED_PRIMITIVE_COMPARE[] = "Compare";
const char NAMED_PRIMITIVE_NAMECONSTANT[] = "NameConstant";
const char NAMED_PRIMITIVE_COMPARATORS[] = "comparators";
const char NAMED_PRIMITIVE_SLICE[] = "slice";
const char NAMED_PRIMITIVE_NUM[] = "Num";
const char NAMED_PRIMITIVE_STR[] = "Str";
const char NAMED_PRIMITIVE_ITER[] = "iter";
const char NAMED_PRIMITIVE_NEXT[] = "next";
const char NAMED_PRIMITIVE_GETITEM[] = "getitem";
@ -104,6 +122,7 @@ const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
// define NAMED_PRIMITIVE_GETATTR "getattr"
// define python inline attr
const char PYTHON_GET_METHOD_LEN[] = "__len__";
const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__";
const char PYTHON_GET_OBJ_DESC[] = "__str__";

View File

@ -28,6 +28,7 @@
#include "pybind11/pybind11.h"
#include "ir/anf.h"
#include "pybind_api/ir/primitive_py.h"
#include "pipeline/jit/parse/parse.h"
#include "abstract/abstract_value.h"
namespace mindspore {
@ -65,6 +66,9 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"};
const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE,
parse::NAMED_PRIMITIVE_NAMECONSTANT,
parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR};
} // namespace pynative
} // namespace mindspore

View File

@ -28,6 +28,7 @@
#include "pybind_api/ir/tensor_py.h"
#include "ir/param_info.h"
#include "ir/anf.h"
#include "ir/cell.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/utils.h"
@ -36,10 +37,8 @@
#include "utils/config_manager.h"
#include "utils/convert_utils_py.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/composite/do_signature.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/static_analysis/prim.h"
#include "backend/session/session_factory.h"
@ -69,8 +68,7 @@ const size_t PTR_LEN = 15;
const std::set<std::string> vm_operators = {"make_ref", "HookBackward", "InsertGradientOf", "stop_gradient",
"mixed_precision_cast"};
namespace mindspore {
namespace pynative {
namespace mindspore::pynative {
static std::shared_ptr<session::SessionBasic> session = nullptr;
PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_;
@ -947,7 +945,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
if (op_mask) {
MS_LOG(DEBUG) << "Cell paramsters(weights)";
// get the parameter name from parameter object
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name");
auto name_attr = parse::python_adapter::GetPyObjAttr(obj, "name");
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
@ -1221,7 +1219,7 @@ py::tuple PynativeExecutor::RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_e
MsBackendPolicy PynativeExecutor::InitEnv(const OpExecInfoPtr &op_exec_info) {
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);
parse::python_adapter::set_python_env_flag(true);
MsBackendPolicy backend_policy;
#if (!defined ENABLE_GE)
auto ms_context = MsContext::GetInstance();
@ -1350,18 +1348,136 @@ std::string PynativeExecutor::GetCellId(const py::object &cell, const py::args &
return cell_id;
}
std::string PynativeExecutor::ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type) {
MS_EXCEPTION_IF_NULL(ast);
if (py::isinstance<py::none>(node)) {
MS_LOG(DEBUG) << "Get none type node!";
return "";
}
auto node_type = ast->GetNodeType(node);
MS_EXCEPTION_IF_NULL(node_type);
// check node type
parse::AstMainType node_main_type = node_type->main_type();
if (node_main_type != type) {
MS_LOG(ERROR) << "Node type is wrong: " << node_main_type << ", it should be " << type;
return "";
}
std::string node_name = node_type->node_name();
MS_LOG(DEBUG) << "Ast node is " << node_name;
return node_name;
}
bool PynativeExecutor::ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse if/while expr";
py::object test_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_TEST);
const auto &node_name = ParseNodeName(ast, test_node, parse::AST_MAIN_TYPE_EXPR);
if (node_name == parse::NAMED_PRIMITIVE_COMPARE) {
py::object left_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_LEFT);
py::list comparators_node = parse::python_adapter::GetPyObjAttr(test_node, parse::NAMED_PRIMITIVE_COMPARATORS);
if (comparators_node.empty()) {
MS_LOG(DEBUG) << "Get comparators node falied!";
return false;
}
const auto &left = ParseNodeName(ast, left_node, parse::AST_MAIN_TYPE_EXPR);
const auto &right = ParseNodeName(ast, comparators_node[0], parse::AST_MAIN_TYPE_EXPR);
MS_LOG(DEBUG) << "left is " << left << " right is " << right;
if (unchanged_named_primitive.find(left) == unchanged_named_primitive.end() ||
unchanged_named_primitive.find(right) == unchanged_named_primitive.end()) {
return true;
}
}
return false;
}
bool PynativeExecutor::ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse assign expr";
py::object value_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_VALUE);
const auto &node_name = ParseNodeName(ast, value_node, parse::AST_MAIN_TYPE_EXPR);
if (node_name == parse::NAMED_PRIMITIVE_CALL) {
py::object func_node = parse::python_adapter::GetPyObjAttr(value_node, parse::NAMED_PRIMITIVE_FUNC);
const auto &func_name = ParseNodeName(ast, func_node, parse::AST_MAIN_TYPE_EXPR);
if (func_name == parse::NAMED_PRIMITIVE_SUBSCRIPT) {
py::object slice_node = parse::python_adapter::GetPyObjAttr(func_node, parse::NAMED_PRIMITIVE_SLICE);
py::object value_in_slice_node = parse::python_adapter::GetPyObjAttr(slice_node, parse::NAMED_PRIMITIVE_VALUE);
const auto &node_name_in_slice_node = ParseNodeName(ast, value_in_slice_node, parse::AST_MAIN_TYPE_EXPR);
return unchanged_named_primitive.find(node_name_in_slice_node) == unchanged_named_primitive.end();
}
}
return false;
}
bool PynativeExecutor::ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node) {
MS_LOG(DEBUG) << "Parse for expr";
py::object body_node = parse::python_adapter::GetPyObjAttr(node, parse::NAMED_PRIMITIVE_BODY);
py::int_ pcount = parse::python_adapter::CallPyObjMethod(body_node, parse::PYTHON_GET_METHOD_LEN);
size_t count = LongToSize(pcount);
MS_LOG(DEBUG) << "The for nodes count is " << count;
for (size_t i = 0; i < count; ++i) {
auto it = py::cast<py::list>(body_node)[i];
const auto &node_name = ParseNodeName(ast, it, parse::AST_MAIN_TYPE_STMT);
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN && ParseAssignExprNode(ast, it)) {
return true;
}
}
return false;
}
bool PynativeExecutor::IsDynamicCell(const py::object &cell) {
std::string cell_info;
if (py::isinstance<Cell>(cell)) {
auto c_cell = py::cast<CellPtr>(cell);
MS_EXCEPTION_IF_NULL(c_cell);
cell_info = c_cell->ToString();
}
if (cell_info.find("nn.layer.basic.Dense") != string::npos) {
return false;
}
// using ast parse to check whether the construct of cell will be changed
auto ast = std::make_shared<parse::ParseAst>(cell);
bool success = ast->InitParseAstInfo(parse::PYTHON_MOD_GET_PARSE_METHOD);
if (!success) {
MS_LOG(ERROR) << "Parse code to ast tree failed";
return false;
}
py::object nodes = ast->GetAstNode();
py::object func_obj = parse::python_adapter::GetPyObjAttr(nodes, parse::NAMED_PRIMITIVE_BODY);
py::int_ pcount = parse::python_adapter::CallPyObjMethod(func_obj, parse::PYTHON_GET_METHOD_LEN);
size_t count = IntToSize(pcount);
MS_LOG(DEBUG) << "The nodes count is " << count;
bool ret = false;
for (size_t i = 0; i < count; ++i) {
auto node = py::cast<py::list>(func_obj)[i];
const auto &node_name = ParseNodeName(ast, node, parse::AST_MAIN_TYPE_STMT);
if (node_name == parse::NAMED_PRIMITIVE_ASSIGN) {
ret = ParseAssignExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_FOR) {
ret = ParseForExprNode(ast, node);
} else if (node_name == parse::NAMED_PRIMITIVE_IF || node_name == parse::NAMED_PRIMITIVE_WHILE) {
ret = ParseIfWhileExprNode(ast, node);
}
if (ret) {
MS_LOG(INFO) << "Cur cell is dynamic";
break;
}
}
return ret;
}
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id;
if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) {
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id;
// check whether cell needed to construct grad graph
if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) {
auto it = cell_resource_map_.find(cell_id);
if (it != cell_resource_map_.end()) {
resource_ = it->second;
MS_EXCEPTION_IF_NULL(resource_);
}
MS_LOG(DEBUG) << "Newgraph already compiled";
MS_LOG(DEBUG) << "Graph already compiled";
return;
}
// init resource for constructing forward graph and grad graph
auto g = std::make_shared<FuncGraph>();
if (graph_context_.empty()) {
MakeNewTopGraph(cell_id, args, g);
@ -1380,6 +1496,11 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
set_node_map(curr_g_, param, new_param, true);
set_node_map(curr_g_, param_obj, new_param);
}
// check whether the constrcut of cell will be changed
if (!dynamic_cell_) {
dynamic_cell_ = IsDynamicCell(cell);
MS_LOG(DEBUG) << "cell id: " << cell_id << ", is dynamic cell: " << dynamic_cell_;
}
}
void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g) {
@ -1391,23 +1512,14 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
}
}
}
if (dynamic_shape) {
auto it = df_builder_map_.find(cell_id);
if (it != df_builder_map_.end()) {
df_builder_map_.erase(cell_id);
}
auto ic = cell_resource_map_.find(cell_id);
if (ic != cell_resource_map_.end()) {
cell_resource_map_.erase(cell_id);
}
}
top_g_ = curr_g_ = g;
dynamic_cell_ = false;
// a df builder is built for every top function graph
df_builder_ = std::make_shared<FuncGraph>();
df_builder_map_.emplace(cell_id, std::make_pair(df_builder_, nullptr));
top_g_ = curr_g_ = g;
df_builder_map_[cell_id] = std::make_pair(df_builder_, nullptr);
resource_ = std::make_shared<pipeline::Resource>();
resource_->results()[pipeline::kPynativeGraphId] = graph_id_++;
cell_resource_map_.emplace(cell_id, resource_);
cell_resource_map_[cell_id] = resource_;
MS_LOG(DEBUG) << "New top graph for " << cell_id;
first_grad_step_ = true;
top_graph_cells_.emplace(cell_id);
@ -1452,12 +1564,11 @@ void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::objec
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) {
auto cell_id = GetCellId(cell, args);
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id;
if (!dynamic_shape && graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end()) {
if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) {
MS_LOG(DEBUG) << "Endgraph already compiled";
return;
}
cell_graph_map_.emplace(cell_id, std::make_pair(curr_g_, false));
cell_graph_map_[cell_id] = std::make_pair(curr_g_, false);
auto out_id = GetId(out);
// x =op1, y =op2, return (x, y)
if (graph_info_map_[curr_g_].node_map.find(out_id) == graph_info_map_[curr_g_].node_map.end()) {
@ -1563,8 +1674,8 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
std::string cell_id = CheckCellChanged(grad, cell, weights, args, &sens_weights_changed);
MS_LOG(DEBUG) << "GradNetInner cell_id " << cell_id;
if (!dynamic_shape && !sens_weights_changed.first && !sens_weights_changed.second &&
cell_graph_map_.find(cell_id) != cell_graph_map_.end() && cell_graph_map_[cell_id].second) {
if (!sens_weights_changed.first && !sens_weights_changed.second &&
cell_graph_map_.find(cell_id) != cell_graph_map_.end() && cell_graph_map_[cell_id].second && !dynamic_cell_) {
if (cell_resource_map_.find(cell_id) == cell_resource_map_.end()) {
MS_LOG(EXCEPTION) << "Can not find resource";
}
@ -1722,7 +1833,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
graph_info_map_[df_builder_].node_map.find(param_id) != graph_info_map_[df_builder_].node_map.end()) {
para_node = graph_info_map_[df_builder_].node_map[param_id].first;
} else {
auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(param, "name");
auto name_attr = parse::python_adapter::GetPyObjAttr(param, "name");
if (py::isinstance<py::none>(name_attr)) {
MS_LOG(EXCEPTION) << "Parameter object should have name attribute";
}
@ -1847,6 +1958,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
if (!flag.empty()) {
MS_LOG(DEBUG) << "Clear cell res";
MapClear<std::unordered_map<std::string, ResourcePtr>>(&cell_resource_map_, flag);
MapClear<std::unordered_map<std::string, bool>>(&cell_dynamic_map_, flag);
MapClear<std::unordered_map<std::string, std::pair<FuncGraphPtr, bool>>>(&cell_graph_map_, flag);
MapClear<std::unordered_map<std::string, std::pair<std::string, std::string>>>(&cell_sw_map_, flag);
MapClear<std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>>>(&df_builder_map_, flag);
@ -1892,6 +2004,7 @@ void PynativeExecutor::ClearRes() {
df_builder_map_.clear();
cell_graph_map_.clear();
cell_resource_map_.clear();
cell_dynamic_map_.clear();
node_abs_map_.clear();
top_graph_cells_.clear();
}
@ -1922,5 +2035,4 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) {
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
"Executor set grad flag.");
}));
} // namespace pynative
} // namespace mindspore
} // namespace mindspore::pynative

View File

@ -100,6 +100,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
private:
PynativeExecutor() = default;
// check cell struct
bool IsDynamicCell(const py::object &cell);
bool ParseIfWhileExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseAssignExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
parse::AstMainType type);
// run op
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
@ -158,9 +166,9 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static std::mutex instance_lock_;
static int64_t graph_id_;
bool grad_flag_{false};
bool dynamic_cell_{false};
bool first_grad_step_{false};
bool grad_is_running{false};
bool dynamic_shape{false};
// Used for construct grad graph
FuncGraphPtr top_g_{nullptr};
@ -173,6 +181,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// record all info of a graph
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, bool> cell_dynamic_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
std::unordered_map<std::string, std::pair<FuncGraphPtr, bool>> cell_graph_map_;
// key: cell_id, value: (send_id, weigths_id), cache for sens and weight change