forked from mindspore-Ecosystem/mindspore
!3209 Fix bug about output of AddNGrad.
Merge pull request !3209 from liuxiao93/AddN
This commit is contained in:
commit
05de773775
|
@ -149,6 +149,8 @@ class _BatchNorm(Cell):
|
|||
def construct(self, x):
|
||||
if self.input_dims == '2d':
|
||||
_shape_check(self.shape(x))
|
||||
if self.input_dims == '1d':
|
||||
_shape_check_2d(self.shape(x))
|
||||
if self.use_batch_statistics is None:
|
||||
flag = self.training
|
||||
else:
|
||||
|
@ -200,6 +202,12 @@ def _channel_check(channel, num_channel):
|
|||
raise ValueError("the input channel is not equal with num_channel")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check_2d(input_shape):
|
||||
if len(input_shape) != 2:
|
||||
raise ValueError("The input must has 2 dims.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _shape_check(in_shape):
|
||||
if len(in_shape) != 4:
|
||||
|
|
|
@ -980,7 +980,8 @@ def get_bprop_scalar_accumulatenv2(self):
|
|||
dx = ()
|
||||
for _ in range(len(x)):
|
||||
dx = dx + (dout,)
|
||||
return dx
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
|
@ -992,7 +993,7 @@ def get_bprop_scalar_addn(self):
|
|||
dx = ()
|
||||
for _ in range(len(x)):
|
||||
dx = dx + (dout,)
|
||||
return dx
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
|
|
@ -1671,13 +1671,11 @@ test_case_array_ops = [
|
|||
('AddN', {
|
||||
'block': NetForTupleInput(P.AddN()),
|
||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 3, 3, 5]],
|
||||
'skip': ['backward']}),
|
||||
'desc_bprop': [[2, 3, 3, 5]]}),
|
||||
('AccumulateNV2', {
|
||||
'block': NetForTupleInput(P.AccumulateNV2()),
|
||||
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
|
||||
'desc_bprop': [[2, 3, 3, 5]],
|
||||
'skip': ['backward']}),
|
||||
'desc_bprop': [[2, 3, 3, 5]]}),
|
||||
('Shape', {
|
||||
'block': P.Shape(),
|
||||
'desc_inputs': [[3, 3, 2, 2]],
|
||||
|
|
|
@ -67,10 +67,10 @@ def test_bn2d():
|
|||
def test_bn1d():
|
||||
"""ut of nn.BatchNorm1d"""
|
||||
bn = nn.BatchNorm1d(3)
|
||||
input_data = Tensor(np.random.randint(0, 1, [1, 3, 100, 100]).astype(np.float32))
|
||||
input_data = Tensor(np.random.randint(0, 1, [1, 3]).astype(np.float32))
|
||||
output = bn(input_data)
|
||||
output_np = output.asnumpy()
|
||||
assert isinstance(output_np[0][0][0][0], (np.float32, np.float64))
|
||||
assert isinstance(output_np[0][0], (np.float32, np.float64))
|
||||
|
||||
|
||||
def test_bn2d_train():
|
||||
|
|
Loading…
Reference in New Issue