forked from mindspore-Ecosystem/mindspore
!45263 add map-parameter export&load support
Merge pull request !45263 from hangq/fire
This commit is contained in:
commit
43c452f099
|
@ -131,6 +131,9 @@ class IrExportBuilder {
|
|||
|
||||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||
mind_ir::MapTensorProto *const map_tensor_proto);
|
||||
bool ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
|
@ -448,6 +451,13 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
std::string param_name = GetUniqueNodeName(param);
|
||||
if (top_graph && param->has_default()) {
|
||||
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
|
||||
if (param->abstract()->isa<abstract::AbstractMapTensor>()) {
|
||||
auto *map_parameter_proto = graph_proto->add_map_parameter();
|
||||
if (!ConvertMapParameterToMapTensorProto(param, map_parameter_proto)) {
|
||||
MS_LOG(ERROR) << "Convert MapParameter " << param->ToString() << " to MapTensorProto failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (param->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||
parameter_proto->set_name(param_name);
|
||||
if (!SetParamToTensorProto(param, parameter_proto)) {
|
||||
|
@ -459,6 +469,11 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
|||
parameter_proto->set_compression_type(
|
||||
static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Only support MapTensor or Tensor as default param of Parameter, got: "
|
||||
<< param->default_param()->ToString();
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||
input_proto->set_name(param_name);
|
||||
|
@ -660,6 +675,101 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::
|
|||
return SetTensorProto(param->abstract(), tensor_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||
mind_ir::MapTensorProto *const map_tensor_proto) {
|
||||
if (map_parameter == nullptr || map_tensor_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "MapParameter or MapTensorProto is null!";
|
||||
}
|
||||
MS_LOG(DEBUG) << "ConvertMapParameterToMapTensorProto: " << map_parameter->ToString();
|
||||
|
||||
// parameter name
|
||||
map_tensor_proto->set_name(GetUniqueNodeName(map_parameter));
|
||||
|
||||
auto param_default = map_parameter->default_param();
|
||||
MS_EXCEPTION_IF_NULL(param_default);
|
||||
auto map_tensor = param_default->cast<tensor::MapTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||
// default value
|
||||
auto default_value = map_tensor->default_value();
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
auto *default_value_proto = map_tensor_proto->mutable_default_value();
|
||||
MS_EXCEPTION_IF_NULL(default_value_proto);
|
||||
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
|
||||
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
|
||||
return false;
|
||||
}
|
||||
tensor::MapTensor::ExportData export_data = map_tensor->Export(true);
|
||||
// key_tensor
|
||||
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
|
||||
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||
auto &key_tensor = export_data.key_tensor;
|
||||
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||
if (!SetTensorProto(key_tensor->ToAbstract(), key_tensor_proto)) {
|
||||
MS_LOG(ERROR) << "Export key tensor of MapTensor failed, key_tensor: " << key_tensor->ToString();
|
||||
return false;
|
||||
}
|
||||
// value_tensor
|
||||
auto *value_tensor_proto = map_tensor_proto->mutable_value_tensor();
|
||||
MS_EXCEPTION_IF_NULL(value_tensor_proto);
|
||||
auto &value_tensor = export_data.value_tensor;
|
||||
MS_EXCEPTION_IF_NULL(value_tensor);
|
||||
if (!SetTensorProto(value_tensor->ToAbstract(), value_tensor_proto)) {
|
||||
MS_LOG(ERROR) << "Export value tensor of MapTensor failed, value_tensor: " << value_tensor->ToString();
|
||||
return false;
|
||||
}
|
||||
// status_tensor
|
||||
auto *status_tensor_proto = map_tensor_proto->mutable_status_tensor();
|
||||
MS_EXCEPTION_IF_NULL(status_tensor_proto);
|
||||
auto &status_tensor = export_data.status_tensor;
|
||||
MS_EXCEPTION_IF_NULL(status_tensor);
|
||||
if (!SetTensorProto(status_tensor->ToAbstract(), status_tensor_proto)) {
|
||||
MS_LOG(ERROR) << "Export status tensor of MapTensor failed, status_tensor: " << status_tensor->ToString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract,
|
||||
mind_ir::AttributeProto *const attr_proto) {
|
||||
auto map_tensor_abs = abstract->cast<abstract::AbstractMapTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(map_tensor_abs);
|
||||
|
||||
auto map_tensor_type = map_tensor_abs->map_tensor_type();
|
||||
MS_EXCEPTION_IF_NULL(map_tensor_type);
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_MAP_TENSOR);
|
||||
// key_tensor
|
||||
auto key_dtype = map_tensor_type->key_dtype();
|
||||
auto key_shape = {abstract::Shape::kShapeDimAny};
|
||||
auto key_tensor_abs = std::make_shared<abstract::AbstractTensor>(key_dtype, key_shape);
|
||||
auto *key_tensor_proto = attr_proto->add_tensors();
|
||||
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||
if (!SetTensorProto(key_tensor_abs, key_tensor_proto)) {
|
||||
MS_LOG(ERROR) << "Export key tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
|
||||
<< abstract->ToString();
|
||||
return false;
|
||||
}
|
||||
// value_dtype value_shape
|
||||
auto value_dtype = map_tensor_type->key_dtype();
|
||||
auto value_shape = map_tensor_abs->value_shape()->shape();
|
||||
auto value_tensor_abs = std::make_shared<abstract::AbstractTensor>(value_dtype, value_shape);
|
||||
auto *value_tensor_proto = attr_proto->add_tensors();
|
||||
MS_EXCEPTION_IF_NULL(value_tensor_proto);
|
||||
if (!SetTensorProto(value_tensor_abs, value_tensor_proto)) {
|
||||
MS_LOG(ERROR) << "Export value tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
|
||||
<< abstract->ToString();
|
||||
return false;
|
||||
}
|
||||
// default_value
|
||||
auto default_value = map_tensor_abs->default_value();
|
||||
auto *default_value_proto = attr_proto->add_values();
|
||||
MS_EXCEPTION_IF_NULL(default_value_proto);
|
||||
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
|
||||
MS_LOG(ERROR) << "Export default value of AbstractMapTensor failed, abstract_map_tensor: " << abstract->ToString();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||
for (const AnfNodePtr &node : nodes) {
|
||||
|
@ -799,6 +909,8 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
|
|||
}
|
||||
} else if (type->isa<TypeNone>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
|
||||
} else if (type->isa<MapTensorType>()) {
|
||||
return ConvertAbstractMapTensorToAttrProto(abs, attr_proto);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
||||
return false;
|
||||
|
|
|
@ -31,4 +31,16 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
|
|||
abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
||||
|
||||
abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
|
||||
|
||||
abstract::AbstractBasePtr MindIRClassType::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRClassType>(), std::make_shared<TypeType>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr MindIRNameSpace::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRNameSpace>(), std::make_shared<External>());
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr MindIRSymbol::ToAbstract() {
|
||||
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRSymbol>(), std::make_shared<External>());
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -152,11 +152,12 @@ class MS_CORE_API Ellipsis final : public Named {
|
|||
};
|
||||
GVAR_DEF(NamedPtr, kEllipsis, std::make_shared<Ellipsis>());
|
||||
|
||||
class MindIRClassType final : public Named {
|
||||
class MS_CORE_API MindIRClassType final : public Named {
|
||||
public:
|
||||
explicit MindIRClassType(const std::string &class_type) : Named(class_type) {}
|
||||
~MindIRClassType() override = default;
|
||||
MS_DECLARE_PARENT(MindIRClassType, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
using MindIRClassTypePtr = std::shared_ptr<MindIRClassType>;
|
||||
|
||||
|
@ -168,24 +169,26 @@ class MindIRMetaFuncGraph final : public Named {
|
|||
};
|
||||
using MindIRMetaFuncGraphPtr = std::shared_ptr<MindIRMetaFuncGraph>;
|
||||
|
||||
class MindIRNameSpace final : public Named {
|
||||
class MS_CORE_API MindIRNameSpace final : public Named {
|
||||
public:
|
||||
explicit MindIRNameSpace(const std::string &name_space) : Named(name_space), name_space_(name_space) {}
|
||||
~MindIRNameSpace() override = default;
|
||||
MS_DECLARE_PARENT(MindIRNameSpace, Named);
|
||||
const std::string &name_space() const { return name_space_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
private:
|
||||
std::string name_space_;
|
||||
};
|
||||
using MindIRNameSpacePtr = std::shared_ptr<MindIRNameSpace>;
|
||||
|
||||
class MindIRSymbol final : public Named {
|
||||
class MS_CORE_API MindIRSymbol final : public Named {
|
||||
public:
|
||||
explicit MindIRSymbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
|
||||
~MindIRSymbol() override = default;
|
||||
MS_DECLARE_PARENT(MindIRSymbol, Named);
|
||||
const std::string &symbol() const { return symbol_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
|
||||
private:
|
||||
std::string symbol_;
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <algorithm>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/map_tensor.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
@ -298,13 +299,13 @@ tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_
|
|||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
|
||||
errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
|
||||
MS_LOG(ERROR) << "Failed to copy data from tensor proto.";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (attr_tensor.has_external_data()) {
|
||||
auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed Load data from external.";
|
||||
MS_LOG(ERROR) << "Failed to get external data from tensor proto.";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
|
@ -321,28 +322,13 @@ abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType
|
|||
return GetAbsTensorFromTensorProto(attr_tensor);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(abs);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractCSRTensor>(vec);
|
||||
return BuildAbstractCSRTensorFromAttrProto(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||
return nullptr;
|
||||
return BuildAbstractCOOTensorFromAttrProto(attr_proto);
|
||||
}
|
||||
(void)vec.emplace_back(abs);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractCOOTensor>(vec);
|
||||
case mind_ir::AttributeProto_AttributeType_MAP_TENSOR: {
|
||||
return BuildAbstractMapTensorFromAttrProto(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE: {
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
|
@ -522,6 +508,146 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
|||
return true;
|
||||
}
|
||||
|
||||
abstract::AbstractCOOTensorPtr MSANFModelParser::BuildAbstractCOOTensorFromAttrProto(
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(abs);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractCOOTensor>(vec);
|
||||
}
|
||||
|
||||
abstract::AbstractCSRTensorPtr MSANFModelParser::BuildAbstractCSRTensorFromAttrProto(
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
std::vector<abstract::AbstractBasePtr> vec;
|
||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||
if (abs == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(abs);
|
||||
}
|
||||
return std::make_shared<abstract::AbstractCSRTensor>(vec);
|
||||
}
|
||||
|
||||
abstract::AbstractMapTensorPtr MSANFModelParser::BuildAbstractMapTensorFromAttrProto(
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
// default value
|
||||
if (attr_proto.values_size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 1 value, but got " << attr_proto.values_size();
|
||||
}
|
||||
const auto &default_value_proto = attr_proto.values(0);
|
||||
auto default_value = ObtainCNodeAttrInSingleScalarForm(default_value_proto);
|
||||
MS_EXCEPTION_IF_NULL(default_value);
|
||||
|
||||
constexpr size_t kAbstractMapTensorAttrProtoTensorsSize = 2;
|
||||
if (attr_proto.tensors_size() != kAbstractMapTensorAttrProtoTensorsSize) {
|
||||
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 2 tensors, but got " << attr_proto.tensors_size();
|
||||
}
|
||||
// key tensor
|
||||
const auto &key_tensor_proto = attr_proto.tensors(0);
|
||||
auto key_tensor_abs = GetAbsTensorFromTensorProto(key_tensor_proto);
|
||||
MS_EXCEPTION_IF_NULL(key_tensor_abs);
|
||||
// value tensor
|
||||
const auto &value_tensor_proto = attr_proto.tensors(1);
|
||||
auto value_tensor_abs = GetAbsTensorFromTensorProto(value_tensor_proto);
|
||||
MS_EXCEPTION_IF_NULL(value_tensor_abs);
|
||||
auto value_build_shape_ptr = value_tensor_abs->BuildShape();
|
||||
if (!value_build_shape_ptr->isa<abstract::Shape>()) {
|
||||
MS_LOG(EXCEPTION) << "value_shape of AbstractMapTensor should be a Shape, but got "
|
||||
<< value_build_shape_ptr->ToString();
|
||||
}
|
||||
auto value_shape_ptr = value_build_shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_shape_ptr);
|
||||
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor_abs->BuildType()->type_id(),
|
||||
value_tensor_abs->BuildType()->type_id(),
|
||||
value_shape_ptr->shape(), default_value);
|
||||
return std::make_shared<abstract::AbstractMapTensor>(map_tensor);
|
||||
}
|
||||
|
||||
bool MSANFModelParser::BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
|
||||
const mind_ir::MapTensorProto &map_parameter_proto) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!map_parameter_proto.has_name()) {
|
||||
MS_LOG(ERROR) << "mind_ir MapTensorProto has no name!";
|
||||
return false;
|
||||
}
|
||||
|
||||
string debug_info_name = ParseParameterName(map_parameter_proto.name());
|
||||
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||
node->set_debug_info(debug_info_ptr);
|
||||
node->set_name(debug_info_name);
|
||||
|
||||
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
||||
param_info->set_name(debug_info_name);
|
||||
|
||||
MS_LOG(DEBUG) << "Load map parameter name: " << map_parameter_proto.name();
|
||||
if (IsIncLoad() && load_tensor_map_.find(map_parameter_proto.name()) != load_tensor_map_.end()) {
|
||||
MS_LOG(ERROR) << "MapParameter dose not support incremental loading, param_name: " << map_parameter_proto.name();
|
||||
return false;
|
||||
}
|
||||
|
||||
// default value
|
||||
if (!map_parameter_proto.has_default_value()) {
|
||||
MS_LOG(ERROR) << "MapTensorProto should have default value: " << map_parameter_proto.name();
|
||||
return false;
|
||||
}
|
||||
const auto &default_value_proto = map_parameter_proto.default_value();
|
||||
auto default_value = BuildValueFromAttributeProto(default_value_proto);
|
||||
if (default_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Build default value from AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
// key tensor
|
||||
if (!map_parameter_proto.has_key_tensor()) {
|
||||
MS_LOG(ERROR) << "MapTensorProto should have key tensor: " << map_parameter_proto.name();
|
||||
return false;
|
||||
}
|
||||
const auto &key_tensor_proto = map_parameter_proto.key_tensor();
|
||||
auto key_tensor = GenerateTensorPtrFromTensorProto(key_tensor_proto);
|
||||
if (key_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Generate key tensor from TensorProto failed.";
|
||||
return false;
|
||||
}
|
||||
// value tensor
|
||||
if (!map_parameter_proto.has_value_tensor()) {
|
||||
MS_LOG(ERROR) << "MapTensorProto should have value tensor: " << map_parameter_proto.name();
|
||||
return false;
|
||||
}
|
||||
const auto &value_tensor_proto = map_parameter_proto.value_tensor();
|
||||
auto value_tensor = GenerateTensorPtrFromTensorProto(value_tensor_proto);
|
||||
if (value_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Generate value tensor from TensorProto failed.";
|
||||
return false;
|
||||
}
|
||||
// status tensor
|
||||
if (!map_parameter_proto.has_status_tensor()) {
|
||||
MS_LOG(ERROR) << "MapTensorProto should have status tensor: " << map_parameter_proto.name();
|
||||
return false;
|
||||
}
|
||||
const auto &status_tensor_proto = map_parameter_proto.status_tensor();
|
||||
auto status_tensor = GenerateTensorPtrFromTensorProto(status_tensor_proto);
|
||||
if (status_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Generate status tensor from TensorProto failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor, value_tensor, status_tensor, default_value);
|
||||
map_tensor->set_param_info(param_info);
|
||||
node->set_default_param(map_tensor);
|
||||
node->set_abstract(map_tensor->ToAbstract());
|
||||
|
||||
anfnode_build_map_[map_parameter_proto.name()] = node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
|
||||
const tensor::TensorPtr &tensor_info) {
|
||||
if (!tensor_proto.has_external_data()) {
|
||||
|
@ -678,6 +804,21 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||
const mind_ir::GraphProto &importProto) {
|
||||
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||
MS_LOG(INFO) << "All MapParameters size is: " << importProto.map_parameter_size();
|
||||
for (int i = 0; i < importProto.map_parameter_size(); ++i) {
|
||||
const mind_ir::MapTensorProto &map_parameter_proto = importProto.map_parameter(i);
|
||||
if (!BuildMapParameterFromMapTensorProto(outputFuncGraph->add_parameter(), map_parameter_proto)) {
|
||||
MS_LOG(ERROR) << "Build map parameter for funcgraph fail at index: " << i;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
const int attr_tensor_type = attr_proto.tensors(0).data_type();
|
||||
|
@ -1040,14 +1181,6 @@ bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_n
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInTypeNullForm(const std::string &value_node_name) {
|
||||
auto new_value_node = NewValueNode(kTypeNull);
|
||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
||||
new_value_node->set_abstract(kTypeNull->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
|
@ -1201,85 +1334,75 @@ ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributePro
|
|||
return value_sequence;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
ValueNodePtr new_value_node;
|
||||
ValuePtr MSANFModelParser::BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
|
||||
switch (attr_proto.type()) {
|
||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
||||
const auto &tensor_proto = attr_proto.tensors(0);
|
||||
if (tensor_proto.has_raw_data()) {
|
||||
// For real tensor.
|
||||
(void)ObtainValueNodeInTensorForm(value_node_name, tensor_proto);
|
||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to GenerateTensorPtrFromTensorProto.";
|
||||
return nullptr;
|
||||
}
|
||||
return MakeValue(tensor_info);
|
||||
} else {
|
||||
// For data type.
|
||||
(void)ObtainValueNodeInTypeForm(value_node_name, tensor_proto);
|
||||
const int attr_tensor_type = tensor_proto.data_type();
|
||||
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||
if (iter == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return nullptr;
|
||||
}
|
||||
return TypeIdToType(iter->second);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_NONE: {
|
||||
(void)ObtainValueNodeInNoneForm(value_node_name);
|
||||
break;
|
||||
return kNone;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_UMONAD: {
|
||||
new_value_node = NewValueNode(kUMonad);
|
||||
new_value_node->set_abstract(kUMonad->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return kUMonad;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_IOMONAD: {
|
||||
new_value_node = NewValueNode(kIOMonad);
|
||||
new_value_node->set_abstract(kIOMonad->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return kIOMonad;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << value_node_name;
|
||||
return false;
|
||||
}
|
||||
new_value_node = NewValueNode(sequence_value);
|
||||
new_value_node->set_abstract(sequence_value->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return ObtainValueInSequenceForm(attr_proto);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
|
||||
auto class_type = static_cast<std::string>(attr_proto.s());
|
||||
auto mindir_class_type = std::make_shared<MindIRClassType>(class_type);
|
||||
new_value_node = NewValueNode(mindir_class_type);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return std::make_shared<MindIRClassType>(class_type);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
|
||||
(void)ObtainValueNodeInTypeNullForm(value_node_name);
|
||||
break;
|
||||
return kTypeNull;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
|
||||
auto name_space = static_cast<std::string>(attr_proto.s());
|
||||
auto mindir_name_space = std::make_shared<MindIRNameSpace>(name_space);
|
||||
new_value_node = NewValueNode(mindir_name_space);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return std::make_shared<MindIRNameSpace>(name_space);
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_SYMBOL: {
|
||||
auto symbol = static_cast<std::string>(attr_proto.s());
|
||||
auto mindir_symbol = std::make_shared<MindIRSymbol>(symbol);
|
||||
new_value_node = NewValueNode(mindir_symbol);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
return std::make_shared<MindIRSymbol>(symbol);
|
||||
}
|
||||
default: {
|
||||
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
return ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
auto value = BuildValueFromAttributeProto(attr_proto);
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Can not get the value for attr: " << value_node_name;
|
||||
MS_LOG(ERROR) << "Failed to build value from AttributeProto while building valuenode: " << value_node_name;
|
||||
return false;
|
||||
}
|
||||
new_value_node = NewValueNode(value);
|
||||
new_value_node->set_abstract(value->ToAbstract());
|
||||
auto abstract = value->ToAbstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
ValueNodePtr new_value_node = NewValueNode(value);
|
||||
new_value_node->set_abstract(abstract);
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1650,11 +1773,15 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const
|
|||
MS_LOG(ERROR) << "Import parameters for graph fail!";
|
||||
return false;
|
||||
}
|
||||
if (!ImportMapParametersForGraph(outputFuncGraph, importProto)) {
|
||||
MS_LOG(ERROR) << "Import map parameters for graph failed!";
|
||||
return false;
|
||||
}
|
||||
if (ImportNodesForGraph(outputFuncGraph, importProto)) {
|
||||
MS_LOG(DEBUG) << "Success to parse graph: " << outputFuncGraph->ToString() << ": " << outputFuncGraph.get();
|
||||
return true;
|
||||
}
|
||||
MS_LOG(ERROR) << "Failed to parse nodes. " << importProto.DebugString();
|
||||
MS_LOG(ERROR) << "Failed to parse nodes. ";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -91,8 +91,14 @@ class MSANFModelParser {
|
|||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto ¶meter_proto);
|
||||
bool BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
|
||||
const mind_ir::MapTensorProto &map_parameter_proto);
|
||||
abstract::AbstractMapTensorPtr BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||
abstract::AbstractCOOTensorPtr BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||
abstract::AbstractCSRTensorPtr BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
|
||||
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
|
||||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||
|
@ -108,6 +114,7 @@ class MSANFModelParser {
|
|||
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||
ValuePtr BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
|
||||
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
|
||||
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
|
||||
void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
|
||||
|
|
|
@ -45,6 +45,7 @@ message AttributeProto {
|
|||
NAME_SPACE = 34;
|
||||
SYMBOL = 35;
|
||||
TYPE_NULL = 36;
|
||||
MAP_TENSOR = 37;
|
||||
}
|
||||
optional string name = 1;
|
||||
optional float f = 2;
|
||||
|
@ -128,6 +129,7 @@ message GraphProto {
|
|||
optional string bprop_hash = 7;
|
||||
repeated AttributeProto attribute = 8;
|
||||
optional string bprop_filepath = 9;
|
||||
repeated MapTensorProto map_parameter = 10;
|
||||
}
|
||||
|
||||
|
||||
|
@ -192,6 +194,14 @@ message TensorProto {
|
|||
repeated QuantParamProto quant_params = 17;
|
||||
}
|
||||
|
||||
message MapTensorProto {
|
||||
required string name = 1;
|
||||
required AttributeProto default_value = 2;
|
||||
required TensorProto key_tensor = 3;
|
||||
required TensorProto value_tensor = 4;
|
||||
required TensorProto status_tensor = 5;
|
||||
}
|
||||
|
||||
message ParallelProto {
|
||||
repeated LayoutProto layout = 1;
|
||||
}
|
||||
|
|
|
@ -1416,6 +1416,17 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
|||
else:
|
||||
logger.warning("The parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
|
||||
.format(param_proto.name))
|
||||
for map_param_proto in model.graph.map_parameter:
|
||||
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
||||
if map_param_name in net_dict.keys():
|
||||
map_parameter = net_dict[map_param_name]
|
||||
key_nparr, value_nparr, status_nparr = map_parameter.export_data()
|
||||
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
||||
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
||||
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
||||
else:
|
||||
logger.warning("The map_parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
|
||||
.format(map_param_proto.name))
|
||||
if not file_name.endswith('.mindir'):
|
||||
file_name += ".mindir"
|
||||
current_path = os.path.abspath(file_name)
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os.path
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
@ -19,6 +21,7 @@ from mindspore import context, Tensor, Parameter, ParameterTuple
|
|||
from mindspore.experimental import MapParameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import export, load
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
|
@ -237,3 +240,38 @@ def test_map_parameter_filter():
|
|||
net = MyNet()
|
||||
out = net()
|
||||
print("out:", out)
|
||||
|
||||
|
||||
def test_simple_graph_export_load():
|
||||
"""
|
||||
Feature: MapParameter
|
||||
Description: Test IR graph export and load with MapParameter.
|
||||
Expectation: IR graph with MapParameter exported and loaded without exceptions.
|
||||
"""
|
||||
|
||||
class MyNet(nn.Cell):
|
||||
def __init__(self):
|
||||
nn.Cell.__init__(self)
|
||||
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
|
||||
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||
self.key = Tensor([1, 2], dtype=ms.int32)
|
||||
|
||||
def construct(self, x):
|
||||
self.m.put(self.key, x)
|
||||
value1 = self.m.get(self.key)
|
||||
value2 = self.m[self.key]
|
||||
self.m[self.key] = value2
|
||||
self.m.erase(self.key)
|
||||
keys = self.m.get_keys()
|
||||
values = self.m.get_values()
|
||||
self.m.put(keys, values)
|
||||
return self.p + value1 + value2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = MyNet()
|
||||
t = initializer('ones', (2, 3), ms.float32)
|
||||
t = t.init_data()
|
||||
file_path = "./map-parameter.mindir"
|
||||
export(net, t, file_name=file_path, file_format="MINDIR")
|
||||
assert os.path.isfile(file_path)
|
||||
load(file_path)
|
||||
|
|
Loading…
Reference in New Issue