!37190 fix some ops infer
Merge pull request !37190 from lianliguang/master
This commit is contained in:
commit
96d9911a90
|
@ -49,7 +49,8 @@ TypePtr MulInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, op_name);
|
||||
return std::make_shared<TensorType>(kBool);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -47,8 +48,8 @@ AbstractBasePtr NotEqualInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
(void)InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), kBool);
|
||||
auto type = InferType(primitive, input_args);
|
||||
return abstract::MakeAbstract(BroadCastInferShape(op_name, input_args), type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameNotEqual, NotEqual);
|
||||
} // namespace ops
|
||||
|
|
|
@ -49,7 +49,8 @@ TypePtr RealDivInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("y", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
Loading…
Reference in New Issue