forked from mindspore-Ecosystem/mindspore
!45263 add map-parameter export&load support
Merge pull request !45263 from hangq/fire
This commit is contained in:
commit
43c452f099
|
@ -131,6 +131,9 @@ class IrExportBuilder {
|
||||||
|
|
||||||
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
bool SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
|
||||||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||||
|
bool ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||||
|
mind_ir::MapTensorProto *const map_tensor_proto);
|
||||||
|
bool ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||||
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
|
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
|
||||||
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||||
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||||
|
@ -448,17 +451,29 @@ bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G
|
||||||
std::string param_name = GetUniqueNodeName(param);
|
std::string param_name = GetUniqueNodeName(param);
|
||||||
if (top_graph && param->has_default()) {
|
if (top_graph && param->has_default()) {
|
||||||
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
|
MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
|
||||||
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
if (param->abstract()->isa<abstract::AbstractMapTensor>()) {
|
||||||
parameter_proto->set_name(param_name);
|
auto *map_parameter_proto = graph_proto->add_map_parameter();
|
||||||
if (!SetParamToTensorProto(param, parameter_proto)) {
|
if (!ConvertMapParameterToMapTensorProto(param, map_parameter_proto)) {
|
||||||
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
|
MS_LOG(ERROR) << "Convert MapParameter " << param->ToString() << " to MapTensorProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (param->abstract()->isa<abstract::AbstractTensor>()) {
|
||||||
|
mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
|
||||||
|
parameter_proto->set_name(param_name);
|
||||||
|
if (!SetParamToTensorProto(param, parameter_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto tensor = param->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
if (tensor != nullptr) {
|
||||||
|
parameter_proto->set_compression_type(
|
||||||
|
static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Only support MapTensor or Tensor as default param of Parameter, got: "
|
||||||
|
<< param->default_param()->ToString();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto tensor = param->default_param()->cast<tensor::TensorPtr>();
|
|
||||||
if (tensor != nullptr) {
|
|
||||||
parameter_proto->set_compression_type(
|
|
||||||
static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
|
||||||
input_proto->set_name(param_name);
|
input_proto->set_name(param_name);
|
||||||
|
@ -660,6 +675,101 @@ bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::
|
||||||
return SetTensorProto(param->abstract(), tensor_proto);
|
return SetTensorProto(param->abstract(), tensor_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
|
||||||
|
mind_ir::MapTensorProto *const map_tensor_proto) {
|
||||||
|
if (map_parameter == nullptr || map_tensor_proto == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "MapParameter or MapTensorProto is null!";
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "ConvertMapParameterToMapTensorProto: " << map_parameter->ToString();
|
||||||
|
|
||||||
|
// parameter name
|
||||||
|
map_tensor_proto->set_name(GetUniqueNodeName(map_parameter));
|
||||||
|
|
||||||
|
auto param_default = map_parameter->default_param();
|
||||||
|
MS_EXCEPTION_IF_NULL(param_default);
|
||||||
|
auto map_tensor = param_default->cast<tensor::MapTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(map_tensor);
|
||||||
|
// default value
|
||||||
|
auto default_value = map_tensor->default_value();
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value);
|
||||||
|
auto *default_value_proto = map_tensor_proto->mutable_default_value();
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value_proto);
|
||||||
|
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tensor::MapTensor::ExportData export_data = map_tensor->Export(true);
|
||||||
|
// key_tensor
|
||||||
|
auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
|
||||||
|
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||||
|
auto &key_tensor = export_data.key_tensor;
|
||||||
|
MS_EXCEPTION_IF_NULL(key_tensor);
|
||||||
|
if (!SetTensorProto(key_tensor->ToAbstract(), key_tensor_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export key tensor of MapTensor failed, key_tensor: " << key_tensor->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// value_tensor
|
||||||
|
auto *value_tensor_proto = map_tensor_proto->mutable_value_tensor();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_tensor_proto);
|
||||||
|
auto &value_tensor = export_data.value_tensor;
|
||||||
|
MS_EXCEPTION_IF_NULL(value_tensor);
|
||||||
|
if (!SetTensorProto(value_tensor->ToAbstract(), value_tensor_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export value tensor of MapTensor failed, value_tensor: " << value_tensor->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// status_tensor
|
||||||
|
auto *status_tensor_proto = map_tensor_proto->mutable_status_tensor();
|
||||||
|
MS_EXCEPTION_IF_NULL(status_tensor_proto);
|
||||||
|
auto &status_tensor = export_data.status_tensor;
|
||||||
|
MS_EXCEPTION_IF_NULL(status_tensor);
|
||||||
|
if (!SetTensorProto(status_tensor->ToAbstract(), status_tensor_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export status tensor of MapTensor failed, status_tensor: " << status_tensor->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IrExportBuilder::ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract,
|
||||||
|
mind_ir::AttributeProto *const attr_proto) {
|
||||||
|
auto map_tensor_abs = abstract->cast<abstract::AbstractMapTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(map_tensor_abs);
|
||||||
|
|
||||||
|
auto map_tensor_type = map_tensor_abs->map_tensor_type();
|
||||||
|
MS_EXCEPTION_IF_NULL(map_tensor_type);
|
||||||
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_MAP_TENSOR);
|
||||||
|
// key_tensor
|
||||||
|
auto key_dtype = map_tensor_type->key_dtype();
|
||||||
|
auto key_shape = {abstract::Shape::kShapeDimAny};
|
||||||
|
auto key_tensor_abs = std::make_shared<abstract::AbstractTensor>(key_dtype, key_shape);
|
||||||
|
auto *key_tensor_proto = attr_proto->add_tensors();
|
||||||
|
MS_EXCEPTION_IF_NULL(key_tensor_proto);
|
||||||
|
if (!SetTensorProto(key_tensor_abs, key_tensor_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export key tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
|
||||||
|
<< abstract->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// value_dtype value_shape
|
||||||
|
auto value_dtype = map_tensor_type->key_dtype();
|
||||||
|
auto value_shape = map_tensor_abs->value_shape()->shape();
|
||||||
|
auto value_tensor_abs = std::make_shared<abstract::AbstractTensor>(value_dtype, value_shape);
|
||||||
|
auto *value_tensor_proto = attr_proto->add_tensors();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_tensor_proto);
|
||||||
|
if (!SetTensorProto(value_tensor_abs, value_tensor_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export value tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
|
||||||
|
<< abstract->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// default_value
|
||||||
|
auto default_value = map_tensor_abs->default_value();
|
||||||
|
auto *default_value_proto = attr_proto->add_values();
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value_proto);
|
||||||
|
if (!SetValueToAttributeProto(default_value, default_value_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Export default value of AbstractMapTensor failed, abstract_map_tensor: " << abstract->ToString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
|
||||||
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
|
||||||
for (const AnfNodePtr &node : nodes) {
|
for (const AnfNodePtr &node : nodes) {
|
||||||
|
@ -799,6 +909,8 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
|
||||||
}
|
}
|
||||||
} else if (type->isa<TypeNone>()) {
|
} else if (type->isa<TypeNone>()) {
|
||||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
|
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
|
||||||
|
} else if (type->isa<MapTensorType>()) {
|
||||||
|
return ConvertAbstractMapTensorToAttrProto(abs, attr_proto);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -31,4 +31,16 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
|
||||||
abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
||||||
|
|
||||||
abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
|
abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MindIRClassType::ToAbstract() {
|
||||||
|
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRClassType>(), std::make_shared<TypeType>());
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MindIRNameSpace::ToAbstract() {
|
||||||
|
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRNameSpace>(), std::make_shared<External>());
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr MindIRSymbol::ToAbstract() {
|
||||||
|
return std::make_shared<abstract::AbstractScalar>(shared_from_base<MindIRSymbol>(), std::make_shared<External>());
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -152,11 +152,12 @@ class MS_CORE_API Ellipsis final : public Named {
|
||||||
};
|
};
|
||||||
GVAR_DEF(NamedPtr, kEllipsis, std::make_shared<Ellipsis>());
|
GVAR_DEF(NamedPtr, kEllipsis, std::make_shared<Ellipsis>());
|
||||||
|
|
||||||
class MindIRClassType final : public Named {
|
class MS_CORE_API MindIRClassType final : public Named {
|
||||||
public:
|
public:
|
||||||
explicit MindIRClassType(const std::string &class_type) : Named(class_type) {}
|
explicit MindIRClassType(const std::string &class_type) : Named(class_type) {}
|
||||||
~MindIRClassType() override = default;
|
~MindIRClassType() override = default;
|
||||||
MS_DECLARE_PARENT(MindIRClassType, Named);
|
MS_DECLARE_PARENT(MindIRClassType, Named);
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
};
|
};
|
||||||
using MindIRClassTypePtr = std::shared_ptr<MindIRClassType>;
|
using MindIRClassTypePtr = std::shared_ptr<MindIRClassType>;
|
||||||
|
|
||||||
|
@ -168,24 +169,26 @@ class MindIRMetaFuncGraph final : public Named {
|
||||||
};
|
};
|
||||||
using MindIRMetaFuncGraphPtr = std::shared_ptr<MindIRMetaFuncGraph>;
|
using MindIRMetaFuncGraphPtr = std::shared_ptr<MindIRMetaFuncGraph>;
|
||||||
|
|
||||||
class MindIRNameSpace final : public Named {
|
class MS_CORE_API MindIRNameSpace final : public Named {
|
||||||
public:
|
public:
|
||||||
explicit MindIRNameSpace(const std::string &name_space) : Named(name_space), name_space_(name_space) {}
|
explicit MindIRNameSpace(const std::string &name_space) : Named(name_space), name_space_(name_space) {}
|
||||||
~MindIRNameSpace() override = default;
|
~MindIRNameSpace() override = default;
|
||||||
MS_DECLARE_PARENT(MindIRNameSpace, Named);
|
MS_DECLARE_PARENT(MindIRNameSpace, Named);
|
||||||
const std::string &name_space() const { return name_space_; }
|
const std::string &name_space() const { return name_space_; }
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_space_;
|
std::string name_space_;
|
||||||
};
|
};
|
||||||
using MindIRNameSpacePtr = std::shared_ptr<MindIRNameSpace>;
|
using MindIRNameSpacePtr = std::shared_ptr<MindIRNameSpace>;
|
||||||
|
|
||||||
class MindIRSymbol final : public Named {
|
class MS_CORE_API MindIRSymbol final : public Named {
|
||||||
public:
|
public:
|
||||||
explicit MindIRSymbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
|
explicit MindIRSymbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {}
|
||||||
~MindIRSymbol() override = default;
|
~MindIRSymbol() override = default;
|
||||||
MS_DECLARE_PARENT(MindIRSymbol, Named);
|
MS_DECLARE_PARENT(MindIRSymbol, Named);
|
||||||
const std::string &symbol() const { return symbol_; }
|
const std::string &symbol() const { return symbol_; }
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string symbol_;
|
std::string symbol_;
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
|
#include "ir/map_tensor.h"
|
||||||
#include "ops/primitive_c.h"
|
#include "ops/primitive_c.h"
|
||||||
#include "abstract/abstract_value.h"
|
#include "abstract/abstract_value.h"
|
||||||
#include "abstract/ops/primitive_infer_map.h"
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
@ -298,13 +299,13 @@ tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_
|
||||||
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
|
auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor->data_c());
|
||||||
errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
errno_t ret = memcpy_s(tensor_data_buf, tensor->data().nbytes(), tensor_buf.data(), tensor_buf.size());
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
MS_LOG(ERROR) << "Failed to get tensor form tensor proto.";
|
MS_LOG(ERROR) << "Failed to copy data from tensor proto.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} else if (attr_tensor.has_external_data()) {
|
} else if (attr_tensor.has_external_data()) {
|
||||||
auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
|
auto ret = GetTensorDataFromExternal(attr_tensor, tensor);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(ERROR) << "Failed Load data from external.";
|
MS_LOG(ERROR) << "Failed to get external data from tensor proto.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -321,28 +322,13 @@ abstract::AbstractBasePtr MSANFModelParser::GetNodeAbstractFromAttrProtoWithType
|
||||||
return GetAbsTensorFromTensorProto(attr_tensor);
|
return GetAbsTensorFromTensorProto(attr_tensor);
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
|
case mind_ir::AttributeProto_AttributeType_CSR_TENSOR: {
|
||||||
std::vector<abstract::AbstractBasePtr> vec;
|
return BuildAbstractCSRTensorFromAttrProto(attr_proto);
|
||||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
|
||||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
|
||||||
if (abs == nullptr) {
|
|
||||||
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
(void)vec.emplace_back(abs);
|
|
||||||
}
|
|
||||||
return std::make_shared<abstract::AbstractCSRTensor>(vec);
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
|
case mind_ir::AttributeProto_AttributeType_COO_TENSOR: {
|
||||||
std::vector<abstract::AbstractBasePtr> vec;
|
return BuildAbstractCOOTensorFromAttrProto(attr_proto);
|
||||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
}
|
||||||
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
case mind_ir::AttributeProto_AttributeType_MAP_TENSOR: {
|
||||||
if (abs == nullptr) {
|
return BuildAbstractMapTensorFromAttrProto(attr_proto);
|
||||||
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
(void)vec.emplace_back(abs);
|
|
||||||
}
|
|
||||||
return std::make_shared<abstract::AbstractCOOTensor>(vec);
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_TUPLE: {
|
case mind_ir::AttributeProto_AttributeType_TUPLE: {
|
||||||
std::vector<abstract::AbstractBasePtr> vec;
|
std::vector<abstract::AbstractBasePtr> vec;
|
||||||
|
@ -522,6 +508,146 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
abstract::AbstractCOOTensorPtr MSANFModelParser::BuildAbstractCOOTensorFromAttrProto(
|
||||||
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
|
std::vector<abstract::AbstractBasePtr> vec;
|
||||||
|
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||||
|
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||||
|
if (abs == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "Failed to get the COOTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
(void)vec.emplace_back(abs);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::AbstractCOOTensor>(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::AbstractCSRTensorPtr MSANFModelParser::BuildAbstractCSRTensorFromAttrProto(
|
||||||
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
|
std::vector<abstract::AbstractBasePtr> vec;
|
||||||
|
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||||
|
auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto.values(i));
|
||||||
|
if (abs == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "Failed to get the CSRTensor's abstract from AttrProto. " << attr_proto.DebugString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
(void)vec.emplace_back(abs);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::AbstractCSRTensor>(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract::AbstractMapTensorPtr MSANFModelParser::BuildAbstractMapTensorFromAttrProto(
|
||||||
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
|
// default value
|
||||||
|
if (attr_proto.values_size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 1 value, but got " << attr_proto.values_size();
|
||||||
|
}
|
||||||
|
const auto &default_value_proto = attr_proto.values(0);
|
||||||
|
auto default_value = ObtainCNodeAttrInSingleScalarForm(default_value_proto);
|
||||||
|
MS_EXCEPTION_IF_NULL(default_value);
|
||||||
|
|
||||||
|
constexpr size_t kAbstractMapTensorAttrProtoTensorsSize = 2;
|
||||||
|
if (attr_proto.tensors_size() != kAbstractMapTensorAttrProtoTensorsSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "AttrProto for AbstractMapTensor should has 2 tensors, but got " << attr_proto.tensors_size();
|
||||||
|
}
|
||||||
|
// key tensor
|
||||||
|
const auto &key_tensor_proto = attr_proto.tensors(0);
|
||||||
|
auto key_tensor_abs = GetAbsTensorFromTensorProto(key_tensor_proto);
|
||||||
|
MS_EXCEPTION_IF_NULL(key_tensor_abs);
|
||||||
|
// value tensor
|
||||||
|
const auto &value_tensor_proto = attr_proto.tensors(1);
|
||||||
|
auto value_tensor_abs = GetAbsTensorFromTensorProto(value_tensor_proto);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_tensor_abs);
|
||||||
|
auto value_build_shape_ptr = value_tensor_abs->BuildShape();
|
||||||
|
if (!value_build_shape_ptr->isa<abstract::Shape>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "value_shape of AbstractMapTensor should be a Shape, but got "
|
||||||
|
<< value_build_shape_ptr->ToString();
|
||||||
|
}
|
||||||
|
auto value_shape_ptr = value_build_shape_ptr->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_shape_ptr);
|
||||||
|
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor_abs->BuildType()->type_id(),
|
||||||
|
value_tensor_abs->BuildType()->type_id(),
|
||||||
|
value_shape_ptr->shape(), default_value);
|
||||||
|
return std::make_shared<abstract::AbstractMapTensor>(map_tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
|
||||||
|
const mind_ir::MapTensorProto &map_parameter_proto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
|
if (!map_parameter_proto.has_name()) {
|
||||||
|
MS_LOG(ERROR) << "mind_ir MapTensorProto has no name!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
string debug_info_name = ParseParameterName(map_parameter_proto.name());
|
||||||
|
auto debug_info_ptr = std::make_shared<NodeDebugInfo>(debug_info_name);
|
||||||
|
node->set_debug_info(debug_info_ptr);
|
||||||
|
node->set_name(debug_info_name);
|
||||||
|
|
||||||
|
ParamInfoPtr param_info = std::make_shared<ParamInfo>();
|
||||||
|
param_info->set_name(debug_info_name);
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Load map parameter name: " << map_parameter_proto.name();
|
||||||
|
if (IsIncLoad() && load_tensor_map_.find(map_parameter_proto.name()) != load_tensor_map_.end()) {
|
||||||
|
MS_LOG(ERROR) << "MapParameter dose not support incremental loading, param_name: " << map_parameter_proto.name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// default value
|
||||||
|
if (!map_parameter_proto.has_default_value()) {
|
||||||
|
MS_LOG(ERROR) << "MapTensorProto should have default value: " << map_parameter_proto.name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &default_value_proto = map_parameter_proto.default_value();
|
||||||
|
auto default_value = BuildValueFromAttributeProto(default_value_proto);
|
||||||
|
if (default_value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Build default value from AttributeProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// key tensor
|
||||||
|
if (!map_parameter_proto.has_key_tensor()) {
|
||||||
|
MS_LOG(ERROR) << "MapTensorProto should have key tensor: " << map_parameter_proto.name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &key_tensor_proto = map_parameter_proto.key_tensor();
|
||||||
|
auto key_tensor = GenerateTensorPtrFromTensorProto(key_tensor_proto);
|
||||||
|
if (key_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Generate key tensor from TensorProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// value tensor
|
||||||
|
if (!map_parameter_proto.has_value_tensor()) {
|
||||||
|
MS_LOG(ERROR) << "MapTensorProto should have value tensor: " << map_parameter_proto.name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &value_tensor_proto = map_parameter_proto.value_tensor();
|
||||||
|
auto value_tensor = GenerateTensorPtrFromTensorProto(value_tensor_proto);
|
||||||
|
if (value_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Generate value tensor from TensorProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// status tensor
|
||||||
|
if (!map_parameter_proto.has_status_tensor()) {
|
||||||
|
MS_LOG(ERROR) << "MapTensorProto should have status tensor: " << map_parameter_proto.name();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto &status_tensor_proto = map_parameter_proto.status_tensor();
|
||||||
|
auto status_tensor = GenerateTensorPtrFromTensorProto(status_tensor_proto);
|
||||||
|
if (status_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Generate status tensor from TensorProto failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto map_tensor = std::make_shared<tensor::MapTensor>(key_tensor, value_tensor, status_tensor, default_value);
|
||||||
|
map_tensor->set_param_info(param_info);
|
||||||
|
node->set_default_param(map_tensor);
|
||||||
|
node->set_abstract(map_tensor->ToAbstract());
|
||||||
|
|
||||||
|
anfnode_build_map_[map_parameter_proto.name()] = node;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
|
bool MSANFModelParser::GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto,
|
||||||
const tensor::TensorPtr &tensor_info) {
|
const tensor::TensorPtr &tensor_info) {
|
||||||
if (!tensor_proto.has_external_data()) {
|
if (!tensor_proto.has_external_data()) {
|
||||||
|
@ -678,6 +804,21 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph,
|
||||||
|
const mind_ir::GraphProto &importProto) {
|
||||||
|
MS_EXCEPTION_IF_NULL(outputFuncGraph);
|
||||||
|
MS_LOG(INFO) << "All MapParameters size is: " << importProto.map_parameter_size();
|
||||||
|
for (int i = 0; i < importProto.map_parameter_size(); ++i) {
|
||||||
|
const mind_ir::MapTensorProto &map_parameter_proto = importProto.map_parameter(i);
|
||||||
|
if (!BuildMapParameterFromMapTensorProto(outputFuncGraph->add_parameter(), map_parameter_proto)) {
|
||||||
|
MS_LOG(ERROR) << "Build map parameter for funcgraph fail at index: " << i;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputFuncGraph->set_fv_param_count(IntToSize(importProto.parameter_size()));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
bool MSANFModelParser::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto) {
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
const int attr_tensor_type = attr_proto.tensors(0).data_type();
|
const int attr_tensor_type = attr_proto.tensors(0).data_type();
|
||||||
|
@ -1040,14 +1181,6 @@ bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_n
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSANFModelParser::ObtainValueNodeInTypeNullForm(const std::string &value_node_name) {
|
|
||||||
auto new_value_node = NewValueNode(kTypeNull);
|
|
||||||
MS_EXCEPTION_IF_NULL(new_value_node);
|
|
||||||
new_value_node->set_abstract(kTypeNull->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
|
bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name,
|
||||||
const mind_ir::AttributeProto &attr_proto) {
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||||
|
@ -1201,85 +1334,75 @@ ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributePro
|
||||||
return value_sequence;
|
return value_sequence;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
|
ValuePtr MSANFModelParser::BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto) {
|
||||||
const mind_ir::AttributeProto &attr_proto) {
|
|
||||||
ValueNodePtr new_value_node;
|
|
||||||
switch (attr_proto.type()) {
|
switch (attr_proto.type()) {
|
||||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||||
mind_ir::TensorProto tensor_proto = attr_proto.tensors(0);
|
const auto &tensor_proto = attr_proto.tensors(0);
|
||||||
if (tensor_proto.has_raw_data()) {
|
if (tensor_proto.has_raw_data()) {
|
||||||
// For real tensor.
|
// For real tensor.
|
||||||
(void)ObtainValueNodeInTensorForm(value_node_name, tensor_proto);
|
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to GenerateTensorPtrFromTensorProto.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return MakeValue(tensor_info);
|
||||||
} else {
|
} else {
|
||||||
// For data type.
|
// For data type.
|
||||||
(void)ObtainValueNodeInTypeForm(value_node_name, tensor_proto);
|
const int attr_tensor_type = tensor_proto.data_type();
|
||||||
|
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||||
|
if (iter == kDefaultValueSwitchMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return TypeIdToType(iter->second);
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_NONE: {
|
case mind_ir::AttributeProto_AttributeType_NONE: {
|
||||||
(void)ObtainValueNodeInNoneForm(value_node_name);
|
return kNone;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_UMONAD: {
|
case mind_ir::AttributeProto_AttributeType_UMONAD: {
|
||||||
new_value_node = NewValueNode(kUMonad);
|
return kUMonad;
|
||||||
new_value_node->set_abstract(kUMonad->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_IOMONAD: {
|
case mind_ir::AttributeProto_AttributeType_IOMONAD: {
|
||||||
new_value_node = NewValueNode(kIOMonad);
|
return kIOMonad;
|
||||||
new_value_node->set_abstract(kIOMonad->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
return ObtainValueInSequenceForm(attr_proto);
|
||||||
if (sequence_value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Failed to get sequence value for " << value_node_name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
new_value_node = NewValueNode(sequence_value);
|
|
||||||
new_value_node->set_abstract(sequence_value->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
|
case mind_ir::AttributeProto_AttributeType_CLASS_TYPE: {
|
||||||
auto class_type = static_cast<std::string>(attr_proto.s());
|
auto class_type = static_cast<std::string>(attr_proto.s());
|
||||||
auto mindir_class_type = std::make_shared<MindIRClassType>(class_type);
|
return std::make_shared<MindIRClassType>(class_type);
|
||||||
new_value_node = NewValueNode(mindir_class_type);
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
|
case mind_ir::AttributeProto_AttributeType_TYPE_NULL: {
|
||||||
(void)ObtainValueNodeInTypeNullForm(value_node_name);
|
return kTypeNull;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
|
case mind_ir::AttributeProto_AttributeType_NAME_SPACE: {
|
||||||
auto name_space = static_cast<std::string>(attr_proto.s());
|
auto name_space = static_cast<std::string>(attr_proto.s());
|
||||||
auto mindir_name_space = std::make_shared<MindIRNameSpace>(name_space);
|
return std::make_shared<MindIRNameSpace>(name_space);
|
||||||
new_value_node = NewValueNode(mindir_name_space);
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
case mind_ir::AttributeProto_AttributeType_SYMBOL: {
|
case mind_ir::AttributeProto_AttributeType_SYMBOL: {
|
||||||
auto symbol = static_cast<std::string>(attr_proto.s());
|
auto symbol = static_cast<std::string>(attr_proto.s());
|
||||||
auto mindir_symbol = std::make_shared<MindIRSymbol>(symbol);
|
return std::make_shared<MindIRSymbol>(symbol);
|
||||||
new_value_node = NewValueNode(mindir_symbol);
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
ValuePtr value = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
return ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||||
if (value == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Can not get the value for attr: " << value_node_name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
new_value_node = NewValueNode(value);
|
|
||||||
new_value_node->set_abstract(value->ToAbstract());
|
|
||||||
anfnode_build_map_[value_node_name] = new_value_node;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MSANFModelParser::GetAttrValueForValueNodeWithType(const std::string &value_node_name,
|
||||||
|
const mind_ir::AttributeProto &attr_proto) {
|
||||||
|
auto value = BuildValueFromAttributeProto(attr_proto);
|
||||||
|
if (value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to build value from AttributeProto while building valuenode: " << value_node_name;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto abstract = value->ToAbstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
ValueNodePtr new_value_node = NewValueNode(value);
|
||||||
|
new_value_node->set_abstract(abstract);
|
||||||
|
anfnode_build_map_[value_node_name] = new_value_node;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1650,11 +1773,15 @@ bool MSANFModelParser::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const
|
||||||
MS_LOG(ERROR) << "Import parameters for graph fail!";
|
MS_LOG(ERROR) << "Import parameters for graph fail!";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!ImportMapParametersForGraph(outputFuncGraph, importProto)) {
|
||||||
|
MS_LOG(ERROR) << "Import map parameters for graph failed!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (ImportNodesForGraph(outputFuncGraph, importProto)) {
|
if (ImportNodesForGraph(outputFuncGraph, importProto)) {
|
||||||
MS_LOG(DEBUG) << "Success to parse graph: " << outputFuncGraph->ToString() << ": " << outputFuncGraph.get();
|
MS_LOG(DEBUG) << "Success to parse graph: " << outputFuncGraph->ToString() << ": " << outputFuncGraph.get();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_LOG(ERROR) << "Failed to parse nodes. " << importProto.DebugString();
|
MS_LOG(ERROR) << "Failed to parse nodes. ";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,8 +91,14 @@ class MSANFModelParser {
|
||||||
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool BuildAttrForFuncGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
|
bool ImportMapParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto);
|
||||||
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto ¶meter_proto);
|
bool BuildParameterForFuncGraph(const ParameterPtr &node, const mind_ir::TensorProto ¶meter_proto);
|
||||||
|
bool BuildMapParameterFromMapTensorProto(const ParameterPtr &node,
|
||||||
|
const mind_ir::MapTensorProto &map_parameter_proto);
|
||||||
|
abstract::AbstractMapTensorPtr BuildAbstractMapTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||||
|
abstract::AbstractCOOTensorPtr BuildAbstractCOOTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||||
|
abstract::AbstractCSRTensorPtr BuildAbstractCSRTensorFromAttrProto(const mind_ir::AttributeProto &attr_proto);
|
||||||
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
|
bool SetValueForTopGraphParameter(const FuncGraphPtr &topGraph, const std::map<std::string, ValuePtr> &weights);
|
||||||
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
|
bool GetTensorDataFromExternal(const mind_ir::TensorProto &tensor_proto, const tensor::TensorPtr &tensor_info);
|
||||||
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
bool BuildInputForFuncGraph(const ParameterPtr &node, const mind_ir::ValueInfoProto &value_proto);
|
||||||
|
@ -108,6 +114,7 @@ class MSANFModelParser {
|
||||||
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
ValuePtr ObtainCNodeAttrInSingleScalarForm(const mind_ir::AttributeProto &attr_proto);
|
||||||
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const mind_ir::AttributeProto &attr_proto);
|
||||||
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
bool BuildValueNodeForFuncGraph(const mind_ir::NodeProto &node_proto);
|
||||||
|
ValuePtr BuildValueFromAttributeProto(const mind_ir::AttributeProto &attr_proto);
|
||||||
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
|
AnfNodePtr BuildOperatorNode(const mind_ir::NodeProto &node_proto);
|
||||||
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
|
bool SetEmptyTensorProtoCNodeAbstract(const AnfNodePtr &node_ptr);
|
||||||
void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
|
void SetCNodeAbstract(const mind_ir::AttributeProto &attr_proto, const CNodePtr &cnode_ptr);
|
||||||
|
|
|
@ -45,6 +45,7 @@ message AttributeProto {
|
||||||
NAME_SPACE = 34;
|
NAME_SPACE = 34;
|
||||||
SYMBOL = 35;
|
SYMBOL = 35;
|
||||||
TYPE_NULL = 36;
|
TYPE_NULL = 36;
|
||||||
|
MAP_TENSOR = 37;
|
||||||
}
|
}
|
||||||
optional string name = 1;
|
optional string name = 1;
|
||||||
optional float f = 2;
|
optional float f = 2;
|
||||||
|
@ -128,6 +129,7 @@ message GraphProto {
|
||||||
optional string bprop_hash = 7;
|
optional string bprop_hash = 7;
|
||||||
repeated AttributeProto attribute = 8;
|
repeated AttributeProto attribute = 8;
|
||||||
optional string bprop_filepath = 9;
|
optional string bprop_filepath = 9;
|
||||||
|
repeated MapTensorProto map_parameter = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -192,6 +194,14 @@ message TensorProto {
|
||||||
repeated QuantParamProto quant_params = 17;
|
repeated QuantParamProto quant_params = 17;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message MapTensorProto {
|
||||||
|
required string name = 1;
|
||||||
|
required AttributeProto default_value = 2;
|
||||||
|
required TensorProto key_tensor = 3;
|
||||||
|
required TensorProto value_tensor = 4;
|
||||||
|
required TensorProto status_tensor = 5;
|
||||||
|
}
|
||||||
|
|
||||||
message ParallelProto {
|
message ParallelProto {
|
||||||
repeated LayoutProto layout = 1;
|
repeated LayoutProto layout = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1416,6 +1416,17 @@ def _save_mindir_together(net_dict, model, file_name, is_encrypt, **kwargs):
|
||||||
else:
|
else:
|
||||||
logger.warning("The parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
|
logger.warning("The parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
|
||||||
.format(param_proto.name))
|
.format(param_proto.name))
|
||||||
|
for map_param_proto in model.graph.map_parameter:
|
||||||
|
map_param_name = map_param_proto.name[map_param_proto.name.find(":") + 1:]
|
||||||
|
if map_param_name in net_dict.keys():
|
||||||
|
map_parameter = net_dict[map_param_name]
|
||||||
|
key_nparr, value_nparr, status_nparr = map_parameter.export_data()
|
||||||
|
map_param_proto.key_tensor.raw_data = key_nparr.tobytes()
|
||||||
|
map_param_proto.value_tensor.raw_data = value_nparr.tobytes()
|
||||||
|
map_param_proto.status_tensor.raw_data = status_nparr.tobytes()
|
||||||
|
else:
|
||||||
|
logger.warning("The map_parameter '{}' is not belongs to any cell,the data of parameter cannot be exported."
|
||||||
|
.format(map_param_proto.name))
|
||||||
if not file_name.endswith('.mindir'):
|
if not file_name.endswith('.mindir'):
|
||||||
file_name += ".mindir"
|
file_name += ".mindir"
|
||||||
current_path = os.path.abspath(file_name)
|
current_path = os.path.abspath(file_name)
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
import os.path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
@ -19,6 +21,7 @@ from mindspore import context, Tensor, Parameter, ParameterTuple
|
||||||
from mindspore.experimental import MapParameter
|
from mindspore.experimental import MapParameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore import export, load
|
||||||
|
|
||||||
|
|
||||||
def test_basic_operations():
|
def test_basic_operations():
|
||||||
|
@ -237,3 +240,38 @@ def test_map_parameter_filter():
|
||||||
net = MyNet()
|
net = MyNet()
|
||||||
out = net()
|
out = net()
|
||||||
print("out:", out)
|
print("out:", out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_graph_export_load():
|
||||||
|
"""
|
||||||
|
Feature: MapParameter
|
||||||
|
Description: Test IR graph export and load with MapParameter.
|
||||||
|
Expectation: IR graph with MapParameter exported and loaded without exceptions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Cell.__init__(self)
|
||||||
|
self.p = Parameter(initializer('ones', (2, 3), ms.float32))
|
||||||
|
self.m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3,))
|
||||||
|
self.key = Tensor([1, 2], dtype=ms.int32)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
self.m.put(self.key, x)
|
||||||
|
value1 = self.m.get(self.key)
|
||||||
|
value2 = self.m[self.key]
|
||||||
|
self.m[self.key] = value2
|
||||||
|
self.m.erase(self.key)
|
||||||
|
keys = self.m.get_keys()
|
||||||
|
values = self.m.get_values()
|
||||||
|
self.m.put(keys, values)
|
||||||
|
return self.p + value1 + value2
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = MyNet()
|
||||||
|
t = initializer('ones', (2, 3), ms.float32)
|
||||||
|
t = t.init_data()
|
||||||
|
file_path = "./map-parameter.mindir"
|
||||||
|
export(net, t, file_name=file_path, file_format="MINDIR")
|
||||||
|
assert os.path.isfile(file_path)
|
||||||
|
load(file_path)
|
||||||
|
|
Loading…
Reference in New Issue