!25791 Add new initializer

Merge pull request !25791 from wanyiming/initilizer
This commit is contained in:
i-robot 2021-11-23 07:14:03 +00:00 committed by Gitee
commit e40f3ec063
2 changed files with 532 additions and 1 deletions

View File

@ -377,6 +377,225 @@ class Constant(Initializer):
_assignment(arr, self.value)
@_register()
class Identity(Initializer):
"""
Initialize a 2 dimension identity matrix to fill the input tensor.
Raises:
ValueError: If the dimension of input tensor is not equal to 2.
Examples:
>>> import mindspore
>>> from mindspore.common.initializer import initializer, Identity
>>> tensor1 = initializer(Identity(), [2, 3], mindspore.float32)
>>> tensor2 = initializer('identity', [2, 3], mindspore.float32)
"""
def _initialize(self, arr):
if len(arr.shape) != 2:
raise ValueError('For Identity initializer, the dimension of the initialized tensor should be 2, '
'but got {}.'.format(len(arr.shape)))
value = np.eye(arr.shape[0], arr.shape[1])
_assignment(arr, value)
@_register()
class Sparse(Initializer):
"""
Initialize a 2 dimension sparse matrix to fill the input tensor. The non-zero positions will be filled with
the value sampled from the normal distribution :math:`{N}(0, 0.01)`
Args:
sparsity (float): The fraction of elements being set to zero in each column.
sigma (float): The standard deviation of the normal distribution. Default: 0.01.
Raises:
ValueError: If the dimension of input tensor is not equal to 2.
Examples:
>>> import mindspore
>>> from mindspore.common.initializer import initializer, Sparse
>>> tensor1 = initializer(Sparse(sparsity=0.1, sigma=0.01), [5, 8], mindspore.float32)
"""
def __init__(self, sparsity, sigma=0.01):
super(Sparse, self).__init__()
self.sparsity = sparsity
self.sigma = sigma
def _initialize(self, arr):
if len(arr.shape) != 2:
raise ValueError('For Sparse initializer, the dimension of the initialized tensor should be 2, '
'but got {}.'.format(len(arr.shape)))
rows, cols = arr.shape
zero_num = int(np.ceil(self.sparsity * rows))
data = np.random.normal(0, self.sigma, arr.shape)
for col_idx in range(cols):
row_idx = np.random.permutation(list(range(rows)))[: zero_num]
data[row_idx, col_idx] = 0.
_assignment(arr, data)
class Dirac(Initializer):
"""Initialize input tensor with the Dirac delta function. It tries to preserves the identity of
input for convolution layers. For group convolution, each group of channels will be preserved respectively.
Args:
groups (int): The number of group in convolution layer. Default: 1.
Raises:
ValueError: If the value of group is not in [3, 4, 5] or the first dimension of the initialized
tensor cannot be divisible by group.
Examples:
>>> import mindspore
>>> from mindspore.common.initializer import initializer, Dirac
>>> tensor1 = initializer(Dirac(groups=2), [6, 4, 3, 3], mindspore.float32)
>>> tensor2 = initializer("dirac", [6, 4, 3, 3], mindspore.float32)
"""
def __init__(self, groups=1):
super(Dirac, self).__init__()
self.groups = groups
def _initialize(self, arr):
dimension = len(arr.shape)
data = np.zeros(arr.shape)
if dimension not in [3, 4, 5]:
raise ValueError("For Dirac initializer, only support "
"to initialize tensor with dimension of 3, 4 or 5, but got {}.".format(dimension))
shapes = arr.shape
if shapes[0] % self.groups != 0:
raise ValueError("For Dirac initializer, the first dimension of"
"the initialized tensor must be divisible by group, "
"but got {}/{}.".format(shapes[0], self.groups))
out_channel_per_group = shapes[0] // self.groups
min_dim = min(out_channel_per_group, shapes[1])
for group in range(self.groups):
for dim in range(min_dim):
if dimension == 3:
data[group * out_channel_per_group + dim, dim, shapes[2]//2] = 1
elif dimension == 4:
data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2] = 1
else:
data[group * out_channel_per_group + dim, dim, shapes[2] // 2, shapes[3] // 2, shapes[4] // 2] = 1
_assignment(arr, data)
@_register()
class Orthogonal(Initializer):
r"""
Initialize a (semi) orthogonal matrix to fill the input tensor. The dimension of input tensor must have at least 2
dimensions. If the dimension is greater than 2, the trailing dimensions will be flattened.
Args:
gain (float): An optional scaling factor. Default: 1.
Raises:
ValueError: If the dimension of input tensor is less than 2.
Examples:
>>> import mindspore
>>> from mindspore.common.initializer import initializer, Orthogonal
>>> tensor1 = initializer(Orthogonal(gain=2.), [2, 3, 4], mindspore.float32)
>>> tensor2 = initializer('orthogonal', [2, 3, 4], mindspore.float32)
"""
def __init__(self, gain=1.):
super(Orthogonal, self).__init__(gain=gain)
self.gain = gain
def _initialize(self, arr):
if len(arr.shape) < 2:
raise ValueError('For Orthogonal initializer, the dimension of the initialized tensor should'
' be no less than 2, but got {}.'.format(len(arr.shape)))
rows = arr.shape[0]
cols = np.prod(arr.shape) // rows
data = np.random.normal(0, 1, size=(rows, cols))
if rows < cols:
data = data.T
q, r = np.linalg.qr(data)
d = np.diag(r)
ph = np.sign(d)
q *= ph
if rows < cols:
q = q.T
q = q * self.gain
_assignment(arr, q.reshape(arr.shape))
@_register()
class VarianceScaling(Initializer):
r"""
Randomly initialize an array with scaling to fill the input tensor.
When distribution is truncated_normal or untruncated_normal, the value will be sampled from truncated or
untruncated normal distribution with a mean of 0 and a scaled standard deviation :math:`stddev = sqrt(scale/n)`.
:math:`n` will be the number of input units if mode is fan_in, the number of output units if mode is fan_out,
the average of fan_in and fan_out if mode is fan_avg.
When distribution is uniform, the value will be sampled from a uniform distribution within the limit of
[`-sqrt(3*scale/n)`, `sqrt(3*scale/n)`].
Args:
scale (float): The scaling factor. Default: 1.0.
mode (str): Should be 'fan_in', 'fan_out' or 'fan_avg'. Default: 'fan_in'.
distribution(str): The type of distribution chose to sample values. Default: 'truncated_normal'.
Raises:
ValueError: If scale is not greater than 0..
ValueError: If mode is not fan_in, fan_out or fan_avg.
ValueError: If distribution is not uniform, truncated_normal or untruncated_normal.
Examples:
>>> import mindspore
>>> from mindspore.common.initializer import initializer, VarianceScaling
>>> tensor1 = initializer(VarianceScaling(scale=1.0, mode='fan_out',
>>> distribution='untruncated_normal'), [2, 3], mindspore.float32)
>>> tensor2 = initializer('varianceScaling', [2, 3], mindspore.float32)
"""
def __init__(self, scale=1.0, mode='fan_in', distribution='truncated_normal'):
super(VarianceScaling, self).__init__(scale=scale, mode=mode, distribution=distribution)
if scale <= 0.:
raise ValueError("For VarianceScaling initializer, scale must be greater than 0, but got {}.".format(scale))
if mode not in ['fan_in', 'fan_out', 'fan_avg']:
raise ValueError('For VarianceScaling initializer, mode must be fan_in, '
'fan_out or fan_avg, but got {}.'.format(mode))
if distribution not in ['uniform', 'truncated_normal', 'untruncated_normal']:
raise ValueError('For VarianceScaling initializer, distribution must be uniform, '
'truncated_norm or untruncated_norm, but got {}.'.format(distribution))
self.scale = scale
self.mode = mode
self.distribution = distribution
def _initialize(self, arr):
scale = self.scale
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr.shape)
if self.mode == 'fan_in':
scale /= max(1., fan_in)
elif self.mode == 'fan_out':
scale /= max(1., fan_out)
else:
scale /= max(1., (fan_in + fan_out) / 2.)
if self.distribution == 'truncated_norm':
stddev = np.sqrt(scale) / 0.87962566103423978
data = truncnorm.rvs(-2, 2, loc=0, scale=stddev, size=arr.shape, random_state=None)
elif self.distribution == 'untruncated_normal':
stddev = np.sqrt(scale)
data = np.random.normal(0, stddev, arr.shape)
else:
limit = np.sqrt(3.0 * scale)
data = np.random.uniform(-limit, limit, arr.shape)
_assignment(arr, data)
@_register()
class Uniform(Initializer):
r"""
@ -534,4 +753,9 @@ __all__ = [
'XavierUniform',
'One',
'Zero',
'Constant']
'Constant',
'Identity',
'Sparse',
'Dirac',
'Orthogonal',
'VarianceScaling']

View File

@ -0,0 +1,307 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore
from mindspore.common.initializer import initializer, Identity, Dirac, Sparse, VarianceScaling, Orthogonal
import numpy as np
def test_sparse():
"""
Feature: Test sparse initializer.
Description: Initialize a 2 dimension sparse matrix to fill the input tensor.
Expectation: The Tensor is initialized with a 2 dimension sparse matrix.
"""
initializer(Sparse(sparsity=0.1, sigma=0.01), [5, 8], mindspore.float32)
def test_orthogonal():
"""
Feature: Test orthogonal initializer.
Description: Initialize a (semi) orthogonal matrix to fill the input tensor.
Expectation: The Tensor is initialized with values from orthogonal matrix.
"""
initializer(Orthogonal(gain=2.), [2, 3, 4], mindspore.float32)
initializer('orthogonal', [2, 3, 4], mindspore.float32)
def test_variancescaling():
"""
Feature: Test varianceScaling initializer.
Description: Randomly initialize an array with scaling to fill the input tensor.
Expectation: The Tensor is initialized successfully.
"""
initializer('varianceScaling', [2, 3], mindspore.float32)
initializer(VarianceScaling(scale=1.0, mode='fan_out', distribution='untruncated_normal'), [2, 3],
mindspore.float32)
initializer(VarianceScaling(scale=2.0, mode='fan_in', distribution='truncated_normal'), [2, 3],
mindspore.float32)
initializer(VarianceScaling(scale=3.0, mode='fan_avg', distribution='uniform'), [2, 3],
mindspore.float32)
def test_identity():
"""
Feature: Test identity initializer.
Description: Initialize an identity matrix to fill a Tensor.
Expectation: The Tensor is initialized with identity matrix.
"""
tensor1 = initializer(Identity(), [3, 3], mindspore.float32)
tensor2 = initializer('identity', [3, 4], mindspore.float32)
tensor3 = initializer('identity', [4, 3], mindspore.float32)
expect1 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32)
expect2 = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=np.float32)
expect3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=np.float32)
assert (tensor1.asnumpy() == expect1).all()
assert (tensor2.asnumpy() == expect2).all()
assert (tensor3.asnumpy() == expect3).all()
def test_dirac():
"""
Feature: Test dirac initializer.
Description: Initialize input tensor with the Dirac delta function.
Expectation: The Tensor is correctly initialized.
"""
tensor3_1 = initializer(Dirac(groups=1), [6, 2, 3], mindspore.float32)
tensor3_2 = initializer(Dirac(groups=2), [6, 2, 3], mindspore.float32)
tensor3_3 = initializer(Dirac(groups=3), [6, 2, 3], mindspore.float32)
tensor4_1 = initializer(Dirac(groups=1), [6, 4, 3, 3], mindspore.float32)
tensor4_2 = initializer(Dirac(groups=2), [6, 4, 3, 3], mindspore.float32)
tensor4_3 = initializer(Dirac(groups=3), [6, 4, 3, 3], mindspore.float32)
tensor5_1 = initializer(Dirac(groups=1), [6, 2, 3, 3, 3], mindspore.float32)
tensor5_2 = initializer(Dirac(groups=2), [6, 2, 3, 3, 3], mindspore.float32)
tensor5_3 = initializer(Dirac(groups=3), [6, 2, 3, 3, 3], mindspore.float32)
expectation3_1 = np.array([[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]],
[[0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.]]], dtype=np.float32)
expectation3_2 = np.array([[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]],
[[0., 0., 0.], [0., 0., 0.]],
[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]],
[[0., 0., 0.], [0., 0., 0.]]], dtype=np.float32)
expectation3_3 = np.array([[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]],
[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]],
[[0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.]]], dtype=np.float32)
expectation4_1 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32)
expectation4_2 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32)
expectation4_3 = np.array([[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]], dtype=np.float32)
expectation5_1 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32)
expectation5_2 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32)
expectation5_3 = np.array([[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]],
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],
[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]],
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]], dtype=np.float32)
assert (tensor3_1.asnumpy() == expectation3_1).all()
assert (tensor3_2.asnumpy() == expectation3_2).all()
assert (tensor3_3.asnumpy() == expectation3_3).all()
assert (tensor4_1.asnumpy() == expectation4_1).all()
assert (tensor4_2.asnumpy() == expectation4_2).all()
assert (tensor4_3.asnumpy() == expectation4_3).all()
assert (tensor5_1.asnumpy() == expectation5_1).all()
assert (tensor5_2.asnumpy() == expectation5_2).all()
assert (tensor5_3.asnumpy() == expectation5_3).all()