!67876 Supply partial shape to full shape

Merge pull request !67876 from 刘崇鸣/supply_partial_shape_to_full_shape
This commit is contained in:
i-robot 2024-04-03 02:19:35 +00:00 committed by Gitee
commit cb8fb912a2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 41 additions and 18 deletions

View File

@ -445,4 +445,5 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/selection_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_pooling_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_norm_ops_proto.cc:ge::CUST_IMPLEMT_INFERFUNC
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cast.cc:aicpu::CastKernel::Compute
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cast.cc:aicpu::CastKernel::Compute
mindspore/mindspore/ccsrc/frontend/parallel/step_parallel.cc:mindspore::parallel::StepParallel

View File

@ -54,14 +54,7 @@ CNodePtr CreateShape(const AnfNodePtr &pre_cnode, const FuncGraphPtr &func_graph
return shape_cnode;
}
bool IsTargetOp(const CNodePtr &cnode, const std::string &target) {
RETURN_IF_FALSE(cnode != nullptr);
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
RETURN_IF_FALSE(value_node != nullptr);
auto prim = value_node->value()->cast<PrimitivePtr>();
RETURN_IF_FALSE(prim != nullptr);
return prim->name() == target;
}
inline bool IsTargetOp(const CNodePtr &cnode, const std::string &target) { return GetPrimName(cnode) == target; }
bool IsTupleGetItem(const CNodePtr &cnode) { return IsTargetOp(cnode, TUPLE_GETITEM_OP); }
@ -497,6 +490,10 @@ Status ConvertReshapeInputs(const OperatorParams &params,
}
Shape shape_vec = GetValue<Shape>(param.first.second);
MS_LOG(INFO) << "shape param = " << shape_vec;
size_t dynamic_axis_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
if (shape_vec.size() > 1 && dynamic_axis_cnt >= SIZE_TWO) {
MS_LOG(EXCEPTION) << "The shape of Reshape op has more than one -1, cannot be supported for now.";
}
if (!WhetherMatchingIsNeededForReshape(shape_vec, tensor_redistribution_from_cnode)) {
MS_LOG(INFO) << "No need to matching for " << shape_vec;
AnfNodePtr val = NewValueNode(param.first.second);
@ -880,23 +877,39 @@ Status UpdatePartialShape(const CNodePtr &cnode) {
return Status::SUCCESS;
}
CNodePtr FindPreviousCareNode(const CNodePtr &current, int32_t depth = 0) {
if (depth == MAX_RECURSIVE_DEPTH) {
return nullptr;
}
auto prev = current->input(1);
// If prev is parameter maybe problem here.
auto cnode = prev->cast<CNodePtr>();
MS_EXCEPTION_IF_CHECK_FAIL(cnode != nullptr, "Input of node is parameter is not supported.");
if (!IsParallelCareNode(cnode) && (IsTargetOp(cnode, "Cast") || IsTupleGetItem(cnode))) {
return FindPreviousCareNode(cnode, depth + 1);
}
return cnode;
}
TensorInfo GetDistributeOperatorFromCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
CNodePtr target_cnode = cnode;
if (IsTupleGetItem(cnode)) {
if (!IsParallelCareNode(cnode)) {
// keep search the previous node.
auto prev_node = FindPreviousNodeAndSkipTupleGetItem(cnode);
target_cnode = prev_node.first;
target_cnode = FindPreviousCareNode(cnode);
}
MS_EXCEPTION_IF_NULL(target_cnode);
if (!target_cnode->has_user_data<OperatorInfo>()) {
MS_LOG(EXCEPTION) << target_cnode->fullname_with_scope() << " has no operator info.";
MS_LOG(EXCEPTION) << "Found " << cnode->fullname_with_scope() << " previous node is "
<< target_cnode->fullname_with_scope() << " and it has no operator info.";
}
OperatorInfoPtr distribute_operator = GetDistributeOperator(target_cnode);
MS_EXCEPTION_IF_NULL(distribute_operator);
std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
if (root_tensor_info.size() != 1) {
MS_LOG(EXCEPTION) << "Outputs number cannot be larger than 1.";
MS_LOG(INFO) << "Outputs number cannot be larger than 1, but " << target_cnode->fullname_with_scope() << " has "
<< root_tensor_info.size() << " outputs.";
}
return root_tensor_info[0];
}
@ -921,7 +934,12 @@ Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
if (shape_user == nullptr) {
continue;
}
MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user), "Only support TupleGetItem here.");
if (IsReshapeOp(shape_user)) {
MS_LOG(WARNING) << "Won't supply shape for Reshape.";
continue;
}
MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user),
"Only support TupleGetItem here, but got " + GetPrimName(shape_user));
int64_t index = GetTupleGetItemIndex(shape_user);
if (LongToSize(index) >= tensor_map.GetDimSize()) {
MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
@ -944,7 +962,7 @@ Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
next_node.second, // shape_user_user[input_index] = scalar_mul_op
shape_user, // insert scalar_mul_op between previous and current
shape_user_user->func_graph(), // current func_graph
"instance_name", "", nullptr);
"update_partial_shape", "", nullptr);
}
}
return Status::SUCCESS;

View File

@ -683,8 +683,7 @@ Operator CreateScalarMulOp(int64_t scalar) {
OperatorAttrs operator_attrs;
OperatorParams operator_param;
constexpr size_t parameter_pos = 2;
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scalar);
ValuePtr scale_value = MakeValue(tensor_ptr);
ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(scalar));
(void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);

View File

@ -3101,6 +3101,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// save strategy as checkpoint for multi-train
CheckpointStrategy(all_nodes, root);
if (MergeEntireShapeForDynamic(root) != Status::SUCCESS) {
MS_LOG(ERROR) << "Merge entire shape for dynamic shape failed.";
return false;
}
// ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager);