forked from mindspore-Ecosystem/mindspore
fix auto parallel reshape strategy set when it is first operator
This commit is contained in:
parent
03093778df
commit
755f381406
|
@ -1565,24 +1565,33 @@ Status CostGraph::InitSelectedStrategy() {
|
|||
auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr<Edge> edge) {
|
||||
return edge->next_operator()->name() == reshape_info->next_operator_name();
|
||||
});
|
||||
if (pre_iter != in_edges.end()) {
|
||||
bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
|
||||
if (reshape_is_first_op) {
|
||||
reshape_info->InitSelectedStrategy(reshape_info->selected_strategy());
|
||||
}
|
||||
if (pre_iter != in_edges.end() || reshape_is_first_op) {
|
||||
MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
|
||||
int32_t pre_index = reshape_info->pre_operator_index();
|
||||
TensorInfo pre_info;
|
||||
if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) {
|
||||
pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index];
|
||||
std::shared_ptr<OperatorInfo> pre_op_info;
|
||||
if (reshape_is_first_op) {
|
||||
pre_op_info = reshape_info;
|
||||
pre_info = pre_op_info->inputs_tensor_info()[pre_index];
|
||||
} else {
|
||||
pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index];
|
||||
pre_op_info = (*pre_iter)->prev_operator();
|
||||
pre_info = pre_op_info->outputs_tensor_info()[pre_index];
|
||||
}
|
||||
reshape_info->SetInputLayout(pre_info.tensor_layout());
|
||||
Dimensions stra = pre_info.InferStrategy();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
||||
if (pre_iter != in_edges.end()) {
|
||||
Dimensions stra = pre_info.InferStrategy();
|
||||
if (stra.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
|
||||
}
|
||||
Strategys stra_inputs = {stra};
|
||||
StrategyPtr reshape_stra =
|
||||
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
||||
reshape_info->set_strategy(reshape_stra);
|
||||
}
|
||||
Strategys stra_inputs = {stra};
|
||||
StrategyPtr reshape_stra =
|
||||
std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
|
||||
reshape_info->set_strategy(reshape_stra);
|
||||
}
|
||||
if (next_iter != out_edges.end()) {
|
||||
MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
|
||||
|
|
|
@ -245,3 +245,50 @@ def test_reshape_auto_5():
|
|||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
def test_reshape_auto_6():
|
||||
class NetWithLoss6(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss6, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrap6(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap6, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.mul = P.Mul()
|
||||
self.reshape = P.Reshape()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = x + self.wide_w
|
||||
w = self.reshape(self.wide_w, (4, 1024))
|
||||
out1 = self.reduce_mean(out1, 1)
|
||||
out1 = out1 - w
|
||||
out2 = self.mul(y, w)
|
||||
out = out1 + out2
|
||||
return out
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
|
||||
|
||||
net = GradWrap6(NetWithLoss6(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
_executor.compile(net, x, y)
|
||||
|
|
Loading…
Reference in New Issue