forked from mindspore-Ecosystem/mindspore
!19083 modify export air parameter format
Merge pull request !19083 from changzherui/mod_export_air_data_format
This commit is contained in:
commit
e1d7594fc6
|
@ -964,7 +964,17 @@ void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr
|
|||
std::ostringstream buf;
|
||||
buf << "[" << shape << "]";
|
||||
MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
|
||||
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
|
||||
std::string format = "NCHW";
|
||||
if (it->isa<Parameter>()) {
|
||||
auto param = it->cast<ParameterPtr>();
|
||||
std::string param_name = param->DebugString();
|
||||
auto param_format = param_format_.find(param_name);
|
||||
if (param_format != param_format_.end()) {
|
||||
format = param_format->second;
|
||||
MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
|
||||
}
|
||||
}
|
||||
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
|
||||
if (desc == nullptr) {
|
||||
MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
|
||||
} else {
|
||||
|
@ -1660,6 +1670,7 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
|||
}
|
||||
|
||||
OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
|
||||
SaveParamFormat(node);
|
||||
std::string name = GetCNodeTargetFuncName(node);
|
||||
if (!CheckCNode(name, node)) {
|
||||
return nullptr;
|
||||
|
@ -1707,6 +1718,27 @@ OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
|
|||
return op_cache_[node.get()];
|
||||
}
|
||||
|
||||
void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
|
||||
AnfNodePtr op = node->input(0);
|
||||
if (IsValueNode<Primitive>(op)) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(op);
|
||||
for (auto attr : prim->attrs()) {
|
||||
if (attr.first == "format" && attr.second->ToString() == "NCDHW") {
|
||||
std::string format = attr.second->ToString();
|
||||
auto inputs_size = node->size();
|
||||
for (size_t i = 1; i < inputs_size; i++) {
|
||||
auto input = node->input(i);
|
||||
if (input->isa<Parameter>()) {
|
||||
param_format_[input->DebugString()] = format;
|
||||
MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
ValuePtr value = node->value();
|
||||
|
|
|
@ -135,6 +135,7 @@ class DfGraphConvertor {
|
|||
std::ostringstream checkpoint_sout_;
|
||||
std::ostringstream restore_checkpoint_sout_;
|
||||
std::unordered_map<AnfNode *, std::string> op_draw_name_;
|
||||
std::map<std::string, std::string> param_format_;
|
||||
|
||||
AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index);
|
||||
AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index);
|
||||
|
@ -148,6 +149,7 @@ class DfGraphConvertor {
|
|||
OperatorPtr ConvertParameter(AnfNodePtr node);
|
||||
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
||||
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
||||
void SaveParamFormat(CNodePtr node);
|
||||
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
|
||||
void ConvertTupleGetItem(const CNodePtr node);
|
||||
void ConvertMakeTuple(const CNodePtr node);
|
||||
|
|
|
@ -306,13 +306,6 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) {
|
|||
if (iter->second == "format") {
|
||||
ValuePtr format = prim->GetAttr("format");
|
||||
MS_EXCEPTION_IF_NULL(format);
|
||||
std::string type_name = prim->name();
|
||||
bool converted = CheckAndConvertUtils::ConvertAttrValueToString(type_name, "format", &format);
|
||||
if (!converted) {
|
||||
MS_LOG(ERROR) << "Fail to convert from attr value to string"
|
||||
<< " for Op: " << type_name;
|
||||
return ret;
|
||||
}
|
||||
return GetValue<std::string>(format);
|
||||
}
|
||||
return iter->second;
|
||||
|
|
Loading…
Reference in New Issue