forked from mindspore-Ecosystem/mindspore
open ConverterContext and open special nodes
This commit is contained in:
parent
cf04d2eb66
commit
1dd50a0a6f
|
@ -503,8 +503,6 @@ else()
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||||
install(FILES ${UTILS_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
install(FILES ${UTILS_HEADER} DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "schema_generated.h" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "schema_generated.h" EXCLUDE)
|
||||||
install(DIRECTORY ${flatbuffers_INC}/ DESTINATION ${CONVERTER_ROOT_DIR}/include/third_party
|
install(DIRECTORY ${flatbuffers_INC}/ DESTINATION ${CONVERTER_ROOT_DIR}/include/third_party
|
||||||
|
|
|
@ -103,6 +103,26 @@ inline std::unordered_map<std::string, std::string> UnorderedMapCharToString(
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::map<std::vector<char>, std::vector<char>> MapStringToVectorChar(
|
||||||
|
const std::map<std::string, std::string> &s) {
|
||||||
|
std::map<std::vector<char>, std::vector<char>> ret;
|
||||||
|
std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) {
|
||||||
|
return std::pair<std::vector<char>, std::vector<char>>(std::vector<char>(str.first.begin(), str.first.end()),
|
||||||
|
std::vector<char>(str.second.begin(), str.second.end()));
|
||||||
|
});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::map<std::string, std::string> MapVectorCharToString(
|
||||||
|
const std::map<std::vector<char>, std::vector<char>> &c) {
|
||||||
|
std::map<std::string, std::string> ret;
|
||||||
|
std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) {
|
||||||
|
return std::pair<std::string, std::string>(std::string(ch.first.begin(), ch.first.end()),
|
||||||
|
std::string(ch.second.begin(), ch.second.end()));
|
||||||
|
});
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
inline std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ClassIndexStringToChar(
|
inline std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ClassIndexStringToChar(
|
||||||
const std::vector<std::pair<std::string, std::vector<int32_t>>> &s) {
|
const std::vector<std::pair<std::string, std::vector<int32_t>>> &s) {
|
||||||
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ret;
|
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ret;
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
REGISTER_PRIMITIVE_C(kNameMakeTuple, MakeTuple);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameMakeTuple = "MakeTuple";
|
||||||
|
/// \brief MakeTuple op is used to pack multiple nodes into a whole, which is only used in FuncGraph.
|
||||||
|
class MS_CORE_API MakeTuple : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
MakeTuple() : PrimitiveC(kNameMakeTuple) {}
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~MakeTuple() = default;
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(MakeTuple, PrimitiveC);
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_MAKE_TUPLE_H_
|
|
@ -0,0 +1,23 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/return.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
REGISTER_PRIMITIVE_C(kNameReturn, Return);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,38 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_RETURN_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_RETURN_H_
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameReturn = "Return";
|
||||||
|
/// \brief Return op is the output node, which is only used in FuncGraph.
|
||||||
|
class MS_CORE_API Return : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
Return() : PrimitiveC(kNameReturn) {}
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~Return() = default;
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(Return, PrimitiveC);
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_RETURN_H_
|
|
@ -0,0 +1,23 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
REGISTER_PRIMITIVE_C(kNameTupleGetItem, TupleGetItem);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameTupleGetItem = "TupleGetItem";
|
||||||
|
/// \brief TupleGetItem op is added to the multi-output node to describe which output of the node, which is only used
|
||||||
|
/// in FuncGraph.
|
||||||
|
class MS_CORE_API TupleGetItem : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
TupleGetItem() : PrimitiveC(kNameTupleGetItem) {}
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~TupleGetItem() = default;
|
||||||
|
|
||||||
|
MS_DECLARE_PARENT(TupleGetItem, PrimitiveC);
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_TUPLE_GET_ITEM_H_
|
|
@ -151,7 +151,7 @@ build_lite() {
|
||||||
CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake
|
CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake
|
||||||
fi
|
fi
|
||||||
|
|
||||||
BRANCH_NAME=nnie_3516_master
|
BRANCH_NAME=nnie_3516_master_2
|
||||||
if [[ ("${MSLITE_REGISTRY_DEVICE}" == "Hi3516D" || "${TOOLCHAIN_NAME}" == "himix200") && "${local_lite_platform}" == "arm32" ]]; then
|
if [[ ("${MSLITE_REGISTRY_DEVICE}" == "Hi3516D" || "${TOOLCHAIN_NAME}" == "himix200") && "${local_lite_platform}" == "arm32" ]]; then
|
||||||
TOOLCHAIN_NAME="himix200"
|
TOOLCHAIN_NAME="himix200"
|
||||||
MSLITE_REGISTRY_DEVICE=Hi3516D
|
MSLITE_REGISTRY_DEVICE=Hi3516D
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/lite_utils.h"
|
#include "include/lite_utils.h"
|
||||||
|
#include "include/api/dual_abi_helper.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace converter {
|
namespace converter {
|
||||||
|
@ -53,12 +54,30 @@ class MS_API ConverterContext {
|
||||||
/// \brief Static method to set exported model's output name as needed by users.
|
/// \brief Static method to set exported model's output name as needed by users.
|
||||||
///
|
///
|
||||||
/// \param[in] output_names Define model's output name, the order of which is consistent with the original model.
|
/// \param[in] output_names Define model's output name, the order of which is consistent with the original model.
|
||||||
static void SetGraphOutputTensorNames(const std::vector<std::string> &output_names);
|
static void SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
|
||||||
|
SetGraphOutputTensorNames(VectorStringToChar(output_names));
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Static method to obtain the outputs' name.
|
/// \brief Static method to obtain the outputs' name.
|
||||||
///
|
///
|
||||||
/// \return the outputs' name.
|
/// \return the outputs' name.
|
||||||
static std::vector<std::string> GetGraphOutputTensorNames();
|
static std::vector<std::string> GetGraphOutputTensorNames() {
|
||||||
|
return VectorCharToString(GetGraphOutputTensorNamesInChar());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Static method to get configure information which is used only by external extension.
|
||||||
|
///
|
||||||
|
/// \param[in] section Define config section name.
|
||||||
|
///
|
||||||
|
/// \return config key-value map.
|
||||||
|
static std::map<std::string, std::string> GetConfigInfo(const std::string §ion) {
|
||||||
|
return MapVectorCharToString(GetConfigInfo(StringToChar(section)));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static void SetGraphOutputTensorNames(const std::vector<std::vector<char>> &&output_names);
|
||||||
|
static std::vector<std::vector<char>> GetGraphOutputTensorNamesInChar();
|
||||||
|
static std::map<std::vector<char>, std::vector<char>> GetConfigInfo(const std::vector<char> &§ion);
|
||||||
};
|
};
|
||||||
} // namespace converter
|
} // namespace converter
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,7 +18,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "ops/fusion/conv2d_fusion.h"
|
#include "ops/fusion/conv2d_fusion.h"
|
||||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||||
|
@ -93,7 +94,7 @@ CNodePtr FusionInoutTest::AddReturn(const FuncGraphPtr &graph, const std::vector
|
||||||
if (return_inputs.size() == 1) {
|
if (return_inputs.size() == 1) {
|
||||||
return_input = return_inputs.front();
|
return_input = return_inputs.front();
|
||||||
} else {
|
} else {
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -107,7 +108,7 @@ CNodePtr FusionInoutTest::AddReturn(const FuncGraphPtr &graph, const std::vector
|
||||||
return_input = return_input_cnode;
|
return_input = return_input_cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto return_prim = std::make_shared<lite::Return>();
|
auto return_prim = std::make_shared<ops::Return>();
|
||||||
MS_CHECK_TRUE_MSG(return_prim != nullptr, nullptr, "create return primitivec failed");
|
MS_CHECK_TRUE_MSG(return_prim != nullptr, nullptr, "create return primitivec failed");
|
||||||
auto return_cnode = graph->NewCNode(return_prim, {return_input});
|
auto return_cnode = graph->NewCNode(return_prim, {return_input});
|
||||||
MS_CHECK_TRUE_MSG(return_cnode != nullptr, nullptr, "create Return failed");
|
MS_CHECK_TRUE_MSG(return_cnode != nullptr, nullptr, "create Return failed");
|
||||||
|
|
|
@ -29,8 +29,10 @@
|
||||||
#include "ops/call.h"
|
#include "ops/call.h"
|
||||||
#include "ops/control_depend.h"
|
#include "ops/control_depend.h"
|
||||||
#include "ops/depend.h"
|
#include "ops/depend.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "ops/quant_dtype_cast.h"
|
#include "ops/quant_dtype_cast.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/quantizer/bitpacking.h"
|
#include "tools/converter/quantizer/bitpacking.h"
|
||||||
|
@ -428,8 +430,8 @@ int AnfExporter::SetTailCallForNonOutput() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfExporter::CaseToContinue(const string &prim_name) {
|
bool AnfExporter::CaseToContinue(const string &prim_name) {
|
||||||
return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::lite::kNameTupleGetItem ||
|
return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::ops::kNameTupleGetItem ||
|
||||||
prim_name == mindspore::lite::kNameMakeTuple || prim_name == "make_tuple";
|
prim_name == mindspore::ops::kNameMakeTuple || prim_name == "make_tuple";
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
|
||||||
|
@ -464,7 +466,7 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
|
||||||
node->name = mindspore::lite::kNameReturn;
|
node->name = mindspore::ops::kNameReturn;
|
||||||
ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get());
|
ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get());
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "SetOpOutputN failed";
|
MS_LOG(ERROR) << "SetOpOutputN failed";
|
||||||
|
|
|
@ -27,8 +27,8 @@
|
||||||
#include "tools/common/node_util.h"
|
#include "tools/common/node_util.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -45,7 +45,7 @@ int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::vector<AnfNodePtr>
|
||||||
graph->set_output(outputs.front(), false);
|
graph->set_output(outputs.front(), false);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(DEBUG) << "new MakeTuple failed";
|
MS_LOG(DEBUG) << "new MakeTuple failed";
|
||||||
return lite::RET_NULL_PTR;
|
return lite::RET_NULL_PTR;
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
#include "tools/converter/adapter/acl/acl_pass_impl.h"
|
#include "tools/converter/adapter/acl/acl_pass_impl.h"
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||||
|
@ -28,6 +27,7 @@
|
||||||
#include "include/registry/pass_registry.h"
|
#include "include/registry/pass_registry.h"
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "cxx_api/model/acl/model_converter.h"
|
#include "cxx_api/model/acl/model_converter.h"
|
||||||
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
#include "backend/kernel_compiler/cpu/nnacl/op_base.h"
|
||||||
|
@ -570,7 +570,7 @@ STATUS AclPassImpl::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, cons
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t j = 0; j < graph_outputs_.size(); ++j) {
|
for (size_t j = 0; j < graph_outputs_.size(); ++j) {
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
|
MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "tools/converter/adapter/acl/common/utils.h"
|
#include "tools/converter/adapter/acl/common/utils.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
#include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -30,6 +29,7 @@
|
||||||
#include "ops/batch_norm.h"
|
#include "ops/batch_norm.h"
|
||||||
#include "ops/fused_batch_norm.h"
|
#include "ops/fused_batch_norm.h"
|
||||||
#include "ops/stack.h"
|
#include "ops/stack.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -44,7 +44,7 @@ const std::set<std::string> kCNodeWithDynamicInput = {kNamewiEltwise, ops::kName
|
||||||
|
|
||||||
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
|
CNodePtr CreateTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode) {
|
||||||
CNodePtr get_item_cnode = nullptr;
|
CNodePtr get_item_cnode = nullptr;
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "New TupleGetItem failed";
|
MS_LOG(ERROR) << "New TupleGetItem failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "tools/common/parse_config_utils.h"
|
#include "tools/common/parse_config_utils.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "tools/converter/converter_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
@ -37,35 +38,44 @@ int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseDataPreProcessString(maps);
|
ret = ParseDataPreProcessString(maps);
|
||||||
|
(void)maps.erase(kDataPreprocessParam);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
|
MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseCommonQuantString(maps);
|
ret = ParseCommonQuantString(maps);
|
||||||
|
(void)maps.erase(kCommonQuantParam);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseCommonQuantString failed.";
|
MS_LOG(ERROR) << "ParseCommonQuantString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseMixedBitQuantString(maps);
|
ret = ParseMixedBitQuantString(maps);
|
||||||
|
(void)maps.erase(kMixedBitWeightQuantParam);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
|
MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseFullQuantString(maps);
|
ret = ParseFullQuantString(maps);
|
||||||
|
(void)maps.erase(kFullQuantParam);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseFullQuantString failed.";
|
MS_LOG(ERROR) << "ParseFullQuantString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseRegistryInfoString(maps);
|
ret = ParseRegistryInfoString(maps);
|
||||||
|
(void)maps.erase(kRegistry);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
|
MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
ret = ParseAclOptionCfgString(maps);
|
ret = ParseAclOptionCfgString(maps);
|
||||||
|
(void)maps.erase(kAclOptionParam);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "ParseAclOptionCfgString failed.";
|
MS_LOG(ERROR) << "ParseAclOptionCfgString failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
for (const auto &config_info : maps) {
|
||||||
|
ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,22 +15,41 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "include/registry/converter_context.h"
|
#include "include/registry/converter_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace converter {
|
namespace converter {
|
||||||
void ConverterContext::SetGraphOutputTensorNames(const std::vector<std::string> &output_names) {
|
void ConverterContext::SetGraphOutputTensorNames(const std::vector<std::vector<char>> &&output_names) {
|
||||||
auto converter_context = lite::ConverterInnerContext::GetInstance();
|
auto converter_context = lite::ConverterInnerContext::GetInstance();
|
||||||
MS_ASSERT(converter_context != nullptr);
|
if (converter_context == nullptr) {
|
||||||
converter_context->SetGraphOutputTensorNames(output_names);
|
MS_LOG(ERROR) << "Set graph output's names failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
converter_context->SetGraphOutputTensorNames(VectorCharToString(output_names));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> ConverterContext::GetGraphOutputTensorNames() {
|
std::vector<std::vector<char>> ConverterContext::GetGraphOutputTensorNamesInChar() {
|
||||||
auto converter_context = lite::ConverterInnerContext::GetInstance();
|
auto converter_context = lite::ConverterInnerContext::GetInstance();
|
||||||
MS_ASSERT(converter_context != nullptr);
|
if (converter_context == nullptr) {
|
||||||
return converter_context->GetGraphOutputTensorNames();
|
MS_LOG(ERROR) << "Get graph output's names failed.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return VectorStringToChar(converter_context->GetGraphOutputTensorNames());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::vector<char>, std::vector<char>> ConverterContext::GetConfigInfo(const std::vector<char> &§ion) {
|
||||||
|
auto converter_context = lite::ConverterInnerContext::GetInstance();
|
||||||
|
if (converter_context == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Get config information only used by external extension failed.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto &external_used_config_infos = converter_context->GetExternalUsedConfigInfos();
|
||||||
|
if (external_used_config_infos.find(CharToString(section)) == external_used_config_infos.end()) {
|
||||||
|
MS_LOG(ERROR) << "This section " << section << " config info is not existed.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return MapStringToVectorChar(external_used_config_infos.at(CharToString(section)));
|
||||||
}
|
}
|
||||||
} // namespace converter
|
} // namespace converter
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -112,6 +112,18 @@ class ConverterInnerContext {
|
||||||
|
|
||||||
const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; }
|
const std::vector<std::string> GetGraphOutputTensorNames() const { return graph_output_tensor_names_; }
|
||||||
|
|
||||||
|
void SetExternalUsedConfigInfos(const std::string §ion,
|
||||||
|
const std::map<std::string, std::string> &external_infos) {
|
||||||
|
if (external_used_config_infos_.find(section) != external_used_config_infos_.end()) {
|
||||||
|
MS_LOG(WARNING) << "This section " << section << " has been saved. Now, the content will be overwrite.";
|
||||||
|
}
|
||||||
|
external_used_config_infos_.emplace(section, external_infos);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::map<std::string, std::map<std::string, std::string>> &GetExternalUsedConfigInfos() {
|
||||||
|
return external_used_config_infos_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ConverterInnerContext() = default;
|
ConverterInnerContext() = default;
|
||||||
virtual ~ConverterInnerContext() = default;
|
virtual ~ConverterInnerContext() = default;
|
||||||
|
@ -119,6 +131,7 @@ class ConverterInnerContext {
|
||||||
std::map<int32_t, int32_t> graph_output_data_type_map_;
|
std::map<int32_t, int32_t> graph_output_data_type_map_;
|
||||||
std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
|
std::map<std::string, std::vector<int64_t>> graph_input_tensor_shape_map_;
|
||||||
std::vector<std::string> graph_output_tensor_names_;
|
std::vector<std::string> graph_output_tensor_names_;
|
||||||
|
std::map<std::string, std::map<std::string, std::string>> external_used_config_infos_;
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,7 +18,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
@ -130,7 +131,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
|
||||||
|
|
||||||
if (after_fg->get_inputs().size() > 1) {
|
if (after_fg->get_inputs().size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = after_fg->get_inputs();
|
std::vector<AnfNodePtr> make_tuple_inputs = after_fg->get_inputs();
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -141,7 +142,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
|
||||||
auto make_tuple_cnode = after_fg->NewCNode(make_tuple_inputs);
|
auto make_tuple_cnode = after_fg->NewCNode(make_tuple_inputs);
|
||||||
MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "Failed to create C node.");
|
MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "Failed to create C node.");
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -154,7 +155,7 @@ FuncGraphPtr MindIRControlFlowAdjust::AddAfterFuncGraph(const FuncGraphPtr &fg,
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
after_fg->set_return(cnode);
|
after_fg->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
||||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_OPS_DEF_H_
|
||||||
#include "schema/inner/model_generated.h"
|
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
using mindspore::ops::PrimitiveC;
|
using mindspore::ops::PrimitiveC;
|
||||||
|
|
||||||
|
@ -43,9 +42,6 @@ ADD_CONVERTER_ONLY_OP(TensorArraySizeV3);
|
||||||
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
|
ADD_CONVERTER_ONLY_OP(TensorArrayV3);
|
||||||
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
|
ADD_CONVERTER_ONLY_OP(TensorArrayWriteV3);
|
||||||
ADD_CONVERTER_ONLY_OP(Constant);
|
ADD_CONVERTER_ONLY_OP(Constant);
|
||||||
ADD_CONVERTER_ONLY_OP(MakeTuple);
|
|
||||||
ADD_CONVERTER_ONLY_OP(TupleGetItem);
|
|
||||||
ADD_CONVERTER_ONLY_OP(Return);
|
|
||||||
ADD_CONVERTER_ONLY_OP(Merge);
|
ADD_CONVERTER_ONLY_OP(Merge);
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,7 +25,6 @@
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
|
@ -35,6 +34,9 @@
|
||||||
#include "tools/converter/parser/unify_format.h"
|
#include "tools/converter/parser/unify_format.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
|
|
||||||
using mindspore::converter::kFmkTypeCaffe;
|
using mindspore::converter::kFmkTypeCaffe;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
@ -416,7 +418,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
caffeInspector.InspectModel(caffe_model_);
|
caffeInspector.InspectModel(caffe_model_);
|
||||||
if (caffeInspector.GetGraphOutput().size() > 1) {
|
if (caffeInspector.GetGraphOutput().size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
MSLITE_CHECK_PTR(make_tuple_prim_ptr);
|
MSLITE_CHECK_PTR(make_tuple_prim_ptr);
|
||||||
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
auto make_tuple_prim = NewValueNode(make_tuple_prim_ptr);
|
||||||
MSLITE_CHECK_PTR(make_tuple_prim);
|
MSLITE_CHECK_PTR(make_tuple_prim);
|
||||||
|
@ -434,7 +436,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
MSLITE_CHECK_PTR(return_prim_ptr);
|
MSLITE_CHECK_PTR(return_prim_ptr);
|
||||||
auto value_node = NewValueNode(return_prim_ptr);
|
auto value_node = NewValueNode(return_prim_ptr);
|
||||||
MSLITE_CHECK_PTR(value_node);
|
MSLITE_CHECK_PTR(value_node);
|
||||||
|
@ -445,7 +447,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
res_graph_->set_return(cnode);
|
res_graph_->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<lite::Return>();
|
auto returnPrim = std::make_shared<ops::Return>();
|
||||||
MSLITE_CHECK_PTR(returnPrim);
|
MSLITE_CHECK_PTR(returnPrim);
|
||||||
auto valueNode = NewValueNode(returnPrim);
|
auto valueNode = NewValueNode(returnPrim);
|
||||||
MSLITE_CHECK_PTR(valueNode);
|
MSLITE_CHECK_PTR(valueNode);
|
||||||
|
@ -584,7 +586,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract);
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "ops/tensor_list_stack.h"
|
#include "ops/tensor_list_stack.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
@ -38,6 +37,9 @@
|
||||||
#include "tools/converter/parser/unify_format.h"
|
#include "tools/converter/parser/unify_format.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
|
|
||||||
using mindspore::converter::kFmkTypeOnnx;
|
using mindspore::converter::kFmkTypeOnnx;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -161,7 +163,7 @@ CNodePtr GetCNodeFromControlFlowNodesMap(
|
||||||
|
|
||||||
STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
|
STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector<AnfNodePtr> &return_inputs) {
|
||||||
MS_ASSERT(anf_graph != nullptr);
|
MS_ASSERT(anf_graph != nullptr);
|
||||||
auto return_prim = std::make_shared<lite::Return>();
|
auto return_prim = std::make_shared<ops::Return>();
|
||||||
if (return_prim == nullptr) {
|
if (return_prim == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -227,7 +229,7 @@ STATUS BuildOpOutputs(const onnx::NodeProto &onnx_node, const FuncGraphPtr &anf_
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -326,7 +328,7 @@ STATUS ConvertGraphOutputs(const onnx::GraphProto &onnx_graph, const FuncGraphPt
|
||||||
}
|
}
|
||||||
if (onnx_graph.output_size() > 1) {
|
if (onnx_graph.output_size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
|
|
||||||
|
@ -157,7 +158,7 @@ FuncGraphPtr FunctionalizeCond::CreateBranchGraph(const AnfNodePtr &node, std::s
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
|
if (!CheckPrimitiveType(node, prim::kPrimSwitch)) { // graph is not empty
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -19,7 +19,9 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
#include "tools/converter/ops/while.h"
|
#include "tools/converter/ops/while.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
|
@ -224,7 +226,7 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract);
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -370,7 +372,7 @@ STATUS FunctionalizeWhile::IdentifyCondSubgraphInput() {
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
STATUS FunctionalizeWhile::IdentifyCondSubgraphOutput() {
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -531,7 +533,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
||||||
"_cnode");
|
"_cnode");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
MS_LOG(ERROR) << "GetReturnPrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -548,7 +550,7 @@ STATUS FunctionalizeWhile::IdentifyBodySubgraphOutput() {
|
||||||
return_cnode->add_input(tmp_output[0]);
|
return_cnode->add_input(tmp_output[0]);
|
||||||
} else {
|
} else {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
|
std::vector<AnfNodePtr> make_tuple_inputs = tmp_output;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
MS_LOG(ERROR) << "GetMakeTuplePrim return nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -26,7 +26,9 @@
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "abstract/utils.h"
|
#include "abstract/utils.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
|
@ -884,7 +886,7 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1145,7 +1147,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
|
||||||
}
|
}
|
||||||
if (output_nodes.size() > 1) {
|
if (output_nodes.size() > 1) {
|
||||||
std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
|
std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1157,7 +1159,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
|
||||||
CHECK_NULL_RETURN(make_tuple_cnode);
|
CHECK_NULL_RETURN(make_tuple_cnode);
|
||||||
make_tuple_cnode->set_fullname_with_scope("return_tuple");
|
make_tuple_cnode->set_fullname_with_scope("return_tuple");
|
||||||
|
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -1170,7 +1172,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
anf_graph->set_return(cnode);
|
anf_graph->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -24,7 +24,6 @@
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/converter/quant_param_holder.h"
|
#include "tools/converter/quant_param_holder.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "tools/converter/converter_context.h"
|
||||||
|
@ -34,6 +33,9 @@
|
||||||
#include "tools/converter/parser/unify_format.h"
|
#include "tools/converter/parser/unify_format.h"
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
#include "src/common/log_util.h"
|
#include "src/common/log_util.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
|
|
||||||
using mindspore::converter::kFmkTypeTflite;
|
using mindspore::converter::kFmkTypeTflite;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
@ -479,7 +481,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
|
||||||
}
|
}
|
||||||
output_nodes.emplace_back(cnode);
|
output_nodes.emplace_back(cnode);
|
||||||
}
|
}
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new MakeTuple failed";
|
MS_LOG(ERROR) << "new MakeTuple failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -491,7 +493,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
|
||||||
auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
|
auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
|
||||||
MSLITE_CHECK_PTR(make_tuple_cnode);
|
MSLITE_CHECK_PTR(make_tuple_cnode);
|
||||||
make_tuple_cnode->set_fullname_with_scope("return_tuple");
|
make_tuple_cnode->set_fullname_with_scope("return_tuple");
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -505,7 +507,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptr<tflite::SubG
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
func_graph->set_return(cnode);
|
func_graph->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<lite::Return>();
|
auto returnPrim = std::make_shared<ops::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
MS_LOG(ERROR) << "new Return failed";
|
MS_LOG(ERROR) << "new Return failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -774,7 +776,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const std::unique_ptr<tflite::SubG
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "tools/converter/quantizer/calibrator.h"
|
#include "tools/converter/quantizer/calibrator.h"
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "tools/converter/preprocess/image_preprocess.h"
|
#include "tools/converter/preprocess/image_preprocess.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
@ -63,7 +63,7 @@ int Calibrator::ComputeThreshold() {
|
||||||
for (const auto &output_diverg_info : outputs_diverg_info.second) {
|
for (const auto &output_diverg_info : outputs_diverg_info.second) {
|
||||||
auto output_diverg_cnode = output_diverg_info.second->GetCNode();
|
auto output_diverg_cnode = output_diverg_info.second->GetCNode();
|
||||||
if (output_diverg_cnode == input_cnode) {
|
if (output_diverg_cnode == input_cnode) {
|
||||||
if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
|
if (NodePrimitiveType(input_cnode) != ops::kNameTupleGetItem) {
|
||||||
*(input_infos[i]) = *output_diverg_info.second;
|
*(input_infos[i]) = *output_diverg_info.second;
|
||||||
input_infos[i]->GetCNode() = cnode;
|
input_infos[i]->GetCNode() = cnode;
|
||||||
already_computed = true;
|
already_computed = true;
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "ops/fusion/full_connection.h"
|
#include "ops/fusion/full_connection.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "tools/converter/quantizer/quant_cast.h"
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
#include "tools/converter/quantizer/quantize_util.h"
|
||||||
|
@ -462,7 +462,7 @@ int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
|
||||||
|
|
||||||
auto op_type = primitive->name();
|
auto op_type = primitive->name();
|
||||||
MS_LOG(DEBUG) << "OpName: " << op_name;
|
MS_LOG(DEBUG) << "OpName: " << op_name;
|
||||||
if (op_type == lite::kNameTupleGetItem) {
|
if (op_type == mindspore::ops::kNameTupleGetItem) {
|
||||||
constexpr int tuple_get_item_input_size = 3;
|
constexpr int tuple_get_item_input_size = 3;
|
||||||
MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
|
MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
|
||||||
auto index_node = cnode->input(THIRD_INPUT);
|
auto index_node = cnode->input(THIRD_INPUT);
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include "ops/fusion/conv2d_fusion.h"
|
#include "ops/fusion/conv2d_fusion.h"
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
#include "ops/gather.h"
|
#include "ops/gather.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "backend/optimizer/common/helper.h"
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
@ -174,121 +174,6 @@ bool IsRealKernel(const AnfNodePtr &node) {
|
||||||
#endif
|
#endif
|
||||||
return !is_virtual_node;
|
return !is_virtual_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
|
||||||
if (utils::isa<int>(sexp)) {
|
|
||||||
return NewValueNode(utils::cast<int>(sexp));
|
|
||||||
}
|
|
||||||
if (utils::isa<float>(sexp)) {
|
|
||||||
return NewValueNode(utils::cast<float>(sexp));
|
|
||||||
}
|
|
||||||
if (utils::isa<bool>(sexp)) {
|
|
||||||
return NewValueNode(utils::cast<bool>(sexp));
|
|
||||||
}
|
|
||||||
if (utils::isa<ValuePtr>(sexp)) {
|
|
||||||
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
|
||||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
|
||||||
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
|
||||||
}
|
|
||||||
if (utils::isa<VarPtr>(graph)) {
|
|
||||||
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
|
||||||
if (utils::isa<VarPtr>(graph)) {
|
|
||||||
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
|
||||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
|
||||||
}
|
|
||||||
if (utils::isa<FuncGraphPtr>(graph)) {
|
|
||||||
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
|
||||||
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
|
||||||
}
|
|
||||||
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
|
||||||
bool multigraph) {
|
|
||||||
if (primitive_vars == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
|
||||||
std::vector<AnfNodePtr> input_nodes;
|
|
||||||
const auto &tuple = utils::cast<VectorRef>(sexp);
|
|
||||||
if (multigraph && utils::isa<VarPtr>(graph)) {
|
|
||||||
for (auto &x : tuple) {
|
|
||||||
auto is_var = std::make_shared<Var>("G");
|
|
||||||
MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
|
|
||||||
AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
|
|
||||||
input_nodes.push_back(node);
|
|
||||||
}
|
|
||||||
auto var_ptr = utils::cast<VarPtr>(graph);
|
|
||||||
return std::make_shared<CNode>(input_nodes, var_ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &x : tuple) {
|
|
||||||
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
|
||||||
input_nodes.push_back(node);
|
|
||||||
}
|
|
||||||
return CreateCNodeWithGraph(input_nodes, graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
|
||||||
auto a_value_node = a_node->cast<ValueNodePtr>();
|
|
||||||
auto b_value_node = b_node->cast<ValueNodePtr>();
|
|
||||||
if (a_value_node == nullptr || b_value_node == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto a_value = a_value_node->value();
|
|
||||||
auto b_value = b_value_node->value();
|
|
||||||
if (a_value == nullptr || b_value == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto a_prim = a_value->cast<PrimitivePtr>();
|
|
||||||
auto b_prim = b_value->cast<PrimitivePtr>();
|
|
||||||
if (a_prim == nullptr || b_prim == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return a_prim->name() == b_prim->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
|
||||||
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
|
||||||
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
|
||||||
if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "cast value node ptr fail";
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto a_value_ptr = a_value_node_ptr->value();
|
|
||||||
auto b_value_ptr = b_value_node_ptr->value();
|
|
||||||
if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "value ptr is nullptr";
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
|
|
||||||
auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
|
|
||||||
auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
|
|
||||||
return (*a_obj) == (*b_obj);
|
|
||||||
} else {
|
|
||||||
return (*a_value_ptr) == (*b_value_ptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool CheckInputs(const CNodePtr &cnode) {
|
bool CheckInputs(const CNodePtr &cnode) {
|
||||||
|
@ -414,71 +299,6 @@ bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_ty
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
|
||||||
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
|
||||||
auto a_node = utils::cast<AnfNodePtr>(a);
|
|
||||||
auto b_node = utils::cast<AnfNodePtr>(b);
|
|
||||||
if (a_node == nullptr || b_node == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
|
||||||
return AnfEqualPrimitive(a_node, b_node);
|
|
||||||
}
|
|
||||||
if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
|
||||||
return AnfEqualValueNode(a_node, b_node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
|
|
||||||
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
|
|
||||||
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
|
|
||||||
return a_value_node_ptr->name() == b_value_node_ptr->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
return a == b;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
|
||||||
// To matchCNode and Kernel's type
|
|
||||||
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return a.type() == b.type();
|
|
||||||
}
|
|
||||||
|
|
||||||
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
|
||||||
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
|
||||||
if (primitive_vars == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (utils::isa<VectorRef>(sexp)) {
|
|
||||||
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
|
||||||
}
|
|
||||||
if (utils::isa<VarPtr>(sexp)) {
|
|
||||||
auto var_ptr = utils::cast<VarPtr>(sexp);
|
|
||||||
if (var_ptr == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
if (var_ptr->primitive()) {
|
|
||||||
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
|
|
||||||
return NewValueNode(var_ptr->primitive());
|
|
||||||
}
|
|
||||||
return CreateVarNodeWithSexp(sexp, graph);
|
|
||||||
}
|
|
||||||
if (utils::isa<AnfNodePtr>(sexp)) {
|
|
||||||
return utils::cast<AnfNodePtr>(sexp);
|
|
||||||
}
|
|
||||||
auto value_node = CreateValueNodeWithSexp(sexp);
|
|
||||||
if (value_node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return value_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
|
bool IsOpType(const BaseRef &n, const PrimitivePtr &prim) {
|
||||||
if (utils::isa<AnfNodePtr>(n)) {
|
if (utils::isa<AnfNodePtr>(n)) {
|
||||||
auto anf_node = utils::cast<AnfNodePtr>(n);
|
auto anf_node = utils::cast<AnfNodePtr>(n);
|
||||||
|
@ -795,31 +615,6 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
|
||||||
const AnfNodePtr &node) {
|
|
||||||
if (graph == nullptr || node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
|
||||||
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
|
|
||||||
auto manager = graph->manager();
|
|
||||||
if (manager == nullptr) {
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto iter = manager->node_users().find(node);
|
|
||||||
if (iter == manager->node_users().end()) {
|
|
||||||
MS_LOG(ERROR) << "node has no output in manager";
|
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto output_info_list = iter->second;
|
|
||||||
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
|
|
||||||
return output_node_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
||||||
if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
|
if (tuple_get_item == nullptr || tuple_get_item->size() != kInputSizeThree) {
|
||||||
MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
|
MS_LOG(ERROR) << "The node tuple_get_item is invalid.";
|
||||||
|
@ -843,43 +638,6 @@ size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
||||||
return indexes.front();
|
return indexes.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
|
||||||
const AnfNodePtr &node,
|
|
||||||
size_t output_index) {
|
|
||||||
if (graph == nullptr || node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "input parameter is nullptr.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
|
||||||
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
|
|
||||||
auto manager = graph->manager();
|
|
||||||
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
|
|
||||||
auto iter = manager->node_users().find(node);
|
|
||||||
if (iter == manager->node_users().end()) {
|
|
||||||
MS_LOG(ERROR) << "node has no output in manager";
|
|
||||||
return output_node_list;
|
|
||||||
}
|
|
||||||
auto output_info_list = iter->second;
|
|
||||||
for (const auto &output_info : output_info_list) {
|
|
||||||
size_t used_output_index;
|
|
||||||
if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
|
|
||||||
used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
|
||||||
} else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
|
||||||
used_output_index = output_index;
|
|
||||||
} else {
|
|
||||||
if (output_index != 0) {
|
|
||||||
MS_LOG(ERROR) << "node has no output in manager";
|
|
||||||
return output_node_list;
|
|
||||||
}
|
|
||||||
return output_node_list;
|
|
||||||
}
|
|
||||||
if (used_output_index == output_index) {
|
|
||||||
output_node_list->push_back(output_info);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return output_node_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
|
STATUS TransFilterFormat(const tensor::TensorPtr &tensor, schema::Format src_format, schema::Format dst_format) {
|
||||||
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
|
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
|
||||||
std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
|
std::unordered_map<TypeId, std::function<STATUS(const tensor::TensorPtr &, schema::Format, schema::Format)>>
|
||||||
|
@ -1135,7 +893,7 @@ CNodePtr GenTupleGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &inp
|
||||||
MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
|
MS_LOG(ERROR) << "input parameter is nullptr, which is invalid.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
|
||||||
MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(tuple_get_item_prim != nullptr, nullptr);
|
||||||
auto second_input = NewValueNode(MakeValue<int>(index));
|
auto second_input = NewValueNode(MakeValue<int>(index));
|
||||||
MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(second_input != nullptr, nullptr);
|
||||||
|
@ -1238,17 +996,6 @@ int GetDataTypeFromAnfNode(const AnfNodePtr &anf_node, TypeId *type_id) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// not implement for lite, just for api compatible
|
|
||||||
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
|
||||||
const std::vector<AnfNodePtr> &orig_nodes) {
|
|
||||||
return fg->NewCNode(inputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// not implement for lite, just for api compatible
|
|
||||||
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
|
bool IsQuantParameterNode(const PrimitiveCPtr &prim) {
|
||||||
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
MS_CHECK_TRUE_RET(prim != nullptr, false);
|
||||||
auto quant_attr = prim->GetAttr("quant_params");
|
auto quant_attr = prim->GetAttr("quant_params");
|
||||||
|
|
|
@ -0,0 +1,279 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "nnacl/op_base.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
|
||||||
|
if (utils::isa<int>(sexp)) {
|
||||||
|
return NewValueNode(utils::cast<int>(sexp));
|
||||||
|
}
|
||||||
|
if (utils::isa<float>(sexp)) {
|
||||||
|
return NewValueNode(utils::cast<float>(sexp));
|
||||||
|
}
|
||||||
|
if (utils::isa<bool>(sexp)) {
|
||||||
|
return NewValueNode(utils::cast<bool>(sexp));
|
||||||
|
}
|
||||||
|
if (utils::isa<ValuePtr>(sexp)) {
|
||||||
|
return NewValueNode(utils::cast<ValuePtr>(sexp));
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
|
||||||
|
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||||
|
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
|
||||||
|
}
|
||||||
|
if (utils::isa<VarPtr>(graph)) {
|
||||||
|
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
|
||||||
|
if (utils::isa<VarPtr>(graph)) {
|
||||||
|
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
|
||||||
|
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
|
||||||
|
}
|
||||||
|
if (utils::isa<FuncGraphPtr>(graph)) {
|
||||||
|
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
|
||||||
|
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
|
||||||
|
}
|
||||||
|
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
|
||||||
|
bool multigraph) {
|
||||||
|
if (primitive_vars == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||||
|
std::vector<AnfNodePtr> input_nodes;
|
||||||
|
const auto &tuple = utils::cast<VectorRef>(sexp);
|
||||||
|
if (multigraph && utils::isa<VarPtr>(graph)) {
|
||||||
|
for (auto &x : tuple) {
|
||||||
|
auto is_var = std::make_shared<Var>("G");
|
||||||
|
MS_CHECK_TRUE_RET(is_var != nullptr, nullptr);
|
||||||
|
AnfNodePtr node = SexpToNode(x, is_var, primitive_vars, true);
|
||||||
|
input_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
auto var_ptr = utils::cast<VarPtr>(graph);
|
||||||
|
return std::make_shared<CNode>(input_nodes, var_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &x : tuple) {
|
||||||
|
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
|
||||||
|
input_nodes.push_back(node);
|
||||||
|
}
|
||||||
|
return CreateCNodeWithGraph(input_nodes, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfEqualPrimitive(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
||||||
|
auto a_value_node = a_node->cast<ValueNodePtr>();
|
||||||
|
auto b_value_node = b_node->cast<ValueNodePtr>();
|
||||||
|
if (a_value_node == nullptr || b_value_node == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto a_value = a_value_node->value();
|
||||||
|
auto b_value = b_value_node->value();
|
||||||
|
if (a_value == nullptr || b_value == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto a_prim = a_value->cast<PrimitivePtr>();
|
||||||
|
auto b_prim = b_value->cast<PrimitivePtr>();
|
||||||
|
if (a_prim == nullptr || b_prim == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return a_prim->name() == b_prim->name();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfEqualValueNode(const AnfNodePtr &a_node, const AnfNodePtr &b_node) {
|
||||||
|
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
|
||||||
|
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
|
||||||
|
if (a_value_node_ptr == nullptr || b_value_node_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "cast value node ptr fail";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto a_value_ptr = a_value_node_ptr->value();
|
||||||
|
auto b_value_ptr = b_value_node_ptr->value();
|
||||||
|
if (a_value_ptr == nullptr || b_value_ptr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "value ptr is nullptr";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (utils::isa<ops::PrimitiveC>(a_value_ptr) && utils::isa<ops::PrimitiveC>(b_value_ptr)) {
|
||||||
|
auto a_obj = (ops::PrimitiveC *)(a_value_ptr.get());
|
||||||
|
auto b_obj = (ops::PrimitiveC *)(b_value_ptr.get());
|
||||||
|
return (*a_obj) == (*b_obj);
|
||||||
|
} else {
|
||||||
|
return (*a_value_ptr) == (*b_value_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
// not implement for lite, just for api compatible
|
||||||
|
CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg,
|
||||||
|
const std::vector<AnfNodePtr> &orig_nodes) {
|
||||||
|
return fg->NewCNode(inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// not implement for lite, just for api compatible
|
||||||
|
CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector<AnfNodePtr> &orig_nodes) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node) {
|
||||||
|
if (graph == nullptr || node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||||
|
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
|
||||||
|
auto manager = graph->manager();
|
||||||
|
if (manager == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto iter = manager->node_users().find(node);
|
||||||
|
if (iter == manager->node_users().end()) {
|
||||||
|
MS_LOG(ERROR) << "node has no output in manager";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_ERROR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto output_info_list = iter->second;
|
||||||
|
std::copy(output_info_list.begin(), output_info_list.end(), std::back_inserter(*output_node_list));
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
|
||||||
|
const AnfNodePtr &node,
|
||||||
|
size_t output_index) {
|
||||||
|
if (graph == nullptr || node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "input parameter is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
|
||||||
|
MS_CHECK_TRUE_RET(output_node_list != nullptr, nullptr);
|
||||||
|
auto manager = graph->manager();
|
||||||
|
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
|
||||||
|
auto iter = manager->node_users().find(node);
|
||||||
|
if (iter == manager->node_users().end()) {
|
||||||
|
MS_LOG(ERROR) << "node has no output in manager";
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
auto output_info_list = iter->second;
|
||||||
|
for (const auto &output_info : output_info_list) {
|
||||||
|
size_t used_output_index;
|
||||||
|
if (CheckPrimitiveType(output_info.first, prim::kPrimTupleGetItem)) {
|
||||||
|
used_output_index = GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
|
||||||
|
} else if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||||
|
used_output_index = output_index;
|
||||||
|
} else {
|
||||||
|
if (output_index != 0) {
|
||||||
|
MS_LOG(ERROR) << "node has no output in manager";
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
if (used_output_index == output_index) {
|
||||||
|
output_node_list->push_back(output_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_node_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
||||||
|
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
|
||||||
|
auto a_node = utils::cast<AnfNodePtr>(a);
|
||||||
|
auto b_node = utils::cast<AnfNodePtr>(b);
|
||||||
|
if (a_node == nullptr || b_node == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
|
||||||
|
return AnfEqualPrimitive(a_node, b_node);
|
||||||
|
}
|
||||||
|
if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
|
||||||
|
return AnfEqualValueNode(a_node, b_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (a.m_ptr->isa<mindspore::ops::PrimitiveC>() && b.m_ptr->isa<mindspore::ops::PrimitiveC>()) {
|
||||||
|
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
|
||||||
|
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
|
||||||
|
return a_value_node_ptr->name() == b_value_node_ptr->name();
|
||||||
|
}
|
||||||
|
|
||||||
|
return a == b;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
|
||||||
|
// To matchCNode and Kernel's type
|
||||||
|
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return a.type() == b.type();
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
|
||||||
|
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
|
||||||
|
if (primitive_vars == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (utils::isa<VectorRef>(sexp)) {
|
||||||
|
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
|
||||||
|
}
|
||||||
|
if (utils::isa<VarPtr>(sexp)) {
|
||||||
|
auto var_ptr = utils::cast<VarPtr>(sexp);
|
||||||
|
if (var_ptr == nullptr) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (var_ptr->primitive()) {
|
||||||
|
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
|
||||||
|
return NewValueNode(var_ptr->primitive());
|
||||||
|
}
|
||||||
|
return CreateVarNodeWithSexp(sexp, graph);
|
||||||
|
}
|
||||||
|
if (utils::isa<AnfNodePtr>(sexp)) {
|
||||||
|
return utils::cast<AnfNodePtr>(sexp);
|
||||||
|
}
|
||||||
|
auto value_node = CreateValueNodeWithSexp(sexp);
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "sexp cannot converted. sexp: " << sexp.ToString();
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return value_node;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -19,7 +19,7 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "ops/lstm.h"
|
#include "ops/lstm.h"
|
||||||
#include "ops/squeeze.h"
|
#include "ops/squeeze.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
@ -608,7 +608,7 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
|
||||||
const int item_index) {
|
const int item_index) {
|
||||||
MS_ASSERT(func_graph != nullptr);
|
MS_ASSERT(func_graph != nullptr);
|
||||||
MS_ASSERT(node != nullptr);
|
MS_ASSERT(node != nullptr);
|
||||||
auto tuple_get_item_prim = std::make_shared<lite::TupleGetItem>();
|
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
|
||||||
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
|
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
|
||||||
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
|
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
|
||||||
MS_LOG(ERROR) << "NewValueNode is nullptr";
|
MS_LOG(ERROR) << "NewValueNode is nullptr";
|
||||||
|
@ -801,9 +801,7 @@ const AnfNodePtr TfliteLstmCellFusion::Process(const FuncGraphPtr &func_graph, c
|
||||||
|
|
||||||
std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
|
std::vector<int> squeeze_axis{1}; // our lstm output:0 have an extra axis that tflite not have, it must be squeezed
|
||||||
auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
|
auto squeeze_node = CreateSqueezeNode(func_graph, get_item_node, squeeze_axis);
|
||||||
if (squeeze_node == nullptr) {
|
MS_CHECK_TRUE_MSG(squeeze_node != nullptr, nullptr, "create a squeeze node failed.");
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
|
auto cond_cnode_index_pair = std::make_shared<CNodeIndexPair>(while_cnode, 1);
|
||||||
MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
|
MS_CHECK_TRUE_RET(cond_cnode_index_pair != nullptr, nullptr);
|
||||||
|
|
|
@ -23,8 +23,9 @@
|
||||||
#include "ops/tensor_array.h"
|
#include "ops/tensor_array.h"
|
||||||
#include "ops/tensor_array_read.h"
|
#include "ops/tensor_array_read.h"
|
||||||
#include "ops/tensor_array_write.h"
|
#include "ops/tensor_array_write.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "nnacl/op_base.h"
|
#include "nnacl/op_base.h"
|
||||||
|
#include "ops/make_tuple.h"
|
||||||
|
#include "ops/return.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
constexpr auto kDefaultIndex = 0;
|
constexpr auto kDefaultIndex = 0;
|
||||||
|
@ -78,7 +79,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
||||||
|
|
||||||
// for single output graph, create tuple for graph output
|
// for single output graph, create tuple for graph output
|
||||||
// make_tuple node
|
// make_tuple node
|
||||||
auto make_tuple_prim_ptr = std::make_shared<lite::MakeTuple>();
|
auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
|
||||||
if (make_tuple_prim_ptr == nullptr) {
|
if (make_tuple_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr";
|
MS_LOG(ERROR) << "make_tuple_prim_ptr is nullptr";
|
||||||
return lite::RET_NULL_PTR;
|
return lite::RET_NULL_PTR;
|
||||||
|
@ -93,7 +94,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
// return node
|
// return node
|
||||||
auto return_prim_ptr = std::make_shared<lite::Return>();
|
auto return_prim_ptr = std::make_shared<ops::Return>();
|
||||||
if (return_prim_ptr == nullptr) {
|
if (return_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "return_prim_ptr is nullptr";
|
MS_LOG(ERROR) << "return_prim_ptr is nullptr";
|
||||||
return lite::RET_NULL_PTR;
|
return lite::RET_NULL_PTR;
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "tools/anf_exporter/fetch_content.h"
|
#include "tools/anf_exporter/fetch_content.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "ops/make_tuple.h"
|
||||||
#include "ops/depend.h"
|
#include "ops/depend.h"
|
||||||
#include "ops/fusion/pad_fusion.h"
|
#include "ops/fusion/pad_fusion.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
@ -120,7 +120,7 @@ int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &c
|
||||||
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
|
if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
|
auto make_tuple_prim = NewValueNode(std::make_shared<ops::MakeTuple>());
|
||||||
auto manager = func_graph->manager();
|
auto manager = func_graph->manager();
|
||||||
MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
|
MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
|
||||||
MS_ASSERT(manager != nullptr);
|
MS_ASSERT(manager != nullptr);
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
|
|
||||||
#include "tools/optimizer/parallel/operator_info.h"
|
#include "tools/optimizer/parallel/operator_info.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "tools/converter/ops/ops_def.h"
|
|
||||||
#include "tools/optimizer/parallel/split_strategy.h"
|
#include "tools/optimizer/parallel/split_strategy.h"
|
||||||
#include "ops/concat.h"
|
#include "ops/concat.h"
|
||||||
#include "ops/addn.h"
|
#include "ops/addn.h"
|
||||||
|
#include "ops/tuple_get_item.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
@ -120,7 +120,7 @@ int OperatorInfo::CreateMultipleOutputsOfAnfNode(const AnfNodePtr &node, size_t
|
||||||
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
|
auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(index);
|
||||||
MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
|
MS_CHECK_TRUE_RET(abstract_scalar != nullptr, lite::RET_ERROR);
|
||||||
idx->set_abstract(abstract_scalar);
|
idx->set_abstract(abstract_scalar);
|
||||||
auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<lite::TupleGetItem>()), node, idx});
|
auto tuple_getitem = func_graph_->NewCNode({NewValueNode(std::make_shared<ops::TupleGetItem>()), node, idx});
|
||||||
if (tuple_getitem == nullptr) {
|
if (tuple_getitem == nullptr) {
|
||||||
MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
|
MS_LOG(ERROR) << name_ << " : Failed to create output nodes.";
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
|
|
Loading…
Reference in New Issue