Handle interpreted node as input, set function Parameters and update local params if need.

This commit is contained in:
Zhang Qinghua 2021-09-11 17:10:07 +08:00
parent 728c449ebe
commit 9fbd118319
15 changed files with 325 additions and 188 deletions

View File

@ -258,11 +258,15 @@ def get_obj_type(obj):
obj_type = RESOLVE_TYPE_CLASS_TYPE
elif _is_class_instance(obj):
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
else:
# Raise a proper error if not using Fallback feature.
if support_fallback_ == '1':
obj_type = RESOLVE_TYPE_INVALID
else:
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
raise TypeError(f'Not support for this object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
f'`{obj.shape if is_ndarray else obj}`.')
raise TypeError(f'Not support for this object with type `{type(obj)}` and '
f'{"shape" if is_ndarray else "value"} `{obj.shape if is_ndarray else obj}`.')
return obj_type

View File

@ -120,7 +120,8 @@ bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph, const C
for (auto &node : orders) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->stage() > 0) {
auto stage_info = cnode->user_data<NodeStageInfo>();
if (stage_info != nullptr && stage_info->stage() > 0) {
continue;
}
if (IsValueNode<FuncGraph>(cnode->input(0))) {
@ -214,7 +215,7 @@ void PipelineTransformer::Coloring() {
auto node_users = manager_->node_users()[node];
for (auto &user_pair : node_users) {
auto user_node = user_pair.first->cast<CNodePtr>();
user_node->set_stage(graph->stage());
user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
auto user_node_graph = user_node->func_graph();
if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
user_node_graph->set_stage(graph->stage());
@ -238,19 +239,24 @@ void PipelineTransformer::BroadCastColoring() {
auto all_nodes = main_graph_->nodes();
auto node_users = manager_->node_users();
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() == -1) {
auto stage_info = node->user_data<NodeStageInfo>();
if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1) {
continue;
}
auto stage = node->stage();
auto stage = stage_info->stage();
for (auto &user_pair : node_users[node]) {
auto user_node = user_pair.first->cast<CNodePtr>();
auto user_node_stage = user_node->stage();
auto user_stage_info = user_node->user_data<NodeStageInfo>();
if (user_stage_info == nullptr) {
continue;
}
auto user_node_stage = user_stage_info->stage();
if (stage > user_node_stage) {
if (IsValueNode<FuncGraph>(user_node->input(0))) {
MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
<< " is larger than NextNode's stage:" << user_node_stage;
}
user_node->set_stage(stage);
user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
need_coloring = true;
}
}
@ -431,11 +437,12 @@ std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
micro = MakeValue(int64_t(0));
}
auto user_stage = node->stage();
if (stage_ == *parameter_stage.begin()) {
if (graph->stage() == stage_) {
auto stage_info = node->user_data<NodeStageInfo>();
if (graph->stage() == stage_ || stage_info == nullptr) {
continue;
}
auto user_stage = stage_info->stage();
if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
continue;
}
@ -469,7 +476,7 @@ void PipelineTransformer::ParameterColoring() {
}
if (graph != root_ && graph->stage() != -1) {
parameter_stage.insert(graph->stage());
parameter->set_stage(graph->stage());
parameter->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
}
}
auto param_info = parameter->cast<ParameterPtr>()->param_info();
@ -762,7 +769,7 @@ bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
int64_t user_stage, const ValuePtr &micro, size_t pos,
const std::vector<AnfNodePtr> ops) {
const std::vector<AnfNodePtr> &ops) {
MS_EXCEPTION_IF_NULL(node);
auto actual_node = ActualOp(node);
auto cnode = actual_node->cast<CNodePtr>();
@ -811,29 +818,19 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
return send_out.depend;
}
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> receive_ops;
std::vector<AnfNodePtr> send_ops;
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
std::reverse(all_nodes.begin(), all_nodes.end());
auto stage_num = g_device_manager->stage_num();
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
}
for (auto &node : all_nodes) {
if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
continue;
}
void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
std::vector<AnfNodePtr> *send_ops, std::vector<AnfNodePtr> *receive_ops) {
auto stage_info = node->user_data<NodeStageInfo>();
auto node_users = manager_->node_users()[node];
AnfNodePtr receive = nullptr;
for (auto &user_pair : node_users) {
auto user_node = user_pair.first;
auto node_stage = node->stage();
auto user_node_stage = user_node->stage();
auto node_stage = stage_info->stage();
auto user_stage_info = user_node->user_data<NodeStageInfo>();
if (user_stage_info == nullptr) {
continue;
}
auto user_node_stage = user_stage_info->stage();
if (node_stage != stage_ && user_node_stage != stage_) {
continue;
}
@ -846,34 +843,33 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
if (node_stage == stage_) {
if (IsParameterGraph(node)) {
auto send_depend =
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, send_ops);
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, *send_ops);
if (!send_depend) {
continue;
}
send_ops.insert(send_ops.begin(), send_depend);
send_ops->insert(send_ops->begin(), send_depend);
continue;
}
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
if (Reuse(node, user_node_stage, *send_ops, DEST_RANK)) {
continue;
}
auto send_out = InsertSend(graph, node, user_node_stage, node_stage, micro);
MS_EXCEPTION_IF_NULL(send_out.depend);
send_ops.push_back(send_out.depend);
send_ops->push_back(send_out.depend);
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
} else {
if (!receive) {
if (IsParameterGraph(node)) {
receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second,
receive_ops);
receive =
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, *receive_ops);
if (!receive) {
continue;
}
receive_ops.push_back(receive);
receive_ops->push_back(receive);
} else {
receive =
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
receive_ops.push_back(receive);
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
receive_ops->push_back(receive);
}
} else {
manager_->SetEdge(user_node, user_pair.second, receive);
@ -882,10 +878,32 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
continue;
}
if (node_stage > user_node_stage) {
MS_LOG(EXCEPTION) << "node_stage: " << node_stage
<< " must be smaller than user_node_stage: " << user_node_stage;
MS_LOG(EXCEPTION) << "node_stage: " << node_stage << " must be smaller than user_node_stage: " << user_node_stage;
}
}
}
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
std::vector<AnfNodePtr> send_ops;
std::vector<AnfNodePtr> receive_ops;
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
std::reverse(all_nodes.begin(), all_nodes.end());
auto stage_num = g_device_manager->stage_num();
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
}
for (auto &node : all_nodes) {
auto stage_info = node->user_data<NodeStageInfo>();
if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
continue;
}
// Modify for lizard cyclomatic complexity.
CutBorderForNode(graph, node, &send_ops, &receive_ops);
}
return std::make_pair(send_ops, receive_ops);
}

