!45263 add map-parameter export&load support

Merge pull request !45263 from hangq/fire
This commit is contained in:
i-robot 2022-11-18 06:24:21 +00:00 committed by Gitee
commit 43c452f099
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 411 additions and 91 deletions

View File

@ -131,6 +131,9 @@ class IrExportBuilder {
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
mind_ir::MapTensorProto *const map_tensor_proto);
bool ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
@ -448,17 +451,29 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
std::string param_name = GetUniqueNodeName(param);
if (top_graph && param->has_default()) {
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
if (!SetParamToTensorProto(param, parameter_proto)) {
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
if (param->abstract()->isa<abstract::AbstractMapTensor>()) {
auto *map_parameter_proto = graph_proto->add_map_parameter();
if (!ConvertMapParameterToMapTensorProto(param, map_parameter_proto)) {
MS_LOG(ERROR) << "Convert MapParameter " << param->ToString() << " to MapTensorProto failed.";
return false;
}
} else if (param->abstract()->isa<abstract::AbstractTensor>()) {
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
parameter_proto->set_name(param_name);
if (!SetParamToTensorProto(param, parameter_proto)) {
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
return false;
}
auto tensor = param->default_param()->cast<tensor::TensorPtr>();
if (tensor != nullptr) {
parameter_proto->set_compression_type(
static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
}
} else {
MS_LOG(ERROR) << "Only support MapTensor or Tensor as default param of Parameter, got: "
<< param->default_param()->ToString();
return false;
}
auto tensor = param->default_param()->cast<tensor::TensorPtr>();
if (tensor != nullptr) {
parameter_proto->set_compression_type(
static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
}
} else {
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
input_proto->set_name(param_name);
@ -660,6 +675,101 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::
return SetTensorProto(param->abstract(), tensor_proto);
}
bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
mind_ir::MapTensorProto *const map_tensor_proto) {
if (map_parameter == nullptr || map_tensor_proto == nullptr) {
MS_LOG(EXCEPTION) << "MapParameter or MapTensorProto is null!";
}
MS_LOG(DEBUG) << "ConvertMapParameterToMapTensorProto: " << map_parameter->ToString();
// parameter name
map_tensor_proto->set_name(GetUniqueNodeName(map_parameter));
auto param_default = map_parameter->default_param();
MS_EXCEPTION_IF_NULL(param_default);
auto map_tensor = param_default->cast<tensor::MapTensorPtr>();
MS_EXCEPTION_IF_NULL(map_tensor);
// default value
auto default_value = map_tensor->default_value();
MS_EXCEPTION_IF_NULL(default_value);
auto *default_value_proto = map_tensor_proto->mutable_default_value();
MS_EXCEPTION_IF_NULL(default_value_proto);
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
return false;
}
tensor::MapTensor::ExportData export_data = map_tensor->Export(true);
// key_tensor
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
MS_EXCEPTION_IF_NULL(key_tensor_proto);
auto &key_tensor = export_data.key_tensor;
MS_EXCEPTION_IF_NULL(key_tensor);
if (!SetTensorProto(key_tensor->ToAbstract(), key_tensor_proto)) {
MS_LOG(ERROR) << "Export key tensor of MapTensor failed, key_tensor: " << key_tensor->ToString();
return false;
}
// value_tensor
auto *value_tensor_proto = map_tensor_proto->mutable_value_tensor();
MS_EXCEPTION_IF_NULL(value_tensor_proto);
auto &value_tensor = export_data.value_tensor;
MS_EXCEPTION_IF_NULL(value_tensor);
if (!SetTensorProto(value_tensor->ToAbstract(), value_tensor_proto)) {
MS_LOG(ERROR) << "Export value tensor of MapTensor failed, value_tensor: " << value_tensor->ToString();
return false;
}
// status_tensor
auto *status_tensor_proto = map_tensor_proto->mutable_status_tensor();
MS_EXCEPTION_IF_NULL(status_tensor_proto);
auto &status_tensor = export_data.status_tensor;
MS_EXCEPTION_IF_NULL(status_tensor);
if (!SetTensorProto(status_tensor->ToAbstract(), status_tensor_proto)) {
MS_LOG(ERROR) << "Export status tensor of MapTensor failed, status_tensor: " << status_tensor->ToString();
return false;
}
return true;
}
bool IrExportBuilder::ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract,
mind_ir::AttributeProto *const attr_proto) {
auto map_tensor_abs = abstract->cast<abstract::AbstractMapTensorPtr>();
MS_EXCEPTION_IF_NULL(map_tensor_abs);
auto map_tensor_type = map_tensor_abs->map_tensor_type();
MS_EXCEPTION_IF_NULL(map_tensor_type);
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_MAP_TENSOR);
// key_tensor
auto key_dtype = map_tensor_type->key_dtype();
auto key_shape = {abstract::Shape::kShapeDimAny};
auto key_tensor_abs = std::make_shared<abstract::AbstractTensor>(key_dtype, key_shape);
auto *key_tensor_proto = attr_proto->add_tensors();
MS_EXCEPTION_IF_NULL(key_tensor_proto);
if (!SetTensorProto(key_tensor_abs, key_tensor_proto)) {
MS_LOG(ERROR) << "Export key tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
<< abstract->ToString();
return false;
}
// value_dtype value_shape
auto value_dtype = map_tensor_type->key_dtype();
auto value_shape = map_tensor_abs->value_shape()->shape();
auto value_tensor_abs = std::make_shared<abstract::AbstractTensor>(value_dtype, value_shape);
auto *value_tensor_proto = attr_proto->add_tensors();
MS_EXCEPTION_IF_NULL(value_tensor_proto);
if (!SetTensorProto(value_tensor_abs, value_tensor_proto)) {
MS_LOG(ERROR) << "Export value tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
<< abstract->ToString();
return false;
}
// default_value
auto default_value = map_tensor_abs->default_value();
auto *default_value_proto = attr_proto->add_values();
MS_EXCEPTION_IF_NULL(default_value_proto);
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
MS_LOG(ERROR) << "Export default value of AbstractMapTensor failed, abstract_map_tensor: " << abstract->ToString();
return false;
}
return true;
}
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
for (const AnfNodePtr &node : nodes) {
@ -799,6 +909,8 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
}
} else if (type->isa<TypeNone>()) {
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
} else if (type->isa<MapTensorType>()) {
return ConvertAbstractMapTensorToAttrProto(abs, attr_proto);
} else {
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
return false;

View File

@ -31,4 +31,16 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
abstract::AbstractBasePtr MindIRClassType::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRClassType>(), std::make_shared<TypeType>());
}
abstract::AbstractBasePtr MindIRNameSpace::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRNameSpace>(), std::make_shared<External>());
}
abstract::AbstractBasePtr MindIRSymbol::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRSymbol>(), std::make_shared<External>());
}
} // namespace mindspore

