support online infer

This commit is contained in:
zhengyuanhua 2022-09-01 19:24:40 +08:00
parent 2f7531af8f
commit 88a1f88e83
90 changed files with 893 additions and 303 deletions

View File

@ -411,11 +411,11 @@ if(PLATFORM_ARM64)
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL) if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -650,11 +650,11 @@ elseif(PLATFORM_ARM32)
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL) if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -840,11 +840,11 @@ else()
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR} install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
if(MSLITE_ENABLE_ACL) if(MSLITE_ENABLE_ACL)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/common/optimizer/graph_optimizer.h"
namespace mindspore {
namespace opt {
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
if (pass_manager != nullptr) {
pass_managers_.push_back(pass_manager);
}
}
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
// cppcheck-suppress *
auto manager = Manage(func_graph, true);
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 0; i < pass_managers_.size(); ++i) {
const PassManagerPtr &pm = pass_managers_[i];
if (pm != nullptr && pm->Run(func_graph)) {
changed = true;
}
}
if (run_only_once_) {
break;
}
}
std::vector<FuncGraphPtr> func_graphs;
func_graphs.push_back(func_graph);
(void)TopoSort(func_graph->get_return());
return func_graph;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_
#include <string>
#include <vector>
#include "backend/common/optimizer/pass_manager.h"
#include "include/backend/visible.h"
namespace mindspore {
namespace opt {
class BACKEND_EXPORT GraphOptimizer {
public:
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
virtual ~GraphOptimizer() = default;
void AddPassManager(const PassManagerPtr &pass_manager);
FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
private:
const std::string name_ = "graph_optimizer";
std::vector<PassManagerPtr> pass_managers_{};
bool run_only_once_ = true;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_

View File

@ -28,7 +28,7 @@ class BACKEND_EXPORT NodePass : public Pass {
public: public:
explicit NodePass(const std::string &name) : Pass(name) {} explicit NodePass(const std::string &name) : Pass(name) {}
~NodePass() override = default; ~NodePass() override = default;
bool Run(const FuncGraphPtr &func_graph) final; virtual bool Run(const FuncGraphPtr &func_graph);
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
}; };
using NodePassPtr = std::shared_ptr<NodePass>; using NodePassPtr = std::shared_ptr<NodePass>;

View File

@ -117,37 +117,5 @@ std::vector<AnfNodePtr> MultipleOutputPatternProcessPass::GetOrigNodes() const {
} }
return orig_nodes; return orig_nodes;
} }
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
if (pass_manager != nullptr) {
pass_managers_.push_back(pass_manager);
}
}
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
MS_EXCEPTION_IF_NULL(func_graph);
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
// cppcheck-suppress *
auto manager = Manage(func_graph, true);
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 0; i < pass_managers_.size(); ++i) {
const PassManagerPtr &pm = pass_managers_[i];
if (pm != nullptr && pm->Run(func_graph)) {
changed = true;
}
}
if (run_only_once_) {
break;
}
}
std::vector<FuncGraphPtr> func_graphs;
func_graphs.push_back(func_graph);
(void)TopoSort(func_graph->get_return());
return func_graph;
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -29,6 +29,7 @@
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "backend/common/optimizer/helper.h" #include "backend/common/optimizer/helper.h"
#include "backend/common/optimizer/graph_optimizer.h"
#include "include/backend/visible.h" #include "include/backend/visible.h"
namespace mindspore { namespace mindspore {
@ -77,20 +78,6 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
PrimitiveVarMapPtr child_primitive_vars_; PrimitiveVarMapPtr child_primitive_vars_;
EquivPtr child_equiv_; EquivPtr child_equiv_;
}; };
class BACKEND_EXPORT GraphOptimizer {
public:
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
virtual ~GraphOptimizer() = default;
void AddPassManager(const PassManagerPtr &pass_manager);
FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
private:
const std::string name_ = "graph_optimizer";
std::vector<PassManagerPtr> pass_managers_{};
bool run_only_once_ = true;
};
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -86,8 +86,6 @@ ShapeVector CacheManager::GetOutputShape(const AnfNodePtr &node, size_t index) {
return result; return result;
} }
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
void PassManager::AddPass(const PassPtr &pass) { void PassManager::AddPass(const PassPtr &pass) {
if (pass != nullptr) { if (pass != nullptr) {
passes_.push_back(pass); passes_.push_back(pass);

View File

@ -48,7 +48,7 @@ class BACKEND_EXPORT PassManager {
: name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {} : name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
virtual ~PassManager() = default; virtual ~PassManager() = default;
// Get all the passes added by AddPass // Get all the passes added by AddPass
const std::vector<PassPtr> &Passes() const; const std::vector<PassPtr> &Passes() const { return passes_; }
// Add graph pass, the pass object will be freed when pass manager freed. // Add graph pass, the pass object will be freed when pass manager freed.
virtual void AddPass(const PassPtr &pass); virtual void AddPass(const PassPtr &pass);
// Run passes added in pass manager on the input graph // Run passes added in pass manager on the input graph

View File

@ -82,6 +82,7 @@ class TensorDefaultImpl : public MSTensor::Impl {
const std::string &Name() const override { return name_; } const std::string &Name() const override { return name_; }
enum DataType DataType() const override { return type_; } enum DataType DataType() const override { return type_; }
const std::vector<int64_t> &Shape() const override { return shape_; } const std::vector<int64_t> &Shape() const override { return shape_; }
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; }
std::shared_ptr<const void> Data() const override { std::shared_ptr<const void> Data() const override {
return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {}); return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
@ -115,6 +116,7 @@ class TensorReferenceImpl : public MSTensor::Impl {
const std::string &Name() const override { return name_; } const std::string &Name() const override { return name_; }
enum DataType DataType() const override { return type_; } enum DataType DataType() const override { return type_; }
const std::vector<int64_t> &Shape() const override { return shape_; } const std::vector<int64_t> &Shape() const override { return shape_; }
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; }
std::shared_ptr<const void> Data() const override { std::shared_ptr<const void> Data() const override {
return std::shared_ptr<const void>(data_, [](const void *) {}); return std::shared_ptr<const void>(data_, [](const void *) {});

View File

@ -44,6 +44,7 @@ class DETensor : public mindspore::MSTensor::Impl {
size_t DataSize() const override; size_t DataSize() const override;
const std::vector<int64_t> &Shape() const override; const std::vector<int64_t> &Shape() const override;
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; };
int64_t ElementNum() const; int64_t ElementNum() const;

View File

@ -47,6 +47,13 @@ class Factory {
(void)kernel_mod_creators_.emplace(name, creator); (void)kernel_mod_creators_.emplace(name, creator);
} }
void UnRegister(const std::string &name) {
auto iter = kernel_mod_creators_.find(name);
if (iter != kernel_mod_creators_.end()) {
kernel_mod_creators_.erase(iter);
}
}
std::shared_ptr<C> Create(const std::string &name) const { std::shared_ptr<C> Create(const std::string &name) const {
typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name); typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name);
if (iter != kernel_mod_creators_.cend()) { if (iter != kernel_mod_creators_.cend()) {

View File

@ -33,6 +33,7 @@ class MSTensor::Impl {
virtual const std::string &Name() const = 0; virtual const std::string &Name() const = 0;
virtual enum DataType DataType() const = 0; virtual enum DataType DataType() const = 0;
virtual const std::vector<int64_t> &Shape() const = 0; virtual const std::vector<int64_t> &Shape() const = 0;
virtual void SetShape(const std::vector<int64_t> &shape) = 0;
virtual std::shared_ptr<const void> Data() const = 0; virtual std::shared_ptr<const void> Data() const = 0;
virtual void *MutableData() = 0; virtual void *MutableData() = 0;

View File

@ -31,6 +31,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc) ${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc)
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC} set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/mindir_model_util.cc
${MSLITE_KERNEL_PLUGIN} ${MSLITE_KERNEL_PLUGIN}
${CMAKE_CURRENT_SOURCE_DIR}/../common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/file_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cc
@ -56,6 +57,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
${CMAKE_CURRENT_SOURCE_DIR}/session/lite_infer_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/session/lite_infer_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc
${CMAKE_CURRENT_SOURCE_DIR}/convert/runtime_convert.cc
) )
# when cpu kernel is need # when cpu kernel is need
#if(NOT MSLITE_ENABLE_ACL) #if(NOT MSLITE_ENABLE_ACL)
@ -191,6 +193,10 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
add_subdirectory(delegate/tensorrt) add_subdirectory(delegate/tensorrt)
endif() endif()
if(MSLITE_ENABLE_CONVERTER)
add_subdirectory(convert)
endif()
set(TEST_CLOUD_INFER on) set(TEST_CLOUD_INFER on)
if(TEST_CLOUD_INFER) if(TEST_CLOUD_INFER)

View File

@ -0,0 +1,11 @@
include_directories(${TOP_DIR})
include_directories(${TOP_DIR}/mindspore/lite)
file(GLOB RUNTIME_CONVERT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_library(runtime_convert_plugin SHARED ${RUNTIME_CONVERT_SRC})
add_dependencies(runtime_convert_plugin fbs_inner_src)
if(MSLITE_ENABLE_CONVERTER AND NOT WIN32)
target_link_libraries(runtime_convert_plugin mindspore_converter)
endif()

View File

@ -0,0 +1,79 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "src/extendrt/convert/runtime_convert.h"
#include "tools/common/string_util.h"
#include "tools/converter/converter.h"
#include "tools/converter/cxx_api/converter_para.h"
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
const std::shared_ptr<mindspore::Context> &context) {
auto param = std::make_shared<mindspore::ConverterPara>();
if (param == nullptr) {
MS_LOG(ERROR) << "New ConverterPara failed";
return nullptr;
}
param->fmk_type = mindspore::converter::kFmkTypeMs;
param->input_data_type = mindspore::DataType::kTypeUnknown;
param->output_data_type = mindspore::DataType::kTypeUnknown;
param->weight_fp16 = false;
param->train_model = false;
param->export_mindir = mindspore::kMindIR;
param->enable_encryption = false;
auto device_list = context->MutableDeviceInfo();
for (auto &device : device_list) {
if (device->GetDeviceType() == mindspore::kAscend) {
auto ascend_info = device->Cast<mindspore::AscendDeviceInfo>();
std::string dynamic_batch_size = ascend_info->GetDynamicBatchSize();
if (!dynamic_batch_size.empty()) {
std::vector<std::string> batch_size_string = mindspore::lite::SplitStringToVector(dynamic_batch_size, ',');
for (const auto &item : batch_size_string) {
int32_t val;
if (mindspore::lite::ConvertIntNum(item, &val)) {
size_t tmp_val = static_cast<size_t>(val);
param->aclModelOptionCfgParam.dynamic_batch_size.push_back(tmp_val);
}
}
}
param->aclModelOptionCfgParam.offline = false;
param->aclModelOptionCfgParam.device_id = ascend_info->GetDeviceID();
param->aclModelOptionCfgParam.output_type = ascend_info->GetOutputType();
param->aclModelOptionCfgParam.input_shape_map = ascend_info->GetInputShapeMap();
param->aclModelOptionCfgParam.input_format = ascend_info->GetInputFormat();
param->aclModelOptionCfgParam.input_shape = ascend_info->GetInputShape();
param->aclModelOptionCfgParam.precision_mode = ascend_info->GetPrecisionMode();
param->aclModelOptionCfgParam.op_select_impl_mode = ascend_info->GetOpSelectImplMode();
param->aclModelOptionCfgParam.fusion_switch_config_file_path = ascend_info->GetFusionSwitchConfigPath();
param->aclModelOptionCfgParam.buffer_optimize = ascend_info->GetBufferOptimizeMode();
param->aclModelOptionCfgParam.insert_op_config_file_path = ascend_info->GetInsertOpConfigPath();
param->aclModelOptionCfgParam.dynamic_image_size = ascend_info->GetDynamicImageSize();
param->device = "Ascend";
param->no_fusion = false;
} else {
continue;
}
}
mindspore::lite::ConverterImpl cvt;
void *dst_buff;
auto ret = cvt.Convert(param, nullptr, model_buf, buf_size, &dst_buff, dst_size);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Convert model failed.";
return nullptr;
}
return dst_buff;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_
#include <stdio.h>
#include <string>
#include <memory>
#include "include/api/context.h"
#ifdef __cplusplus
extern "C" {
#endif
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
const std::shared_ptr<mindspore::Context> &context);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_

View File

@ -68,13 +68,15 @@ inline Status DLSoPath(const std::string &benchmark_so, const std::string &targe
return kSuccess; return kSuccess;
} }
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function) { inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function,
bool runtime_convert = false) {
// do dlopen and export functions from c_dataengine // do dlopen and export functions from c_dataengine
if (handle == nullptr) { if (handle == nullptr) {
MS_LOG(WARNING) << "Input parameter handle cannot be nullptr"; MS_LOG(WARNING) << "Input parameter handle cannot be nullptr";
return Status(kMEFailed, "Input parameter handle cannot be nullptr"); return Status(kMEFailed, "Input parameter handle cannot be nullptr");
} }
*handle = dlopen(dl_path.c_str(), RTLD_LAZY | RTLD_LOCAL); int mode = runtime_convert ? RTLD_GLOBAL : RTLD_LOCAL;
*handle = dlopen(dl_path.c_str(), RTLD_LAZY | mode);
auto get_dl_error = []() -> std::string { auto get_dl_error = []() -> std::string {
auto error = dlerror(); auto error = dlerror();

View File

@ -21,12 +21,38 @@
#include "extendrt/cxx_api/file_utils.h" #include "extendrt/cxx_api/file_utils.h"
#include "extendrt/utils/tensor_utils.h" #include "extendrt/utils/tensor_utils.h"
#include "mindspore/core/utils/ms_context.h" #include "mindspore/core/utils/ms_context.h"
#include "extendrt/mindir_loader/mindir_model/mindir_model_util.h"
#include "src/extendrt/convert/runtime_convert.h"
namespace mindspore { namespace mindspore {
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type, Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context) { const std::shared_ptr<Context> &model_context) {
const void *model_buff = model_data;
size_t model_size = data_size;
#ifndef _WIN32
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
MS_LOG(WARNING) << "Need runtime convert";
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
if (ret != kSuccess) {
MS_LOG(WARNING) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
}
void *function = nullptr;
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
if (ret != kSuccess) {
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
}
auto convert =
reinterpret_cast<void *(*)(const char *, const size_t &, size_t *, const std::shared_ptr<Context> &)>(function);
if (convert != nullptr) {
model_buff = convert(static_cast<const char *>(model_data), data_size, &model_size, model_context);
}
} else {
MS_LOG(WARNING) << "Not need runtime convert";
}
#endif
graph_ = std::make_shared<Graph>(); graph_ = std::make_shared<Graph>();
auto ret = Serialization::Load(model_data, data_size, model_type, graph_.get()); auto ret = Serialization::Load(model_buff, model_size, model_type, graph_.get());
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Serialization::Load model failed."; MS_LOG(ERROR) << "Serialization::Load model failed.";
return ret; return ret;
@ -47,7 +73,7 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
}); });
} }
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph(), model_data, data_size); return session_->CompileGraph(graph_->graph_data_->GetFuncGraph(), model_buff, model_size);
} }
Status ModelImpl::Build(const std::string &model_path, ModelType model_type, Status ModelImpl::Build(const std::string &model_path, ModelType model_type,

View File

@ -29,11 +29,21 @@
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "extendrt/infer_session.h" #include "extendrt/infer_session.h"
#ifndef _WIN32
#include <dlfcn.h>
#endif
namespace mindspore { namespace mindspore {
class ModelImpl { class ModelImpl {
public: public:
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {} ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
~ModelImpl() = default; ~ModelImpl() {
#ifndef _WIN32
if (handle_ != nullptr) {
(void)dlclose(handle_);
handle_ = nullptr;
}
#endif
}
Status Build(const void *model_data, size_t data_size, ModelType model_type, Status Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context); const std::shared_ptr<Context> &model_context);
@ -56,6 +66,9 @@ class ModelImpl {
std::shared_ptr<Graph> graph_ = nullptr; std::shared_ptr<Graph> graph_ = nullptr;
std::shared_ptr<InferSession> session_ = nullptr; std::shared_ptr<InferSession> session_ = nullptr;
std::shared_ptr<Context> context_ = nullptr; std::shared_ptr<Context> context_ = nullptr;
#ifndef _WIN32
void *handle_ = nullptr;
#endif
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_MODEL_IMPL_H_ #endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_MODEL_IMPL_H_

View File

@ -247,7 +247,5 @@ bool CustomAscendKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
UpdateOutputAddr(outputs); UpdateOutputAddr(outputs);
return true; return true;
} }
MS_KERNEL_FACTORY_REG(KernelMod, CustomAscend, CustomAscendKernelMod);
} // namespace acl } // namespace acl
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -23,6 +23,7 @@
#include "ir/value.h" #include "ir/value.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/common.h"
namespace mindspore::infer::mindir { namespace mindspore::infer::mindir {
static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{ static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
@ -191,4 +192,34 @@ mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) {
} }
return it->second; return it->second;
} }
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size) {
bool need_runtime_convert = true;
mind_ir::ModelProto model_proto;
std::string str(static_cast<const char *>(model_data), data_size);
if (model_proto.ParseFromString(str)) {
mind_ir::GraphProto *graph_proto = model_proto.mutable_graph();
if (graph_proto != nullptr) {
for (int i = 0; i < graph_proto->attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = graph_proto->attribute(i);
if (attr_proto.has_name() && attr_proto.name() == lite::kIsOptimized) {
const int attr_type = static_cast<int>(attr_proto.type());
if (attr_type != mind_ir::AttributeProto_AttributeType_BOOL) {
MS_LOG(ERROR) << "The type of attr optimized value must be bool.";
return true;
}
if (static_cast<bool>(attr_proto.i())) {
need_runtime_convert = false;
MS_LOG(DEBUG) << "No need to online infer.";
}
break;
}
}
}
} else {
MS_LOG(WARNING) << "Not mindir model";
need_runtime_convert = false;
}
return need_runtime_convert;
}
} // namespace mindspore::infer::mindir } // namespace mindspore::infer::mindir

