!29242 fix a bug about pynative mode error on Ascend

Merge pull request !29242 from 沈竞兴/codefix_r1.6
This commit is contained in:
i-robot 2022-02-25 12:32:02 +00:00 committed by Gitee
commit 6c8f580daa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 7 deletions

View File

@ -28,9 +28,6 @@ namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("gelu infer", SizeToLong(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
@ -39,9 +36,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
return shape_ptr;
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto op_name = prim->name();
(void)CheckAndConvertUtils::CheckInteger("gelu infer", SizeToLong(input_args.size()), kEqual, 1, op_name);
std::map<std::string, TypePtr> types;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(input_args[0]);
@ -51,7 +45,12 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
} // namespace
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = InferType(primitive, input_args);
auto infer_shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true);
} // namespace ops