!31939 modify_virtualdataset_for_master
Merge pull request !31939 from lilei/modify_virtualdataset_for_master
This commit is contained in:
commit
a391c78e9a
|
@ -131,6 +131,29 @@ static CNodePtr CreateVirtualDataset(const FuncGraphPtr &func_graph) {
|
|||
|
||||
static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::set<FuncGraphPtr> graph_sets;
|
||||
if (!root->has_flag(parallel::kAutoParallel)) {
|
||||
return graph_sets;
|
||||
}
|
||||
std::set<AnfNodePtr> input_parameters;
|
||||
for (auto &anf_param : root->parameters()) {
|
||||
auto param = anf_param->cast<ParameterPtr>();
|
||||
if (!param->has_default()) {
|
||||
input_parameters.insert(anf_param);
|
||||
}
|
||||
}
|
||||
for (auto input_parameter : input_parameters) {
|
||||
auto node_users_map = root->manager()->node_users();
|
||||
auto node_users = node_users_map[input_parameter];
|
||||
for (auto node_user : node_users) {
|
||||
auto cnode = node_user.first->cast<CNodePtr>();
|
||||
if (IsValueNode<FuncGraph>(cnode->inputs()[0])) {
|
||||
graph_sets.insert(GetValueNode<FuncGraphPtr>(cnode->inputs()[0]));
|
||||
}
|
||||
if (IsValueNode<Primitive>(cnode->inputs()[0])) {
|
||||
graph_sets.insert(cnode->func_graph());
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
|
@ -151,7 +174,6 @@ static std::set<FuncGraphPtr> FindForwardGraph(const FuncGraphPtr &root, const s
|
|||
graph_sets.insert(fun_graph);
|
||||
}
|
||||
}
|
||||
graph_sets.insert(root);
|
||||
return graph_sets;
|
||||
}
|
||||
|
||||
|
|
|
@ -284,11 +284,11 @@ class Cell(Cell_):
|
|||
@pipeline_stage.setter
|
||||
def pipeline_stage(self, value):
|
||||
if not isinstance(value, int) or isinstance(value, bool):
|
||||
raise TypeError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||
raise TypeError("For 'Cell', the property 'pipeline_stage' "
|
||||
"must be int type, but got type : {}".format(type(value)))
|
||||
|
||||
if value < 0:
|
||||
raise ValueError("For 'context.set_auto_parallel_context', the argument 'pipeline_stages' "
|
||||
raise ValueError("For 'Cell', the property 'pipeline_stage' "
|
||||
"can not be less than 0, but got {}".format(value))
|
||||
self._pipeline_stage = value
|
||||
for item in self.trainable_params():
|
||||
|
|
Loading…
Reference in New Issue