[MSLITE] add data type check for NegGrad

This commit is contained in:
ling 2022-01-14 16:18:09 +08:00
parent 5e7a38d1e2
commit 6a89231fd9
1 changed files with 6 additions and 0 deletions

View File

@ -40,6 +40,12 @@ int NegGradCPUKernel::Prepare() {
CHECK_LESS_RETURN(out_tensors_.size(), 1);
CHECK_NULL_RETURN(in_tensors_.at(0));
CHECK_NULL_RETURN(out_tensors_.at(0));
if (in_tensors_.at(kInputIndex)->data_type() != kNumberTypeFloat32 ||
out_tensors_.at(kOutputIndex)->data_type() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "illegal data type for NegGrad: " << in_tensors_.at(kInputIndex)->data_type() << ", "
<< out_tensors_.at(kOutputIndex)->data_type();
return RET_ERROR;
}
return RET_OK;
}