View File

@ -66,7 +66,7 @@ class PipelineTransformer {
bool IsParameterGraph(const AnfNodePtr &node);
AnfNodeIndexSet GetActualOpUsers(const std::pair<AnfNodePtr, int> &node_pair, NodeUsersMap *node_users_map);
AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
const ValuePtr &micro, size_t pos, const std::vector<AnfNodePtr> ops);
const ValuePtr &micro, size_t pos, const std::vector<AnfNodePtr> &ops);
ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
std::vector<AnfNodePtr> HandleSharedParameter();
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage,
@ -75,6 +75,8 @@ class PipelineTransformer {
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
const AnfNodePtr &graph_param);
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops,
std::vector<AnfNodePtr> *receive_ops);
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
const std::string &tag);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
@ -99,6 +101,20 @@ class PipelineTransformer {
int64_t micro_size_ = 0;
std::vector<std::string> group_ = {};
};
class NodeStageInfo {
public:
explicit NodeStageInfo(int64_t stage) : stage_(stage) {}
~NodeStageInfo() = default;
int64_t stage() const { return stage_; }
// Key for user data.
constexpr static char key[] = "NodeStageInfo";
private:
int64_t stage_;
};
} // namespace parallel
} // namespace mindspore

View File

@ -362,7 +362,17 @@ ValuePtr ConvertOtherObj(const py::object &obj) {
MS_LOG(DEBUG) << "name_space: " << res->ToString();
return res;
}
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
// Start RESOLVE_TYPE_INVALID...
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
static const auto use_fallback = (support_fallback == "1");
if (use_fallback) {
auto res = std::make_shared<InterpretedObject>(obj, py::str(obj));
MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
return res;
}
MS_LOG(ERROR) << "Resolve type is invalid, obj: " << py::str(obj);
return nullptr;
}

View File

