forked from mindspore-Ecosystem/mindspore
delete ENABLE_TUPLE_UNFOLD
This commit is contained in:
parent
390f6e35ba
commit
b294db6b05
|
@ -76,8 +76,6 @@ if(ENABLE_ASAN)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_compile_definitions(ENABLE_TUPLE_UNFOLD)
|
|
||||||
|
|
||||||
if(DEBUG_MODE)
|
if(DEBUG_MODE)
|
||||||
set(CMAKE_BUILD_TYPE "Debug")
|
set(CMAKE_BUILD_TYPE "Debug")
|
||||||
add_compile_definitions(MEM_REUSE_DEBUG)
|
add_compile_definitions(MEM_REUSE_DEBUG)
|
||||||
|
|
|
@ -55,14 +55,9 @@ PassManagerPtr GetBackendCommonOptimizationPassManagerPtr(const FuncGraphPtr &gr
|
||||||
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
|
||||||
common_pm->AddPass(std::make_shared<ConvertUnusedTupleParaToMakeTuple>());
|
common_pm->AddPass(std::make_shared<ConvertUnusedTupleParaToMakeTuple>());
|
||||||
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
|
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
MS_LOG(INFO) << "Enable tuple unfold.";
|
|
||||||
if (graph->has_flag(kAttrMutableKernel) || graph->has_flag(kFlagEnableRunGraphBySingleOp)) {
|
if (graph->has_flag(kAttrMutableKernel) || graph->has_flag(kFlagEnableRunGraphBySingleOp)) {
|
||||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
|
|
||||||
#endif
|
|
||||||
common_pm->AddPass(std::make_shared<FlattenConcatFission>());
|
common_pm->AddPass(std::make_shared<FlattenConcatFission>());
|
||||||
common_pm->AddPass(std::make_shared<AddDropoutAttrs>());
|
common_pm->AddPass(std::make_shared<AddDropoutAttrs>());
|
||||||
return common_pm;
|
return common_pm;
|
||||||
|
|
|
@ -49,25 +49,18 @@ AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNode
|
||||||
return make_tuple;
|
return make_tuple;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
bool IsKerenlGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
bool IsKerenlGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
const auto &outputs =
|
const auto &outputs =
|
||||||
common::AnfAlgo::GetAllOutputIndexByReturnTypes(func_graph->output(), {prim::kPrimTupleGetItem});
|
common::AnfAlgo::GetAllOutputIndexByReturnTypes(func_graph->output(), {prim::kPrimTupleGetItem});
|
||||||
return std::find_if(outputs.begin(), outputs.end(), [&node](const auto &output) { return output.first == node; }) !=
|
return std::find_if(outputs.begin(), outputs.end(), [&node](const auto &output) { return output.first == node; }) !=
|
||||||
outputs.end();
|
outputs.end();
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
bool IsNeedConvert(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
|
bool IsNeedConvert(const FuncGraphPtr &func_graph, const AnfNodePtr &input) {
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
return (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && common::AnfAlgo::IsTupleOutput(input) &&
|
return (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && common::AnfAlgo::IsTupleOutput(input) &&
|
||||||
!common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall) &&
|
!common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall) &&
|
||||||
(input->isa<Parameter>() || input->isa<ValueNode>() || IsKerenlGraphOutput(func_graph, input)));
|
(input->isa<Parameter>() || input->isa<ValueNode>() || IsKerenlGraphOutput(func_graph, input)));
|
||||||
#else
|
|
||||||
return (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && common::AnfAlgo::IsTupleOutput(input) &&
|
|
||||||
!common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
|
@ -139,7 +139,6 @@ tensor::TensorPtr GetForwardOutputTensor(const AnfNodePtr &node) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
size_t GetOutputTensorNumByKernelInfo(const AnfNodePtr &node) {
|
size_t GetOutputTensorNumByKernelInfo(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node->kernel_info());
|
MS_EXCEPTION_IF_NULL(node->kernel_info());
|
||||||
|
@ -149,7 +148,6 @@ size_t GetOutputTensorNumByKernelInfo(const AnfNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(build_info);
|
MS_EXCEPTION_IF_NULL(build_info);
|
||||||
return build_info->GetAllOutputDeviceTypes().size();
|
return build_info->GetAllOutputDeviceTypes().size();
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
|
AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) {
|
||||||
|
@ -196,7 +194,6 @@ void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
|
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
size_t res;
|
size_t res;
|
||||||
TypePtr type = node->Type();
|
TypePtr type = node->Type();
|
||||||
|
@ -225,13 +222,9 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
|
||||||
res = 1;
|
res = 1;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
#else
|
|
||||||
return AnfUtils::GetOutputTensorNum(node);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AnfRuntimeAlgorithm::GetOutputNumWithoutKernelInfo(const AnfNodePtr &node) {
|
size_t AnfRuntimeAlgorithm::GetOutputNumWithoutKernelInfo(const AnfNodePtr &node) {
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
const auto &kernel_info = node->kernel_info();
|
const auto &kernel_info = node->kernel_info();
|
||||||
if (kernel_info != nullptr) {
|
if (kernel_info != nullptr) {
|
||||||
|
@ -261,9 +254,6 @@ size_t AnfRuntimeAlgorithm::GetOutputNumWithoutKernelInfo(const AnfNodePtr &node
|
||||||
res = 1;
|
res = 1;
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
#else
|
|
||||||
return AnfUtils::GetOutputTensorNum(node);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -82,13 +82,11 @@ void GetOutputDtypes(const CNodePtr &kernel_node, std::vector<TypeId> *output_ty
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
// Real tuple isn't expanded.
|
// Real tuple isn't expanded.
|
||||||
void GetOutputDtypesForRealTuple(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
|
void GetOutputDtypesForRealTuple(const CNodePtr &kernel_node, std::vector<TypeId> *output_types) {
|
||||||
TypeId dtype = common::AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
TypeId dtype = common::AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
||||||
(void)output_types->emplace_back(dtype);
|
(void)output_types->emplace_back(dtype);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
void GetOutputFormat(const CNodePtr &kernel_node, std::vector<std::string> *output_formats) {
|
void GetOutputFormat(const CNodePtr &kernel_node, std::vector<std::string> *output_formats) {
|
||||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||||
|
@ -544,7 +542,6 @@ bool SelectKernel(const CNodePtr &kernel_node, kernel::KernelAttr *selected_kern
|
||||||
bool input_matched = false;
|
bool input_matched = false;
|
||||||
for (auto kernel_attr : kernel_attrs) {
|
for (auto kernel_attr : kernel_attrs) {
|
||||||
output_types.clear();
|
output_types.clear();
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
// The real tuple and allsame don't fold the tuple.
|
// The real tuple and allsame don't fold the tuple.
|
||||||
if (kernel_attr.GetAllSame() ||
|
if (kernel_attr.GetAllSame() ||
|
||||||
(kernel_attr.GetOutputSize() != 0 && kernel_attr.GetOutputAttr(0).object_type == kObjectTypeTuple)) {
|
(kernel_attr.GetOutputSize() != 0 && kernel_attr.GetOutputAttr(0).object_type == kObjectTypeTuple)) {
|
||||||
|
@ -552,9 +549,6 @@ bool SelectKernel(const CNodePtr &kernel_node, kernel::KernelAttr *selected_kern
|
||||||
} else {
|
} else {
|
||||||
GetOutputDtypes(kernel_node, &output_types);
|
GetOutputDtypes(kernel_node, &output_types);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
GetOutputDtypes(kernel_node, &output_types);
|
|
||||||
#endif
|
|
||||||
MS_LOG(DEBUG) << "Select kernel for op: " << kernel_node->fullname_with_scope() << ", input types:" << input_types
|
MS_LOG(DEBUG) << "Select kernel for op: " << kernel_node->fullname_with_scope() << ", input types:" << input_types
|
||||||
<< ", output types:" << output_types;
|
<< ", output types:" << output_types;
|
||||||
|
|
||||||
|
@ -647,7 +641,6 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
||||||
// First select the kernel object types.
|
// First select the kernel object types.
|
||||||
std::vector<kernel::KernelAttr> object_selected_kernel_attrs;
|
std::vector<kernel::KernelAttr> object_selected_kernel_attrs;
|
||||||
const auto &kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(op_name);
|
const auto &kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(op_name);
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
if (kernel_attrs.empty()) {
|
if (kernel_attrs.empty()) {
|
||||||
return KernelNotSupportWarning(kernel_node, false);
|
return KernelNotSupportWarning(kernel_node, false);
|
||||||
} else if (kernel_attrs[0].GetSkipCheck()) {
|
} else if (kernel_attrs[0].GetSkipCheck()) {
|
||||||
|
@ -656,9 +649,6 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
||||||
!kernel::SelectKernelByObjectType(kernel_node, kernel_attrs, &object_selected_kernel_attrs, false)) {
|
!kernel::SelectKernelByObjectType(kernel_node, kernel_attrs, &object_selected_kernel_attrs, false)) {
|
||||||
return kernel::KernelObjectTypeNotSupportWarning(kernel_node);
|
return kernel::KernelObjectTypeNotSupportWarning(kernel_node);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
object_selected_kernel_attrs = kernel_attrs;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Second select the matched kernel attr.
|
// Second select the matched kernel attr.
|
||||||
kernel::KernelAttr selected_kernel_attr;
|
kernel::KernelAttr selected_kernel_attr;
|
||||||
|
|
|
@ -236,9 +236,7 @@ void CPUKernelExecutor::OptimizeGraphImpl(const KernelGraphPtr &graph) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>("insert_type_transform_op"));
|
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>("insert_type_transform_op"));
|
||||||
#endif
|
|
||||||
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
|
pm->AddPass(std::make_shared<opt::InsertFormatTransformOpCPU>("insert_format_transform_op_cpu"));
|
||||||
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
|
||||||
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
pm->AddPass(std::make_shared<opt::InsertCastCPU>("insert_cast"));
|
||||||
|
|
|
@ -593,7 +593,6 @@ bool GetSelectKernelResult(const CNodePtr &kernel_node,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType kernel_type) {
|
bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||||
// Only the kernel nodes that register kernel attr can support the backoff.
|
// Only the kernel nodes that register kernel attr can support the backoff.
|
||||||
|
@ -631,7 +630,6 @@ bool GetSelectKernelObjectTypeResult(const CNodePtr &kernel_node, KernelType ker
|
||||||
kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, object_selected_kernel_attrs[0]);
|
kernel::SetKernelObjectTypeWithSelectedAttr(kernel_node, object_selected_kernel_attrs[0]);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kernel_node, KernelType kernel_type) {
|
std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kernel_node, KernelType kernel_type) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
@ -643,12 +641,10 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
||||||
}
|
}
|
||||||
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
bool selected = GetSelectKernelObjectTypeResult(kernel_node, kernel_type);
|
bool selected = GetSelectKernelObjectTypeResult(kernel_node, kernel_type);
|
||||||
if (!selected) {
|
if (!selected) {
|
||||||
return kernel::KernelObjectTypeNotSupportWarning(kernel_node);
|
return kernel::KernelObjectTypeNotSupportWarning(kernel_node);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
std::vector<std::string> inputs_format;
|
std::vector<std::string> inputs_format;
|
||||||
std::vector<TypeId> inputs_type;
|
std::vector<TypeId> inputs_type;
|
||||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
@ -659,21 +655,17 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
||||||
|
|
||||||
std::vector<std::string> outputs_format;
|
std::vector<std::string> outputs_format;
|
||||||
std::vector<TypeId> outputs_type;
|
std::vector<TypeId> outputs_type;
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
auto output_kernel_object_types = builder->Build()->GetAllOutputKernelObjectTypes();
|
auto output_kernel_object_types = builder->Build()->GetAllOutputKernelObjectTypes();
|
||||||
if (output_kernel_object_types.size() == 1 && output_kernel_object_types[0] == kernel::KernelObjectType::TUPLE) {
|
if (output_kernel_object_types.size() == 1 && output_kernel_object_types[0] == kernel::KernelObjectType::TUPLE) {
|
||||||
outputs_type = {common::AnfAlgo::GetOutputInferDataType(kernel_node, 0)};
|
outputs_type = {common::AnfAlgo::GetOutputInferDataType(kernel_node, 0)};
|
||||||
outputs_format = {kOpFormat_DEFAULT};
|
outputs_format = {kOpFormat_DEFAULT};
|
||||||
} else {
|
} else {
|
||||||
#endif
|
|
||||||
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
size_t output_num = AnfAlgo::GetOutputElementNum(kernel_node);
|
||||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||||
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
outputs_type.push_back(common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
std::string origin_data_format = kOpFormat_DEFAULT;
|
std::string origin_data_format = kOpFormat_DEFAULT;
|
||||||
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
||||||
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
||||||
|
@ -683,12 +675,10 @@ std::pair<std::string, ExceptionType> SetKernelInfoWithMsg(const CNodePtr &kerne
|
||||||
builder->SetInputsDeviceType(inputs_type);
|
builder->SetInputsDeviceType(inputs_type);
|
||||||
builder->SetOutputsFormat(outputs_format);
|
builder->SetOutputsFormat(outputs_format);
|
||||||
builder->SetOutputsDeviceType(outputs_type);
|
builder->SetOutputsDeviceType(outputs_type);
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
kernel::UnfoldKernelBuildInfo(kernel_node);
|
kernel::UnfoldKernelBuildInfo(kernel_node);
|
||||||
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
if (!common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, kernel_node)) {
|
||||||
kernel::SetDynamicInputSizeAttr(kernel_node);
|
kernel::SetDynamicInputSizeAttr(kernel_node);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
MS_LOG(INFO) << kernel_node->fullname_with_scope() << " kernel attr info: "
|
MS_LOG(INFO) << kernel_node->fullname_with_scope() << " kernel attr info: "
|
||||||
<< kernel::FetchPrintInfoByKernelAttr(kernel::GetKernelAttrFromBuildInfo(builder->Build()));
|
<< kernel::FetchPrintInfoByKernelAttr(kernel::GetKernelAttrFromBuildInfo(builder->Build()));
|
||||||
|
|
||||||
|
|
|
@ -356,9 +356,7 @@ void GPUKernelExecutor::OptimizeGraphWithDeviceInfo(const KernelGraphPtr &graph)
|
||||||
// Graph optimization relevant to device data format
|
// Graph optimization relevant to device data format
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto pm = std::make_shared<opt::PassManager>();
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
#ifdef ENABLE_TUPLE_UNFOLD
|
|
||||||
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>("insert_type_transform_op"));
|
pm->AddPass(std::make_shared<opt::InsertTypeTransformOp>("insert_type_transform_op"));
|
||||||
#endif
|
|
||||||
// ReplaceAddNFusion depends on the input expansion of AddN, so must be after the operator select.
|
// ReplaceAddNFusion depends on the input expansion of AddN, so must be after the operator select.
|
||||||
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
|
||||||
// PrintReduceFusion depends on the input expansion of Print, so must be after the operator select.
|
// PrintReduceFusion depends on the input expansion of Print, so must be after the operator select.
|
||||||
|
|
Loading…
Reference in New Issue