!43556 Use correct type infer for nllloss

Merge pull request !43556 from panzhihui/nllloss
This commit is contained in:
i-robot 2022-10-15 03:53:59 +00:00 committed by Gitee
commit 549a6fa81e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 8 deletions

View File

@ -82,15 +82,13 @@ class NLLLossInfer : public abstract::OpInferBase {
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
const std::set valid_types = {kFloat16, kFloat32};
auto target_type = input_args[kIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
std::map<std::string, TypePtr> types;
auto logits_data_type = input_args[kIndex0]->BuildType();
(void)types.emplace("logits", logits_data_type);
(void)types.emplace("weight", input_args[kIndex2]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", types["logits"], valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", types["weight"], valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, logits_data_type});
auto target_type = input_args[kIndex1]->BuildType();
auto weight_data_type = input_args[kIndex2]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", logits_data_type, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_data_type, valid_types, prim->name());
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, weight_data_type});
}
};