forked from mindspore-Ecosystem/mindspore
raise error for BatchNorm3D.
This commit is contained in:
parent
64e83c1f26
commit
431ec15bb7
|
@ -66,6 +66,8 @@ class _BatchNorm(Cell):
|
|||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
||||
else:
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
|
||||
raise ValueError("NCDHW format only support in Ascend target.")
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError("NHWC format only support in GPU target.")
|
||||
self.use_batch_statistics = use_batch_statistics
|
||||
|
|
Loading…
Reference in New Issue