forked from mindspore-Ecosystem/mindspore
!19083 modify export air parameter format
Merge pull request !19083 from changzherui/mod_export_air_data_format
This commit is contained in:
commit
e1d7594fc6
|
@ -964,7 +964,17 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
|
||||||
std::ostringstream buf;
|
std::ostringstream buf;
|
||||||
buf << "[" << shape << "]";
|
buf << "[" << shape << "]";
|
||||||
MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
|
MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
|
||||||
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
|
std::string format = "NCHW";
|
||||||
|
if (it->isa<Parameter>()) {
|
||||||
|
auto param = it->cast<ParameterPtr>();
|
||||||
|
std::string param_name = param->DebugString();
|
||||||
|
auto param_format = param_format_.find(param_name);
|
||||||
|
if (param_format != param_format_.end()) {
|
||||||
|
format = param_format->second;
|
||||||
|
MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
|
||||||
if (desc == nullptr) {
|
if (desc == nullptr) {
|
||||||
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
|
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
|
||||||
} else {
|
} else {
|
||||||
|
@ -1660,6 +1670,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
||||||
}
|
}
|
||||||
|
|
||||||
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
|
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
|
||||||
|
SaveParamFormat(node);
|
||||||
std::string name = GetCNodeTargetFuncName(node);
|
std::string name = GetCNodeTargetFuncName(node);
|
||||||
if (!CheckCNode(name, node)) {
|
if (!CheckCNode(name, node)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1707,6 +1718,27 @@ OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
|
||||||
return op_cache_[node.get()];
|
return op_cache_[node.get()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
|
||||||
|
AnfNodePtr op = node->input(0);
|
||||||
|
if (IsValueNode<Primitive>(op)) {
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||||
|
for (auto attr : prim->attrs()) {
|
||||||
|
if (attr.first == "format" && attr.second->ToString() == "NCDHW") {
|
||||||
|
std::string format = attr.second->ToString();
|
||||||
|
auto inputs_size = node->size();
|
||||||
|
for (size_t i = 1; i < inputs_size; i++) {
|
||||||
|
auto input = node->input(i);
|
||||||
|
if (input->isa<Parameter>()) {
|
||||||
|
param_format_[input->DebugString()] = format;
|
||||||
|
MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
|
Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
ValuePtr value = node->value();
|
ValuePtr value = node->value();
|
||||||
|
|
|
@ -135,6 +135,7 @@ class DfGraphConvertor {
|
||||||
std::ostringstream checkpoint_sout_;
|
std::ostringstream checkpoint_sout_;
|
||||||
std::ostringstream restore_checkpoint_sout_;
|
std::ostringstream restore_checkpoint_sout_;
|
||||||
std::unordered_map<AnfNode *, std::string> op_draw_name_;
|
std::unordered_map<AnfNode *, std::string> op_draw_name_;
|
||||||
|
std::map<std::string, std::string> param_format_;
|
||||||
|
|
||||||
AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index);
|
AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index);
|
||||||
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
|
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
|
||||||
|
@ -148,6 +149,7 @@ class DfGraphConvertor {
|
||||||
OperatorPtr ConvertParameter(AnfNodePtr node);
|
OperatorPtr ConvertParameter(AnfNodePtr node);
|
||||||
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
||||||
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
||||||
|
void SaveParamFormat(CNodePtr node);
|
||||||
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
|
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
|
||||||
void ConvertTupleGetItem(const CNodePtr node);
|
void ConvertTupleGetItem(const CNodePtr node);
|
||||||
void ConvertMakeTuple(const CNodePtr node);
|
void ConvertMakeTuple(const CNodePtr node);
|
||||||
|
|
|
@ -306,13 +306,6 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
|
||||||
if (iter->second == "format") {
|
if (iter->second == "format") {
|
||||||
ValuePtr format = prim->GetAttr("format");
|
ValuePtr format = prim->GetAttr("format");
|
||||||
MS_EXCEPTION_IF_NULL(format);
|
MS_EXCEPTION_IF_NULL(format);
|
||||||
std::string type_name = prim->name();
|
|
||||||
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, "format", &format);
|
|
||||||
if (!converted) {
|
|
||||||
MS_LOG(ERROR) << "Fail to convert from attr value to string"
|
|
||||||
<< " for Op: " << type_name;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
return GetValue<std::string>(format);
|
return GetValue<std::string>(format);
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
|
|
Loading…
Reference in New Issue