diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 696237d1a42..fe2f2e53d0a 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -50,6 +50,7 @@ constexpr auto kMul = "Mul"; constexpr auto kRealDiv = "RealDiv"; constexpr auto kAdd = "Add"; constexpr auto kTile = "Tile"; +constexpr auto kBiasAddGrad = "BiasAddGrad"; // Arrays constexpr auto kStack = "Stack"; @@ -274,7 +275,7 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = std::make_shared("DepthwiseConv2dNativeBackpropInput"); inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared("DetectionPostProcess"); inline const PrimitivePtr kPrimBiasAdd = std::make_shared("BiasAdd"); -inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); +inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared(kBiasAddGrad); inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared("BiasSubGrad"); inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared("BinaryCrossEntropy"); inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared("BinaryCrossEntropyGrad"); diff --git a/mindspore/core/ops/grad/bias_add_grad.cc b/mindspore/core/ops/grad/bias_add_grad.cc index 9fec587f180..5d8f269a14e 100644 --- a/mindspore/core/ops/grad/bias_add_grad.cc +++ b/mindspore/core/ops/grad/bias_add_grad.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include "ops/op_utils.h" @@ -26,36 +27,55 @@ namespace mindspore { namespace ops { -void BiasAddGrad::Init(const Format format) { this->set_format(format); } - -void BiasAddGrad::set_format(const Format format) { - int64_t f = format; - AddAttr(kFormat, MakeValue(f)); +namespace { +std::vector GetFormatShape(const int64_t &format, const std::vector &input_shape) { + std::vector output_shape; + if (format == NHWC) { + output_shape.push_back(input_shape.back()); + } else { + output_shape.push_back(input_shape[1]); + } + return output_shape; } - -Format BiasAddGrad::get_format() const { - auto value_ptr = GetAttr(kFormat); - return Format(GetValue(value_ptr)); -} - -AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name); - MS_EXCEPTION_IF_NULL(input_args[0]); - - // Infer shape - auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - for (size_t i = 0; i < inshape.size() - 1; i++) { - inshape[i] = 1; + CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); } - - // Infer type - auto intype = input_args[0]->BuildType()->cast()->element(); - - return std::make_shared(intype, inshape); + int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format")); + auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape()); + auto input_shape = shape_map[kShape]; + auto min_shape = shape_map[kMinShape]; + auto max_shape = shape_map[kMaxShape]; + auto input_shape_ = GetFormatShape(format, input_shape); + if (min_shape.size() != 0 && max_shape.size() != 0) { + auto min_shape_ = GetFormatShape(format, min_shape); + auto max_shape_ = GetFormatShape(format, max_shape); + return std::make_shared(input_shape_, min_shape_, max_shape_); + } + return std::make_shared(input_shape_); } -REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad); +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + CheckAndConvertUtils::CheckInteger("BiasAddGrad infer", input_args.size(), kEqual, 1, prim_name); + MS_EXCEPTION_IF_NULL(input_args[0]); + auto x_type_map = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(x_type_map); + auto x_type = x_type_map->cast(); + MS_EXCEPTION_IF_NULL(x_type); + std::set valid_x_type = {kTensorType}; + return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name); +} +} // namespace +AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); +} +REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true); + } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/bias_add_grad.h b/mindspore/core/ops/grad/bias_add_grad.h index 7399a77ccad..b1df50986de 100644 --- a/mindspore/core/ops/grad/bias_add_grad.h +++ b/mindspore/core/ops/grad/bias_add_grad.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,15 +26,13 @@ namespace mindspore { namespace ops { -constexpr auto kNameBiasAddGrad = "BiasAddGrad"; +constexpr auto kNameBiasAddGrad = prim::kBiasAddGrad; class BiasAddGrad : public PrimitiveC { public: - BiasAddGrad() : PrimitiveC(kNameBiasAddGrad) {} + BiasAddGrad() : PrimitiveC(prim::kPrimBiasAddGrad->name()) { InitIOName({"x"}, {"output"}); } ~BiasAddGrad() = default; MS_DECLARE_PARENT(BiasAddGrad, PrimitiveC); - void Init(const Format format); - void set_format(const Format format); - Format get_format() const; + void Init() {} }; AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index bfb52fe1cfd..0bd64b50824 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -224,7 +224,7 @@ class SyncBatchNormGrad(PrimitiveWithInfer): return (x_type, scale_type, scale_type) -class BiasAddGrad(PrimitiveWithInfer): +class BiasAddGrad(Primitive): """Computes gradients of BiasAdd.""" @prim_attr_register @@ -237,13 +237,6 @@ class BiasAddGrad(PrimitiveWithInfer): self.format = "NCHW" self.add_prim_attr('data_format', self.format) - def infer_shape(self, d_output): - channel = d_output[-1] if self.format == "NHWC" else d_output[1] - return (channel,) - - def infer_dtype(self, dout_dtype): - return dout_dtype - class KLDivLossGrad(PrimitiveWithInfer): """Computes gradients for `KLDivLoss` operation."""