View File

@ -152,11 +152,12 @@ class MS_CORE_API Ellipsis final : public Named {
};
GVAR_DEF(NamedPtr, kEllipsis, std::make_shared<Ellipsis>());
class MindIRClassType final : public Named {
class MS_CORE_API MindIRClassType final : public Named {
public:
explicit MindIRClassType(const std::string &class_type) : Named(class_type) {}
~MindIRClassType() override = default;
MS_DECLARE_PARENT(MindIRClassType, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
using MindIRClassTypePtr = std::shared_ptr<MindIRClassType>;
@ -168,24 +169,26 @@ class MindIRMetaFuncGraph final : public Named {
};
using MindIRMetaFuncGraphPtr = std::shared_ptr<MindIRMetaFuncGraph>;
class MindIRNameSpace final : public Named {
class MS_CORE_API MindIRNameSpace final : public Named {
public:
explicit MindIRNameSpace(const std::string &name_space) : Named(name_space), name_space_(name_space) {}
~MindIRNameSpace() override = default;
MS_DECLARE_PARENT(MindIRNameSpace, Named);
const std::string &name_space() const { return name_space_; }
abstract::AbstractBasePtr ToAbstract() override;
private:
std::string name_space_;
};
using MindIRNameSpacePtr = std::shared_ptr<MindIRNameSpace>;
class MindIRSymbol final : public Named {
class MS_CORE_API MindIRSymbol final : public Named {
public:
explicit MindIRSymbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
~MindIRSymbol() override = default;
MS_DECLARE_PARENT(MindIRSymbol, Named);
const std::string &symbol() const { return symbol_; }
abstract::AbstractBasePtr ToAbstract() override;
private:
std::string symbol_;

View File

@ -27,6 +27,7 @@
#include <algorithm>
#include "ir/tensor.h"
#include "ir/param_info.h"
#include "ir/map_tensor.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "abstract/ops/primitive_infer_map.h"
@ -298,13 +299,13 @@ tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
if (ret != EOK) {
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
MS_LOG(ERROR) << "Failed to copy data from tensor proto.";
return nullptr;
}
} else if (attr_tensor.has_external_data()) {
auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
if (!ret) {
MS_LOG(ERROR) << "Failed Load data from external.";
MS_LOG(ERROR) << "Failed to get external data from tensor proto.";
return nullptr;
}
} else {
@ -321,28 +322,13 @@ abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType
return GetAbsTensorFromTensorProto(attr_tensor);
}
case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
std::vector<abstract::AbstractBasePtr> vec;
for (int i = 0; i < attr_proto.values_size(); ++i) {
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
if (abs == nullptr) {
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
return nullptr;
}
(void)vec.emplace_back(abs);
}
return std::make_shared<abstract::AbstractCSRTensor>(vec);
return BuildAbstractCSRTensorFromAttrProto(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
std::vector<abstract::AbstractBasePtr> vec;
for (int i = 0; i < attr_proto.values_size(); ++i) {
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
if (abs == nullptr) {
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
return nullptr;
}
(void)vec.emplace_back(abs);
}
return std::make_shared<abstract::AbstractCOOTensor>(vec);
return BuildAbstractCOOTensorFromAttrProto(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_MAP_TENSOR: {
return BuildAbstractMapTensorFromAttrProto(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_TUPLE: {
std::vector<abstract::AbstractBasePtr> vec;
@ -522,6 +508,146 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
return true;
}
abstract::AbstractCOOTensorPtr MSANFModelParser::BuildAbstractCOOTensorFromAttrProto(
const mind_ir::AttributeProto &attr_proto) {
std::vector<abstract::AbstractBasePtr> vec;
for (int i = 0; i < attr_proto.values_size(); ++i) {
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
if (abs == nullptr) {
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
return nullptr;
}
(void)vec.emplace_back(abs);
}
return std::make_shared<abstract::AbstractCOOTensor>(vec);
}
abstract::AbstractCSRTensorPtr MSANFModelParser::BuildAbstractCSRTensorFromAttrProto(
const mind_ir::AttributeProto &attr_proto) {
std::vector<abstract::AbstractBasePtr> vec;
for (int i = 0; i < attr_proto.values_size(); ++i) {
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
if (abs == nullptr) {
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
return nullptr;
}
(void)vec.emplace_back(abs);
}
return std::make_shared<abstract::AbstractCSRTensor>(vec);
}
abstract::AbstractMapTensorPtr MSANFModelParser::BuildAbstractMapTensorFromAttrProto(
const mind_ir::AttributeProto &attr_proto) {
// default value
if (attr_proto.values_size() != 1) {
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 1 value, but got " << attr_proto.values_size();
}
const auto &default_value_proto = attr_proto.values(0);
auto default_value = ObtainCNodeAttrInSingleScalarForm(default_value_proto);
MS_EXCEPTION_IF_NULL(default_value);
constexpr size_t kAbstractMapTensorAttrProtoTensorsSize = 2;
if (attr_proto.tensors_size() != kAbstractMapTensorAttrProtoTensorsSize) {
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 2 tensors, but got " << attr_proto.tensors_size();
}
// key tensor
const auto &key_tensor_proto = attr_proto.tensors(0);
auto key_tensor_abs = GetAbsTensorFromTensorProto(key_tensor_proto);
MS_EXCEPTION_IF_NULL(key_tensor_abs);
// value tensor
const auto &value_tensor_proto = attr_proto.tensors(1);
auto value_tensor_abs = GetAbsTensorFromTensorProto(value_tensor_proto);
MS_EXCEPTION_IF_NULL(value_tensor_abs);
auto value_build_shape_ptr = value_tensor_abs->BuildShape();
if (!value_build_shape_ptr->isa<abstract::Shape>()) {
MS_LOG(EXCEPTION) << "value_shape of AbstractMapTensor should be a Shape, but got "
<< value_build_shape_ptr->ToString();
}
auto value_shape_ptr = value_build_shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(value_shape_ptr);
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor_abs->BuildType()->type_id(),
value_tensor_abs->BuildType()->type_id(),
value_shape_ptr->shape(), default_value);
return std::make_shared<abstract::AbstractMapTensor>(map_tensor);
}
bool MSANFModelParser::BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
const mind_ir::MapTensorProto &map_parameter_proto) {
MS_EXCEPTION_IF_NULL(node);
if (!map_parameter_proto.has_name()) {
MS_LOG(ERROR) << "mind_ir MapTensorProto has no name!";
return false;
}
string debug_info_name = ParseParameterName(map_parameter_proto.name());
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
node->set_debug_info(debug_info_ptr);
node->set_name(debug_info_name);
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
param_info->set_name(debug_info_name);
MS_LOG(DEBUG) << "Load map parameter name: " << map_parameter_proto.name();
if (IsIncLoad() && load_tensor_map_.find(map_parameter_proto.name()) != load_tensor_map_.end()) {
MS_LOG(ERROR) << "MapParameter dose not support incremental loading, param_name: " << map_parameter_proto.name();
return false;
}
// default value
if (!map_parameter_proto.has_default_value()) {
MS_LOG(ERROR) << "MapTensorProto should have default value: " << map_parameter_proto.name();
return false;
}
const auto &default_value_proto = map_parameter_proto.default_value();
auto default_value = BuildValueFromAttributeProto(default_value_proto);
if (default_value == nullptr) {
MS_LOG(ERROR) << "Build default value from AttributeProto failed.";
return false;
}
// key tensor
if (!map_parameter_proto.has_key_tensor()) {
MS_LOG(ERROR) << "MapTensorProto should have key tensor: " << map_parameter_proto.name();
return false;
}
const auto &key_tensor_proto = map_parameter_proto.key_tensor();
auto key_tensor = GenerateTensorPtrFromTensorProto(key_tensor_proto);
if (key_tensor == nullptr) {
MS_LOG(ERROR) << "Generate key tensor from TensorProto failed.";
return false;
}
// value tensor
if (!map_parameter_proto.has_value_tensor()) {
MS_LOG(ERROR) << "MapTensorProto should have value tensor: " << map_parameter_proto.name();
return false;
}
const auto &value_tensor_proto = map_parameter_proto.value_tensor();
auto value_tensor = GenerateTensorPtrFromTensorProto(value_tensor_proto);
if (value_tensor == nullptr) {
MS_LOG(ERROR) << "Generate value tensor from TensorProto failed.";
return false;
}
// status tensor
if (!map_parameter_proto.has_status_tensor()) {
MS_LOG(ERROR) << "MapTensorProto should have status tensor: " << map_parameter_proto.name();
return false;
}
const auto &status_tensor_proto = map_parameter_proto.status_tensor();
auto status_tensor = GenerateTensorPtrFromTensorProto(status_tensor_proto);
if (status_tensor == nullptr) {
MS_LOG(ERROR) << "Generate status tensor from TensorProto failed.";
return false;
}
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor, value_tensor, status_tensor, default_value);
map_tensor->set_param_info(param_info);
node->set_default_param(map_tensor);
node->set_abstract(map_tensor->ToAbstract());
anfnode_build_map_[map_parameter_proto.name()] = node;
return true;
}
bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
const tensor::TensorPtr &tensor_info) {
if (!tensor_proto.has_external_data()) {
@ -678,6 +804,21 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
return true;
}
bool MSANFModelParser::ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const mind_ir::GraphProto &importProto) {
MS_EXCEPTION_IF_NULL(outputFuncGraph);
MS_LOG(INFO) << "All MapParameters size is: " << importProto.map_parameter_size();
for (int i = 0; i < importProto.map_parameter_size(); ++i) {
const mind_ir::MapTensorProto &map_parameter_proto = importProto.map_parameter(i);
if (!BuildMapParameterFromMapTensorProto(outputFuncGraph->add_parameter(), map_parameter_proto)) {
MS_LOG(ERROR) << "Build map parameter for funcgraph fail at index: " << i;
return false;
}
}
outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
return true;
}
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
MS_EXCEPTION_IF_NULL(prim);
const int attr_tensor_type = attr_proto.tensors(0).data_type();
@ -1040,14 +1181,6 @@ bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_n
return true;
}
bool MSANFModelParser::ObtainValueNodeInTypeNullForm(const std::string &value_node_name) {
auto new_value_node = NewValueNode(kTypeNull);
MS_EXCEPTION_IF_NULL(new_value_node);
new_value_node->set_abstract(kTypeNull->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
const mind_ir::AttributeProto &attr_proto) {
const std::string &ref_attr_name = attr_proto.ref_attr_name();
@ -1201,85 +1334,75 @@ ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributePro
return value_sequence;
}
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
const mind_ir::AttributeProto &attr_proto) {
ValueNodePtr new_value_node;
ValuePtr MSANFModelParser::BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
switch (attr_proto.type()) {
case mind_ir::AttributeProto_AttributeType_TENSORS: {
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
const auto &tensor_proto = attr_proto.tensors(0);
if (tensor_proto.has_raw_data()) {
// For real tensor.
(void)ObtainValueNodeInTensorForm(value_node_name, tensor_proto);
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
if (tensor_info == nullptr) {
MS_LOG(ERROR) << "Failed to GenerateTensorPtrFromTensorProto.";
return nullptr;
}
return MakeValue(tensor_info);
} else {
// For data type.
(void)ObtainValueNodeInTypeForm(value_node_name, tensor_proto);
const int attr_tensor_type = tensor_proto.data_type();
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
if (iter == kDefaultValueSwitchMap.end()) {
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
return nullptr;
}
return TypeIdToType(iter->second);
}
break;
}
case mind_ir::AttributeProto_AttributeType_NONE: {
(void)ObtainValueNodeInNoneForm(value_node_name);
break;
return kNone;
}
case mind_ir::AttributeProto_AttributeType_UMONAD: {
new_value_node = NewValueNode(kUMonad);
new_value_node->set_abstract(kUMonad->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
break;
return kUMonad;
}
case mind_ir::AttributeProto_AttributeType_IOMONAD: {
new_value_node = NewValueNode(kIOMonad);
new_value_node->set_abstract(kIOMonad->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
break;
return kIOMonad;
}
case mind_ir::AttributeProto_AttributeType_TUPLE:
case mind_ir::AttributeProto_AttributeType_LIST: {
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
if (sequence_value == nullptr) {
MS_LOG(ERROR) << "Failed to get sequence value for " << value_node_name;
return false;
}
new_value_node = NewValueNode(sequence_value);
new_value_node->set_abstract(sequence_value->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
break;
return ObtainValueInSequenceForm(attr_proto);
}
case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
auto class_type = static_cast<std::string>(attr_proto.s());
auto mindir_class_type = std::make_shared<MindIRClassType>(class_type);
new_value_node = NewValueNode(mindir_class_type);
anfnode_build_map_[value_node_name] = new_value_node;
break;
return std::make_shared<MindIRClassType>(class_type);
}
case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
(void)ObtainValueNodeInTypeNullForm(value_node_name);
break;
return kTypeNull;
}
case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
auto name_space = static_cast<std::string>(attr_proto.s());
auto mindir_name_space = std::make_shared<MindIRNameSpace>(name_space);
new_value_node = NewValueNode(mindir_name_space);
anfnode_build_map_[value_node_name] = new_value_node;
break;
return std::make_shared<MindIRNameSpace>(name_space);
}
case mind_ir::AttributeProto_AttributeType_SYMBOL: {
auto symbol = static_cast<std::string>(attr_proto.s());
auto mindir_symbol = std::make_shared<MindIRSymbol>(symbol);
new_value_node = NewValueNode(mindir_symbol);
anfnode_build_map_[value_node_name] = new_value_node;
break;
return std::make_shared<MindIRSymbol>(symbol);
}
default: {
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
if (value == nullptr) {
MS_LOG(ERROR) << "Can not get the value for attr: " << value_node_name;
return false;
}
new_value_node = NewValueNode(value);
new_value_node->set_abstract(value->ToAbstract());
anfnode_build_map_[value_node_name] = new_value_node;
return ObtainCNodeAttrInSingleScalarForm(attr_proto);
}
}
}
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
const mind_ir::AttributeProto &attr_proto) {
auto value = BuildValueFromAttributeProto(attr_proto);
if (value == nullptr) {
MS_LOG(ERROR) << "Failed to build value from AttributeProto while building valuenode: " << value_node_name;
return false;
}
auto abstract = value->ToAbstract();
MS_EXCEPTION_IF_NULL(abstract);
ValueNodePtr new_value_node = NewValueNode(value);
new_value_node->set_abstract(abstract);
anfnode_build_map_[value_node_name] = new_value_node;
return true;
}
@ -1650,11 +1773,15 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const
MS_LOG(ERROR) << "Import parameters for graph fail!";
return false;
}
if (!ImportMapParametersForGraph(outputFuncGraph, importProto)) {
MS_LOG(ERROR) << "Import map parameters for graph failed!";
return false;
}
if (ImportNodesForGraph(outputFuncGraph, importProto)) {
MS_LOG(DEBUG) << "Success to parse graph: " << outputFuncGraph->ToString() << ": " << outputFuncGraph.get();
return true;
}
MS_LOG(ERROR) << "Failed to parse nodes. " << importProto.DebugString();
MS_LOG(ERROR) << "Failed to parse nodes. ";
return false;
}

View File

@ -91,8 +91,14 @@ class MSANFModelParser {
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto &parameter_proto);
bool BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
const mind_ir::MapTensorProto &map_parameter_proto);
abstract::AbstractMapTensorPtr BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
abstract::AbstractCOOTensorPtr BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
abstract::AbstractCSRTensorPtr BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
@ -108,6 +114,7 @@ class MSANFModelParser {
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
ValuePtr BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);