View File

@ -33,6 +33,7 @@ class MindirModelUtil {
static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto); static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto);
static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type); static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type);
static bool NeedRuntimeConvert(const void *model_data, size_t data_size);
}; };
} // namespace mindspore::infer::mindir } // namespace mindspore::infer::mindir

View File

@ -37,6 +37,10 @@ namespace mindspore {
const size_t tensor_max_size = 0x1000000; const size_t tensor_max_size = 0x1000000;
constexpr auto kNameCustomAscend = "CustomAscend"; constexpr auto kNameCustomAscend = "CustomAscend";
SingleOpInferSession::~SingleOpInferSession() {
kernel::Factory<kernel::KernelMod>::Instance().UnRegister(kNameCustomAscend);
}
Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context) { Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context) {
auto device_list = context->MutableDeviceInfo(); auto device_list = context->MutableDeviceInfo();
for (const auto &device_info : device_list) { for (const auto &device_info : device_list) {

View File

@ -26,7 +26,7 @@ namespace mindspore {
class SingleOpInferSession : public InferSession { class SingleOpInferSession : public InferSession {
public: public:
SingleOpInferSession() = default; SingleOpInferSession() = default;
virtual ~SingleOpInferSession() = default; ~SingleOpInferSession() override;
Status Init(const std::shared_ptr<Context> &context) override; Status Init(const std::shared_ptr<Context> &context) override;
Status AscendInit(const std::shared_ptr<Context> &context); Status AscendInit(const std::shared_ptr<Context> &context);
Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override; Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override;

View File

@ -32,11 +32,11 @@ bool EraseQuotes(std::string *input_string);
bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace); bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace);
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter); MS_API std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const std::string &delimiter); std::vector<std::string> SplitStringToVector(const std::string &raw_str, const std::string &delimiter);
bool ConvertIntNum(const std::string &str, int *value); MS_API bool ConvertIntNum(const std::string &str, int *value);
bool ConvertDoubleNum(const std::string &str, double *value); bool ConvertDoubleNum(const std::string &str, double *value);

View File

@ -11,7 +11,7 @@ set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
set(CCSRC_SRC set(CCSRC_SRC
${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc
${CCSRC_DIR}/backend/common/optimizer/visit.cc ${CCSRC_DIR}/backend/common/optimizer/visit.cc
${CCSRC_DIR}/backend/common/optimizer/optimizer.cc ${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc
) )
if(MSLITE_ENABLE_GRAPH_KERNEL) if(MSLITE_ENABLE_GRAPH_KERNEL)
@ -26,6 +26,7 @@ if(MSLITE_ENABLE_GRAPH_KERNEL)
${CCSRC_DIR}/common/graph_kernel/graph_kernel_flags.cc ${CCSRC_DIR}/common/graph_kernel/graph_kernel_flags.cc
${CCSRC_DIR}/kernel/akg/akg_kernel_json_generator.cc ${CCSRC_DIR}/kernel/akg/akg_kernel_json_generator.cc
${CCSRC_DIR}/backend/common/pass/getitem_tuple.cc ${CCSRC_DIR}/backend/common/pass/getitem_tuple.cc
${CCSRC_DIR}/backend/common/optimizer/optimizer.cc
) )
set(CCSRC_SRC set(CCSRC_SRC
${CCSRC_SRC} ${CCSRC_SRC}

View File

@ -58,7 +58,7 @@ Pass *AclPassPlugin::CreateAclPass(const std::shared_ptr<ConverterPara> &param)
void *function = nullptr; void *function = nullptr;
auto ret = DLSoOpen(real_path_, "CreateAclPass", &handle_, &function); auto ret = DLSoOpen(real_path_, "CreateAclPass", &handle_, &function);
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << real_path_; MS_LOG(ERROR) << "DLSoOpen failed, so path: " << real_path_ << ", ret: " << ret;
return nullptr; return nullptr;
} }
auto create_func = reinterpret_cast<mindspore::opt::Pass *(*)(const std::shared_ptr<ConverterPara> &)>(function); auto create_func = reinterpret_cast<mindspore::opt::Pass *(*)(const std::shared_ptr<ConverterPara> &)>(function);

View File

@ -35,6 +35,7 @@
#include "plugin/device/cpu/kernel/nnacl/op_base.h" #include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/graph/specify_graph_input_format.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -57,9 +58,21 @@ constexpr size_t kDependInputNum = 3;
constexpr size_t kDependFirstInputIdx = 1; constexpr size_t kDependFirstInputIdx = 1;
constexpr size_t kTupleGetItemFirstInputIdx = 1; constexpr size_t kTupleGetItemFirstInputIdx = 1;
STATUS PreProcForMindIr(const FuncGraphPtr &func_graph) { return lite::RET_OK; } STATUS PreProcForMindIr(const FuncGraphPtr &func_graph, bool offline) { return lite::RET_OK; }
STATUS PreProcForTF(const FuncGraphPtr &func_graph) { STATUS PreProcForTF(const FuncGraphPtr &func_graph, bool offline) {
if (!offline) {
auto format_pass = std::make_shared<opt::SpecifyGraphInputFormat>(Format::NCHW, Format::NHWC);
MS_CHECK_TRUE_MSG(format_pass != nullptr, lite::RET_ERROR, "Make shared specify graph input format failed.");
if (!format_pass->Run(func_graph)) {
MS_LOG(ERROR) << "Run specify graph input format pass failed.";
return lite::RET_ERROR;
}
if (!lite::RunOptimizerPass(func_graph, {kToNHWCFormatPass, kDelRedundantTranspose})) {
MS_LOG(ERROR) << "To nhwc format failed.";
return lite::RET_ERROR;
}
}
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) { if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) {
MS_LOG(ERROR) << "Infer shape pass failed."; MS_LOG(ERROR) << "Infer shape pass failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
@ -85,7 +98,14 @@ STATUS PreProcForTF(const FuncGraphPtr &func_graph) {
return lite::RET_OK; return lite::RET_OK;
} }
STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) { STATUS PreProcForCaffe(const FuncGraphPtr &func_graph, bool offline) {
if (!offline) {
if (!lite::RunOptimizerPass(func_graph, {kDelRedundantTranspose})) {
MS_LOG(ERROR) << "Del redundant transpose failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) { if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
MS_LOG(ERROR) << "To nchw format failed."; MS_LOG(ERROR) << "To nchw format failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
@ -93,7 +113,14 @@ STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) {
return lite::RET_OK; return lite::RET_OK;
} }
STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) { STATUS PreProcForOnnx(const FuncGraphPtr &func_graph, bool offline) {
if (!offline) {
if (!lite::RunOptimizerPass(func_graph, {kDelRedundantTranspose})) {
MS_LOG(ERROR) << "Del redundant transpose failed.";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) { if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
MS_LOG(ERROR) << "To nchw format failed."; MS_LOG(ERROR) << "To nchw format failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
@ -138,14 +165,14 @@ STATUS AclPassImpl::PreProcGraph(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "Common pass failed."; MS_LOG(ERROR) << "Common pass failed.";
return lite::RET_ERROR; return lite::RET_ERROR;
} }
std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &)>> fmk_proc_func = { static std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &, bool)>> fmk_proc_func = {
{converter::kFmkTypeMs, PreProcForMindIr}, {converter::kFmkTypeTf, PreProcForTF}, {converter::kFmkTypeMs, PreProcForMindIr}, {converter::kFmkTypeTf, PreProcForTF},
{converter::kFmkTypeCaffe, PreProcForCaffe}, {converter::kFmkTypeOnnx, PreProcForOnnx}, {converter::kFmkTypeCaffe, PreProcForCaffe}, {converter::kFmkTypeOnnx, PreProcForOnnx},
{converter::kFmkTypeTflite, PreProcForTF}, {converter::kFmkTypeTflite, PreProcForTF},
}; };
if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) { if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) {
auto func = fmk_proc_func.at(fmk_type_); auto func = fmk_proc_func.at(fmk_type_);
if (func(func_graph) != lite::RET_OK) { if (func(func_graph, user_options_cfg_.offline) != lite::RET_OK) {
MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_; MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_;
return lite::RET_ERROR; return lite::RET_ERROR;
} }
@ -164,10 +191,6 @@ STATUS AclPassImpl::PostProcGraph(const FuncGraphPtr &func_graph) {
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is nullptr."); MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is nullptr.");
manager->Reset(); manager->Reset();
if (!user_options_cfg_.offline) {
MS_LOG(DEBUG) << "Online model infer no need to change to nhwc format.";
return lite::RET_OK;
}
if (fmk_type_ == converter::kFmkTypeTf) { if (fmk_type_ == converter::kFmkTypeTf) {
MS_LOG(DEBUG) << "Tf no need to change to nhwc format."; MS_LOG(DEBUG) << "Tf no need to change to nhwc format.";
return lite::RET_OK; return lite::RET_OK;

View File

@ -26,6 +26,7 @@
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/converter/optimizer_manager.h" #include "tools/converter/optimizer_manager.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/pass_manager_extends.h"
#include "ir/primitive.h" #include "ir/primitive.h"
#include "tools/optimizer/fusion/affine_activation_fusion.h" #include "tools/optimizer/fusion/affine_activation_fusion.h"
#include "tools/optimizer/fusion/affine_fusion.h" #include "tools/optimizer/fusion/affine_fusion.h"
@ -111,6 +112,10 @@
using std::string; using std::string;
namespace mindspore::lite { namespace mindspore::lite {
namespace {
constexpr auto kOriginalFmkType = "original_fmk_type";
} // namespace
AnfTransform::AnfTransform() = default; AnfTransform::AnfTransform() = default;
AnfTransform::~AnfTransform() = default; AnfTransform::~AnfTransform() = default;
@ -195,7 +200,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); auto fusion_pm = std::make_shared<opt::LitePassManager>("anf fusion pass manager", false);
CHECK_NULL_RETURN(fusion_pm); CHECK_NULL_RETURN(fusion_pm);
// The training model only does the fusion of the inference part // The training model only does the fusion of the inference part
@ -306,7 +311,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shar
return RET_OK; return RET_OK;
} }
opt::Spliter::GetInstance()->RecordGraphInfo(old_graph); opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
auto parallel_pm = std::make_shared<opt::PassManager>("anf parallel pass manager", true); auto parallel_pm = std::make_shared<opt::LitePassManager>("anf parallel pass manager", true);
CHECK_NULL_RETURN(parallel_pm); CHECK_NULL_RETURN(parallel_pm);
// 2. preceding parallel pass // 2. preceding parallel pass
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>()); parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
@ -333,7 +338,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shar
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) { int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
CHECK_NULL_RETURN(graph_pm); CHECK_NULL_RETURN(graph_pm);
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf || if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
param->fmk_type == converter::kFmkTypeOnnx) { param->fmk_type == converter::kFmkTypeOnnx) {
@ -377,7 +382,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true); auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
CHECK_NULL_RETURN(convert_pm); CHECK_NULL_RETURN(convert_pm);
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model)); convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model)); convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
@ -393,7 +398,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) { int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); auto const_fold_pm = std::make_shared<opt::LitePassManager>("const fold fusion pass manager", false);
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
CHECK_NULL_RETURN(const_fold_pm); CHECK_NULL_RETURN(const_fold_pm);
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model)); const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
@ -420,7 +425,7 @@ int RunDecreaseTransposePass(const FuncGraphPtr &old_graph, const std::shared_pt
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto decrease_trans_pm = std::make_shared<opt::PassManager>("decrease transpose fusion pass manager", false); auto decrease_trans_pm = std::make_shared<opt::LitePassManager>("decrease transpose fusion pass manager", false);
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
CHECK_NULL_RETURN(decrease_trans_pm); CHECK_NULL_RETURN(decrease_trans_pm);
decrease_trans_pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>()); decrease_trans_pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
@ -566,44 +571,53 @@ bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_
return true; return true;
} }
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, STATUS AnfTransform::ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
const std::shared_ptr<ConverterPara> &param) { if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
MS_ASSERT(old_graph != nullptr); MS_LOG(WARNING) << "Run infershape opt pass failed.";
MS_ASSERT(param != nullptr); }
auto status = DoFormatForMindIR(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do format for mindir failed.";
return lite::RET_ERROR;
}
old_graph->set_attr(kOriginalFmkType, MakeValue(static_cast<int32_t>(param->fmk_type)));
return lite::RET_OK;
}
int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto status = RunConvertPass(old_graph, param); auto status = RunConvertPass(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed."; MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr; return RET_ERROR;
} }
if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) { if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN"; MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
return nullptr; return RET_ERROR;
} }
status = RunConstFoldPass(old_graph, param); status = RunConstFoldPass(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run const fold pass failed."; MS_LOG(ERROR) << "Run const fold pass failed.";
return nullptr; return RET_ERROR;
} }
if (!RunEliminateRedundantPass(old_graph, param)) { if (!RunEliminateRedundantPass(old_graph, param)) {
MS_LOG(ERROR) << "Run elimination of redundant pass failed."; MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
return nullptr; return RET_ERROR;
} }
if (!param->no_fusion) { if (!param->no_fusion) {
status = RunFusionPass(old_graph, param); status = RunFusionPass(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run fusion pass failed."; MS_LOG(ERROR) << "Run fusion pass failed.";
return nullptr; return RET_ERROR;
} }
} }
if (!RunExternalPass(old_graph, registry::POSITION_END)) { if (!RunExternalPass(old_graph, registry::POSITION_END)) {
MS_LOG(ERROR) << "Run external pass failed, place is END"; MS_LOG(ERROR) << "Run external pass failed, place is END";
return nullptr; return RET_ERROR;
} }
if (!RunOptimizerPass(old_graph, {"InferShapePass"})) { if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
@ -616,18 +630,37 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run transpose opt pass failed."; MS_LOG(ERROR) << "Run transpose opt pass failed.";
return nullptr; return RET_ERROR;
} }
status = RunGraphPass(old_graph, param); status = RunGraphPass(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed."; MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr; return RET_ERROR;
} }
status = RunParallelPass(old_graph, param); status = RunParallelPass(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed."; MS_LOG(ERROR) << "Run convert pass failed.";
return RET_ERROR;
}
return RET_OK;
}
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
const std::shared_ptr<ConverterPara> &param) {
MS_ASSERT(old_graph != nullptr);
MS_ASSERT(param != nullptr);
if (param->no_fusion && param->export_mindir == kMindIR) {
if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
MS_LOG(ERROR) << "Proc online transform failed.";
return nullptr;
}
return old_graph;
}
if (RunPass(old_graph, param) != RET_OK) {
MS_LOG(ERROR) << "Proc online transform failed.";
return nullptr; return nullptr;
} }
@ -636,7 +669,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
return nullptr; return nullptr;
} }
status = QATTransform(old_graph, param); auto status = QATTransform(old_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do QATTransform failed."; MS_LOG(ERROR) << "Do QATTransform failed.";
return nullptr; return nullptr;
@ -700,6 +733,11 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> &param)
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> &param) { FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> &param) {
MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr"); MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr");
MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr"); MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr");
if (main_graph->has_attr(kOriginalFmkType)) {
auto val_ptr = main_graph->get_attr(kOriginalFmkType);
MS_CHECK_TRUE_MSG(val_ptr != nullptr, nullptr, "Val ptr is nullptr.");
param->fmk_type = static_cast<converter::FmkType>(GetValue<int32_t>(val_ptr));
}
if (!StoreBuiltinPass(param)) { if (!StoreBuiltinPass(param)) {
MS_LOG(ERROR) << "store pass failed."; MS_LOG(ERROR) << "store pass failed.";
return nullptr; return nullptr;

View File

@ -39,6 +39,8 @@ class AnfTransform {
private: private:
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param); FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param); static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param); static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
@ -66,6 +68,8 @@ class AnfTransform {
static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param); static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param);
static bool CheckExternalExtension(const std::shared_ptr<ConverterPara> &param); static bool CheckExternalExtension(const std::shared_ptr<ConverterPara> &param);
static STATUS ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -127,47 +127,11 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara>
return func_graph; return func_graph;
} }
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph,
const void *buf, const size_t &size) {
if (param == nullptr || buf == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr";
return RET_ERROR;
}
auto graph = BuildFuncGraph(param, buf, size);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed.");
// funcgraph_transform
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
// export protobuf
if (param->export_mindir == kMindIR) {
auto status = UpdateFuncGraphInputAndOutputNames(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
return RET_ERROR;
}
status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Export to mindir proto failed";
return RET_ERROR;
} else {
MS_LOG(DEBUG) << "Export to mindir success";
return RET_OK;
}
}
*meta_graph = TransferFuncGraph(param, graph);
return RET_OK;
}
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph) { int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph) {
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr"; MS_LOG(ERROR) << "Input param is nullptr";
return RET_ERROR; return RET_ERROR;
} }
param->aclModelOptionCfgParam.om_file_path = param->output_file; param->aclModelOptionCfgParam.om_file_path = param->output_file;
if (!param->config_file.empty() || !param->config_param.empty()) { if (!param->config_file.empty() || !param->config_param.empty()) {
auto ret = InitConfigParam(param); auto ret = InitConfigParam(param);
@ -176,7 +140,6 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::
return RET_ERROR; return RET_ERROR;
} }
} }
// load plugin // load plugin
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders; static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
if (!param->plugins_path.empty()) { if (!param->plugins_path.empty()) {
@ -191,20 +154,19 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::
dl_loaders.emplace_back(dl_loader); dl_loaders.emplace_back(dl_loader);
} }
} }
auto graph = BuildFuncGraph(param); auto graph = BuildFuncGraph(param);
if (graph == nullptr) { return FuncGraphConvert(param, graph, meta_graph, false, nullptr, nullptr);
MS_LOG(ERROR) << "Parser/Import model return nullptr"; }
return RET_ERROR;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed"); int ConverterImpl::FuncGraphConvert(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr graph,
// funcgraph transform schema::MetaGraphT **meta_graph, bool isRuntimeConvert, void **buff, size_t *size) {
graph = funcgraph_transform_->Transform(graph, param); if (param == nullptr || graph == nullptr) {
if (graph == nullptr) { MS_LOG(ERROR) << "Input param or graph is nullptr";
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return RET_ERROR; return RET_ERROR;
} }
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed");
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
// export protobuf // export protobuf
if (param->export_mindir == kMindIR) { if (param->export_mindir == kMindIR) {
@ -218,16 +180,15 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::
MS_LOG(ERROR) << "Update input and output names of funcgraph failed."; MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
return RET_ERROR; return RET_ERROR;
} }
status = MindIRSerialize(param, graph); status = MindIRSerialize(param, graph, isRuntimeConvert, buff, size);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Export to mindir failed"; MS_LOG(ERROR) << "Export to mindir failed";
return RET_ERROR; return RET_ERROR;
} else {
MS_LOG(DEBUG) << "Export to mindir success";
return RET_OK;
} }
} else { // fb
*meta_graph = TransferFuncGraph(param, graph);
} }
*meta_graph = TransferFuncGraph(param, graph); MS_LOG(DEBUG) << "FuncGraph convert success";
return RET_OK; return RET_OK;
} }

View File

@ -53,13 +53,19 @@ class ConverterImpl {
this->model_parser_ = nullptr; this->model_parser_ = nullptr;
} }
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph); int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph);
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, const void *buf, int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, const void *buff,
const size_t &size); const size_t &size, void **dst_buff, size_t *dst_size) {
auto graph = BuildFuncGraph(param, buff, size);
return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size);
}
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph); int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
private: private:
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param); FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param);
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param, const void *buf, const size_t &size); FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param, const void *buf, const size_t &size);
int FuncGraphConvert(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr graph, schema::MetaGraphT **meta_graph,
bool isRuntimeConvert, void **buff, size_t *size);
schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr func_graph); schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr func_graph);
int InitConfigParam(const std::shared_ptr<ConverterPara> &param); int InitConfigParam(const std::shared_ptr<ConverterPara> &param);

View File

@ -26,6 +26,7 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "tools/lite_exporter/anf_exporter.h" #include "tools/lite_exporter/anf_exporter.h"
#include "tools/optimizer/common/pass_manager_extends.h"
#include "tools/converter/graphdef_transform.h" #include "tools/converter/graphdef_transform.h"
#include "tools/converter/optimizer_manager.h" #include "tools/converter/optimizer_manager.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
@ -250,7 +251,7 @@ STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPar
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer); CHECK_NULL_RETURN(optimizer);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
CHECK_NULL_RETURN(graph_pm); CHECK_NULL_RETURN(graph_pm);
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf || if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
param->fmk_type == converter::kFmkTypeOnnx) { param->fmk_type == converter::kFmkTypeOnnx) {

View File

@ -35,6 +35,7 @@
#include "tools/converter/quantizer/quant_param_holder.h" #include "tools/converter/quantizer/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/format/to_format_base.h" #include "tools/optimizer/format/to_format_base.h"
#include "tools/optimizer/common/pass_manager_extends.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "src/common/common.h" #include "src/common/common.h"
@ -89,7 +90,7 @@ int CommonAnfAdjust(const FuncGraphPtr &func_graph) {
{ {
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>(); auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr."); MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr.");
auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false); auto asylic_pm = std::make_shared<opt::LitePassManager>("asylic pass manager", false);
MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr."); MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr.");
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic // fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic

View File

@ -134,7 +134,7 @@ int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> &param, const Fu
MS_LOG(ERROR) << "error occur when check condition of saving together."; MS_LOG(ERROR) << "error occur when check condition of saving together.";
return ret; return ret;
} }
is_fusion_ = !param->no_fusion;
if (save_together_) { if (save_together_) {
ret = SaveMindIRTogether(); ret = SaveMindIRTogether();
} else { } else {
@ -383,12 +383,15 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
} }
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) { int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
if (isRuntimeConvert_) {
return RET_OK;
}
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph(); mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute(); mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
if (attr_proto != nullptr) { if (attr_proto != nullptr) {
attr_proto->set_name(kIsOptimized); attr_proto->set_name(kIsOptimized);
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->set_i(1); attr_proto->set_i(is_fusion_);
} }
auto realpath = Common::CreatePrefixPath(output_file, true); auto realpath = Common::CreatePrefixPath(output_file, true);
@ -414,8 +417,32 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
return RET_OK; return RET_OK;
} }
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph) { int MindIRSerializer::GetBuffAndSize(void **buff, size_t *size) {
mindspore::lite::MindIRSerializer serializer; if (buff == nullptr || size == nullptr) {
return serializer.Save(param, func_graph); MS_LOG(ERROR) << "param is nullptr";
return RET_ERROR;
}
*size = model_proto_.ByteSize();
*buff = malloc(*size);
if (*buff == nullptr) {
MS_LOG(ERROR) << "Malloc fail";
return RET_ERROR;
}
model_proto_.SerializeToArray(*buff, *size);
return RET_OK;
}
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph, bool isRuntimeConvert,
void **buff, size_t *size) {
mindspore::lite::MindIRSerializer serializer(isRuntimeConvert);
auto ret = serializer.Save(param, func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MindIR serialize fail";
return ret;
}
if (isRuntimeConvert) {
return serializer.GetBuffAndSize(buff, size);
}
return RET_OK;
} }
} // namespace mindspore::lite } // namespace mindspore::lite

View File

@ -31,7 +31,7 @@
namespace mindspore::lite { namespace mindspore::lite {
class MindIRSerializer { class MindIRSerializer {
public: public:
MindIRSerializer() = default; explicit MindIRSerializer(bool isRuntimeConvert) : isRuntimeConvert_(isRuntimeConvert) {}
virtual ~MindIRSerializer() { virtual ~MindIRSerializer() {
if (data_fs_ != nullptr) { if (data_fs_ != nullptr) {
data_fs_->close(); data_fs_->close();
@ -40,6 +40,7 @@ class MindIRSerializer {
} }
} }
int Save(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph); int Save(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph);
int GetBuffAndSize(void **buff, size_t *size);
private: private:
int ParserPath(const std::string &output_path); int ParserPath(const std::string &output_path);
@ -59,6 +60,8 @@ class MindIRSerializer {
int RemoveQuantParameterHolder(FuncGraphPtr func_graph); int RemoveQuantParameterHolder(FuncGraphPtr func_graph);
private: private:
bool isRuntimeConvert_ = false;
bool is_fusion_ = true;
std::string model_name_; std::string model_name_;
std::string save_path_; std::string save_path_;
std::string save_model_path_; std::string save_model_path_;
@ -72,6 +75,7 @@ class MindIRSerializer {
std::shared_ptr<system::FileSystem> fs_{}; std::shared_ptr<system::FileSystem> fs_{};
}; };
// export func_graph // export func_graph
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph); int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph, bool isRuntimeConvert,
void **buff, size_t *size);
} // namespace mindspore::lite } // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_ #endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_

View File

@ -33,6 +33,7 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "tools/converter/parser/parser_utils.h" #include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/helper.h"
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "ops/custom.h" #include "ops/custom.h"
@ -709,7 +710,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false; return false;
} }
auto output_node_list = GetRealNodeUsedList(graph, node); auto output_node_list = Helper::GetRealNodeUsedList(graph, node);
if (output_node_list == nullptr) { if (output_node_list == nullptr) {
MS_LOG(ERROR) << "output node list is nullptr"; MS_LOG(ERROR) << "output node list is nullptr";
return false; return false;

View File

@ -15,16 +15,15 @@
*/ */
#define USE_DEPRECATED_API #define USE_DEPRECATED_API
#include "backend/common/optimizer/helper.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tools/optimizer/common/gllo_utils.h" #include <algorithm>
#include "tools/optimizer/common/helper.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { ValueNodePtr Helper::CreateValueNodeWithSexp(const BaseRef &sexp) {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) { if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp)); return NewValueNode(utils::cast<int>(sexp));
} }
@ -40,7 +39,7 @@ ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
return nullptr; return nullptr;
} }
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) { CNodePtr Helper::CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) { if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph)); return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
} }
@ -50,7 +49,7 @@ CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const
return nullptr; return nullptr;
} }
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { VarNodePtr Helper::CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) { if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr); return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
@ -63,8 +62,8 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
return nullptr; return nullptr;
} }
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, AnfNodePtr Helper::HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) { bool multigraph) {
if (primitive_vars == nullptr) { if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr; return nullptr;
@ -90,6 +89,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
return CreateCNodeWithGraph(input_nodes, graph); return CreateCNodeWithGraph(input_nodes, graph);
} }
namespace {
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) { bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
auto a_value_node = a_node->cast<ValueNodePtr>(); auto a_value_node = a_node->cast<ValueNodePtr>();
auto b_value_node = b_node->cast<ValueNodePtr>(); auto b_value_node = b_node->cast<ValueNodePtr>();
@ -139,6 +139,7 @@ bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
} }
} }
} // namespace } // namespace
// not implement for lite, just for api compatible // not implement for lite, just for api compatible
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg, CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes) { const std::vector<AnfNodePtr> &orig_nodes) {
@ -152,6 +153,11 @@ CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::ve
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) { const AnfNodePtr &node) {
return Helper::GetRealNodeUsedList(graph, node);
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> Helper::GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
if (graph == nullptr || node == nullptr) { if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr."; MS_LOG(ERROR) << "input parameter is nullptr.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -175,9 +181,8 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
return output_node_list; return output_node_list;
} }
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> Helper::GetRealNodeUsedListByOutputIdx(
const AnfNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index) {
size_t output_index) {
if (graph == nullptr || node == nullptr) { if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr."; MS_LOG(ERROR) << "input parameter is nullptr.";
return nullptr; return nullptr;
@ -236,15 +241,12 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
return a == b; return a == b;
} }
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
// To matchCNode and Kernel's type return Helper::SexpToNode(sexp, graph, primitive_vars, multigraph);
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return a.type() == b.type();
} }
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { AnfNodePtr Helper::SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
if (primitive_vars == nullptr) { if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_
#include <utility>
#include <memory>
#include <vector>
#include "backend/common/optimizer/helper.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class Helper {
public:
static std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node);
static std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(
const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index);
static AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph);
private:
static ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp);
static CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph);
static VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph);
static AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_

