forked from mindspore-Ecosystem/mindspore
!49186 [MSLITE][Fix] fix prelu dtype check
Merge pull request !49186 from 赵英灼/fix_prelu_dtype_check
This commit is contained in:
commit
c99db1410d
|
@ -67,15 +67,9 @@ TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
|||
auto weight_type = input_args[kInputIndex1]->BuildType();
|
||||
auto valid_types = {kFloat16, kFloat32};
|
||||
|
||||
if (IsAscend()) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
||||
} else {
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.emplace("x", x_type);
|
||||
(void)args.emplace("weight", weight_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
||||
|
||||
return x_type;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue