fix auto parallel mutigrpah bug

This commit is contained in:
lichenever 2020-09-16 16:05:47 +08:00
parent 811704a30a
commit 6b2a9de09f
3 changed files with 89 additions and 11 deletions

View File

@ -1269,15 +1269,17 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
} else {
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
for (auto &param_pair : param_sub_set) {
CNodePtr graph_cnode = param_pair.first->cast<CNodePtr>();
if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa<CNode>()) {
CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
AnfNodePtr graph_value_node;
if (param_cnode->input(0)->isa<CNode>()) {
graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
} else {
graph_value_node = param_cnode->input(0);
}
if (!IsValueNode<FuncGraph>(graph_value_node)) {
continue;
}
CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>();
if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
continue;
}
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1));
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
auto parameters = graph_sub->parameters();
if (IntToSize(param_pair.second - 1) >= parameters.size()) {
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
@ -1864,7 +1866,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
// return -> make_tuple
if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported";
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
return nullptr;
}
// return -> loss
@ -2069,6 +2072,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
MS_LOG(DEBUG) << "Find the forward graph success";
graph_set.insert(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto graph_used = manager->func_graphs_used_total(graph);
for (auto &sub_graph : graph_used) {
graph_set.insert(sub_graph);
}
}
}
return graph_set;
@ -2423,7 +2432,7 @@ void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) {
void MarkForwardCNode(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
auto all_nodes = root->nodes();
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
auto graph_set = FindForwardGraphByRootNodes(all_nodes);
if (graph_set.empty()) {
MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";

View File

@ -145,10 +145,10 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode);
Status ParallelInit();
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>;

View File

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class TwoInputBprop(Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
def construct(self, x, y):
return self.op(x, y)
def bprop(self, x, y, out, dout):
return x * 5, y * 8
class ParallelFloorDivBpropNet(Cell):
def __init__(self, mul_size, test_size, strategy=None, strategy2=None):
super().__init__()
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
floordiv_np = np.full(test_size, 0.1, dtype=np.float32)
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
self.mul = TwoInputBprop()
self.floor_div = P.FloorDiv()
if strategy is not None:
self.mul.op.shard(strategy2)
self.floor_div.shard(strategy)
def construct(self, inputs, label):
x = self.mul(inputs, self.mul_weight)
x = self.floor_div(x, self.floordiv_weight)
return x
inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, inputs_, label_)
context.reset_auto_parallel_context()
def test_net():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy = ((4, 1), (4, 1))
net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy)
compile_net(net)