forked from mindspore-Ecosystem/mindspore
!21527 fix_reshape_reshape_in_auto_parallel
Merge pull request !21527 from yao_yf/fix_reshape_reshape_in_auto_parallel
This commit is contained in:
commit
530da3e37e
|
@ -319,7 +319,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
|
|||
return false;
|
||||
}
|
||||
auto node_op_info = cnode->user_data<OperatorInfo>();
|
||||
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
|
||||
if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
|
||||
*pre_operator_info = node_op_info;
|
||||
*out_index = 0;
|
||||
return true;
|
||||
|
@ -358,7 +358,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
|
|||
// Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
|
||||
// if reshape's output connect to several primitive, return the first layout found
|
||||
bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
|
||||
size_t curr_depth) {
|
||||
bool *is_next_reshape, size_t curr_depth) {
|
||||
if (curr_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
|
||||
return false;
|
||||
|
@ -373,6 +373,10 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
|
|||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
|
||||
*is_next_reshape = true;
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
|
@ -384,6 +388,7 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
|
|||
auto op_info = use_apply->user_data<OperatorInfo>();
|
||||
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
|
||||
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
|
||||
*is_next_reshape = false;
|
||||
*next_operator_info = op_info;
|
||||
*in_index = node_pair.second - 1;
|
||||
return true;
|
||||
|
@ -391,7 +396,7 @@ bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_o
|
|||
MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " "
|
||||
<< IsParallelCareNode(use_apply) << " " << (op_info != nullptr);
|
||||
|
||||
if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, ++curr_depth)) {
|
||||
if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index, is_next_reshape, ++curr_depth)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_op
|
|||
size_t curr_depth);
|
||||
|
||||
bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index,
|
||||
size_t curr_depth);
|
||||
bool *is_next_reshape, size_t curr_depth);
|
||||
void SetUserAttrs(const std::unordered_map<std::string, ValuePtr> &origin_prim_attrs, PrimitivePtr self_prim);
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -443,7 +443,8 @@ std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t) {
|
|||
|
||||
Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs,
|
||||
int64_t out_index, int64_t in_index, bool is_prev_param) {
|
||||
int64_t out_index, int64_t in_index, bool is_prev_param,
|
||||
bool is_next_reshape) {
|
||||
is_generating_costs_ = true;
|
||||
for (auto pre_stra_cost : pre_stra_costs) {
|
||||
std::vector<TensorInfo> pre_out_tensor_infos;
|
||||
|
@ -466,7 +467,12 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
|||
}
|
||||
Strategys stra_inputs = {stra};
|
||||
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
||||
if (next_stra_costs.empty()) {
|
||||
if (is_next_reshape) {
|
||||
SetOutputLayout(pre_out_tensor_info.tensor_layout());
|
||||
ResetQueueMember();
|
||||
InferTensorInfoByLayout();
|
||||
SetCostForReshape(reshape_stra);
|
||||
} else if (next_stra_costs.empty()) {
|
||||
if (Init(nullptr) == FAILED) {
|
||||
MS_LOG(ERROR) << "Failure:operator reshape init failed";
|
||||
return FAILED;
|
||||
|
@ -481,6 +487,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
|
|||
return FAILED;
|
||||
}
|
||||
TensorInfo next_in_tensor_info = next_in_tensor_infos[LongToSize(in_index)];
|
||||
|
||||
SetOutputLayout(next_in_tensor_info.tensor_layout());
|
||||
ResetQueueMember();
|
||||
InferTensorInfoByLayout();
|
||||
|
|
|
@ -60,7 +60,7 @@ class ReshapeInfo : public OperatorInfo {
|
|||
void set_next_operator_index(int64_t next_index) { next_operator_index_ = next_index; }
|
||||
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
|
||||
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
|
||||
int64_t in_index, bool is_prev_param);
|
||||
int64_t in_index, bool is_prev_param, bool is_next_reshape);
|
||||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
Status GenerateStrategies(int64_t stage_id) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
|
||||
|
|
|
@ -874,8 +874,9 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
// get next node's strategy_cost_
|
||||
int64_t in_index = 0;
|
||||
OperatorInfoPtr next_operator_info;
|
||||
bool is_next_reshape = false;
|
||||
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
|
||||
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, 0);
|
||||
bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index, &is_next_reshape, 0);
|
||||
if (!find_next_node) {
|
||||
MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed";
|
||||
}
|
||||
|
@ -890,8 +891,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
reshape_info->set_next_operator_index(in_index);
|
||||
}
|
||||
bool is_prev_param = pre_node->isa<Parameter>();
|
||||
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) !=
|
||||
SUCCESS) {
|
||||
if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param,
|
||||
is_next_reshape) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "reshape generate strategy_costs failed!";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2233,7 +2233,7 @@ TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_
|
|||
}
|
||||
|
||||
// if reshape's output connect to several primitive, return the first layout found
|
||||
std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
|
||||
std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
|
@ -2244,6 +2244,10 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
|
|||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
|
||||
*next_is_reshape = true;
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
|
@ -2254,13 +2258,14 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
|
|||
}
|
||||
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
|
||||
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
|
||||
*next_is_reshape = false;
|
||||
auto layout = GetInputLayoutFromCNode(node_pair);
|
||||
return std::make_shared<TensorLayout>(layout);
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
||||
<< " " << use_apply->has_user_data<OperatorInfo>();
|
||||
|
||||
auto layout_ptr = FindNextLayout(use_apply);
|
||||
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
|
||||
if (layout_ptr) {
|
||||
return layout_ptr;
|
||||
}
|
||||
|
@ -2475,10 +2480,14 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
|
||||
}
|
||||
auto next_layout_ptr = FindNextLayout(cnode);
|
||||
bool is_next_reshape = false;
|
||||
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
|
||||
if (next_layout_ptr) {
|
||||
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
|
||||
} else if (is_next_reshape && prev_layout_ptr != nullptr) {
|
||||
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info_ptr->SetOutputLayout(*prev_layout_ptr);
|
||||
}
|
||||
if (operator_info->Init(nullptr) == FAILED) {
|
||||
MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed";
|
||||
|
|
|
@ -323,3 +323,57 @@ def test_reshape_auto_7():
|
|||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, x)
|
||||
|
||||
def test_reshape_depend_reshape():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape1 = P.Reshape()
|
||||
self.reshape2 = P.Reshape()
|
||||
self.relu = P.ReLU()
|
||||
self.depend = P.Depend()
|
||||
self.mul = P.Mul().shard(((2, 4), (2, 4)))
|
||||
self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
|
||||
self.add = P.Add().shard(((4, 2), (4, 2)))
|
||||
|
||||
def construct(self, x, y):
|
||||
out1 = self.mul(x, self.mul_weight)
|
||||
y = self.relu(y)
|
||||
out2 = self.reshape1(y, (96, 32, 4))
|
||||
out3 = self.depend(out2, out1)
|
||||
out3 = self.reshape2(out3, (128, 96))
|
||||
out = out1 + out3
|
||||
return out
|
||||
|
||||
class NetWithLoss1(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss1, self).__init__()
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
predict = self.network(x, y)
|
||||
return self.mean(predict, ())
|
||||
|
||||
class GradWrap1(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap1, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x, y):
|
||||
return grad_all(self.network)(x, y)
|
||||
|
||||
size = 8
|
||||
context.set_auto_parallel_context(device_num=size, global_rank=0)
|
||||
x = Tensor(np.ones([128, 96]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([256, 48]), dtype=ms.float32)
|
||||
net = GradWrap1(NetWithLoss1(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
net.set_auto_parallel()
|
||||
net.set_train()
|
||||
_executor.compile(net, x, y)
|
||||
net_auto = GradWrap1(NetWithLoss1(Net()))
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
||||
net_auto.set_auto_parallel()
|
||||
net_auto.set_train()
|
||||
_executor.compile(net_auto, x, y)
|
||||
|
|
Loading…
Reference in New Issue