!11390 convert part of prim attr from str to int

From: @wangnan39
Reviewed-by: @kingxian
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-07 19:15:11 +08:00 committed by Gitee
commit 228a64de0f
14 changed files with 151 additions and 80 deletions

View File

@ -18,6 +18,7 @@
#include <set> #include <set>
#include "common/trans.h" #include "common/trans.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "utils/check_convert_utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
@ -66,9 +67,17 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
std::string InitDefaultFormat(const AnfNodePtr &node) { std::string InitDefaultFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) { if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat); auto primitive_ptr = GetCNodePrimitive(node);
if (attr == kOpFormat_NCDHW) { MS_EXCEPTION_IF_NULL(primitive_ptr);
return kOpFormat_NCDHW; auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
if (!result) {
auto attr = GetValue<std::string>(data_format_ptr);
if (attr == kOpFormat_NCDHW) {
return kOpFormat_NCDHW;
}
} }
} else if (AnfAlgo::IsRealKernel(node)) { } else if (AnfAlgo::IsRealKernel(node)) {
auto formats = AnfAlgo::GetAllOutputFormats(node); auto formats = AnfAlgo::GetAllOutputFormats(node);

View File

@ -23,6 +23,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/check_convert_utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
@ -46,9 +47,15 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
if (group == 1) { if (group == 1) {
return false; return false;
} }
auto data_format = AnfAlgo::GetNodeAttr<std::string>(conv2d, kAttrFormat);
if (data_format != "NCHW") { auto primitive_ptr = GetCNodePrimitive(conv2d);
MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format; MS_EXCEPTION_IF_NULL(primitive_ptr);
auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
if (!result || data_format != Format::NCHW) {
MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1";
} }
if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) { if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) {
MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size() MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size()

View File

@ -21,6 +21,7 @@
#include "base/core_ops.h" #include "base/core_ops.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/check_convert_utils.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
@ -402,9 +403,17 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
} }
SetKernelInfoForNode(cnode); SetKernelInfoForNode(cnode);
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) { if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) {
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat); auto primitive_ptr = GetCNodePrimitive(cnode);
if (attr == kOpFormat_NCDHW) { MS_EXCEPTION_IF_NULL(primitive_ptr);
ResetInFormat(cnode, kOpFormat_NCDHW); auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
if (!result) {
auto attr = GetValue<std::string>(data_format_ptr);
if (attr == kOpFormat_NCDHW) {
ResetInFormat(cnode, kOpFormat_NCDHW);
}
} }
} }
AnfAlgo::SetGraphId(graph_id_, cnode.get()); AnfAlgo::SetGraphId(graph_id_, cnode.get());

View File

@ -29,6 +29,7 @@
#include "utils/convert_utils_py.h" #include "utils/convert_utils_py.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/primitive_utils.h" #include "utils/primitive_utils.h"
#include "utils/check_convert_utils.h"
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "pipeline/pynative/pynative_execute.h" #include "pipeline/pynative/pynative_execute.h"
@ -280,6 +281,8 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) { if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
attr_name = kOpAttrNameReplaceMap[attr_name]; attr_name = kOpAttrNameReplaceMap[attr_name];
} }
const std::string &prim_name = this->name();
CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret);
(void)this->AddAttr(attr_name, converted_ret); (void)this->AddAttr(attr_name, converted_ret);
} }

View File

