forked from mindspore-Ecosystem/mindspore
!17936 fix bug of relu grad
From: @lianliguang Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
3828b819d4
|
@ -47,22 +47,25 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("ReLUGrad infer", input_args.size(), kEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("ReLUGrad infer", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
abstract::CheckDtypeSame(prim_name, out, dout);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_map);
|
||||
auto x_type = x_type_map->cast<TensorTypePtr>();
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
|
||||
if (!x_type->isa<TensorType>()) {
|
||||
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s "
|
||||
<< " input must be tensor type but got " << x_type->ToString();
|
||||
}
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true);
|
||||
|
||||
|
|
Loading…
Reference in New Issue