!18556 the bug of BiasAdd operator fixed.

Merge pull request !18556 from 沈竞兴/BiasAdd
This commit is contained in:
i-robot 2021-06-21 16:17:32 +08:00 committed by Gitee
commit 6bbbb19dab
1 changed files with 3 additions and 3 deletions

View File

@ -31,7 +31,6 @@ namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_args.size(), kIncludeBoth, {2, 5}, prim_name);
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto bias = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
MS_EXCEPTION_IF_NULL(x);
@ -41,6 +40,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto input_shape = shape_map[kShape];
auto min_shape = shape_map[kMinShape];
auto max_shape = shape_map[kMaxShape];
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name);
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::CheckInteger("bias rank", bias_shape.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", input_shape.size(), kGreaterEqual, 2, prim_name);
@ -59,8 +59,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
if (x_not_dyn && bias_shape[0] != x_channel) {
MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
MS_EXCEPTION(ValueError) << "BiasAdd shape error, data format is " << data_format
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
}
CheckAndConvertUtils::CheckMinMaxShape(input_shape, &min_shape, &max_shape);
if (min_shape.size() != 0 && max_shape.size() != 0) {