forked from mindspore-Ecosystem/mindspore
!48611 fix sigmoid ops bug for NULL input
Merge pull request !48611 from zhangqi/0209
This commit is contained in:
commit
68a49590e4
|
@ -38,6 +38,9 @@ class SigmoidInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||||
auto x_dtype = input_args[0]->BuildType();
|
auto x_dtype = input_args[0]->BuildType();
|
||||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
|
||||||
|
|
Loading…
Reference in New Issue