forked from mindspore-Ecosystem/mindspore
!49786 fix bug of tuple structual attr
Merge pull request !49786 from lianliguang/fix-bug-of-pyexec-dynamic-input
This commit is contained in:
commit
3f8c7219e6
|
@ -20,6 +20,7 @@
|
||||||
#include "backend/common/pass/convert_list_to_tuple.h"
|
#include "backend/common/pass/convert_list_to_tuple.h"
|
||||||
#include "backend/common/pass/eliminate_func_data_type.h"
|
#include "backend/common/pass/eliminate_func_data_type.h"
|
||||||
#include "backend/common/pass/convert_const_input_to_attr.h"
|
#include "backend/common/pass/convert_const_input_to_attr.h"
|
||||||
|
#include "backend/common/pass/add_input_structural_for_py_execute.h"
|
||||||
#include "backend/common/pass/custom_op_const_input_to_attr.h"
|
#include "backend/common/pass/custom_op_const_input_to_attr.h"
|
||||||
#include "backend/common/pass/custom_op_reg_info_to_attr.h"
|
#include "backend/common/pass/custom_op_reg_info_to_attr.h"
|
||||||
#include "backend/common/pass/convert_tuple_output_to_maketuple.h"
|
#include "backend/common/pass/convert_tuple_output_to_maketuple.h"
|
||||||
|
@ -60,6 +61,7 @@ PassManagerPtr GetBackendCommonOptimizationPassManagerPtr(const FuncGraphPtr &gr
|
||||||
}
|
}
|
||||||
common_pm->AddPass(std::make_shared<FlattenConcatFission>());
|
common_pm->AddPass(std::make_shared<FlattenConcatFission>());
|
||||||
common_pm->AddPass(std::make_shared<AddDropoutAttrs>());
|
common_pm->AddPass(std::make_shared<AddDropoutAttrs>());
|
||||||
|
common_pm->AddPass(std::make_shared<AddInputStructuralForPyExecute>());
|
||||||
return common_pm;
|
return common_pm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -833,10 +833,8 @@ std::pair<AbstractBasePtr, int> RectifyAbstractFromStructuralAttr(const ValuePtr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const PrimitivePtr &prim,
|
AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const ValuePtr &tuple_structural,
|
||||||
const AbstractBasePtrList &input_abstract) {
|
const AbstractBasePtrList &input_abstract) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
|
||||||
auto tuple_structural = prim->GetAttr(kAttrTupleInputStructural);
|
|
||||||
if (tuple_structural == nullptr) {
|
if (tuple_structural == nullptr) {
|
||||||
return input_abstract;
|
return input_abstract;
|
||||||
}
|
}
|
||||||
|
@ -858,6 +856,61 @@ AbstractBasePtrList RectifyAbstractFromTupleInputStructural(const PrimitivePtr &
|
||||||
|
|
||||||
return rectifyed_abs_list;
|
return rectifyed_abs_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
|
||||||
|
const AbstractBasePtrList &input_abstract) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto dyn_input_list = prim->GetAttr(kAttrDynInputSizes);
|
||||||
|
if (dyn_input_list == nullptr) {
|
||||||
|
return input_abstract;
|
||||||
|
}
|
||||||
|
AbstractBasePtrList rectifyed_abs_list;
|
||||||
|
const int kNotDynamicFlag = -1;
|
||||||
|
auto dynamic_input_index = GetValue<std::vector<int64_t>>(dyn_input_list);
|
||||||
|
size_t input_index = 0;
|
||||||
|
for (auto item : dynamic_input_index) {
|
||||||
|
if (item == kNotDynamicFlag) {
|
||||||
|
if (input_index >= input_abstract.size()) {
|
||||||
|
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
|
||||||
|
MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
|
||||||
|
<< " is out of range in input abstract " << input_abstract.size();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
|
||||||
|
<< " is out of range in input abstract " << input_abstract.size();
|
||||||
|
}
|
||||||
|
(void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
|
||||||
|
} else {
|
||||||
|
if (item < 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "The dynamic input size check error the index should be -1 or positive number but got "
|
||||||
|
<< item;
|
||||||
|
}
|
||||||
|
AbstractBasePtrList dynamic_inputs_abs;
|
||||||
|
for (auto index = item; index > 0; --index) {
|
||||||
|
if (input_index >= input_abstract.size()) {
|
||||||
|
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
|
||||||
|
MS_LOG(WARNING) << "For primitive \'PyExecute\', index " << input_index
|
||||||
|
<< " is out of range in input abstract " << input_abstract.size();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "For primitive \'" << prim->name() << "\', index " << input_index
|
||||||
|
<< " is out of range in input abstract " << input_abstract.size();
|
||||||
|
}
|
||||||
|
(void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
|
||||||
|
}
|
||||||
|
(void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rectifyed_abs_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtrList RectifyAbstract(const PrimitivePtr &prim, const AbstractBasePtrList &input_abstract) {
|
||||||
|
auto input_structural = prim->GetAttr(kAttrTupleInputStructural);
|
||||||
|
if (input_structural != nullptr) {
|
||||||
|
return RectifyAbstractFromTupleInputStructural(input_structural, input_abstract);
|
||||||
|
}
|
||||||
|
return RectifyAbstractFromDynamicInput(prim, input_abstract);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||||
|
@ -1046,7 +1099,7 @@ void CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spe
|
||||||
if (found.has_value()) {
|
if (found.has_value()) {
|
||||||
auto infer = found.value();
|
auto infer = found.value();
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-shape implement for backend!");
|
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-shape implement for backend!");
|
||||||
auto infer_spec_list = RectifyAbstractFromTupleInputStructural(prim_clone, args_spec_list);
|
auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
|
||||||
if (common::AnfAlgo::IsDynamicSequence(cnode)) {
|
if (common::AnfAlgo::IsDynamicSequence(cnode)) {
|
||||||
out_abs = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
|
out_abs = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1087,7 +1140,7 @@ AbstractBasePtr CppInferShapeAndType(const PrimitivePtr &prim, const AbstractBas
|
||||||
if (found.has_value()) {
|
if (found.has_value()) {
|
||||||
auto infer = found.value();
|
auto infer = found.value();
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
|
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
|
||||||
auto infer_spec_list = RectifyAbstractFromTupleInputStructural(prim_clone, args_spec_list);
|
auto infer_spec_list = RectifyAbstract(prim_clone, args_spec_list);
|
||||||
auto ret = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
|
auto ret = infer.InferShapeAndType(nullptr, prim_clone, infer_spec_list);
|
||||||
if (prim_clone != prim) {
|
if (prim_clone != prim) {
|
||||||
*prim = *prim_clone;
|
*prim = *prim_clone;
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/common/pass/add_input_structural_for_py_execute.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "utils/hash_set.h"
|
||||||
|
#include "backend/common/optimizer/const_input_to_attr.h"
|
||||||
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
ValuePtr SetInputStructuralFromAbstract(const AbstractBasePtr &abs) {
|
||||||
|
if (abs->isa<abstract::AbstractSequence>()) {
|
||||||
|
auto seq_abs = abs->cast_ptr<abstract::AbstractSequence>();
|
||||||
|
std::vector<ValuePtr> structural;
|
||||||
|
for (size_t index = 0; index < seq_abs->size(); ++index) {
|
||||||
|
(void)structural.emplace_back(SetInputStructuralFromAbstract((*seq_abs)[index]));
|
||||||
|
}
|
||||||
|
return std::make_shared<ValueTuple>(structural);
|
||||||
|
}
|
||||||
|
return MakeValue<int64_t>(-1);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
const BaseRef AddInputStructuralForPyExecute::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
return VectorRef({prim::kPrimPyExecute, Xs});
|
||||||
|
}
|
||||||
|
const AnfNodePtr AddInputStructuralForPyExecute::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPyExecute)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
if (common::AnfAlgo::HasNodeAttr(kAttrTupleInputStructural, cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<ValuePtr> input_structurals;
|
||||||
|
for (size_t index = 1; index < cnode->size(); ++index) {
|
||||||
|
auto input_node = cnode->input(index);
|
||||||
|
auto abstract = input_node->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
if (!abstract->isa<abstract::AbstractMonad>()) {
|
||||||
|
(void)input_structurals.emplace_back(SetInputStructuralFromAbstract(abstract));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto input_structural = std::make_shared<ValueTuple>(input_structurals);
|
||||||
|
common::AnfAlgo::SetNodeAttr(kAttrTupleInputStructural, input_structural, cnode);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_INPUT_STRUCTURAL_OF_PY_EXECUTE_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_INPUT_STRUCTURAL_OF_PY_EXECUTE_H_
|
||||||
|
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "backend/common/optimizer/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class AddInputStructuralForPyExecute : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit AddInputStructuralForPyExecute(bool multigraph = true)
|
||||||
|
: PatternProcessPass("inset_input_structural_for_py_execute", multigraph) {}
|
||||||
|
~AddInputStructuralForPyExecute() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_CONST_INPUT_TO_ATTR_H_
|
|
@ -1507,38 +1507,37 @@ void UnfoldKernelBuildInfo(const CNodePtr &kernel_node) {
|
||||||
kernel_node);
|
kernel_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<ValuePtr, int64_t, bool> CalOutputTupleSize(const AnfNodePtr &node) {
|
int64_t CalOutputTupleSize(const AnfNodePtr &node) {
|
||||||
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
|
bool is_bprop_cut = common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimBpropCut);
|
||||||
bool skip = (is_bprop_cut && node->abstract()->isa<abstract::AbstractSparseTensor>());
|
bool skip = (is_bprop_cut && node->abstract()->isa<abstract::AbstractSparseTensor>());
|
||||||
if (skip) {
|
if (skip || !common::AnfAlgo::IsTupleOutput(node)) {
|
||||||
return std::make_tuple(MakeValue<int64_t>(-1), -1, false);
|
return -1;
|
||||||
}
|
}
|
||||||
const auto &real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
const auto &real_node = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
|
||||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(real_node);
|
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(real_node);
|
||||||
if (build_info != nullptr) {
|
if (build_info != nullptr) {
|
||||||
auto output_object = AnfAlgo::GetOutputKernelObjectType(real_node, 0);
|
auto output_object = AnfAlgo::GetOutputKernelObjectType(real_node, 0);
|
||||||
if (output_object != kernel::KernelObjectType::TUPLE_UNFOLD) {
|
if (output_object != kernel::KernelObjectType::TUPLE_UNFOLD) {
|
||||||
return std::make_tuple(MakeValue<int64_t>(-1), 1, false);
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto output_size = AnfAlgo::GetOutputElementNum(node);
|
auto output_size = AnfAlgo::GetOutputElementNum(node);
|
||||||
if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
if (node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||||
std::vector<ValuePtr> dyn_input_size;
|
output_size = 0;
|
||||||
auto make_tuple = node->cast<CNodePtr>();
|
auto make_tuple = node->cast<CNodePtr>();
|
||||||
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
size_t tuple_input_num = common::AnfAlgo::GetInputTensorNum(make_tuple);
|
||||||
int64_t input_dyn_size = 0;
|
|
||||||
for (size_t j = 0; j < tuple_input_num; ++j) {
|
for (size_t j = 0; j < tuple_input_num; ++j) {
|
||||||
|
// using for graph kernel
|
||||||
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
auto dyn_input_node = common::AnfAlgo::GetInputNode(make_tuple, j);
|
||||||
auto [tuple_structual, input_size, is_dyn_input] = CalOutputTupleSize(dyn_input_node);
|
// Handle tuple nested scenes.
|
||||||
input_dyn_size += input_size;
|
if (dyn_input_node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(dyn_input_node, prim::kPrimMakeTuple)) {
|
||||||
(void)dyn_input_size.emplace_back(tuple_structual);
|
output_size += CalOutputTupleSize(dyn_input_node);
|
||||||
MS_LOG(DEBUG) << "Tuple structural:" << tuple_structual->ToString() << ", input size:" << input_size
|
} else {
|
||||||
<< ", is dyn size:" << is_dyn_input;
|
output_size++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_tuple(std::make_shared<ValueTuple>(dyn_input_size), input_dyn_size, true);
|
|
||||||
}
|
}
|
||||||
output_size = output_size == 0 ? 1 : output_size;
|
return output_size == 0 ? -1 : output_size;
|
||||||
return std::make_tuple(MakeValue<int64_t>(-1), SizeToLong(output_size), true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
|
void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
|
||||||
|
@ -1547,28 +1546,19 @@ void SetDynamicInputSizeAttr(const CNodePtr &cnode) {
|
||||||
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
|
common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<ValuePtr> tuple_placeholder; // Record Tuple Structural of the node input
|
std::vector<int64_t> dyn_input_sizes;
|
||||||
std::vector<int64_t> dyn_input_size;
|
|
||||||
auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
|
auto input_obj_types = AnfAlgo::GetInputKernelObjectTypes(cnode);
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||||
bool is_dyn_input = false;
|
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
|
||||||
if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
|
if (i < input_obj_types.size() && input_obj_types[i] == kernel::KernelObjectType::TUPLE_UNFOLD) {
|
||||||
auto [input_structural, input_size, dyn_input] = CalOutputTupleSize(input_node);
|
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
||||||
is_dyn_input |= dyn_input;
|
dyn_input_sizes.push_back(CalOutputTupleSize(input_node));
|
||||||
tuple_placeholder.push_back(input_structural);
|
|
||||||
(void)dyn_input_size.emplace_back(input_size);
|
|
||||||
} else {
|
} else {
|
||||||
tuple_placeholder.push_back(MakeValue<int64_t>(-1));
|
dyn_input_sizes.push_back(-1);
|
||||||
(void)dyn_input_size.emplace_back(-1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_dyn_input) {
|
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||||
auto dyn_input_attr = std::make_shared<ValueTuple>(tuple_placeholder);
|
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
||||||
auto prim = GetCNodePrimitive(cnode);
|
|
||||||
prim->set_attr(kAttrTupleInputStructural, dyn_input_attr);
|
|
||||||
prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_input_size));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -457,7 +457,7 @@ BACKEND_EXPORT std::vector<KernelObjectType> TypeIdToKernelObjectTypeForTupleUnf
|
||||||
BACKEND_EXPORT TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type);
|
BACKEND_EXPORT TypeId KernelObjectTypeToTypeId(const KernelObjectType &object_type);
|
||||||
KernelObjectType StringToKernelObjectType(const std::string &object_type);
|
KernelObjectType StringToKernelObjectType(const std::string &object_type);
|
||||||
BACKEND_EXPORT void UnfoldKernelBuildInfo(const CNodePtr &kernel_node);
|
BACKEND_EXPORT void UnfoldKernelBuildInfo(const CNodePtr &kernel_node);
|
||||||
BACKEND_EXPORT std::tuple<ValuePtr, int64_t, bool> CalOutputTupleSize(const AnfNodePtr &node);
|
BACKEND_EXPORT int64_t CalOutputTupleSize(const AnfNodePtr &node);
|
||||||
BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode);
|
BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode);
|
||||||
BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name);
|
BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name);
|
||||||
BACKEND_EXPORT std::pair<std::string, ExceptionType> KernelObjectTypeNotSupportWarning(const CNodePtr &kernel_node);
|
BACKEND_EXPORT std::pair<std::string, ExceptionType> KernelObjectTypeNotSupportWarning(const CNodePtr &kernel_node);
|
||||||
|
|
|
@ -31,18 +31,6 @@ namespace mindspore {
|
||||||
namespace pynative {
|
namespace pynative {
|
||||||
namespace PyNativeAlgo {
|
namespace PyNativeAlgo {
|
||||||
namespace {
|
namespace {
|
||||||
ValuePtr GetInputStructural(const ValuePtr &input) {
|
|
||||||
if (!input->isa<ValueSequence>()) {
|
|
||||||
return MakeValue<int64_t>(-1);
|
|
||||||
}
|
|
||||||
auto seq = input->cast_ptr<ValueSequence>();
|
|
||||||
std::vector<ValuePtr> tuple_structural;
|
|
||||||
for (size_t i = 0; i < seq->size(); ++i) {
|
|
||||||
(void)tuple_structural.emplace_back(GetInputStructural((*seq)[i]));
|
|
||||||
}
|
|
||||||
return std::make_shared<ValueTuple>(tuple_structural);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
|
void ClonePrim(const FrontendOpRunInfoPtr &op_run_info) {
|
||||||
// Clone a new prim
|
// Clone a new prim
|
||||||
MS_EXCEPTION_IF_NULL(op_run_info);
|
MS_EXCEPTION_IF_NULL(op_run_info);
|
||||||
|
@ -800,8 +788,6 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
||||||
|
|
||||||
// Get input tensors.
|
// Get input tensors.
|
||||||
op_prim->BeginRecordAddAttr();
|
op_prim->BeginRecordAddAttr();
|
||||||
std::vector<ValuePtr> inputs_structural;
|
|
||||||
bool is_dyn_input = false;
|
|
||||||
for (size_t index = 0; index < op_run_info->input_size; ++index) {
|
for (size_t index = 0; index < op_run_info->input_size; ++index) {
|
||||||
const ValuePtr &input_object = op_run_info->input_value[index];
|
const ValuePtr &input_object = op_run_info->input_value[index];
|
||||||
// convert const input to attr
|
// convert const input to attr
|
||||||
|
@ -823,16 +809,8 @@ void DataConvert::GetInputTensor(const FrontendOpRunInfoPtr &op_run_info, const
|
||||||
(void)dyn_v.emplace_back(-1);
|
(void)dyn_v.emplace_back(-1);
|
||||||
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
op_prim->set_attr(kAttrDynInputSizes, MakeValue(dyn_v));
|
||||||
}
|
}
|
||||||
auto input_structural = GetInputStructural(input_object);
|
|
||||||
is_dyn_input = true;
|
|
||||||
(void)inputs_structural.emplace_back(input_structural);
|
|
||||||
} else {
|
|
||||||
(void)inputs_structural.emplace_back(MakeValue<int64_t>(-1));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_dyn_input) {
|
|
||||||
op_prim->set_attr(kAttrTupleInputStructural, std::make_shared<ValueTuple>(inputs_structural));
|
|
||||||
}
|
|
||||||
op_prim->EndRecordAddAttr();
|
op_prim->EndRecordAddAttr();
|
||||||
ReplaceValueNodeWithParameter(op_run_info, device_target);
|
ReplaceValueNodeWithParameter(op_run_info, device_target);
|
||||||
ReplaceReduceAxis(op_run_info);
|
ReplaceReduceAxis(op_run_info);
|
||||||
|
|
|
@ -1141,27 +1141,13 @@ void SetDynamicInputSizeAttrBeforeKernelSelect(const CNodePtr &cnode) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::vector<int64_t> dyn_input_sizes;
|
std::vector<int64_t> dyn_input_sizes;
|
||||||
std::vector<ValuePtr> inputs_structural;
|
|
||||||
size_t input_num = cnode->inputs().size() - 1;
|
size_t input_num = cnode->inputs().size() - 1;
|
||||||
bool is_dyn_input = false;
|
|
||||||
for (size_t i = 0; i < input_num; ++i) {
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
auto input_node = common::AnfAlgo::GetInputNode(cnode, i);
|
||||||
// Ascend using abstract to charge the node input if is dynamic input.
|
dyn_input_sizes.emplace_back(kernel::CalOutputTupleSize(input_node));
|
||||||
// GPU CPU using the KernelObjectType to charge the node input if is dynamic input.
|
|
||||||
if (common::AnfAlgo::IsTupleOutput(input_node)) {
|
|
||||||
auto [input_structural, input_size, dyn_input] = kernel::CalOutputTupleSize(input_node);
|
|
||||||
is_dyn_input |= dyn_input;
|
|
||||||
dyn_input_sizes.push_back(input_size);
|
|
||||||
(void)inputs_structural.emplace_back(input_structural);
|
|
||||||
} else {
|
|
||||||
is_dyn_input |= false;
|
|
||||||
dyn_input_sizes.push_back(-1);
|
|
||||||
(void)inputs_structural.emplace_back(MakeValue<int64_t>(-1));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (is_dyn_input) {
|
if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int64_t s) { return s >= 0; })) {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrTupleInputStructural, std::make_shared<ValueTuple>(inputs_structural), cnode);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1183,7 +1169,6 @@ void RefreshDynamicInputSizeAttr(const CNodePtr &cnode) {
|
||||||
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
common::AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode);
|
||||||
} else {
|
} else {
|
||||||
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, cnode);
|
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, cnode);
|
||||||
common::AnfAlgo::EraseNodeAttr(kAttrTupleInputStructural, cnode);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1389,7 +1374,6 @@ void HandleKernelSelectFailure(const KernelGraphPtr &graph, const CNodePtr &node
|
||||||
// and make wrong choose, for example, the TupleToTensor op
|
// and make wrong choose, for example, the TupleToTensor op
|
||||||
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node)) {
|
if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node)) {
|
||||||
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, node);
|
common::AnfAlgo::EraseNodeAttr(kAttrDynInputSizes, node);
|
||||||
common::AnfAlgo::EraseNodeAttr(kAttrTupleInputStructural, node);
|
|
||||||
}
|
}
|
||||||
auto [cpu_msg, cpu_etype] = device::cpu::SetKernelInfoWithMsg(node);
|
auto [cpu_msg, cpu_etype] = device::cpu::SetKernelInfoWithMsg(node);
|
||||||
if (cpu_msg.empty()) {
|
if (cpu_msg.empty()) {
|
||||||
|
|
Loading…
Reference in New Issue