forked from mindspore-Ecosystem/mindspore
!15695 dont insert VirtualOutput for scalar
From: @yao_yf Reviewed-by: @yangzhenzhang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
e3d54b7e3b
|
@ -1040,14 +1040,17 @@ void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr>
|
||||||
cnode = node_pair.first->cast<CNodePtr>();
|
cnode = node_pair.first->cast<CNodePtr>();
|
||||||
last_indexs[last_node_index] = size_t(node_pair.second);
|
last_indexs[last_node_index] = size_t(node_pair.second);
|
||||||
}
|
}
|
||||||
|
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
||||||
|
Shapes shape_outputs = GetNodeShape(pre_node);
|
||||||
|
if (shape_outputs[0].empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
FuncGraphPtr func_graph = node->func_graph();
|
FuncGraphPtr func_graph = node->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
OperatorParams params;
|
OperatorParams params;
|
||||||
OperatorAttrs attrs;
|
OperatorAttrs attrs;
|
||||||
OperatorArgs args = std::make_pair(attrs, params);
|
OperatorArgs args = std::make_pair(attrs, params);
|
||||||
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
|
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
|
||||||
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
|
||||||
Shapes shape_outputs = GetNodeShape(pre_node);
|
|
||||||
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
|
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
|
||||||
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
|
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
|
||||||
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
|
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
|
||||||
|
|
|
@ -97,6 +97,24 @@ class ReshapeMulNet(nn.Cell):
|
||||||
out = self.mul(weight, self.mul_weight)
|
out = self.mul(weight, self.mul_weight)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class ParallelMulNet(nn.Cell):
|
||||||
|
def __init__(self, dense_in_channel=2048, dense_out_channel=250):
|
||||||
|
super().__init__()
|
||||||
|
weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32)
|
||||||
|
bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32)
|
||||||
|
self.flat = nn.Flatten()
|
||||||
|
self.dense = nn.Dense(in_channels=dense_in_channel,
|
||||||
|
out_channels=dense_out_channel,
|
||||||
|
weight_init=Tensor(weight_np),
|
||||||
|
bias_init=Tensor(bias_np),
|
||||||
|
has_bias=True)
|
||||||
|
self.mul = P.Mul()
|
||||||
|
def construct(self, inputs):
|
||||||
|
x = self.flat(inputs)
|
||||||
|
x = self.dense(x)
|
||||||
|
x = self.mul(x, x)
|
||||||
|
return x
|
||||||
|
|
||||||
def compile_graph(x, net):
|
def compile_graph(x, net):
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
net.set_train(False)
|
net.set_train(False)
|
||||||
|
@ -104,6 +122,13 @@ def compile_graph(x, net):
|
||||||
strategies = _executor._get_shard_strategy(net)
|
strategies = _executor._get_shard_strategy(net)
|
||||||
return strategies
|
return strategies
|
||||||
|
|
||||||
|
def compile_graph_two_input(x, y, net):
|
||||||
|
net.set_auto_parallel()
|
||||||
|
net.set_train(False)
|
||||||
|
_executor.compile(net, x, y, auto_parallel_mode=True)
|
||||||
|
strategies = _executor._get_shard_strategy(net)
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
|
||||||
def test_dense_relu_semi_auto():
|
def test_dense_relu_semi_auto():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||||
|
@ -250,3 +275,33 @@ def test_reshape_mul_auto():
|
||||||
for (k, v) in strategies.items():
|
for (k, v) in strategies.items():
|
||||||
if re.search('VirtualOutput-op', k) is not None:
|
if re.search('VirtualOutput-op', k) is not None:
|
||||||
assert v[0][0] == 1
|
assert v[0][0] == 1
|
||||||
|
|
||||||
|
def test_scalar_output_semi_auto():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
|
||||||
|
net = ParallelMulNet()
|
||||||
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||||
|
eval_net = nn.WithEvalCell(net, loss_fn)
|
||||||
|
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
|
||||||
|
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
|
||||||
|
strategies = compile_graph_two_input(x, label, eval_net)
|
||||||
|
count = 0
|
||||||
|
for (k, v) in strategies.items():
|
||||||
|
if re.search('VirtualOutput-op', k) is not None:
|
||||||
|
assert v[0][0] == 8
|
||||||
|
count += 1
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
def test_scalar_output_auto():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
|
||||||
|
net = ParallelMulNet()
|
||||||
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||||
|
eval_net = nn.WithEvalCell(net, loss_fn)
|
||||||
|
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01)
|
||||||
|
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01)
|
||||||
|
strategies = compile_graph_two_input(x, label, eval_net)
|
||||||
|
count = 0
|
||||||
|
for (k, v) in strategies.items():
|
||||||
|
if re.search('VirtualOutput-op', k) is not None:
|
||||||
|
assert v[0][0] == 8
|
||||||
|
count += 1
|
||||||
|
assert count == 1
|
||||||
|
|
Loading…
Reference in New Issue