fix fse return status

This commit is contained in:
yeyunpeng2020 2021-09-30 15:32:36 +08:00
parent a65ee6cc81
commit f2337e4a77
2 changed files with 14 additions and 1 deletions

View File

@ -124,7 +124,11 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
return RET_OK;
}
quant::FSEEncoder fse_encoder;
fse_encoder.Compress(tensor_input);
auto status = fse_encoder.Compress(tensor_input);
if (status != RET_OK) {
MS_LOG(ERROR) << "fse encode compress failed." << status;
return RET_ERROR;
}
} else if (bit_num <= kBitNum8) {
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
} else {

View File

@ -293,6 +293,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
offset += sizeof(uint16_t);
@ -307,6 +308,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
offset += sizeof(uint16_t);
@ -315,6 +317,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
offset += sizeof(uint16_t);
@ -323,6 +326,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(float) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
offset += sizeof(float);
@ -331,6 +335,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
offset += sizeof(uint16_t);
@ -339,6 +344,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint64_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
offset += sizeof(uint64_t);
@ -346,12 +352,14 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
if (offset + sizeof(uint64_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
offset += sizeof(uint64_t);
if (offset + sizeof(uint8_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
return RET_ERROR;
}
*(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
offset += sizeof(uint8_t);
@ -359,6 +367,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
tensor_input->data.resize(offset);
if (memcpy_s(tensor_input->data.data(), offset, out8, offset) != EOK) {
MS_LOG(ERROR) << "memcpy failed.";
return RET_ERROR;
}
}
tensor_input->quantParams.clear();