forked from mindspore-Ecosystem/mindspore
!25734 Enable ref type in mindir export and loading
Merge pull request !25734 from LiangZhibo/mindir_ref
This commit is contained in:
commit
2cd87ddf79
|
@ -107,10 +107,11 @@ class IrExportBuilder {
|
|||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract,
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string);
|
||||
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
|
@ -349,6 +350,9 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
}
|
||||
if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) {
|
||||
return false;
|
||||
}
|
||||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
|
@ -402,6 +406,25 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs,
|
||||
mind_ir::TensorProto *const tensor_proto) {
|
||||
if (!type->isa<RefType>()) {
|
||||
return true;
|
||||
}
|
||||
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
|
||||
if (abs_ref == nullptr) {
|
||||
MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef.";
|
||||
return false;
|
||||
}
|
||||
auto ref_key_value = abs_ref->ref_key_value();
|
||||
if (ref_key_value == nullptr) {
|
||||
MS_LOG(ERROR) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr";
|
||||
return false;
|
||||
}
|
||||
tensor_proto->set_ref_key(ref_key_value->name());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
|
||||
if (param == nullptr || tensor_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
|
||||
|
@ -479,7 +502,7 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
return type_name;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs,
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
|
@ -488,8 +511,9 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
*seq_string += "Tuple[";
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
||||
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string)) {
|
||||
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -499,7 +523,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
*seq_string += shape_name + ",";
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
return SetTensorProto(type, shape, tensor_proto);
|
||||
return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto);
|
||||
} else if (type->isa<Number>()) {
|
||||
if (type->isa<Bool>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
|
@ -531,6 +555,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto type = node->Type();
|
||||
auto shape = node->Shape();
|
||||
auto abs = node->abstract();
|
||||
// For the bprop fg which has not been renormalized.
|
||||
if (type == nullptr || shape == nullptr) {
|
||||
return true;
|
||||
|
@ -538,7 +563,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
|
|||
ResetTupleIndex();
|
||||
std::string seq_string = "shape:";
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
if (!SetShapeToNodeProto(type, shape, attr_proto, &seq_string)) {
|
||||
if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) {
|
||||
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -394,11 +394,36 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor_info fail.";
|
||||
return false;
|
||||
}
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
node->set_abstract(tensor_abstract);
|
||||
if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
|
||||
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
node->set_abstract(abs_ref);
|
||||
auto parameters = top_graph_->parameters();
|
||||
for (auto parameter : parameters) {
|
||||
auto parameter_abs = parameter->abstract();
|
||||
if (parameter_abs->isa<abstract::AbstractRef>()) {
|
||||
auto parameter_abs_value = parameter_abs->cast<abstract::AbstractRefPtr>()->ref_key_value();
|
||||
if (parameter_abs_value->name() == tensor_proto.ref_key()) {
|
||||
node->set_default_param(parameter->cast<ParameterPtr>()->default_param());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
node->set_abstract(tensor_abstract);
|
||||
}
|
||||
} else if (value_proto.has_denotation()) {
|
||||
if (value_proto.denotation() == "UMonadType") {
|
||||
node->set_abstract(kUMonad->ToAbstract());
|
||||
} else if (value_proto.denotation() == "IOMonadType") {
|
||||
node->set_abstract(kIOMonad->ToAbstract());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
|
||||
}
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
|
@ -849,9 +874,17 @@ std::unordered_map<std::string, abstract::AbstractBasePtr> MSANFModelParser::Get
|
|||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
|
||||
if (attr_tensor.has_ref_key()) {
|
||||
auto ref_key = std::make_shared<RefKey>(attr_tensor.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abs_ref));
|
||||
} else {
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
kv.insert(std::pair<string, abstract::AbstractBasePtr>(attr_tensor.name(), abstract));
|
||||
}
|
||||
}
|
||||
return kv;
|
||||
}
|
||||
|
@ -1198,6 +1231,7 @@ FuncGraphPtr MSANFModelParser::Parse(const mind_ir::ModelProto &model_proto) {
|
|||
return nullptr;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Parse pb to build FuncGraph Success! " << graphBuild.name();
|
||||
top_graph_ = dstGraph;
|
||||
for (int i = 0; i < model_proto.functions_size(); ++i) {
|
||||
const auto &graph_proto = model_proto.functions(i);
|
||||
FuncGraphPtr graph = GetValueNode<FuncGraphPtr>(anfnode_build_map_[graph_proto.name()]);
|
||||
|
|
|
@ -81,6 +81,7 @@ class MSANFModelParser {
|
|||
const mind_ir::AttributeProto &attr_proto);
|
||||
AnfNodePtr GetAnfNode(const std::string &node_name);
|
||||
|
||||
FuncGraphPtr top_graph_ = nullptr;
|
||||
std::string producer_name_;
|
||||
std::string model_version_;
|
||||
std::string ir_version_;
|
||||
|
|
|
@ -133,6 +133,7 @@ message TensorProto {
|
|||
repeated double double_data = 10;
|
||||
repeated uint64 uint64_data = 11;
|
||||
optional ExternalDataProto external_data = 12;
|
||||
optional string ref_key = 13;
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue