!31522 mindir:support csrtensor export

Merge pull request !31522 from lanzhineng/func_closure
This commit is contained in:
i-robot 2022-03-21 09:26:29 +00:00 committed by Gitee
commit cfab60eca7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 17 additions and 2 deletions

View File

@ -118,6 +118,7 @@ class IrExportBuilder {
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetAbstractToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
@ -529,6 +530,17 @@ bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
return true;
}
bool IrExportBuilder::SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
abstract::AbstractCSRTensorPtr csr_tensor_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
MS_EXCEPTION_IF_NULL(csr_tensor_abs);
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_CSR_TENSOR);
(void)SetTensorProto(csr_tensor_abs->indptr(), attr_proto->add_tensors());
(void)SetTensorProto(csr_tensor_abs->indices(), attr_proto->add_tensors());
(void)SetTensorProto(csr_tensor_abs->values(), attr_proto->add_tensors());
auto dense_proto = attr_proto->add_values();
return SetAbstractToNodeProto(csr_tensor_abs->dense_shape(), dense_proto);
}
bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
auto type = abstract->BuildType();
auto shape = abstract->BuildShape();
@ -691,7 +703,7 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_IOMONAD);
} else if (type->isa<CSRTensorType>()) {
auto csr_tensor_abs = abs->cast<abstract::AbstractCSRTensorPtr>();
if (!SetAbstractToNodeProto(csr_tensor_abs->element(), attr_proto)) {
if (!SetCSRTensorToProto(csr_tensor_abs, attr_proto)) {
return false;
}
} else {

View File

@ -311,7 +311,7 @@ abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType
return BuildAbstractFunction(attr_proto);
}
default: {
MS_LOG(ERROR) << "Not support to get the abstract from AttrProto type: " << attr_proto.type();
MS_LOG(INFO) << "Not support to get the abstract from AttrProto type: " << attr_proto.type();
return nullptr;
}
}

View File

@ -38,6 +38,9 @@ message AttributeProto {
FUNCGRAPHCLOSURE = 27;
PARTIALCLOSURE = 28;
UNIONFUNCCLOSURE = 29;
CSR_TENSOR= 30;
COO_TENSOR= 31;
ROW_TENSOR= 32;
}
optional string name = 1;
optional float f = 2;