!48281 fix ops support tuple review
Merge pull request !48281 from TuDouNi/tuple_commit_my
This commit is contained in:
commit
58a5017a1b
|
@ -198,7 +198,7 @@ void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type) {
|
|||
std::vector<KernelObjectType> input_obj_type;
|
||||
std::vector<KernelObjectType> output_obj_type;
|
||||
GenerateKernelObjectTypeForNewCNode(cnode, &input_obj_type, &output_obj_type);
|
||||
builder->SetKernelType(CPU_KERNEL);
|
||||
builder->SetKernelType(UNKNOWN_KERNEL_TYPE);
|
||||
builder->SetInputsKernelObjectType(input_obj_type);
|
||||
builder->SetOutputsKernelObjectType(output_obj_type);
|
||||
|
||||
|
|
|
@ -320,8 +320,7 @@ bool MatchObjectType(const kernel::KernelObjectType &node_object, const kernel::
|
|||
return true;
|
||||
}
|
||||
|
||||
if ((node_object == kernel::TUPLE || node_object == kernel::TUPLE_UNFOLD || node_object == kernel::SCALAR) &&
|
||||
(kernel_object == kernel::TENSOR)) {
|
||||
if (node_object == kernel::SCALAR && kernel_object == kernel::TENSOR) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -392,7 +391,7 @@ bool MatchObjectType(const CNodePtr &cnode, const std::shared_ptr<kernel::Kernel
|
|||
auto node_output_object_type = AnfAlgo::GetAbstractObjectType(cnode->abstract());
|
||||
std::vector<kernel::KernelObjectType> new_output_object_types = {};
|
||||
|
||||
if (node_output_object_type == kObjectTypeTuple) {
|
||||
if (node_output_object_type == kObjectTypeTuple && kernel_outputs_object_type[0] != kernel::KernelObjectType::TUPLE) {
|
||||
auto tuple_abs = cnode->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abs);
|
||||
auto items = tuple_abs->elements();
|
||||
|
|
|
@ -53,7 +53,9 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
|
|||
if (is_fold) {
|
||||
bool is_match = true;
|
||||
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
||||
is_match = false;
|
||||
if (kernel_info->GetInputNum() != fold_input_tensor_num) {
|
||||
is_match = false;
|
||||
}
|
||||
} else {
|
||||
// compare input num
|
||||
std::vector<int64_t> dyn_input_sizes =
|
||||
|
|
|
@ -223,7 +223,6 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) {
|
|||
|
||||
void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
|
||||
ir_fusion_pm->AddPass(std::make_shared<AscendConvertTupleInputToDynamicInput>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<UnsortedSegmentSumReplace>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SingleBatchNormFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
|
||||
|
|
|
@ -657,11 +657,7 @@ void SelectCallInlineKernelInfo(const CNodePtr &node) {
|
|||
for (size_t i = 0; i < AnfUtils::GetOutputTensorNum(node); ++i) {
|
||||
output_formats.push_back(AnfAlgo::GetOutputFormat(sub_ret, i));
|
||||
output_types.push_back(common::AnfAlgo::GetOutputInferDataType(sub_ret, i));
|
||||
if (AnfAlgo::GetOutputObjectType(node, i) == TypeId::kObjectTypeTuple) {
|
||||
output_object_types.push_back(kernel::KernelObjectType::TUPLE_UNFOLD);
|
||||
} else {
|
||||
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
output_object_types.push_back(kernel::KernelObjectType::TENSOR);
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
|
|
|
@ -38,9 +38,9 @@ const AnfNodePtr AscendConvertTupleInputToDynamicInput::Process(const FuncGraphP
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// this pass should be in front of concat_fission, pack_fission, addn_fission, since the input should be unfold before
|
||||
// this passes.
|
||||
// the auto_monad pass should before this pass
|
||||
// since the input should be unfold before sone function, this pass should be in front of concat_fission,
|
||||
// pack_fission, addn_fission, and HandleControlFlow
|
||||
|
||||
bool is_communication_op = common::AnfAlgo::IsCommunicationOp(node);
|
||||
static const PrimitiveSet need_unfold_node = {prim::kPrimAddN, prim::kPrimConcatD, prim::kPrimPack,
|
||||
prim::kPrimStack, prim::kPrimCallInline, prim::kPrimPrint,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
add = P.Add()
|
||||
addn = P.AddN()
|
||||
mul = P.Mul()
|
||||
|
||||
|
||||
def add_net(x1, x2, x3, x4, x5):
|
||||
|
@ -23,5 +23,5 @@ def add_net(x1, x2, x3, x4, x5):
|
|||
sum2 = add(sum1, x3)
|
||||
sum3 = add(sum2, x4)
|
||||
sum4 = add(sum3, x5)
|
||||
ret = addn((sum4, sum1, sum2))
|
||||
ret = mul(sum4, sum1)
|
||||
return ret
|
||||
|
|
|
@ -48,7 +48,7 @@ TEST_F(TestMemUsageAnalyzer, test_mem_usage_analyzer) {
|
|||
auto tensor_infos = analyzer->GetMemUsageTensorInfos();
|
||||
|
||||
ASSERT_EQ(5, kernel_infos.size());
|
||||
ASSERT_EQ(16, tensor_infos.size());
|
||||
ASSERT_EQ(15, tensor_infos.size());
|
||||
for (size_t i = 0; i < kernel_infos.size(); ++i) {
|
||||
ASSERT_NE(nullptr, analyzer->GetMemUsageKernelInfo(i));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue