!37190 fix some ops infer

Merge pull request !37190 from lianliguang/master
This commit is contained in:
i-robot 2022-07-05 09:13:23 +00:00 committed by Gitee
commit 96d9911a90
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 8 additions and 5 deletions

View File

@ -49,7 +49,8 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[0]->BuildType();
}
} // namespace

View File

@ -33,7 +33,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
return std::make_shared<TensorType>(kBool);
}
} // namespace
@ -47,8 +48,8 @@ AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const Primiti
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
(void)InferType(primitive, input_args);
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool);
auto type = InferType(primitive, input_args);
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), type);
}
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
} // namespace ops

View File

@ -49,7 +49,8 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("y", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
return input_args[0]->BuildType();
}
} // namespace