forked from mindspore-Ecosystem/mindspore
fix LSTM quant bug and fix ToD: check input param
This commit is contained in:
parent
8e8f3043f9
commit
8f495b7aa8
|
@ -159,6 +159,18 @@ session::TrainSession *session::TrainSession::CreateTransferSession(const char *
|
|||
size_t size_backbone, const char *model_buf_head,
|
||||
size_t size_head, lite::Context *context,
|
||||
bool train_mode) {
|
||||
auto ValidModelSize = [](size_t size) -> bool {
|
||||
constexpr size_t MaxModelSize = 1024 * 1024 * 1024ULL; // 1G B
|
||||
return size < MaxModelSize && size > 0;
|
||||
};
|
||||
if (!ValidModelSize(size_backbone)) {
|
||||
MS_LOG(ERROR) << "size_backbone too large: " << size_backbone;
|
||||
return nullptr;
|
||||
}
|
||||
if (!ValidModelSize(size_head)) {
|
||||
MS_LOG(ERROR) << "size_head too large: " << size_head;
|
||||
return nullptr;
|
||||
}
|
||||
auto session = new (std::nothrow) lite::TransferSession(model_buf_backbone, size_backbone, context);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create transfer session failed";
|
||||
|
|
|
@ -275,8 +275,8 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
|
|||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_i, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
MS_LOG(INFO) << "LSTM input index " << index << " is not weight";
|
||||
return RET_OK;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
|
|
Loading…
Reference in New Issue