forked from mindspore-Ecosystem/mindspore
support online infer
This commit is contained in:
parent
2f7531af8f
commit
88a1f88e83
|
@ -411,11 +411,11 @@ if(PLATFORM_ARM64)
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
|
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
|
||||||
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
if(MSLITE_ENABLE_ACL)
|
if(MSLITE_ENABLE_ACL)
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
@ -650,11 +650,11 @@ elseif(PLATFORM_ARM32)
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
|
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
|
||||||
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
if(MSLITE_ENABLE_ACL)
|
if(MSLITE_ENABLE_ACL)
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
@ -840,11 +840,11 @@ else()
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
install(FILES ${EXTENDRT_BUILD_DIR}/delegate/graph_executor/litert/${MINDSPORE_GE_LITERT_LIB_NAME}.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
|
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
install(FILES ${glog_LIBPATH}/libmindspore_glog.so.0.4.0 DESTINATION ${RUNTIME_LIB_DIR}
|
||||||
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
install(TARGETS mindspore_core DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so
|
||||||
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
if(MSLITE_ENABLE_ACL)
|
if(MSLITE_ENABLE_ACL)
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so
|
||||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/common/optimizer/graph_optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
|
||||||
|
if (pass_manager != nullptr) {
|
||||||
|
pass_managers_.push_back(pass_manager);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
|
||||||
|
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
|
||||||
|
// cppcheck-suppress *
|
||||||
|
auto manager = Manage(func_graph, true);
|
||||||
|
|
||||||
|
bool changed = true;
|
||||||
|
while (changed) {
|
||||||
|
changed = false;
|
||||||
|
for (size_t i = 0; i < pass_managers_.size(); ++i) {
|
||||||
|
const PassManagerPtr &pm = pass_managers_[i];
|
||||||
|
if (pm != nullptr && pm->Run(func_graph)) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (run_only_once_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<FuncGraphPtr> func_graphs;
|
||||||
|
func_graphs.push_back(func_graph);
|
||||||
|
(void)TopoSort(func_graph->get_return());
|
||||||
|
return func_graph;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/optimizer/pass_manager.h"
|
||||||
|
#include "include/backend/visible.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class BACKEND_EXPORT GraphOptimizer {
|
||||||
|
public:
|
||||||
|
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
|
||||||
|
virtual ~GraphOptimizer() = default;
|
||||||
|
|
||||||
|
void AddPassManager(const PassManagerPtr &pass_manager);
|
||||||
|
FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string name_ = "graph_optimizer";
|
||||||
|
std::vector<PassManagerPtr> pass_managers_{};
|
||||||
|
bool run_only_once_ = true;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_
|
|
@ -28,7 +28,7 @@ class BACKEND_EXPORT NodePass : public Pass {
|
||||||
public:
|
public:
|
||||||
explicit NodePass(const std::string &name) : Pass(name) {}
|
explicit NodePass(const std::string &name) : Pass(name) {}
|
||||||
~NodePass() override = default;
|
~NodePass() override = default;
|
||||||
bool Run(const FuncGraphPtr &func_graph) final;
|
virtual bool Run(const FuncGraphPtr &func_graph);
|
||||||
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
||||||
};
|
};
|
||||||
using NodePassPtr = std::shared_ptr<NodePass>;
|
using NodePassPtr = std::shared_ptr<NodePass>;
|
||||||
|
|
|
@ -117,37 +117,5 @@ std::vector<AnfNodePtr> MultipleOutputPatternProcessPass::GetOrigNodes() const {
|
||||||
}
|
}
|
||||||
return orig_nodes;
|
return orig_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
|
|
||||||
if (pass_manager != nullptr) {
|
|
||||||
pass_managers_.push_back(pass_manager);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) {
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once;
|
|
||||||
// cppcheck-suppress *
|
|
||||||
auto manager = Manage(func_graph, true);
|
|
||||||
|
|
||||||
bool changed = true;
|
|
||||||
while (changed) {
|
|
||||||
changed = false;
|
|
||||||
for (size_t i = 0; i < pass_managers_.size(); ++i) {
|
|
||||||
const PassManagerPtr &pm = pass_managers_[i];
|
|
||||||
if (pm != nullptr && pm->Run(func_graph)) {
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (run_only_once_) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<FuncGraphPtr> func_graphs;
|
|
||||||
func_graphs.push_back(func_graph);
|
|
||||||
(void)TopoSort(func_graph->get_return());
|
|
||||||
return func_graph;
|
|
||||||
}
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include "ir/graph_utils.h"
|
#include "ir/graph_utils.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
#include "backend/common/optimizer/graph_optimizer.h"
|
||||||
#include "include/backend/visible.h"
|
#include "include/backend/visible.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -77,20 +78,6 @@ class MultipleOutputPatternProcessPass : public PatternProcessPass {
|
||||||
PrimitiveVarMapPtr child_primitive_vars_;
|
PrimitiveVarMapPtr child_primitive_vars_;
|
||||||
EquivPtr child_equiv_;
|
EquivPtr child_equiv_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BACKEND_EXPORT GraphOptimizer {
|
|
||||||
public:
|
|
||||||
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
|
|
||||||
virtual ~GraphOptimizer() = default;
|
|
||||||
|
|
||||||
void AddPassManager(const PassManagerPtr &pass_manager);
|
|
||||||
FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true);
|
|
||||||
|
|
||||||
private:
|
|
||||||
const std::string name_ = "graph_optimizer";
|
|
||||||
std::vector<PassManagerPtr> pass_managers_{};
|
|
||||||
bool run_only_once_ = true;
|
|
||||||
};
|
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -86,8 +86,6 @@ ShapeVector CacheManager::GetOutputShape(const AnfNodePtr &node, size_t index) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
|
|
||||||
|
|
||||||
void PassManager::AddPass(const PassPtr &pass) {
|
void PassManager::AddPass(const PassPtr &pass) {
|
||||||
if (pass != nullptr) {
|
if (pass != nullptr) {
|
||||||
passes_.push_back(pass);
|
passes_.push_back(pass);
|
||||||
|
|
|
@ -48,7 +48,7 @@ class BACKEND_EXPORT PassManager {
|
||||||
: name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
|
: name_(name), passes_{}, run_only_once_(run_only_once), cache_manager_(std::make_shared<CacheManager>()) {}
|
||||||
virtual ~PassManager() = default;
|
virtual ~PassManager() = default;
|
||||||
// Get all the passes added by AddPass
|
// Get all the passes added by AddPass
|
||||||
const std::vector<PassPtr> &Passes() const;
|
const std::vector<PassPtr> &Passes() const { return passes_; }
|
||||||
// Add graph pass, the pass object will be freed when pass manager freed.
|
// Add graph pass, the pass object will be freed when pass manager freed.
|
||||||
virtual void AddPass(const PassPtr &pass);
|
virtual void AddPass(const PassPtr &pass);
|
||||||
// Run passes added in pass manager on the input graph
|
// Run passes added in pass manager on the input graph
|
||||||
|
|
|
@ -82,6 +82,7 @@ class TensorDefaultImpl : public MSTensor::Impl {
|
||||||
const std::string &Name() const override { return name_; }
|
const std::string &Name() const override { return name_; }
|
||||||
enum DataType DataType() const override { return type_; }
|
enum DataType DataType() const override { return type_; }
|
||||||
const std::vector<int64_t> &Shape() const override { return shape_; }
|
const std::vector<int64_t> &Shape() const override { return shape_; }
|
||||||
|
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; }
|
||||||
|
|
||||||
std::shared_ptr<const void> Data() const override {
|
std::shared_ptr<const void> Data() const override {
|
||||||
return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
|
return std::shared_ptr<const void>(buffer_.Data(), [](const void *) {});
|
||||||
|
@ -115,6 +116,7 @@ class TensorReferenceImpl : public MSTensor::Impl {
|
||||||
const std::string &Name() const override { return name_; }
|
const std::string &Name() const override { return name_; }
|
||||||
enum DataType DataType() const override { return type_; }
|
enum DataType DataType() const override { return type_; }
|
||||||
const std::vector<int64_t> &Shape() const override { return shape_; }
|
const std::vector<int64_t> &Shape() const override { return shape_; }
|
||||||
|
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; }
|
||||||
|
|
||||||
std::shared_ptr<const void> Data() const override {
|
std::shared_ptr<const void> Data() const override {
|
||||||
return std::shared_ptr<const void>(data_, [](const void *) {});
|
return std::shared_ptr<const void>(data_, [](const void *) {});
|
||||||
|
|
|
@ -44,6 +44,7 @@ class DETensor : public mindspore::MSTensor::Impl {
|
||||||
size_t DataSize() const override;
|
size_t DataSize() const override;
|
||||||
|
|
||||||
const std::vector<int64_t> &Shape() const override;
|
const std::vector<int64_t> &Shape() const override;
|
||||||
|
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; };
|
||||||
|
|
||||||
int64_t ElementNum() const;
|
int64_t ElementNum() const;
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,13 @@ class Factory {
|
||||||
(void)kernel_mod_creators_.emplace(name, creator);
|
(void)kernel_mod_creators_.emplace(name, creator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UnRegister(const std::string &name) {
|
||||||
|
auto iter = kernel_mod_creators_.find(name);
|
||||||
|
if (iter != kernel_mod_creators_.end()) {
|
||||||
|
kernel_mod_creators_.erase(iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<C> Create(const std::string &name) const {
|
std::shared_ptr<C> Create(const std::string &name) const {
|
||||||
typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name);
|
typename std::map<std::string, CreatorFunc>::const_iterator iter = kernel_mod_creators_.find(name);
|
||||||
if (iter != kernel_mod_creators_.cend()) {
|
if (iter != kernel_mod_creators_.cend()) {
|
||||||
|
|
|
@ -33,6 +33,7 @@ class MSTensor::Impl {
|
||||||
virtual const std::string &Name() const = 0;
|
virtual const std::string &Name() const = 0;
|
||||||
virtual enum DataType DataType() const = 0;
|
virtual enum DataType DataType() const = 0;
|
||||||
virtual const std::vector<int64_t> &Shape() const = 0;
|
virtual const std::vector<int64_t> &Shape() const = 0;
|
||||||
|
virtual void SetShape(const std::vector<int64_t> &shape) = 0;
|
||||||
|
|
||||||
virtual std::shared_ptr<const void> Data() const = 0;
|
virtual std::shared_ptr<const void> Data() const = 0;
|
||||||
virtual void *MutableData() = 0;
|
virtual void *MutableData() = 0;
|
||||||
|
|
|
@ -31,6 +31,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc)
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc)
|
||||||
|
|
||||||
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
|
set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC}
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/mindir_loader/mindir_model/mindir_model_util.cc
|
||||||
${MSLITE_KERNEL_PLUGIN}
|
${MSLITE_KERNEL_PLUGIN}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/file_utils.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/file_utils.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/../common/utils.cc
|
||||||
|
@ -56,6 +57,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/session/lite_infer_session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/session/lite_infer_session.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/convert/runtime_convert.cc
|
||||||
)
|
)
|
||||||
# when cpu kernel is need
|
# when cpu kernel is need
|
||||||
#if(NOT MSLITE_ENABLE_ACL)
|
#if(NOT MSLITE_ENABLE_ACL)
|
||||||
|
@ -191,6 +193,10 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
||||||
add_subdirectory(delegate/tensorrt)
|
add_subdirectory(delegate/tensorrt)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_CONVERTER)
|
||||||
|
add_subdirectory(convert)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(TEST_CLOUD_INFER on)
|
set(TEST_CLOUD_INFER on)
|
||||||
|
|
||||||
if(TEST_CLOUD_INFER)
|
if(TEST_CLOUD_INFER)
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
include_directories(${TOP_DIR})
|
||||||
|
include_directories(${TOP_DIR}/mindspore/lite)
|
||||||
|
|
||||||
|
file(GLOB RUNTIME_CONVERT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
|
||||||
|
|
||||||
|
add_library(runtime_convert_plugin SHARED ${RUNTIME_CONVERT_SRC})
|
||||||
|
add_dependencies(runtime_convert_plugin fbs_inner_src)
|
||||||
|
|
||||||
|
if(MSLITE_ENABLE_CONVERTER AND NOT WIN32)
|
||||||
|
target_link_libraries(runtime_convert_plugin mindspore_converter)
|
||||||
|
endif()
|
|
@ -0,0 +1,79 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include <vector>
|
||||||
|
#include "src/extendrt/convert/runtime_convert.h"
|
||||||
|
#include "tools/common/string_util.h"
|
||||||
|
#include "tools/converter/converter.h"
|
||||||
|
#include "tools/converter/cxx_api/converter_para.h"
|
||||||
|
|
||||||
|
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
|
||||||
|
const std::shared_ptr<mindspore::Context> &context) {
|
||||||
|
auto param = std::make_shared<mindspore::ConverterPara>();
|
||||||
|
if (param == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "New ConverterPara failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
param->fmk_type = mindspore::converter::kFmkTypeMs;
|
||||||
|
param->input_data_type = mindspore::DataType::kTypeUnknown;
|
||||||
|
param->output_data_type = mindspore::DataType::kTypeUnknown;
|
||||||
|
param->weight_fp16 = false;
|
||||||
|
param->train_model = false;
|
||||||
|
param->export_mindir = mindspore::kMindIR;
|
||||||
|
param->enable_encryption = false;
|
||||||
|
|
||||||
|
auto device_list = context->MutableDeviceInfo();
|
||||||
|
for (auto &device : device_list) {
|
||||||
|
if (device->GetDeviceType() == mindspore::kAscend) {
|
||||||
|
auto ascend_info = device->Cast<mindspore::AscendDeviceInfo>();
|
||||||
|
std::string dynamic_batch_size = ascend_info->GetDynamicBatchSize();
|
||||||
|
if (!dynamic_batch_size.empty()) {
|
||||||
|
std::vector<std::string> batch_size_string = mindspore::lite::SplitStringToVector(dynamic_batch_size, ',');
|
||||||
|
for (const auto &item : batch_size_string) {
|
||||||
|
int32_t val;
|
||||||
|
if (mindspore::lite::ConvertIntNum(item, &val)) {
|
||||||
|
size_t tmp_val = static_cast<size_t>(val);
|
||||||
|
param->aclModelOptionCfgParam.dynamic_batch_size.push_back(tmp_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
param->aclModelOptionCfgParam.offline = false;
|
||||||
|
param->aclModelOptionCfgParam.device_id = ascend_info->GetDeviceID();
|
||||||
|
param->aclModelOptionCfgParam.output_type = ascend_info->GetOutputType();
|
||||||
|
param->aclModelOptionCfgParam.input_shape_map = ascend_info->GetInputShapeMap();
|
||||||
|
param->aclModelOptionCfgParam.input_format = ascend_info->GetInputFormat();
|
||||||
|
param->aclModelOptionCfgParam.input_shape = ascend_info->GetInputShape();
|
||||||
|
param->aclModelOptionCfgParam.precision_mode = ascend_info->GetPrecisionMode();
|
||||||
|
param->aclModelOptionCfgParam.op_select_impl_mode = ascend_info->GetOpSelectImplMode();
|
||||||
|
param->aclModelOptionCfgParam.fusion_switch_config_file_path = ascend_info->GetFusionSwitchConfigPath();
|
||||||
|
param->aclModelOptionCfgParam.buffer_optimize = ascend_info->GetBufferOptimizeMode();
|
||||||
|
param->aclModelOptionCfgParam.insert_op_config_file_path = ascend_info->GetInsertOpConfigPath();
|
||||||
|
param->aclModelOptionCfgParam.dynamic_image_size = ascend_info->GetDynamicImageSize();
|
||||||
|
param->device = "Ascend";
|
||||||
|
param->no_fusion = false;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mindspore::lite::ConverterImpl cvt;
|
||||||
|
void *dst_buff;
|
||||||
|
auto ret = cvt.Convert(param, nullptr, model_buf, buf_size, &dst_buff, dst_size);
|
||||||
|
if (ret != mindspore::lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Convert model failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return dst_buff;
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "include/api/context.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
|
||||||
|
const std::shared_ptr<mindspore::Context> &context);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CONVERT_RUNTIME_CONVERT_H_
|
|
@ -68,13 +68,15 @@ inline Status DLSoPath(const std::string &benchmark_so, const std::string &targe
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function) {
|
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function,
|
||||||
|
bool runtime_convert = false) {
|
||||||
// do dlopen and export functions from c_dataengine
|
// do dlopen and export functions from c_dataengine
|
||||||
if (handle == nullptr) {
|
if (handle == nullptr) {
|
||||||
MS_LOG(WARNING) << "Input parameter handle cannot be nullptr";
|
MS_LOG(WARNING) << "Input parameter handle cannot be nullptr";
|
||||||
return Status(kMEFailed, "Input parameter handle cannot be nullptr");
|
return Status(kMEFailed, "Input parameter handle cannot be nullptr");
|
||||||
}
|
}
|
||||||
*handle = dlopen(dl_path.c_str(), RTLD_LAZY | RTLD_LOCAL);
|
int mode = runtime_convert ? RTLD_GLOBAL : RTLD_LOCAL;
|
||||||
|
*handle = dlopen(dl_path.c_str(), RTLD_LAZY | mode);
|
||||||
|
|
||||||
auto get_dl_error = []() -> std::string {
|
auto get_dl_error = []() -> std::string {
|
||||||
auto error = dlerror();
|
auto error = dlerror();
|
||||||
|
|
|
@ -21,12 +21,38 @@
|
||||||
#include "extendrt/cxx_api/file_utils.h"
|
#include "extendrt/cxx_api/file_utils.h"
|
||||||
#include "extendrt/utils/tensor_utils.h"
|
#include "extendrt/utils/tensor_utils.h"
|
||||||
#include "mindspore/core/utils/ms_context.h"
|
#include "mindspore/core/utils/ms_context.h"
|
||||||
|
#include "extendrt/mindir_loader/mindir_model/mindir_model_util.h"
|
||||||
|
#include "src/extendrt/convert/runtime_convert.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context) {
|
const std::shared_ptr<Context> &model_context) {
|
||||||
|
const void *model_buff = model_data;
|
||||||
|
size_t model_size = data_size;
|
||||||
|
#ifndef _WIN32
|
||||||
|
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
|
||||||
|
MS_LOG(WARNING) << "Need runtime convert";
|
||||||
|
std::string plugin_path;
|
||||||
|
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(WARNING) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
|
||||||
|
}
|
||||||
|
void *function = nullptr;
|
||||||
|
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
|
||||||
|
if (ret != kSuccess) {
|
||||||
|
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
|
||||||
|
}
|
||||||
|
auto convert =
|
||||||
|
reinterpret_cast<void *(*)(const char *, const size_t &, size_t *, const std::shared_ptr<Context> &)>(function);
|
||||||
|
if (convert != nullptr) {
|
||||||
|
model_buff = convert(static_cast<const char *>(model_data), data_size, &model_size, model_context);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "Not need runtime convert";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
graph_ = std::make_shared<Graph>();
|
graph_ = std::make_shared<Graph>();
|
||||||
auto ret = Serialization::Load(model_data, data_size, model_type, graph_.get());
|
auto ret = Serialization::Load(model_buff, model_size, model_type, graph_.get());
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Serialization::Load model failed.";
|
MS_LOG(ERROR) << "Serialization::Load model failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -47,7 +73,7 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
|
||||||
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
|
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph(), model_data, data_size);
|
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph(), model_buff, model_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
||||||
|
|
|
@ -29,11 +29,21 @@
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "extendrt/infer_session.h"
|
#include "extendrt/infer_session.h"
|
||||||
|
#ifndef _WIN32
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#endif
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelImpl {
|
class ModelImpl {
|
||||||
public:
|
public:
|
||||||
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
|
ModelImpl() : graph_(nullptr), session_(nullptr), context_(nullptr) {}
|
||||||
~ModelImpl() = default;
|
~ModelImpl() {
|
||||||
|
#ifndef _WIN32
|
||||||
|
if (handle_ != nullptr) {
|
||||||
|
(void)dlclose(handle_);
|
||||||
|
handle_ = nullptr;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
Status Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||||
const std::shared_ptr<Context> &model_context);
|
const std::shared_ptr<Context> &model_context);
|
||||||
|
@ -56,6 +66,9 @@ class ModelImpl {
|
||||||
std::shared_ptr<Graph> graph_ = nullptr;
|
std::shared_ptr<Graph> graph_ = nullptr;
|
||||||
std::shared_ptr<InferSession> session_ = nullptr;
|
std::shared_ptr<InferSession> session_ = nullptr;
|
||||||
std::shared_ptr<Context> context_ = nullptr;
|
std::shared_ptr<Context> context_ = nullptr;
|
||||||
|
#ifndef _WIN32
|
||||||
|
void *handle_ = nullptr;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_MODEL_IMPL_H_
|
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_MODEL_MODEL_IMPL_H_
|
||||||
|
|
|
@ -247,7 +247,5 @@ bool CustomAscendKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
||||||
UpdateOutputAddr(outputs);
|
UpdateOutputAddr(outputs);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_KERNEL_FACTORY_REG(KernelMod, CustomAscend, CustomAscendKernelMod);
|
|
||||||
} // namespace acl
|
} // namespace acl
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ir/value.h"
|
#include "ir/value.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
#include "src/common/common.h"
|
||||||
|
|
||||||
namespace mindspore::infer::mindir {
|
namespace mindspore::infer::mindir {
|
||||||
static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
|
static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
|
||||||
|
@ -191,4 +192,34 @@ mindspore::TypeId MindirModelUtil::ProtoTypeToTypeId(int32_t proto_type) {
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MindirModelUtil::NeedRuntimeConvert(const void *model_data, size_t data_size) {
|
||||||
|
bool need_runtime_convert = true;
|
||||||
|
mind_ir::ModelProto model_proto;
|
||||||
|
std::string str(static_cast<const char *>(model_data), data_size);
|
||||||
|
if (model_proto.ParseFromString(str)) {
|
||||||
|
mind_ir::GraphProto *graph_proto = model_proto.mutable_graph();
|
||||||
|
if (graph_proto != nullptr) {
|
||||||
|
for (int i = 0; i < graph_proto->attribute_size(); ++i) {
|
||||||
|
const mind_ir::AttributeProto &attr_proto = graph_proto->attribute(i);
|
||||||
|
if (attr_proto.has_name() && attr_proto.name() == lite::kIsOptimized) {
|
||||||
|
const int attr_type = static_cast<int>(attr_proto.type());
|
||||||
|
if (attr_type != mind_ir::AttributeProto_AttributeType_BOOL) {
|
||||||
|
MS_LOG(ERROR) << "The type of attr optimized value must be bool.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (static_cast<bool>(attr_proto.i())) {
|
||||||
|
need_runtime_convert = false;
|
||||||
|
MS_LOG(DEBUG) << "No need to online infer.";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "Not mindir model";
|
||||||
|
need_runtime_convert = false;
|
||||||
|
}
|
||||||
|
return need_runtime_convert;
|
||||||
|
}
|
||||||
} // namespace mindspore::infer::mindir
|
} // namespace mindspore::infer::mindir
|
||||||
|
|
|
@ -33,6 +33,7 @@ class MindirModelUtil {
|
||||||
static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto);
|
static mindspore::ValuePtr MakeValueFromScalarAttribute(const mind_ir::AttributeProto &attr_proto);
|
||||||
|
|
||||||
static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type);
|
static mindspore::TypeId ProtoTypeToTypeId(int32_t proto_type);
|
||||||
|
static bool NeedRuntimeConvert(const void *model_data, size_t data_size);
|
||||||
};
|
};
|
||||||
} // namespace mindspore::infer::mindir
|
} // namespace mindspore::infer::mindir
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,10 @@ namespace mindspore {
|
||||||
const size_t tensor_max_size = 0x1000000;
|
const size_t tensor_max_size = 0x1000000;
|
||||||
constexpr auto kNameCustomAscend = "CustomAscend";
|
constexpr auto kNameCustomAscend = "CustomAscend";
|
||||||
|
|
||||||
|
SingleOpInferSession::~SingleOpInferSession() {
|
||||||
|
kernel::Factory<kernel::KernelMod>::Instance().UnRegister(kNameCustomAscend);
|
||||||
|
}
|
||||||
|
|
||||||
Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context) {
|
Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context) {
|
||||||
auto device_list = context->MutableDeviceInfo();
|
auto device_list = context->MutableDeviceInfo();
|
||||||
for (const auto &device_info : device_list) {
|
for (const auto &device_info : device_list) {
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace mindspore {
|
||||||
class SingleOpInferSession : public InferSession {
|
class SingleOpInferSession : public InferSession {
|
||||||
public:
|
public:
|
||||||
SingleOpInferSession() = default;
|
SingleOpInferSession() = default;
|
||||||
virtual ~SingleOpInferSession() = default;
|
~SingleOpInferSession() override;
|
||||||
Status Init(const std::shared_ptr<Context> &context) override;
|
Status Init(const std::shared_ptr<Context> &context) override;
|
||||||
Status AscendInit(const std::shared_ptr<Context> &context);
|
Status AscendInit(const std::shared_ptr<Context> &context);
|
||||||
Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override;
|
Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0) override;
|
||||||
|
|
|
@ -32,11 +32,11 @@ bool EraseQuotes(std::string *input_string);
|
||||||
|
|
||||||
bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace);
|
bool FindAndReplaceAll(std::string *input_str, const std::string &search, const std::string &replace);
|
||||||
|
|
||||||
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
|
MS_API std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
|
||||||
|
|
||||||
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const std::string &delimiter);
|
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const std::string &delimiter);
|
||||||
|
|
||||||
bool ConvertIntNum(const std::string &str, int *value);
|
MS_API bool ConvertIntNum(const std::string &str, int *value);
|
||||||
|
|
||||||
bool ConvertDoubleNum(const std::string &str, double *value);
|
bool ConvertDoubleNum(const std::string &str, double *value);
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||||
set(CCSRC_SRC
|
set(CCSRC_SRC
|
||||||
${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc
|
${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc
|
||||||
${CCSRC_DIR}/backend/common/optimizer/visit.cc
|
${CCSRC_DIR}/backend/common/optimizer/visit.cc
|
||||||
${CCSRC_DIR}/backend/common/optimizer/optimizer.cc
|
${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
if(MSLITE_ENABLE_GRAPH_KERNEL)
|
if(MSLITE_ENABLE_GRAPH_KERNEL)
|
||||||
|
@ -26,6 +26,7 @@ if(MSLITE_ENABLE_GRAPH_KERNEL)
|
||||||
${CCSRC_DIR}/common/graph_kernel/graph_kernel_flags.cc
|
${CCSRC_DIR}/common/graph_kernel/graph_kernel_flags.cc
|
||||||
${CCSRC_DIR}/kernel/akg/akg_kernel_json_generator.cc
|
${CCSRC_DIR}/kernel/akg/akg_kernel_json_generator.cc
|
||||||
${CCSRC_DIR}/backend/common/pass/getitem_tuple.cc
|
${CCSRC_DIR}/backend/common/pass/getitem_tuple.cc
|
||||||
|
${CCSRC_DIR}/backend/common/optimizer/optimizer.cc
|
||||||
)
|
)
|
||||||
set(CCSRC_SRC
|
set(CCSRC_SRC
|
||||||
${CCSRC_SRC}
|
${CCSRC_SRC}
|
||||||
|
|
|
@ -58,7 +58,7 @@ Pass *AclPassPlugin::CreateAclPass(const std::shared_ptr<ConverterPara> ¶m)
|
||||||
void *function = nullptr;
|
void *function = nullptr;
|
||||||
auto ret = DLSoOpen(real_path_, "CreateAclPass", &handle_, &function);
|
auto ret = DLSoOpen(real_path_, "CreateAclPass", &handle_, &function);
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << real_path_;
|
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << real_path_ << ", ret: " << ret;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto create_func = reinterpret_cast<mindspore::opt::Pass *(*)(const std::shared_ptr<ConverterPara> &)>(function);
|
auto create_func = reinterpret_cast<mindspore::opt::Pass *(*)(const std::shared_ptr<ConverterPara> &)>(function);
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -57,9 +58,21 @@ constexpr size_t kDependInputNum = 3;
|
||||||
constexpr size_t kDependFirstInputIdx = 1;
|
constexpr size_t kDependFirstInputIdx = 1;
|
||||||
constexpr size_t kTupleGetItemFirstInputIdx = 1;
|
constexpr size_t kTupleGetItemFirstInputIdx = 1;
|
||||||
|
|
||||||
STATUS PreProcForMindIr(const FuncGraphPtr &func_graph) { return lite::RET_OK; }
|
STATUS PreProcForMindIr(const FuncGraphPtr &func_graph, bool offline) { return lite::RET_OK; }
|
||||||
|
|
||||||
STATUS PreProcForTF(const FuncGraphPtr &func_graph) {
|
STATUS PreProcForTF(const FuncGraphPtr &func_graph, bool offline) {
|
||||||
|
if (!offline) {
|
||||||
|
auto format_pass = std::make_shared<opt::SpecifyGraphInputFormat>(Format::NCHW, Format::NHWC);
|
||||||
|
MS_CHECK_TRUE_MSG(format_pass != nullptr, lite::RET_ERROR, "Make shared specify graph input format failed.");
|
||||||
|
if (!format_pass->Run(func_graph)) {
|
||||||
|
MS_LOG(ERROR) << "Run specify graph input format pass failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kToNHWCFormatPass, kDelRedundantTranspose})) {
|
||||||
|
MS_LOG(ERROR) << "To nhwc format failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) {
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass})) {
|
||||||
MS_LOG(ERROR) << "Infer shape pass failed.";
|
MS_LOG(ERROR) << "Infer shape pass failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
@ -85,7 +98,14 @@ STATUS PreProcForTF(const FuncGraphPtr &func_graph) {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) {
|
STATUS PreProcForCaffe(const FuncGraphPtr &func_graph, bool offline) {
|
||||||
|
if (!offline) {
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kDelRedundantTranspose})) {
|
||||||
|
MS_LOG(ERROR) << "Del redundant transpose failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||||
MS_LOG(ERROR) << "To nchw format failed.";
|
MS_LOG(ERROR) << "To nchw format failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
@ -93,7 +113,14 @@ STATUS PreProcForCaffe(const FuncGraphPtr &func_graph) {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) {
|
STATUS PreProcForOnnx(const FuncGraphPtr &func_graph, bool offline) {
|
||||||
|
if (!offline) {
|
||||||
|
if (!lite::RunOptimizerPass(func_graph, {kDelRedundantTranspose})) {
|
||||||
|
MS_LOG(ERROR) << "Del redundant transpose failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
if (!lite::RunOptimizerPass(func_graph, {kInferShapePass, kToNCHWFormatPass, kDelRedundantTranspose})) {
|
||||||
MS_LOG(ERROR) << "To nchw format failed.";
|
MS_LOG(ERROR) << "To nchw format failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
@ -138,14 +165,14 @@ STATUS AclPassImpl::PreProcGraph(const FuncGraphPtr &func_graph) {
|
||||||
MS_LOG(ERROR) << "Common pass failed.";
|
MS_LOG(ERROR) << "Common pass failed.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &)>> fmk_proc_func = {
|
static std::map<converter::FmkType, std::function<STATUS(const FuncGraphPtr &, bool)>> fmk_proc_func = {
|
||||||
{converter::kFmkTypeMs, PreProcForMindIr}, {converter::kFmkTypeTf, PreProcForTF},
|
{converter::kFmkTypeMs, PreProcForMindIr}, {converter::kFmkTypeTf, PreProcForTF},
|
||||||
{converter::kFmkTypeCaffe, PreProcForCaffe}, {converter::kFmkTypeOnnx, PreProcForOnnx},
|
{converter::kFmkTypeCaffe, PreProcForCaffe}, {converter::kFmkTypeOnnx, PreProcForOnnx},
|
||||||
{converter::kFmkTypeTflite, PreProcForTF},
|
{converter::kFmkTypeTflite, PreProcForTF},
|
||||||
};
|
};
|
||||||
if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) {
|
if (fmk_proc_func.find(fmk_type_) != fmk_proc_func.end()) {
|
||||||
auto func = fmk_proc_func.at(fmk_type_);
|
auto func = fmk_proc_func.at(fmk_type_);
|
||||||
if (func(func_graph) != lite::RET_OK) {
|
if (func(func_graph, user_options_cfg_.offline) != lite::RET_OK) {
|
||||||
MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_;
|
MS_LOG(ERROR) << "Pre proc failed, fmk " << fmk_type_;
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -164,10 +191,6 @@ STATUS AclPassImpl::PostProcGraph(const FuncGraphPtr &func_graph) {
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is nullptr.");
|
MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is nullptr.");
|
||||||
manager->Reset();
|
manager->Reset();
|
||||||
if (!user_options_cfg_.offline) {
|
|
||||||
MS_LOG(DEBUG) << "Online model infer no need to change to nhwc format.";
|
|
||||||
return lite::RET_OK;
|
|
||||||
}
|
|
||||||
if (fmk_type_ == converter::kFmkTypeTf) {
|
if (fmk_type_ == converter::kFmkTypeTf) {
|
||||||
MS_LOG(DEBUG) << "Tf no need to change to nhwc format.";
|
MS_LOG(DEBUG) << "Tf no need to change to nhwc format.";
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "tools/converter/optimizer_manager.h"
|
#include "tools/converter/optimizer_manager.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/optimizer/common/pass_manager_extends.h"
|
||||||
#include "ir/primitive.h"
|
#include "ir/primitive.h"
|
||||||
#include "tools/optimizer/fusion/affine_activation_fusion.h"
|
#include "tools/optimizer/fusion/affine_activation_fusion.h"
|
||||||
#include "tools/optimizer/fusion/affine_fusion.h"
|
#include "tools/optimizer/fusion/affine_fusion.h"
|
||||||
|
@ -111,6 +112,10 @@
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
namespace {
|
||||||
|
constexpr auto kOriginalFmkType = "original_fmk_type";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
AnfTransform::AnfTransform() = default;
|
AnfTransform::AnfTransform() = default;
|
||||||
|
|
||||||
AnfTransform::~AnfTransform() = default;
|
AnfTransform::~AnfTransform() = default;
|
||||||
|
@ -195,7 +200,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
|
||||||
}
|
}
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
|
auto fusion_pm = std::make_shared<opt::LitePassManager>("anf fusion pass manager", false);
|
||||||
CHECK_NULL_RETURN(fusion_pm);
|
CHECK_NULL_RETURN(fusion_pm);
|
||||||
|
|
||||||
// The training model only does the fusion of the inference part
|
// The training model only does the fusion of the inference part
|
||||||
|
@ -306,7 +311,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shar
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
|
opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
|
||||||
auto parallel_pm = std::make_shared<opt::PassManager>("anf parallel pass manager", true);
|
auto parallel_pm = std::make_shared<opt::LitePassManager>("anf parallel pass manager", true);
|
||||||
CHECK_NULL_RETURN(parallel_pm);
|
CHECK_NULL_RETURN(parallel_pm);
|
||||||
// 2. preceding parallel pass
|
// 2. preceding parallel pass
|
||||||
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
|
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
|
||||||
|
@ -333,7 +338,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shar
|
||||||
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
|
||||||
CHECK_NULL_RETURN(graph_pm);
|
CHECK_NULL_RETURN(graph_pm);
|
||||||
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
|
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
|
||||||
param->fmk_type == converter::kFmkTypeOnnx) {
|
param->fmk_type == converter::kFmkTypeOnnx) {
|
||||||
|
@ -377,7 +382,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
|
||||||
}
|
}
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
|
||||||
CHECK_NULL_RETURN(convert_pm);
|
CHECK_NULL_RETURN(convert_pm);
|
||||||
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
|
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
|
||||||
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
||||||
|
@ -393,7 +398,7 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::share
|
||||||
|
|
||||||
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
|
auto const_fold_pm = std::make_shared<opt::LitePassManager>("const fold fusion pass manager", false);
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
CHECK_NULL_RETURN(const_fold_pm);
|
CHECK_NULL_RETURN(const_fold_pm);
|
||||||
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
||||||
|
@ -420,7 +425,7 @@ int RunDecreaseTransposePass(const FuncGraphPtr &old_graph, const std::shared_pt
|
||||||
}
|
}
|
||||||
|
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto decrease_trans_pm = std::make_shared<opt::PassManager>("decrease transpose fusion pass manager", false);
|
auto decrease_trans_pm = std::make_shared<opt::LitePassManager>("decrease transpose fusion pass manager", false);
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
CHECK_NULL_RETURN(decrease_trans_pm);
|
CHECK_NULL_RETURN(decrease_trans_pm);
|
||||||
decrease_trans_pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
|
decrease_trans_pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
|
||||||
|
@ -566,44 +571,53 @@ bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
STATUS AnfTransform::ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
const std::shared_ptr<ConverterPara> ¶m) {
|
if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
|
||||||
MS_ASSERT(old_graph != nullptr);
|
MS_LOG(WARNING) << "Run infershape opt pass failed.";
|
||||||
MS_ASSERT(param != nullptr);
|
}
|
||||||
|
auto status = DoFormatForMindIR(old_graph, param);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Do format for mindir failed.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
old_graph->set_attr(kOriginalFmkType, MakeValue(static_cast<int32_t>(param->fmk_type)));
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
auto status = RunConvertPass(old_graph, param);
|
auto status = RunConvertPass(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run convert pass failed.";
|
MS_LOG(ERROR) << "Run convert pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
|
if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
|
||||||
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
|
MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = RunConstFoldPass(old_graph, param);
|
status = RunConstFoldPass(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run const fold pass failed.";
|
MS_LOG(ERROR) << "Run const fold pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RunEliminateRedundantPass(old_graph, param)) {
|
if (!RunEliminateRedundantPass(old_graph, param)) {
|
||||||
MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
|
MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!param->no_fusion) {
|
if (!param->no_fusion) {
|
||||||
status = RunFusionPass(old_graph, param);
|
status = RunFusionPass(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run fusion pass failed.";
|
MS_LOG(ERROR) << "Run fusion pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RunExternalPass(old_graph, registry::POSITION_END)) {
|
if (!RunExternalPass(old_graph, registry::POSITION_END)) {
|
||||||
MS_LOG(ERROR) << "Run external pass failed, place is END";
|
MS_LOG(ERROR) << "Run external pass failed, place is END";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
|
if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
|
||||||
|
@ -616,18 +630,37 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
MS_LOG(ERROR) << "Run transpose opt pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = RunGraphPass(old_graph, param);
|
status = RunGraphPass(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run convert pass failed.";
|
MS_LOG(ERROR) << "Run convert pass failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = RunParallelPass(old_graph, param);
|
status = RunParallelPass(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run convert pass failed.";
|
MS_LOG(ERROR) << "Run convert pass failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
||||||
|
const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
|
MS_ASSERT(old_graph != nullptr);
|
||||||
|
MS_ASSERT(param != nullptr);
|
||||||
|
if (param->no_fusion && param->export_mindir == kMindIR) {
|
||||||
|
if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Proc online transform failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return old_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (RunPass(old_graph, param) != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Proc online transform failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -636,7 +669,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = QATTransform(old_graph, param);
|
auto status = QATTransform(old_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do QATTransform failed.";
|
MS_LOG(ERROR) << "Do QATTransform failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -700,6 +733,11 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> ¶m)
|
||||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr");
|
MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr");
|
||||||
MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr");
|
MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr");
|
||||||
|
if (main_graph->has_attr(kOriginalFmkType)) {
|
||||||
|
auto val_ptr = main_graph->get_attr(kOriginalFmkType);
|
||||||
|
MS_CHECK_TRUE_MSG(val_ptr != nullptr, nullptr, "Val ptr is nullptr.");
|
||||||
|
param->fmk_type = static_cast<converter::FmkType>(GetValue<int32_t>(val_ptr));
|
||||||
|
}
|
||||||
if (!StoreBuiltinPass(param)) {
|
if (!StoreBuiltinPass(param)) {
|
||||||
MS_LOG(ERROR) << "store pass failed.";
|
MS_LOG(ERROR) << "store pass failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -39,6 +39,8 @@ class AnfTransform {
|
||||||
private:
|
private:
|
||||||
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
||||||
|
static int RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
||||||
static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
||||||
static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
@ -66,6 +68,8 @@ class AnfTransform {
|
||||||
static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m);
|
static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
||||||
static bool CheckExternalExtension(const std::shared_ptr<ConverterPara> ¶m);
|
static bool CheckExternalExtension(const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
||||||
|
static STATUS ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -127,47 +127,11 @@ FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara>
|
||||||
return func_graph;
|
return func_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph,
|
|
||||||
const void *buf, const size_t &size) {
|
|
||||||
if (param == nullptr || buf == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Input param is nullptr";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto graph = BuildFuncGraph(param, buf, size);
|
|
||||||
if (graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed.");
|
|
||||||
// funcgraph_transform
|
|
||||||
graph = funcgraph_transform_->Transform(graph, param);
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
|
|
||||||
// export protobuf
|
|
||||||
if (param->export_mindir == kMindIR) {
|
|
||||||
auto status = UpdateFuncGraphInputAndOutputNames(graph);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
status = MindIRSerialize(param, graph);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Export to mindir proto failed";
|
|
||||||
return RET_ERROR;
|
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "Export to mindir success";
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*meta_graph = TransferFuncGraph(param, graph);
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph) {
|
int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph) {
|
||||||
if (param == nullptr) {
|
if (param == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input param is nullptr";
|
MS_LOG(ERROR) << "Input param is nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
param->aclModelOptionCfgParam.om_file_path = param->output_file;
|
param->aclModelOptionCfgParam.om_file_path = param->output_file;
|
||||||
if (!param->config_file.empty() || !param->config_param.empty()) {
|
if (!param->config_file.empty() || !param->config_param.empty()) {
|
||||||
auto ret = InitConfigParam(param);
|
auto ret = InitConfigParam(param);
|
||||||
|
@ -176,7 +140,6 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// load plugin
|
// load plugin
|
||||||
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
|
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
|
||||||
if (!param->plugins_path.empty()) {
|
if (!param->plugins_path.empty()) {
|
||||||
|
@ -191,20 +154,19 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::
|
||||||
dl_loaders.emplace_back(dl_loader);
|
dl_loaders.emplace_back(dl_loader);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto graph = BuildFuncGraph(param);
|
auto graph = BuildFuncGraph(param);
|
||||||
if (graph == nullptr) {
|
return FuncGraphConvert(param, graph, meta_graph, false, nullptr, nullptr);
|
||||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
}
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed");
|
int ConverterImpl::FuncGraphConvert(const std::shared_ptr<ConverterPara> ¶m, FuncGraphPtr graph,
|
||||||
// funcgraph transform
|
schema::MetaGraphT **meta_graph, bool isRuntimeConvert, void **buff, size_t *size) {
|
||||||
graph = funcgraph_transform_->Transform(graph, param);
|
if (param == nullptr || graph == nullptr) {
|
||||||
if (graph == nullptr) {
|
MS_LOG(ERROR) << "Input param or graph is nullptr";
|
||||||
MS_LOG(ERROR) << "Transform anf graph return nullptr";
|
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, RET_ERROR, "funcgraph_transform init failed");
|
||||||
|
graph = funcgraph_transform_->Transform(graph, param);
|
||||||
|
MS_CHECK_TRUE_MSG(graph != nullptr, RET_ERROR, "Transform anf graph return nullptr.");
|
||||||
|
|
||||||
// export protobuf
|
// export protobuf
|
||||||
if (param->export_mindir == kMindIR) {
|
if (param->export_mindir == kMindIR) {
|
||||||
|
@ -218,16 +180,15 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::
|
||||||
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
|
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
status = MindIRSerialize(param, graph);
|
status = MindIRSerialize(param, graph, isRuntimeConvert, buff, size);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Export to mindir failed";
|
MS_LOG(ERROR) << "Export to mindir failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
} else {
|
|
||||||
MS_LOG(DEBUG) << "Export to mindir success";
|
|
||||||
return RET_OK;
|
|
||||||
}
|
}
|
||||||
|
} else { // fb
|
||||||
|
*meta_graph = TransferFuncGraph(param, graph);
|
||||||
}
|
}
|
||||||
*meta_graph = TransferFuncGraph(param, graph);
|
MS_LOG(DEBUG) << "FuncGraph convert success";
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,13 +53,19 @@ class ConverterImpl {
|
||||||
this->model_parser_ = nullptr;
|
this->model_parser_ = nullptr;
|
||||||
}
|
}
|
||||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph);
|
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph);
|
||||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, const void *buf,
|
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, const void *buff,
|
||||||
const size_t &size);
|
const size_t &size, void **dst_buff, size_t *dst_size) {
|
||||||
|
auto graph = BuildFuncGraph(param, buff, size);
|
||||||
|
return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size);
|
||||||
|
}
|
||||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
|
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m);
|
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m);
|
||||||
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m, const void *buf, const size_t &size);
|
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m, const void *buf, const size_t &size);
|
||||||
|
int FuncGraphConvert(const std::shared_ptr<ConverterPara> ¶m, FuncGraphPtr graph, schema::MetaGraphT **meta_graph,
|
||||||
|
bool isRuntimeConvert, void **buff, size_t *size);
|
||||||
|
|
||||||
schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr<ConverterPara> ¶m, FuncGraphPtr func_graph);
|
schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr<ConverterPara> ¶m, FuncGraphPtr func_graph);
|
||||||
|
|
||||||
int InitConfigParam(const std::shared_ptr<ConverterPara> ¶m);
|
int InitConfigParam(const std::shared_ptr<ConverterPara> ¶m);
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/lite_exporter/anf_exporter.h"
|
#include "tools/lite_exporter/anf_exporter.h"
|
||||||
|
#include "tools/optimizer/common/pass_manager_extends.h"
|
||||||
#include "tools/converter/graphdef_transform.h"
|
#include "tools/converter/graphdef_transform.h"
|
||||||
#include "tools/converter/optimizer_manager.h"
|
#include "tools/converter/optimizer_manager.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
|
@ -250,7 +251,7 @@ STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPar
|
||||||
}
|
}
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
CHECK_NULL_RETURN(optimizer);
|
CHECK_NULL_RETURN(optimizer);
|
||||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
|
||||||
CHECK_NULL_RETURN(graph_pm);
|
CHECK_NULL_RETURN(graph_pm);
|
||||||
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
|
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
|
||||||
param->fmk_type == converter::kFmkTypeOnnx) {
|
param->fmk_type == converter::kFmkTypeOnnx) {
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/optimizer/format/to_format_base.h"
|
#include "tools/optimizer/format/to_format_base.h"
|
||||||
|
#include "tools/optimizer/common/pass_manager_extends.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
|
@ -89,7 +90,7 @@ int CommonAnfAdjust(const FuncGraphPtr &func_graph) {
|
||||||
{
|
{
|
||||||
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
|
auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr.");
|
MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr.");
|
||||||
auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
|
auto asylic_pm = std::make_shared<opt::LitePassManager>("asylic pass manager", false);
|
||||||
MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr.");
|
MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr.");
|
||||||
|
|
||||||
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
|
// fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
|
||||||
|
|
|
@ -134,7 +134,7 @@ int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const Fu
|
||||||
MS_LOG(ERROR) << "error occur when check condition of saving together.";
|
MS_LOG(ERROR) << "error occur when check condition of saving together.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
is_fusion_ = !param->no_fusion;
|
||||||
if (save_together_) {
|
if (save_together_) {
|
||||||
ret = SaveMindIRTogether();
|
ret = SaveMindIRTogether();
|
||||||
} else {
|
} else {
|
||||||
|
@ -383,12 +383,15 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
|
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
|
||||||
|
if (isRuntimeConvert_) {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
|
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
|
||||||
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
|
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
|
||||||
if (attr_proto != nullptr) {
|
if (attr_proto != nullptr) {
|
||||||
attr_proto->set_name(kIsOptimized);
|
attr_proto->set_name(kIsOptimized);
|
||||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||||
attr_proto->set_i(1);
|
attr_proto->set_i(is_fusion_);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto realpath = Common::CreatePrefixPath(output_file, true);
|
auto realpath = Common::CreatePrefixPath(output_file, true);
|
||||||
|
@ -414,8 +417,32 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph) {
|
int MindIRSerializer::GetBuffAndSize(void **buff, size_t *size) {
|
||||||
mindspore::lite::MindIRSerializer serializer;
|
if (buff == nullptr || size == nullptr) {
|
||||||
return serializer.Save(param, func_graph);
|
MS_LOG(ERROR) << "param is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
*size = model_proto_.ByteSize();
|
||||||
|
*buff = malloc(*size);
|
||||||
|
if (*buff == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Malloc fail";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
model_proto_.SerializeToArray(*buff, *size);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, bool isRuntimeConvert,
|
||||||
|
void **buff, size_t *size) {
|
||||||
|
mindspore::lite::MindIRSerializer serializer(isRuntimeConvert);
|
||||||
|
auto ret = serializer.Save(param, func_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "MindIR serialize fail";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (isRuntimeConvert) {
|
||||||
|
return serializer.GetBuffAndSize(buff, size);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -31,7 +31,7 @@
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
class MindIRSerializer {
|
class MindIRSerializer {
|
||||||
public:
|
public:
|
||||||
MindIRSerializer() = default;
|
explicit MindIRSerializer(bool isRuntimeConvert) : isRuntimeConvert_(isRuntimeConvert) {}
|
||||||
virtual ~MindIRSerializer() {
|
virtual ~MindIRSerializer() {
|
||||||
if (data_fs_ != nullptr) {
|
if (data_fs_ != nullptr) {
|
||||||
data_fs_->close();
|
data_fs_->close();
|
||||||
|
@ -40,6 +40,7 @@ class MindIRSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph);
|
int Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph);
|
||||||
|
int GetBuffAndSize(void **buff, size_t *size);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int ParserPath(const std::string &output_path);
|
int ParserPath(const std::string &output_path);
|
||||||
|
@ -59,6 +60,8 @@ class MindIRSerializer {
|
||||||
int RemoveQuantParameterHolder(FuncGraphPtr func_graph);
|
int RemoveQuantParameterHolder(FuncGraphPtr func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
bool isRuntimeConvert_ = false;
|
||||||
|
bool is_fusion_ = true;
|
||||||
std::string model_name_;
|
std::string model_name_;
|
||||||
std::string save_path_;
|
std::string save_path_;
|
||||||
std::string save_model_path_;
|
std::string save_model_path_;
|
||||||
|
@ -72,6 +75,7 @@ class MindIRSerializer {
|
||||||
std::shared_ptr<system::FileSystem> fs_{};
|
std::shared_ptr<system::FileSystem> fs_{};
|
||||||
};
|
};
|
||||||
// export func_graph
|
// export func_graph
|
||||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph);
|
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, bool isRuntimeConvert,
|
||||||
|
void **buff, size_t *size);
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
#endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_
|
#endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_
|
||||||
|
|
|
@ -33,6 +33,7 @@
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
#include "tools/converter/parser/parser_utils.h"
|
#include "tools/converter/parser/parser_utils.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
|
|
||||||
|
@ -709,7 +710,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto output_node_list = GetRealNodeUsedList(graph, node);
|
auto output_node_list = Helper::GetRealNodeUsedList(graph, node);
|
||||||
if (output_node_list == nullptr) {
|
if (output_node_list == nullptr) {
|
||||||
MS_LOG(ERROR) << "output node list is nullptr";
|
MS_LOG(ERROR) << "output node list is nullptr";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -15,16 +15,15 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#define USE_DEPRECATED_API
|
#define USE_DEPRECATED_API
|
||||||
#include "backend/common/optimizer/helper.h"
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include <algorithm>
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
ValueNodePtr Helper::CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
|
||||||
if (utils::isa<int>(sexp)) {
|
if (utils::isa<int>(sexp)) {
|
||||||
return NewValueNode(utils::cast<int>(sexp));
|
return NewValueNode(utils::cast<int>(sexp));
|
||||||
}
|
}
|
||||||
|
@ -40,7 +39,7 @@ ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
CNodePtr Helper::CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||||
}
|
}
|
||||||
|
@ -50,7 +49,7 @@ CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
VarNodePtr Helper::CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||||
if (utils::isa<VarPtr>(graph)) {
|
if (utils::isa<VarPtr>(graph)) {
|
||||||
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||||
|
@ -63,8 +62,8 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
AnfNodePtr Helper::HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||||
bool multigraph) {
|
bool multigraph) {
|
||||||
if (primitive_vars == nullptr) {
|
if (primitive_vars == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -90,6 +89,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, Primitive
|
||||||
return CreateCNodeWithGraph(input_nodes, graph);
|
return CreateCNodeWithGraph(input_nodes, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
||||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||||
|
@ -139,6 +139,7 @@ bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// not implement for lite, just for api compatible
|
// not implement for lite, just for api compatible
|
||||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
||||||
const std::vector<AnfNodePtr> &orig_nodes) {
|
const std::vector<AnfNodePtr> &orig_nodes) {
|
||||||
|
@ -152,6 +153,11 @@ CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::ve
|
||||||
|
|
||||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||||
const AnfNodePtr &node) {
|
const AnfNodePtr &node) {
|
||||||
|
return Helper::GetRealNodeUsedList(graph, node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> Helper::GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node) {
|
||||||
if (graph == nullptr || node == nullptr) {
|
if (graph == nullptr || node == nullptr) {
|
||||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
@ -175,9 +181,8 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
|
||||||
return output_node_list;
|
return output_node_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> Helper::GetRealNodeUsedListByOutputIdx(
|
||||||
const AnfNodePtr &node,
|
const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index) {
|
||||||
size_t output_index) {
|
|
||||||
if (graph == nullptr || node == nullptr) {
|
if (graph == nullptr || node == nullptr) {
|
||||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -236,15 +241,12 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||||
return a == b;
|
return a == b;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||||
// To matchCNode and Kernel's type
|
return Helper::SexpToNode(sexp, graph, primitive_vars, multigraph);
|
||||||
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return a.type() == b.type();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
AnfNodePtr Helper::SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||||
|
bool multigraph) {
|
||||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||||
if (primitive_vars == nullptr) {
|
if (primitive_vars == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class Helper {
|
||||||
|
public:
|
||||||
|
static std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node);
|
||||||
|
static std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(
|
||||||
|
const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index);
|
||||||
|
static AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||||
|
bool multigraph);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp);
|
||||||
|
static CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph);
|
||||||
|
static VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph);
|
||||||
|
static AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||||
|
bool multigraph);
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_HELPER_H_
|
|
@ -15,12 +15,10 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
MultiplePatternProcessPass::MultiplePatternProcessPass(const std::string &name, bool multigraph)
|
|
||||||
: NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
|
|
||||||
|
|
||||||
AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
if (patterns_.empty()) {
|
if (patterns_.empty()) {
|
||||||
VarPtr fg = std::make_shared<Var>("RootG");
|
VarPtr fg = std::make_shared<Var>("RootG");
|
||||||
|
@ -29,7 +27,7 @@ AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const
|
||||||
for (const auto &pattern : patterns) {
|
for (const auto &pattern : patterns) {
|
||||||
auto primitive_var = std::make_shared<PrimitiveVarMap>();
|
auto primitive_var = std::make_shared<PrimitiveVarMap>();
|
||||||
MS_CHECK_TRUE_RET(primitive_var != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(primitive_var != nullptr, nullptr);
|
||||||
this->patterns_[pattern.first] = (SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_));
|
this->patterns_[pattern.first] = (Helper::SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_));
|
||||||
this->primitive_var_maps_[pattern.first] = primitive_var;
|
this->primitive_var_maps_[pattern.first] = primitive_var;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,12 +24,14 @@
|
||||||
#include "backend/common/optimizer/node_pass.h"
|
#include "backend/common/optimizer/node_pass.h"
|
||||||
#include "backend/common/optimizer/pattern_engine.h"
|
#include "backend/common/optimizer/pattern_engine.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class MultiplePatternProcessPass : public NodePass {
|
class MultiplePatternProcessPass : public LiteNodePass {
|
||||||
public:
|
public:
|
||||||
explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true);
|
explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||||
|
: LiteNodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
|
||||||
~MultiplePatternProcessPass() override = default;
|
~MultiplePatternProcessPass() override = default;
|
||||||
virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
||||||
virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0;
|
virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0;
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#define USE_DEPRECATED_API
|
#define USE_DEPRECATED_API
|
||||||
#include "backend/common/optimizer/node_pass.h"
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -27,6 +27,11 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
MS_LOG(ERROR) << "stub func";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LiteNodePass::Run(const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/optimizer/node_pass.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class LiteNodePass : public NodePass {
|
||||||
|
public:
|
||||||
|
explicit LiteNodePass(const std::string &name) : NodePass(name) {}
|
||||||
|
~LiteNodePass() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_NODE_PASS_EXTENDS_H_
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "backend/common/optimizer/pass_manager.h"
|
#include "tools/optimizer/common/pass_manager_extends.h"
|
||||||
#ifndef _MSC_VER
|
#ifndef _MSC_VER
|
||||||
#include <sys/time.h>
|
#include <sys/time.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -27,8 +27,6 @@ namespace opt {
|
||||||
constexpr size_t kMaxRepassTimes = 12;
|
constexpr size_t kMaxRepassTimes = 12;
|
||||||
constexpr uint64_t kUSecondInSecond = 1000000;
|
constexpr uint64_t kUSecondInSecond = 1000000;
|
||||||
|
|
||||||
const std::vector<PassPtr> &PassManager::Passes() const { return passes_; }
|
|
||||||
|
|
||||||
void PassManager::AddPass(const PassPtr &pass) {
|
void PassManager::AddPass(const PassPtr &pass) {
|
||||||
if (pass != nullptr) {
|
if (pass != nullptr) {
|
||||||
passes_.push_back(pass);
|
passes_.push_back(pass);
|
||||||
|
@ -36,6 +34,35 @@ void PassManager::AddPass(const PassPtr &pass) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
|
bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
|
||||||
|
MS_LOG(ERROR) << "stub func";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
|
||||||
|
return std::string("hwopt_") + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {
|
||||||
|
MS_LOG(ERROR) << "stub func";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
|
||||||
|
MS_LOG(ERROR) << "stub func";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PassManager::Run(const FuncGraphPtr &func_graph) const {
|
||||||
|
MS_LOG(ERROR) << "stub func";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LitePassManager::AddPass(const PassPtr &pass) {
|
||||||
|
if (pass != nullptr) {
|
||||||
|
passes_.push_back(pass);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool LitePassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
#if defined(_WIN32) || defined(_WIN64)
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
auto start_time = std::chrono::steady_clock::now();
|
auto start_time = std::chrono::steady_clock::now();
|
||||||
|
@ -61,14 +88,11 @@ bool PassManager::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
|
std::string LitePassManager::GetPassFullname(size_t pass_id, const PassPtr &pass) const {
|
||||||
return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
|
return "hwopt_" + name() + "_" + std::to_string(pass_id) + "_" + pass->name();
|
||||||
}
|
}
|
||||||
|
|
||||||
// not implement for lite, just for api compatible
|
bool LitePassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
|
||||||
void PassManager::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const {}
|
|
||||||
|
|
||||||
bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const {
|
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -85,7 +109,7 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PassManager::Run(const FuncGraphPtr &func_graph) const {
|
bool LitePassManager::Run(const FuncGraphPtr &func_graph) const {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/optimizer/pass_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class LitePassManager : public PassManager {
|
||||||
|
public:
|
||||||
|
explicit LitePassManager(const std::string &name = "pm", bool run_only_once = true)
|
||||||
|
: PassManager(name, run_only_once) {}
|
||||||
|
virtual ~LitePassManager() = default;
|
||||||
|
void AddPass(const PassPtr &pass) override;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) const override;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr> &passes) const override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const override;
|
||||||
|
std::string GetPassFullname(size_t pass_id, const PassPtr &pass) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PASS_MANAGER_EXTENDS_H_
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "backend/common/optimizer/pass_manager.h"
|
||||||
|
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||||
|
#include "include/common/utils/anfalgo.h"
|
||||||
|
#include "ir/manager.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
void LitePatternProcessPass::Build() {
|
||||||
|
VarPtr fg = std::make_shared<Var>("RootG");
|
||||||
|
pattern_ = Helper::SexpToNode(DefinePattern(), fg, primitive_vars_.get(), multigraph_);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr LitePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
|
if (pattern_ == nullptr) {
|
||||||
|
Build();
|
||||||
|
}
|
||||||
|
auto primitive = GetCNodePrimitive(pattern_);
|
||||||
|
if (primitive_vars_ == nullptr || equiv_ == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (IsPrimitiveCNode(node, primitive)) {
|
||||||
|
equiv_->clear();
|
||||||
|
EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, equiv_);
|
||||||
|
if (equiv != nullptr && !equiv->empty()) {
|
||||||
|
return Process(func_graph, node, equiv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_
|
||||||
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/common/optimizer/pattern_engine.h"
|
||||||
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class LitePatternProcessPass : public LiteNodePass {
|
||||||
|
public:
|
||||||
|
explicit LitePatternProcessPass(const std::string &name = "", bool multigraph = true)
|
||||||
|
: LiteNodePass(name),
|
||||||
|
multigraph_(multigraph),
|
||||||
|
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||||
|
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
|
||||||
|
equiv_(std::make_shared<Equiv>()) {}
|
||||||
|
~LitePatternProcessPass() override = default;
|
||||||
|
virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
||||||
|
virtual const BaseRef DefinePattern() const = 0;
|
||||||
|
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Build();
|
||||||
|
|
||||||
|
AnfNodePtr pattern_ = nullptr;
|
||||||
|
bool multigraph_ = true;
|
||||||
|
PatternEngine pattern_engine_;
|
||||||
|
PrimitiveVarMapPtr primitive_vars_;
|
||||||
|
EquivPtr equiv_;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_PATTERN_PROCESS_PASS_EXTENDS_H_
|
|
@ -34,6 +34,7 @@
|
||||||
#include "src/common/ops/anf_utils.h"
|
#include "src/common/ops/anf_utils.h"
|
||||||
#include "src/litert/infer_manager.h"
|
#include "src/litert/infer_manager.h"
|
||||||
#include "tools/optimizer/graph/lite_tensor_extractor.h"
|
#include "tools/optimizer/graph/lite_tensor_extractor.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
|
|
||||||
using mindspore::lite::KernelRegistry;
|
using mindspore::lite::KernelRegistry;
|
||||||
using mindspore::lite::Tensor;
|
using mindspore::lite::Tensor;
|
||||||
|
@ -117,7 +118,7 @@ lite::STATUS ReplaceCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_NULL_PTR);
|
MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_NULL_PTR);
|
||||||
if (output_tensors.size() != 1) {
|
if (output_tensors.size() != 1) {
|
||||||
for (size_t k = 0; k < output_tensors.size(); k++) {
|
for (size_t k = 0; k < output_tensors.size(); k++) {
|
||||||
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, k);
|
auto used_node_list = Helper::GetRealNodeUsedListByOutputIdx(func_graph, cnode, k);
|
||||||
if (used_node_list->empty()) {
|
if (used_node_list->empty()) {
|
||||||
MS_LOG(DEBUG) << "this output don't be used by other node.";
|
MS_LOG(DEBUG) << "this output don't be used by other node.";
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -17,14 +17,15 @@
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ELIMINATE_CONCAT_SPLIT_H_
|
||||||
|
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/optimizer/fisson/fisson_util.h"
|
#include "tools/optimizer/fisson/fisson_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class EliminateConcatSplit : public PatternProcessPass {
|
class EliminateConcatSplit : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit EliminateConcatSplit(bool multigraph = true) : PatternProcessPass("eliminate_concat_split", multigraph) {}
|
explicit EliminateConcatSplit(bool multigraph = true)
|
||||||
|
: LitePatternProcessPass("eliminate_concat_split", multigraph) {}
|
||||||
~EliminateConcatSplit() override = default;
|
~EliminateConcatSplit() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "mindspore/lite/include/context.h"
|
#include "mindspore/lite/include/context.h"
|
||||||
#include "mindspore/lite/include/lite_types.h"
|
#include "mindspore/lite/include/lite_types.h"
|
||||||
|
|
|
@ -15,16 +15,16 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "mindspore/ccsrc/backend/common/optimizer/node_pass.h"
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_ITER_NODE_OUTPUTS_H_
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class IterNodeOutputs : public opt::NodePass {
|
class IterNodeOutputs : public opt::LiteNodePass {
|
||||||
public:
|
public:
|
||||||
IterNodeOutputs() : NodePass("iter_node_outputs") {}
|
IterNodeOutputs() : LiteNodePass("iter_node_outputs") {}
|
||||||
~IterNodeOutputs() override = default;
|
~IterNodeOutputs() override = default;
|
||||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/optimizer/fisson/fisson_util.h"
|
#include "tools/optimizer/fisson/fisson_util.h"
|
||||||
#include "tools/optimizer/parallel/split_strategy.h"
|
#include "tools/optimizer/parallel/split_strategy.h"
|
||||||
#include "tools/optimizer/parallel/multi_node_split.h"
|
#include "tools/optimizer/parallel/multi_node_split.h"
|
||||||
|
@ -31,11 +31,11 @@ using mindspore::schema::PrimitiveType;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
|
||||||
class MultiConvSplitPass : public PatternProcessPass {
|
class MultiConvSplitPass : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit MultiConvSplitPass(std::unordered_map<std::string, SplitStrategy> strategys, int32_t fmk_type = -1,
|
explicit MultiConvSplitPass(std::unordered_map<std::string, SplitStrategy> strategys, int32_t fmk_type = -1,
|
||||||
int32_t num = 3, bool multigraph = true)
|
int32_t num = 3, bool multigraph = true)
|
||||||
: PatternProcessPass("multi_conv_split", multigraph),
|
: LitePatternProcessPass("multi_conv_split", multigraph),
|
||||||
strategys_(std::move(strategys)),
|
strategys_(std::move(strategys)),
|
||||||
fmk_type_(fmk_type),
|
fmk_type_(fmk_type),
|
||||||
num_(num) {}
|
num_(num) {}
|
||||||
|
|
|
@ -15,16 +15,16 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "mindspore/ccsrc/backend/common/optimizer/node_pass.h"
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FISSON_NODE_OUT_SHAPES_H_
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class NodeOutShapes : public opt::NodePass {
|
class NodeOutShapes : public opt::LiteNodePass {
|
||||||
public:
|
public:
|
||||||
NodeOutShapes() : NodePass("node_out_shapes") {}
|
NodeOutShapes() : LiteNodePass("node_out_shapes") {}
|
||||||
~NodeOutShapes() override = default;
|
~NodeOutShapes() override = default;
|
||||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -18,13 +18,13 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ADD_CONCAT_ACTIVATION_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ADD_CONCAT_ACTIVATION_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class AddConcatActivationFusion : public PatternProcessPass {
|
class AddConcatActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit AddConcatActivationFusion(bool multigraph = true, const std::string &name = "AddConcatActivationFusion")
|
explicit AddConcatActivationFusion(bool multigraph = true, const std::string &name = "AddConcatActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~AddConcatActivationFusion() override = default;
|
~AddConcatActivationFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -34,6 +34,7 @@ const BaseRef AffineActivationFusion::DefinePattern() const {
|
||||||
|
|
||||||
const AnfNodePtr AffineActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr AffineActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &equiv) const {
|
const EquivPtr &equiv) const {
|
||||||
|
constexpr size_t kAnfPrimitiveIndex = 0;
|
||||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -19,14 +19,14 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class AffineActivationFusion : public PatternProcessPass {
|
class AffineActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit AffineActivationFusion(bool multigraph = true, const std::string &name = "AffineActivationFusion")
|
explicit AffineActivationFusion(bool multigraph = true, const std::string &name = "AffineActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~AffineActivationFusion() override = default;
|
~AffineActivationFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -44,6 +44,7 @@ const BaseRef AffineFusion::DefinePattern() const {
|
||||||
|
|
||||||
const AnfNodePtr AffineFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
const AnfNodePtr AffineFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
const EquivPtr &equiv) const {
|
const EquivPtr &equiv) const {
|
||||||
|
constexpr size_t kAnfPrimitiveIndex = 0;
|
||||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -18,15 +18,15 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_AFFINE_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_AFFINE_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class AffineFusion : public PatternProcessPass {
|
class AffineFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit AffineFusion(bool multigraph = true, const std::string &name = "AffineFusion")
|
explicit AffineFusion(bool multigraph = true, const std::string &name = "AffineFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~AffineFusion() override = default;
|
~AffineFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -17,14 +17,14 @@
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHMATMUL_FUSION_H_
|
||||||
|
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class BatchMatMulFusion : public PatternProcessPass {
|
class BatchMatMulFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit BatchMatMulFusion(bool multigraph = true) : PatternProcessPass("BatchMatMulFusion", multigraph) {}
|
explicit BatchMatMulFusion(bool multigraph = true) : LitePatternProcessPass("BatchMatMulFusion", multigraph) {}
|
||||||
~BatchMatMulFusion() override = default;
|
~BatchMatMulFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHNORM_TO_SCALE_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_BATCHNORM_TO_SCALE_FUSION_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class BatchNormToScaleFusion : public Pass {
|
class BatchNormToScaleFusion : public Pass {
|
||||||
|
|
|
@ -19,16 +19,16 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/cxx_api/converter_para.h"
|
#include "tools/converter/cxx_api/converter_para.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ConvActivationFusion : public PatternProcessPass {
|
class ConvActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> ¶m, bool multigraph = true,
|
explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> ¶m, bool multigraph = true,
|
||||||
const std::string &name = "ConvActivationFusion")
|
const std::string &name = "ConvActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph), param_(param) {}
|
: LitePatternProcessPass(name, multigraph), param_(param) {}
|
||||||
|
|
||||||
~ConvActivationFusion() override = default;
|
~ConvActivationFusion() override = default;
|
||||||
|
|
||||||
|
|
|
@ -17,14 +17,14 @@
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_BIASADD_FUSION_H_
|
||||||
|
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ConvBiasaddFusion : public PatternProcessPass {
|
class ConvBiasaddFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvBiasaddFusion(bool multigraph = true) : PatternProcessPass("ConvBiasaddFusion", multigraph) {}
|
explicit ConvBiasaddFusion(bool multigraph = true) : LitePatternProcessPass("ConvBiasaddFusion", multigraph) {}
|
||||||
~ConvBiasaddFusion() override = default;
|
~ConvBiasaddFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -19,13 +19,13 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ConvConvFusion : public PatternProcessPass {
|
class ConvConvFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvConvFusion(bool multigraph = true) : PatternProcessPass("ConvConvFusion", multigraph) {}
|
explicit ConvConvFusion(bool multigraph = true) : LitePatternProcessPass("ConvConvFusion", multigraph) {}
|
||||||
~ConvConvFusion() override = default;
|
~ConvConvFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,15 +18,15 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/registry/converter_context.h"
|
#include "include/registry/converter_context.h"
|
||||||
|
|
||||||
using mindspore::converter::FmkType;
|
using mindspore::converter::FmkType;
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class ConvTransformFusion : public PatternProcessPass {
|
class ConvTransformFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "ConvTransformFusion")
|
explicit ConvTransformFusion(bool multigraph = true, const std::string &name = "ConvTransformFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ConvTransformFusion() override = default;
|
~ConvTransformFusion() override = default;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -18,14 +18,14 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLE_ACTIVATION_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLE_ACTIVATION_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ConvTupleActivationFusion : public PatternProcessPass {
|
class ConvTupleActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "ConvTupleActivationFusion")
|
explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "ConvTupleActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ConvTupleActivationFusion() override = default;
|
~ConvTupleActivationFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -16,12 +16,13 @@
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_TUPLEGETITEM_FUSION_H_
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class ConvTupleGetItemFusion : public PatternProcessPass {
|
class ConvTupleGetItemFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ConvTupleGetItemFusion(const std::string &name = "ConvTupleGetItemFusion", bool multigraph = true)
|
explicit ConvTupleGetItemFusion(const std::string &name = "ConvTupleGetItemFusion", bool multigraph = true)
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ConvTupleGetItemFusion() override = default;
|
~ConvTupleGetItemFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
|
||||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
|
|
|
@ -19,14 +19,13 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
|
||||||
#include "tools/optimizer/fusion/conv_transform_fusion.h"
|
#include "tools/optimizer/fusion/conv_transform_fusion.h"
|
||||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class FullConnectedFusion : public PatternProcessPass {
|
class FullConnectedFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit FullConnectedFusion(bool multigraph = true) : PatternProcessPass("FullConnectedFusion", multigraph) {}
|
explicit FullConnectedFusion(bool multigraph = true) : LitePatternProcessPass("FullConnectedFusion", multigraph) {}
|
||||||
~FullConnectedFusion() override = default;
|
~FullConnectedFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,16 +18,16 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class GLUFusion : public PatternProcessPass {
|
class GLUFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit GLUFusion(const std::string &name = "glu_fusion", bool multigraph = true)
|
explicit GLUFusion(const std::string &name = "glu_fusion", bool multigraph = true)
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
|
|
||||||
~GLUFusion() override = default;
|
~GLUFusion() override = default;
|
||||||
|
|
||||||
|
|
|
@ -21,17 +21,17 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
/// fuse layer_norm or instance_norm into one operator
|
/// fuse layer_norm or instance_norm into one operator
|
||||||
class GroupNormFusion : public PatternProcessPass {
|
class GroupNormFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit GroupNormFusion(const std::string &name = "GroupNormFusion", bool multigraph = true)
|
explicit GroupNormFusion(const std::string &name = "GroupNormFusion", bool multigraph = true)
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
|
|
||||||
~GroupNormFusion() override = default;
|
~GroupNormFusion() override = default;
|
||||||
|
|
||||||
|
|
|
@ -19,16 +19,16 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
#include "tools/converter/cxx_api/converter_para.h"
|
#include "tools/converter/cxx_api/converter_para.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class MatMulActivationFusion : public PatternProcessPass {
|
class MatMulActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit MatMulActivationFusion(const std::shared_ptr<ConverterPara> ¶m, bool multigraph = true)
|
explicit MatMulActivationFusion(const std::shared_ptr<ConverterPara> ¶m, bool multigraph = true)
|
||||||
: PatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {}
|
: LitePatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {}
|
||||||
~MatMulActivationFusion() = default;
|
~MatMulActivationFusion() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
|
||||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -18,13 +18,13 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MATMUL_MUL_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MATMUL_MUL_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class MatMulMulFusion : public PatternProcessPass {
|
class MatMulMulFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit MatMulMulFusion(bool multigraph = true, const std::string &name = "MatMulMulFusion")
|
explicit MatMulMulFusion(bool multigraph = true, const std::string &name = "MatMulMulFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~MatMulMulFusion() override = default;
|
~MatMulMulFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,13 +18,13 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MUL_ACTIVATION_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MUL_ACTIVATION_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class MulActivationFusion : public PatternProcessPass {
|
class MulActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit MulActivationFusion(bool multigraph = true, const std::string &name = "MulActivationFusion")
|
explicit MulActivationFusion(bool multigraph = true, const std::string &name = "MulActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~MulActivationFusion() = default;
|
~MulActivationFusion() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
|
@ -30,10 +30,10 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
|
||||||
/// fuse layer_norm or instance_norm into one operator
|
/// fuse layer_norm or instance_norm into one operator
|
||||||
class NormFusion : public PatternProcessPass {
|
class NormFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true)
|
explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true)
|
||||||
: PatternProcessPass(name, multigraph) {
|
: LitePatternProcessPass(name, multigraph) {
|
||||||
InitShapeSizeInferFuncMap();
|
InitShapeSizeInferFuncMap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,15 +19,15 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ReshapeReshapeFusion : public PatternProcessPass {
|
class ReshapeReshapeFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "ReshapeReshapeFusion")
|
explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "ReshapeReshapeFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ReshapeReshapeFusion() override = default;
|
~ReshapeReshapeFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,14 +18,14 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SCALE_ACTIVATION_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SCALE_ACTIVATION_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ScaleActivationFusion : public PatternProcessPass {
|
class ScaleActivationFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ScaleActivationFusion(bool multigraph = true, const std::string &name = "ScaleActivationFusion")
|
explicit ScaleActivationFusion(bool multigraph = true, const std::string &name = "ScaleActivationFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ScaleActivationFusion() override = default;
|
~ScaleActivationFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -20,14 +20,14 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "ops/fusion/scale_fusion.h"
|
#include "ops/fusion/scale_fusion.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class ScaleBaseFusion : public PatternProcessPass {
|
class ScaleBaseFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ScaleBaseFusion(std::string name, bool multigraph = true) : PatternProcessPass(name, multigraph) {}
|
explicit ScaleBaseFusion(std::string name, bool multigraph = true) : LitePatternProcessPass(name, multigraph) {}
|
||||||
~ScaleBaseFusion() override = default;
|
~ScaleBaseFusion() override = default;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
|
|
|
@ -19,13 +19,13 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class ScaleScaleFusion : public PatternProcessPass {
|
class ScaleScaleFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit ScaleScaleFusion(bool multigraph = true, const std::string &name = "ScaleScaleFusion")
|
explicit ScaleScaleFusion(bool multigraph = true, const std::string &name = "ScaleScaleFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~ScaleScaleFusion() override = default;
|
~ScaleScaleFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -17,14 +17,14 @@
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
|
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SIGMOID_MUL_FUSION_H_
|
||||||
|
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class SigmoidMulFusion : public PatternProcessPass {
|
class SigmoidMulFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit SigmoidMulFusion(bool multigraph = true) : PatternProcessPass("SigmoidMulFusion", multigraph) {}
|
explicit SigmoidMulFusion(bool multigraph = true) : LitePatternProcessPass("SigmoidMulFusion", multigraph) {}
|
||||||
~SigmoidMulFusion() override = default;
|
~SigmoidMulFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,15 +18,15 @@
|
||||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_FUSION_H_
|
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_SQUEEZE_FUSION_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class SqueezeFusion : public PatternProcessPass {
|
class SqueezeFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit SqueezeFusion(bool multigraph = true, const std::string &name = "SqueezeFusion")
|
explicit SqueezeFusion(bool multigraph = true, const std::string &name = "SqueezeFusion")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~SqueezeFusion() override = default;
|
~SqueezeFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||||
#include "ops/fusion/scale_fusion.h"
|
#include "ops/fusion/scale_fusion.h"
|
||||||
|
@ -28,10 +28,9 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class TensorDotFusion : public PatternProcessPass {
|
class TensorDotFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit TensorDotFusion(bool multigraph = true) : PatternProcessPass("TensorDotFusion", multigraph) {}
|
explicit TensorDotFusion(bool multigraph = true) : LitePatternProcessPass("TensorDotFusion", multigraph) {}
|
||||||
|
|
||||||
~TensorDotFusion() override = default;
|
~TensorDotFusion() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "ops/concat.h"
|
#include "ops/concat.h"
|
||||||
#include "ops/gru.h"
|
#include "ops/gru.h"
|
||||||
#include "ops/split.h"
|
#include "ops/split.h"
|
||||||
|
@ -300,7 +301,7 @@ AnfNodePtr TfBidirectionGruFusion::GetCondGraphPattern(const PrimitiveVarMapPtr
|
||||||
MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(is_return != nullptr, nullptr);
|
||||||
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
|
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
|
||||||
VarPtr is_fg = std::make_shared<Var>("RootG");
|
VarPtr is_fg = std::make_shared<Var>("RootG");
|
||||||
auto pattern = SexpToNode(return_ref, is_fg, primitive_vars.get(), true);
|
auto pattern = Helper::SexpToNode(return_ref, is_fg, primitive_vars.get(), true);
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -373,7 +374,7 @@ AnfNodePtr TfBidirectionGruFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr
|
||||||
|
|
||||||
VarPtr is_fg = std::make_shared<Var>("RootG");
|
VarPtr is_fg = std::make_shared<Var>("RootG");
|
||||||
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
|
||||||
auto pattern = SexpToNode(return_node, is_fg, primitive_vars.get(), true);
|
auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
|
@ -29,11 +29,11 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
constexpr size_t kWhileUniqInputsLength = 6;
|
constexpr size_t kWhileUniqInputsLength = 6;
|
||||||
// fuse tf 2.x bidirection_gru into MSLITE GRU
|
// fuse tf 2.x bidirection_gru into MSLITE GRU
|
||||||
class TfBidirectionGruFusion : public PatternProcessPass {
|
class TfBidirectionGruFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
|
explicit TfBidirectionGruFusion(int num_fw_vars = kWhileUniqInputsLength, int num_bw_vars = kWhileUniqInputsLength,
|
||||||
const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true)
|
const std::string &name = "TfBidirectionGruFusion", bool multi_graph = true)
|
||||||
: PatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {}
|
: LitePatternProcessPass(name, multi_graph), num_fw_vars_(num_fw_vars), num_bw_vars_(num_bw_vars) {}
|
||||||
|
|
||||||
~TfBidirectionGruFusion() override = default;
|
~TfBidirectionGruFusion() override = default;
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -179,7 +180,7 @@ AnfNodePtr TfLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &primi
|
||||||
|
|
||||||
VarPtr is_fg = std::make_shared<Var>("RootG");
|
VarPtr is_fg = std::make_shared<Var>("RootG");
|
||||||
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(is_fg != nullptr, nullptr);
|
||||||
auto pattern = SexpToNode(return_node, is_fg, primitive_vars.get(), true);
|
auto pattern = Helper::SexpToNode(return_node, is_fg, primitive_vars.get(), true);
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
|
@ -169,7 +170,7 @@ bool TfliteLstmCellFusion::Init() const {
|
||||||
TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
|
TfliteLstmCellFusion::TfliteLstmCellFusion(const std::string &name, bool multigraph, int input_length, int var_num,
|
||||||
int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
|
int cond_nodes_num, int cond_cnodes_num, int body_nodes_num,
|
||||||
int body_cnodes_num)
|
int body_cnodes_num)
|
||||||
: PatternProcessPass(name, multigraph) {
|
: LitePatternProcessPass(name, multigraph) {
|
||||||
/*
|
/*
|
||||||
* input vars for lstm while node
|
* input vars for lstm while node
|
||||||
* 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
|
* 0:cond_ 1:body_ 2:time_ 3:limit1_ 4:output_ 5:cell_ 6:hidden_ 7:limit2_ 8:input_
|
||||||
|
@ -207,7 +208,7 @@ AnfNodePtr TfliteLstmCellFusion::GetCondGraphPattern(const PrimitiveVarMapPtr &p
|
||||||
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
|
VectorRef return_ref = VectorRef({is_return, logicaland_ref});
|
||||||
VarPtr fg = std::make_shared<Var>("RootG");
|
VarPtr fg = std::make_shared<Var>("RootG");
|
||||||
MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
|
||||||
auto pattern = SexpToNode(return_ref, fg, primitive_vars.get(), true);
|
auto pattern = Helper::SexpToNode(return_ref, fg, primitive_vars.get(), true);
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -278,7 +279,7 @@ AnfNodePtr TfliteLstmCellFusion::GetBodyGraphPattern(const PrimitiveVarMapPtr &p
|
||||||
|
|
||||||
VarPtr fg = std::make_shared<Var>("RootG");
|
VarPtr fg = std::make_shared<Var>("RootG");
|
||||||
MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(fg != nullptr, nullptr);
|
||||||
auto pattern = SexpToNode(return_node, fg, primitive_vars.get(), true);
|
auto pattern = Helper::SexpToNode(return_node, fg, primitive_vars.get(), true);
|
||||||
return pattern;
|
return pattern;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,13 +19,13 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
#include "include/common/utils/utils.h"
|
#include "include/common/utils/utils.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class TfliteLstmCellFusion : public PatternProcessPass {
|
class TfliteLstmCellFusion : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit TfliteLstmCellFusion(const std::string &name = "TfliteLstmCellFusion", bool multigraph = true,
|
explicit TfliteLstmCellFusion(const std::string &name = "TfliteLstmCellFusion", bool multigraph = true,
|
||||||
int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0,
|
int input_length = 0, int var_num = 0, int cond_nodes_num = 0, int cond_cnodes_num = 0,
|
||||||
|
|
|
@ -19,14 +19,14 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "backend/common/optimizer/optimizer.h"
|
#include "tools/optimizer/common/pattern_process_pass_extends.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class AddTensorArray : public PatternProcessPass {
|
class AddTensorArray : public LitePatternProcessPass {
|
||||||
public:
|
public:
|
||||||
explicit AddTensorArray(bool multigraph = true, const std::string &name = "add_tensor_array")
|
explicit AddTensorArray(bool multigraph = true, const std::string &name = "add_tensor_array")
|
||||||
: PatternProcessPass(name, multigraph) {}
|
: LitePatternProcessPass(name, multigraph) {}
|
||||||
~AddTensorArray() override = default;
|
~AddTensorArray() override = default;
|
||||||
const BaseRef DefinePattern() const override;
|
const BaseRef DefinePattern() const override;
|
||||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/optimizer/common/helper.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
@ -1514,7 +1515,7 @@ bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
|
||||||
if (output_tensor_num > 1) {
|
if (output_tensor_num > 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
|
auto output_node_list = Helper::GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
|
||||||
if (output_node_list->size() > 1) { // referenced by multi nodes
|
if (output_node_list->size() > 1) { // referenced by multi nodes
|
||||||
if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
|
if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
|
||||||
this_time_changed = true;
|
this_time_changed = true;
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
#include "tools/optimizer/common/node_pass_extends.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "backend/common/optimizer/node_pass.h"
|
|
||||||
#include "tools/optimizer/parallel/split_strategy.h"
|
#include "tools/optimizer/parallel/split_strategy.h"
|
||||||
#include "tools/optimizer/parallel/operator_info.h"
|
#include "tools/optimizer/parallel/operator_info.h"
|
||||||
|
|
||||||
|
@ -30,10 +30,10 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
class ParallelPass : public opt::NodePass {
|
class ParallelPass : public opt::LiteNodePass {
|
||||||
public:
|
public:
|
||||||
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type)
|
explicit ParallelPass(const std::unordered_map<std::string, SplitStrategy> &strategys, const int32_t fmk_type)
|
||||||
: NodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {}
|
: LiteNodePass("parallel_pass"), split_strategys_(strategys), fmk_type_(fmk_type) {}
|
||||||
~ParallelPass() override = default;
|
~ParallelPass() override = default;
|
||||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue