!49786 fix bug of tuple structual attr

Merge pull request !49786 from lianliguang/fix-bug-of-pyexec-dynamic-input
This commit is contained in:
i-robot 2023-03-08 06:18:47 +00:00 committed by Gitee
commit 3f8c7219e6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 184 additions and 75 deletions

View File

@ -20,6 +20,7 @@
#include "backend/common/pass/convert_list_to_tuple.h"
#include "backend/common/pass/eliminate_func_data_type.h"
#include "backend/common/pass/convert_const_input_to_attr.h"
#include "backend/common/pass/add_input_structural_for_py_execute.h"
#include "backend/common/pass/custom_op_const_input_to_attr.h"
#include "backend/common/pass/custom_op_reg_info_to_attr.h"
#include "backend/common/pass/convert_tuple_output_to_maketuple.h"
@ -60,6 +61,7 @@ PassManagerPtr GetBackendCommonOptimizationPassManagerPtr(const FuncGraphPtr &gr
}
common_pm->AddPass(std::make_shared<FlattenConcatFission>());
common_pm->AddPass(std::make_shared<AddDropoutAttrs>());
common_pm->AddPass(std::make_shared<AddInputStructuralForPyExecute>());
return common_pm;
}

View File

@ -833,10 +833,8 @@ std::pair<AbstractBasePtr, int> RectifyAbstractFromStructuralAttr(const ValuePtr
}
}
AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const PrimitivePtr &prim,
AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const ValuePtr &tuple_structural,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(prim);
auto tuple_structural = prim->GetAttr(kAttrTupleInputStructural);
if (tuple_structural == nullptr) {
return input_abstract;
}
@ -858,6 +856,61 @@ AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const PrimitivePtr &
return rectifyed_abs_list;
}
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(prim);
auto dyn_input_list = prim->GetAttr(kAttrDynInputSizes);
if (dyn_input_list == nullptr) {
return input_abstract;
}
AbstractBasePtrList rectifyed_abs_list;
const int kNotDynamicFlag = -1;
auto dynamic_input_index = GetValue<std::vector<int64_t>>(dyn_input_list);
size_t input_index = 0;
for (auto item : dynamic_input_index) {
if (item == kNotDynamicFlag) {
if (input_index >= input_abstract.size()) {
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
<< " is out of range in input abstract " << input_abstract.size();
continue;
}
MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
<< " is out of range in input abstract " << input_abstract.size();
}
(void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
} else {
if (item < 0) {
MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got "
<< item;
}
AbstractBasePtrList dynamic_inputs_abs;
for (auto index = item; index > 0; --index) {
if (input_index >= input_abstract.size()) {
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
<< " is out of range in input abstract " << input_abstract.size();
continue;
}
MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
<< " is out of range in input abstract " << input_abstract.size();
}
(void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
}
(void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
}
}
return rectifyed_abs_list;
}
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &input_abstract) {
auto input_structural = prim->GetAttr(kAttrTupleInputStructural);
if (input_structural != nullptr) {
return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract);
}
return RectifyAbstractFromDynamicInput(prim, input_abstract);
}
} // namespace
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
@ -1046,7 +1099,7 @@ void CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spe
if (found.has_value()) {
auto infer = found.value();
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-shape implement for backend!");
auto infer_spec_list = RectifyAbstractFromTupleInputStructural(prim_clone, args_spec_list);
auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
if (common::AnfAlgo::IsDynamicSequence(cnode)) {
out_abs = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
} else {
@ -1087,7 +1140,7 @@ AbstractBasePtr CppInferShapeAndType(const PrimitivePtr &prim, const AbstractBas
if (found.has_value()) {
auto infer = found.value();
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
auto infer_spec_list = RectifyAbstractFromTupleInputStructural(prim_clone, args_spec_list);
auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
auto ret = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
if (prim_clone != prim) {
*prim = *prim_clone;

View File

@ -0,0 +1,68 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/pass/add_input_structural_for_py_execute.h"
#include <memory>
#include <vector>
#include "utils/hash_set.h"
#include "backend/common/optimizer/const_input_to_attr.h"
#include "include/common/utils/anfalgo.h"
namespace mindspore {
namespace opt {
namespace {
ValuePtr SetInputStructuralFromAbstract(const AbstractBasePtr &abs) {
if (abs->isa<abstract::AbstractSequence>()) {
auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
std::vector<ValuePtr> structural;
for (size_t index = 0; index < seq_abs->size(); ++index) {
(void)structural.emplace_back(SetInputStructuralFromAbstract((*seq_abs)[index]));
}
return std::make_shared<ValueTuple>(structural);
}
return MakeValue<int64_t>(-1);
}
} // namespace
const BaseRef AddInputStructuralForPyExecute::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimPyExecute, Xs});
}
const AnfNodePtr AddInputStructuralForPyExecute::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPyExecute)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (common::AnfAlgo::HasNodeAttr(kAttrTupleInputStructural, cnode)) {
return nullptr;
}
std::vector<ValuePtr> input_structurals;
for (size_t index = 1; index < cnode->size(); ++index) {
auto input_node = cnode->input(index);
auto abstract = input_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractMonad>()) {
(void)input_structurals.emplace_back(SetInputStructuralFromAbstract(abstract));
}
}
auto input_structural = std::make_shared<ValueTuple>(input_structurals);
common::AnfAlgo::SetNodeAttr(kAttrTupleInputStructural, input_structural, cnode);
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,34 @@
/**
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_INPUT_STRUCTURAL_OF_PY_EXECUTE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_INPUT_STRUCTURAL_OF_PY_EXECUTE_H_
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class AddInputStructuralForPyExecute : public PatternProcessPass {
public:
explicit AddInputStructuralForPyExecute(bool multigraph = true)
: PatternProcessPass("inset_input_structural_for_py_execute", multigraph) {}
~AddInputStructuralForPyExecute() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_CONST_INPUT_TO_ATTR_H_

View File

@ -1507,38 +1507,37 @@ void UnfoldKernelBuildInfo(const CNodePtr &kernel_node) {
kernel_node);
}
std::tuple<ValuePtr, int64_t, bool> CalOutputTupleSize(const AnfNodePtr &node) {
int64_t CalOutputTupleSize(const AnfNodePtr &node) {
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
bool skip = (is_bprop_cut && node->abstract()->isa<abstract::AbstractSparseTensor>());
if (skip) {
return std::make_tuple(MakeValue<int64_t>(-1), -1, false);
if (skip || !common::AnfAlgo::IsTupleOutput(node)) {
return -1;
}
const auto &real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(real_node);
if (build_info != nullptr) {
auto output_object = AnfAlgo::GetOutputKernelObjectType(real_node, 0);
if (output_object != kernel::KernelObjectType::TUPLE_UNFOLD) {
return std::make_tuple(MakeValue<int64_t>(-1), 1, false);
return -1;
}
}
auto output_size = AnfAlgo::GetOutputElementNum(node);
if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
std::vector<ValuePtr> dyn_input_size;
output_size = 0;
auto make_tuple = node->cast<CNodePtr>();
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
int64_t input_dyn_size = 0;
for (size_t j = 0; j < tuple_input_num; ++j) {
// using for graph kernel
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
auto [tuple_structual, input_size, is_dyn_input] = CalOutputTupleSize(dyn_input_node);
input_dyn_size += input_size;
(void)dyn_input_size.emplace_back(tuple_structual);
MS_LOG(DEBUG) << "Tuple structural:" << tuple_structual->ToString() << ", input size:" << input_size
<< ", is dyn size:" << is_dyn_input;
// Handle tuple nested scenes.
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
output_size += CalOutputTupleSize(dyn_input_node);
} else {
output_size++;
}
}
return std::make_tuple(std::make_shared<ValueTuple>(dyn_input_size), input_dyn_size, true);
}
output_size = output_size == 0 ? 1 : output_size;
return std::make_tuple(MakeValue<int64_t>(-1), SizeToLong(output_size), true);
return output_size == 0 ? -1 : output_size;
}
void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
@ -1547,28 +1546,19 @@ void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
return;
}
std::vector<ValuePtr> tuple_placeholder; // Record Tuple Structural of the node input
std::vector<int64_t> dyn_input_size;
std::vector<int64_t> dyn_input_sizes;
auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
bool is_dyn_input = false;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
auto [input_structural, input_size, dyn_input] = CalOutputTupleSize(input_node);
is_dyn_input |= dyn_input;
tuple_placeholder.push_back(input_structural);
(void)dyn_input_size.emplace_back(input_size);
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
dyn_input_sizes.push_back(CalOutputTupleSize(input_node));
} else {
tuple_placeholder.push_back(MakeValue<int64_t>(-1));
(void)dyn_input_size.emplace_back(-1);
dyn_input_sizes.push_back(-1);
}
}
if (is_dyn_input) {
auto dyn_input_attr = std::make_shared<ValueTuple>(tuple_placeholder);
auto prim = GetCNodePrimitive(cnode);
prim->set_attr(kAttrTupleInputStructural, dyn_input_attr);
prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_input_size));
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
}
}

View File

@ -457,7 +457,7 @@ BACKEND_EXPORT std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnf
BACKEND_EXPORT TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type);
KernelObjectType StringToKernelObjectType(const std::string &object_type);
BACKEND_EXPORT void UnfoldKernelBuildInfo(const CNodePtr &kernel_node);
BACKEND_EXPORT std::tuple<ValuePtr, int64_t, bool> CalOutputTupleSize(const AnfNodePtr &node);
BACKEND_EXPORT int64_t CalOutputTupleSize(const AnfNodePtr &node);
BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode);
BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name);
BACKEND_EXPORT std::pair<std::string, ExceptionType> KernelObjectTypeNotSupportWarning(const CNodePtr &kernel_node);

View File

@ -31,18 +31,6 @@ namespace mindspore {
namespace pynative {
namespace PyNativeAlgo {
namespace {
ValuePtr GetInputStructural(const ValuePtr &input) {
if (!input->isa<ValueSequence>()) {
return MakeValue<int64_t>(-1);
}
auto seq = input->cast_ptr<ValueSequence>();
std::vector<ValuePtr> tuple_structural;
for (size_t i = 0; i < seq->size(); ++i) {
(void)tuple_structural.emplace_back(GetInputStructural((*seq)[i]));
}
return std::make_shared<ValueTuple>(tuple_structural);
}
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
// Clone a new prim
MS_EXCEPTION_IF_NULL(op_run_info);
@ -800,8 +788,6 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
// Get input tensors.
op_prim->BeginRecordAddAttr();
std::vector<ValuePtr> inputs_structural;
bool is_dyn_input = false;
for (size_t index = 0; index < op_run_info->input_size; ++index) {
const ValuePtr &input_object = op_run_info->input_value[index];
// convert const input to attr
@ -823,16 +809,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
(void)dyn_v.emplace_back(-1);
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
}
auto input_structural = GetInputStructural(input_object);
is_dyn_input = true;
(void)inputs_structural.emplace_back(input_structural);
} else {
(void)inputs_structural.emplace_back(MakeValue<int64_t>(-1));
}
}
if (is_dyn_input) {
op_prim->set_attr(kAttrTupleInputStructural, std::make_shared<ValueTuple>(inputs_structural));
}
op_prim->EndRecordAddAttr();
ReplaceValueNodeWithParameter(op_run_info, device_target);
ReplaceReduceAxis(op_run_info);

View File

@ -1141,27 +1141,13 @@ void SetDynamicInputSizeAttrBeforeKernelSelect(const CNodePtr &cnode) {
return;
}
std::vector<int64_t> dyn_input_sizes;
std::vector<ValuePtr> inputs_structural;
size_t input_num = cnode->inputs().size() - 1;
bool is_dyn_input = false;
for (size_t i = 0; i < input_num; ++i) {
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
// Ascend using abstract to charge the node input if is dynamic input.
// GPU CPU using the KernelObjectType to charge the node input if is dynamic input.
if (common::AnfAlgo::IsTupleOutput(input_node)) {
auto [input_structural, input_size, dyn_input] = kernel::CalOutputTupleSize(input_node);
is_dyn_input |= dyn_input;
dyn_input_sizes.push_back(input_size);
(void)inputs_structural.emplace_back(input_structural);
} else {
is_dyn_input |= false;
dyn_input_sizes.push_back(-1);
(void)inputs_structural.emplace_back(MakeValue<int64_t>(-1));
}
dyn_input_sizes.emplace_back(kernel::CalOutputTupleSize(input_node));
}
if (is_dyn_input) {
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
common::AnfAlgo::SetNodeAttr(kAttrTupleInputStructural, std::make_shared<ValueTuple>(inputs_structural), cnode);
}
}
@ -1183,7 +1169,6 @@ void RefreshDynamicInputSizeAttr(const CNodePtr &cnode) {
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
} else {
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, cnode);
common::AnfAlgo::EraseNodeAttr(kAttrTupleInputStructural, cnode);
}
}
@ -1389,7 +1374,6 @@ void HandleKernelSelectFailure(const KernelGraphPtr &graph, const CNodePtr &node
// and make wrong choose, for example, the TupleToTensor op
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node)) {
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, node);
common::AnfAlgo::EraseNodeAttr(kAttrTupleInputStructural, node);
}
auto [cpu_msg, cpu_etype] = device::cpu::SetKernelInfoWithMsg(node);
if (cpu_msg.empty()) {