!17936 fix bug of relu grad

From: @lianliguang
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-06-08 14:25:01 +08:00 committed by Gitee
commit 3828b819d4
1 changed files with 10 additions and 7 deletions

View File

@ -47,22 +47,25 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
CheckAndConvertUtils::CheckInteger("ReLUGrad infer", input_args.size(), kEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
abstract::CheckDtypeSame(prim_name, out, dout);
auto x_type_map = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type_map);
auto x_type = x_type_map->cast<TensorTypePtr>();
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
std::set<TypePtr> valid_x_type = {kTensorType};
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
if (!x_type->isa<TensorType>()) {
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s "
<< " input must be tensor type but got " << x_type->ToString();
}
return x_type;
}
} // namespace
AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true);