From 2f79905f60d78d28618ccb008f868df986126eb2 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Thu, 4 Feb 2021 22:09:32 +0800 Subject: [PATCH] convert attr from str to enum --- .../backend/optimizer/ascend/ascend_helper.cc | 15 +++- .../ascend/mindir/conv2d_unify_mindir.cc | 13 +++- .../ccsrc/backend/session/kernel_graph.cc | 15 +++- .../pipeline/pynative/pynative_execute.cc | 2 - mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 3 + .../transform/express_ir/mindir_exporter.cc | 5 +- .../transform/express_ir/onnx_exporter.cc | 20 +++-- mindspore/core/abstract/prim_nn.cc | 53 ++++++------- .../core/load_mindir/anf_model_parser.cc | 7 +- mindspore/core/load_mindir/anf_model_parser.h | 3 + mindspore/core/load_mindir/load_model.cc | 9 ++- mindspore/core/load_mindir/load_model.h | 4 +- mindspore/core/utils/check_convert_utils.cc | 76 ++++++++++++------- mindspore/core/utils/check_convert_utils.h | 4 +- .../tools/anf_importer/import_from_mindir.cc | 4 +- 15 files changed, 151 insertions(+), 82 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 225b6d266d6..a42539e1dd7 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -18,6 +18,7 @@ #include #include "common/trans.h" #include "utils/ms_utils.h" +#include "utils/check_convert_utils.h" #include "backend/optimizer/common/helper.h" #include "utils/utils.h" #include "runtime/device/kernel_info.h" @@ -66,9 +67,17 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { std::string InitDefaultFormat(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (node->isa() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast())) { - auto attr = AnfAlgo::GetNodeAttr(node, kAttrFormat); - if (attr == kOpFormat_NCDHW) { - return kOpFormat_NCDHW; + auto primitive_ptr = GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive_ptr); + auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); + MS_EXCEPTION_IF_NULL(data_format_ptr); + int64_t data_format; + bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); + if (!result) { + auto attr = GetValue(data_format_ptr); + if (attr == kOpFormat_NCDHW) { + return kOpFormat_NCDHW; + } } } else if (AnfAlgo::IsRealKernel(node)) { auto formats = AnfAlgo::GetAllOutputFormats(node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc index 56b4664332c..f31fa66793f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc @@ -23,6 +23,7 @@ #include "utils/utils.h" #include "utils/ms_context.h" +#include "utils/check_convert_utils.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_info.h" #include "backend/session/anf_runtime_algorithm.h" @@ -46,9 +47,15 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector in_shape, std::vecto if (group == 1) { return false; } - auto data_format = AnfAlgo::GetNodeAttr(conv2d, kAttrFormat); - if (data_format != "NCHW") { - MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format; + + auto primitive_ptr = GetCNodePrimitive(conv2d); + MS_EXCEPTION_IF_NULL(primitive_ptr); + auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); + MS_EXCEPTION_IF_NULL(data_format_ptr); + int64_t data_format; + bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); + if (!result || data_format != Format::NCHW) { + MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1"; } if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) { MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size() diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 71ffbfc6d4b..d0d1f1f331d 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -21,6 +21,7 @@ #include "base/core_ops.h" #include "ir/param_info.h" #include "utils/utils.h" +#include "utils/check_convert_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/kernel_info.h" #include "backend/kernel_compiler/kernel_build_info.h" @@ -402,9 +403,17 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { } SetKernelInfoForNode(cnode); if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) { - auto attr = AnfAlgo::GetNodeAttr(cnode, kAttrFormat); - if (attr == kOpFormat_NCDHW) { - ResetInFormat(cnode, kOpFormat_NCDHW); + auto primitive_ptr = GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive_ptr); + auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat); + MS_EXCEPTION_IF_NULL(data_format_ptr); + int64_t data_format; + bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); + if (!result) { + auto attr = GetValue(data_format_ptr); + if (attr == kOpFormat_NCDHW) { + ResetInFormat(cnode, kOpFormat_NCDHW); + } } } AnfAlgo::SetGraphId(graph_id_, cnode.get()); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b52997f9fdd..c4fdb0c92d8 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -459,7 +459,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector op_prim->EndRecordAddAttr(); } - void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) { MS_EXCEPTION_IF_NULL(op_run_info); PrimitivePtr op_prim = op_run_info->py_primitive; @@ -479,7 +478,6 @@ void ConvertAttrToUnifyMindIR(const OpExecInfoPtr &op_run_info) { } } - BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { if (utils::isa(base_ref)) { auto ref_list = utils::cast(base_ref); diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 449d7bc6f3d..23cfee8d677 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -29,6 +29,7 @@ #include "utils/convert_utils_py.h" #include "utils/ms_context.h" #include "utils/primitive_utils.h" +#include "utils/check_convert_utils.h" #include "pipeline/jit/resource.h" #include "pipeline/pynative/pynative_execute.h" @@ -280,6 +281,8 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { attr_name = kOpAttrNameReplaceMap[attr_name]; } + const std::string &prim_name = this->name(); + CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret); (void)this->AddAttr(attr_name, converted_ret); } diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index 4f38cf5f9ea..cadc5e2f7c5 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -26,6 +26,7 @@ #include "ir/func_graph.h" #include "base/core_ops.h" #include "proto/mind_ir.pb.h" +#include "utils/check_convert_utils.h" namespace mindspore { using FloatPtr = std::shared_ptr; @@ -425,7 +426,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); - SetValueToAttributeProto(attr.second, attr_proto); + auto attr_value = attr.second; + CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value); + SetValueToAttributeProto(attr_value, attr_proto); } } else { MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index 572058555b4..69e41dd840f 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -25,6 +25,7 @@ #include "ir/func_graph.h" #include "base/core_ops.h" #include "proto/onnx.pb.h" +#include "utils/check_convert_utils.h" namespace mindspore { enum OpMergeMode { @@ -102,8 +103,9 @@ void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_Attrib void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "VALID") { + int64_t attr_value; + CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true); + if (attr_value == PadMode::VALID) { attr_proto->set_s("VALID"); } else { attr_proto->set_s("SAME_UPPER"); @@ -186,10 +188,11 @@ OPERATOR_ONNX_CONVERT_DEFINE( [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, const PrimitivePtr &prim) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "valid") { + int64_t attr_value; + CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value); + if (attr_value == PadMode::VALID) { attr_proto->set_s("VALID"); - } else if (attr_value == "same") { + } else if (attr_value == PadMode::SAME) { attr_proto->set_s("SAME_UPPER"); } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' attr_proto->set_name("pads"); @@ -834,12 +837,13 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/ // set pad onnx_attr_proto = node_proto->add_attribute(); - auto attr_value = GetValue(prim->GetAttr("pad_mode")); + int64_t attr_value; + CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value); onnx_attr_proto->set_name("auto_pad"); onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - if (attr_value == "valid") { + if (attr_value == PadMode::VALID) { onnx_attr_proto->set_s("VALID"); - } else if (attr_value == "same") { + } else if (attr_value == PadMode::SAME) { onnx_attr_proto->set_s("SAME_UPPER"); } else { onnx_attr_proto->set_name("pads"); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index a92094f50d5..af13884cf96 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -59,20 +59,16 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; } - std::set available_pad_mode{"pad", "same", "valid"}; auto pad_mode_ptr = primitive->GetAttr("pad_mode"); - if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa()) { - auto pad_mode = pad_mode_ptr->cast()->value(); - if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; - } - if (pad_mode == "valid") { + if (pad_mode_ptr != nullptr) { + int64_t pad_mode; + CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true); + if (pad_mode == PadMode::VALID) { padding = 0; - } else if (pad_mode == "same") { + } else if (pad_mode == PadMode::SAME) { padding = (window - 1) / 2; } } - std::set available_mode{"max", "avg"}; auto mode_ptr = primitive->GetAttr("mode"); if ((mode_ptr != nullptr) && mode_ptr->isa()) { @@ -270,13 +266,13 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit void Conv2DPadFunction(std::vector *output_hw, std::vector *pad_list, const int64_t x_h, const int64_t x_w, const std::vector &kernel, const std::vector &stride, - const std::vector &dilation, const std::string &pad_mode, + const std::vector &dilation, const int64_t &pad_mode, const std::vector &padding) { - if (pad_mode == "valid") { + if (pad_mode == PadMode::VALID) { output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); pad_list->insert(pad_list->begin(), 4, 0); - } else if (pad_mode == "same") { + } else if (pad_mode == PadMode::SAME) { output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; @@ -287,7 +283,7 @@ void Conv2DPadFunction(std::vector *output_hw, std::vector *pa pad_needed_w = std::max((int64_t)0, pad_needed_w); pad_list->push_back(std::floor(pad_needed_w / 2)); pad_list->push_back(pad_needed_w - pad_list->at(2)); - } else if (pad_mode == "pad") { + } else if (pad_mode == PadMode::PAD) { pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); output_hw->push_back(std::floor( 1 + @@ -298,6 +294,15 @@ void Conv2DPadFunction(std::vector *output_hw, std::vector *pa } } +int64_t GetAndCheckFormat(const ValuePtr &value) { + int64_t data_format; + bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); + if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { + MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; + } + return data_format; +} + AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); @@ -322,12 +327,12 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); - std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"}); int64_t n_axis = 0; int64_t c_axis = 1; int64_t h_axis = 2; int64_t w_axis = 3; - if (data_format == "NHWC") { + int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format")); + if (data_format == Format::NHWC) { c_axis = 3; h_axis = 1; w_axis = 2; @@ -351,8 +356,8 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p std::vector stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2); std::vector dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2); std::vector padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4); - std::string pad_mode = - CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"}); + int64_t pad_mode; + CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode); std::vector output_hw; std::vector pad_list; std::vector output_hw_min; @@ -377,7 +382,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p ShapeVector output_shape; ShapeVector output_shape_min; ShapeVector output_shape_max; - if (data_format == "NHWC") { + if (data_format == Format::NHWC) { output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; @@ -425,16 +430,12 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr & ShapeVector bias_shape = bias->shape()->shape(); ShapeVector x_min_shape = x->shape()->min_shape(); ShapeVector x_max_shape = x->shape()->max_shape(); - std::set available_data_format{"NCHW", "NHWC"}; auto data_format_ptr = primitive->GetAttr("format"); - std::string data_format = "NCHW"; - if ((data_format_ptr != nullptr) && data_format_ptr->isa()) { - data_format = data_format_ptr->cast()->value(); + int64_t data_format = Format::NCHW; + if (data_format_ptr != nullptr) { + data_format = GetAndCheckFormat(data_format_ptr); } - if (available_data_format.find(data_format) == available_data_format.end()) { - MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC."; - } - auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1]; + auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1]; // Additional check for dynamic shape // Last infer will be real shape values bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 8df89433817..f787e3af106 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -29,6 +29,7 @@ #include "abstract/abstract_value.h" #include "utils/log_adapter.h" #include "utils/shape_utils.h" +#include "utils/check_convert_utils.h" using std::string; @@ -494,7 +495,11 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind case FORM_PARSE_SCALAR: { std::size_t value_pos(0); if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { - auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); + ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto); + const std::string &op_type = prim->name(); + if (!IsLite()) { + CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res); + } prim->AddAttr(attr_name, res); break; } diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index fa03d0e82ec..805bb765b1c 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -39,6 +39,8 @@ class MSANFModelParser { std::string GetProducerName() { return producer_name_; } std::string GetProducerVersion() { return model_version_; } std::string GetIrVersion() { return ir_version_; } + void SetLite() { is_lite_ = true; } + bool IsLite() { return is_lite_; } private: bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); @@ -68,6 +70,7 @@ class MSANFModelParser { std::string producer_name_; std::string model_version_; std::string ir_version_; + bool is_lite_ = false; std::unordered_map anfnode_build_map_; }; } // namespace mindspore diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index 92122f77827..ed34247ddc3 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -71,7 +71,7 @@ std::shared_ptr> ReadProtoFile(const std::string &file) { return buf; } -std::shared_ptr LoadMindIR(const std::string &file_name) { +std::shared_ptr LoadMindIR(const std::string &file_name, bool is_lite) { auto graphBuf = ReadProtoFile(file_name); if (graphBuf == nullptr) { MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); @@ -79,7 +79,7 @@ std::shared_ptr LoadMindIR(const std::string &file_name) { } try { - auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size()); + auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size(), is_lite); return graph; } catch (std::exception &e) { MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); @@ -87,7 +87,7 @@ std::shared_ptr LoadMindIR(const std::string &file_name) { } } -std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) { +std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) { MS_EXCEPTION_IF_NULL(buf); std::string str((const char *)buf, buf_size); mind_ir::ModelProto model_; @@ -95,6 +95,9 @@ std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_ MS_LOG(ERROR) << "Parse model from buffer fail!"; } MSANFModelParser model_parser; + if (is_lite) { + model_parser.SetLite(); + } FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); return dstgraph_ptr; } diff --git a/mindspore/core/load_mindir/load_model.h b/mindspore/core/load_mindir/load_model.h index afe739cb3b1..46f6180d0fa 100644 --- a/mindspore/core/load_mindir/load_model.h +++ b/mindspore/core/load_mindir/load_model.h @@ -24,8 +24,8 @@ #include "ir/func_graph.h" namespace mindspore { -std::shared_ptr LoadMindIR(const std::string &file_name); +std::shared_ptr LoadMindIR(const std::string &file_name, bool is_lite = false); std::shared_ptr> ReadProtoFile(const std::string &file); -std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size); +std::shared_ptr ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false); } // namespace mindspore #endif // MINDSPORE_CORE_LOAD_MODEL_H diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 6c778a0498d..bdc22af34ae 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,16 @@ * limitations under the License. */ +#include "utils/check_convert_utils.h" + #include #include #include #include #include -#include "utils/check_convert_utils.h" + #include "abstract/abstract_value.h" +#include "ops/op_utils.h" #include "ir/dtype/type.h" #include "ir/dtype/tensor_type.h" #include "ir/dtype.h" @@ -84,21 +87,21 @@ AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMa AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap); static std::map FormatAndPadAttrMap = { - {"format", DataFormatConverter}, - {"pad_mode", PadModeConverter}, + {ops::kFormat, DataFormatConverter}, + {ops::kPadMode, PadModeConverter}, }; static std::map FormatAndPadUpperAttrMap = { - {"format", DataFormatConverter}, - {"pad_mode", PadModeUpperConverter}, + {ops::kFormat, DataFormatConverter}, + {ops::kPadMode, PadModeUpperConverter}, }; static std::map DataFormatMap = { - {"format", DataFormatConverter}, + {ops::kFormat, DataFormatConverter}, }; static std::map ReductionMap = { - {"reduction", ReductionConverter}, + {ops::kReduction, ReductionConverter}, }; static std::map> PrimAttrConvertMap = { @@ -132,24 +135,42 @@ static std::map> PrimAttrC {"BinaryCrossEntropy", ReductionMap}, {"BinaryCrossEntropyGrad", ReductionMap}, {"NLLLoss", ReductionMap}, + {"DepthToSpace", DataFormatMap}, }; -int64_t CheckAndConvertUtils::GetDataFormatEnumValue(const std::string &value) { - if (DataFormatToEnumMap.find(value) == DataFormatToEnumMap.end()) { - MS_LOG(ERROR) << "Can not convert data format " << value << "to enum"; +bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto attr_value_str = GetValue(value); + if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) { + MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum."; + return false; + } + *enum_value = DataFormatToEnumMap[attr_value_str]; + return true; + } else { + *enum_value = GetValue(value); + return true; } - return DataFormatToEnumMap[value]; + return false; } -int64_t CheckAndConvertUtils::GetPadModEnumValue(const std::string &value, bool is_upper) { - std::map pad_map = PadModToEnumMap; - if (is_upper) { - pad_map = PadModToEnumUpperMap; +void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto attr_value_str = GetValue(value); + + std::map pad_map = PadModToEnumMap; + if (is_upper) { + pad_map = PadModToEnumUpperMap; + } + if (pad_map.find(attr_value_str) == pad_map.end()) { + MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same"; + } + *enum_value = pad_map[attr_value_str]; + } else { + *enum_value = GetValue(value); } - if (pad_map.find(value) == pad_map.end()) { - MS_LOG(ERROR) << "Can not convert pad mode " << value << "to enum"; - } - return pad_map[value]; } AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) { @@ -172,8 +193,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value) { - if (value == nullptr) { - MS_LOG(ERROR) << "value is nullptr"; + if (value == nullptr || *value == nullptr) { + MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr."; return false; } if (!(*value)->isa()) { @@ -191,12 +212,17 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con } if (!do_convert) { transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper); + if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) { + do_convert = true; + } + } + if (!do_convert) { + transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower); if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) { MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int"; return false; } } - *value = MakeValue(attr_map_pair.first[real_value]); MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name; return true; @@ -204,7 +230,7 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value) { - if (value == nullptr) { + if (value == nullptr || *value == nullptr) { MS_LOG(ERROR) << "value is nullptr"; return false; } @@ -226,7 +252,6 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, return true; } - namespace { typedef std::map> AttrFunction; @@ -242,7 +267,6 @@ std::map kIrAttrToOpAttr = {{"L2Normalize", {{"axis", {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}}; } // namespace - bool CheckAndConvertUtils::IsEqualVector(const std::vector &vec_1, const std::vector &vec_2) { if (vec_1.size() != vec_2.size()) { return false; diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 758d55f8cca..c83f13c12c3 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -282,8 +282,8 @@ class CheckAndConvertUtils { static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); - static int64_t GetDataFormatEnumValue(const std::string &value); - static int64_t GetPadModEnumValue(const std::string &value, bool is_upper = false); + static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value); + static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false); static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); private: diff --git a/mindspore/lite/tools/anf_importer/import_from_mindir.cc b/mindspore/lite/tools/anf_importer/import_from_mindir.cc index 3c44c319e6a..63a01b9caa1 100644 --- a/mindspore/lite/tools/anf_importer/import_from_mindir.cc +++ b/mindspore/lite/tools/anf_importer/import_from_mindir.cc @@ -856,7 +856,7 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model int AnfImporterFromMindir::Import(const converter::Flags *flag) { #if SUPPORT_TRAIN - func_graph_ = LoadMindIR(flag->modelFile); + func_graph_ = LoadMindIR(flag->modelFile, true); if (func_graph_ != nullptr) { return RET_OK; } else { @@ -866,7 +866,7 @@ int AnfImporterFromMindir::Import(const converter::Flags *flag) { onnx_model_ = ReadOnnxFromBinary(flag->modelFile); if (onnx_model_ == nullptr) { MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; - func_graph_ = LoadMindIR(flag->modelFile); + func_graph_ = LoadMindIR(flag->modelFile, true); if (func_graph_ == nullptr) { MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; return RET_GRAPH_FILE_ERR;