forked from mindspore-Ecosystem/mindspore
!11390 convert part of prim attr from str to int
From: @wangnan39 Reviewed-by: @kingxian Signed-off-by:
This commit is contained in:
commit
228a64de0f
|
@ -18,6 +18,7 @@
|
|||
#include <set>
|
||||
#include "common/trans.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "utils/utils.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
|
@ -66,10 +67,18 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
|
|||
std::string InitDefaultFormat(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat);
|
||||
auto primitive_ptr = GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive_ptr);
|
||||
auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
|
||||
if (!result) {
|
||||
auto attr = GetValue<std::string>(data_format_ptr);
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
return kOpFormat_NCDHW;
|
||||
}
|
||||
}
|
||||
} else if (AnfAlgo::IsRealKernel(node)) {
|
||||
auto formats = AnfAlgo::GetAllOutputFormats(node);
|
||||
if (std::any_of(formats.begin(), formats.end(),
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -46,9 +47,15 @@ bool NeedUpdate(const CNodePtr &conv2d, std::vector<size_t> in_shape, std::vecto
|
|||
if (group == 1) {
|
||||
return false;
|
||||
}
|
||||
auto data_format = AnfAlgo::GetNodeAttr<std::string>(conv2d, kAttrFormat);
|
||||
if (data_format != "NCHW") {
|
||||
MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1, but got " << data_format;
|
||||
|
||||
auto primitive_ptr = GetCNodePrimitive(conv2d);
|
||||
MS_EXCEPTION_IF_NULL(primitive_ptr);
|
||||
auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
|
||||
if (!result || data_format != Format::NCHW) {
|
||||
MS_LOG(EXCEPTION) << "Conv2D only supports NCHW when group > 1";
|
||||
}
|
||||
if (in_shape.size() != kConv2DAxisNum || out_shape.size() != kConv2DAxisNum) {
|
||||
MS_LOG(EXCEPTION) << "Conv2D's input and output should have 4 axis, but got input axis num: " << in_shape.size()
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
|
@ -402,11 +403,19 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
}
|
||||
SetKernelInfoForNode(cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat);
|
||||
auto primitive_ptr = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive_ptr);
|
||||
auto data_format_ptr = primitive_ptr->GetAttr(kAttrFormat);
|
||||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
|
||||
if (!result) {
|
||||
auto attr = GetValue<std::string>(data_format_ptr);
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
ResetInFormat(cnode, kOpFormat_NCDHW);
|
||||
}
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetGraphId(graph_id_, cnode.get());
|
||||
return cnode;
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "utils/convert_utils_py.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/primitive_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
|
||||
|
@ -280,6 +281,8 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
|
|||
if (kOpAttrNameReplaceMap.find(attr_name) != kOpAttrNameReplaceMap.end()) {
|
||||
attr_name = kOpAttrNameReplaceMap[attr_name];
|
||||
}
|
||||
const std::string &prim_name = this->name();
|
||||
CheckAndConvertUtils::ConvertAttrValueToInt(prim_name, attr_name, &converted_ret);
|
||||
(void)this->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
using FloatPtr = std::shared_ptr<Float>;
|
||||
|
@ -425,7 +426,9 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|||
MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name(attr.first);
|
||||
SetValueToAttributeProto(attr.second, attr_proto);
|
||||
auto attr_value = attr.second;
|
||||
CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value);
|
||||
SetValueToAttributeProto(attr_value, attr_proto);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name();
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
enum OpMergeMode {
|
||||
|
@ -102,8 +103,9 @@ void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_Attrib
|
|||
void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType,
|
||||
onnx::AttributeProto *const attr_proto, const PrimitivePtr &) {
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
auto attr_value = GetValue<std::string>(value);
|
||||
if (attr_value == "VALID") {
|
||||
int64_t attr_value;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value, true);
|
||||
if (attr_value == PadMode::VALID) {
|
||||
attr_proto->set_s("VALID");
|
||||
} else {
|
||||
attr_proto->set_s("SAME_UPPER");
|
||||
|
@ -186,10 +188,11 @@ OPERATOR_ONNX_CONVERT_DEFINE(
|
|||
[](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto,
|
||||
const PrimitivePtr &prim) {
|
||||
attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
auto attr_value = GetValue<std::string>(value);
|
||||
if (attr_value == "valid") {
|
||||
int64_t attr_value;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(value, &attr_value);
|
||||
if (attr_value == PadMode::VALID) {
|
||||
attr_proto->set_s("VALID");
|
||||
} else if (attr_value == "same") {
|
||||
} else if (attr_value == PadMode::SAME) {
|
||||
attr_proto->set_s("SAME_UPPER");
|
||||
} else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads'
|
||||
attr_proto->set_name("pads");
|
||||
|
@ -834,12 +837,13 @@ void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/
|
|||
|
||||
// set pad
|
||||
onnx_attr_proto = node_proto->add_attribute();
|
||||
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode"));
|
||||
int64_t attr_value;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr("pad_mode"), &attr_value);
|
||||
onnx_attr_proto->set_name("auto_pad");
|
||||
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING);
|
||||
if (attr_value == "valid") {
|
||||
if (attr_value == PadMode::VALID) {
|
||||
onnx_attr_proto->set_s("VALID");
|
||||
} else if (attr_value == "same") {
|
||||
} else if (attr_value == PadMode::SAME) {
|
||||
onnx_attr_proto->set_s("SAME_UPPER");
|
||||
} else {
|
||||
onnx_attr_proto->set_name("pads");
|
||||
|
|
|
@ -59,20 +59,16 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0";
|
||||
}
|
||||
|
||||
std::set<std::string> available_pad_mode{"pad", "same", "valid"};
|
||||
auto pad_mode_ptr = primitive->GetAttr("pad_mode");
|
||||
if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa<StringImm>()) {
|
||||
auto pad_mode = pad_mode_ptr->cast<StringImmPtr>()->value();
|
||||
if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid";
|
||||
}
|
||||
if (pad_mode == "valid") {
|
||||
if (pad_mode_ptr != nullptr) {
|
||||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(pad_mode_ptr, &pad_mode, true);
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
padding = 0;
|
||||
} else if (pad_mode == "same") {
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
padding = (window - 1) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
std::set<std::string> available_mode{"max", "avg"};
|
||||
auto mode_ptr = primitive->GetAttr("mode");
|
||||
if ((mode_ptr != nullptr) && mode_ptr->isa<StringImm>()) {
|
||||
|
@ -270,13 +266,13 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit
|
|||
|
||||
void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pad_list, const int64_t x_h,
|
||||
const int64_t x_w, const std::vector<int64_t> &kernel, const std::vector<int64_t> &stride,
|
||||
const std::vector<int64_t> &dilation, const std::string &pad_mode,
|
||||
const std::vector<int64_t> &dilation, const int64_t &pad_mode,
|
||||
const std::vector<int64_t> &padding) {
|
||||
if (pad_mode == "valid") {
|
||||
if (pad_mode == PadMode::VALID) {
|
||||
output_hw->push_back(std::ceil(((x_h * 1.0) - dilation[0] * (kernel[0] - 1)) / stride[0]));
|
||||
output_hw->push_back(std::ceil(((x_w * 1.0) - dilation[1] * (kernel[1] - 1)) / stride[1]));
|
||||
pad_list->insert(pad_list->begin(), 4, 0);
|
||||
} else if (pad_mode == "same") {
|
||||
} else if (pad_mode == PadMode::SAME) {
|
||||
output_hw->push_back(std::ceil((x_h * 1.0) / stride[0]));
|
||||
output_hw->push_back(std::ceil((x_w * 1.0) / stride[1]));
|
||||
int64_t pad_needed_h = (output_hw->at(0) - 1) * stride[0] + dilation[0] * (kernel[0] - 1) + 1 - x_h;
|
||||
|
@ -287,7 +283,7 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
|
|||
pad_needed_w = std::max((int64_t)0, pad_needed_w);
|
||||
pad_list->push_back(std::floor(pad_needed_w / 2));
|
||||
pad_list->push_back(pad_needed_w - pad_list->at(2));
|
||||
} else if (pad_mode == "pad") {
|
||||
} else if (pad_mode == PadMode::PAD) {
|
||||
pad_list->insert(pad_list->begin(), padding.begin(), padding.end());
|
||||
output_hw->push_back(std::floor(
|
||||
1 +
|
||||
|
@ -298,6 +294,15 @@ void Conv2DPadFunction(std::vector<int64_t> *output_hw, std::vector<int64_t> *pa
|
|||
}
|
||||
}
|
||||
|
||||
int64_t GetAndCheckFormat(const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) {
|
||||
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC";
|
||||
}
|
||||
return data_format;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
@ -322,12 +327,12 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
CheckShapeAnyAndPositive(op_name + " w_shape", w_shape);
|
||||
CheckShapeAllPositive(op_name + " w_min_shape", w_min_shape);
|
||||
CheckShapeAllPositive(op_name + " w_max_shape", w_max_shape);
|
||||
std::string data_format = CheckAttrStringSet(op_name, primitive->GetAttr("format"), "format", {"NCHW", "NHWC"});
|
||||
int64_t n_axis = 0;
|
||||
int64_t c_axis = 1;
|
||||
int64_t h_axis = 2;
|
||||
int64_t w_axis = 3;
|
||||
if (data_format == "NHWC") {
|
||||
int64_t data_format = GetAndCheckFormat(primitive->GetAttr("format"));
|
||||
if (data_format == Format::NHWC) {
|
||||
c_axis = 3;
|
||||
h_axis = 1;
|
||||
w_axis = 2;
|
||||
|
@ -352,8 +357,8 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
std::vector<int64_t> stride = CheckAttrIntOrTuple(op_name, primitive->GetAttr("stride"), 2, 2);
|
||||
std::vector<int64_t> dilation = CheckAttrIntOrTuple(op_name, primitive->GetAttr("dilation"), 2, 2);
|
||||
std::vector<int64_t> padding = CheckAttrIntOrTuple(op_name, primitive->GetAttr("pad"), 0, 4);
|
||||
std::string pad_mode =
|
||||
CheckAttrStringSet(op_name, primitive->GetAttr("pad_mode"), "pad_mode", {"pad", "same", "valid"});
|
||||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr("pad_mode"), &pad_mode);
|
||||
std::vector<int64_t> output_hw;
|
||||
std::vector<int64_t> pad_list;
|
||||
std::vector<int64_t> output_hw_min;
|
||||
|
@ -378,7 +383,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
ShapeVector output_shape;
|
||||
ShapeVector output_shape_min;
|
||||
ShapeVector output_shape_max;
|
||||
if (data_format == "NHWC") {
|
||||
if (data_format == Format::NHWC) {
|
||||
output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel};
|
||||
output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel};
|
||||
output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel};
|
||||
|
@ -426,16 +431,12 @@ AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
ShapeVector bias_shape = bias->shape()->shape();
|
||||
ShapeVector x_min_shape = x->shape()->min_shape();
|
||||
ShapeVector x_max_shape = x->shape()->max_shape();
|
||||
std::set<std::string> available_data_format{"NCHW", "NHWC"};
|
||||
auto data_format_ptr = primitive->GetAttr("format");
|
||||
std::string data_format = "NCHW";
|
||||
if ((data_format_ptr != nullptr) && data_format_ptr->isa<StringImm>()) {
|
||||
data_format = data_format_ptr->cast<StringImmPtr>()->value();
|
||||
int64_t data_format = Format::NCHW;
|
||||
if (data_format_ptr != nullptr) {
|
||||
data_format = GetAndCheckFormat(data_format_ptr);
|
||||
}
|
||||
if (available_data_format.find(data_format) == available_data_format.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC.";
|
||||
}
|
||||
auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1];
|
||||
auto x_channel = data_format == Format::NHWC ? x_shape[x_shape.size() - 1] : x_shape[1];
|
||||
// Additional check for dynamic shape
|
||||
// Last infer will be real shape values
|
||||
bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; });
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
using std::string;
|
||||
|
||||
|
@ -494,7 +495,11 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
case FORM_PARSE_SCALAR: {
|
||||
std::size_t value_pos(0);
|
||||
if ((value_pos = ref_attr_name.find("value0")) != std::string::npos) {
|
||||
auto res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
const std::string &op_type = prim->name();
|
||||
if (!IsLite()) {
|
||||
CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res);
|
||||
}
|
||||
prim->AddAttr(attr_name, res);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -39,6 +39,8 @@ class MSANFModelParser {
|
|||
std::string GetProducerName() { return producer_name_; }
|
||||
std::string GetProducerVersion() { return model_version_; }
|
||||
std::string GetIrVersion() { return ir_version_; }
|
||||
void SetLite() { is_lite_ = true; }
|
||||
bool IsLite() { return is_lite_; }
|
||||
|
||||
private:
|
||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
|
@ -68,6 +70,7 @@ class MSANFModelParser {
|
|||
std::string producer_name_;
|
||||
std::string model_version_;
|
||||
std::string ir_version_;
|
||||
bool is_lite_ = false;
|
||||
std::unordered_map<std::string, AnfNodePtr> anfnode_build_map_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,7 +71,7 @@ std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite) {
|
||||
auto graphBuf = ReadProtoFile(file_name);
|
||||
if (graphBuf == nullptr) {
|
||||
MS_LOG(ERROR) << "Read Mind IR failed, file name is " << file_name.c_str();
|
||||
|
@ -79,7 +79,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
|
|||
}
|
||||
|
||||
try {
|
||||
auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size());
|
||||
auto graph = ConvertStreamToFuncGraph(graphBuf->data(), graphBuf->size(), is_lite);
|
||||
return graph;
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
|
||||
|
@ -87,7 +87,7 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name) {
|
|||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size) {
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
std::string str((const char *)buf, buf_size);
|
||||
mind_ir::ModelProto model_;
|
||||
|
@ -95,6 +95,9 @@ std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_
|
|||
MS_LOG(ERROR) << "Parse model from buffer fail!";
|
||||
}
|
||||
MSANFModelParser model_parser;
|
||||
if (is_lite) {
|
||||
model_parser.SetLite();
|
||||
}
|
||||
FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
|
||||
return dstgraph_ptr;
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name);
|
||||
std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite = false);
|
||||
std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size);
|
||||
std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite = false);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_LOAD_MODEL_H
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,13 +14,16 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <typeinfo>
|
||||
#include <functional>
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/dtype/tensor_type.h"
|
||||
#include "ir/dtype.h"
|
||||
|
@ -84,21 +87,21 @@ AttrConverterPair PadModeUpperConverter(PadModToEnumUpperMap, PadModToStrUpperMa
|
|||
AttrConverterPair ReductionConverter(ReductionToEnumMap, ReductionToStrMap);
|
||||
|
||||
static std::map<std::string, AttrConverterPair> FormatAndPadAttrMap = {
|
||||
{"format", DataFormatConverter},
|
||||
{"pad_mode", PadModeConverter},
|
||||
{ops::kFormat, DataFormatConverter},
|
||||
{ops::kPadMode, PadModeConverter},
|
||||
};
|
||||
|
||||
static std::map<std::string, AttrConverterPair> FormatAndPadUpperAttrMap = {
|
||||
{"format", DataFormatConverter},
|
||||
{"pad_mode", PadModeUpperConverter},
|
||||
{ops::kFormat, DataFormatConverter},
|
||||
{ops::kPadMode, PadModeUpperConverter},
|
||||
};
|
||||
|
||||
static std::map<std::string, AttrConverterPair> DataFormatMap = {
|
||||
{"format", DataFormatConverter},
|
||||
{ops::kFormat, DataFormatConverter},
|
||||
};
|
||||
|
||||
static std::map<std::string, AttrConverterPair> ReductionMap = {
|
||||
{"reduction", ReductionConverter},
|
||||
{ops::kReduction, ReductionConverter},
|
||||
};
|
||||
|
||||
static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrConvertMap = {
|
||||
|
@ -132,24 +135,42 @@ static std::map<std::string, std::map<std::string, AttrConverterPair>> PrimAttrC
|
|||
{"BinaryCrossEntropy", ReductionMap},
|
||||
{"BinaryCrossEntropyGrad", ReductionMap},
|
||||
{"NLLLoss", ReductionMap},
|
||||
{"DepthToSpace", DataFormatMap},
|
||||
};
|
||||
|
||||
int64_t CheckAndConvertUtils::GetDataFormatEnumValue(const std::string &value) {
|
||||
if (DataFormatToEnumMap.find(value) == DataFormatToEnumMap.end()) {
|
||||
MS_LOG(ERROR) << "Can not convert data format " << value << "to enum";
|
||||
bool CheckAndConvertUtils::GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<StringImm>()) {
|
||||
auto attr_value_str = GetValue<std::string>(value);
|
||||
if (DataFormatToEnumMap.find(attr_value_str) == DataFormatToEnumMap.end()) {
|
||||
MS_LOG(DEBUG) << "The data format " << attr_value_str << " not be converted to enum.";
|
||||
return false;
|
||||
}
|
||||
return DataFormatToEnumMap[value];
|
||||
*enum_value = DataFormatToEnumMap[attr_value_str];
|
||||
return true;
|
||||
} else {
|
||||
*enum_value = GetValue<int64_t>(value);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t CheckAndConvertUtils::GetPadModEnumValue(const std::string &value, bool is_upper) {
|
||||
void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<StringImm>()) {
|
||||
auto attr_value_str = GetValue<std::string>(value);
|
||||
|
||||
std::map<std::string, int64_t> pad_map = PadModToEnumMap;
|
||||
if (is_upper) {
|
||||
pad_map = PadModToEnumUpperMap;
|
||||
}
|
||||
if (pad_map.find(value) == pad_map.end()) {
|
||||
MS_LOG(ERROR) << "Can not convert pad mode " << value << "to enum";
|
||||
if (pad_map.find(attr_value_str) == pad_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
|
||||
}
|
||||
*enum_value = pad_map[attr_value_str];
|
||||
} else {
|
||||
*enum_value = GetValue<int64_t>(value);
|
||||
}
|
||||
return pad_map[value];
|
||||
}
|
||||
|
||||
AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
|
||||
|
@ -172,8 +193,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op
|
|||
|
||||
bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name,
|
||||
ValuePtr *const value) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
if (value == nullptr || *value == nullptr) {
|
||||
MS_LOG(DEBUG) << "value of attr " << op_type << attr_name << " is nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (!(*value)->isa<StringImm>()) {
|
||||
|
@ -191,12 +212,17 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con
|
|||
}
|
||||
if (!do_convert) {
|
||||
transform(real_value.begin(), real_value.end(), real_value.begin(), ::toupper);
|
||||
if (attr_map_pair.first.find(real_value) != attr_map_pair.first.end()) {
|
||||
do_convert = true;
|
||||
}
|
||||
}
|
||||
if (!do_convert) {
|
||||
transform(real_value.begin(), real_value.end(), real_value.begin(), ::tolower);
|
||||
if (attr_map_pair.first.find(real_value) == attr_map_pair.first.end()) {
|
||||
MS_LOG(DEBUG) << "Can not convert " << op_type << " attr " << attr_name << ": " << real_value << " to int";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
*value = MakeValue<int64_t>(attr_map_pair.first[real_value]);
|
||||
MS_LOG(DEBUG) << "convert str to int, name: " << op_type << ", attr: " << attr_name;
|
||||
return true;
|
||||
|
@ -204,7 +230,7 @@ bool CheckAndConvertUtils::ConvertAttrValueToInt(const std::string &op_type, con
|
|||
|
||||
bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name,
|
||||
ValuePtr *const value) {
|
||||
if (value == nullptr) {
|
||||
if (value == nullptr || *value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return false;
|
||||
}
|
||||
|
@ -226,7 +252,6 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
|
||||
|
||||
|
@ -242,7 +267,6 @@ std::map<std::string, AttrFunction> kIrAttrToOpAttr = {{"L2Normalize", {{"axis",
|
|||
{"L2NormalizeGrad", {{"axis", L2NormalizeAttrConversion}}}};
|
||||
} // namespace
|
||||
|
||||
|
||||
bool CheckAndConvertUtils::IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2) {
|
||||
if (vec_1.size() != vec_2.size()) {
|
||||
return false;
|
||||
|
|
|
@ -284,8 +284,8 @@ class CheckAndConvertUtils {
|
|||
static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
|
||||
static int64_t GetDataFormatEnumValue(const std::string &value);
|
||||
static int64_t GetPadModEnumValue(const std::string &value, bool is_upper = false);
|
||||
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
|
||||
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
|
||||
static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
|
||||
private:
|
||||
|
|
|
@ -856,7 +856,7 @@ int AnfImporterFromMindir::ParseModelConfigureInfo(const onnx::ModelProto &model
|
|||
|
||||
int AnfImporterFromMindir::Import(const converter::Flags *flag) {
|
||||
#if SUPPORT_TRAIN
|
||||
func_graph_ = LoadMindIR(flag->modelFile);
|
||||
func_graph_ = LoadMindIR(flag->modelFile, true);
|
||||
if (func_graph_ != nullptr) {
|
||||
return RET_OK;
|
||||
} else {
|
||||
|
@ -866,7 +866,7 @@ int AnfImporterFromMindir::Import(const converter::Flags *flag) {
|
|||
onnx_model_ = ReadOnnxFromBinary(flag->modelFile);
|
||||
if (onnx_model_ == nullptr) {
|
||||
MS_LOG(DEBUG) << "Parse model failed, which is not an old mindir model";
|
||||
func_graph_ = LoadMindIR(flag->modelFile);
|
||||
func_graph_ = LoadMindIR(flag->modelFile, true);
|
||||
if (func_graph_ == nullptr) {
|
||||
MS_LOG(ERROR) << "The mindir model cannot be parsed, which may not match proto file.";
|
||||
return RET_GRAPH_FILE_ERR;
|
||||
|
|
Loading…
Reference in New Issue