Add nn.Tril function

This commit is contained in:
l00591931 2020-11-16 20:40:07 +08:00
parent 7a192973ff
commit e1dba1337c
3 changed files with 291 additions and 1 deletions

View File

@ -35,7 +35,7 @@ from .activation import get_activation
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold',
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']
'Tril', 'Triu', 'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag']
class Dropout(Cell):
@ -547,6 +547,80 @@ class Unfold(Cell):
return result
@constexpr
def tril(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "tril")
Validator.check_is_int(k, "k value", "tril")
mask = np.tril(np.ones(x_shape), k)
return Tensor(mask, x_dtype)
class Tril(Cell):
"""
Returns a tensor with elements above the kth diagonal zeroed.
Inputs:
- **x** (Tensor) - The input tensor.
- **k** (Int) - The index of diagonal. Default: 0
Outputs:
Tensor, has the same type as input `x`.
Examples:
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[1 0]
[3 4]]
"""
def __init__(self):
super(Tril, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
def construct(self, x, k=0):
assist = tril(x.shape, self.dtype(x), k)
return self.mul(x, assist)
@constexpr
def triu(x_shape, x_dtype, k):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "triu")
Validator.check_is_int(k, "k value", "triu")
mask = np.triu(np.ones(x_shape), k)
return Tensor(mask, x_dtype)
class Triu(Cell):
"""
Returns a tensor with elements below the kth diagonal zeroed.
Inputs:
- **x** (Tensor) - The input tensor.
- **k** (Int) - The index of diagonal. Default: 0
Outputs:
Tensor, has the same type as input `x`.
Examples:
>>> x = Tensor(np.array([[1, 2], [3, 4]]))
>>> tril = nn.Tril()
>>> result = tril(x)
>>> print(result)
[[1 2]
[0 4]]
"""
def __init__(self):
super(Triu, self).__init__()
self.dtype = P.DType()
self.mul = P.Mul()
def construct(self, x, k=0):
assist = triu(x.shape, self.dtype(x), k)
return self.mul(x, assist)
@constexpr
def _get_matrix_diag_assist(x_shape, x_dtype):
Validator.check_int(len(x_shape), 1, Rel.GE, "x rank", "_get_matrix_diag_assist")

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Tril()
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_tril():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, 0)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 34
def test_tril_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, 1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 42
def test_tril_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
tril = nn.Tril()
return tril(self.value, -1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 19
def test_tril_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, 0)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_tril_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, 1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_tril_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
tril = nn.Tril()
return tril(x, -1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

View File

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test nn.Triu()
"""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_triu():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, 0)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 26
def test_triu_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, 1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 11
def test_triu_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def construct(self):
triu = nn.Triu()
return triu(self.value, -1)
net = Net()
out = net()
assert np.sum(out.asnumpy()) == 38
def test_triu_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
triu = nn.Triu()
return triu(x, 0)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_triu_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
triu = nn.Triu()
return triu(x, 1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
def test_triu_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
triu = nn.Triu()
return triu(x, -1)
net = Net()
net(Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))