!67876 Supply partial shape to full shape
Merge pull request !67876 from 刘崇鸣/supply_partial_shape_to_full_shape
This commit is contained in:
commit
cb8fb912a2
|
@ -446,3 +446,4 @@ mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/
|
|||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_pooling_ops_proto.cc:ge::IMPLEMT_COMMON_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/customize/op_proto/nn_norm_ops_proto.cc:ge::CUST_IMPLEMT_INFERFUNC
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/cpu_kernel/ms_kernel/cast.cc:aicpu::CastKernel::Compute
|
||||
mindspore/mindspore/ccsrc/frontend/parallel/step_parallel.cc:mindspore::parallel::StepParallel
|
|
@ -54,14 +54,7 @@ CNodePtr CreateShape(const AnfNodePtr &pre_cnode, const FuncGraphPtr &func_graph
|
|||
return shape_cnode;
|
||||
}
|
||||
|
||||
bool IsTargetOp(const CNodePtr &cnode, const std::string &target) {
|
||||
RETURN_IF_FALSE(cnode != nullptr);
|
||||
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
|
||||
RETURN_IF_FALSE(value_node != nullptr);
|
||||
auto prim = value_node->value()->cast<PrimitivePtr>();
|
||||
RETURN_IF_FALSE(prim != nullptr);
|
||||
return prim->name() == target;
|
||||
}
|
||||
inline bool IsTargetOp(const CNodePtr &cnode, const std::string &target) { return GetPrimName(cnode) == target; }
|
||||
|
||||
bool IsTupleGetItem(const CNodePtr &cnode) { return IsTargetOp(cnode, TUPLE_GETITEM_OP); }
|
||||
|
||||
|
@ -497,6 +490,10 @@ Status ConvertReshapeInputs(const OperatorParams ¶ms,
|
|||
}
|
||||
Shape shape_vec = GetValue<Shape>(param.first.second);
|
||||
MS_LOG(INFO) << "shape param = " << shape_vec;
|
||||
size_t dynamic_axis_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
|
||||
if (shape_vec.size() > 1 && dynamic_axis_cnt >= SIZE_TWO) {
|
||||
MS_LOG(EXCEPTION) << "The shape of Reshape op has more than one -1, cannot be supported for now.";
|
||||
}
|
||||
if (!WhetherMatchingIsNeededForReshape(shape_vec, tensor_redistribution_from_cnode)) {
|
||||
MS_LOG(INFO) << "No need to matching for " << shape_vec;
|
||||
AnfNodePtr val = NewValueNode(param.first.second);
|
||||
|
@ -880,23 +877,39 @@ Status UpdatePartialShape(const CNodePtr &cnode) {
|
|||
return Status::SUCCESS;
|
||||
}
|
||||
|
||||
CNodePtr FindPreviousCareNode(const CNodePtr ¤t, int32_t depth = 0) {
|
||||
if (depth == MAX_RECURSIVE_DEPTH) {
|
||||
return nullptr;
|
||||
}
|
||||
auto prev = current->input(1);
|
||||
// If prev is parameter maybe problem here.
|
||||
auto cnode = prev->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(cnode != nullptr, "Input of node is parameter is not supported.");
|
||||
if (!IsParallelCareNode(cnode) && (IsTargetOp(cnode, "Cast") || IsTupleGetItem(cnode))) {
|
||||
return FindPreviousCareNode(cnode, depth + 1);
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
||||
TensorInfo GetDistributeOperatorFromCNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CNodePtr target_cnode = cnode;
|
||||
if (IsTupleGetItem(cnode)) {
|
||||
if (!IsParallelCareNode(cnode)) {
|
||||
// keep search the previous node.
|
||||
auto prev_node = FindPreviousNodeAndSkipTupleGetItem(cnode);
|
||||
target_cnode = prev_node.first;
|
||||
target_cnode = FindPreviousCareNode(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(target_cnode);
|
||||
if (!target_cnode->has_user_data<OperatorInfo>()) {
|
||||
MS_LOG(EXCEPTION) << target_cnode->fullname_with_scope() << " has no operator info.";
|
||||
MS_LOG(EXCEPTION) << "Found " << cnode->fullname_with_scope() << " previous node is "
|
||||
<< target_cnode->fullname_with_scope() << " and it has no operator info.";
|
||||
}
|
||||
|
||||
OperatorInfoPtr distribute_operator = GetDistributeOperator(target_cnode);
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
|
||||
if (root_tensor_info.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Outputs number cannot be larger than 1.";
|
||||
MS_LOG(INFO) << "Outputs number cannot be larger than 1, but " << target_cnode->fullname_with_scope() << " has "
|
||||
<< root_tensor_info.size() << " outputs.";
|
||||
}
|
||||
return root_tensor_info[0];
|
||||
}
|
||||
|
@ -921,7 +934,12 @@ Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
|
|||
if (shape_user == nullptr) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user), "Only support TupleGetItem here.");
|
||||
if (IsReshapeOp(shape_user)) {
|
||||
MS_LOG(WARNING) << "Won't supply shape for Reshape.";
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user),
|
||||
"Only support TupleGetItem here, but got " + GetPrimName(shape_user));
|
||||
int64_t index = GetTupleGetItemIndex(shape_user);
|
||||
if (LongToSize(index) >= tensor_map.GetDimSize()) {
|
||||
MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
|
||||
|
@ -944,7 +962,7 @@ Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
|
|||
next_node.second, // shape_user_user[input_index] = scalar_mul_op
|
||||
shape_user, // insert scalar_mul_op between previous and current
|
||||
shape_user_user->func_graph(), // current func_graph
|
||||
"instance_name", "", nullptr);
|
||||
"update_partial_shape", "", nullptr);
|
||||
}
|
||||
}
|
||||
return Status::SUCCESS;
|
||||
|
|
|
@ -683,8 +683,7 @@ Operator CreateScalarMulOp(int64_t scalar) {
|
|||
OperatorAttrs operator_attrs;
|
||||
OperatorParams operator_param;
|
||||
constexpr size_t parameter_pos = 2;
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scalar);
|
||||
ValuePtr scale_value = MakeValue(tensor_ptr);
|
||||
ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(scalar));
|
||||
(void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
|
||||
OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
|
|
|
@ -3101,6 +3101,11 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|||
// save strategy as checkpoint for multi-train
|
||||
CheckpointStrategy(all_nodes, root);
|
||||
|
||||
if (MergeEntireShapeForDynamic(root) != Status::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Merge entire shape for dynamic shape failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// ForwardCommunication BackwardCommunication TensorRedistribution
|
||||
ParallelCommunication(root, all_nodes, manager);
|
||||
|
||||
|
|
Loading…
Reference in New Issue