raise error for BatchNorm3D.

This commit is contained in:
liuxiao93 2021-03-13 11:04:35 +08:00
parent 64e83c1f26
commit 431ec15bb7
1 changed files with 2 additions and 0 deletions

View File

@ -66,6 +66,8 @@ class _BatchNorm(Cell):
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
else:
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
raise ValueError("NCDHW format only support in Ascend target.")
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics