forked from mindspore-Ecosystem/mindspore
fix_lite_kernel
This commit is contained in:
parent
b33cc9d991
commit
48cdd7cd97
|
@ -20,7 +20,7 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
|
||||
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
|
||||
if (tensor->MutableData() == nullptr) {
|
||||
if (tensor->data_c() == nullptr) {
|
||||
MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
|
||||
return std::vector<StringPack>{};
|
||||
}
|
||||
|
|
|
@ -65,8 +65,12 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
|
|||
MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr";
|
||||
} else {
|
||||
if (TensorCategory(srcTensor) == Tensor::Category::CONST) {
|
||||
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
|
||||
shape.push_back(srcTensor->dims()->data()[j]);
|
||||
if (srcTensor->dataType() == kObjectTypeString && srcTensor->data() != nullptr) {
|
||||
shape.push_back(srcTensor->data()->size());
|
||||
} else {
|
||||
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
|
||||
shape.push_back(srcTensor->dims()->data()[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() {
|
|||
auto output_tensor = out_tensors_.at(0);
|
||||
auto hits_tensor = out_tensors_.at(1);
|
||||
|
||||
int rows = values_tensor->DimensionSize(0);
|
||||
int rows = GetStringCount(values_tensor);
|
||||
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
|
||||
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
|
||||
std::vector<lite::StringPack> output_string_pack;
|
||||
std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum());
|
||||
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);
|
||||
lite::StringPack null_string_pack = {0, nullptr};
|
||||
|
||||
for (int i = 0; i < input_tensor->ElementsNum(); i++) {
|
||||
int index = -1;
|
||||
|
@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() {
|
|||
index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
|
||||
}
|
||||
if (index >= rows || index < 0) {
|
||||
lite::StringPack tmp = {0, nullptr};
|
||||
output_string_pack.push_back(tmp);
|
||||
output_string_pack[i] = null_string_pack;
|
||||
hits_data[i] = 0;
|
||||
} else {
|
||||
output_string_pack.push_back(all_string_pack[i]);
|
||||
output_string_pack[i] = all_string_pack[i];
|
||||
hits_data[i] = 1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,9 +88,10 @@ int PredictCPUKernel::Run() {
|
|||
if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) {
|
||||
output_label[i] = -1;
|
||||
output_weight[i] = 0.0f;
|
||||
} else {
|
||||
output_label[i] = label_info_vec[i].label;
|
||||
output_weight[i] = label_info_vec[i].weight;
|
||||
}
|
||||
output_label[i] = label_info_vec[i].label;
|
||||
output_weight[i] = label_info_vec[i].weight;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue