[OPS] fix bug of hsigmoid infer

This commit is contained in:
yangruoqi713 2023-01-29 11:48:10 +08:00
parent 284d260d92
commit b8e7147ba7
1 changed files with 3 additions and 0 deletions

View File

@ -26,6 +26,7 @@ namespace ops {
namespace { namespace {
abstract::ShapePtr HSigmoidInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr HSigmoidInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1, primitive->name());
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
@ -34,6 +35,8 @@ abstract::ShapePtr HSigmoidInferShape(const PrimitivePtr &primitive, const std::
} }
TypePtr HSigmoidInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr HSigmoidInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, 1, prim->name());
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) { if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "For '" << prim->name() MS_LOG(EXCEPTION) << "For '" << prim->name()
<< "', the input args used for infer shape and type is necessary, but missing it."; << "', the input args used for infer shape and type is necessary, but missing it.";