fix ops support tuple review

This commit is contained in:
ttudu 2023-02-01 09:52:46 +08:00
parent eabf510dd5
commit bd88a6e07c
8 changed files with 13 additions and 17 deletions

View File

@ -198,7 +198,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
std::vector<KernelObjectType> input_obj_type;
std::vector<KernelObjectType> output_obj_type;
GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
builder->SetKernelType(CPU_KERNEL);
builder->SetKernelType(UNKNOWN_KERNEL_TYPE);
builder->SetInputsKernelObjectType(input_obj_type);
builder->SetOutputsKernelObjectType(output_obj_type);

View File

@ -320,8 +320,7 @@ bool MatchObjectType(const kernel::KernelObjectType &node_object, const kernel::
return true;
}
if ((node_object == kernel::TUPLE || node_object == kernel::TUPLE_UNFOLD || node_object == kernel::SCALAR) &&
(kernel_object == kernel::TENSOR)) {
if (node_object == kernel::SCALAR && kernel_object == kernel::TENSOR) {
return true;
}
@ -392,7 +391,7 @@ bool MatchObjectType(const CNodePtr &cnode, const std::shared_ptr<kernel::Kernel
auto node_output_object_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
std::vector<kernel::KernelObjectType> new_output_object_types = {};
if (node_output_object_type == kObjectTypeTuple) {
if (node_output_object_type == kObjectTypeTuple && kernel_outputs_object_type[0] != kernel::KernelObjectType::TUPLE) {
auto tuple_abs = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abs);
auto items = tuple_abs->elements();

View File

@ -53,7 +53,9 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
if (is_fold) {
bool is_match = true;
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
is_match = false;
if (kernel_info->GetInputNum() != fold_input_tensor_num) {
is_match = false;
}
} else {
// compare input num
std::vector<int64_t> dyn_input_sizes =

View File

@ -223,7 +223,6 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<AscendConvertTupleInputToDynamicInput>());
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());

View File

@ -657,11 +657,7 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
if (AnfAlgo::GetOutputObjectType(node, i) == TypeId::kObjectTypeTuple) {
output_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
} else {
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
}
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);

View File

@ -38,9 +38,9 @@ const AnfNodePtr AscendConvertTupleInputToDynamicInput::Process(const FuncGraphP
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// this pass should be in front of concat_fission, pack_fission, addn_fission, since the input should be unfold before
// this passes.
// the auto_monad pass should before this pass
// since the input should be unfold before sone function, this pass should be in front of concat_fission,
// pack_fission, addn_fission, and HandleControlFlow
bool is_communication_op = common::AnfAlgo::IsCommunicationOp(node);
static const PrimitiveSet need_unfold_node = {prim::kPrimAddN, prim::kPrimConcatD, prim::kPrimPack,
prim::kPrimStack, prim::kPrimCallInline, prim::kPrimPrint,

View File

@ -15,7 +15,7 @@
from mindspore.ops import operations as P
add = P.Add()
addn = P.AddN()
mul = P.Mul()
def add_net(x1, x2, x3, x4, x5):
@ -23,5 +23,5 @@ def add_net(x1, x2, x3, x4, x5):
sum2 = add(sum1, x3)
sum3 = add(sum2, x4)
sum4 = add(sum3, x5)
ret = addn((sum4, sum1, sum2))
ret = mul(sum4, sum1)
return ret

View File

@ -48,7 +48,7 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
ASSERT_EQ(5, kernel_infos.size());
ASSERT_EQ(16, tensor_infos.size());
ASSERT_EQ(15, tensor_infos.size());
for (size_t i = 0; i < kernel_infos.size(); ++i) {
ASSERT_NE(nullptr, analyzer->GetMemUsageKernelInfo(i));
}