forked from mindspore-Ecosystem/mindspore
!18556 the bug of BiasAdd operator fixed.
Merge pull request !18556 from 沈竞兴/BiasAdd
This commit is contained in:
commit
6bbbb19dab
|
@ -31,7 +31,6 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_args.size(), kIncludeBoth, {2, 5}, prim_name);
|
||||
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto bias = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
|
@ -41,6 +40,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto input_shape = shape_map[kShape];
|
||||
auto min_shape = shape_map[kMinShape];
|
||||
auto max_shape = shape_map[kMaxShape];
|
||||
CheckAndConvertUtils::CheckInRange("bias_add_infer", input_shape.size(), kIncludeBoth, {2, 5}, prim_name);
|
||||
auto bias_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("bias rank", bias_shape.size(), kEqual, 1, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", input_shape.size(), kGreaterEqual, 2, prim_name);
|
||||
|
@ -59,8 +59,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
|
||||
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
||||
if (x_not_dyn && bias_shape[0] != x_channel) {
|
||||
MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
|
||||
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
|
||||
MS_EXCEPTION(ValueError) << "BiasAdd shape error, data format is " << data_format
|
||||
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
|
||||
}
|
||||
CheckAndConvertUtils::CheckMinMaxShape(input_shape, &min_shape, &max_shape);
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
|
|
Loading…
Reference in New Issue