!15695 dont insert VirtualOutput for scalar

From: @yao_yf
Reviewed-by: @yangzhenzhang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-04-27 09:25:30 +08:00 committed by Gitee
commit e3d54b7e3b
2 changed files with 60 additions and 2 deletions

View File

@ -1040,14 +1040,17 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
cnode = node_pair.first->cast<CNodePtr>();
last_indexs[last_node_index] = size_t(node_pair.second);
}
auto pre_node = cnode->input(last_indexs[last_node_index]);
Shapes shape_outputs = GetNodeShape(pre_node);
if (shape_outputs[0].empty()) {
continue;
}
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
OperatorParams params;
OperatorAttrs attrs;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
auto pre_node = cnode->input(last_indexs[last_node_index]);
Shapes shape_outputs = GetNodeShape(pre_node);
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();

View File

@ -97,6 +97,24 @@ class ReshapeMulNet(nn.Cell):
out = self.mul(weight, self.mul_weight)
return out
class ParallelMulNet(nn.Cell):
def __init__(self, dense_in_channel=2048, dense_out_channel=250):
super().__init__()
weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32)
bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32)
self.flat = nn.Flatten()
self.dense = nn.Dense(in_channels=dense_in_channel,
out_channels=dense_out_channel,
weight_init=Tensor(weight_np),
bias_init=Tensor(bias_np),
has_bias=True)
self.mul = P.Mul()
def construct(self, inputs):
x = self.flat(inputs)
x = self.dense(x)
x = self.mul(x, x)
return x
def compile_graph(x, net):
net.set_auto_parallel()
net.set_train(False)
@ -104,6 +122,13 @@ def compile_graph(x, net):
strategies = _executor._get_shard_strategy(net)
return strategies
def compile_graph_two_input(x, y, net):
net.set_auto_parallel()
net.set_train(False)
_executor.compile(net, x, y, auto_parallel_mode=True)
strategies = _executor._get_shard_strategy(net)
return strategies
def test_dense_relu_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
@ -250,3 +275,33 @@ def test_reshape_mul_auto():
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 1
def test_scalar_output_semi_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
net = ParallelMulNet()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
eval_net = nn.WithEvalCell(net, loss_fn)
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
strategies = compile_graph_two_input(x, label, eval_net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
count += 1
assert count == 1
def test_scalar_output_auto():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
net = ParallelMulNet()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
eval_net = nn.WithEvalCell(net, loss_fn)
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
strategies = compile_graph_two_input(x, label, eval_net)
count = 0
for (k, v) in strategies.items():
if re.search('VirtualOutput-op', k) is not None:
assert v[0][0] == 8
count += 1
assert count == 1