forked from mindspore-Ecosystem/mindspore
!16131 add export mindir for tuple tensor
From: @changzherui Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxian
This commit is contained in:
commit
126abb728a
|
@ -519,7 +519,7 @@ def set_context(**kwargs):
|
|||
Set context for running environment.
|
||||
|
||||
Context should be configured before running your program. If there is no configuration,
|
||||
the "Ascend" device target will be used by default. GRAPH_MODE or
|
||||
it will automatic acquisition according to device target by default. GRAPH_MODE or
|
||||
PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
|
||||
mode is PYNATIVE_MODE.
|
||||
|
||||
|
@ -553,7 +553,7 @@ def set_context(**kwargs):
|
|||
|
||||
Args:
|
||||
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE(1).
|
||||
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". Default: "Ascend".
|
||||
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
|
||||
device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
|
||||
while device_num_per_host should be no more than 4096. Default: 0.
|
||||
save_graphs (bool): Whether to save graphs. Default: False.
|
||||
|
@ -685,6 +685,7 @@ def set_context(**kwargs):
|
|||
def get_context(attr_key):
|
||||
"""
|
||||
Get context attribute value according to the input key.
|
||||
If some attribute are not set, it will be automatically obtained.
|
||||
|
||||
Args:
|
||||
attr_key (str): The key of the attribute.
|
||||
|
|
|
@ -555,6 +555,27 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTupleTensorForm(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
std::vector<tensor::TensorPtr> tensor_vec;
|
||||
for (int i = 0; i < attr_proto.tensors_size(); ++i) {
|
||||
mind_ir::TensorProto attr_tensor = attr_proto.tensors(i);
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
ShapeVector shape;
|
||||
for (int j = 0; j < attr_tensor.dims_size(); ++j) {
|
||||
shape.push_back(attr_tensor.dims(j));
|
||||
}
|
||||
tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape);
|
||||
const std::string &tensor_buf = attr_tensor.raw_data();
|
||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->data_c());
|
||||
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
tensor_vec.push_back(tensor_info);
|
||||
}
|
||||
auto new_value_node = NewValueNode(MakeValue(tensor_vec));
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_name,
|
||||
const mind_ir::TensorProto &attr_tensor) {
|
||||
const int attr_tensor_type = attr_tensor.data_type();
|
||||
|
@ -646,6 +667,11 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na
|
|||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
}
|
||||
if ((value_pos = ref_attr_name.find("Tuple[value")) != std::string::npos && attr_proto.tensors_size() > 1) {
|
||||
MS_LOG(INFO) << "Build TupleTensor ValueNode for primitive.";
|
||||
ObtainValueNodeInTupleTensorForm(value_node_name, attr_proto);
|
||||
break;
|
||||
}
|
||||
ObtainCNodeAttrInScalarForm(attr_proto, &multi_value_map);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ class MSANFModelParser {
|
|||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||
bool ObtainValueNodeInTensorForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInTupleTensorForm(const string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor);
|
||||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
|
|
Loading…
Reference in New Issue