!31939 modify_virtualdataset_for_master

Merge pull request !31939 from lilei/modify_virtualdataset_for_master
This commit is contained in:
i-robot 2022-03-28 11:42:46 +00:00 committed by Gitee
commit a391c78e9a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 25 additions and 3 deletions

View File

@ -131,6 +131,29 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
std::set<FuncGraphPtr> graph_sets;
if (!root->has_flag(parallel::kAutoParallel)) {
return graph_sets;
}
std::set<AnfNodePtr> input_parameters;
for (auto &anf_param : root->parameters()) {
auto param = anf_param->cast<ParameterPtr>();
if (!param->has_default()) {
input_parameters.insert(anf_param);
}
}
for (auto input_parameter : input_parameters) {
auto node_users_map = root->manager()->node_users();
auto node_users = node_users_map[input_parameter];
for (auto node_user : node_users) {
auto cnode = node_user.first->cast<CNodePtr>();
if (IsValueNode<FuncGraph>(cnode->inputs()[0])) {
graph_sets.insert(GetValueNode<FuncGraphPtr>(cnode->inputs()[0]));
}
if (IsValueNode<Primitive>(cnode->inputs()[0])) {
graph_sets.insert(cnode->func_graph());
}
}
}
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
@ -151,7 +174,6 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
graph_sets.insert(fun_graph);
}
}
graph_sets.insert(root);
return graph_sets;
}

View File

@ -284,11 +284,11 @@ class Cell(Cell_):
@pipeline_stage.setter
def pipeline_stage(self, value):
if not isinstance(value, int) or isinstance(value, bool):
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
raise TypeError("For 'Cell', the property 'pipeline_stage' "
"must be int type, but got type : {}".format(type(value)))
if value < 0:
raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
raise ValueError("For 'Cell', the property 'pipeline_stage' "
"can not be less than 0, but got {}".format(value))
self._pipeline_stage = value
for item in self.trainable_params():