forked from mindspore-Ecosystem/mindspore
!15695 dont insert VirtualOutput for scalar
From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
e3d54b7e3b
|
@ -1040,14 +1040,17 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
|||
cnode = node_pair.first->cast<CNodePtr>();
|
||||
last_indexs[last_node_index] = size_t(node_pair.second);
|
||||
}
|
||||
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
||||
Shapes shape_outputs = GetNodeShape(pre_node);
|
||||
if (shape_outputs[0].empty()) {
|
||||
continue;
|
||||
}
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
OperatorParams params;
|
||||
OperatorAttrs attrs;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
|
||||
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
||||
Shapes shape_outputs = GetNodeShape(pre_node);
|
||||
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
|
||||
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
|
||||
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
|
||||
|
|
|
@ -97,6 +97,24 @@ class ReshapeMulNet(nn.Cell):
|
|||
out = self.mul(weight, self.mul_weight)
|
||||
return out
|
||||
|
||||
class ParallelMulNet(nn.Cell):
|
||||
def __init__(self, dense_in_channel=2048, dense_out_channel=250):
|
||||
super().__init__()
|
||||
weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32)
|
||||
bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32)
|
||||
self.flat = nn.Flatten()
|
||||
self.dense = nn.Dense(in_channels=dense_in_channel,
|
||||
out_channels=dense_out_channel,
|
||||
weight_init=Tensor(weight_np),
|
||||
bias_init=Tensor(bias_np),
|
||||
has_bias=True)
|
||||
self.mul = P.Mul()
|
||||
def construct(self, inputs):
|
||||
x = self.flat(inputs)
|
||||
x = self.dense(x)
|
||||
x = self.mul(x, x)
|
||||
return x
|
||||
|
||||
def compile_graph(x, net):
|
||||
net.set_auto_parallel()
|
||||
net.set_train(False)
|
||||
|
@ -104,6 +122,13 @@ def compile_graph(x, net):
|
|||
strategies = _executor._get_shard_strategy(net)
|
||||
return strategies
|
||||
|
||||
def compile_graph_two_input(x, y, net):
|
||||
net.set_auto_parallel()
|
||||
net.set_train(False)
|
||||
_executor.compile(net, x, y, auto_parallel_mode=True)
|
||||
strategies = _executor._get_shard_strategy(net)
|
||||
return strategies
|
||||
|
||||
|
||||
def test_dense_relu_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
|
@ -250,3 +275,33 @@ def test_reshape_mul_auto():
|
|||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 1
|
||||
|
||||
def test_scalar_output_semi_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||
net = ParallelMulNet()
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||
eval_net = nn.WithEvalCell(net, loss_fn)
|
||||
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
|
||||
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
|
||||
strategies = compile_graph_two_input(x, label, eval_net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
count += 1
|
||||
assert count == 1
|
||||
|
||||
def test_scalar_output_auto():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||
net = ParallelMulNet()
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||
eval_net = nn.WithEvalCell(net, loss_fn)
|
||||
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
|
||||
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
|
||||
strategies = compile_graph_two_input(x, label, eval_net)
|
||||
count = 0
|
||||
for (k, v) in strategies.items():
|
||||
if re.search('VirtualOutput-op', k) is not None:
|
||||
assert v[0][0] == 8
|
||||
count += 1
|
||||
assert count == 1
|
||||
|
|
Loading…
Reference in New Issue