!49186 [MSLITE][Fix] fix prelu dtype check

Merge pull request !49186 from 赵英灼/fix_prelu_dtype_check
This commit is contained in:
i-robot 2023-02-22 02:46:00 +00:00 committed by Gitee
commit c99db1410d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 3 additions and 9 deletions

View File

@ -67,15 +67,9 @@ TypePtr PReLUInferType(const PrimitivePtr &primitive, const std::vector<Abstract
auto weight_type = input_args[kInputIndex1]->BuildType(); auto weight_type = input_args[kInputIndex1]->BuildType();
auto valid_types = {kFloat16, kFloat32}; auto valid_types = {kFloat16, kFloat32};
if (IsAscend()) { (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); (void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_type, valid_types, prim_name);
} else {
std::map<std::string, TypePtr> args;
(void)args.emplace("x", x_type);
(void)args.emplace("weight", weight_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
}
return x_type; return x_type;
} }