forked from mindspore-Ecosystem/mindspore
fix fse return status
This commit is contained in:
parent
a65ee6cc81
commit
f2337e4a77
|
@ -124,7 +124,11 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
|
|||
return RET_OK;
|
||||
}
|
||||
quant::FSEEncoder fse_encoder;
|
||||
fse_encoder.Compress(tensor_input);
|
||||
auto status = fse_encoder.Compress(tensor_input);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "fse encode compress failed." << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (bit_num <= kBitNum8) {
|
||||
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
|
||||
} else {
|
||||
|
|
|
@ -293,6 +293,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
|
||||
offset += sizeof(uint16_t);
|
||||
|
@ -307,6 +308,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
|
||||
offset += sizeof(uint16_t);
|
||||
|
@ -315,6 +317,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
|
||||
offset += sizeof(uint16_t);
|
||||
|
@ -323,6 +326,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(float) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
|
||||
offset += sizeof(float);
|
||||
|
@ -331,6 +335,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
|
||||
offset += sizeof(uint16_t);
|
||||
|
@ -339,6 +344,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint64_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
|
||||
offset += sizeof(uint64_t);
|
||||
|
@ -346,12 +352,14 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
if (offset + sizeof(uint64_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
|
||||
offset += sizeof(uint64_t);
|
||||
if (offset + sizeof(uint8_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
*(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
|
||||
offset += sizeof(uint8_t);
|
||||
|
@ -359,6 +367,7 @@ int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, c
|
|||
tensor_input->data.resize(offset);
|
||||
if (memcpy_s(tensor_input->data.data(), offset, out8, offset) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
tensor_input->quantParams.clear();
|
||||
|
|
Loading…
Reference in New Issue