@ -26,6 +26,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "proto/mind_ir.pb.h" #include "proto/mind_ir.pb.h"
#include "utils/check_convert_utils.h"
namespace mindspore { namespace mindspore {
using FloatPtr = std::shared_ptr<Float>; using FloatPtr = std::shared_ptr<Float>;
@ -425,7 +426,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute(); mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first); attr_proto->set_name(attr.first);
SetValueToAttributeProto(attr.second, attr_proto); auto attr_value = attr.second;
CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto);
} }
} else { } else {
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();

View File

@ -25,6 +25,7 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "proto/onnx.pb.h" #include "proto/onnx.pb.h"
#include "utils/check_convert_utils.h"
namespace mindspore { namespace mindspore {
enum OpMergeMode { enum OpMergeMode {
@ -102,8 +103,9 @@ void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_Attrib
void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value); int64_t attr_value;
if (attr_value == "VALID") { CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true);
if (attr_value == PadMode::VALID) {
attr_proto->set_s("VALID"); attr_proto->set_s("VALID");
} else { } else {
attr_proto->set_s("SAME_UPPER"); attr_proto->set_s("SAME_UPPER");
@ -186,10 +188,11 @@ OPERATOR_ONNX_CONVERT_DEFINE(
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
const PrimitivePtr &prim) { const PrimitivePtr &prim) {
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
auto attr_value = GetValue<std::string>(value); int64_t attr_value;
if (attr_value == "valid") { CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
if (attr_value == PadMode::VALID) {
attr_proto->set_s("VALID"); attr_proto->set_s("VALID");
} else if (attr_value == "same") { } else if (attr_value == PadMode::SAME) {
attr_proto->set_s("SAME_UPPER"); attr_proto->set_s("SAME_UPPER");
} else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
attr_proto->set_name("pads"); attr_proto->set_name("pads");
@ -834,12 +837,13 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
// set pad // set pad
onnx_attr_proto = node_proto->add_attribute(); onnx_attr_proto = node_proto->add_attribute();
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode")); int64_t attr_value;
CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value);
onnx_attr_proto->set_name("auto_pad"); onnx_attr_proto->set_name("auto_pad");
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
if (attr_value == "valid") { if (attr_value == PadMode::VALID) {
onnx_attr_proto->set_s("VALID"); onnx_attr_proto->set_s("VALID");
} else if (attr_value == "same") { } else if (attr_value == PadMode::SAME) {
onnx_attr_proto->set_s("SAME_UPPER"); onnx_attr_proto->set_s("SAME_UPPER");
} else { } else {
onnx_attr_proto->set_name("pads"); onnx_attr_proto->set_name("pads");

View File

@ -59,20 +59,16 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0";
} }
std::set<std::string> available_pad_mode{"pad", "same", "valid"};
auto pad_mode_ptr = primitive->GetAttr("pad_mode"); auto pad_mode_ptr = primitive->GetAttr("pad_mode");
if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa<StringImm>()) { if (pad_mode_ptr != nullptr) {
auto pad_mode = pad_mode_ptr->cast<StringImmPtr>()->value(); int64_t pad_mode;
if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; if (pad_mode == PadMode::VALID) {
}
if (pad_mode == "valid") {
padding = 0; padding = 0;
} else if (pad_mode == "same") { } else if (pad_mode == PadMode::SAME) {
padding = (window - 1) / 2; padding = (window - 1) / 2;
} }
} }
std::set<std::string> available_mode{"max", "avg"}; std::set<std::string> available_mode{"max", "avg"};
auto mode_ptr = primitive->GetAttr("mode"); auto mode_ptr = primitive->GetAttr("mode");
if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) { if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) {
@ -270,13 +266,13 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h, void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride, const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
const std::vector<int64_t> &dilation, const std::string &pad_mode, const std::vector<int64_t> &dilation, const int64_t &pad_mode,
const std::vector<int64_t> &padding) { const std::vector<int64_t> &padding) {
if (pad_mode == "valid") { if (pad_mode == PadMode::VALID) {
output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0])); output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]));
output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1])); output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]));
pad_list->insert(pad_list->begin(), 4, 0); pad_list->insert(pad_list->begin(), 4, 0);
} else if (pad_mode == "same") { } else if (pad_mode == PadMode::SAME) {
output_hw->push_back(std::ceil((x_h * 1.0) / stride[0])); output_hw->push_back(std::ceil((x_h * 1.0) / stride[0]));
output_hw->push_back(std::ceil((x_w * 1.0) / stride[1])); output_hw->push_back(std::ceil((x_w * 1.0) / stride[1]));
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h; int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
@ -287,7 +283,7 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
pad_needed_w = std::max((int64_t)0, pad_needed_w); pad_needed_w = std::max((int64_t)0, pad_needed_w);
pad_list->push_back(std::floor(pad_needed_w / 2)); pad_list->push_back(std::floor(pad_needed_w / 2));
pad_list->push_back(pad_needed_w - pad_list->at(2)); pad_list->push_back(pad_needed_w - pad_list->at(2));
} else if (pad_mode == "pad") { } else if (pad_mode == PadMode::PAD) {
pad_list->insert(pad_list->begin(), padding.begin(), padding.end()); pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
output_hw->push_back(std::floor( output_hw->push_back(std::floor(
1 + 1 +
@ -298,6 +294,15 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
} }
} }
int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC";
}
return data_format;
}
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
@ -322,12 +327,12 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape); CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape); CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape); CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"});
int64_t n_axis = 0; int64_t n_axis = 0;
int64_t c_axis = 1; int64_t c_axis = 1;
int64_t h_axis = 2; int64_t h_axis = 2;
int64_t w_axis = 3; int64_t w_axis = 3;
if (data_format == "NHWC") { int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
if (data_format == Format::NHWC) {
c_axis = 3; c_axis = 3;
h_axis = 1; h_axis = 1;
w_axis = 2; w_axis = 2;
@ -352,8 +357,8 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2); std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2);
std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2); std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2);
std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4); std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4);
std::string pad_mode = int64_t pad_mode;
CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"}); CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
std::vector<int64_t> output_hw; std::vector<int64_t> output_hw;
std::vector<int64_t> pad_list; std::vector<int64_t> pad_list;
std::vector<int64_t> output_hw_min; std::vector<int64_t> output_hw_min;
@ -378,7 +383,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
ShapeVector output_shape; ShapeVector output_shape;
ShapeVector output_shape_min; ShapeVector output_shape_min;
ShapeVector output_shape_max; ShapeVector output_shape_max;
if (data_format == "NHWC") { if (data_format == Format::NHWC) {
output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
@ -426,16 +431,12 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
ShapeVector bias_shape = bias->shape()->shape(); ShapeVector bias_shape = bias->shape()->shape();
ShapeVector x_min_shape = x->shape()->min_shape(); ShapeVector x_min_shape = x->shape()->min_shape();
ShapeVector x_max_shape = x->shape()->max_shape(); ShapeVector x_max_shape = x->shape()->max_shape();
std::set<std::string> available_data_format{"NCHW", "NHWC"};
auto data_format_ptr = primitive->GetAttr("format"); auto data_format_ptr = primitive->GetAttr("format");
std::string data_format = "NCHW"; int64_t data_format = Format::NCHW;
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) { if (data_format_ptr != nullptr) {
data_format = data_format_ptr->cast<StringImmPtr>()->value(); data_format = GetAndCheckFormat(data_format_ptr);
} }
if (available_data_format.find(data_format) == available_data_format.end()) { auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1];
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC.";
}
auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1];
// Additional check for dynamic shape // Additional check for dynamic shape
// Last infer will be real shape values // Last infer will be real shape values
bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; });

View File

@ -29,6 +29,7 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "utils/shape_utils.h" #include "utils/shape_utils.h"
#include "utils/check_convert_utils.h"
using std::string; using std::string;
@ -494,7 +495,11 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
case FORM_PARSE_SCALAR: { case FORM_PARSE_SCALAR: {
std::size_t value_pos(0); std::size_t value_pos(0);
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) { if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto); ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
const std::string &op_type = prim->name();
if (!IsLite()) {
CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res);
}
prim->AddAttr(attr_name, res); prim->AddAttr(attr_name, res);
break; break;
} }

View File

@ -39,6 +39,8 @@ class MSANFModelParser {
std::string GetProducerName() { return producer_name_; } std::string GetProducerName() { return producer_name_; }
std::string GetProducerVersion() { return model_version_; } std::string GetProducerVersion() { return model_version_; }
std::string GetIrVersion() { return ir_version_; } std::string GetIrVersion() { return ir_version_; }
void SetLite() { is_lite_ = true; }
bool IsLite() { return is_lite_; }
private: private:
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto); bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
@ -68,6 +70,7 @@ class MSANFModelParser {
std::string producer_name_; std::string producer_name_;
std::string model_version_; std::string model_version_;
std::string ir_version_; std::string ir_version_;
bool is_lite_ = false;
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_; std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
}; };
} // namespace mindspore } // namespace mindspore