View File

@ -15,12 +15,10 @@
*/ */
#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "tools/optimizer/common/helper.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
namespace mindspore::opt { namespace mindspore::opt {
MultiplePatternProcessPass::MultiplePatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (patterns_.empty()) { if (patterns_.empty()) {
VarPtr fg = std::make_shared<Var>("RootG"); VarPtr fg = std::make_shared<Var>("RootG");
@ -29,7 +27,7 @@ AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const
for (const auto &pattern : patterns) { for (const auto &pattern : patterns) {
auto primitive_var = std::make_shared<PrimitiveVarMap>(); auto primitive_var = std::make_shared<PrimitiveVarMap>();
MS_CHECK_TRUE_RET(primitive_var != nullptr, nullptr); MS_CHECK_TRUE_RET(primitive_var != nullptr, nullptr);
this->patterns_[pattern.first] = (SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_)); this->patterns_[pattern.first] = (Helper::SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_));
this->primitive_var_maps_[pattern.first] = primitive_var; this->primitive_var_maps_[pattern.first] = primitive_var;
} }
} }

View File

@ -24,12 +24,14 @@
#include "backend/common/optimizer/node_pass.h" #include "backend/common/optimizer/node_pass.h"
#include "backend/common/optimizer/pattern_engine.h" #include "backend/common/optimizer/pattern_engine.h"
#include "backend/common/optimizer/helper.h" #include "backend/common/optimizer/helper.h"
#include "tools/optimizer/common/node_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class MultiplePatternProcessPass : public NodePass { class MultiplePatternProcessPass : public LiteNodePass {
public: public:
explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true); explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true)
: LiteNodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
~MultiplePatternProcessPass() override = default; ~MultiplePatternProcessPass() override = default;
virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0; virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0;

