!19083 modify export air parameter format

Merge pull request !19083 from changzherui/mod_export_air_data_format
This commit is contained in:
i-robot 2021-07-03 09:21:39 +00:00 committed by Gitee
commit e1d7594fc6
3 changed files with 35 additions and 8 deletions

View File

@ -964,7 +964,17 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
std::ostringstream buf;
buf << "[" << shape << "]";
MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
std::string format = "NCHW";
if (it->isa<Parameter>()) {
auto param = it->cast<ParameterPtr>();
std::string param_name = param->DebugString();
auto param_format = param_format_.find(param_name);
if (param_format != param_format_.end()) {
format = param_format->second;
MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
}
}
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
if (desc == nullptr) {
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
} else {
@ -1660,6 +1670,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
}
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
SaveParamFormat(node);
std::string name = GetCNodeTargetFuncName(node);
if (!CheckCNode(name, node)) {
return nullptr;
@ -1707,6 +1718,27 @@ OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
return op_cache_[node.get()];
}
void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
AnfNodePtr op = node->input(0);
if (IsValueNode<Primitive>(op)) {
auto prim = GetValueNode<PrimitivePtr>(op);
for (auto attr : prim->attrs()) {
if (attr.first == "format" && attr.second->ToString() == "NCDHW") {
std::string format = attr.second->ToString();
auto inputs_size = node->size();
for (size_t i = 1; i < inputs_size; i++) {
auto input = node->input(i);
if (input->isa<Parameter>()) {
param_format_[input->DebugString()] = format;
MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
}
}
continue;
}
}
}
}
Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
ValuePtr value = node->value();

View File

@ -135,6 +135,7 @@ class DfGraphConvertor {
std::ostringstream checkpoint_sout_;
std::ostringstream restore_checkpoint_sout_;
std::unordered_map<AnfNode *, std::string> op_draw_name_;
std::map<std::string, std::string> param_format_;
AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index);
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
@ -148,6 +149,7 @@ class DfGraphConvertor {
OperatorPtr ConvertParameter(AnfNodePtr node);
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
OperatorPtr ConvertValueNode(ValueNodePtr node);
void SaveParamFormat(CNodePtr node);
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
void ConvertTupleGetItem(const CNodePtr node);
void ConvertMakeTuple(const CNodePtr node);

View File

@ -306,13 +306,6 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
if (iter->second == "format") {
ValuePtr format = prim->GetAttr("format");
MS_EXCEPTION_IF_NULL(format);
std::string type_name = prim->name();
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, "format", &format);
if (!converted) {
MS_LOG(ERROR) << "Fail to convert from attr value to string"
<< " for Op: " << type_name;
return ret;
}
return GetValue<std::string>(format);
}
return iter->second;