forked from mindspore-Ecosystem/mindspore
update mindspore/nn/layer/timedistributed.py.
Fix: reshape with aix bug.
This commit is contained in:
parent
885177769a
commit
640487a974
|
@ -41,7 +41,7 @@ def _check_expand_dims_axis(time_axis, ndim):
|
||||||
def _generate_perm(axis_a, axis_b, length):
|
def _generate_perm(axis_a, axis_b, length):
|
||||||
perm = tuple(range(length))
|
perm = tuple(range(length))
|
||||||
axis_a, axis_b = (axis_a, axis_b) if axis_a < axis_b else (axis_b, axis_a)
|
axis_a, axis_b = (axis_a, axis_b) if axis_a < axis_b else (axis_b, axis_a)
|
||||||
return perm[:axis_a] + perm[axis_a + 1: axis_b + 1] + (perm[axis_a],) + perm[axis_b + 1:]
|
return perm[:axis_a] + (perm[axis_b],) + perm[axis_a: axis_b] + perm[axis_b + 1:]
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
Loading…
Reference in New Issue