forked from mindspore-Ecosystem/mindspore
!16800 covert python to C++ in BiasAddGrad operator
From: @shen_jingxing Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
56bb6468df
|
@ -50,6 +50,7 @@ constexpr auto kMul = "Mul";
|
|||
constexpr auto kRealDiv = "RealDiv";
|
||||
constexpr auto kAdd = "Add";
|
||||
constexpr auto kTile = "Tile";
|
||||
constexpr auto kBiasAddGrad = "BiasAddGrad";
|
||||
|
||||
// Arrays
|
||||
constexpr auto kStack = "Stack";
|
||||
|
@ -274,7 +275,7 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput =
|
|||
std::make_shared<Primitive>("DepthwiseConv2dNativeBackpropInput");
|
||||
inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared<Primitive>("DetectionPostProcess");
|
||||
inline const PrimitivePtr kPrimBiasAdd = std::make_shared<Primitive>("BiasAdd");
|
||||
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>("BiasAddGrad");
|
||||
inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared<Primitive>(kBiasAddGrad);
|
||||
inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared<Primitive>("BiasSubGrad");
|
||||
inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>("BinaryCrossEntropy");
|
||||
inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,7 @@
|
|||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
|
@ -26,36 +27,55 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void BiasAddGrad::Init(const Format format) { this->set_format(format); }
|
||||
|
||||
void BiasAddGrad::set_format(const Format format) {
|
||||
int64_t f = format;
|
||||
AddAttr(kFormat, MakeValue(f));
|
||||
namespace {
|
||||
std::vector<int64_t> GetFormatShape(const int64_t &format, const std::vector<int64_t> &input_shape) {
|
||||
std::vector<int64_t> output_shape;
|
||||
if (format == NHWC) {
|
||||
output_shape.push_back(input_shape.back());
|
||||
} else {
|
||||
output_shape.push_back(input_shape[1]);
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
Format BiasAddGrad::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
return Format(GetValue<int64_t>(value_ptr));
|
||||
}
|
||||
|
||||
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("bias_grad_infer", input_args.size(), kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
|
||||
// Infer shape
|
||||
auto inshape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
for (size_t i = 0; i < inshape.size() - 1; i++) {
|
||||
inshape[i] = 1;
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
// Infer type
|
||||
auto intype = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
|
||||
return std::make_shared<abstract::AbstractTensor>(intype, inshape);
|
||||
int64_t format = CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr("format"));
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
auto input_shape_ = GetFormatShape(format, input_shape);
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
auto min_shape_ = GetFormatShape(format, min_shape);
|
||||
auto max_shape_ = GetFormatShape(format, max_shape);
|
||||
return std::make_shared<abstract::Shape>(input_shape_, min_shape_, max_shape_);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(input_shape_);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameBiasAddGrad, BiasAddGrad);
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
CheckAndConvertUtils::CheckInteger("BiasAddGrad infer", input_args.size(), kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type_map = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_map);
|
||||
auto x_type = x_type_map->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,15 +26,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameBiasAddGrad = "BiasAddGrad";
|
||||
constexpr auto kNameBiasAddGrad = prim::kBiasAddGrad;
|
||||
class BiasAddGrad : public PrimitiveC {
|
||||
public:
|
||||
BiasAddGrad() : PrimitiveC(kNameBiasAddGrad) {}
|
||||
BiasAddGrad() : PrimitiveC(prim::kPrimBiasAddGrad->name()) { InitIOName({"x"}, {"output"}); }
|
||||
~BiasAddGrad() = default;
|
||||
MS_DECLARE_PARENT(BiasAddGrad, PrimitiveC);
|
||||
void Init(const Format format);
|
||||
void set_format(const Format format);
|
||||
Format get_format() const;
|
||||
void Init() {}
|
||||
};
|
||||
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -224,7 +224,7 @@ class SyncBatchNormGrad(PrimitiveWithInfer):
|
|||
return (x_type, scale_type, scale_type)
|
||||
|
||||
|
||||
class BiasAddGrad(PrimitiveWithInfer):
|
||||
class BiasAddGrad(Primitive):
|
||||
"""Computes gradients of BiasAdd."""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -237,13 +237,6 @@ class BiasAddGrad(PrimitiveWithInfer):
|
|||
self.format = "NCHW"
|
||||
self.add_prim_attr('data_format', self.format)
|
||||
|
||||
def infer_shape(self, d_output):
|
||||
channel = d_output[-1] if self.format == "NHWC" else d_output[1]
|
||||
return (channel,)
|
||||
|
||||
def infer_dtype(self, dout_dtype):
|
||||
return dout_dtype
|
||||
|
||||
|
||||
class KLDivLossGrad(PrimitiveWithInfer):
|
||||
"""Computes gradients for `KLDivLoss` operation."""
|
||||
|
|
Loading…
Reference in New Issue