!27737 refractor mindir get abstract

Merge pull request !27737 from lianliguang/fix-mindir-abstract-bug
This commit is contained in:
i-robot 2021-12-25 08:15:09 +00:00 committed by Gitee
commit 5f35cd0daa
5 changed files with 249 additions and 226 deletions

View File

@ -164,10 +164,20 @@ void AclModelMulti::SetInputs() {
for (const auto &in : inputs) {
auto input_param = std::dynamic_pointer_cast<Parameter>(in);
MS_EXCEPTION_IF_NULL(input_param);
MS_EXCEPTION_IF_NULL(input_param->abstract());
auto input_value = input_param->abstract()->GetValueTrack();
auto tensor = input_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto input_abs = input_param->abstract();
MS_EXCEPTION_IF_NULL(input_abs);
auto tensor_abs = input_abs->cast<abstract::AbstractTensorPtr>();
if (tensor_abs == nullptr) {
MS_LOG(EXCEPTION) << "The graph input type is not a tensor. input args info:" << input_abs->ToString();
}
auto shape_ptr = tensor_abs->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(tensor_shape);
auto elem = tensor_abs->element();
MS_EXCEPTION_IF_NULL(elem);
auto type_id = elem->BuildType()->type_id();
auto tensor = std::make_shared<tensor::Tensor>(type_id, tensor_shape->shape());
std::vector<int64_t> shape = tensor->shape_c();
auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast<DataType>(tensor->data_type_c()),

View File

@ -111,12 +111,11 @@ class IrExportBuilder {
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string);
bool SetShapeToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string);
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
@ -399,30 +398,18 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
return true;
}
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
auto tensor = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor);
auto elem_type = tensor->element();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
auto data_type = GetMindirDataType(elem_type->type_id());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
}
tensor_proto->set_data_type(data_type);
if (dims.size() == 0) {
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
} else {
for (const auto &dim : dims) {
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
tensor_proto->add_dims(dim);
}
}
if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) {
if (!SetTensorProto(node->abstract(), tensor_proto)) {
return false;
}
} else if (type->isa<Tuple>()) {
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
std::string seq_string = "shape:";
mind_ir::AttributeProto *attribute = value_proto->mutable_attr_info();
if (!SetShapeToNodeProto(node->abstract(), attribute, &seq_string)) {
MS_LOG(ERROR) << "Set shape to Proto for " << node->DebugString() << " failed.";
return false;
}
attribute->set_ref_attr_name(seq_string);
} else {
value_proto->set_denotation(type->type_name());
}
@ -454,14 +441,16 @@ bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
return true;
}
bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
mind_ir::TensorProto *const tensor_proto) {
bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
auto type = abstract->BuildType();
auto shape = abstract->BuildShape();
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
return false;
}
auto tensor = type->cast<TensorTypePtr>();
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
auto tensor_shape = shape->cast<abstract::ShapePtr>();
const auto &dims = tensor_shape->shape();
auto data_type = GetMindirDataType(tensor->element()->type_id());
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
return false;
@ -470,22 +459,29 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh
for (const auto &dim : dims) {
tensor_proto->add_dims(dim);
}
return true;
}
bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs,
mind_ir::TensorProto *const tensor_proto) {
if (tensor_shape->IsDynamic()) {
auto min_shape = tensor_shape->min_shape();
auto max_shape = tensor_shape->max_shape();
for (auto item : min_shape) {
tensor_proto->add_min_dims(item);
}
for (auto item : max_shape) {
tensor_proto->add_max_dims(item);
}
}
// Deal Ref
if (!type->isa<RefType>()) {
return true;
}
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
auto abs_ref = abstract->cast<abstract::AbstractRefPtr>();
if (abs_ref == nullptr) {
MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef.";
MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef.";
return false;
}
auto ref_key_value = abs_ref->ref_key_value();
if (ref_key_value == nullptr) {
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr";
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
return true;
}
tensor_proto->set_ref_key(ref_key_value->name());
@ -497,7 +493,7 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
}
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
return SetTensorProto(param->Type(), param->Shape(), tensor_proto);
return SetTensorProto(param->abstract(), tensor_proto);
}
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
@ -569,18 +565,16 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
return type_name;
}
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs,
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(shape);
bool IrExportBuilder::SetShapeToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto,
std::string *const seq_string) {
auto type = abs->BuildType();
auto shape = abs->BuildShape();
MS_EXCEPTION_IF_NULL(seq_string);
if (type->isa<Tuple>()) {
*seq_string += "Tuple[";
auto elements = type->cast<TuplePtr>()->elements();
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>()->elements();
for (size_t i = 0; i < elements.size(); i++) {
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) {
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
for (size_t i = 0; i < tuple_abs->size(); i++) {
if (!SetShapeToNodeProto((*tuple_abs)[i], attr_proto, seq_string)) {
return false;
}
}
@ -590,7 +584,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
*seq_string += shape_name + ",";
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
tensor_proto->set_name(shape_name);
return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto);
return SetTensorProto(abs, tensor_proto);
} else if (type->isa<Number>()) {
if (type->isa<Bool>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
@ -608,9 +602,10 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
*seq_string += type->type_name() + ",";
} else if (type->isa<CSRTensorType>()) {
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
MS_EXCEPTION_IF_NULL(csr_tensor_type);
if (!SetShapeToNodeProto(csr_tensor_type->element(), shape, abs, attr_proto, seq_string)) return false;
auto cst_tensor_abs = abs->cast<abstract::AbstractCSRTensorPtr>();
if (!SetShapeToNodeProto(cst_tensor_abs->element(), attr_proto, seq_string)) {
return false;
}
} else {
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
return false;
@ -634,7 +629,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
ResetTupleIndex();
std::string seq_string = "shape:";
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) {
if (!SetShapeToNodeProto(abs, attr_proto, &seq_string)) {
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
return false;
}

View File

@ -241,11 +241,107 @@ string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
return "";
}
} // namespace
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor,
bool need_load_data) {
ShapeVector shape;
const int attr_tensor_type = attr_tensor.data_type();
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
if (!IsIncLoad() || load_tensor_map_.find(attr_tensor.name()) == load_tensor_map_.end()) {
load_tensor_map_[attr_tensor.name()] = tensor;
}
MS_EXCEPTION_IF_NULL(tensor);
const std::string &tensor_buf = attr_tensor.raw_data();
if (attr_tensor.has_raw_data()) {
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
return nullptr;
}
} else if (need_load_data) {
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
return nullptr;
}
return tensor;
}
bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto,
const AnfNodePtr &node_ptr) {
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
if (attr_proto.ref_attr_name().find("shape:") == string::npos) {
MS_LOG(ERROR) << "Cannot use a attr_proto " << attr_proto.ref_attr_name() << " to init shape.";
return false;
}
bool is_tuple_or_list = false;
shape_ref_attr_name = attr_proto.ref_attr_name();
is_tuple_or_list =
shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
kv = GetAbstractForNode(attr_proto);
if (kv.empty()) {
return SetEmptyTensorProtoCNodeAbstract(node_ptr);
} else if (!is_tuple_or_list) {
auto iter = kv.begin();
if (iter->second != nullptr) {
node_ptr->set_abstract(iter->second);
}
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
node_ptr->set_abstract(abstract);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Node's attribute is nullptr.";
return false;
}
}
return true;
}
void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
auto prim = GetCNodePrimitive(cnode_ptr);
auto prim_to_add_attr = prim;
if (prim_to_add_attr != nullptr) {
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
}
}
prim_to_add_attr->set_attr("is_load", MakeValue(true));
}
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
// CNode abstract
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
SetCNodeAbstract(attr_proto, cnode_ptr);
continue;
}
if (prim_to_add_attr != nullptr) {
if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parser prim: " << prim->ToString() << " attributes error : " << attr_proto.DebugString();
}
}
}
}
abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto) {
ShapeVector shape;
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
shape.push_back(tensor_proto.dims(i));
shape.emplace_back(tensor_proto.dims(i));
}
ShapeVector min_shape;
for (int i = 0; i < tensor_proto.min_dims_size(); ++i) {
min_shape.emplace_back(tensor_proto.min_dims(i));
}
ShapeVector max_shape;
for (int i = 0; i < tensor_proto.max_dims_size(); ++i) {
max_shape.emplace_back(tensor_proto.min_dims(i));
}
if (!tensor_proto.has_data_type()) {
@ -253,14 +349,20 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type.";
return nullptr;
}
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
auto iter = kDefaultValueSwitchMap.find(tensor_proto.data_type());
if (iter == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
return nullptr;
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
auto tensor_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
if (tensor_proto.has_ref_key()) {
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
return abs_ref;
}
return tensor_info;
}
@ -277,51 +379,34 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
node->set_debug_info(debug_info_ptr);
node->set_name(debug_info_name);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
if (tensor_info == nullptr) {
return false;
}
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(debug_info_name);
MS_LOG(DEBUG) << "Load parameter name: " << parameter_proto.name();
if (!IsIncLoad() || load_tensor_map_.find(parameter_proto.name()) == load_tensor_map_.end()) {
load_tensor_map_[parameter_proto.name()] = tensor_info;
} else {
if (IsIncLoad() && load_tensor_map_.find(parameter_proto.name()) != load_tensor_map_.end()) {
MS_LOG(DEBUG) << "Parameter: " << parameter_proto.name() << " has been already loaded, use it again.";
tensor::TensorPtr load_tensor_info = load_tensor_map_[parameter_proto.name()];
auto tensor_abstract = load_tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
auto load_tensor_info = load_tensor_map_[parameter_proto.name()];
MS_EXCEPTION_IF_NULL(load_tensor_info);
load_tensor_info->set_param_info(param_info);
node->set_default_param(load_tensor_info);
node->set_abstract(load_tensor_info->ToAbstract());
anfnode_build_map_[parameter_proto.name()] = node;
return true;
}
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(debug_info_name);
tensor_info->set_param_info(param_info);
auto tensor_abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(tensor_abstract);
node->set_abstract(tensor_abstract);
auto tensor = GenerateTensorPtrFromTensorProto(parameter_proto, false);
tensor->set_param_info(param_info);
if (parameter_proto.has_raw_data()) {
std::string initial_data = parameter_proto.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
MS_EXCEPTION_IF_NULL(tensor_data_buf);
auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), initial_data.data(),
initial_data.size());
if (ret != 0) {
MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
return false;
}
node->set_default_param(tensor_info);
node->set_default_param(tensor);
} else if (parameter_proto.has_external_data()) {
auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info);
auto ret = GetTensorDataFromExternal(parameter_proto, tensor);
if (!ret) {
return false;
}
node->set_default_param(tensor_info);
node->set_default_param(tensor);
} else {
MS_LOG(DEBUG) << "Parameter will load initialized data.";
}
node->set_abstract(tensor->ToAbstract());
anfnode_build_map_[parameter_proto.name()] = node;
return true;
@ -398,7 +483,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
// Set abstract of the parameter
if (value_proto.tensor_size() > 0) {
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
auto tensor_info = GetAbsTensorFromTensorProto(tensor_proto);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Get tensor_info fail.";
return false;
@ -406,8 +491,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
node->set_abstract(abs_ref);
auto parameters = top_graph_->parameters();
for (auto parameter : parameters) {
@ -421,8 +505,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
}
}
} else {
auto tensor_abstract = tensor_info->ToAbstract();
node->set_abstract(tensor_abstract);
node->set_abstract(tensor_info);
}
} else if (value_proto.has_denotation()) {
if (value_proto.denotation() == "UMonadType") {
@ -432,6 +515,12 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
}
MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
}
if (value_proto.has_attr_info()) {
if (!SetNodeAbstractFromAttrProto(value_proto.attr_info(), node)) {
MS_LOG(ERROR) << "Failed to get abstract from proto.";
return false;
}
}
anfnode_build_map_[value_proto.name()] = node;
return true;
}
@ -589,19 +678,10 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
const int attr_tensor_type = attr_tensor.data_type();
ShapeVector shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret =
memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error.";
auto tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Failed to get attr[" << attr_proto.name() << "] for node " << prim->ToString()
<< " from the proto.";
return false;
}
prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
@ -682,21 +762,10 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
const mind_ir::TensorProto &attr_tensor) {
const int attr_tensor_type = attr_tensor.data_type();
ShapeVector shape;
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
shape.push_back(attr_tensor.dims(i));
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
}
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
MS_EXCEPTION_IF_NULL(tensor_info);
const std::string &tensor_buf = attr_tensor.raw_data();
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != 0) {
MS_LOG(ERROR) << "Obtain ValueNode in TensorForm occur memcpy_s error.";
return false;
}
auto new_value_node = NewValueNode(MakeValue(tensor_info));
MS_EXCEPTION_IF_NULL(new_value_node);
auto tensor_abstract = tensor_info->ToAbstract();
@ -740,9 +809,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return false;
}
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
new_value_node->set_abstract(abs_type);
auto value = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
auto new_value_node = NewValueNode(value);
new_value_node->set_abstract(value->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
@ -872,29 +941,13 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node
return GetAttrValueForValueNode(value_node_name, attr_proto);
}
mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForNode(
const mind_ir::AttributeProto &attr_proto) {
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
ShapeVector shape_vec;
const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
shape_vec.push_back(attr_tensor.dims(j));
}
tensor::TensorPtr tensor_info =
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
MS_EXCEPTION_IF_NULL(tensor_info);
if (attr_tensor.has_ref_key()) {
auto ref_key = std::make_shared<RefKey>(attr_tensor.ref_key());
auto abs_ref_key = ref_key->ToAbstract();
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
(void)kv.emplace(attr_tensor.name(), abs_ref);
} else {
auto abstract = tensor_info->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
(void)kv.emplace(attr_tensor.name(), abstract);
}
auto tensor_info = GetAbsTensorFromTensorProto(attr_tensor);
(void)kv.emplace(attr_tensor.name(), tensor_info);
}
return kv;
}
@ -946,26 +999,6 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
prim->set_instance_name(node_type);
}
}
MS_EXCEPTION_IF_NULL(prim);
auto prim_to_add_attr = prim;
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
prim_to_add_attr->set_attr("is_load", MakeValue(true));
}
}
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
// CNode abstract
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
continue;
}
if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
return nullptr;
}
}
prim->set_attr("is_load", MakeValue(true));
return std::make_shared<ValueNode>(prim);
}
@ -991,70 +1024,47 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) {
return false;
}
void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type) {
if (node_type == "UpdateState") {
cnode_ptr->set_abstract(kUMonad->ToAbstract());
} else if (node_type == "Depend") {
cnode_ptr->set_abstract(kBool->ToAbstract());
} else {
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
auto abs = cnode_ptr->input(index)->abstract();
if (abs != nullptr) {
if (abs->GetValueTrack() == nullptr) {
abs->set_value(kAnyValue);
bool MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr) {
auto primitive = GetCNodePrimitive(node_ptr);
if (primitive != nullptr) {
auto node_type = primitive->name();
if (node_type == "UpdateState") {
node_ptr->set_abstract(kUMonad->ToAbstract());
} else if (node_type == "Depend") {
node_ptr->set_abstract(kBool->ToAbstract());
} else {
auto cnode_ptr = node_ptr->cast<CNodePtr>();
AbstractBasePtrList elem;
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
auto abs = cnode_ptr->input(index)->abstract();
if (abs != nullptr) {
if (abs->GetValueTrack() == nullptr) {
abs->set_value(kAnyValue);
}
elem.push_back(abs);
}
elem.push_back(abs);
}
if (!elem.empty()) {
node_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
}
}
if (!elem.empty()) {
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
}
} else {
MS_LOG(ERROR) << "Failed to get the abstract of node:" << node_ptr->DebugString();
return false;
}
return true;
}
// Set CNode abstract.
void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
bool MSANFModelParser::SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr) {
if (CheckCNodePrim(cnode_ptr)) {
return;
return true;
}
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
string shape_ref_attr_name;
bool is_tuple_or_list = false;
for (int i = 0; i < node_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
shape_ref_attr_name = attr_proto.ref_attr_name();
is_tuple_or_list =
shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
kv = GetAbstractForCNode(attr_proto);
break;
}
}
// Because there is not context in unit test,
// abstract->broaden() is replaced by abstract->set_value(kAnyValue).
const std::string &node_type = node_proto.op_type();
if (kv.size() == 0) {
SetEmptyTensorProtoCNodeAbstract(cnode_ptr, node_type);
} else if (kv.size() == 1 && !is_tuple_or_list) {
auto iter = kv.begin();
if (iter->second != nullptr) {
iter->second->set_value(kAnyValue);
cnode_ptr->set_abstract(iter->second);
}
} else {
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
if (abstract == nullptr) {
cnode_ptr->set_abstract(nullptr);
MS_LOG(ERROR) << "Node's attribute is nullptr.";
} else {
abstract->set_value(kAnyValue);
cnode_ptr->set_abstract(abstract);
}
if (!SetNodeAbstractFromAttrProto(attr_proto, cnode_ptr)) {
MS_LOG(ERROR) << "Failed to get CNode abstract from proto.";
return false;
}
return true;
}
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
@ -1085,8 +1095,8 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(cnode_ptr);
SetCNodeAbstract(node_proto, cnode_ptr);
// Set Abstract and prim attr for CNode
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
const std::string &fullname_with_scope = node_proto.domain();
string debug_info_name = ParseCNodeName(node_name);
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
@ -1148,6 +1158,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
outputFuncGraph->set_return(return_node);
MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
}
return true;
}
@ -1386,7 +1397,7 @@ AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
return nullptr;
}
FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
if (func_graph_ptr) {
if (func_graph_ptr != nullptr) {
return NewValueNode(func_graph_ptr);
} else {
return it->second;

View File

@ -63,7 +63,7 @@ class MSANFModelParser {
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
abstract::AbstractTensorPtr GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
@ -76,8 +76,10 @@ class MSANFModelParser {
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
bool CheckCNodePrim(CNodePtr cnode_ptr);
void SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type);
void SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
bool SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
bool SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, const AnfNodePtr &node_ptr);
void SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr);
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
@ -85,9 +87,11 @@ class MSANFModelParser {
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
bool little_endian() { return little_endian_; }
mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForNode(
const mind_ir::AttributeProto &attr_proto);
AnfNodePtr GetAnfNode(const std::string &node_name);
tensor::TensorPtr GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor,
bool need_load_data = true);
FuncGraphPtr top_graph_ = nullptr;
std::string producer_name_;

View File

@ -52,6 +52,7 @@ message ValueInfoProto {
repeated TensorProto tensor = 2;
optional string doc_string = 3;
optional string denotation = 4;
optional AttributeProto attr_info = 5; // graph input info for other type
}
@ -150,6 +151,8 @@ message TensorProto {
repeated uint64 uint64_data = 11;
optional ExternalDataProto external_data = 12;
optional string ref_key = 13;
repeated int64 min_dims = 14;
repeated int64 max_dims = 15;
}
message ParallelProto {