!21527 fix_reshape_reshape_in_auto_parallel

Merge pull request !21527 from yao_yf/fix_reshape_reshape_in_auto_parallel
This commit is contained in:
i-robot 2021-08-13 08:29:15 +00:00 committed by Gitee
commit 530da3e37e
7 changed files with 89 additions and 13 deletions

View File

@ -319,7 +319,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
return false;
}
auto node_op_info = cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
*pre_operator_info = node_op_info;
*out_index = 0;
return true;
@ -358,7 +358,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
// Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
// if reshape's output connect to several primitive, return the first layout found
bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
size_t curr_depth) {
bool *is_next_reshape, size_t curr_depth) {
if (curr_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
return false;
@ -373,6 +373,10 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
*is_next_reshape = true;
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
@ -384,6 +388,7 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
auto op_info = use_apply->user_data<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
*is_next_reshape = false;
*next_operator_info = op_info;
*in_index = node_pair.second - 1;
return true;
@ -391,7 +396,7 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " "
<< IsParallelCareNode(use_apply) << " " << (op_info != nullptr);
if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, ++curr_depth)) {
if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, is_next_reshape, ++curr_depth)) {
return true;
}
}

View File

@ -51,7 +51,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
size_t curr_depth);
bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
size_t curr_depth);
bool *is_next_reshape, size_t curr_depth);
void SetUserAttrs(const std::unordered_map<std::string, ValuePtr> &origin_prim_attrs, PrimitivePtr self_prim);
} // namespace parallel
} // namespace mindspore

View File

@ -443,7 +443,8 @@ std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
int64_t out_index, int64_t in_index, bool is_prev_param) {
int64_t out_index, int64_t in_index, bool is_prev_param,
bool is_next_reshape) {
is_generating_costs_ = true;
for (auto pre_stra_cost : pre_stra_costs) {
std::vector<TensorInfo> pre_out_tensor_infos;
@ -466,7 +467,12 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
}
Strategys stra_inputs = {stra};
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
if (next_stra_costs.empty()) {
if (is_next_reshape) {
SetOutputLayout(pre_out_tensor_info.tensor_layout());
ResetQueueMember();
InferTensorInfoByLayout();
SetCostForReshape(reshape_stra);
} else if (next_stra_costs.empty()) {
if (Init(nullptr) == FAILED) {
MS_LOG(ERROR) << "Failure:operator reshape init failed";
return FAILED;
@ -481,6 +487,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
return FAILED;
}
TensorInfo next_in_tensor_info = next_in_tensor_infos[LongToSize(in_index)];
SetOutputLayout(next_in_tensor_info.tensor_layout());
ResetQueueMember();
InferTensorInfoByLayout();

View File

@ -60,7 +60,7 @@ class ReshapeInfo : public OperatorInfo {
void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
int64_t in_index, bool is_prev_param);
int64_t in_index, bool is_prev_param, bool is_next_reshape);
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;

View File

@ -874,8 +874,9 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
// get next node's strategy_cost_
int64_t in_index = 0;
OperatorInfoPtr next_operator_info;
bool is_next_reshape = false;
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, 0);
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, &is_next_reshape, 0);
if (!find_next_node) {
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
}
@ -890,8 +891,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
reshape_info->set_next_operator_index(in_index);
}
bool is_prev_param = pre_node->isa<Parameter>();
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
SUCCESS) {
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
is_next_reshape) != SUCCESS) {
MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
}
}

View File

@ -2233,7 +2233,7 @@ TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_
}
// if reshape's output connect to several primitive, return the first layout found
std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
@ -2244,6 +2244,10 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
*next_is_reshape = true;
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
@ -2254,13 +2258,14 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
}
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
*next_is_reshape = false;
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << use_apply->has_user_data<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply);
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
if (layout_ptr) {
return layout_ptr;
}
@ -2475,10 +2480,14 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
}
auto next_layout_ptr = FindNextLayout(cnode);
bool is_next_reshape = false;
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
if (next_layout_ptr) {
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
} else if (is_next_reshape && prev_layout_ptr != nullptr) {
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
}
if (operator_info->Init(nullptr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";

View File

@ -323,3 +323,57 @@ def test_reshape_auto_7():
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x)
def test_reshape_depend_reshape():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape1 = P.Reshape()
self.reshape2 = P.Reshape()
self.relu = P.ReLU()
self.depend = P.Depend()
self.mul = P.Mul().shard(((2, 4), (2, 4)))
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
self.add = P.Add().shard(((4, 2), (4, 2)))
def construct(self, x, y):
out1 = self.mul(x, self.mul_weight)
y = self.relu(y)
out2 = self.reshape1(y, (96, 32, 4))
out3 = self.depend(out2, out1)
out3 = self.reshape2(out3, (128, 96))
out = out1 + out3
return out
class NetWithLoss1(nn.Cell):
def __init__(self, network):
super(NetWithLoss1, self).__init__()
self.mean = P.ReduceMean(keep_dims=False)
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.mean(predict, ())
class GradWrap1(nn.Cell):
def __init__(self, network):
super(GradWrap1, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0)
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
y = Tensor(np.ones([256, 48]), dtype=ms.float32)
net = GradWrap1(NetWithLoss1(Net()))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
net.set_train()
_executor.compile(net, x, y)
net_auto = GradWrap1(NetWithLoss1(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net_auto.set_auto_parallel()
net_auto.set_train()
_executor.compile(net_auto, x, y)