fix auto parallel mutigrpah bug
This commit is contained in:
parent
811704a30a
commit
6b2a9de09f
|
@ -1269,15 +1269,17 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
|
|||
} else {
|
||||
AnfNodeIndexSet param_sub_set = manager->node_users()[parameter];
|
||||
for (auto ¶m_pair : param_sub_set) {
|
||||
CNodePtr graph_cnode = param_pair.first->cast<CNodePtr>();
|
||||
if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa<CNode>()) {
|
||||
CNodePtr param_cnode = param_pair.first->cast<CNodePtr>();
|
||||
AnfNodePtr graph_value_node;
|
||||
if (param_cnode->input(0)->isa<CNode>()) {
|
||||
graph_value_node = param_cnode->input(0)->cast<CNodePtr>()->input(1);
|
||||
} else {
|
||||
graph_value_node = param_cnode->input(0);
|
||||
}
|
||||
if (!IsValueNode<FuncGraph>(graph_value_node)) {
|
||||
continue;
|
||||
}
|
||||
CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>();
|
||||
if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
|
||||
continue;
|
||||
}
|
||||
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1));
|
||||
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_value_node);
|
||||
auto parameters = graph_sub->parameters();
|
||||
if (IntToSize(param_pair.second - 1) >= parameters.size()) {
|
||||
MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is "
|
||||
|
@ -1864,7 +1866,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// return -> make_tuple
|
||||
if (current_prim->name() == MAKE_TUPLE) {
|
||||
MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported";
|
||||
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// return -> loss
|
||||
|
@ -2069,6 +2072,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|||
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
MS_LOG(DEBUG) << "Find the forward graph success";
|
||||
graph_set.insert(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto graph_used = manager->func_graphs_used_total(graph);
|
||||
for (auto &sub_graph : graph_used) {
|
||||
graph_set.insert(sub_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return graph_set;
|
||||
|
@ -2423,7 +2432,7 @@ void HandleRootReshape(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
void MarkForwardCNode(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
auto all_nodes = root->nodes();
|
||||
std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
|
||||
auto graph_set = FindForwardGraphByRootNodes(all_nodes);
|
||||
|
||||
if (graph_set.empty()) {
|
||||
MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
|
||||
|
|
|
@ -145,10 +145,10 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode);
|
|||
|
||||
Status ParallelInit();
|
||||
|
||||
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
|
||||
|
||||
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
|
||||
|
||||
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
|
||||
|
||||
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
|
||||
|
||||
using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>;
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.context as context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class TwoInputBprop(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return x * 5, y * 8
|
||||
|
||||
|
||||
class ParallelFloorDivBpropNet(Cell):
|
||||
def __init__(self, mul_size, test_size, strategy=None, strategy2=None):
|
||||
super().__init__()
|
||||
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
|
||||
floordiv_np = np.full(test_size, 0.1, dtype=np.float32)
|
||||
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
|
||||
self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
|
||||
self.mul = TwoInputBprop()
|
||||
self.floor_div = P.FloorDiv()
|
||||
if strategy is not None:
|
||||
self.mul.op.shard(strategy2)
|
||||
self.floor_div.shard(strategy)
|
||||
|
||||
def construct(self, inputs, label):
|
||||
x = self.mul(inputs, self.mul_weight)
|
||||
x = self.floor_div(x, self.floordiv_weight)
|
||||
return x
|
||||
|
||||
|
||||
inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||
label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
_executor.compile(train_net, inputs_, label_)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_net():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||
strategy = ((4, 1), (4, 1))
|
||||
net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue