!13233 modify attr_target when load and export

From: @wangnan39
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-17 14:45:44 +08:00 committed by Gitee
commit 17654135aa
4 changed files with 54 additions and 2 deletions

View File

@ -430,7 +430,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
attr_proto->set_name(attr.first);
auto attr_value = attr.second;
CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value);
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
SetValueToAttributeProto(attr_value, attr_proto);
}
} else {

View File

@ -501,7 +501,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
const std::string &op_type = prim->name();
if (!IsLite()) {
CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res);
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
}
prim->AddAttr(attr_name, res);
break;

View File

@ -253,6 +253,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
return true;
}
void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) {
if (attr_name == "primitive_target") {
auto target_value = GetValue<std::string>(*value);
if (target_value == "CPU") {
*value = MakeValue<std::string>("host");
} else {
MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value;
}
}
}
void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) {
if (attr_name == "primitive_target") {
auto target_value = GetValue<std::string>(*value);
// compatible with exported model
if (target_value == "CPU") {
return;
}
if (target_value == "host") {
*value = MakeValue<std::string>("CPU");
} else {
MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value;
}
}
}
void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
return;
}
// convert enum to string
ConvertAttrValueToString(op_type, attr_name, value);
// set cpu target as host
ConvertTargetAttr(attr_name, value);
}
void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
ValuePtr *const value) {
if (value == nullptr || *value == nullptr) {
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
return;
}
// convert string to enum
ConvertAttrValueToInt(op_type, attr_name, value);
// restore target as CPU
RestoreTargetAttr(attr_name, value);
}
namespace {
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;

View File

@ -284,6 +284,8 @@ class CheckAndConvertUtils {
const std::string &prim_name);
static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);