!5661 fix auto parallel reshape strategy set when it is first operator

Merge pull request !5661 from yao_yf/auto_parallel_reshape_fix
This commit is contained in:
mindspore-ci-bot 2020-09-02 15:42:47 +08:00 committed by Gitee
commit ccc0ea60ee
2 changed files with 67 additions and 11 deletions

View File

@ -1565,16 +1565,24 @@ Status CostGraph::InitSelectedStrategy() {
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
return edge->next_operator()->name() == reshape_info->next_operator_name();
});
if (pre_iter != in_edges.end()) {
bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
if (reshape_is_first_op) {
reshape_info->InitSelectedStrategy(reshape_info->selected_strategy());
}
if (pre_iter != in_edges.end() || reshape_is_first_op) {
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
int32_t pre_index = reshape_info->pre_operator_index();
TensorInfo pre_info;
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
std::shared_ptr<OperatorInfo> pre_op_info;
if (reshape_is_first_op) {
pre_op_info = reshape_info;
pre_info = pre_op_info->inputs_tensor_info()[pre_index];
} else {
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
pre_op_info = (*pre_iter)->prev_operator();
pre_info = pre_op_info->outputs_tensor_info()[pre_index];
}
reshape_info->SetInputLayout(pre_info.tensor_layout());
if (pre_iter != in_edges.end()) {
Dimensions stra = pre_info.InferStrategy();
if (stra.empty()) {
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
@ -1584,6 +1592,7 @@ Status CostGraph::InitSelectedStrategy() {
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
reshape_info->set_strategy(reshape_stra);
}
}
if (next_iter != out_edges.end()) {
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
int32_t next_index = reshape_info->next_operator_index();

View File

@ -245,3 +245,50 @@ def test_reshape_auto_5():
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_reshape_auto_6():
class NetWithLoss6(nn.Cell):
def __init__(self, network):
super(NetWithLoss6, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap6(nn.Cell):
def __init__(self, network):
super(GradWrap6, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = P.ReLU()
self.mul = P.Mul()
self.reshape = P.Reshape()
self.reduce_mean = P.ReduceMean()
self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
def construct(self, x, y):
out1 = x + self.wide_w
w = self.reshape(self.wide_w, (4, 1024))
out1 = self.reduce_mean(out1, 1)
out1 = out1 - w
out2 = self.mul(y, w)
out = out1 + out2
return out
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
net = GradWrap6(NetWithLoss6(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x, y)