View File

@ -71,7 +71,7 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
return buf; return buf;
} }
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) { std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) {
auto graphBuf = ReadProtoFile(file_name); auto graphBuf = ReadProtoFile(file_name);
if (graphBuf == nullptr) { if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str();
@ -79,7 +79,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
} }
try { try {
auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size()); auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size(), is_lite);
return graph; return graph;
} catch (std::exception &e) { } catch (std::exception &e) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str(); MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
@ -87,7 +87,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
} }
} }
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) { std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
MS_EXCEPTION_IF_NULL(buf); MS_EXCEPTION_IF_NULL(buf);
std::string str((const char *)buf, buf_size); std::string str((const char *)buf, buf_size);
mind_ir::ModelProto model_; mind_ir::ModelProto model_;
@ -95,6 +95,9 @@ std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_
MS_LOG(ERROR) << "Parse model from buffer fail!"; MS_LOG(ERROR) << "Parse model from buffer fail!";
} }
MSANFModelParser model_parser; MSANFModelParser model_parser;
if (is_lite) {
model_parser.SetLite();
}
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
return dstgraph_ptr; return dstgraph_ptr;
} }

View File

@ -24,8 +24,8 @@
#include "ir/func_graph.h" #include "ir/func_graph.h"
namespace mindspore { namespace mindspore {
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name); std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false);
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file); std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size); std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_LOAD_MODEL_H #endif // MINDSPORE_CORE_LOAD_MODEL_H

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,13 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
#include "utils/check_convert_utils.h"
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <typeinfo> #include <typeinfo>
#include <functional> #include <functional>
#include "utils/check_convert_utils.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "ops/op_utils.h"
#include "ir/dtype/type.h" #include "ir/dtype/type.h"
#include "ir/dtype/tensor_type.h" #include "ir/dtype/tensor_type.h"
#include "ir/dtype.h" #include "ir/dtype.h"
@ -84,21 +87,21 @@ AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMa
AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap); AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = { static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
{"format", DataFormatConverter}, {ops::kFormat, DataFormatConverter},
{"pad_mode", PadModeConverter}, {ops::kPadMode, PadModeConverter},
}; };
static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = { static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
{"format", DataFormatConverter}, {ops::kFormat, DataFormatConverter},
{"pad_mode", PadModeUpperConverter}, {ops::kPadMode, PadModeUpperConverter},
}; };
static std::map<std::string, AttrConverterPair> DataFormatMap = { static std::map<std::string, AttrConverterPair> DataFormatMap = {
{"format", DataFormatConverter}, {ops::kFormat, DataFormatConverter},
}; };
static std::map<std::string, AttrConverterPair> ReductionMap = { static std::map<std::string, AttrConverterPair> ReductionMap = {
{"reduction", ReductionConverter}, {ops::kReduction, ReductionConverter},
}; };
static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = { static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
@ -132,24 +135,42 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
{"BinaryCrossEntropy", ReductionMap}, {"BinaryCrossEntropy", ReductionMap},
{"BinaryCrossEntropyGrad", ReductionMap}, {"BinaryCrossEntropyGrad", ReductionMap},
{"NLLLoss", ReductionMap}, {"NLLLoss", ReductionMap},
{"DepthToSpace", DataFormatMap},
}; };
int64_t CheckAndConvertUtils::GetDataFormatEnumValue(const std::string &value) { bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
if (DataFormatToEnumMap.find(value) == DataFormatToEnumMap.end()) { MS_EXCEPTION_IF_NULL(value);
MS_LOG(ERROR) << "Can not convert data format " << value << "to enum"; if (value->isa<StringImm>()) {
auto attr_value_str = GetValue<std::string>(value);
if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) {
MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
return false;
}
*enum_value = DataFormatToEnumMap[attr_value_str];
return true;
} else {
*enum_value = GetValue<int64_t>(value);
return true;
} }
return DataFormatToEnumMap[value]; return false;
} }
int64_t CheckAndConvertUtils::GetPadModEnumValue(const std::string &value, bool is_upper) { void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
std::map<std::string, int64_t> pad_map = PadModToEnumMap; MS_EXCEPTION_IF_NULL(value);
if (is_upper) { if (value->isa<StringImm>()) {
pad_map = PadModToEnumUpperMap; auto attr_value_str = GetValue<std::string>(value);
std::map<std::string, int64_t> pad_map = PadModToEnumMap;
if (is_upper) {
pad_map = PadModToEnumUpperMap;
}
if (pad_map.find(attr_value_str) == pad_map.end()) {
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
}
*enum_value = pad_map[attr_value_str];
} else {
*enum_value = GetValue<int64_t>(value);
} }
if (pad_map.find(value) == pad_map.end()) {
MS_LOG(ERROR) << "Can not convert pad mode " << value << "to enum";
}
return pad_map[value];
} }
AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) { AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
@ -172,8 +193,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op
bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) { ValuePtr *const value) {
if (value == nullptr) { if (value == nullptr || *value == nullptr) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
return false; return false;
} }
if (!(*value)->isa<StringImm>()) { if (!(*value)->isa<StringImm>()) {
@ -191,12 +212,17 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con
} }
if (!do_convert) { if (!do_convert) {
transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper); transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
do_convert = true;
}
}
if (!do_convert) {
transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) { if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int"; MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
return false; return false;
} }
} }
*value = MakeValue<int64_t>(attr_map_pair.first[real_value]); *value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name; MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
return true; return true;
@ -204,7 +230,7 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con
bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) { ValuePtr *const value) {
if (value == nullptr) { if (value == nullptr || *value == nullptr) {
MS_LOG(ERROR) << "value is nullptr"; MS_LOG(ERROR) << "value is nullptr";
return false; return false;
} }
@ -226,7 +252,6 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
return true; return true;
} }
namespace { namespace {
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction; typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
@ -242,7 +267,6 @@ std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis",
{"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}}; {"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
} // namespace } // namespace
bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2) { bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2) {
if (vec_1.size() != vec_2.size()) { if (vec_1.size() != vec_2.size()) {
return false; return false;

View File

@ -284,8 +284,8 @@ class CheckAndConvertUtils {
static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name); static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
static int64_t GetDataFormatEnumValue(const std::string &value); static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
static int64_t GetPadModEnumValue(const std::string &value, bool is_upper = false); static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value); static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
private: private:

View File

@ -856,7 +856,7 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model
int AnfImporterFromMindir::Import(const converter::Flags *flag) { int AnfImporterFromMindir::Import(const converter::Flags *flag) {
#if SUPPORT_TRAIN #if SUPPORT_TRAIN
func_graph_ = LoadMindIR(flag->modelFile); func_graph_ = LoadMindIR(flag->modelFile, true);
if (func_graph_ != nullptr) { if (func_graph_ != nullptr) {
return RET_OK; return RET_OK;
} else { } else {
@ -866,7 +866,7 @@ int AnfImporterFromMindir::Import(const converter::Flags *flag) {
onnx_model_ = ReadOnnxFromBinary(flag->modelFile); onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
if (onnx_model_ == nullptr) { if (onnx_model_ == nullptr) {
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model"; MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
func_graph_ = LoadMindIR(flag->modelFile); func_graph_ = LoadMindIR(flag->modelFile, true);
if (func_graph_ == nullptr) { if (func_graph_ == nullptr) {
MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file."; MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file.";
return RET_GRAPH_FILE_ERR; return RET_GRAPH_FILE_ERR;