fix_lite_kernel

This commit is contained in:
sunsuodong 2020-10-20 15:37:00 +08:00
parent b33cc9d991
commit 48cdd7cd97
4 changed files with 15 additions and 10 deletions

View File

@ -20,7 +20,7 @@ namespace mindspore {
namespace lite {
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor) {
if (tensor->MutableData() == nullptr) {
if (tensor->data_c() == nullptr) {
MS_LOG(ERROR) << "Tensor data is null, cannot be parsed";
return std::vector<StringPack>{};
}

View File

@ -65,8 +65,12 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
MS_LOG(DEBUG) << "Dims of " << i << "th tensor is nullptr";
} else {
if (TensorCategory(srcTensor) == Tensor::Category::CONST) {
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
shape.push_back(srcTensor->dims()->data()[j]);
if (srcTensor->dataType() == kObjectTypeString && srcTensor->data() != nullptr) {
shape.push_back(srcTensor->data()->size());
} else {
for (size_t j = 0; j < srcTensor->dims()->size(); j++) {
shape.push_back(srcTensor->dims()->data()[j]);
}
}
}
}

View File

@ -48,11 +48,12 @@ int HashtableLookupCPUKernel::Run() {
auto output_tensor = out_tensors_.at(0);
auto hits_tensor = out_tensors_.at(1);
int rows = values_tensor->DimensionSize(0);
int rows = GetStringCount(values_tensor);
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
std::vector<lite::StringPack> output_string_pack;
std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum());
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);
lite::StringPack null_string_pack = {0, nullptr};
for (int i = 0; i < input_tensor->ElementsNum(); i++) {
int index = -1;
@ -61,11 +62,10 @@ int HashtableLookupCPUKernel::Run() {
index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
}
if (index >= rows || index < 0) {
lite::StringPack tmp = {0, nullptr};
output_string_pack.push_back(tmp);
output_string_pack[i] = null_string_pack;
hits_data[i] = 0;
} else {
output_string_pack.push_back(all_string_pack[i]);
output_string_pack[i] = all_string_pack[i];
hits_data[i] = 1;
}
}

View File

@ -88,9 +88,10 @@ int PredictCPUKernel::Run() {
if (static_cast<size_t>(i) >= label_info_vec.size() || label_info_vec[i].weight < param->weight_threshold) {
output_label[i] = -1;
output_weight[i] = 0.0f;
} else {
output_label[i] = label_info_vec[i].label;
output_weight[i] = label_info_vec[i].weight;
}
output_label[i] = label_info_vec[i].label;
output_weight[i] = label_info_vec[i].weight;
}
return RET_OK;
}