forked from mindspore-Ecosystem/mindspore
!43556 Use correct type infer for nllloss
Merge pull request !43556 from panzhihui/nllloss
This commit is contained in:
commit
549a6fa81e
|
@ -82,15 +82,13 @@ class NLLLossInfer : public abstract::OpInferBase {
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
const std::set valid_types = {kFloat16, kFloat32};
|
||||
auto target_type = input_args[kIndex1]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
|
||||
std::map<std::string, TypePtr> types;
|
||||
auto logits_data_type = input_args[kIndex0]->BuildType();
|
||||
(void)types.emplace("logits", logits_data_type);
|
||||
(void)types.emplace("weight", input_args[kIndex2]->BuildType());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", types["logits"], valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", types["weight"], valid_types, prim->name());
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, logits_data_type});
|
||||
auto target_type = input_args[kIndex1]->BuildType();
|
||||
auto weight_data_type = input_args[kIndex2]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("target", target_type, {kInt32}, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("logits", logits_data_type, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("weight", weight_data_type, valid_types, prim->name());
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{logits_data_type, weight_data_type});
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue