forked from mindspore-Ecosystem/mindspore
!23270 fix tf model parser
Merge pull request !23270 from zhaodezan/master
This commit is contained in:
commit
73db02cf0f
|
@ -32,6 +32,7 @@ constexpr size_t NCHWTopPadPos = 4;
|
|||
|
||||
STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *kernel) {
|
||||
MS_ASSERT(kernel != nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The kernels should be specified";
|
||||
|
@ -51,6 +52,7 @@ STATUS TFConvBaseParser::ParseKernels(const tensorflow::NodeDef &node_def, const
|
|||
|
||||
STATUS TFConvBaseParser::ParseStrides(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *strides) {
|
||||
MS_ASSERT(strides != nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "strides", &attr_value)) {
|
||||
strides->at(0) = 1;
|
||||
|
@ -97,6 +99,7 @@ STATUS TFConvBaseParser::ParseExplicitPaddings(const tensorflow::NodeDef &node_d
|
|||
|
||||
STATUS TFConvBaseParser::ParseDilations(const tensorflow::NodeDef &node_def, const mindspore::Format &format,
|
||||
std::vector<int64_t> *dilations) {
|
||||
MS_ASSERT(dilations != nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!TensorFlowUtils::FindAttrValue(node_def, "dilations", &attr_value)) {
|
||||
dilations->at(0) = 1;
|
||||
|
|
|
@ -97,6 +97,7 @@ int GetShapeSize(const tensorflow::TensorProto &tensor_proto) {
|
|||
auto &tensor_shape = tensor_proto.tensor_shape();
|
||||
int shape_size = 1;
|
||||
for (int i = 0; i < tensor_shape.dim_size(); i++) {
|
||||
MS_CHECK_INT_MUL_NOT_OVERFLOW(shape_size, tensor_shape.dim(i).size(), 0);
|
||||
shape_size *= tensor_shape.dim(i).size();
|
||||
}
|
||||
return shape_size;
|
||||
|
@ -582,6 +583,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
return nullptr;
|
||||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
|
@ -733,6 +735,7 @@ STATUS TFModelParser::ConvertSubgraph() {
|
|||
}
|
||||
|
||||
FuncGraphPtr sub_func_graph = std::make_shared<FuncGraph>();
|
||||
MS_CHECK_TRUE_RET(sub_func_graph != nullptr, RET_ERROR);
|
||||
sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name));
|
||||
sub_func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeTf)));
|
||||
std::unordered_map<std::string, AnfNodePtr> anf_sub_node_map;
|
||||
|
@ -1142,6 +1145,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
|||
int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
||||
for (const auto &func_graph : all_func_graphs) {
|
||||
auto functionalize_control_op_pass = std::make_shared<opt::FunctionalizeControlOpPass>();
|
||||
MS_CHECK_TRUE_RET(functionalize_control_op_pass != nullptr, RET_ERROR);
|
||||
if (!functionalize_control_op_pass->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "functionalize control op pass failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
|
|
Loading…
Reference in New Issue