forked from mindspore-Ecosystem/mindspore
!15602 [GraphKernel]Bugfix about GetValue in akg_kernel_json_generator
From: @dayschan Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @dylangeng
This commit is contained in:
commit
8de54f3355
|
@ -111,8 +111,8 @@ class OpInfoExtractor {
|
||||||
op_attr->set_type("bool");
|
op_attr->set_type("bool");
|
||||||
} else if (v->isa<StringImm>()) {
|
} else if (v->isa<StringImm>()) {
|
||||||
op_attr->set_type("str");
|
op_attr->set_type("str");
|
||||||
} else if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
|
} else if (v->isa<ValueSequeue>()) {
|
||||||
auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
|
const auto &vec = v->cast<ValueSequeuePtr>()->value();
|
||||||
if (vec.empty()) {
|
if (vec.empty()) {
|
||||||
op_attr->set_type("listInt");
|
op_attr->set_type("listInt");
|
||||||
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
|
} else if (vec[0]->isa<Int32Imm>() || vec[0]->isa<Int64Imm>()) {
|
||||||
|
@ -262,10 +262,14 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
MS_EXCEPTION_IF_NULL(op_attr);
|
MS_EXCEPTION_IF_NULL(op_attr);
|
||||||
MS_EXCEPTION_IF_NULL(attr_json);
|
MS_EXCEPTION_IF_NULL(attr_json);
|
||||||
|
|
||||||
|
auto get_int_value = [](const ValuePtr &value) -> int {
|
||||||
|
return value->isa<Int64Imm>() ? static_cast<int>(GetValue<int64_t>(value)) : GetValue<int>(value);
|
||||||
|
};
|
||||||
std::string type = op_attr->type();
|
std::string type = op_attr->type();
|
||||||
(*attr_json)[kJsonKeyDataType] = type;
|
(*attr_json)[kJsonKeyDataType] = type;
|
||||||
if (type == "int") {
|
if (type == "int") {
|
||||||
(*attr_json)[kJsonKeyValue] = static_cast<int>(GetValue<int64_t>(attr_value));
|
(*attr_json)[kJsonKeyValue] = get_int_value(attr_value);
|
||||||
} else if (type == "str") {
|
} else if (type == "str") {
|
||||||
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
|
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
|
||||||
} else if (type == "bool") {
|
} else if (type == "bool") {
|
||||||
|
@ -274,9 +278,8 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
|
||||||
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
|
(*attr_json)[kJsonKeyValue] = GetValue<float>(attr_value);
|
||||||
} else if (type == "listInt") {
|
} else if (type == "listInt") {
|
||||||
std::vector<int> list_int;
|
std::vector<int> list_int;
|
||||||
std::vector<int64_t> list_int_me = GetValue<std::vector<int64_t>>(attr_value);
|
const auto &vals = attr_value->cast<ValueSequeuePtr>()->value();
|
||||||
(void)std::transform(list_int_me.begin(), list_int_me.end(), std::back_inserter(list_int),
|
(void)std::transform(vals.begin(), vals.end(), std::back_inserter(list_int), get_int_value);
|
||||||
[](const int64_t &value) { return static_cast<int>(value); });
|
|
||||||
(*attr_json)[kJsonKeyValue] = list_int;
|
(*attr_json)[kJsonKeyValue] = list_int;
|
||||||
} else if (type == "listStr") {
|
} else if (type == "listStr") {
|
||||||
std::vector<std::string> data_format;
|
std::vector<std::string> data_format;
|
||||||
|
|
Loading…
Reference in New Issue