forked from mindspore-Ecosystem/mindspore
refactor get abstract for mindir export and load
This commit is contained in:
parent
f9270b8e09
commit
281284d1b9
|
@ -164,10 +164,20 @@ void AclModelMulti::SetInputs() {
|
|||
for (const auto &in : inputs) {
|
||||
auto input_param = std::dynamic_pointer_cast<Parameter>(in);
|
||||
MS_EXCEPTION_IF_NULL(input_param);
|
||||
MS_EXCEPTION_IF_NULL(input_param->abstract());
|
||||
auto input_value = input_param->abstract()->GetValueTrack();
|
||||
auto tensor = input_value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto input_abs = input_param->abstract();
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
auto tensor_abs = input_abs->cast<abstract::AbstractTensorPtr>();
|
||||
if (tensor_abs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The graph input type is not a tensor. input args info:" << input_abs->ToString();
|
||||
}
|
||||
auto shape_ptr = tensor_abs->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_shape);
|
||||
auto elem = tensor_abs->element();
|
||||
MS_EXCEPTION_IF_NULL(elem);
|
||||
auto type_id = elem->BuildType()->type_id();
|
||||
auto tensor = std::make_shared<tensor::Tensor>(type_id, tensor_shape->shape());
|
||||
|
||||
std::vector<int64_t> shape = tensor->shape_c();
|
||||
auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast<DataType>(tensor->data_type_c()),
|
||||
|
|
|
@ -111,12 +111,11 @@ class IrExportBuilder {
|
|||
|
||||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const abstract::AbstractBasePtr &abstract,
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string);
|
||||
bool SetShapeToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
|
@ -399,30 +398,18 @@ bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueIn
|
|||
return true;
|
||||
}
|
||||
if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto elem_type = tensor->element();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
|
||||
auto data_type = GetMindirDataType(elem_type->type_id());
|
||||
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
||||
return false;
|
||||
}
|
||||
tensor_proto->set_data_type(data_type);
|
||||
if (dims.size() == 0) {
|
||||
MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
|
||||
} else {
|
||||
for (const auto &dim : dims) {
|
||||
MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
}
|
||||
if (!SetTensorProtoForRef(type, node->abstract(), tensor_proto)) {
|
||||
if (!SetTensorProto(node->abstract(), tensor_proto)) {
|
||||
return false;
|
||||
}
|
||||
} else if (type->isa<Tuple>()) {
|
||||
auto tup_shape = shape->cast<abstract::TupleShapePtr>();
|
||||
value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
|
||||
std::string seq_string = "shape:";
|
||||
mind_ir::AttributeProto *attribute = value_proto->mutable_attr_info();
|
||||
if (!SetShapeToNodeProto(node->abstract(), attribute, &seq_string)) {
|
||||
MS_LOG(ERROR) << "Set shape to Proto for " << node->DebugString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
attribute->set_ref_attr_name(seq_string);
|
||||
} else {
|
||||
value_proto->set_denotation(type->type_name());
|
||||
}
|
||||
|
@ -454,14 +441,16 @@ bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
|
||||
mind_ir::TensorProto *const tensor_proto) {
|
||||
bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
|
||||
auto type = abstract->BuildType();
|
||||
auto shape = abstract->BuildShape();
|
||||
if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
|
||||
MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
|
||||
return false;
|
||||
}
|
||||
auto tensor = type->cast<TensorTypePtr>();
|
||||
const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
|
||||
auto tensor_shape = shape->cast<abstract::ShapePtr>();
|
||||
const auto &dims = tensor_shape->shape();
|
||||
auto data_type = GetMindirDataType(tensor->element()->type_id());
|
||||
if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
|
||||
return false;
|
||||
|
@ -470,22 +459,29 @@ bool IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &sh
|
|||
for (const auto &dim : dims) {
|
||||
tensor_proto->add_dims(dim);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetTensorProtoForRef(const TypePtr &type, const AbstractBasePtr &abs,
|
||||
mind_ir::TensorProto *const tensor_proto) {
|
||||
if (tensor_shape->IsDynamic()) {
|
||||
auto min_shape = tensor_shape->min_shape();
|
||||
auto max_shape = tensor_shape->max_shape();
|
||||
for (auto item : min_shape) {
|
||||
tensor_proto->add_min_dims(item);
|
||||
}
|
||||
for (auto item : max_shape) {
|
||||
tensor_proto->add_max_dims(item);
|
||||
}
|
||||
}
|
||||
// Deal Ref
|
||||
if (!type->isa<RefType>()) {
|
||||
return true;
|
||||
}
|
||||
auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
|
||||
|
||||
auto abs_ref = abstract->cast<abstract::AbstractRefPtr>();
|
||||
if (abs_ref == nullptr) {
|
||||
MS_LOG(ERROR) << "The abstract " << abs->ToString() << " should be AbstractRef.";
|
||||
MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRef.";
|
||||
return false;
|
||||
}
|
||||
auto ref_key_value = abs_ref->ref_key_value();
|
||||
if (ref_key_value == nullptr) {
|
||||
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abs->ToString() << " is nullptr";
|
||||
MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
|
||||
return true;
|
||||
}
|
||||
tensor_proto->set_ref_key(ref_key_value->name());
|
||||
|
@ -497,7 +493,7 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::
|
|||
MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
|
||||
}
|
||||
MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
|
||||
return SetTensorProto(param->Type(), param->Shape(), tensor_proto);
|
||||
return SetTensorProto(param->abstract(), tensor_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
|
@ -569,18 +565,16 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
|
|||
return type_name;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, const AbstractBasePtr &abs,
|
||||
mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
bool IrExportBuilder::SetShapeToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
auto type = abs->BuildType();
|
||||
auto shape = abs->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(seq_string);
|
||||
if (type->isa<Tuple>()) {
|
||||
*seq_string += "Tuple[";
|
||||
auto elements = type->cast<TuplePtr>()->elements();
|
||||
auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
|
||||
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>()->elements();
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
if (!SetShapeToNodeProto(elements[i], tuple_shape[i], tuple_abs[i], attr_proto, seq_string)) {
|
||||
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
|
||||
for (size_t i = 0; i < tuple_abs->size(); i++) {
|
||||
if (!SetShapeToNodeProto((*tuple_abs)[i], attr_proto, seq_string)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -590,7 +584,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
*seq_string += shape_name + ",";
|
||||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_name(shape_name);
|
||||
return SetTensorProto(type, shape, tensor_proto) && SetTensorProtoForRef(type, abs, tensor_proto);
|
||||
return SetTensorProto(abs, tensor_proto);
|
||||
} else if (type->isa<Number>()) {
|
||||
if (type->isa<Bool>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
|
@ -608,9 +602,10 @@ bool IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt
|
|||
} else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
|
||||
*seq_string += type->type_name() + ",";
|
||||
} else if (type->isa<CSRTensorType>()) {
|
||||
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor_type);
|
||||
if (!SetShapeToNodeProto(csr_tensor_type->element(), shape, abs, attr_proto, seq_string)) return false;
|
||||
auto cst_tensor_abs = abs->cast<abstract::AbstractCSRTensorPtr>();
|
||||
if (!SetShapeToNodeProto(cst_tensor_abs->element(), attr_proto, seq_string)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
||||
return false;
|
||||
|
@ -634,7 +629,7 @@ bool IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodePro
|
|||
ResetTupleIndex();
|
||||
std::string seq_string = "shape:";
|
||||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
if (!SetShapeToNodeProto(type, shape, abs, attr_proto, &seq_string)) {
|
||||
if (!SetShapeToNodeProto(abs, attr_proto, &seq_string)) {
|
||||
MS_LOG(ERROR) << "Set shape to NodeProto for " << node->DebugString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -241,11 +241,107 @@ string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
|
|||
return "";
|
||||
}
|
||||
} // namespace
|
||||
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor,
|
||||
bool need_load_data) {
|
||||
ShapeVector shape;
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
|
||||
tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto) {
|
||||
if (!IsIncLoad() || load_tensor_map_.find(attr_tensor.name()) == load_tensor_map_.end()) {
|
||||
load_tensor_map_[attr_tensor.name()] = tensor;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
if (attr_tensor.has_raw_data()) {
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (need_load_data) {
|
||||
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
|
||||
return nullptr;
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto,
|
||||
const AnfNodePtr &node_ptr) {
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
if (attr_proto.ref_attr_name().find("shape:") == string::npos) {
|
||||
MS_LOG(ERROR) << "Cannot use a attr_proto " << attr_proto.ref_attr_name() << " to init shape.";
|
||||
return false;
|
||||
}
|
||||
bool is_tuple_or_list = false;
|
||||
|
||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||
is_tuple_or_list =
|
||||
shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
|
||||
kv = GetAbstractForNode(attr_proto);
|
||||
if (kv.empty()) {
|
||||
return SetEmptyTensorProtoCNodeAbstract(node_ptr);
|
||||
} else if (!is_tuple_or_list) {
|
||||
auto iter = kv.begin();
|
||||
if (iter->second != nullptr) {
|
||||
node_ptr->set_abstract(iter->second);
|
||||
}
|
||||
} else {
|
||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||
node_ptr->set_abstract(abstract);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Node's attribute is nullptr.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
|
||||
auto prim = GetCNodePrimitive(cnode_ptr);
|
||||
auto prim_to_add_attr = prim;
|
||||
if (prim_to_add_attr != nullptr) {
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
if (func != nullptr && func->isa<Primitive>()) {
|
||||
prim_to_add_attr = func->cast<PrimitivePtr>();
|
||||
}
|
||||
}
|
||||
prim_to_add_attr->set_attr("is_load", MakeValue(true));
|
||||
}
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
// CNode abstract
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
SetCNodeAbstract(attr_proto, cnode_ptr);
|
||||
continue;
|
||||
}
|
||||
if (prim_to_add_attr != nullptr) {
|
||||
if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Parser prim: " << prim->ToString() << " attributes error : " << attr_proto.DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract::AbstractTensorPtr MSANFModelParser::GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto) {
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < tensor_proto.dims_size(); ++i) {
|
||||
shape.push_back(tensor_proto.dims(i));
|
||||
shape.emplace_back(tensor_proto.dims(i));
|
||||
}
|
||||
ShapeVector min_shape;
|
||||
for (int i = 0; i < tensor_proto.min_dims_size(); ++i) {
|
||||
min_shape.emplace_back(tensor_proto.min_dims(i));
|
||||
}
|
||||
|
||||
ShapeVector max_shape;
|
||||
for (int i = 0; i < tensor_proto.max_dims_size(); ++i) {
|
||||
max_shape.emplace_back(tensor_proto.min_dims(i));
|
||||
}
|
||||
|
||||
if (!tensor_proto.has_data_type()) {
|
||||
|
@ -253,14 +349,20 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T
|
|||
MS_LOG(ERROR) << "mind_ir TensorProto has no data_type.";
|
||||
return nullptr;
|
||||
}
|
||||
if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) {
|
||||
auto iter = kDefaultValueSwitchMap.find(tensor_proto.data_type());
|
||||
if (iter == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "mind_ir build tensor: " << tensor_proto.name() << " failed";
|
||||
MS_LOG(ERROR) << "mind_ir TensorProto data_type: " << tensor_proto.data_type() << " is not support yet!";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[tensor_proto.data_type()], shape);
|
||||
auto tensor_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
auto tensor_info = std::make_shared<abstract::AbstractTensor>(TypeIdToType(iter->second), tensor_shape);
|
||||
if (tensor_proto.has_ref_key()) {
|
||||
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
|
||||
return abs_ref;
|
||||
}
|
||||
return tensor_info;
|
||||
}
|
||||
|
||||
|
@ -277,51 +379,34 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(debug_info_name);
|
||||
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
return false;
|
||||
}
|
||||
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
||||
param_info->set_name(debug_info_name);
|
||||
|
||||
MS_LOG(DEBUG) << "Load parameter name: " << parameter_proto.name();
|
||||
if (!IsIncLoad() || load_tensor_map_.find(parameter_proto.name()) == load_tensor_map_.end()) {
|
||||
load_tensor_map_[parameter_proto.name()] = tensor_info;
|
||||
} else {
|
||||
if (IsIncLoad() && load_tensor_map_.find(parameter_proto.name()) != load_tensor_map_.end()) {
|
||||
MS_LOG(DEBUG) << "Parameter: " << parameter_proto.name() << " has been already loaded, use it again.";
|
||||
tensor::TensorPtr load_tensor_info = load_tensor_map_[parameter_proto.name()];
|
||||
auto tensor_abstract = load_tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
auto load_tensor_info = load_tensor_map_[parameter_proto.name()];
|
||||
MS_EXCEPTION_IF_NULL(load_tensor_info);
|
||||
load_tensor_info->set_param_info(param_info);
|
||||
node->set_default_param(load_tensor_info);
|
||||
node->set_abstract(load_tensor_info->ToAbstract());
|
||||
anfnode_build_map_[parameter_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
||||
param_info->set_name(debug_info_name);
|
||||
tensor_info->set_param_info(param_info);
|
||||
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
node->set_abstract(tensor_abstract);
|
||||
|
||||
auto tensor = GenerateTensorPtrFromTensorProto(parameter_proto, false);
|
||||
tensor->set_param_info(param_info);
|
||||
if (parameter_proto.has_raw_data()) {
|
||||
std::string initial_data = parameter_proto.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_buf);
|
||||
auto ret = memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), initial_data.data(),
|
||||
initial_data.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Build parameter occur memcpy_s error.";
|
||||
return false;
|
||||
}
|
||||
node->set_default_param(tensor_info);
|
||||
node->set_default_param(tensor);
|
||||
} else if (parameter_proto.has_external_data()) {
|
||||
auto ret = GetTensorDataFromExternal(parameter_proto, tensor_info);
|
||||
auto ret = GetTensorDataFromExternal(parameter_proto, tensor);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
node->set_default_param(tensor_info);
|
||||
node->set_default_param(tensor);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Parameter will load initialized data.";
|
||||
}
|
||||
node->set_abstract(tensor->ToAbstract());
|
||||
|
||||
anfnode_build_map_[parameter_proto.name()] = node;
|
||||
return true;
|
||||
|
@ -398,7 +483,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
// Set abstract of the parameter
|
||||
if (value_proto.tensor_size() > 0) {
|
||||
const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0);
|
||||
tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto);
|
||||
auto tensor_info = GetAbsTensorFromTensorProto(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor_info fail.";
|
||||
return false;
|
||||
|
@ -406,8 +491,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
if (tensor_proto.has_ref_key() && top_graph_ != nullptr) {
|
||||
auto ref_key = std::make_shared<RefKey>(tensor_proto.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, tensor_info);
|
||||
node->set_abstract(abs_ref);
|
||||
auto parameters = top_graph_->parameters();
|
||||
for (auto parameter : parameters) {
|
||||
|
@ -421,8 +505,7 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
}
|
||||
}
|
||||
} else {
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
node->set_abstract(tensor_abstract);
|
||||
node->set_abstract(tensor_info);
|
||||
}
|
||||
} else if (value_proto.has_denotation()) {
|
||||
if (value_proto.denotation() == "UMonadType") {
|
||||
|
@ -432,6 +515,12 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi
|
|||
}
|
||||
MS_LOG(DEBUG) << "Not tensor. parameter type: " << value_proto.denotation();
|
||||
}
|
||||
if (value_proto.has_attr_info()) {
|
||||
if (!SetNodeAbstractFromAttrProto(value_proto.attr_info(), node)) {
|
||||
MS_LOG(ERROR) << "Failed to get abstract from proto.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
anfnode_build_map_[value_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
@ -589,19 +678,10 @@ bool MSANFModelParser::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim,
|
|||
const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const mind_ir::TensorProto attr_tensor = attr_proto.tensors(0);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret =
|
||||
memcpy_s(tensor_data_buf, static_cast<size_t>(tensor_info->data().nbytes()), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Obtain CNode in TensorForm occur memcpy_s error.";
|
||||
auto tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get attr[" << attr_proto.name() << "] for node " << prim->ToString()
|
||||
<< " from the proto.";
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_proto.name(), MakeValue(tensor_info));
|
||||
|
@ -682,21 +762,10 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
|
||||
bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name,
|
||||
const mind_ir::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int i = 0; i < attr_tensor.dims_size(); ++i) {
|
||||
shape.push_back(attr_tensor.dims(i));
|
||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(attr_tensor);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode in TensorForm occur memcpy_s error.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_info));
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
auto tensor_abstract = tensor_info->ToAbstract();
|
||||
|
@ -740,9 +809,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n
|
|||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return false;
|
||||
}
|
||||
auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]));
|
||||
abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>());
|
||||
new_value_node->set_abstract(abs_type);
|
||||
auto value = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
auto new_value_node = NewValueNode(value);
|
||||
new_value_node->set_abstract(value->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
@ -872,29 +941,13 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node
|
|||
return GetAttrValueForValueNode(value_node_name, attr_proto);
|
||||
}
|
||||
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForCNode(
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> MSANFModelParser::GetAbstractForNode(
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
|
||||
ShapeVector shape_vec;
|
||||
const mind_ir::TensorProto &attr_tensor = attr_proto.tensors(i);
|
||||
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||
shape_vec.push_back(attr_tensor.dims(j));
|
||||
}
|
||||
tensor::TensorPtr tensor_info =
|
||||
std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
if (attr_tensor.has_ref_key()) {
|
||||
auto ref_key = std::make_shared<RefKey>(attr_tensor.ref_key());
|
||||
auto abs_ref_key = ref_key->ToAbstract();
|
||||
auto abs_value = tensor_info->ToAbstract()->Broaden()->cast<abstract::AbstractTensorPtr>();
|
||||
auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
|
||||
(void)kv.emplace(attr_tensor.name(), abs_ref);
|
||||
} else {
|
||||
auto abstract = tensor_info->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
(void)kv.emplace(attr_tensor.name(), abstract);
|
||||
}
|
||||
auto tensor_info = GetAbsTensorFromTensorProto(attr_tensor);
|
||||
(void)kv.emplace(attr_tensor.name(), tensor_info);
|
||||
}
|
||||
return kv;
|
||||
}
|
||||
|
@ -946,26 +999,6 @@ AnfNodePtr MSANFModelParser::BuildOperatorNode(const mind_ir::NodeProto &node_pr
|
|||
prim->set_instance_name(node_type);
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_to_add_attr = prim;
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
if (func != nullptr && func->isa<Primitive>()) {
|
||||
prim_to_add_attr = func->cast<PrimitivePtr>();
|
||||
prim_to_add_attr->set_attr("is_load", MakeValue(true));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
// CNode abstract
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
continue;
|
||||
}
|
||||
if (!GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Parser prim: " << node_type << " attributes error : " << attr_proto.DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
prim->set_attr("is_load", MakeValue(true));
|
||||
return std::make_shared<ValueNode>(prim);
|
||||
}
|
||||
|
@ -991,70 +1024,47 @@ bool MSANFModelParser::CheckCNodePrim(CNodePtr cnode_ptr) {
|
|||
return false;
|
||||
}
|
||||
|
||||
void MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type) {
|
||||
if (node_type == "UpdateState") {
|
||||
cnode_ptr->set_abstract(kUMonad->ToAbstract());
|
||||
} else if (node_type == "Depend") {
|
||||
cnode_ptr->set_abstract(kBool->ToAbstract());
|
||||
} else {
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
auto abs = cnode_ptr->input(index)->abstract();
|
||||
if (abs != nullptr) {
|
||||
if (abs->GetValueTrack() == nullptr) {
|
||||
abs->set_value(kAnyValue);
|
||||
bool MSANFModelParser::SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr) {
|
||||
auto primitive = GetCNodePrimitive(node_ptr);
|
||||
if (primitive != nullptr) {
|
||||
auto node_type = primitive->name();
|
||||
if (node_type == "UpdateState") {
|
||||
node_ptr->set_abstract(kUMonad->ToAbstract());
|
||||
} else if (node_type == "Depend") {
|
||||
node_ptr->set_abstract(kBool->ToAbstract());
|
||||
} else {
|
||||
auto cnode_ptr = node_ptr->cast<CNodePtr>();
|
||||
AbstractBasePtrList elem;
|
||||
for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) {
|
||||
auto abs = cnode_ptr->input(index)->abstract();
|
||||
if (abs != nullptr) {
|
||||
if (abs->GetValueTrack() == nullptr) {
|
||||
abs->set_value(kAnyValue);
|
||||
}
|
||||
elem.push_back(abs);
|
||||
}
|
||||
elem.push_back(abs);
|
||||
}
|
||||
if (!elem.empty()) {
|
||||
node_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
}
|
||||
}
|
||||
if (!elem.empty()) {
|
||||
cnode_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Failed to get the abstract of node:" << node_ptr->DebugString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Set CNode abstract.
|
||||
void MSANFModelParser::SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr) {
|
||||
bool MSANFModelParser::SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr) {
|
||||
if (CheckCNodePrim(cnode_ptr)) {
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> kv;
|
||||
string shape_ref_attr_name;
|
||||
|
||||
bool is_tuple_or_list = false;
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = node_proto.attribute(i);
|
||||
if (attr_proto.ref_attr_name().find("shape:") != string::npos) {
|
||||
shape_ref_attr_name = attr_proto.ref_attr_name();
|
||||
is_tuple_or_list =
|
||||
shape_ref_attr_name.find("Tuple[") != string::npos || shape_ref_attr_name.find("List[") != string::npos;
|
||||
kv = GetAbstractForCNode(attr_proto);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Because there is not context in unit test,
|
||||
// abstract->broaden() is replaced by abstract->set_value(kAnyValue).
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
if (kv.size() == 0) {
|
||||
SetEmptyTensorProtoCNodeAbstract(cnode_ptr, node_type);
|
||||
} else if (kv.size() == 1 && !is_tuple_or_list) {
|
||||
auto iter = kv.begin();
|
||||
if (iter->second != nullptr) {
|
||||
iter->second->set_value(kAnyValue);
|
||||
cnode_ptr->set_abstract(iter->second);
|
||||
}
|
||||
} else {
|
||||
auto abstract = ParserAttrShape(shape_ref_attr_name, kv);
|
||||
if (abstract == nullptr) {
|
||||
cnode_ptr->set_abstract(nullptr);
|
||||
MS_LOG(ERROR) << "Node's attribute is nullptr.";
|
||||
} else {
|
||||
abstract->set_value(kAnyValue);
|
||||
cnode_ptr->set_abstract(abstract);
|
||||
}
|
||||
if (!SetNodeAbstractFromAttrProto(attr_proto, cnode_ptr)) {
|
||||
MS_LOG(ERROR) << "Failed to get CNode abstract from proto.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
|
@ -1085,8 +1095,8 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc
|
|||
|
||||
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
SetCNodeAbstract(node_proto, cnode_ptr);
|
||||
|
||||
// Set Abstract and prim attr for CNode
|
||||
SetCNodePrimAttrAndAbstract(node_proto, cnode_ptr);
|
||||
const std::string &fullname_with_scope = node_proto.domain();
|
||||
string debug_info_name = ParseCNodeName(node_name);
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
|
@ -1148,6 +1158,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra
|
|||
outputFuncGraph->set_return(return_node);
|
||||
MS_LOG(DEBUG) << "Construct funcgraph finined, all success!";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1386,7 +1397,7 @@ AnfNodePtr MSANFModelParser::GetAnfNode(const std::string &node_name) {
|
|||
return nullptr;
|
||||
}
|
||||
FuncGraphPtr func_graph_ptr = GetValueNode<FuncGraphPtr>(it->second);
|
||||
if (func_graph_ptr) {
|
||||
if (func_graph_ptr != nullptr) {
|
||||
return NewValueNode(func_graph_ptr);
|
||||
} else {
|
||||
return it->second;
|
||||
|
|
|
@ -63,7 +63,7 @@ class MSANFModelParser {
|
|||
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
|
||||
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
|
||||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||
tensor::TensorPtr BuildTensorInfoForFuncGraph(const mind_ir::TensorProto &tensor_proto);
|
||||
abstract::AbstractTensorPtr GetAbsTensorFromTensorProto(const mind_ir::TensorProto &tensor_proto);
|
||||
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::NodeProto &node_proto);
|
||||
bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool GetAttrValueForCNode(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
|
@ -76,8 +76,10 @@ class MSANFModelParser {
|
|||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
|
||||
bool CheckCNodePrim(CNodePtr cnode_ptr);
|
||||
void SetEmptyTensorProtoCNodeAbstract(CNodePtr cnode_ptr, const std::string &node_type);
|
||||
void SetCNodeAbstract(const mind_ir::NodeProto &node_proto, CNodePtr cnode_ptr);
|
||||
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
|
||||
bool SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
|
||||
bool SetNodeAbstractFromAttrProto(const mind_ir::AttributeProto &attr_proto, const AnfNodePtr &node_ptr);
|
||||
void SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
|
||||
|
@ -85,9 +87,11 @@ class MSANFModelParser {
|
|||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool little_endian() { return little_endian_; }
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForCNode(
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForNode(
|
||||
const mind_ir::AttributeProto &attr_proto);
|
||||
AnfNodePtr GetAnfNode(const std::string &node_name);
|
||||
tensor::TensorPtr GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor,
|
||||
bool need_load_data = true);
|
||||
|
||||
FuncGraphPtr top_graph_ = nullptr;
|
||||
std::string producer_name_;
|
||||
|
|
|
@ -52,6 +52,7 @@ message ValueInfoProto {
|
|||
repeated TensorProto tensor = 2;
|
||||
optional string doc_string = 3;
|
||||
optional string denotation = 4;
|
||||
optional AttributeProto attr_info = 5; // graph input info for other type
|
||||
}
|
||||
|
||||
|
||||
|
@ -150,6 +151,8 @@ message TensorProto {
|
|||
repeated uint64 uint64_data = 11;
|
||||
optional ExternalDataProto external_data = 12;
|
||||
optional string ref_key = 13;
|
||||
repeated int64 min_dims = 14;
|
||||
repeated int64 max_dims = 15;
|
||||
}
|
||||
|
||||
message ParallelProto {
|
||||
|
|
Loading…
Reference in New Issue