!5661 fix auto parallel reshape strategy set when it is first operator
Merge pull request !5661 from yao_yf/auto_parallel_reshape_fix
This commit is contained in:
commit
ccc0ea60ee
|
@ -1565,24 +1565,33 @@ Status CostGraph::InitSelectedStrategy() {
|
||||||
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
|
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
|
||||||
return edge->next_operator()->name() == reshape_info->next_operator_name();
|
return edge->next_operator()->name() == reshape_info->next_operator_name();
|
||||||
});
|
});
|
||||||
if (pre_iter != in_edges.end()) {
|
bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
|
||||||
|
if (reshape_is_first_op) {
|
||||||
|
reshape_info->InitSelectedStrategy(reshape_info->selected_strategy());
|
||||||
|
}
|
||||||
|
if (pre_iter != in_edges.end() || reshape_is_first_op) {
|
||||||
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
|
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
|
||||||
int32_t pre_index = reshape_info->pre_operator_index();
|
int32_t pre_index = reshape_info->pre_operator_index();
|
||||||
TensorInfo pre_info;
|
TensorInfo pre_info;
|
||||||
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
|
std::shared_ptr<OperatorInfo> pre_op_info;
|
||||||
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
|
if (reshape_is_first_op) {
|
||||||
|
pre_op_info = reshape_info;
|
||||||
|
pre_info = pre_op_info->inputs_tensor_info()[pre_index];
|
||||||
} else {
|
} else {
|
||||||
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
|
pre_op_info = (*pre_iter)->prev_operator();
|
||||||
|
pre_info = pre_op_info->outputs_tensor_info()[pre_index];
|
||||||
}
|
}
|
||||||
reshape_info->SetInputLayout(pre_info.tensor_layout());
|
reshape_info->SetInputLayout(pre_info.tensor_layout());
|
||||||
Dimensions stra = pre_info.InferStrategy();
|
if (pre_iter != in_edges.end()) {
|
||||||
if (stra.empty()) {
|
Dimensions stra = pre_info.InferStrategy();
|
||||||
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
if (stra.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
||||||
|
}
|
||||||
|
Strategys stra_inputs = {stra};
|
||||||
|
StrategyPtr reshape_stra =
|
||||||
|
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
||||||
|
reshape_info->set_strategy(reshape_stra);
|
||||||
}
|
}
|
||||||
Strategys stra_inputs = {stra};
|
|
||||||
StrategyPtr reshape_stra =
|
|
||||||
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
|
||||||
reshape_info->set_strategy(reshape_stra);
|
|
||||||
}
|
}
|
||||||
if (next_iter != out_edges.end()) {
|
if (next_iter != out_edges.end()) {
|
||||||
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
|
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
|
||||||
|
|
|
@ -245,3 +245,50 @@ def test_reshape_auto_5():
|
||||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
net.set_auto_parallel()
|
net.set_auto_parallel()
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
def test_reshape_auto_6():
|
||||||
|
class NetWithLoss6(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(NetWithLoss6, self).__init__()
|
||||||
|
self.loss = VirtualLoss()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
predict = self.network(x, y)
|
||||||
|
return self.loss(predict)
|
||||||
|
|
||||||
|
class GradWrap6(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap6, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return grad_all(self.network)(x, y)
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.mul = P.Mul()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
out1 = x + self.wide_w
|
||||||
|
w = self.reshape(self.wide_w, (4, 1024))
|
||||||
|
out1 = self.reduce_mean(out1, 1)
|
||||||
|
out1 = out1 - w
|
||||||
|
out2 = self.mul(y, w)
|
||||||
|
out = out1 + out2
|
||||||
|
return out
|
||||||
|
|
||||||
|
size = 8
|
||||||
|
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||||
|
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
|
||||||
|
|
||||||
|
net = GradWrap6(NetWithLoss6(Net()))
|
||||||
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||||
|
net.set_auto_parallel()
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
Loading…
Reference in New Issue