forked from mindspore-Ecosystem/mindspore
!48560 fix bug for syncbatchnorm
Merge pull request !48560 from yangzhenzhang/fix-bug-for-syncbatchnorm
This commit is contained in:
commit
7a15dc2385
|
@ -28,7 +28,34 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t kSyncBatchNormInputNum = 5;
|
||||
|
||||
void CheckSyncBatchNormInputNum(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
if (input_args.empty()) {
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name);
|
||||
return;
|
||||
}
|
||||
|
||||
// the inputs has U
|
||||
if (!input_args.back()->isa<abstract::AbstractMonad>()) {
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSyncBatchNormInputNum, prim_name);
|
||||
return;
|
||||
}
|
||||
|
||||
// the inputs has not U
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size() - 1), kEqual,
|
||||
kSyncBatchNormInputNum, prim_name);
|
||||
for (size_t index = 0; index < input_args.size(); index++) {
|
||||
if (input_args[index] == nullptr) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckSyncBatchNormInputNum(prim, input_args);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
|
@ -53,6 +80,7 @@ TuplePtr SyncBatchNormInferType(const PrimitivePtr &prim, const std::vector<Abst
|
|||
|
||||
abstract::TupleShapePtr SyncBatchNormInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckSyncBatchNormInputNum(primitive, input_args);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape_ptr = input_args[0]->BuildShape();
|
||||
|
|
Loading…
Reference in New Issue