open ConverterContext and open special nodes

This commit is contained in:
xuanyue 2021-12-17 15:12:31 +08:00
parent cf04d2eb66
commit 1dd50a0a6f
34 changed files with 628 additions and 329 deletions

View File

@ -503,8 +503,6 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${UTILS_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils install(FILES ${UTILS_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "schema_generated.h" EXCLUDE) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "schema_generated.h" EXCLUDE)
install(DIRECTORY ${flatbuffers_INC}/ DESTINATION ${CONVERTER_ROOT_DIR}/include/third_party install(DIRECTORY ${flatbuffers_INC}/ DESTINATION ${CONVERTER_ROOT_DIR}/include/third_party

View File

@ -103,6 +103,26 @@ inline std::unordered_map<std::string, std::string> UnorderedMapCharToString(
return ret; return ret;
} }
inline std::map<std::vector<char>, std::vector<char>> MapStringToVectorChar(
const std::map<std::string, std::string> &s) {
std::map<std::vector<char>, std::vector<char>> ret;
std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) {
return std::pair<std::vector<char>, std::vector<char>>(std::vector<char>(str.first.begin(), str.first.end()),
std::vector<char>(str.second.begin(), str.second.end()));
});
return ret;
}
inline std::map<std::string, std::string> MapVectorCharToString(
const std::map<std::vector<char>, std::vector<char>> &c) {
std::map<std::string, std::string> ret;
std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) {
return std::pair<std::string, std::string>(std::string(ch.first.begin(), ch.first.end()),
std::string(ch.second.begin(), ch.second.end()));
});
return ret;
}
inline std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ClassIndexStringToChar( inline std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ClassIndexStringToChar(
const std::vector<std::pair<std::string, std::vector<int32_t>>> &s) { const std::vector<std::pair<std::string, std::vector<int32_t>>> &s) {
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ret; std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ret;

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/make_tuple.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameMakeTuple, MakeTuple);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
#define MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
#include "ops/primitive_c.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMakeTuple = "MakeTuple";
/// \brief MakeTuple op is used to pack multiple nodes into a whole, which is only used in FuncGraph.
class MS_CORE_API MakeTuple : public PrimitiveC {
public:
/// \brief Constructor.
MakeTuple() : PrimitiveC(kNameMakeTuple) {}
/// \brief Destructor.
~MakeTuple() = default;
MS_DECLARE_PARENT(MakeTuple, PrimitiveC);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAKE_TUPLE_H_

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/return.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameReturn, Return);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_RETURN_H_
#define MINDSPORE_CORE_OPS_RETURN_H_
#include "ops/primitive_c.h"
namespace mindspore {
namespace ops {
constexpr auto kNameReturn = "Return";
/// \brief Return op is the output node, which is only used in FuncGraph.
class MS_CORE_API Return : public PrimitiveC {
public:
/// \brief Constructor.
Return() : PrimitiveC(kNameReturn) {}
/// \brief Destructor.
~Return() = default;
MS_DECLARE_PARENT(Return, PrimitiveC);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_RETURN_H_

View File

@ -0,0 +1,23 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/tuple_get_item.h"
namespace mindspore {
namespace ops {
REGISTER_PRIMITIVE_C(kNameTupleGetItem, TupleGetItem);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
#define MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
#include "ops/primitive_c.h"
namespace mindspore {
namespace ops {
constexpr auto kNameTupleGetItem = "TupleGetItem";
/// \brief TupleGetItem op is added to the multi-output node to describe which output of the node, which is only used
/// in FuncGraph.
class MS_CORE_API TupleGetItem : public PrimitiveC {
public:
/// \brief Constructor.
TupleGetItem() : PrimitiveC(kNameTupleGetItem) {}
/// \brief Destructor.
~TupleGetItem() = default;
MS_DECLARE_PARENT(TupleGetItem, PrimitiveC);
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_

View File

@ -151,7 +151,7 @@ build_lite() {
CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake
fi fi
BRANCH_NAME=nnie_3516_master BRANCH_NAME=nnie_3516_master_2
if [[ ("${MSLITE_REGISTRY_DEVICE}" == "Hi3516D" || "${TOOLCHAIN_NAME}" == "himix200") && "${local_lite_platform}" == "arm32" ]]; then if [[ ("${MSLITE_REGISTRY_DEVICE}" == "Hi3516D" || "${TOOLCHAIN_NAME}" == "himix200") && "${local_lite_platform}" == "arm32" ]]; then
TOOLCHAIN_NAME="himix200" TOOLCHAIN_NAME="himix200"
MSLITE_REGISTRY_DEVICE=Hi3516D MSLITE_REGISTRY_DEVICE=Hi3516D

View File

@ -21,6 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "include/lite_utils.h" #include "include/lite_utils.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore { namespace mindspore {
namespace converter { namespace converter {
@ -53,12 +54,30 @@ class MS_API ConverterContext {
/// \brief Static method to set exported model's output name as needed by users. /// \brief Static method to set exported model's output name as needed by users.
/// ///
/// \param[in] output_names Define model's output name, the order of which is consistent with the original model. /// \param[in] output_names Define model's output name, the order of which is consistent with the original model.
static void SetGraphOutputTensorNames(const std::vector<std::string> &output_names); static void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
SetGraphOutputTensorNames(VectorStringToChar(output_names));
}
/// \brief Static method to obtain the outputs' name. /// \brief Static method to obtain the outputs' name.
/// ///
/// \return the outputs' name. /// \return the outputs' name.
static std::vector<std::string> GetGraphOutputTensorNames(); static std::vector<std::string> GetGraphOutputTensorNames() {
return VectorCharToString(GetGraphOutputTensorNamesInChar());
}
/// \brief Static method to get configure information which is used only by external extension.
///
/// \param[in] section Define config section name.
///
/// \return config key-value map.
static std::map<std::string, std::string> GetConfigInfo(const std::string &section) {
return MapVectorCharToString(GetConfigInfo(StringToChar(section)));
}
private:
static void SetGraphOutputTensorNames(const std::vector<std::vector<char>> &&output_names);
static std::vector<std::vector<char>> GetGraphOutputTensorNamesInChar();
static std::map<std::vector<char>, std::vector<char>> GetConfigInfo(const std::vector<char> &&section);
}; };
} // namespace converter } // namespace converter
} // namespace mindspore } // namespace mindspore

View File

@ -18,7 +18,8 @@
#include <memory> #include <memory>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/ops/ops_def.h" #include "ops/make_tuple.h"
#include "ops/return.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ops/fusion/conv2d_fusion.h" #include "ops/fusion/conv2d_fusion.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h" #include "backend/kernel_compiler/cpu/nnacl/op_base.h"
@ -93,7 +94,7 @@ CNodePtr FusionInoutTest::AddReturn(const FuncGraphPtr &graph, const std::vector
if (return_inputs.size() == 1) { if (return_inputs.size() == 1) {
return_input = return_inputs.front(); return_input = return_inputs.front();
} else { } else {
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed"; MS_LOG(ERROR) << "new MakeTuple failed";
return nullptr; return nullptr;
@ -107,7 +108,7 @@ CNodePtr FusionInoutTest::AddReturn(const FuncGraphPtr &graph, const std::vector
return_input = return_input_cnode; return_input = return_input_cnode;
} }
auto return_prim = std::make_shared<lite::Return>(); auto return_prim = std::make_shared<ops::Return>();
MS_CHECK_TRUE_MSG(return_prim != nullptr, nullptr, "create return primitivec failed"); MS_CHECK_TRUE_MSG(return_prim != nullptr, nullptr, "create return primitivec failed");
auto return_cnode = graph->NewCNode(return_prim, {return_input}); auto return_cnode = graph->NewCNode(return_prim, {return_input});
MS_CHECK_TRUE_MSG(return_cnode != nullptr, nullptr, "create Return failed"); MS_CHECK_TRUE_MSG(return_cnode != nullptr, nullptr, "create Return failed");

View File

@ -29,8 +29,10 @@
#include "ops/call.h" #include "ops/call.h"
#include "ops/control_depend.h" #include "ops/control_depend.h"
#include "ops/depend.h" #include "ops/depend.h"
#include "tools/converter/ops/ops_def.h"
#include "ops/quant_dtype_cast.h" #include "ops/quant_dtype_cast.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
#include "tools/converter/quant_param_holder.h" #include "tools/converter/quant_param_holder.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/quantizer/bitpacking.h" #include "tools/converter/quantizer/bitpacking.h"
@ -428,8 +430,8 @@ int AnfExporter::SetTailCallForNonOutput() {
} }
bool AnfExporter::CaseToContinue(const string &prim_name) { bool AnfExporter::CaseToContinue(const string &prim_name) {
return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::lite::kNameTupleGetItem || return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::ops::kNameTupleGetItem ||
prim_name == mindspore::lite::kNameMakeTuple || prim_name == "make_tuple"; prim_name == mindspore::ops::kNameMakeTuple || prim_name == "make_tuple";
} }
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
@ -464,7 +466,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
break; break;
} }
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) { if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
node->name = mindspore::lite::kNameReturn; node->name = mindspore::ops::kNameReturn;
ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get()); ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get());
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpOutputN failed"; MS_LOG(ERROR) << "SetOpOutputN failed";

View File

@ -27,8 +27,8 @@
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/converter/ops/ops_def.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "ops/make_tuple.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -45,7 +45,7 @@ int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr>
graph->set_output(outputs.front(), false); graph->set_output(outputs.front(), false);
return RET_OK; return RET_OK;
} }
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(DEBUG) << "new MakeTuple failed"; MS_LOG(DEBUG) << "new MakeTuple failed";
return lite::RET_NULL_PTR; return lite::RET_NULL_PTR;

View File

@ -17,7 +17,6 @@
#include "tools/converter/adapter/acl/acl_pass_impl.h" #include "tools/converter/adapter/acl/acl_pass_impl.h"
#include <set> #include <set>
#include <map> #include <map>
#include "tools/converter/ops/ops_def.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h" #include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
@ -28,6 +27,7 @@
#include "include/registry/pass_registry.h" #include "include/registry/pass_registry.h"
#include "common/utils.h" #include "common/utils.h"
#include "ops/custom.h" #include "ops/custom.h"
#include "ops/tuple_get_item.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "cxx_api/model/acl/model_converter.h" #include "cxx_api/model/acl/model_converter.h"
#include "backend/kernel_compiler/cpu/nnacl/op_base.h" #include "backend/kernel_compiler/cpu/nnacl/op_base.h"
@ -570,7 +570,7 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
} }
} else { } else {
for (size_t j = 0; j < graph_outputs_.size(); ++j) { for (size_t j = 0; j < graph_outputs_.size(); ++j) {
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "New TupleGetItem failed for output " << j; MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
return lite::RET_ERROR; return lite::RET_ERROR;

View File

@ -20,7 +20,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "tools/converter/adapter/acl/common/utils.h" #include "tools/converter/adapter/acl/common/utils.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h" #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -30,6 +29,7 @@
#include "ops/batch_norm.h" #include "ops/batch_norm.h"
#include "ops/fused_batch_norm.h" #include "ops/fused_batch_norm.h"
#include "ops/stack.h" #include "ops/stack.h"
#include "ops/tuple_get_item.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -44,7 +44,7 @@ const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kName
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) { CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
CNodePtr get_item_cnode = nullptr; CNodePtr get_item_cnode = nullptr;
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "New TupleGetItem failed"; MS_LOG(ERROR) << "New TupleGetItem failed";
return nullptr; return nullptr;

View File

@ -18,6 +18,7 @@
#include "tools/common/parse_config_utils.h" #include "tools/common/parse_config_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -37,35 +38,44 @@ int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
return ret; return ret;
} }
ret = ParseDataPreProcessString(maps); ret = ParseDataPreProcessString(maps);
(void)maps.erase(kDataPreprocessParam);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseDataPreProcessString failed."; MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
return ret; return ret;
} }
ret = ParseCommonQuantString(maps); ret = ParseCommonQuantString(maps);
(void)maps.erase(kCommonQuantParam);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseCommonQuantString failed."; MS_LOG(ERROR) << "ParseCommonQuantString failed.";
return ret; return ret;
} }
ret = ParseMixedBitQuantString(maps); ret = ParseMixedBitQuantString(maps);
(void)maps.erase(kMixedBitWeightQuantParam);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseMixedBitQuantString failed."; MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
return ret; return ret;
} }
ret = ParseFullQuantString(maps); ret = ParseFullQuantString(maps);
(void)maps.erase(kFullQuantParam);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseFullQuantString failed."; MS_LOG(ERROR) << "ParseFullQuantString failed.";
return ret; return ret;
} }
ret = ParseRegistryInfoString(maps); ret = ParseRegistryInfoString(maps);
(void)maps.erase(kRegistry);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseExtendedintegrationString failed."; MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
return ret; return ret;
} }
ret = ParseAclOptionCfgString(maps); ret = ParseAclOptionCfgString(maps);
(void)maps.erase(kAclOptionParam);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseAclOptionCfgString failed."; MS_LOG(ERROR) << "ParseAclOptionCfgString failed.";
return ret; return ret;
} }
for (const auto &config_info : maps) {
ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
}
return RET_OK; return RET_OK;
} }

View File

@ -15,22 +15,41 @@
*/ */
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
#include <string>
#include <vector> #include <vector>
#include "include/registry/converter_context.h" #include "include/registry/converter_context.h"
namespace mindspore { namespace mindspore {
namespace converter { namespace converter {
void ConverterContext::SetGraphOutputTensorNames(const std::vector<std::string> &output_names) { void ConverterContext::SetGraphOutputTensorNames(const std::vector<std::vector<char>> &&output_names) {
auto converter_context = lite::ConverterInnerContext::GetInstance(); auto converter_context = lite::ConverterInnerContext::GetInstance();
MS_ASSERT(converter_context != nullptr); if (converter_context == nullptr) {
converter_context->SetGraphOutputTensorNames(output_names); MS_LOG(ERROR) << "Set graph output's names failed.";
return;
}
converter_context->SetGraphOutputTensorNames(VectorCharToString(output_names));
} }
std::vector<std::string> ConverterContext::GetGraphOutputTensorNames() { std::vector<std::vector<char>> ConverterContext::GetGraphOutputTensorNamesInChar() {
auto converter_context = lite::ConverterInnerContext::GetInstance(); auto converter_context = lite::ConverterInnerContext::GetInstance();
MS_ASSERT(converter_context != nullptr); if (converter_context == nullptr) {
return converter_context->GetGraphOutputTensorNames(); MS_LOG(ERROR) << "Get graph output's names failed.";
return {};
}
return VectorStringToChar(converter_context->GetGraphOutputTensorNames());
}
std::map<std::vector<char>, std::vector<char>> ConverterContext::GetConfigInfo(const std::vector<char> &&section) {
auto converter_context = lite::ConverterInnerContext::GetInstance();
if (converter_context == nullptr) {
MS_LOG(ERROR) << "Get config information only used by external extension failed.";
return {};
}
auto &external_used_config_infos = converter_context->GetExternalUsedConfigInfos();
if (external_used_config_infos.find(CharToString(section)) == external_used_config_infos.end()) {
MS_LOG(ERROR) << "This section " << section << " config info is not existed.";
return {};
}
return MapStringToVectorChar(external_used_config_infos.at(CharToString(section)));
} }
} // namespace converter } // namespace converter
} // namespace mindspore } // namespace mindspore

View File

@ -112,6 +112,18 @@ class ConverterInnerContext {
const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; } const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; }
void SetExternalUsedConfigInfos(const std::string &section,
const std::map<std::string, std::string> &external_infos) {
if (external_used_config_infos_.find(section) != external_used_config_infos_.end()) {
MS_LOG(WARNING) << "This section " << section << " has been saved. Now, the content will be overwrite.";
}
external_used_config_infos_.emplace(section, external_infos);
}
const std::map<std::string, std::map<std::string, std::string>> &GetExternalUsedConfigInfos() {
return external_used_config_infos_;
}
private: private:
ConverterInnerContext() = default; ConverterInnerContext() = default;
virtual ~ConverterInnerContext() = default; virtual ~ConverterInnerContext() = default;
@ -119,6 +131,7 @@ class ConverterInnerContext {
std::map<int32_t, int32_t> graph_output_data_type_map_; std::map<int32_t, int32_t> graph_output_data_type_map_;
std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_; std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
std::vector<std::string> graph_output_tensor_names_; std::vector<std::string> graph_output_tensor_names_;
std::map<std::string, std::map<std::string, std::string>> external_used_config_infos_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -18,7 +18,8 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include <algorithm> #include <algorithm>
#include "tools/converter/ops/ops_def.h" #include "ops/make_tuple.h"
#include "ops/return.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
#include "tools/converter/quant_param_holder.h" #include "tools/converter/quant_param_holder.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@ -130,7 +131,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
if (after_fg->get_inputs().size() > 1) { if (after_fg->get_inputs().size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs = after_fg->get_inputs(); std::vector<AnfNodePtr> make_tuple_inputs = after_fg->get_inputs();
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed"; MS_LOG(ERROR) << "new MakeTuple failed";
return nullptr; return nullptr;
@ -141,7 +142,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
auto make_tuple_cnode = after_fg->NewCNode(make_tuple_inputs); auto make_tuple_cnode = after_fg->NewCNode(make_tuple_inputs);
MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "Failed to create C node."); MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "Failed to create C node.");
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return nullptr; return nullptr;
@ -154,7 +155,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
after_fg->set_return(cnode); after_fg->set_return(cnode);
} else { } else {
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return nullptr; return nullptr;

View File

@ -16,7 +16,6 @@
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ #ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_ #define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
#include "schema/inner/model_generated.h"
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
using mindspore::ops::PrimitiveC; using mindspore::ops::PrimitiveC;
@ -43,9 +42,6 @@ ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
ADD_CONVERTER_ONLY_OP(TensorArrayV3); ADD_CONVERTER_ONLY_OP(TensorArrayV3);
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3); ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
ADD_CONVERTER_ONLY_OP(Constant); ADD_CONVERTER_ONLY_OP(Constant);
ADD_CONVERTER_ONLY_OP(MakeTuple);
ADD_CONVERTER_ONLY_OP(TupleGetItem);
ADD_CONVERTER_ONLY_OP(Return);
ADD_CONVERTER_ONLY_OP(Merge); ADD_CONVERTER_ONLY_OP(Merge);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -25,7 +25,6 @@
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h" #include "tools/common/protobuf_utils.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/ops/ops_def.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
@ -35,6 +34,9 @@
#include "tools/converter/parser/unify_format.h" #include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
using mindspore::converter::kFmkTypeCaffe; using mindspore::converter::kFmkTypeCaffe;
namespace mindspore::lite { namespace mindspore::lite {
@ -416,7 +418,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
caffeInspector.InspectModel(caffe_model_); caffeInspector.InspectModel(caffe_model_);
if (caffeInspector.GetGraphOutput().size() > 1) { if (caffeInspector.GetGraphOutput().size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
MSLITE_CHECK_PTR(make_tuple_prim_ptr); MSLITE_CHECK_PTR(make_tuple_prim_ptr);
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr); auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
MSLITE_CHECK_PTR(make_tuple_prim); MSLITE_CHECK_PTR(make_tuple_prim);
@ -434,7 +436,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");
std::vector<AnfNodePtr> op_inputs; std::vector<AnfNodePtr> op_inputs;
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
MSLITE_CHECK_PTR(return_prim_ptr); MSLITE_CHECK_PTR(return_prim_ptr);
auto value_node = NewValueNode(return_prim_ptr); auto value_node = NewValueNode(return_prim_ptr);
MSLITE_CHECK_PTR(value_node); MSLITE_CHECK_PTR(value_node);
@ -445,7 +447,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
res_graph_->set_return(cnode); res_graph_->set_return(cnode);
} else { } else {
auto returnPrim = std::make_shared<lite::Return>(); auto returnPrim = std::make_shared<ops::Return>();
MSLITE_CHECK_PTR(returnPrim); MSLITE_CHECK_PTR(returnPrim);
auto valueNode = NewValueNode(returnPrim); auto valueNode = NewValueNode(returnPrim);
MSLITE_CHECK_PTR(valueNode); MSLITE_CHECK_PTR(valueNode);
@ -584,7 +586,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract); abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -26,7 +26,6 @@
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h" #include "tools/common/protobuf_utils.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "tools/converter/ops/ops_def.h"
#include "ops/tensor_list_stack.h" #include "ops/tensor_list_stack.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
@ -38,6 +37,9 @@
#include "tools/converter/parser/unify_format.h" #include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
using mindspore::converter::kFmkTypeOnnx; using mindspore::converter::kFmkTypeOnnx;
namespace mindspore { namespace mindspore {
@ -161,7 +163,7 @@ CNodePtr GetCNodeFromControlFlowNodesMap(
STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) { STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
MS_ASSERT(anf_graph != nullptr); MS_ASSERT(anf_graph != nullptr);
auto return_prim = std::make_shared<lite::Return>(); auto return_prim = std::make_shared<ops::Return>();
if (return_prim == nullptr) { if (return_prim == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -227,7 +229,7 @@ STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract_tensor); abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -326,7 +328,7 @@ STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPt
} }
if (onnx_graph.output_size() > 1) { if (onnx_graph.output_size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs; std::vector<AnfNodePtr> make_tuple_inputs;
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed"; MS_LOG(ERROR) << "new MakeTuple failed";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -23,6 +23,7 @@
#include "tools/converter/ops/ops_def.h" #include "tools/converter/ops/ops_def.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "ops/return.h"
namespace mindspore::opt { namespace mindspore::opt {
@ -157,7 +158,7 @@ FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::s
} }
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return nullptr; return nullptr;

View File

@ -19,7 +19,9 @@
#include <memory> #include <memory>
#include <deque> #include <deque>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "tools/converter/ops/ops_def.h" #include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
#include "tools/converter/ops/while.h" #include "tools/converter/ops/while.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
@ -224,7 +226,7 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract); abstract_list.emplace_back(abstract);
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -370,7 +372,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
} }
STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() { STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -531,7 +533,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
"_cnode"); "_cnode");
} }
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetReturnPrim return nullptr"; MS_LOG(ERROR) << "GetReturnPrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -548,7 +550,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
return_cnode->add_input(tmp_output[0]); return_cnode->add_input(tmp_output[0]);
} else { } else {
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output; std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr"; MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -26,7 +26,9 @@
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/ops/ops_def.h" #include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "abstract/utils.h" #include "abstract/utils.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
@ -884,7 +886,7 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract_tensor); abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -1145,7 +1147,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
} }
if (output_nodes.size() > 1) { if (output_nodes.size() > 1) {
std::vector<AnfNodePtr> make_tuple_inputs = output_nodes; std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed"; MS_LOG(ERROR) << "new MakeTuple failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -1157,7 +1159,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
CHECK_NULL_RETURN(make_tuple_cnode); CHECK_NULL_RETURN(make_tuple_cnode);
make_tuple_cnode->set_fullname_with_scope("return_tuple"); make_tuple_cnode->set_fullname_with_scope("return_tuple");
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -1170,7 +1172,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
anf_graph->set_return(cnode); anf_graph->set_return(cnode);
} else { } else {
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -24,7 +24,6 @@
#include "ops/primitive_c.h" #include "ops/primitive_c.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/converter/quant_param_holder.h" #include "tools/converter/quant_param_holder.h"
#include "tools/converter/converter_context.h" #include "tools/converter/converter_context.h"
@ -34,6 +33,9 @@
#include "tools/converter/parser/unify_format.h" #include "tools/converter/parser/unify_format.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "src/common/log_util.h" #include "src/common/log_util.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
#include "ops/tuple_get_item.h"
using mindspore::converter::kFmkTypeTflite; using mindspore::converter::kFmkTypeTflite;
namespace mindspore::lite { namespace mindspore::lite {
@ -479,7 +481,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
} }
output_nodes.emplace_back(cnode); output_nodes.emplace_back(cnode);
} }
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new MakeTuple failed"; MS_LOG(ERROR) << "new MakeTuple failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -491,7 +493,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs); auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
MSLITE_CHECK_PTR(make_tuple_cnode); MSLITE_CHECK_PTR(make_tuple_cnode);
make_tuple_cnode->set_fullname_with_scope("return_tuple"); make_tuple_cnode->set_fullname_with_scope("return_tuple");
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -505,7 +507,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
cnode->set_fullname_with_scope("Return"); cnode->set_fullname_with_scope("Return");
func_graph->set_return(cnode); func_graph->set_return(cnode);
} else { } else {
auto returnPrim = std::make_shared<lite::Return>(); auto returnPrim = std::make_shared<ops::Return>();
if (returnPrim == nullptr) { if (returnPrim == nullptr) {
MS_LOG(ERROR) << "new Return failed"; MS_LOG(ERROR) << "new Return failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -774,7 +776,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const std::unique_ptr<tflite::SubG
return RET_ERROR; return RET_ERROR;
} }
abstract_list.emplace_back(abstract_tensor); abstract_list.emplace_back(abstract_tensor);
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
if (tuple_get_item_prim_ptr == nullptr) { if (tuple_get_item_prim_ptr == nullptr) {
MS_LOG(ERROR) << "new TupleGetItem failed"; MS_LOG(ERROR) << "new TupleGetItem failed";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -17,7 +17,7 @@
#include "tools/converter/quantizer/calibrator.h" #include "tools/converter/quantizer/calibrator.h"
#include <utility> #include <utility>
#include "tools/converter/preprocess/image_preprocess.h" #include "tools/converter/preprocess/image_preprocess.h"
#include "tools/converter/ops/ops_def.h" #include "ops/tuple_get_item.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@ -63,7 +63,7 @@ int Calibrator::ComputeThreshold() {
for (const auto &output_diverg_info : outputs_diverg_info.second) { for (const auto &output_diverg_info : outputs_diverg_info.second) {
auto output_diverg_cnode = output_diverg_info.second->GetCNode(); auto output_diverg_cnode = output_diverg_info.second->GetCNode();
if (output_diverg_cnode == input_cnode) { if (output_diverg_cnode == input_cnode) {
if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) { if (NodePrimitiveType(input_cnode) != ops::kNameTupleGetItem) {
*(input_infos[i]) = *output_diverg_info.second; *(input_infos[i]) = *output_diverg_info.second;
input_infos[i]->GetCNode() = cnode; input_infos[i]->GetCNode() = cnode;
already_computed = true; already_computed = true;

View File

@ -28,7 +28,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include "ops/fusion/full_connection.h" #include "ops/fusion/full_connection.h"
#include "tools/converter/ops/ops_def.h" #include "ops/tuple_get_item.h"
#include "src/tensor.h" #include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"
@ -462,7 +462,7 @@ int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
auto op_type = primitive->name(); auto op_type = primitive->name();
MS_LOG(DEBUG) << "OpName: " << op_name; MS_LOG(DEBUG) << "OpName: " << op_name;
if (op_type == lite::kNameTupleGetItem) { if (op_type == mindspore::ops::kNameTupleGetItem) {
constexpr int tuple_get_item_input_size = 3; constexpr int tuple_get_item_input_size = 3;
MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3"); MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
auto index_node = cnode->input(THIRD_INPUT); auto index_node = cnode->input(THIRD_INPUT);

View File

@ -25,7 +25,7 @@
#include "ops/fusion/conv2d_fusion.h" #include "ops/fusion/conv2d_fusion.h"
#include "ops/transpose.h" #include "ops/transpose.h"
#include "ops/gather.h" #include "ops/gather.h"
#include "tools/converter/ops/ops_def.h" #include "ops/tuple_get_item.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
@ -174,121 +174,6 @@ bool IsRealKernel(const AnfNodePtr &node) {
#endif #endif
return !is_virtual_node; return !is_virtual_node;
} }
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
auto is_var = std::make_shared<Var>("G");
MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
input_nodes.push_back(node);
}
auto var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
auto a_value_node = a_node->cast<ValueNodePtr>();
auto b_value_node = b_node->cast<ValueNodePtr>();
if (a_value_node == nullptr || b_value_node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_value = a_value_node->value();
auto b_value = b_value_node->value();
if (a_value == nullptr || b_value == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_prim = a_value->cast<PrimitivePtr>();
auto b_prim = b_value->cast<PrimitivePtr>();
if (a_prim == nullptr || b_prim == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
return a_prim->name() == b_prim->name();
}
bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
MS_LOG(ERROR) << "cast value node ptr fail";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_value_ptr = a_value_node_ptr->value();
auto b_value_ptr = b_value_node_ptr->value();
if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
MS_LOG(ERROR) << "value ptr is nullptr";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
return (*a_obj) == (*b_obj);
} else {
return (*a_value_ptr) == (*b_value_ptr);
}
}
} // namespace } // namespace
bool CheckInputs(const CNodePtr &cnode) { bool CheckInputs(const CNodePtr &cnode) {
@ -414,71 +299,6 @@ bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_ty
return false; return false;
} }
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
if (a_node == nullptr || b_node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
return AnfEqualPrimitive(a_node, b_node);
}
if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
return AnfEqualValueNode(a_node, b_node);
}
}
if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
return a_value_node_ptr->name() == b_value_node_ptr->name();
}
return a == b;
}
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return a.type() == b.type();
}
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
}
if (utils::isa<VarPtr>(sexp)) {
auto var_ptr = utils::cast<VarPtr>(sexp);
if (var_ptr == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (var_ptr->primitive()) {
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
return NewValueNode(var_ptr->primitive());
}
return CreateVarNodeWithSexp(sexp, graph);
}
if (utils::isa<AnfNodePtr>(sexp)) {
return utils::cast<AnfNodePtr>(sexp);
}
auto value_node = CreateValueNodeWithSexp(sexp);
if (value_node == nullptr) {
MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
return value_node;
}
bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) { bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
if (utils::isa<AnfNodePtr>(n)) { if (utils::isa<AnfNodePtr>(n)) {
auto anf_node = utils::cast<AnfNodePtr>(n); auto anf_node = utils::cast<AnfNodePtr>(n);
@ -795,31 +615,6 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
return false; return false;
} }
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
auto manager = graph->manager();
if (manager == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(ERROR) << "node has no output in manager";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return nullptr;
}
auto output_info_list = iter->second;
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
return output_node_list;
}
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) { if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
MS_LOG(ERROR) << "The node tuple_get_item is invalid."; MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
@ -843,43 +638,6 @@ size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
return indexes.front(); return indexes.front();
} }
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index) {
if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
return nullptr;
}
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
auto manager = graph->manager();
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(ERROR) << "node has no output in manager";
return output_node_list;
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
size_t used_output_index;
if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
} else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
used_output_index = output_index;
} else {
if (output_index != 0) {
MS_LOG(ERROR) << "node has no output in manager";
return output_node_list;
}
return output_node_list;
}
if (used_output_index == output_index) {
output_node_list->push_back(output_info);
}
}
return output_node_list;
}
STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) { STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>> std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
@ -1135,7 +893,7 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp
MS_LOG(ERROR) << "input parameter is nullptr, which is invalid."; MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
return nullptr; return nullptr;
} }
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr); MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
auto second_input = NewValueNode(MakeValue<int>(index)); auto second_input = NewValueNode(MakeValue<int>(index));
MS_CHECK_TRUE_RET(second_input != nullptr, nullptr); MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
@ -1238,17 +996,6 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
return RET_OK; return RET_OK;
} }
// not implement for lite, just for api compatible
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes) {
return fg->NewCNode(inputs);
}
// not implement for lite, just for api compatible
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
return nullptr;
}
bool IsQuantParameterNode(const PrimitiveCPtr &prim) { bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
MS_CHECK_TRUE_RET(prim != nullptr, false); MS_CHECK_TRUE_RET(prim != nullptr, false);
auto quant_attr = prim->GetAttr("quant_params"); auto quant_attr = prim->GetAttr("quant_params");

View File

@ -0,0 +1,279 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/common/helper.h"
#include <memory>
#include <vector>
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace opt {
namespace {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
auto is_var = std::make_shared<Var>("G");
MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
input_nodes.push_back(node);
}
auto var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
auto a_value_node = a_node->cast<ValueNodePtr>();
auto b_value_node = b_node->cast<ValueNodePtr>();
if (a_value_node == nullptr || b_value_node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_value = a_value_node->value();
auto b_value = b_value_node->value();
if (a_value == nullptr || b_value == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_prim = a_value->cast<PrimitivePtr>();
auto b_prim = b_value->cast<PrimitivePtr>();
if (a_prim == nullptr || b_prim == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
return a_prim->name() == b_prim->name();
}
bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
MS_LOG(ERROR) << "cast value node ptr fail";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
auto a_value_ptr = a_value_node_ptr->value();
auto b_value_ptr = b_value_node_ptr->value();
if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
MS_LOG(ERROR) << "value ptr is nullptr";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
return (*a_obj) == (*b_obj);
} else {
return (*a_value_ptr) == (*b_value_ptr);
}
}
} // namespace
// not implement for lite, just for api compatible
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
const std::vector<AnfNodePtr> &orig_nodes) {
return fg->NewCNode(inputs);
}
// not implement for lite, just for api compatible
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
return nullptr;
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
auto manager = graph->manager();
if (manager == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(ERROR) << "node has no output in manager";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
return nullptr;
}
auto output_info_list = iter->second;
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
return output_node_list;
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index) {
if (graph == nullptr || node == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr.";
return nullptr;
}
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
auto manager = graph->manager();
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(ERROR) << "node has no output in manager";
return output_node_list;
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
size_t used_output_index;
if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
} else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
used_output_index = output_index;
} else {
if (output_index != 0) {
MS_LOG(ERROR) << "node has no output in manager";
return output_node_list;
}
return output_node_list;
}
if (used_output_index == output_index) {
output_node_list->push_back(output_info);
}
}
return output_node_list;
}
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
if (a_node == nullptr || b_node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return false;
}
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
return AnfEqualPrimitive(a_node, b_node);
}
if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
return AnfEqualValueNode(a_node, b_node);
}
}
if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
return a_value_node_ptr->name() == b_value_node_ptr->name();
}
return a == b;
}
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return a.type() == b.type();
}
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
if (primitive_vars == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
}
if (utils::isa<VarPtr>(sexp)) {
auto var_ptr = utils::cast<VarPtr>(sexp);
if (var_ptr == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
if (var_ptr->primitive()) {
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
return NewValueNode(var_ptr->primitive());
}
return CreateVarNodeWithSexp(sexp, graph);
}
if (utils::isa<AnfNodePtr>(sexp)) {
return utils::cast<AnfNodePtr>(sexp);
}
auto value_node = CreateValueNodeWithSexp(sexp);
if (value_node == nullptr) {
MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
}
return value_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -19,7 +19,7 @@
#include <functional> #include <functional>
#include "ops/lstm.h" #include "ops/lstm.h"
#include "ops/squeeze.h" #include "ops/squeeze.h"
#include "tools/converter/ops/ops_def.h" #include "ops/tuple_get_item.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/tensor_util.h" #include "tools/common/tensor_util.h"
#include "utils/utils.h" #include "utils/utils.h"
@ -608,7 +608,7 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
const int item_index) { const int item_index) {
MS_ASSERT(func_graph != nullptr); MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>(); auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
auto get_item_value = NewValueNode(MakeValue<int>(item_index)); auto get_item_value = NewValueNode(MakeValue<int>(item_index));
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
MS_LOG(ERROR) << "NewValueNode is nullptr"; MS_LOG(ERROR) << "NewValueNode is nullptr";
@ -801,9 +801,7 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis); auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
if (squeeze_node == nullptr) { MS_CHECK_TRUE_MSG(squeeze_node != nullptr, nullptr, "create a squeeze node failed.");
return nullptr;
}
auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1); auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr); MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);

View File

@ -23,8 +23,9 @@
#include "ops/tensor_array.h" #include "ops/tensor_array.h"
#include "ops/tensor_array_read.h" #include "ops/tensor_array_read.h"
#include "ops/tensor_array_write.h" #include "ops/tensor_array_write.h"
#include "tools/converter/ops/ops_def.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "ops/make_tuple.h"
#include "ops/return.h"
namespace mindspore::opt { namespace mindspore::opt {
constexpr auto kDefaultIndex = 0; constexpr auto kDefaultIndex = 0;
@ -78,7 +79,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
// for single output graph, create tuple for graph output // for single output graph, create tuple for graph output
// make_tuple node // make_tuple node
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>(); auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
if (make_tuple_prim_ptr == nullptr) { if (make_tuple_prim_ptr == nullptr) {
MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr"; MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr";
return lite::RET_NULL_PTR; return lite::RET_NULL_PTR;
@ -93,7 +94,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
make_tuple_cnode->set_fullname_with_scope("return tuple"); make_tuple_cnode->set_fullname_with_scope("return tuple");
// return node // return node
auto return_prim_ptr = std::make_shared<lite::Return>(); auto return_prim_ptr = std::make_shared<ops::Return>();
if (return_prim_ptr == nullptr) { if (return_prim_ptr == nullptr) {
MS_LOG(ERROR) << "return_prim_ptr is nullptr"; MS_LOG(ERROR) << "return_prim_ptr is nullptr";
return lite::RET_NULL_PTR; return lite::RET_NULL_PTR;

View File

@ -20,7 +20,7 @@
#include <utility> #include <utility>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "tools/anf_exporter/fetch_content.h" #include "tools/anf_exporter/fetch_content.h"
#include "tools/converter/ops/ops_def.h" #include "ops/make_tuple.h"
#include "ops/depend.h" #include "ops/depend.h"
#include "ops/fusion/pad_fusion.h" #include "ops/fusion/pad_fusion.h"
#include "ops/op_utils.h" #include "ops/op_utils.h"
@ -120,7 +120,7 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) { if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
return lite::RET_OK; return lite::RET_OK;
} }
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>()); auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>());
auto manager = func_graph->manager(); auto manager = func_graph->manager();
MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed"); MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
MS_ASSERT(manager != nullptr); MS_ASSERT(manager != nullptr);

View File

@ -16,10 +16,10 @@
#include "tools/optimizer/parallel/operator_info.h" #include "tools/optimizer/parallel/operator_info.h"
#include <algorithm> #include <algorithm>
#include "tools/converter/ops/ops_def.h"
#include "tools/optimizer/parallel/split_strategy.h" #include "tools/optimizer/parallel/split_strategy.h"
#include "ops/concat.h" #include "ops/concat.h"
#include "ops/addn.h" #include "ops/addn.h"
#include "ops/tuple_get_item.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "include/errorcode.h" #include "include/errorcode.h"
@ -120,7 +120,7 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index); auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR); MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
idx->set_abstract(abstract_scalar); idx->set_abstract(abstract_scalar);
auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx}); auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<ops::TupleGetItem>()), node, idx});
if (tuple_getitem == nullptr) { if (tuple_getitem == nullptr) {
MS_LOG(ERROR) << name_ << " : Failed to create output nodes."; MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
return lite::RET_ERROR; return lite::RET_ERROR;