From afd8c43c8436481a6d6a1ecc95564d45bce795c7 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Wed, 8 Feb 2023 10:19:44 +0800 Subject: [PATCH] fix bug for syncbatchnorm --- mindspore/core/ops/sync_batch_norm.cc | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mindspore/core/ops/sync_batch_norm.cc b/mindspore/core/ops/sync_batch_norm.cc index f52b5c43c93..f88670c1817 100644 --- a/mindspore/core/ops/sync_batch_norm.cc +++ b/mindspore/core/ops/sync_batch_norm.cc @@ -28,7 +28,34 @@ namespace mindspore { namespace ops { namespace { +constexpr int64_t kSyncBatchNormInputNum = 5; + +void CheckSyncBatchNormInputNum(const PrimitivePtr &prim, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(prim); + auto prim_name = prim->name(); + if (input_args.empty()) { + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name); + return; + } + + // the inputs has U + if (!input_args.back()->isa()) { + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name); + return; + } + + // the inputs has not U + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size() - 1), kEqual, + kSyncBatchNormInputNum, prim_name); + for (size_t index = 0; index < input_args.size(); index++) { + if (input_args[index] == nullptr) { + MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr."; + } + } +} + TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector &input_args) { + CheckSyncBatchNormInputNum(prim, input_args); MS_EXCEPTION_IF_NULL(prim); auto prim_name = prim->name(); const std::set valid_types = {kFloat16, kFloat32}; @@ -53,6 +80,7 @@ TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector &input_args) { + CheckSyncBatchNormInputNum(primitive, input_args); MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); auto x_shape_ptr = input_args[0]->BuildShape();