forked from mindspore-Ecosystem/mindspore
!18777 the bug of BiasAdd operator fixed.
Merge pull request !18777 from 沈竞兴/BiasAdd
This commit is contained in:
commit
96dc8f8eee
|
@ -28,7 +28,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -53,14 +53,14 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto is_ascend = (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
if (data_format == Format::NCDHW && (input_shape.size() != 5 || !is_ascend)) {
|
||||
MS_LOG(EXCEPTION) << "NCDHW format only support 5-dims input in Ascend target.";
|
||||
MS_EXCEPTION(ValueError) << "NCDHW format only support 5-dims input in Ascend target.";
|
||||
}
|
||||
auto x_channel = data_format == Format::NHWC ? input_shape[input_shape.size() - 1] : input_shape[1];
|
||||
bool x_not_dyn = std::all_of(input_shape.begin(), input_shape.end(),
|
||||
[](int64_t value) { return value != abstract::Shape::SHP_ANY; });
|
||||
if (x_not_dyn && bias_shape[0] != x_channel) {
|
||||
MS_EXCEPTION(ValueError) << "BiasAdd shape error, data format is " << data_format
|
||||
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
|
||||
MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format
|
||||
<< ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << ".";
|
||||
}
|
||||
CheckAndConvertUtils::CheckMinMaxShape(input_shape, &min_shape, &max_shape);
|
||||
if (min_shape.size() != 0 && max_shape.size() != 0) {
|
||||
|
@ -76,9 +76,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
for (size_t i = 0; i < input_args.size(); i++) {
|
||||
auto x_type_tmp = input_args[i]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type_tmp);
|
||||
auto x_type = x_type_tmp->cast<TensorTypePtr>();
|
||||
auto x_type = input_args[i]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x_dtype", x_type, valid_x_type, prim_name);
|
||||
|
@ -100,7 +98,9 @@ Format BiasAdd::get_format() const {
|
|||
void BiasAdd::Init(const Format &format) { this->set_format(format); }
|
||||
AbstractBasePtr BiasAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
auto infertype = InferType(primitive, input_args);
|
||||
auto infershape = InferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kGreaterEqual, 2, op_name);
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 2, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 2, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue