draw out model_parser basic class
This commit is contained in:
parent
c8b9d45abc
commit
6f738d2e4a
|
@ -334,8 +334,6 @@ elseif(WIN32)
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/build/mindspore/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
@ -477,8 +475,6 @@ else()
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/schema/ DESTINATION ${CONVERTER_ROOT_DIR}/include/schema
|
||||||
|
|
|
@ -237,13 +237,13 @@ build_lite() {
|
||||||
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
|
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
|
||||||
cd ${BASEPATH}/../
|
cd ${BASEPATH}/../
|
||||||
if [[ "${local_lite_platform}" == "x86_64" ]]; then
|
if [[ "${local_lite_platform}" == "x86_64" ]]; then
|
||||||
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master -j $THREAD_NUM
|
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM
|
||||||
if [[ $? -ne 0 ]]; then
|
if [[ $? -ne 0 ]]; then
|
||||||
echo "compile x86_64 for nnie failed."
|
echo "compile x86_64 for nnie failed."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
elif [[ "${local_lite_platform}" == "arm32" ]]; then
|
elif [[ "${local_lite_platform}" == "arm32" ]]; then
|
||||||
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master -j $THREAD_NUM
|
sh ${compile_nnie_script} -I arm32 -b nnie_3516_master_dev -j $THREAD_NUM
|
||||||
if [[ $? -ne 0 ]]; then
|
if [[ $? -ne 0 ]]; then
|
||||||
echo "compile arm32 for nnie failed."
|
echo "compile arm32 for nnie failed."
|
||||||
exit 1
|
exit 1
|
||||||
|
|
|
@ -14,41 +14,34 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
|
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_H
|
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
||||||
#include <google/protobuf/message.h>
|
|
||||||
#include <string>
|
|
||||||
#include <memory>
|
|
||||||
#include "schema/inner/model_generated.h"
|
|
||||||
#include "ir/anf.h"
|
|
||||||
#include "api/ir/func_graph.h"
|
|
||||||
#include "include/registry/model_parser_registry.h"
|
|
||||||
#include "utils/log_adapter.h"
|
|
||||||
|
|
||||||
namespace mindspore::converter {
|
#include "api/ir/func_graph.h"
|
||||||
class ModelParser {
|
#include "include/registry/parser_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace converter {
|
||||||
|
/// \brief ModelParser defined a base class to parse model.
|
||||||
|
class MS_API ModelParser {
|
||||||
public:
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
ModelParser() = default;
|
ModelParser() = default;
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
virtual ~ModelParser() = default;
|
virtual ~ModelParser() = default;
|
||||||
|
|
||||||
|
/// \brief Method to parse model, which must be onnx/caffe/tf/tflite.
|
||||||
|
///
|
||||||
|
/// \param[in] flags Define the basic parameters when converting, which defined in parser_context.h.
|
||||||
|
///
|
||||||
|
/// \return FuncGraph Pointer, which contains all information about the model.
|
||||||
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
api::FuncGraphPtr res_graph_ = nullptr;
|
api::FuncGraphPtr res_graph_ = nullptr;
|
||||||
};
|
};
|
||||||
|
} // namespace converter
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
typedef ModelParser *(*ModelParserCreator)();
|
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_H_
|
||||||
|
|
||||||
template <class T>
|
|
||||||
ModelParser *LiteModelParserCreator() {
|
|
||||||
auto *parser = new (std::nothrow) T();
|
|
||||||
if (parser == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new model parser failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return parser;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::converter
|
|
||||||
|
|
||||||
#endif
|
|
|
@ -23,6 +23,9 @@
|
||||||
|
|
||||||
using mindspore::converter::FmkType;
|
using mindspore::converter::FmkType;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
namespace converter {
|
||||||
|
class ModelParser;
|
||||||
|
} // namespace converter
|
||||||
namespace registry {
|
namespace registry {
|
||||||
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
/// \brief ModelParserCreator defined function pointer to get a ModelParser class.
|
||||||
typedef converter::ModelParser *(*ModelParserCreator)();
|
typedef converter::ModelParser *(*ModelParserCreator)();
|
||||||
|
|
|
@ -39,9 +39,6 @@ struct MS_API ConverterParameters {
|
||||||
std::string weight_file;
|
std::string weight_file;
|
||||||
std::map<std::string, std::string> attrs;
|
std::map<std::string, std::string> attrs;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief ModelParser defined a base class of model parser
|
|
||||||
class MS_API ModelParser;
|
|
||||||
} // namespace converter
|
} // namespace converter
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "include/registry/model_parser.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||||
#include "tools/converter/model_parser.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelParserTest : public converter::ModelParser {
|
class ModelParserTest : public converter::ModelParser {
|
||||||
|
|
|
@ -24,7 +24,6 @@
|
||||||
#include "ops/fusion/add_fusion.h"
|
#include "ops/fusion/add_fusion.h"
|
||||||
#include "ops/addn.h"
|
#include "ops/addn.h"
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
#include "tools/converter/model_parser.h"
|
|
||||||
#include "tools/converter/optimizer_manager.h"
|
#include "tools/converter/optimizer_manager.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||||
|
|
|
@ -19,9 +19,9 @@
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include "include/registry/model_parser.h"
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "tools/converter/graphdef_transform.h"
|
#include "tools/converter/graphdef_transform.h"
|
||||||
#include "tools/converter/model_parser.h"
|
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "tools/converter/anf_transform.h"
|
#include "tools/converter/anf_transform.h"
|
||||||
|
|
|
@ -576,5 +576,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
||||||
}
|
}
|
||||||
return layer.name();
|
return layer.name();
|
||||||
}
|
}
|
||||||
REG_MODEL_PARSER(kFmkTypeCaffe, converter::LiteModelParserCreator<CaffeModelParser>)
|
REG_MODEL_PARSER(kFmkTypeCaffe, LiteModelParserCreator<CaffeModelParser>)
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "tools/converter/model_parser.h"
|
#include "include/registry/model_parser.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "proto/caffe.pb.h"
|
#include "proto/caffe.pb.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
|
|
|
@ -1248,6 +1248,6 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_MODEL_PARSER(kFmkTypeOnnx, converter::LiteModelParserCreator<OnnxModelParser>)
|
REG_MODEL_PARSER(kFmkTypeOnnx, LiteModelParserCreator<OnnxModelParser>)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,7 +27,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
#include "tools/converter/model_parser.h"
|
#include "include/registry/model_parser.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||||
#include "proto/onnx.pb.h"
|
#include "proto/onnx.pb.h"
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "include/registry/model_parser.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
@ -39,6 +40,16 @@ int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_nod
|
||||||
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
||||||
int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
|
int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
|
||||||
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
schema::Format dst_format, std::set<AnfNodePtr> *has_visited);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
converter::ModelParser *LiteModelParserCreator() {
|
||||||
|
auto *parser = new (std::nothrow) T();
|
||||||
|
if (parser == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new model parser failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return parser;
|
||||||
|
}
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -1158,6 +1158,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_MODEL_PARSER(kFmkTypeTf, converter::LiteModelParserCreator<TFModelParser>)
|
REG_MODEL_PARSER(kFmkTypeTf, LiteModelParserCreator<TFModelParser>)
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/model_parser.h"
|
#include "include/registry/model_parser.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
|
|
||||||
|
|
|
@ -734,5 +734,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_MODEL_PARSER(kFmkTypeTflite, converter::LiteModelParserCreator<TfliteModelParser>)
|
REG_MODEL_PARSER(kFmkTypeTflite, LiteModelParserCreator<TfliteModelParser>)
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "tools/converter/model_parser.h"
|
#include "include/registry/model_parser.h"
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
|
@ -68,7 +68,7 @@ STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPt
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
|
for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
|
||||||
auto tensor_addr = GetInputAddr(stack_node->input(i), 2);
|
auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo);
|
||||||
if (tensor_addr == nullptr) {
|
if (tensor_addr == nullptr) {
|
||||||
MS_LOG(ERROR) << "input tensor addr nullptr";
|
MS_LOG(ERROR) << "input tensor addr nullptr";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -178,7 +178,7 @@ const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, cons
|
||||||
if (IsMarkedTrainOp(fullconnect_cnode)) {
|
if (IsMarkedTrainOp(fullconnect_cnode)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == 3, nullptr);
|
MS_CHECK_TRUE_RET(fullconnect_cnode->inputs().size() == kInputSizeThree, nullptr);
|
||||||
auto left_slice_node = fullconnect_cnode->input(1);
|
auto left_slice_node = fullconnect_cnode->input(1);
|
||||||
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
|
||||||
if (IsMarkedTrainOp(left_slice_cnode)) {
|
if (IsMarkedTrainOp(left_slice_cnode)) {
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
constexpr int kOffsetTwo = 2;
|
||||||
constexpr size_t kCondNodesNum = 12;
|
constexpr size_t kCondNodesNum = 12;
|
||||||
constexpr size_t kCondCNodesNum = 4;
|
constexpr size_t kCondCNodesNum = 4;
|
||||||
constexpr size_t kBodyNodesNum = 69;
|
constexpr size_t kBodyNodesNum = 69;
|
||||||
|
@ -162,7 +163,7 @@ const VectorRef TfBidirectionGruFusion::DefineFowardPattern() const {
|
||||||
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
||||||
auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
|
auto fw_while = VectorRef({is_while, fw_vars_[0], fw_vars_[1], is_param5, fw_stride, is_param6, fw_reserve,
|
||||||
fw_init_state_, fw_min, fw_from_tensor, input_length_});
|
fw_init_state_, fw_min, fw_from_tensor, input_length_});
|
||||||
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
|
fw_while.insert(fw_while.end(), fw_vars_.begin() + kOffsetTwo, fw_vars_.end());
|
||||||
auto is_var1 = std::make_shared<Var>();
|
auto is_var1 = std::make_shared<Var>();
|
||||||
MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
|
MS_CHECK_TRUE_RET(is_var1 != nullptr, {});
|
||||||
fw_while.emplace_back(is_var1);
|
fw_while.emplace_back(is_var1);
|
||||||
|
@ -232,7 +233,7 @@ const VectorRef TfBidirectionGruFusion::DefinebackwardPattern() const {
|
||||||
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
MS_CHECK_TRUE_RET(is_param6 != nullptr, {});
|
||||||
auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
|
auto bw_while = VectorRef({is_while, bw_vars_[0], bw_vars_[1], is_param5, bw_stride, is_param6, bw_reserve,
|
||||||
bw_init_state_, bw_min, bw_from_tensor, input_length_});
|
bw_init_state_, bw_min, bw_from_tensor, input_length_});
|
||||||
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
|
bw_while.insert(bw_while.end(), bw_vars_.begin() + kOffsetTwo, bw_vars_.end());
|
||||||
auto is_var2 = std::make_shared<Var>();
|
auto is_var2 = std::make_shared<Var>();
|
||||||
MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
|
MS_CHECK_TRUE_RET(is_var2 != nullptr, {});
|
||||||
bw_while.emplace_back(is_var2);
|
bw_while.emplace_back(is_var2);
|
||||||
|
@ -400,7 +401,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
|
auto fw_cand_kernel_shape = fw_cand_kernel_value->shape();
|
||||||
if (fw_cand_kernel_shape.size() != 2) {
|
if (fw_cand_kernel_shape.size() != kInputSizeTwo) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
|
auto bw_cand_kernel_value = GetDefaultTensorInfo(bw_cand_kernel_anf);
|
||||||
|
@ -408,7 +409,7 @@ STATUS TfBidirectionGruFusion::GetInputAndHiddenSize(const AnfNodePtr &fw_cand_k
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
|
auto bw_cand_kernel_shape = bw_cand_kernel_value->shape();
|
||||||
if (bw_cand_kernel_shape.size() != 2) {
|
if (bw_cand_kernel_shape.size() != kInputSizeTwo) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
|
if (fw_cand_kernel_shape != bw_cand_kernel_shape) {
|
||||||
|
|
|
@ -45,7 +45,7 @@ void SetConvAttr(const PrimitivePtr &prim, const std::vector<int64_t> &kernel_si
|
||||||
STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
STATUS UpdateConv2DParamPass::UpdateConv2DAttr(const CNodePtr &cnode) {
|
||||||
MS_ASSERT(cnode != nullptr);
|
MS_ASSERT(cnode != nullptr);
|
||||||
if (cnode->size() < kInputSizeThree) {
|
if (cnode->size() < kInputSizeThree) {
|
||||||
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << cnode->size() - 1;
|
MS_LOG(ERROR) << "conv2d's input size is invalid, now is " << (cnode->size() - 1);
|
||||||
return lite::RET_ERROR;
|
return lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
auto weight = cnode->input(kInputIndexTwo);
|
auto weight = cnode->input(kInputIndexTwo);
|
||||||
|
|
Loading…
Reference in New Issue