!18777 the bug of BiasAdd operator fixed.

Merge pull request !18777 from 沈竞兴/BiasAdd
This commit is contained in:
i-robot 2021-06-26 09:28:57 +00:00 committed by Gitee
commit 96dc8f8eee
7 changed files with 13 additions and 13 deletions

View File

@ -28,7 +28,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -53,14 +53,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
MS_EXCEPTION_IF_NULL(context_ptr);
auto is_ascend = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
if (data_format == Format::NCDHW && (input_shape.size() != 5 || !is_ascend)) {
MS_LOG(EXCEPTION) << "NCDHW format only support 5-dims input in Ascend target.";
MS_EXCEPTION(ValueError) << "NCDHW format only support 5-dims input in Ascend target.";
}
auto x_channel = data_format == Format::NHWC ? input_shape[input_shape.size() - 1] : input_shape[1];
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
if (x_not_dyn && bias_shape[0] != x_channel) {
MS_EXCEPTION(ValueError) << "BiasAdd shape error, data format is " << data_format
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
}
CheckAndConvertUtils::CheckMinMaxShape(input_shape, &min_shape, &max_shape);
if (min_shape.size() != 0 && max_shape.size() != 0) {
@ -76,9 +76,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
MS_EXCEPTION_IF_NULL(item);
}
for (size_t i = 0; i < input_args.size(); i++) {
auto x_type_tmp = input_args[i]->BuildType();
MS_EXCEPTION_IF_NULL(x_type_tmp);
auto x_type = x_type_tmp->cast<TensorTypePtr>();
auto x_type = input_args[i]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
std::set<TypePtr> valid_x_type = {kTensorType};
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_type, valid_x_type, prim_name);
@ -100,7 +98,9 @@ Format BiasAdd::get_format() const {
void BiasAdd::Init(const Format &format) { this->set_format(format); }
AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
auto infertype = InferType(primitive, input_args);
auto infershape = InferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer, nullptr, true);
} // namespace ops

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -29,7 +29,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}

View File

@ -30,7 +30,7 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}