View File

@ -45,6 +45,7 @@ message AttributeProto {
NAME_SPACE = 34;
SYMBOL = 35;
TYPE_NULL = 36;
MAP_TENSOR = 37;
}
optional string name = 1;
optional float f = 2;
@ -128,6 +129,7 @@ message GraphProto {
optional string bprop_hash = 7;
repeated AttributeProto attribute = 8;
optional string bprop_filepath = 9;
repeated MapTensorProto map_parameter = 10;
}
@ -192,6 +194,14 @@ message TensorProto {
repeated QuantParamProto quant_params = 17;
}
message MapTensorProto {
required string name = 1;
required AttributeProto default_value = 2;
required TensorProto key_tensor = 3;
required TensorProto value_tensor = 4;
required TensorProto status_tensor = 5;
}
message ParallelProto {
repeated LayoutProto layout = 1;
}

View File

@ -1416,6 +1416,17 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
else:
logger.warning("The parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
.format(param_proto.name))
for map_param_proto in model.graph.map_parameter:
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
if map_param_name in net_dict.keys():
map_parameter = net_dict[map_param_name]
key_nparr, value_nparr, status_nparr = map_parameter.export_data()
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
else:
logger.warning("The map_parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
.format(map_param_proto.name))
if not file_name.endswith('.mindir'):
file_name += ".mindir"
current_path = os.path.abspath(file_name)

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os.path
import numpy as np
import mindspore as ms
import mindspore.nn as nn
@ -19,6 +21,7 @@ from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore.experimental import MapParameter
from mindspore.common.initializer import initializer
from mindspore.ops import composite as C
from mindspore import export, load
def test_basic_operations():
@ -237,3 +240,38 @@ def test_map_parameter_filter():
net = MyNet()
out = net()
print("out:", out)
def test_simple_graph_export_load():
"""
Feature: MapParameter
Description: Test IR graph export and load with MapParameter.
Expectation: IR graph with MapParameter exported and loaded without exceptions.
"""
class MyNet(nn.Cell):
def __init__(self):
nn.Cell.__init__(self)
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
self.key = Tensor([1, 2], dtype=ms.int32)
def construct(self, x):
self.m.put(self.key, x)
value1 = self.m.get(self.key)
value2 = self.m[self.key]
self.m[self.key] = value2
self.m.erase(self.key)
keys = self.m.get_keys()
values = self.m.get_values()
self.m.put(keys, values)
return self.p + value1 + value2
context.set_context(mode=context.GRAPH_MODE)
net = MyNet()
t = initializer('ones', (2, 3), ms.float32)
t = t.init_data()
file_path = "./map-parameter.mindir"
export(net, t, file_name=file_path, file_format="MINDIR")
assert os.path.isfile(file_path)
load(file_path)