!15602 [GraphKernel]Bugfix about GetValue in akg_kernel_json_generator

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
This commit is contained in:
mindspore-ci-bot 2021-04-25 14:42:27 +08:00 committed by Gitee
commit 8de54f3355
1 changed files with 9 additions and 6 deletions

View File

@ -111,8 +111,8 @@ class OpInfoExtractor {
op_attr->set_type("bool"); op_attr->set_type("bool");
} else if (v->isa<StringImm>()) { } else if (v->isa<StringImm>()) {
op_attr->set_type("str"); op_attr->set_type("str");
} else if (v->isa<ValueList>() || v->isa<ValueTuple>()) { } else if (v->isa<ValueSequeue>()) {
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value(); const auto &vec = v->cast<ValueSequeuePtr>()->value();
if (vec.empty()) { if (vec.empty()) {
op_attr->set_type("listInt"); op_attr->set_type("listInt");
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) { } else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
@ -262,10 +262,14 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(op_attr); MS_EXCEPTION_IF_NULL(op_attr);
MS_EXCEPTION_IF_NULL(attr_json); MS_EXCEPTION_IF_NULL(attr_json);
auto get_int_value = [](const ValuePtr &value) -> int {
return value->isa<Int64Imm>() ? static_cast<int>(GetValue<int64_t>(value)) : GetValue<int>(value);
};
std::string type = op_attr->type(); std::string type = op_attr->type();
(*attr_json)[kJsonKeyDataType] = type; (*attr_json)[kJsonKeyDataType] = type;
if (type == "int") { if (type == "int") {
(*attr_json)[kJsonKeyValue] = static_cast<int>(GetValue<int64_t>(attr_value)); (*attr_json)[kJsonKeyValue] = get_int_value(attr_value);
} else if (type == "str") { } else if (type == "str") {
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value); (*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
} else if (type == "bool") { } else if (type == "bool") {
@ -274,9 +278,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value); (*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
} else if (type == "listInt") { } else if (type == "listInt") {
std::vector<int> list_int; std::vector<int> list_int;
std::vector<int64_t> list_int_me = GetValue<std::vector<int64_t>>(attr_value); const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
(void)std::transform(list_int_me.begin(), list_int_me.end(), std::back_inserter(list_int), (void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
[](const int64_t &value) { return static_cast<int>(value); });
(*attr_json)[kJsonKeyValue] = list_int; (*attr_json)[kJsonKeyValue] = list_int;
} else if (type == "listStr") { } else if (type == "listStr") {
std::vector<std::string> data_format; std::vector<std::string> data_format;