!13233 modify attr_target when load and export
From: @wangnan39 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
17654135aa
|
@ -430,7 +430,7 @@ void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *cons
|
|||
mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
|
||||
attr_proto->set_name(attr.first);
|
||||
auto attr_value = attr.second;
|
||||
CheckAndConvertUtils::ConvertAttrValueToString(type_name, attr.first, &attr_value);
|
||||
CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
|
||||
SetValueToAttributeProto(attr_value, attr_proto);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -501,7 +501,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
ValuePtr res = ObtainCNodeAttrInSingleScalarForm(attr_proto);
|
||||
const std::string &op_type = prim->name();
|
||||
if (!IsLite()) {
|
||||
CheckAndConvertUtils::ConvertAttrValueToInt(op_type, attr_name, &res);
|
||||
CheckAndConvertUtils::ConvertAttrValueInLoad(op_type, attr_name, &res);
|
||||
}
|
||||
prim->AddAttr(attr_name, res);
|
||||
break;
|
||||
|
|
|
@ -253,6 +253,56 @@ bool CheckAndConvertUtils::ConvertAttrValueToString(const std::string &op_type,
|
|||
return true;
|
||||
}
|
||||
|
||||
void ConvertTargetAttr(const std::string &attr_name, ValuePtr *const value) {
|
||||
if (attr_name == "primitive_target") {
|
||||
auto target_value = GetValue<std::string>(*value);
|
||||
if (target_value == "CPU") {
|
||||
*value = MakeValue<std::string>("host");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The primitive_target only support CPU when export, but got " << target_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RestoreTargetAttr(const std::string &attr_name, ValuePtr *const value) {
|
||||
if (attr_name == "primitive_target") {
|
||||
auto target_value = GetValue<std::string>(*value);
|
||||
// compatible with exported model
|
||||
if (target_value == "CPU") {
|
||||
return;
|
||||
}
|
||||
if (target_value == "host") {
|
||||
*value = MakeValue<std::string>("CPU");
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid primitive_target value: " << target_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name,
|
||||
ValuePtr *const value) {
|
||||
if (value == nullptr || *value == nullptr) {
|
||||
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
|
||||
return;
|
||||
}
|
||||
// convert enum to string
|
||||
ConvertAttrValueToString(op_type, attr_name, value);
|
||||
// set cpu target as host
|
||||
ConvertTargetAttr(attr_name, value);
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name,
|
||||
ValuePtr *const value) {
|
||||
if (value == nullptr || *value == nullptr) {
|
||||
MS_LOG(INFO) << "value is nullptr! op_type = " << op_type << ", attr_name = " << attr_name;
|
||||
return;
|
||||
}
|
||||
// convert string to enum
|
||||
ConvertAttrValueToInt(op_type, attr_name, value);
|
||||
// restore target as CPU
|
||||
RestoreTargetAttr(attr_name, value);
|
||||
}
|
||||
|
||||
namespace {
|
||||
typedef std::map<std::string, std::function<ValuePtr(ValuePtr)>> AttrFunction;
|
||||
|
||||
|
|
|
@ -284,6 +284,8 @@ class CheckAndConvertUtils {
|
|||
const std::string &prim_name);
|
||||
static bool ConvertAttrValueToInt(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static bool ConvertAttrValueToString(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static void ConvertAttrValueInExport(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static void ConvertAttrValueInLoad(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
|
||||
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
|
||||
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
|
||||
|
|
Loading…
Reference in New Issue