@ -72,6 +72,11 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
<< node->DebugString();
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() == "1");
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
if (!is_new_name) {
// If a cnode variable with same name already existed but not used,
@ -92,14 +97,21 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
AddIsolatedNode(hidden_node);
}
iter->second = std::make_pair(node, false);
if (use_fallback) {
UpdateLocalPyParam(var_name, node);
}
} else {
if (use_fallback) {
AddLocalPyParam(var_name, node);
}
}
}
// Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString();
// Get var node if it is found
auto found = assigned_vars_.find(var);
auto found = assigned_vars_.find(var_name);
if (found != assigned_vars_.end()) {
auto &node = found->second.first;
MS_EXCEPTION_IF_NULL(node);
@ -117,34 +129,40 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
if (prev_blocks_.size() == 1) {
auto block = prev_blocks_[0];
MS_EXCEPTION_IF_NULL(block);
auto res = block->ReadVariable(var);
MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
<< ",\nCurrent: " << py::str(global_py_params())
auto res = block->ReadVariable(var_name);
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (parser_.support_fallback() == "1");
if (use_fallback) {
MS_LOG(DEBUG) << "Update global params of block: " << ToString()
<< ", with previous block: " << block->ToString() << ",\nCurrent: " << py::str(global_py_params())
<< "\nInsert: " << py::str(block->global_py_params());
CopyGlobalPyParam(block->global_py_params());
UpdateGlobalPyParam(block->global_py_params());
}
return res;
} else if (prev_blocks_.empty()) {
// Get namespace and make Resolve
auto it = var_to_resolve_.find(var);
auto it = var_to_resolve_.find(var_name);
if (it != var_to_resolve_.end()) {
return it->second;
}
MS_LOG(DEBUG) << "var: " << var;
auto tmp_node = MakeResolveSymbol(var);
var_to_resolve_[var] = tmp_node;
MS_LOG(DEBUG) << "var: " << var_name;
auto tmp_node = MakeResolveSymbol(var_name);
var_to_resolve_[var_name] = tmp_node;
return tmp_node;
}
}
// If have more than one predecessor blocks then build a phi node.
auto debug_info = std::make_shared<NodeDebugInfo>();
debug_info->set_name(var);
debug_info->set_name(var_name);
TraceGuard guard(std::make_shared<TracePhi>(debug_info));
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
<< phi_param->ToString() << " for " << var;
<< phi_param->ToString() << " for " << var_name;
func_graph()->add_parameter(phi_param);
phi_nodes_[phi_param] = var;
WriteVariable(var, phi_param);
phi_nodes_[phi_param] = var_name;
WriteVariable(var_name, phi_param);
if (matured_) {
SetPhiArgument(phi_param);
}
@ -337,19 +355,19 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
// Args: phi: This parameter node is functioning as a phi node.
bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
MS_EXCEPTION_IF_NULL(phi);
std::string var = phi_nodes_[phi];
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
std::string var_name = phi_nodes_[phi];
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var_name;
if (prev_blocks_.empty()) {
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var_name;
return false;
}
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
AnfNodePtr arg_node = SearchReplaceNode(var_name, phi);
if (arg_node != nullptr) {
arg_node->set_debug_info(phi->debug_info());
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
<< " can be replaced with " << arg_node->DebugString();
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
WriteVariable(var, arg_node);
WriteVariable(var_name, arg_node);
removable_phis_[phi] = arg_node;
resolve_to_removable_phis_[arg_node] = phi;
// The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized

View File

@ -86,7 +86,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
py::dict &global_py_params() { return global_py_params_; }
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
void CopyGlobalPyParam(const py::dict &symbols) {
void UpdateGlobalPyParam(const py::dict &symbols) {
for (auto &param : symbols) {
if (!global_py_params_.contains(param.first)) {
global_py_params_[param.first] = param.second;
@ -101,6 +101,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
local_py_params_keys_.emplace_back(NewValueNode(name));
local_py_params_values_.emplace_back(node);
}
// Call this methon only if you need update a variable. Usually variable override.
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(),
[&name](const AnfNodePtr node) -> bool {
const auto value_node = dyn_cast<ValueNode>(node);
MS_EXCEPTION_IF_NULL(value_node);
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
MS_EXCEPTION_IF_NULL(str_imm);
return name == str_imm->value();
});
if (iter == local_py_params_keys_.cend()) {
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist.";
}
// Find the same position in 'values', and update the node.
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
auto values_pos_iter = local_py_params_values_.begin() + distance;
MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString();
*values_pos_iter = node;
}
private:
// Block graph
@ -160,16 +179,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// x = x - 1 #This after block is a dead block
bool is_dead_block_{false};
};
class ScriptInfo {
public:
explicit ScriptInfo(const py::object &obj) : py_obj_(obj) {}
// Key for user data.
constexpr static char key[] = "ScriptInfo";
py::object py_obj_;
};
} // namespace parse
} // namespace mindspore

View File

@ -30,11 +30,11 @@
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "utils/ms_context.h"
#include "utils/interpret_node_recorder.h"
#include "debug/trace.h"
namespace mindspore {
namespace parse {
FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
(void)python_adapter::set_python_scoped();
@ -472,15 +472,15 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast return";
MS_EXCEPTION_IF_NULL(block);
// Create return valuenode
AnfNodePtr return_value_node = NewValueNode(prim::kPrimReturn);
// Parse the return Statements value
py::object value = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr return_expr_node = ParseExprNode(block, value);
// Create the cnode
auto block_fg = block->func_graph();
CNodePtr return_node = block_fg->NewCNodeInOrder({return_value_node, return_expr_node});
block_fg->set_return(return_node);
// Parse the return Statements value.
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
// Check if need interpreting.
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
// Create the `return` CNode.
auto func_graph = block->func_graph();
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
func_graph->set_return(return_cnode);
return block;
}
@ -928,13 +928,13 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
py::list args = ast_->GetArgs(node);
auto block_fg = func_block->func_graph();
for (std::size_t i = 0; i < args.size(); i++) {
std::string arg = py::cast<std::string>(args[i].attr("arg"));
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
TraceGuard guard(GetLocation(args[i]));
auto para_node = std::make_shared<Parameter>(block_fg);
para_node->debug_info()->set_name(arg);
para_node->debug_info()->set_name(arg_name);
block_fg->add_parameter(para_node);
func_block->WriteVariable(arg, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
func_block->WriteVariable(arg_name, para_node);
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
}
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
@ -1701,7 +1701,6 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
}
}
MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
block->AddLocalPyParam(name_id, assigned_node);
block->WriteVariable(name_id, assigned_node);
}
@ -1827,19 +1826,13 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
// The fallback feature is enabled in default.
// Not support change the flag during the process is alive.
static const auto use_fallback = (support_fallback() == "1");
if (!use_fallback) {
if (!use_fallback || !value_node->interpret()) {
return value_node;
}
AnfNodePtr interpreted_node = value_node;
if (value_node->interpret()) {
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
py::dict global_dict = block->global_py_params();
constexpr int recursive_level = 3;
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text
<< ", value_node: " << value_node->DebugString(recursive_level)
<< ", global_dict: " << py::str(global_dict);
// Prepare global parameters.
py::dict global_dict = block->global_py_params();
ValuePtr globals_converted_value = nullptr;
if (!ConvertData(global_dict, &globals_converted_value)) {
MS_LOG(EXCEPTION) << "Convert data failed";
@ -1849,13 +1842,19 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
auto [keys, values] = block->local_py_params();
auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
// Update the valued node if it need interpreting.
interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
constexpr int recursive_level = 3;
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text
<< "`,\nvalue_node: " << value_node->DebugString(recursive_level)
<< ",\nglobal_dict_node: " << global_dict_node->ToString()
<< ",\nlocal_dict_node: " << local_dict_node->ToString();
AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
// Print a hint for user.
MS_LOG(ERROR) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
auto line_info = trace::GetDebugInfo(value_node->debug_info());
MS_LOG(DEBUG) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
<< "\n\n"
<< trace::GetDebugInfo(value_node->debug_info());
}
<< line_info;
InterpretNodeRecorder::GetInstance().Push(line_info);
return interpreted_node;
}

View File

@ -275,12 +275,12 @@ class Parser {
// so in FunctionBlock class we can use FunctionBlock* in member
// pre_blocks_ and jumps_ to break reference cycle.
std::vector<FunctionBlockPtr> func_block_list_;
using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
using StmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
using ExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
// Define the function map to parse ast Statement
std::map<std::string, pStmtFunc> stmt_method_map_;
std::map<std::string, StmtFunc> stmt_method_map_;
// Define the function map to parse ast expression
std::map<std::string, pExprFunc> expr_method_map_;
std::map<std::string, ExprFunc> expr_method_map_;
// Save current loops to support 'continue', 'break' statement.
std::stack<Loop> loops_;
string max_for_loop_count_str_;

View File

@ -119,8 +119,12 @@ class InterpretedObject : public PyObjectWrapper {
: PyObjectWrapper(obj, name) {}
~InterpretedObject() override = default;
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
abstract::AbstractBasePtr ToAbstract() override;
abstract::AbstractBasePtr ToAbstract() override {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<InterpretedObject>(),
std::make_shared<External>());
}
};
using InterpretedObjectPtr = std::shared_ptr<InterpretedObject>;
// ClassObject class wrappers dataclass
class ClassObject : public PyObjectWrapper {

View File

@ -26,36 +26,38 @@
#include <algorithm>
#include <iomanip>
#include "pybind_api/pybind_patch.h"
#include "ir/param_info.h"
#include "pipeline/jit/pass.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "pipeline/jit/static_analysis/async_eval_result.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#include "debug/anf_ir_utils.h"
#include "debug/common.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/graph_util/get_parallel_info.h"
#include "utils/config_manager.h"
#include "utils/convert_utils.h"
#include "utils/convert_utils_py.h"
#include "utils/context/context_extends.h"
#include "vm/segment_runner.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/graph_util/get_parallel_info.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "debug/trace.h"
#include "debug/draw.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "pybind_api/pybind_patch.h"
#include "utils/shape_utils.h"
#include "utils/info.h"
#include "load_mindir/load_model.h"
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
#include "runtime/hardware/device_context_manager.h"
#include "utils/crypto.h"
#include "utils/comm_manager.h"
#include "utils/interpret_node_recorder.h"
#include "debug/anf_ir_dump.h"
#include "debug/dump_proto.h"
#include "debug/anf_ir_utils.h"
#include "debug/trace.h"
#include "debug/draw.h"
#include "debug/common.h"
#include "load_mindir/load_model.h"
#include "vm/segment_runner.h"
#include "backend/session/executor_manager.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#include "ps/constants.h"
#include "ps/util.h"
@ -722,9 +724,8 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
CheckArgsValid(args);
auto phase = py::cast<std::string>(phase_obj);
MS_LOG(INFO) << "Start compiling, phase: " << phase << ".";
MS_LOG(DEBUG) << "Compiling source: {" << py::str(source_obj)
<< "}\n\n Args: " << py::str(const_cast<py::tuple &>(args));
MS_LOG(INFO) << "Start compiling, phase: " << phase;
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args));
#ifdef ENABLE_GE
GetGeBackendPolicy();
@ -1471,6 +1472,7 @@ void ClearResAtexit() {
parse::Parser::CleanParserResource();
parse::CleanDataClassToClassMap();
trace::ClearTraceStack();
InterpretNodeRecorder::GetInstance().Clear();
}
py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {

View File

@ -205,6 +205,12 @@ static ValueNameToConverterVector value_name_to_converter = {
auto class_type = value->cast<parse::ClassTypePtr>();
return class_type->obj();
}},
// parse::InterpretedObject
{typeid(parse::InterpretedObject).name(),
[](const ValuePtr &value) -> py::object {
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
return interpreted_object->obj();
}},
// None
{typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
// AnyValue

View File

@ -106,8 +106,6 @@ class MS_CORE_API AnfNode : public Base {
fullname_with_scope_(""),
hash_(std::hash<const AnfNode *>()),
kernel_info_(nullptr),
stage_(-1),
need_grad_(false),
interpret_(false),
interpreted_node_(nullptr) {
scope_ = ScopeManager::GetInstance().GetCurrentScope();
@ -200,12 +198,6 @@ class MS_CORE_API AnfNode : public Base {
void CloneUserData(const AnfNodePtr &node) { user_data_ = node->user_data_; }
int64_t stage() { return stage_; }
void set_stage(const int &stage) { stage_ = stage; }
bool grad() { return need_grad_; }
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
bool interpret() { return interpret_; }
void set_interpret(const bool &interpret) { interpret_ = interpret; }
@ -226,8 +218,6 @@ class MS_CORE_API AnfNode : public Base {
ScopePtr scope_;
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
int64_t stage_;
bool need_grad_;
bool interpret_;
AnfNodePtr interpreted_node_;
};

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
#define MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
#include <vector>
#include <string>
namespace mindspore {
class InterpretNodeRecorder {
public:
explicit InterpretNodeRecorder(InterpretNodeRecorder &&) = delete;
explicit InterpretNodeRecorder(const InterpretNodeRecorder &) = delete;
void operator=(const InterpretNodeRecorder &) = delete;
void operator=(const InterpretNodeRecorder &&) = delete;
static InterpretNodeRecorder &GetInstance() {
static InterpretNodeRecorder instance;
return instance;
}
void Push(const std::string &line) { interpret_nodes_lines_.emplace_back(line); }
void Clear() { interpret_nodes_lines_.clear(); }
protected:
InterpretNodeRecorder() = default;
virtual ~InterpretNodeRecorder() = default;
private:
std::vector<std::string> interpret_nodes_lines_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_

View File

@ -80,14 +80,28 @@ def test_np_fallback_func():
print(np_fallback_func())
# Test `return` interpret node.
@ms_function
def div_mod_func(x, y):
def div_mod_func1():
x = 8
y = 3
a = divmod(x, y)
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func():
print(div_mod_func(8, 3)) # (2, 2)
def test_div_mod_func1():
print(div_mod_func1()) # (2, 2)
# Test interpret node with parameters as input.
@ms_function
def div_mod_func2(x, y):
a = divmod(x, y)
return Tensor(a)
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_div_mod_func2():
print(div_mod_func2(8, 3)) # (2, 2)
# NameError: name 'Tensor' is not defined.