fix LSTM quant bug and fix ToD: check input param

This commit is contained in:
xutianchun 2021-03-19 14:54:47 +08:00
parent 8e8f3043f9
commit 8f495b7aa8
2 changed files with 14 additions and 2 deletions

View File

@ -159,6 +159,18 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
size_t size_backbone, const char *model_buf_head,
size_t size_head, lite::Context *context,
bool train_mode) {
auto ValidModelSize = [](size_t size) -> bool {
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
return size < MaxModelSize && size > 0;
};
if (!ValidModelSize(size_backbone)) {
MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
return nullptr;
}
if (!ValidModelSize(size_head)) {
MS_LOG(ERROR) << "size_head too large: " << size_head;
return nullptr;
}
auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
if (session == nullptr) {
MS_LOG(ERROR) << "create transfer session failed";

View File

@ -275,8 +275,8 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
MS_LOG(INFO) << "LSTM input index " << index << " is not weight";
return RET_OK;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";