diff --git a/mindspore/common/initializer.py b/mindspore/common/initializer.py index 31ca355cb6d..bcf37a8d3d6 100644 --- a/mindspore/common/initializer.py +++ b/mindspore/common/initializer.py @@ -387,6 +387,225 @@ class Constant(Initializer): _assignment(arr, self.value) +@_register() +class Identity(Initializer): + """ + Initialize a 2 dimension identity matrix to fill the input tensor. + + Raises: + ValueError: If the dimension of input tensor is not equal to 2. + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Identity + >>> tensor1 = initializer(Identity(), [2, 3], mindspore.float32) + >>> tensor2 = initializer('identity', [2, 3], mindspore.float32) + """ + def _initialize(self, arr): + if len(arr.shape) != 2: + raise ValueError('For Identity initializer, the dimension of the initialized tensor should be 2, ' + 'but got {}.'.format(len(arr.shape))) + value = np.eye(arr.shape[0], arr.shape[1]) + _assignment(arr, value) + + +@_register() +class Sparse(Initializer): + """ + Initialize a 2 dimension sparse matrix to fill the input tensor. The non-zero positions will be filled with + the value sampled from the normal distribution :math:`{N}(0, 0.01)` + + Args: + sparsity (float): The fraction of elements being set to zero in each column. + sigma (float): The standard deviation of the normal distribution. Default: 0.01. + + Raises: + ValueError: If the dimension of input tensor is not equal to 2. + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Sparse + >>> tensor1 = initializer(Sparse(sparsity=0.1, sigma=0.01), [5, 8], mindspore.float32) + """ + def __init__(self, sparsity, sigma=0.01): + super(Sparse, self).__init__() + self.sparsity = sparsity + self.sigma = sigma + + def _initialize(self, arr): + if len(arr.shape) != 2: + raise ValueError('For Sparse initializer, the dimension of the initialized tensor should be 2, ' + 'but got {}.'.format(len(arr.shape))) + rows, cols = arr.shape + zero_num = int(np.ceil(self.sparsity * rows)) + data = np.random.normal(0, self.sigma, arr.shape) + for col_idx in range(cols): + row_idx = np.random.permutation(list(range(rows)))[: zero_num] + data[row_idx, col_idx] = 0. + _assignment(arr, data) + + +class Dirac(Initializer): + """Initialize input tensor with the Dirac delta function. It tries to preserves the identity of + input for convolution layers. For group convolution, each group of channels will be preserved respectively. + + Args: + groups (int): The number of group in convolution layer. Default: 1. + + Raises: + ValueError: If the value of group is not in [3, 4, 5] or the first dimension of the initialized + tensor cannot be divisible by group. + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Dirac + >>> tensor1 = initializer(Dirac(groups=2), [6, 4, 3, 3], mindspore.float32) + >>> tensor2 = initializer("dirac", [6, 4, 3, 3], mindspore.float32) + """ + + def __init__(self, groups=1): + super(Dirac, self).__init__() + self.groups = groups + + def _initialize(self, arr): + dimension = len(arr.shape) + data = np.zeros(arr.shape) + if dimension not in [3, 4, 5]: + raise ValueError("For Dirac initializer, only support " + "to initialize tensor with dimension of 3, 4 or 5, but got {}.".format(dimension)) + + shapes = arr.shape + if shapes[0] % self.groups != 0: + raise ValueError("For Dirac initializer, the first dimension of" + "the initialized tensor must be divisible by group, " + "but got {}/{}.".format(shapes[0], self.groups)) + + out_channel_per_group = shapes[0] // self.groups + min_dim = min(out_channel_per_group, shapes[1]) + + for group in range(self.groups): + for dim in range(min_dim): + if dimension == 3: + data[group * out_channel_per_group + dim, dim, shapes[2]//2] = 1 + elif dimension == 4: + data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2] = 1 + else: + data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2, shapes[4] // 2] = 1 + _assignment(arr, data) + + +@_register() +class Orthogonal(Initializer): + r""" + Initialize a (semi) orthogonal matrix to fill the input tensor. The dimension of input tensor must have at least 2 + dimensions. If the dimension is greater than 2, the trailing dimensions will be flattened. + + Args: + gain (float): An optional scaling factor. Default: 1. + + Raises: + ValueError: If the dimension of input tensor is less than 2. + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, Orthogonal + >>> tensor1 = initializer(Orthogonal(gain=2.), [2, 3, 4], mindspore.float32) + >>> tensor2 = initializer('orthogonal', [2, 3, 4], mindspore.float32) + """ + def __init__(self, gain=1.): + super(Orthogonal, self).__init__(gain=gain) + self.gain = gain + + def _initialize(self, arr): + if len(arr.shape) < 2: + raise ValueError('For Orthogonal initializer, the dimension of the initialized tensor should' + ' be no less than 2, but got {}.'.format(len(arr.shape))) + rows = arr.shape[0] + + cols = np.prod(arr.shape) // rows + data = np.random.normal(0, 1, size=(rows, cols)) + + if rows < cols: + data = data.T + + q, r = np.linalg.qr(data) + d = np.diag(r) + ph = np.sign(d) + q *= ph + + if rows < cols: + q = q.T + q = q * self.gain + _assignment(arr, q.reshape(arr.shape)) + + +@_register() +class VarianceScaling(Initializer): + r""" + Randomly initialize an array with scaling to fill the input tensor. + When distribution is truncated_normal or untruncated_normal, the value will be sampled from truncated or + untruncated normal distribution with a mean of 0 and a scaled standard deviation :math:`stddev = sqrt(scale/n)`. + :math:`n` will be the number of input units if mode is fan_in, the number of output units if mode is fan_out, + the average of fan_in and fan_out if mode is fan_avg. + When distribution is uniform, the value will be sampled from a uniform distribution within the limit of + [`-sqrt(3*scale/n)`, `sqrt(3*scale/n)`]. + + Args: + scale (float): The scaling factor. Default: 1.0. + mode (str): Should be 'fan_in', 'fan_out' or 'fan_avg'. Default: 'fan_in'. + distribution(str): The type of distribution chose to sample values. Default: 'truncated_normal'. + + Raises: + ValueError: If scale is not greater than 0.. + ValueError: If mode is not fan_in, fan_out or fan_avg. + ValueError: If distribution is not uniform, truncated_normal or untruncated_normal. + + Examples: + >>> import mindspore + >>> from mindspore.common.initializer import initializer, VarianceScaling + >>> tensor1 = initializer(VarianceScaling(scale=1.0, mode='fan_out', + >>> distribution='untruncated_normal'), [2, 3], mindspore.float32) + >>> tensor2 = initializer('varianceScaling', [2, 3], mindspore.float32) + """ + def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal'): + super(VarianceScaling, self).__init__(scale=scale, mode=mode, distribution=distribution) + if scale <= 0.: + raise ValueError("For VarianceScaling initializer, scale must be greater than 0, but got {}.".format(scale)) + + if mode not in ['fan_in', 'fan_out', 'fan_avg']: + raise ValueError('For VarianceScaling initializer, mode must be fan_in, ' + 'fan_out or fan_avg, but got {}.'.format(mode)) + + if distribution not in ['uniform', 'truncated_normal', 'untruncated_normal']: + raise ValueError('For VarianceScaling initializer, distribution must be uniform, ' + 'truncated_norm or untruncated_norm, but got {}.'.format(distribution)) + + self.scale = scale + self.mode = mode + self.distribution = distribution + + def _initialize(self, arr): + scale = self.scale + fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape) + if self.mode == 'fan_in': + scale /= max(1., fan_in) + elif self.mode == 'fan_out': + scale /= max(1., fan_out) + else: + scale /= max(1., (fan_in + fan_out) / 2.) + + if self.distribution == 'truncated_norm': + stddev = np.sqrt(scale) / 0.87962566103423978 + data = truncnorm.rvs(-2, 2, loc=0, scale=stddev, size=arr.shape, random_state=None) + elif self.distribution == 'untruncated_normal': + stddev = np.sqrt(scale) + data = np.random.normal(0, stddev, arr.shape) + else: + limit = np.sqrt(3.0 * scale) + data = np.random.uniform(-limit, limit, arr.shape) + _assignment(arr, data) + + @_register() class Uniform(Initializer): r""" @@ -546,4 +765,9 @@ __all__ = [ 'XavierUniform', 'One', 'Zero', - 'Constant'] + 'Constant', + 'Identity', + 'Sparse', + 'Dirac', + 'Orthogonal', + 'VarianceScaling'] diff --git a/tests/st/initializer/test_initializer.py b/tests/st/initializer/test_initializer.py new file mode 100644 index 00000000000..0c20936df7e --- /dev/null +++ b/tests/st/initializer/test_initializer.py @@ -0,0 +1,307 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore +from mindspore.common.initializer import initializer, Identity, Dirac, Sparse, VarianceScaling, Orthogonal +import numpy as np + + +def test_sparse(): + """ + Feature: Test sparse initializer. + Description: Initialize a 2 dimension sparse matrix to fill the input tensor. + Expectation: The Tensor is initialized with a 2 dimension sparse matrix. + """ + initializer(Sparse(sparsity=0.1, sigma=0.01), [5, 8], mindspore.float32) + + +def test_orthogonal(): + """ + Feature: Test orthogonal initializer. + Description: Initialize a (semi) orthogonal matrix to fill the input tensor. + Expectation: The Tensor is initialized with values from orthogonal matrix. + """ + initializer(Orthogonal(gain=2.), [2, 3, 4], mindspore.float32) + initializer('orthogonal', [2, 3, 4], mindspore.float32) + + +def test_variancescaling(): + """ + Feature: Test varianceScaling initializer. + Description: Randomly initialize an array with scaling to fill the input tensor. + Expectation: The Tensor is initialized successfully. + """ + initializer('varianceScaling', [2, 3], mindspore.float32) + initializer(VarianceScaling(scale=1.0, mode='fan_out', distribution='untruncated_normal'), [2, 3], + mindspore.float32) + initializer(VarianceScaling(scale=2.0, mode='fan_in', distribution='truncated_normal'), [2, 3], + mindspore.float32) + initializer(VarianceScaling(scale=3.0, mode='fan_avg', distribution='uniform'), [2, 3], + mindspore.float32) + + +def test_identity(): + """ + Feature: Test identity initializer. + Description: Initialize an identity matrix to fill a Tensor. + Expectation: The Tensor is initialized with identity matrix. + """ + tensor1 = initializer(Identity(), [3, 3], mindspore.float32) + tensor2 = initializer('identity', [3, 4], mindspore.float32) + tensor3 = initializer('identity', [4, 3], mindspore.float32) + expect1 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + expect2 = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32) + expect3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=np.float32) + assert (tensor1.asnumpy() == expect1).all() + assert (tensor2.asnumpy() == expect2).all() + assert (tensor3.asnumpy() == expect3).all() + + +def test_dirac(): + """ + Feature: Test dirac initializer. + Description: Initialize input tensor with the Dirac delta function. + Expectation: The Tensor is correctly initialized. + """ + tensor3_1 = initializer(Dirac(groups=1), [6, 2, 3], mindspore.float32) + tensor3_2 = initializer(Dirac(groups=2), [6, 2, 3], mindspore.float32) + tensor3_3 = initializer(Dirac(groups=3), [6, 2, 3], mindspore.float32) + + tensor4_1 = initializer(Dirac(groups=1), [6, 4, 3, 3], mindspore.float32) + tensor4_2 = initializer(Dirac(groups=2), [6, 4, 3, 3], mindspore.float32) + tensor4_3 = initializer(Dirac(groups=3), [6, 4, 3, 3], mindspore.float32) + + tensor5_1 = initializer(Dirac(groups=1), [6, 2, 3, 3, 3], mindspore.float32) + tensor5_2 = initializer(Dirac(groups=2), [6, 2, 3, 3, 3], mindspore.float32) + tensor5_3 = initializer(Dirac(groups=3), [6, 2, 3, 3, 3], mindspore.float32) + + expectation3_1 = np.array([[[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.]]], dtype=np.float32) + + expectation3_2 = np.array([[[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]], + [[0., 0., 0.], [0., 0., 0.]]], dtype=np.float32) + + expectation3_3 = np.array([[[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]], + [[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]], + [[0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.]]], dtype=np.float32) + + expectation4_1 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32) + + expectation4_2 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32) + + expectation4_3 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32) + + expectation5_1 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32) + + expectation5_2 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32) + + expectation5_3 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], + [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], + [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32) + + assert (tensor3_1.asnumpy() == expectation3_1).all() + assert (tensor3_2.asnumpy() == expectation3_2).all() + assert (tensor3_3.asnumpy() == expectation3_3).all() + + assert (tensor4_1.asnumpy() == expectation4_1).all() + assert (tensor4_2.asnumpy() == expectation4_2).all() + assert (tensor4_3.asnumpy() == expectation4_3).all() + + assert (tensor5_1.asnumpy() == expectation5_1).all() + assert (tensor5_2.asnumpy() == expectation5_2).all() + assert (tensor5_3.asnumpy() == expectation5_3).all()