View File

@ -15,7 +15,7 @@
*/ */
#define USE_DEPRECATED_API #define USE_DEPRECATED_API
#include "backend/common/optimizer/node_pass.h" #include "tools/optimizer/common/node_pass_extends.h"
#include <unordered_set> #include <unordered_set>
#include <deque> #include <deque>
#include <algorithm> #include <algorithm>
@ -27,6 +27,11 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
bool NodePass::Run(const FuncGraphPtr &func_graph) { bool NodePass::Run(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "stub func";
return false;
}
bool LiteNodePass::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false; return false;

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/common/optimizer/node_pass.h"
namespace mindspore {
namespace opt {
class LiteNodePass : public NodePass {
public:
explicit LiteNodePass(const std::string &name) : NodePass(name) {}
~LiteNodePass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/common/optimizer/pass_manager.h" #include "tools/optimizer/common/pass_manager_extends.h"
#ifndef _MSC_VER #ifndef _MSC_VER
#include <sys/time.h> #include <sys/time.h>
#endif #endif
@ -27,8 +27,6 @@ namespace opt {
constexpr size_t kMaxRepassTimes = 12; constexpr size_t kMaxRepassTimes = 12;
constexpr uint64_t kUSecondInSecond = 1000000; constexpr uint64_t kUSecondInSecond = 1000000;
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
void PassManager::AddPass(const PassPtr &pass) { void PassManager::AddPass(const PassPtr &pass) {
if (pass != nullptr) { if (pass != nullptr) {
passes_.push_back(pass); passes_.push_back(pass);
@ -36,6 +34,35 @@ void PassManager::AddPass(const PassPtr &pass) {
} }
bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const { bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
MS_LOG(ERROR) << "stub func";
return false;
}
std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
}
void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {
MS_LOG(ERROR) << "stub func";
}
bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
MS_LOG(ERROR) << "stub func";
return false;
}
bool PassManager::Run(const FuncGraphPtr &func_graph) const {
MS_LOG(ERROR) << "stub func";
return false;
}
void LitePassManager::AddPass(const PassPtr &pass) {
if (pass != nullptr) {
passes_.push_back(pass);
}
}
bool LitePassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
bool changed = false; bool changed = false;
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now(); auto start_time = std::chrono::steady_clock::now();
@ -61,14 +88,11 @@ bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const
return changed; return changed;
} }
std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const { std::string LitePassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name(); return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
} }
// not implement for lite, just for api compatible bool LitePassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {}
bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return false; return false;
} }
@ -85,7 +109,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
return changed; return changed;
} }
bool PassManager::Run(const FuncGraphPtr &func_graph) const { bool LitePassManager::Run(const FuncGraphPtr &func_graph) const {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return false; return false;
} }

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/common/optimizer/pass_manager.h"
namespace mindspore {
namespace opt {
class LitePassManager : public PassManager {
public:
explicit LitePassManager(const std::string &name = "pm", bool run_only_once = true)
: PassManager(name, run_only_once) {}
virtual ~LitePassManager() = default;
void AddPass(const PassPtr &pass) override;
bool Run(const FuncGraphPtr &func_graph) const override;
bool Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const override;
protected:
bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const override;
std::string GetPassFullname(size_t pass_id, const PassPtr &pass) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_

View File

@ -0,0 +1,54 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/common/pattern_process_pass_extends.h"
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "backend/common/optimizer/pass_manager.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "ir/manager.h"
#include "tools/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
void LitePatternProcessPass::Build() {
VarPtr fg = std::make_shared<Var>("RootG");
pattern_ = Helper::SexpToNode(DefinePattern(), fg, primitive_vars_.get(), multigraph_);
}
AnfNodePtr LitePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (pattern_ == nullptr) {
Build();
}
auto primitive = GetCNodePrimitive(pattern_);
if (primitive_vars_ == nullptr || equiv_ == nullptr) {
return nullptr;
}
if (IsPrimitiveCNode(node, primitive)) {
equiv_->clear();
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, equiv_);
if (equiv != nullptr && !equiv->empty()) {
return Process(func_graph, node, equiv);
}
}
return nullptr;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/common/optimizer/pattern_engine.h"
#include "tools/optimizer/common/node_pass_extends.h"
namespace mindspore {
namespace opt {
class LitePatternProcessPass : public LiteNodePass {
public:
explicit LitePatternProcessPass(const std::string &name = "", bool multigraph = true)
: LiteNodePass(name),
multigraph_(multigraph),
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
equiv_(std::make_shared<Equiv>()) {}
~LitePatternProcessPass() override = default;
virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
virtual const BaseRef DefinePattern() const = 0;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
private:
void Build();
AnfNodePtr pattern_ = nullptr;
bool multigraph_ = true;
PatternEngine pattern_engine_;
PrimitiveVarMapPtr primitive_vars_;
EquivPtr equiv_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_

View File

@ -34,6 +34,7 @@
#include "src/common/ops/anf_utils.h" #include "src/common/ops/anf_utils.h"
#include "src/litert/infer_manager.h" #include "src/litert/infer_manager.h"
#include "tools/optimizer/graph/lite_tensor_extractor.h" #include "tools/optimizer/graph/lite_tensor_extractor.h"
#include "tools/optimizer/common/helper.h"
using mindspore::lite::KernelRegistry; using mindspore::lite::KernelRegistry;
using mindspore::lite::Tensor; using mindspore::lite::Tensor;
@ -117,7 +118,7 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_NULL_PTR);
if (output_tensors.size() != 1) { if (output_tensors.size() != 1) {
for (size_t k = 0; k < output_tensors.size(); k++) { for (size_t k = 0; k < output_tensors.size(); k++) {
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, k); auto used_node_list = Helper::GetRealNodeUsedListByOutputIdx(func_graph, cnode, k);
if (used_node_list->empty()) { if (used_node_list->empty()) {
MS_LOG(DEBUG) << "this output don't be used by other node."; MS_LOG(DEBUG) << "this output don't be used by other node.";
continue; continue;

View File

@ -17,14 +17,15 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/optimizer/fisson/fisson_util.h" #include "tools/optimizer/fisson/fisson_util.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class EliminateConcatSplit : public PatternProcessPass { class EliminateConcatSplit : public LitePatternProcessPass {
public: public:
explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {} explicit EliminateConcatSplit(bool multigraph = true)
: LitePatternProcessPass("eliminate_concat_split", multigraph) {}
~EliminateConcatSplit() override = default; ~EliminateConcatSplit() override = default;
private: private:

View File

@ -22,7 +22,6 @@
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "mindspore/ccsrc/include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/lite/include/context.h" #include "mindspore/lite/include/context.h"
#include "mindspore/lite/include/lite_types.h" #include "mindspore/lite/include/lite_types.h"

View File

@ -15,16 +15,16 @@
*/ */
#include "ir/anf.h" #include "ir/anf.h"
#include "mindspore/ccsrc/backend/common/optimizer/node_pass.h" #include "tools/optimizer/common/node_pass_extends.h"
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class IterNodeOutputs : public opt::NodePass { class IterNodeOutputs : public opt::LiteNodePass {
public: public:
IterNodeOutputs() : NodePass("iter_node_outputs") {} IterNodeOutputs() : LiteNodePass("iter_node_outputs") {}
~IterNodeOutputs() override = default; ~IterNodeOutputs() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
}; };

View File

@ -22,7 +22,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/optimizer/fisson/fisson_util.h" #include "tools/optimizer/fisson/fisson_util.h"
#include "tools/optimizer/parallel/split_strategy.h" #include "tools/optimizer/parallel/split_strategy.h"
#include "tools/optimizer/parallel/multi_node_split.h" #include "tools/optimizer/parallel/multi_node_split.h"
@ -31,11 +31,11 @@ using mindspore::schema::PrimitiveType;
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class MultiConvSplitPass : public PatternProcessPass { class MultiConvSplitPass : public LitePatternProcessPass {
public: public:
explicit MultiConvSplitPass(std::unordered_map<std::string, SplitStrategy> strategys, int32_t fmk_type = -1, explicit MultiConvSplitPass(std::unordered_map<std::string, SplitStrategy> strategys, int32_t fmk_type = -1,
int32_t num = 3, bool multigraph = true) int32_t num = 3, bool multigraph = true)
: PatternProcessPass("multi_conv_split", multigraph), : LitePatternProcessPass("multi_conv_split", multigraph),
strategys_(std::move(strategys)), strategys_(std::move(strategys)),
fmk_type_(fmk_type), fmk_type_(fmk_type),
num_(num) {} num_(num) {}

View File

@ -15,16 +15,16 @@
*/ */
#include "ir/anf.h" #include "ir/anf.h"
#include "mindspore/ccsrc/backend/common/optimizer/node_pass.h" #include "tools/optimizer/common/node_pass_extends.h"
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class NodeOutShapes : public opt::NodePass { class NodeOutShapes : public opt::LiteNodePass {
public: public:
NodeOutShapes() : NodePass("node_out_shapes") {} NodeOutShapes() : LiteNodePass("node_out_shapes") {}
~NodeOutShapes() override = default; ~NodeOutShapes() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
}; };

View File

@ -18,13 +18,13 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ADD_CONCAT_ACTIVATION_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ADD_CONCAT_ACTIVATION_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class AddConcatActivationFusion : public PatternProcessPass { class AddConcatActivationFusion : public LitePatternProcessPass {
public: public:
explicit AddConcatActivationFusion(bool multigraph = true, const std::string &name = "AddConcatActivationFusion") explicit AddConcatActivationFusion(bool multigraph = true, const std::string &name = "AddConcatActivationFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~AddConcatActivationFusion() override = default; ~AddConcatActivationFusion() override = default;
private: private:

View File

@ -34,6 +34,7 @@ const BaseRef AffineActivationFusion::DefinePattern() const {
const AnfNodePtr AffineActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr AffineActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const { const EquivPtr &equiv) const {
constexpr size_t kAnfPrimitiveIndex = 0;
if (func_graph == nullptr || node == nullptr || equiv == nullptr) { if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr; return nullptr;

View File

@ -19,14 +19,14 @@
#include <string> #include <string>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class AffineActivationFusion : public PatternProcessPass { class AffineActivationFusion : public LitePatternProcessPass {
public: public:
explicit AffineActivationFusion(bool multigraph = true, const std::string &name = "AffineActivationFusion") explicit AffineActivationFusion(bool multigraph = true, const std::string &name = "AffineActivationFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~AffineActivationFusion() override = default; ~AffineActivationFusion() override = default;
private: private:

View File

@ -44,6 +44,7 @@ const BaseRef AffineFusion::DefinePattern() const {
const AnfNodePtr AffineFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr AffineFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const { const EquivPtr &equiv) const {
constexpr size_t kAnfPrimitiveIndex = 0;
if (func_graph == nullptr || node == nullptr || equiv == nullptr) { if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr; return nullptr;

View File

@ -18,15 +18,15 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_AFFINE_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_AFFINE_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class AffineFusion : public PatternProcessPass { class AffineFusion : public LitePatternProcessPass {
public: public:
explicit AffineFusion(bool multigraph = true, const std::string &name = "AffineFusion") explicit AffineFusion(bool multigraph = true, const std::string &name = "AffineFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~AffineFusion() override = default; ~AffineFusion() override = default;
private: private:

View File

@ -17,14 +17,14 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class BatchMatMulFusion : public PatternProcessPass { class BatchMatMulFusion : public LitePatternProcessPass {
public: public:
explicit BatchMatMulFusion(bool multigraph = true) : PatternProcessPass("BatchMatMulFusion", multigraph) {} explicit BatchMatMulFusion(bool multigraph = true) : LitePatternProcessPass("BatchMatMulFusion", multigraph) {}
~BatchMatMulFusion() override = default; ~BatchMatMulFusion() override = default;
private: private:

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHNORM_TO_SCALE_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHNORM_TO_SCALE_FUSION_H_
#include <vector> #include <vector>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class BatchNormToScaleFusion : public Pass { class BatchNormToScaleFusion : public Pass {

View File

@ -19,16 +19,16 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/cxx_api/converter_para.h" #include "tools/converter/cxx_api/converter_para.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ConvActivationFusion : public PatternProcessPass { class ConvActivationFusion : public LitePatternProcessPass {
public: public:
explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true, explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true,
const std::string &name = "ConvActivationFusion") const std::string &name = "ConvActivationFusion")
: PatternProcessPass(name, multigraph), param_(param) {} : LitePatternProcessPass(name, multigraph), param_(param) {}
~ConvActivationFusion() override = default; ~ConvActivationFusion() override = default;

View File

@ -17,14 +17,14 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ConvBiasaddFusion : public PatternProcessPass { class ConvBiasaddFusion : public LitePatternProcessPass {
public: public:
explicit ConvBiasaddFusion(bool multigraph = true) : PatternProcessPass("ConvBiasaddFusion", multigraph) {} explicit ConvBiasaddFusion(bool multigraph = true) : LitePatternProcessPass("ConvBiasaddFusion", multigraph) {}
~ConvBiasaddFusion() override = default; ~ConvBiasaddFusion() override = default;
private: private:

View File

@ -19,13 +19,13 @@
#include <string> #include <string>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ConvConvFusion : public PatternProcessPass { class ConvConvFusion : public LitePatternProcessPass {
public: public:
explicit ConvConvFusion(bool multigraph = true) : PatternProcessPass("ConvConvFusion", multigraph) {} explicit ConvConvFusion(bool multigraph = true) : LitePatternProcessPass("ConvConvFusion", multigraph) {}
~ConvConvFusion() override = default; ~ConvConvFusion() override = default;
private: private:

View File

@ -18,15 +18,15 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/registry/converter_context.h" #include "include/registry/converter_context.h"
using mindspore::converter::FmkType; using mindspore::converter::FmkType;
namespace mindspore::opt { namespace mindspore::opt {
class ConvTransformFusion : public PatternProcessPass { class ConvTransformFusion : public LitePatternProcessPass {
public: public:
explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "ConvTransformFusion") explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "ConvTransformFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ConvTransformFusion() override = default; ~ConvTransformFusion() override = default;
protected: protected:

View File

@ -18,14 +18,14 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLE_ACTIVATION_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLE_ACTIVATION_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ConvTupleActivationFusion : public PatternProcessPass { class ConvTupleActivationFusion : public LitePatternProcessPass {
public: public:
explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "ConvTupleActivationFusion") explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "ConvTupleActivationFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ConvTupleActivationFusion() override = default; ~ConvTupleActivationFusion() override = default;
private: private:

View File

@ -16,12 +16,13 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class ConvTupleGetItemFusion : public PatternProcessPass { class ConvTupleGetItemFusion : public LitePatternProcessPass {
public: public:
explicit ConvTupleGetItemFusion(const std::string &name = "ConvTupleGetItemFusion", bool multigraph = true) explicit ConvTupleGetItemFusion(const std::string &name = "ConvTupleGetItemFusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ConvTupleGetItemFusion() override = default; ~ConvTupleGetItemFusion() override = default;
private: private:

View File

@ -19,7 +19,6 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "backend/common/optimizer/optimizer.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore::opt { namespace mindspore::opt {

View File

@ -19,14 +19,13 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "backend/common/optimizer/optimizer.h"
#include "tools/optimizer/fusion/conv_transform_fusion.h" #include "tools/optimizer/fusion/conv_transform_fusion.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
namespace mindspore::opt { namespace mindspore::opt {
class FullConnectedFusion : public PatternProcessPass { class FullConnectedFusion : public LitePatternProcessPass {
public: public:
explicit FullConnectedFusion(bool multigraph = true) : PatternProcessPass("FullConnectedFusion", multigraph) {} explicit FullConnectedFusion(bool multigraph = true) : LitePatternProcessPass("FullConnectedFusion", multigraph) {}
~FullConnectedFusion() override = default; ~FullConnectedFusion() override = default;
private: private:

View File

@ -18,16 +18,16 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class GLUFusion : public PatternProcessPass { class GLUFusion : public LitePatternProcessPass {
public: public:
explicit GLUFusion(const std::string &name = "glu_fusion", bool multigraph = true) explicit GLUFusion(const std::string &name = "glu_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~GLUFusion() override = default; ~GLUFusion() override = default;

View File

@ -21,17 +21,17 @@
#include <string> #include <string>
#include <map> #include <map>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
/// fuse layer_norm or instance_norm into one operator /// fuse layer_norm or instance_norm into one operator
class GroupNormFusion : public PatternProcessPass { class GroupNormFusion : public LitePatternProcessPass {
public: public:
explicit GroupNormFusion(const std::string &name = "GroupNormFusion", bool multigraph = true) explicit GroupNormFusion(const std::string &name = "GroupNormFusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~GroupNormFusion() override = default; ~GroupNormFusion() override = default;

View File

@ -19,16 +19,16 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
#include "tools/converter/cxx_api/converter_para.h" #include "tools/converter/cxx_api/converter_para.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class MatMulActivationFusion : public PatternProcessPass { class MatMulActivationFusion : public LitePatternProcessPass {
public: public:
explicit MatMulActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true) explicit MatMulActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true)
: PatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {} : LitePatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {}
~MatMulActivationFusion() = default; ~MatMulActivationFusion() = default;
private: private:

View File

@ -19,7 +19,6 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "backend/common/optimizer/optimizer.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore { namespace mindspore {

View File

@ -18,13 +18,13 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MATMUL_MUL_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MATMUL_MUL_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class MatMulMulFusion : public PatternProcessPass { class MatMulMulFusion : public LitePatternProcessPass {
public: public:
explicit MatMulMulFusion(bool multigraph = true, const std::string &name = "MatMulMulFusion") explicit MatMulMulFusion(bool multigraph = true, const std::string &name = "MatMulMulFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~MatMulMulFusion() override = default; ~MatMulMulFusion() override = default;
private: private:

View File

@ -18,13 +18,13 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MUL_ACTIVATION_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MUL_ACTIVATION_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class MulActivationFusion : public PatternProcessPass { class MulActivationFusion : public LitePatternProcessPass {
public: public:
explicit MulActivationFusion(bool multigraph = true, const std::string &name = "MulActivationFusion") explicit MulActivationFusion(bool multigraph = true, const std::string &name = "MulActivationFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~MulActivationFusion() = default; ~MulActivationFusion() = default;
private: private:

View File

@ -22,7 +22,7 @@
#include <string> #include <string>
#include <map> #include <map>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
@ -30,10 +30,10 @@ namespace mindspore {
namespace opt { namespace opt {
/// fuse layer_norm or instance_norm into one operator /// fuse layer_norm or instance_norm into one operator
class NormFusion : public PatternProcessPass { class NormFusion : public LitePatternProcessPass {
public: public:
explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true) explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) { : LitePatternProcessPass(name, multigraph) {
InitShapeSizeInferFuncMap(); InitShapeSizeInferFuncMap();
} }

View File

@ -19,15 +19,15 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ReshapeReshapeFusion : public PatternProcessPass { class ReshapeReshapeFusion : public LitePatternProcessPass {
public: public:
explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "ReshapeReshapeFusion") explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "ReshapeReshapeFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ReshapeReshapeFusion() override = default; ~ReshapeReshapeFusion() override = default;
private: private:

View File

@ -18,14 +18,14 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SCALE_ACTIVATION_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SCALE_ACTIVATION_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ScaleActivationFusion : public PatternProcessPass { class ScaleActivationFusion : public LitePatternProcessPass {
public: public:
explicit ScaleActivationFusion(bool multigraph = true, const std::string &name = "ScaleActivationFusion") explicit ScaleActivationFusion(bool multigraph = true, const std::string &name = "ScaleActivationFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ScaleActivationFusion() override = default; ~ScaleActivationFusion() override = default;
private: private:

View File

@ -20,14 +20,14 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "ops/fusion/scale_fusion.h" #include "ops/fusion/scale_fusion.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt { namespace mindspore::opt {
class ScaleBaseFusion : public PatternProcessPass { class ScaleBaseFusion : public LitePatternProcessPass {
public: public:
explicit ScaleBaseFusion(std::string name, bool multigraph = true) : PatternProcessPass(name, multigraph) {} explicit ScaleBaseFusion(std::string name, bool multigraph = true) : LitePatternProcessPass(name, multigraph) {}
~ScaleBaseFusion() override = default; ~ScaleBaseFusion() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -19,13 +19,13 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore::opt { namespace mindspore::opt {
class ScaleScaleFusion : public PatternProcessPass { class ScaleScaleFusion : public LitePatternProcessPass {
public: public:
explicit ScaleScaleFusion(bool multigraph = true, const std::string &name = "ScaleScaleFusion") explicit ScaleScaleFusion(bool multigraph = true, const std::string &name = "ScaleScaleFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~ScaleScaleFusion() override = default; ~ScaleScaleFusion() override = default;
private: private:

View File

@ -17,14 +17,14 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_ #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class SigmoidMulFusion : public PatternProcessPass { class SigmoidMulFusion : public LitePatternProcessPass {
public: public:
explicit SigmoidMulFusion(bool multigraph = true) : PatternProcessPass("SigmoidMulFusion", multigraph) {} explicit SigmoidMulFusion(bool multigraph = true) : LitePatternProcessPass("SigmoidMulFusion", multigraph) {}
~SigmoidMulFusion() override = default; ~SigmoidMulFusion() override = default;
private: private:

View File

@ -18,15 +18,15 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_FUSION_H_ #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_FUSION_H_
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class SqueezeFusion : public PatternProcessPass { class SqueezeFusion : public LitePatternProcessPass {
public: public:
explicit SqueezeFusion(bool multigraph = true, const std::string &name = "SqueezeFusion") explicit SqueezeFusion(bool multigraph = true, const std::string &name = "SqueezeFusion")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~SqueezeFusion() override = default; ~SqueezeFusion() override = default;
private: private:

View File

@ -20,7 +20,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "tools/converter/quantizer/quant_param_holder.h" #include "tools/converter/quantizer/quant_param_holder.h"
#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "ops/fusion/scale_fusion.h" #include "ops/fusion/scale_fusion.h"
@ -28,10 +28,9 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class TensorDotFusion : public PatternProcessPass { class TensorDotFusion : public LitePatternProcessPass {
public: public:
explicit TensorDotFusion(bool multigraph = true) : PatternProcessPass("TensorDotFusion", multigraph) {} explicit TensorDotFusion(bool multigraph = true) : LitePatternProcessPass("TensorDotFusion", multigraph) {}
~TensorDotFusion() override = default; ~TensorDotFusion() override = default;
private: private:

View File

@ -18,6 +18,7 @@
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include <memory> #include <memory>
#include <functional> #include <functional>
#include "tools/optimizer/common/helper.h"
#include "ops/concat.h" #include "ops/concat.h"
#include "ops/gru.h" #include "ops/gru.h"
#include "ops/split.h" #include "ops/split.h"
@ -300,7 +301,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr
MS_CHECK_TRUE_RET(is_return != nullptr, nullptr); MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
VectorRef return_ref = VectorRef({is_return, logicaland_ref}); VectorRef return_ref = VectorRef({is_return, logicaland_ref});
VarPtr is_fg = std::make_shared<Var>("RootG"); VarPtr is_fg = std::make_shared<Var>("RootG");
auto pattern = SexpToNode(return_ref, is_fg, primitive_vars.get(), true); auto pattern = Helper::SexpToNode(return_ref, is_fg, primitive_vars.get(), true);
return pattern; return pattern;
} }
@ -373,7 +374,7 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr
VarPtr is_fg = std::make_shared<Var>("RootG"); VarPtr is_fg = std::make_shared<Var>("RootG");
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr); MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
auto pattern = SexpToNode(return_node, is_fg, primitive_vars.get(), true); auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
return pattern; return pattern;
} }

View File

@ -21,7 +21,7 @@
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -29,11 +29,11 @@ namespace mindspore {
namespace opt { namespace opt {
constexpr size_t kWhileUniqInputsLength = 6; constexpr size_t kWhileUniqInputsLength = 6;
// fuse tf 2.x bidirection_gru into MSLITE GRU // fuse tf 2.x bidirection_gru into MSLITE GRU
class TfBidirectionGruFusion : public PatternProcessPass { class TfBidirectionGruFusion : public LitePatternProcessPass {
public: public:
explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength, explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true) const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true)
: PatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {} : LitePatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {}
~TfBidirectionGruFusion() override = default; ~TfBidirectionGruFusion() override = default;

View File

@ -23,6 +23,7 @@
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/common/helper.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
namespace mindspore { namespace mindspore {
@ -179,7 +180,7 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi
VarPtr is_fg = std::make_shared<Var>("RootG"); VarPtr is_fg = std::make_shared<Var>("RootG");
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr); MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
auto pattern = SexpToNode(return_node, is_fg, primitive_vars.get(), true); auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
return pattern; return pattern;
} }

View File

@ -26,6 +26,7 @@
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/helper.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
@ -169,7 +170,7 @@ bool TfliteLstmCellFusion::Init() const {
TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num, TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
int cond_nodes_num, int cond_cnodes_num, int body_nodes_num, int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
int body_cnodes_num) int body_cnodes_num)
: PatternProcessPass(name, multigraph) { : LitePatternProcessPass(name, multigraph) {
/* /*
* input vars for lstm while node * input vars for lstm while node
* 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_ * 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
@ -207,7 +208,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p
VectorRef return_ref = VectorRef({is_return, logicaland_ref}); VectorRef return_ref = VectorRef({is_return, logicaland_ref});
VarPtr fg = std::make_shared<Var>("RootG"); VarPtr fg = std::make_shared<Var>("RootG");
MS_CHECK_TRUE_RET(fg != nullptr, nullptr); MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true); auto pattern = Helper::SexpToNode(return_ref, fg, primitive_vars.get(), true);
return pattern; return pattern;
} }
@ -278,7 +279,7 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p
VarPtr fg = std::make_shared<Var>("RootG"); VarPtr fg = std::make_shared<Var>("RootG");
MS_CHECK_TRUE_RET(fg != nullptr, nullptr); MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true); auto pattern = Helper::SexpToNode(return_node, fg, primitive_vars.get(), true);
return pattern; return pattern;
} }

View File

@ -19,13 +19,13 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class TfliteLstmCellFusion : public PatternProcessPass { class TfliteLstmCellFusion : public LitePatternProcessPass {
public: public:
explicit TfliteLstmCellFusion(const std::string &name = "TfliteLstmCellFusion", bool multigraph = true, explicit TfliteLstmCellFusion(const std::string &name = "TfliteLstmCellFusion", bool multigraph = true,
int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0, int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0,

View File

@ -19,14 +19,14 @@
#include <string> #include <string>
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "backend/common/optimizer/optimizer.h" #include "tools/optimizer/common/pattern_process_pass_extends.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class AddTensorArray : public PatternProcessPass { class AddTensorArray : public LitePatternProcessPass {
public: public:
explicit AddTensorArray(bool multigraph = true, const std::string &name = "add_tensor_array") explicit AddTensorArray(bool multigraph = true, const std::string &name = "add_tensor_array")
: PatternProcessPass(name, multigraph) {} : LitePatternProcessPass(name, multigraph) {}
~AddTensorArray() override = default; ~AddTensorArray() override = default;
const BaseRef DefinePattern() const override; const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -27,6 +27,7 @@
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/helper.h"
#include "backend/common/optimizer/helper.h" #include "backend/common/optimizer/helper.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
@ -1514,7 +1515,7 @@ bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
if (output_tensor_num > 1) { if (output_tensor_num > 1) {
continue; continue;
} }
auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node)); auto output_node_list = Helper::GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
if (output_node_list->size() > 1) { // referenced by multi nodes if (output_node_list->size() > 1) { // referenced by multi nodes
if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) { if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
this_time_changed = true; this_time_changed = true;

View File

@ -20,8 +20,8 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "ir/anf.h" #include "ir/anf.h"
#include "tools/optimizer/common/node_pass_extends.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "backend/common/optimizer/node_pass.h"
#include "tools/optimizer/parallel/split_strategy.h" #include "tools/optimizer/parallel/split_strategy.h"
#include "tools/optimizer/parallel/operator_info.h" #include "tools/optimizer/parallel/operator_info.h"
@ -30,10 +30,10 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
class ParallelPass : public opt::NodePass { class ParallelPass : public opt::LiteNodePass {
public: public:
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type) explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type)
: NodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {} : LiteNodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {}
~ParallelPass() override = default; ~ParallelPass() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;