forked from mindspore-Ecosystem/mindspore
!49186 [MSLITE][Fix] fix prelu dtype check
Merge pull request !49186 from 赵英灼/fix_prelu_dtype_check
This commit is contained in:
commit
c99db1410d
|
@ -67,15 +67,9 @@ TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<Abstract
|
||||||
auto weight_type = input_args[kInputIndex1]->BuildType();
|
auto weight_type = input_args[kInputIndex1]->BuildType();
|
||||||
auto valid_types = {kFloat16, kFloat32};
|
auto valid_types = {kFloat16, kFloat32};
|
||||||
|
|
||||||
if (IsAscend()) {
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
|
|
||||||
} else {
|
|
||||||
std::map<std::string, TypePtr> args;
|
|
||||||
(void)args.emplace("x", x_type);
|
|
||||||
(void)args.emplace("weight", weight_type);
|
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
|
||||||
}
|
|
||||||
return x_type;
|
return x_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue