!48560 fix bug for syncbatchnorm

Merge pull request !48560 from yangzhenzhang/fix-bug-for-syncbatchnorm
This commit is contained in:
i-robot 2023-02-11 02:25:04 +00:00 committed by Gitee
commit 7a15dc2385
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 28 additions and 0 deletions

View File

@ -28,7 +28,34 @@
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kSyncBatchNormInputNum = 5;
void CheckSyncBatchNormInputNum(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
if (input_args.empty()) {
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name);
return;
}
// the inputs has U
if (!input_args.back()->isa<abstract::AbstractMonad>()) {
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name);
return;
}
// the inputs has not U
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size() - 1), kEqual,
kSyncBatchNormInputNum, prim_name);
for (size_t index = 0; index < input_args.size(); index++) {
if (input_args[index] == nullptr) {
MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
}
}
}
TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckSyncBatchNormInputNum(prim, input_args);
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
@ -53,6 +80,7 @@ TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector<Abst
abstract::TupleShapePtr SyncBatchNormInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckSyncBatchNormInputNum(primitive, input_args);
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape_ptr = input_args[0]->BuildShape();