forked from mindspore-Ecosystem/mindspore
Handle interpreted node as input, set function Parameters and update local params if need.
This commit is contained in:
parent
728c449ebe
commit
9fbd118319
|
@ -258,11 +258,15 @@ def get_obj_type(obj):
|
|||
obj_type = RESOLVE_TYPE_CLASS_TYPE
|
||||
elif _is_class_instance(obj):
|
||||
obj_type = RESOLVE_TYPE_CLASS_INSTANCE
|
||||
else:
|
||||
# Raise a proper error if not using Fallback feature.
|
||||
if support_fallback_ == '1':
|
||||
obj_type = RESOLVE_TYPE_INVALID
|
||||
else:
|
||||
# here for ndarray, just print its shape (in case of the array to large and print many data in screen)
|
||||
is_ndarray = type(obj).__name__ == 'ndarray' and hasattr(obj, 'shape')
|
||||
raise TypeError(f'Not support for this object with type `{type(obj)}` and {"shape" if is_ndarray else "value"} '
|
||||
f'`{obj.shape if is_ndarray else obj}`.')
|
||||
raise TypeError(f'Not support for this object with type `{type(obj)}` and '
|
||||
f'{"shape" if is_ndarray else "value"} `{obj.shape if is_ndarray else obj}`.')
|
||||
return obj_type
|
||||
|
||||
|
||||
|
|
|
@ -120,7 +120,8 @@ bool PipelineTransformer::LabelParameterStart(const FuncGraphPtr &graph, const C
|
|||
for (auto &node : orders) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->stage() > 0) {
|
||||
auto stage_info = cnode->user_data<NodeStageInfo>();
|
||||
if (stage_info != nullptr && stage_info->stage() > 0) {
|
||||
continue;
|
||||
}
|
||||
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
|
@ -214,7 +215,7 @@ void PipelineTransformer::Coloring() {
|
|||
auto node_users = manager_->node_users()[node];
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
user_node->set_stage(graph->stage());
|
||||
user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
|
||||
auto user_node_graph = user_node->func_graph();
|
||||
if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
|
||||
user_node_graph->set_stage(graph->stage());
|
||||
|
@ -238,19 +239,24 @@ void PipelineTransformer::BroadCastColoring() {
|
|||
auto all_nodes = main_graph_->nodes();
|
||||
auto node_users = manager_->node_users();
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>() || node->stage() == -1) {
|
||||
auto stage_info = node->user_data<NodeStageInfo>();
|
||||
if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1) {
|
||||
continue;
|
||||
}
|
||||
auto stage = node->stage();
|
||||
auto stage = stage_info->stage();
|
||||
for (auto &user_pair : node_users[node]) {
|
||||
auto user_node = user_pair.first->cast<CNodePtr>();
|
||||
auto user_node_stage = user_node->stage();
|
||||
auto user_stage_info = user_node->user_data<NodeStageInfo>();
|
||||
if (user_stage_info == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto user_node_stage = user_stage_info->stage();
|
||||
if (stage > user_node_stage) {
|
||||
if (IsValueNode<FuncGraph>(user_node->input(0))) {
|
||||
MS_LOG(EXCEPTION) << "The stage setting is incorrect. PreNode's stage:" << stage
|
||||
<< " is larger than NextNode's stage:" << user_node_stage;
|
||||
}
|
||||
user_node->set_stage(stage);
|
||||
user_node->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(stage));
|
||||
need_coloring = true;
|
||||
}
|
||||
}
|
||||
|
@ -431,11 +437,12 @@ std::vector<AnfNodePtr> PipelineTransformer::HandleSharedParameter() {
|
|||
MS_LOG(INFO) << "parameter: " << parameter->ToString() << " doesn't have micro batch";
|
||||
micro = MakeValue(int64_t(0));
|
||||
}
|
||||
auto user_stage = node->stage();
|
||||
if (stage_ == *parameter_stage.begin()) {
|
||||
if (graph->stage() == stage_) {
|
||||
auto stage_info = node->user_data<NodeStageInfo>();
|
||||
if (graph->stage() == stage_ || stage_info == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto user_stage = stage_info->stage();
|
||||
if (Reuse(parameter, user_stage, make_tuple_input, DEST_RANK)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -469,7 +476,7 @@ void PipelineTransformer::ParameterColoring() {
|
|||
}
|
||||
if (graph != root_ && graph->stage() != -1) {
|
||||
parameter_stage.insert(graph->stage());
|
||||
parameter->set_stage(graph->stage());
|
||||
parameter->set_user_data<NodeStageInfo>(std::make_shared<NodeStageInfo>(graph->stage()));
|
||||
}
|
||||
}
|
||||
auto param_info = parameter->cast<ParameterPtr>()->param_info();
|
||||
|
@ -762,7 +769,7 @@ bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) {
|
|||
|
||||
AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
|
||||
int64_t user_stage, const ValuePtr µ, size_t pos,
|
||||
const std::vector<AnfNodePtr> ops) {
|
||||
const std::vector<AnfNodePtr> &ops) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto actual_node = ActualOp(node);
|
||||
auto cnode = actual_node->cast<CNodePtr>();
|
||||
|
@ -811,29 +818,19 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
|
|||
return send_out.depend;
|
||||
}
|
||||
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||
OperatorAttrs depend_attrs;
|
||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
std::vector<AnfNodePtr> receive_ops;
|
||||
std::vector<AnfNodePtr> send_ops;
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
std::reverse(all_nodes.begin(), all_nodes.end());
|
||||
auto stage_num = g_device_manager->stage_num();
|
||||
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
|
||||
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
if (!node->isa<CNode>() || node->stage() == -1 || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
void PipelineTransformer::CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
std::vector<AnfNodePtr> *send_ops, std::vector<AnfNodePtr> *receive_ops) {
|
||||
auto stage_info = node->user_data<NodeStageInfo>();
|
||||
auto node_users = manager_->node_users()[node];
|
||||
AnfNodePtr receive = nullptr;
|
||||
for (auto &user_pair : node_users) {
|
||||
auto user_node = user_pair.first;
|
||||
auto node_stage = node->stage();
|
||||
auto user_node_stage = user_node->stage();
|
||||
auto node_stage = stage_info->stage();
|
||||
auto user_stage_info = user_node->user_data<NodeStageInfo>();
|
||||
if (user_stage_info == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto user_node_stage = user_stage_info->stage();
|
||||
if (node_stage != stage_ && user_node_stage != stage_) {
|
||||
continue;
|
||||
}
|
||||
|
@ -846,34 +843,33 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
if (node_stage == stage_) {
|
||||
if (IsParameterGraph(node)) {
|
||||
auto send_depend =
|
||||
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, send_ops);
|
||||
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, *send_ops);
|
||||
if (!send_depend) {
|
||||
continue;
|
||||
}
|
||||
send_ops.insert(send_ops.begin(), send_depend);
|
||||
send_ops->insert(send_ops->begin(), send_depend);
|
||||
continue;
|
||||
}
|
||||
if (Reuse(node, user_node_stage, send_ops, DEST_RANK)) {
|
||||
if (Reuse(node, user_node_stage, *send_ops, DEST_RANK)) {
|
||||
continue;
|
||||
}
|
||||
auto send_out = InsertSend(graph, node, user_node_stage, node_stage, micro);
|
||||
MS_EXCEPTION_IF_NULL(send_out.depend);
|
||||
send_ops.push_back(send_out.depend);
|
||||
send_ops->push_back(send_out.depend);
|
||||
send_out.depend->set_user_data<Type>(DTYPE, send_out.type);
|
||||
send_out.depend->set_user_data<ValueList>(SHAPE, send_out.shape);
|
||||
} else {
|
||||
if (!receive) {
|
||||
if (IsParameterGraph(node)) {
|
||||
receive = HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second,
|
||||
receive_ops);
|
||||
receive =
|
||||
HandleParameterGraph(node, user_node, node_stage, user_node_stage, micro, user_pair.second, *receive_ops);
|
||||
if (!receive) {
|
||||
continue;
|
||||
}
|
||||
receive_ops.push_back(receive);
|
||||
receive_ops->push_back(receive);
|
||||
} else {
|
||||
receive =
|
||||
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
|
||||
receive_ops.push_back(receive);
|
||||
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage, micro, node);
|
||||
receive_ops->push_back(receive);
|
||||
}
|
||||
} else {
|
||||
manager_->SetEdge(user_node, user_pair.second, receive);
|
||||
|
@ -882,11 +878,33 @@ std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer:
|
|||
continue;
|
||||
}
|
||||
if (node_stage > user_node_stage) {
|
||||
MS_LOG(EXCEPTION) << "node_stage: " << node_stage
|
||||
<< " must be smaller than user_node_stage: " << user_node_stage;
|
||||
MS_LOG(EXCEPTION) << "node_stage: " << node_stage << " must be smaller than user_node_stage: " << user_node_stage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
|
||||
OperatorAttrs depend_attrs;
|
||||
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, DEPEND);
|
||||
std::vector<AnfNodePtr> send_ops;
|
||||
std::vector<AnfNodePtr> receive_ops;
|
||||
auto ret = graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
|
||||
std::reverse(all_nodes.begin(), all_nodes.end());
|
||||
auto stage_num = g_device_manager->stage_num();
|
||||
if (root_->has_flag(TRAINING) && (stage_num > micro_size_)) {
|
||||
MS_LOG(EXCEPTION) << "MicroBatch size: " << micro_size_ << " can't less than stage num: " << stage_num;
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
auto stage_info = node->user_data<NodeStageInfo>();
|
||||
if (!node->isa<CNode>() || stage_info == nullptr || stage_info->stage() == -1 ||
|
||||
IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
// Modify for lizard cyclomatic complexity.
|
||||
CutBorderForNode(graph, node, &send_ops, &receive_ops);
|
||||
}
|
||||
return std::make_pair(send_ops, receive_ops);
|
||||
}
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class PipelineTransformer {
|
|||
bool IsParameterGraph(const AnfNodePtr &node);
|
||||
AnfNodeIndexSet GetActualOpUsers(const std::pair<AnfNodePtr, int> &node_pair, NodeUsersMap *node_users_map);
|
||||
AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
|
||||
const ValuePtr µ, size_t pos, const std::vector<AnfNodePtr> ops);
|
||||
const ValuePtr µ, size_t pos, const std::vector<AnfNodePtr> &ops);
|
||||
ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
|
||||
std::vector<AnfNodePtr> HandleSharedParameter();
|
||||
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage,
|
||||
|
@ -75,6 +75,8 @@ class PipelineTransformer {
|
|||
int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
|
||||
const AnfNodePtr &graph_param);
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
|
||||
void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops,
|
||||
std::vector<AnfNodePtr> *receive_ops);
|
||||
AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
|
||||
const std::string &tag);
|
||||
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
|
||||
|
@ -99,6 +101,20 @@ class PipelineTransformer {
|
|||
int64_t micro_size_ = 0;
|
||||
std::vector<std::string> group_ = {};
|
||||
};
|
||||
|
||||
class NodeStageInfo {
|
||||
public:
|
||||
explicit NodeStageInfo(int64_t stage) : stage_(stage) {}
|
||||
~NodeStageInfo() = default;
|
||||
|
||||
int64_t stage() const { return stage_; }
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "NodeStageInfo";
|
||||
|
||||
private:
|
||||
int64_t stage_;
|
||||
};
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -362,7 +362,17 @@ ValuePtr ConvertOtherObj(const py::object &obj) {
|
|||
MS_LOG(DEBUG) << "name_space: " << res->ToString();
|
||||
return res;
|
||||
}
|
||||
MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj));
|
||||
// Start RESOLVE_TYPE_INVALID...
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto support_fallback = common::GetEnv("ENV_SUPPORT_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback == "1");
|
||||
if (use_fallback) {
|
||||
auto res = std::make_shared<InterpretedObject>(obj, py::str(obj));
|
||||
MS_LOG(DEBUG) << "Get interpreted object: " << res->ToString();
|
||||
return res;
|
||||
}
|
||||
MS_LOG(ERROR) << "Resolve type is invalid, obj: " << py::str(obj);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -72,6 +72,11 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node "
|
||||
<< node->DebugString();
|
||||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||
|
||||
auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false));
|
||||
if (!is_new_name) {
|
||||
// If a cnode variable with same name already existed but not used,
|
||||
|
@ -92,14 +97,21 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
|
|||
AddIsolatedNode(hidden_node);
|
||||
}
|
||||
iter->second = std::make_pair(node, false);
|
||||
if (use_fallback) {
|
||||
UpdateLocalPyParam(var_name, node);
|
||||
}
|
||||
} else {
|
||||
if (use_fallback) {
|
||||
AddLocalPyParam(var_name, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read variable from predecessors
|
||||
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
||||
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block: " << ToString();
|
||||
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) {
|
||||
MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString();
|
||||
// Get var node if it is found
|
||||
auto found = assigned_vars_.find(var);
|
||||
auto found = assigned_vars_.find(var_name);
|
||||
if (found != assigned_vars_.end()) {
|
||||
auto &node = found->second.first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -117,34 +129,40 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
|
|||
if (prev_blocks_.size() == 1) {
|
||||
auto block = prev_blocks_[0];
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
auto res = block->ReadVariable(var);
|
||||
MS_LOG(INFO) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString()
|
||||
<< ",\nCurrent: " << py::str(global_py_params())
|
||||
auto res = block->ReadVariable(var_name);
|
||||
|
||||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (parser_.support_fallback() == "1");
|
||||
if (use_fallback) {
|
||||
MS_LOG(DEBUG) << "Update global params of block: " << ToString()
|
||||
<< ", with previous block: " << block->ToString() << ",\nCurrent: " << py::str(global_py_params())
|
||||
<< "\nInsert: " << py::str(block->global_py_params());
|
||||
CopyGlobalPyParam(block->global_py_params());
|
||||
UpdateGlobalPyParam(block->global_py_params());
|
||||
}
|
||||
return res;
|
||||
} else if (prev_blocks_.empty()) {
|
||||
// Get namespace and make Resolve
|
||||
auto it = var_to_resolve_.find(var);
|
||||
auto it = var_to_resolve_.find(var_name);
|
||||
if (it != var_to_resolve_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
MS_LOG(DEBUG) << "var: " << var;
|
||||
auto tmp_node = MakeResolveSymbol(var);
|
||||
var_to_resolve_[var] = tmp_node;
|
||||
MS_LOG(DEBUG) << "var: " << var_name;
|
||||
auto tmp_node = MakeResolveSymbol(var_name);
|
||||
var_to_resolve_[var_name] = tmp_node;
|
||||
return tmp_node;
|
||||
}
|
||||
}
|
||||
// If have more than one predecessor blocks then build a phi node.
|
||||
auto debug_info = std::make_shared<NodeDebugInfo>();
|
||||
debug_info->set_name(var);
|
||||
debug_info->set_name(var_name);
|
||||
TraceGuard guard(std::make_shared<TracePhi>(debug_info));
|
||||
ParameterPtr phi_param = std::make_shared<Parameter>(func_graph());
|
||||
MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node "
|
||||
<< phi_param->ToString() << " for " << var;
|
||||
<< phi_param->ToString() << " for " << var_name;
|
||||
func_graph()->add_parameter(phi_param);
|
||||
phi_nodes_[phi_param] = var;
|
||||
WriteVariable(var, phi_param);
|
||||
phi_nodes_[phi_param] = var_name;
|
||||
WriteVariable(var_name, phi_param);
|
||||
if (matured_) {
|
||||
SetPhiArgument(phi_param);
|
||||
}
|
||||
|
@ -337,19 +355,19 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame
|
|||
// Args: phi: This parameter node is functioning as a phi node.
|
||||
bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) {
|
||||
MS_EXCEPTION_IF_NULL(phi);
|
||||
std::string var = phi_nodes_[phi];
|
||||
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var;
|
||||
std::string var_name = phi_nodes_[phi];
|
||||
MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var_name;
|
||||
if (prev_blocks_.empty()) {
|
||||
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var;
|
||||
MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var_name;
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr arg_node = SearchReplaceNode(var, phi);
|
||||
AnfNodePtr arg_node = SearchReplaceNode(var_name, phi);
|
||||
if (arg_node != nullptr) {
|
||||
arg_node->set_debug_info(phi->debug_info());
|
||||
MS_LOG(DEBUG) << "graph " << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " phi " << phi->ToString()
|
||||
<< " can be replaced with " << arg_node->DebugString();
|
||||
// Replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1."
|
||||
WriteVariable(var, arg_node);
|
||||
WriteVariable(var_name, arg_node);
|
||||
removable_phis_[phi] = arg_node;
|
||||
resolve_to_removable_phis_[arg_node] = phi;
|
||||
// The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized
|
||||
|
|
|
@ -86,7 +86,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
py::dict &global_py_params() { return global_py_params_; }
|
||||
void set_global_py_params(const py::dict &symbols) { global_py_params_ = symbols; }
|
||||
void AddGlobalPyParam(const std::string &name, const py::object &obj) { global_py_params_[py::str(name)] = obj; }
|
||||
void CopyGlobalPyParam(const py::dict &symbols) {
|
||||
void UpdateGlobalPyParam(const py::dict &symbols) {
|
||||
for (auto ¶m : symbols) {
|
||||
if (!global_py_params_.contains(param.first)) {
|
||||
global_py_params_[param.first] = param.second;
|
||||
|
@ -101,6 +101,25 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
local_py_params_keys_.emplace_back(NewValueNode(name));
|
||||
local_py_params_values_.emplace_back(node);
|
||||
}
|
||||
// Call this methon only if you need update a variable. Usually variable override.
|
||||
void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) {
|
||||
auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(),
|
||||
[&name](const AnfNodePtr node) -> bool {
|
||||
const auto value_node = dyn_cast<ValueNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value());
|
||||
MS_EXCEPTION_IF_NULL(str_imm);
|
||||
return name == str_imm->value();
|
||||
});
|
||||
if (iter == local_py_params_keys_.cend()) {
|
||||
MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist.";
|
||||
}
|
||||
// Find the same position in 'values', and update the node.
|
||||
auto distance = std::distance(local_py_params_keys_.cbegin(), iter);
|
||||
auto values_pos_iter = local_py_params_values_.begin() + distance;
|
||||
MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString();
|
||||
*values_pos_iter = node;
|
||||
}
|
||||
|
||||
private:
|
||||
// Block graph
|
||||
|
@ -160,16 +179,6 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
// x = x - 1 #This after block is a dead block
|
||||
bool is_dead_block_{false};
|
||||
};
|
||||
|
||||
class ScriptInfo {
|
||||
public:
|
||||
explicit ScriptInfo(const py::object &obj) : py_obj_(obj) {}
|
||||
|
||||
// Key for user data.
|
||||
constexpr static char key[] = "ScriptInfo";
|
||||
|
||||
py::object py_obj_;
|
||||
};
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -30,11 +30,11 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/interpret_node_recorder.h"
|
||||
#include "debug/trace.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parse {
|
||||
|
||||
FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
|
||||
(void)python_adapter::set_python_scoped();
|
||||
|
||||
|
@ -472,15 +472,15 @@ void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const Functi
|
|||
FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast return";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// Create return valuenode
|
||||
AnfNodePtr return_value_node = NewValueNode(prim::kPrimReturn);
|
||||
// Parse the return Statements value
|
||||
py::object value = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr return_expr_node = ParseExprNode(block, value);
|
||||
// Create the cnode
|
||||
auto block_fg = block->func_graph();
|
||||
CNodePtr return_node = block_fg->NewCNodeInOrder({return_value_node, return_expr_node});
|
||||
block_fg->set_return(return_node);
|
||||
// Parse the return Statements value.
|
||||
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
|
||||
// Check if need interpreting.
|
||||
return_expr_node = HandleInterpret(block, return_expr_node, value_object);
|
||||
// Create the `return` CNode.
|
||||
auto func_graph = block->func_graph();
|
||||
CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
|
||||
func_graph->set_return(return_cnode);
|
||||
return block;
|
||||
}
|
||||
|
||||
|
@ -928,13 +928,13 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
|
|||
py::list args = ast_->GetArgs(node);
|
||||
auto block_fg = func_block->func_graph();
|
||||
for (std::size_t i = 0; i < args.size(); i++) {
|
||||
std::string arg = py::cast<std::string>(args[i].attr("arg"));
|
||||
std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
|
||||
TraceGuard guard(GetLocation(args[i]));
|
||||
auto para_node = std::make_shared<Parameter>(block_fg);
|
||||
para_node->debug_info()->set_name(arg);
|
||||
para_node->debug_info()->set_name(arg_name);
|
||||
block_fg->add_parameter(para_node);
|
||||
func_block->WriteVariable(arg, para_node);
|
||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
|
||||
func_block->WriteVariable(arg_name, para_node);
|
||||
MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
|
||||
}
|
||||
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
|
@ -1701,7 +1701,6 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
|||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
|
||||
block->AddLocalPyParam(name_id, assigned_node);
|
||||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
|
@ -1827,19 +1826,13 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
// The fallback feature is enabled in default.
|
||||
// Not support change the flag during the process is alive.
|
||||
static const auto use_fallback = (support_fallback() == "1");
|
||||
if (!use_fallback) {
|
||||
if (!use_fallback || !value_node->interpret()) {
|
||||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr interpreted_node = value_node;
|
||||
if (value_node->interpret()) {
|
||||
const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
|
||||
py::dict global_dict = block->global_py_params();
|
||||
constexpr int recursive_level = 3;
|
||||
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text
|
||||
<< ", value_node: " << value_node->DebugString(recursive_level)
|
||||
<< ", global_dict: " << py::str(global_dict);
|
||||
// Prepare global parameters.
|
||||
py::dict global_dict = block->global_py_params();
|
||||
ValuePtr globals_converted_value = nullptr;
|
||||
if (!ConvertData(global_dict, &globals_converted_value)) {
|
||||
MS_LOG(EXCEPTION) << "Convert data failed";
|
||||
|
@ -1849,13 +1842,19 @@ AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodeP
|
|||
auto [keys, values] = block->local_py_params();
|
||||
auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
|
||||
// Update the valued node if it need interpreting.
|
||||
interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
|
||||
constexpr int recursive_level = 3;
|
||||
MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << script_text
|
||||
<< "`,\nvalue_node: " << value_node->DebugString(recursive_level)
|
||||
<< ",\nglobal_dict_node: " << global_dict_node->ToString()
|
||||
<< ",\nlocal_dict_node: " << local_dict_node->ToString();
|
||||
AnfNodePtr interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
|
||||
|
||||
// Print a hint for user.
|
||||
MS_LOG(ERROR) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
|
||||
auto line_info = trace::GetDebugInfo(value_node->debug_info());
|
||||
MS_LOG(DEBUG) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
|
||||
<< "\n\n"
|
||||
<< trace::GetDebugInfo(value_node->debug_info());
|
||||
}
|
||||
<< line_info;
|
||||
InterpretNodeRecorder::GetInstance().Push(line_info);
|
||||
return interpreted_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -275,12 +275,12 @@ class Parser {
|
|||
// so in FunctionBlock class we can use FunctionBlock* in member
|
||||
// pre_blocks_ and jumps_ to break reference cycle.
|
||||
std::vector<FunctionBlockPtr> func_block_list_;
|
||||
using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
using StmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
using ExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Define the function map to parse ast Statement
|
||||
std::map<std::string, pStmtFunc> stmt_method_map_;
|
||||
std::map<std::string, StmtFunc> stmt_method_map_;
|
||||
// Define the function map to parse ast expression
|
||||
std::map<std::string, pExprFunc> expr_method_map_;
|
||||
std::map<std::string, ExprFunc> expr_method_map_;
|
||||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
string max_for_loop_count_str_;
|
||||
|
|
|
@ -119,8 +119,12 @@ class InterpretedObject : public PyObjectWrapper {
|
|||
: PyObjectWrapper(obj, name) {}
|
||||
~InterpretedObject() override = default;
|
||||
MS_DECLARE_PARENT(InterpretedObject, PyObjectWrapper);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<InterpretedObject>(),
|
||||
std::make_shared<External>());
|
||||
}
|
||||
};
|
||||
using InterpretedObjectPtr = std::shared_ptr<InterpretedObject>;
|
||||
|
||||
// ClassObject class wrappers dataclass
|
||||
class ClassObject : public PyObjectWrapper {
|
||||
|
|
|
@ -26,36 +26,38 @@
|
|||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "pipeline/jit/static_analysis/async_eval_result.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "debug/common.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/graph_util/get_parallel_info.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/convert_utils_py.h"
|
||||
#include "utils/context/context_extends.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/graph_util/get_parallel_info.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "debug/trace.h"
|
||||
#include "debug/draw.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "frontend/optimizer/py_pass_manager.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/info.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "frontend/optimizer/ad/prim_bprop_optimizer.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "utils/crypto.h"
|
||||
#include "utils/comm_manager.h"
|
||||
#include "utils/interpret_node_recorder.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "debug/anf_ir_utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "debug/draw.h"
|
||||
#include "debug/common.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "backend/session/executor_manager.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/constants.h"
|
||||
#include "ps/util.h"
|
||||
|
@ -722,9 +724,8 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
CheckArgsValid(args);
|
||||
|
||||
auto phase = py::cast<std::string>(phase_obj);
|
||||
MS_LOG(INFO) << "Start compiling, phase: " << phase << ".";
|
||||
MS_LOG(DEBUG) << "Compiling source: {" << py::str(source_obj)
|
||||
<< "}\n\n Args: " << py::str(const_cast<py::tuple &>(args));
|
||||
MS_LOG(INFO) << "Start compiling, phase: " << phase;
|
||||
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args));
|
||||
|
||||
#ifdef ENABLE_GE
|
||||
GetGeBackendPolicy();
|
||||
|
@ -1471,6 +1472,7 @@ void ClearResAtexit() {
|
|||
parse::Parser::CleanParserResource();
|
||||
parse::CleanDataClassToClassMap();
|
||||
trace::ClearTraceStack();
|
||||
InterpretNodeRecorder::GetInstance().Clear();
|
||||
}
|
||||
|
||||
py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
|
||||
|
|
|
@ -205,6 +205,12 @@ static ValueNameToConverterVector value_name_to_converter = {
|
|||
auto class_type = value->cast<parse::ClassTypePtr>();
|
||||
return class_type->obj();
|
||||
}},
|
||||
// parse::InterpretedObject
|
||||
{typeid(parse::InterpretedObject).name(),
|
||||
[](const ValuePtr &value) -> py::object {
|
||||
auto interpreted_object = value->cast<parse::InterpretedObjectPtr>();
|
||||
return interpreted_object->obj();
|
||||
}},
|
||||
// None
|
||||
{typeid(None).name(), [](const ValuePtr &value) -> py::object { return py::none(); }},
|
||||
// AnyValue
|
||||
|
|
|
@ -106,8 +106,6 @@ class MS_CORE_API AnfNode : public Base {
|
|||
fullname_with_scope_(""),
|
||||
hash_(std::hash<const AnfNode *>()),
|
||||
kernel_info_(nullptr),
|
||||
stage_(-1),
|
||||
need_grad_(false),
|
||||
interpret_(false),
|
||||
interpreted_node_(nullptr) {
|
||||
scope_ = ScopeManager::GetInstance().GetCurrentScope();
|
||||
|
@ -200,12 +198,6 @@ class MS_CORE_API AnfNode : public Base {
|
|||
|
||||
void CloneUserData(const AnfNodePtr &node) { user_data_ = node->user_data_; }
|
||||
|
||||
int64_t stage() { return stage_; }
|
||||
void set_stage(const int &stage) { stage_ = stage; }
|
||||
|
||||
bool grad() { return need_grad_; }
|
||||
void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
|
||||
|
||||
bool interpret() { return interpret_; }
|
||||
void set_interpret(const bool &interpret) { interpret_ = interpret; }
|
||||
|
||||
|
@ -226,8 +218,6 @@ class MS_CORE_API AnfNode : public Base {
|
|||
ScopePtr scope_;
|
||||
KernelInfoDevicePtr kernel_info_;
|
||||
UserData user_data_;
|
||||
int64_t stage_;
|
||||
bool need_grad_;
|
||||
bool interpret_;
|
||||
AnfNodePtr interpreted_node_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
||||
#define MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
class InterpretNodeRecorder {
|
||||
public:
|
||||
explicit InterpretNodeRecorder(InterpretNodeRecorder &&) = delete;
|
||||
explicit InterpretNodeRecorder(const InterpretNodeRecorder &) = delete;
|
||||
void operator=(const InterpretNodeRecorder &) = delete;
|
||||
void operator=(const InterpretNodeRecorder &&) = delete;
|
||||
static InterpretNodeRecorder &GetInstance() {
|
||||
static InterpretNodeRecorder instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Push(const std::string &line) { interpret_nodes_lines_.emplace_back(line); }
|
||||
|
||||
void Clear() { interpret_nodes_lines_.clear(); }
|
||||
|
||||
protected:
|
||||
InterpretNodeRecorder() = default;
|
||||
virtual ~InterpretNodeRecorder() = default;
|
||||
|
||||
private:
|
||||
std::vector<std::string> interpret_nodes_lines_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_UTILS_InterpretNodeRecorder_H_
|
|
@ -80,14 +80,28 @@ def test_np_fallback_func():
|
|||
print(np_fallback_func())
|
||||
|
||||
|
||||
# Test `return` interpret node.
|
||||
@ms_function
|
||||
def div_mod_func(x, y):
|
||||
def div_mod_func1():
|
||||
x = 8
|
||||
y = 3
|
||||
a = divmod(x, y)
|
||||
return Tensor(a)
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func():
|
||||
print(div_mod_func(8, 3)) # (2, 2)
|
||||
def test_div_mod_func1():
|
||||
print(div_mod_func1()) # (2, 2)
|
||||
|
||||
|
||||
# Test interpret node with parameters as input.
|
||||
@ms_function
|
||||
def div_mod_func2(x, y):
|
||||
a = divmod(x, y)
|
||||
return Tensor(a)
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_div_mod_func2():
|
||||
print(div_mod_func2(8, 3)) # (2, 2)
|
||||
|
||||
|
||||
# NameError: name 'Tensor' is not defined.
|
||||
|
|
Loading…
Reference in New Issue