forked from mindspore-Ecosystem/mindspore
Bugfix of GetValue
when the attr is a list that mixed with Int32Imm and Int64Imm (unreasonable, but it occured), it will crash whether we use GetValue<std::vector<int>> or GetValue<std::vector<int64_t>>. so we need to traverse the list and pick the numbers manually.
This commit is contained in:
parent
1c44e367e0
commit
5edf2b3d92
|
@ -111,8 +111,8 @@ class OpInfoExtractor {
|
|||
op_attr->set_type("bool");
|
||||
} else if (v->isa<StringImm>()) {
|
||||
op_attr->set_type("str");
|
||||
} else if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
|
||||
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
|
||||
} else if (v->isa<ValueSequeue>()) {
|
||||
const auto &vec = v->cast<ValueSequeuePtr>()->value();
|
||||
if (vec.empty()) {
|
||||
op_attr->set_type("listInt");
|
||||
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
|
||||
|
@ -262,10 +262,14 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
|
|||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(op_attr);
|
||||
MS_EXCEPTION_IF_NULL(attr_json);
|
||||
|
||||
auto get_int_value = [](const ValuePtr &value) -> int {
|
||||
return value->isa<Int64Imm>() ? static_cast<int>(GetValue<int64_t>(value)) : GetValue<int>(value);
|
||||
};
|
||||
std::string type = op_attr->type();
|
||||
(*attr_json)[kJsonKeyDataType] = type;
|
||||
if (type == "int") {
|
||||
(*attr_json)[kJsonKeyValue] = static_cast<int>(GetValue<int64_t>(attr_value));
|
||||
(*attr_json)[kJsonKeyValue] = get_int_value(attr_value);
|
||||
} else if (type == "str") {
|
||||
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
|
||||
} else if (type == "bool") {
|
||||
|
@ -274,9 +278,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
|
|||
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
|
||||
} else if (type == "listInt") {
|
||||
std::vector<int> list_int;
|
||||
std::vector<int64_t> list_int_me = GetValue<std::vector<int64_t>>(attr_value);
|
||||
(void)std::transform(list_int_me.begin(), list_int_me.end(), std::back_inserter(list_int),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
|
||||
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
|
||||
(*attr_json)[kJsonKeyValue] = list_int;
|
||||
} else if (type == "listStr") {
|
||||
std::vector<std::string> data_format;
|
||||
|
|
Loading…
Reference in New Issue