!4021 Add ScalarAffine and Softplus bijector

Merge pull request !4021 from XunDeng/scalar_affine_softplus
This commit is contained in:
mindspore-ci-bot 2020-08-11 14:14:35 +08:00 committed by Gitee
commit 6772bbde8d
10 changed files with 735 additions and 5 deletions

View File

@ -21,7 +21,13 @@ The high-level components(Bijectors) used to construct the probabilistic network
from .bijector import Bijector
from .power_transform import PowerTransform
from .exp import Exp
from .scalar_affine import ScalarAffine
from .softplus import Softplus
__all__ = ['Bijector',
'PowerTransform',
'Exp']
__all__ = [
'Bijector',
'PowerTransform',
'Exp',
'ScalarAffine',
'Softplus',
]

View File

@ -14,6 +14,7 @@
# ============================================================================
"""Bijector"""
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
from ..distribution import Distribution
from ..distribution import TransformedDistribution
@ -39,6 +40,9 @@ class Bijector(Cell):
Constructor of bijector class.
"""
super(Bijector, self).__init__()
validator.check_value_type('name', name, [str], 'Bijector')
validator.check_value_type('is_constant_jacobian', is_constant_jacobian, [bool], name)
validator.check_value_type('is_injective', is_injective, [bool], name)
self._name = name
self._dtype = dtype
self._parameters = {}

View File

@ -0,0 +1,116 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Scalar Affine Bijector"""
from mindspore.ops import operations as P
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from .bijector import Bijector
class ScalarAffine(Bijector):
"""
Scalar Affine Bijector.
This Bijector performs the operation: Y = a * X + b, where a is the scale
factor and b is the shift factor.
Args:
scale (float): scale factor. Default: 1.0.
shift (float): shift factor. Default: 0.0.
Examples:
>>> # To initialize a ScalarAffine bijector of scale 1 and shift 2
>>> scalaraffine = nn.probability.bijector.ScalarAffine(1, 2)
>>>
>>> # To use ScalarAffine bijector in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.s1 = nn.probability.bijector.ScalarAffine(1, 2)
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans = self.s1.forward(value)
>>> ans = self.s1.inverse(value)
>>> ans = self.s1.forward_log_jacobian(value)
>>> ans = self.s1.inverse_log_jacobian(value)
"""
def __init__(self,
scale=1.0,
shift=0.0,
name='ScalarAffine'):
"""
Constructor of scalar affine bijector.
"""
param = dict(locals())
validator.check_value_type('scale', scale, [float], name)
validator.check_value_type('shift', shift, [float], name)
self._scale = cast_to_tensor(scale)
self._shift = cast_to_tensor(shift)
super(ScalarAffine, self).__init__(
is_constant_jacobian=True,
is_injective=True,
name=name,
dtype=None,
param=param)
self.log = P.Log()
self.oneslike = P.OnesLike()
@property
def scale(self):
return self._scale
@property
def shift(self):
return self._shift
def extend_repr(self):
str_info = f'scale = {self.scale}, shift = {self.shift}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
r"""
.. math::
f(x) = a * x + b
"""
return self.scale * x + self.shift
def _inverse(self, y):
r"""
.. math::
f(y) = \frac{y - b}{a}
"""
return (y - self.shift) / self.scale
def _forward_log_jacobian(self, value):
r"""
.. math::
f(x) = a * x + b
f'(x) = a
\log(f'(x)) = \log(a)
"""
return self.log(self.scale) * self.oneslike(value)
def _inverse_log_jacobian(self, value):
r"""
.. math::
f(y) = \frac{(y - b)}{a}
f'(x) = \frac{1.0}{a}
\log(f'(x)) = - \log(a)
"""
return -1. * self.log(self.scale) * self.oneslike(value)

View File

@ -0,0 +1,124 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Softplus Bijector"""
from mindspore.ops import operations as P
from mindspore.nn.layer.activation import LogSigmoid
from mindspore._checkparam import Validator as validator
from ..distribution._utils.utils import cast_to_tensor
from .bijector import Bijector
class Softplus(Bijector):
r"""
Softplus Bijector.
This Bijector performs the operation: Y = \frac{\log(1 + e ^ {kX})}{k}, where k is the sharpness factor.
Args:
sharpness (float): scale factor. Default: 1.0.
Examples:
>>> # To initialize a Softplus bijector of sharpness 2
>>> softplus = nn.probability.bijector.Softfplus(2)
>>>
>>> # To use ScalarAffine bijector in a network
>>> class net(Cell):
>>> def __init__(self):
>>> super(net, self).__init__():
>>> self.sp1 = nn.probability.bijector.Softflus(2)
>>>
>>> def construct(self, value):
>>> # Similar calls can be made to other probability functions
>>> # by replacing 'forward' with the name of the function
>>> ans = self.sp1.forward(value)
>>> ans = self.sp1.inverse(value)
>>> ans = self.sp1.forward_log_jacobian(value)
>>> ans = self.sp1.inverse_log_jacobian(value)
"""
def __init__(self,
sharpness=1.0,
name='Softplus'):
param = dict(locals())
validator.check_value_type('sharpness', sharpness, [float], name)
super(Softplus, self).__init__(name=name, param=param)
self._sharpness = cast_to_tensor(sharpness)
self.exp = P.Exp()
self.expm1 = self._expm1_by_step
self.log_sigmoid = LogSigmoid()
self.log = P.Log()
self.sigmoid = P.Sigmoid()
self.softplus = self._softplus
self.inverse_softplus = self._inverse_softplus
def _expm1_by_step(self, x):
"""
Expm1 ops under GPU context.
"""
return self.exp(x) - 1.0
def _softplus(self, x):
return self.log(self.exp(x) + 1.0)
def _inverse_softplus(self, x):
r"""
.. math::
f(x) = \frac{\log(1 + e^{x}))}
f^{-1}(y) = \frac{\log(e^{y} - 1)}
"""
return self.log(self.expm1(x))
@property
def sharpness(self):
return self._sharpness
def extend_repr(self):
str_info = f'sharpness = {self.sharpness}'
return str_info
def shape_mapping(self, shape):
return shape
def _forward(self, x):
scaled_value = self.sharpness * x
return self.softplus(scaled_value) / self.sharpness
def _inverse(self, y):
r"""
.. math::
f(x) = \frac{\log(1 + e^{kx}))}{k}
f^{-1}(y) = \frac{\log(e^{ky} - 1)}{k}
"""
scaled_value = self.sharpness * y
return self.inverse_softplus(scaled_value) / self.sharpness
def _forward_log_jacobian(self, x):
r"""
.. math:
f(x) = \log(1 + e^{kx}) / k
f'(x) = \frac{e^{kx}}{ 1 + e^{kx}}
\log(f'(x)) = kx - \log(1 + e^{kx}) = kx - f(kx)
"""
scaled_value = self.sharpness * x
return self.log_sigmoid(scaled_value)
def _inverse_log_jacobian(self, y):
r"""
.. math:
f(y) = \frac{\log(e^{ky} - 1)}{k}
f'(y) = \frac{e^{ky}}{e^{ky} - 1}
\log(f'(y)) = ky - \log(e^{ky} - 1) = ky - f(ky)
"""
scaled_value = self.sharpness * y
return scaled_value - self.inverse_softplus(scaled_value)

View File

@ -0,0 +1,99 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.bijector = msb.ScalarAffine(scale=2.0, shift=1.0)
def construct(self, x_):
return self.bijector.forward(x_)
def test_forward():
forward = Net()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = 2 * x + 1
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.inverse(x_)
def test_backward():
backward = Net1()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward(Tensor(x, dtype=dtype.float32))
tol = 1e-6
expected = 0.5 * (x - 1.0)
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
forward_jacobian = Net2()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = forward_jacobian(x)
expected = np.log([2.0, 2.0, 2.0, 2.0])
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.bijector = msb.ScalarAffine(shift=1.0, scale=2.0)
def construct(self, x_):
return self.bijector.inverse_log_jacobian(x_)
def test_backward_jacobian():
backward_jacobian = Net3()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = backward_jacobian(x)
expected = np.log([0.5, 0.5, 0.5, 0.5])
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -0,0 +1,99 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
context.set_context(device_target="Ascend")
class Net(nn.Cell):
"""
Test class: forward pass of bijector.
"""
def __init__(self):
super(Net, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.forward(x_)
def test_forward():
forward = Net()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward(Tensor(x, dtype=dtype.float32))
expected = np.log(1 + np.exp(2 * x)) * 0.5
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net1(nn.Cell):
"""
Test class: backward pass of bijector.
"""
def __init__(self):
super(Net1, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.inverse(x_)
def test_backward():
backward = Net1()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2 * x) - 1) * 0.5
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net2(nn.Cell):
"""
Test class: Forward Jacobian.
"""
def __init__(self):
super(Net2, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.forward_log_jacobian(x_)
def test_forward_jacobian():
forward_jacobian = Net2()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = forward_jacobian(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2 * x) / (1 + np.exp(2.0 * x)))
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()
class Net3(nn.Cell):
"""
Test class: Backward Jacobian.
"""
def __init__(self):
super(Net3, self).__init__()
self.bijector = msb.Softplus(sharpness=2.0)
def construct(self, x_):
return self.bijector.inverse_log_jacobian(x_)
def test_backward_jacobian():
backward_jacobian = Net3()
x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32)
ans = backward_jacobian(Tensor(x, dtype=dtype.float32))
expected = np.log(np.exp(2.0 * x) / np.expm1(2.0 * x))
tol = 1e-6
assert (np.abs(ans.asnumpy() - expected) < tol).all()

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test cases for exp"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
@ -21,8 +22,10 @@ from mindspore import dtype
def test_init():
b = msb.Exp()
assert isinstance(b, msb.Bijector)
b = msb.Exp(1.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.Exp(name=0.1)
class Net(nn.Cell):
"""

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test cases for powertransform"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
@ -24,6 +25,12 @@ def test_init():
b = msb.PowerTransform(1)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.PowerTransform(power='power')
with pytest.raises(TypeError):
msb.PowerTransform(name=0.1)
class Net(nn.Cell):
"""
Test class: forward and inverse pass of bijector.

View File

@ -0,0 +1,139 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.ScalarAffine()
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(scale=1.0)
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(shift=2.0)
assert isinstance(b, msb.Bijector)
b = msb.ScalarAffine(3.0, 4.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.ScalarAffine(scale='scale')
with pytest.raises(TypeError):
msb.ScalarAffine(shift='shift')
with pytest.raises(TypeError):
msb.ScalarAffine(name=0.1)
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of ScalarAffine bijector.
"""
net = ForwardBackward()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of ScalarAffine bijector.
"""
net = ForwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.b1 = msb.ScalarAffine(2.0, 1.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1.inverse_log_jacobian(x_)
ans2 = self.b2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of ScalarAffine bijector.
"""
net = BackwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.ScalarAffine(1.0, 0.0)
self.b2 = msb.ScalarAffine()
def construct(self, x_):
ans1 = self.b1('inverse', self.b1('forward', x_))
ans2 = self.b2('inverse', self.b2('forward', x_))
ans3 = self.b1('forward_log_jacobian', x_)
ans4 = self.b2('forward_log_jacobian', x_)
ans5 = self.b1('inverse_log_jacobian', x_)
ans6 = self.b2('inverse_log_jacobian', x_)
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)

View File

@ -0,0 +1,133 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test cases for scalar affine"""
import pytest
import mindspore.nn as nn
import mindspore.nn.probability.bijector as msb
from mindspore import Tensor
from mindspore import dtype
def test_init():
"""
Test initializations.
"""
b = msb.Softplus()
assert isinstance(b, msb.Bijector)
b = msb.Softplus(1.0)
assert isinstance(b, msb.Bijector)
def test_type():
with pytest.raises(TypeError):
msb.Softplus(sharpness='sharpness')
with pytest.raises(TypeError):
msb.Softplus(name=0.1)
class ForwardBackward(nn.Cell):
"""
Test class: forward and backward pass.
"""
def __init__(self):
super(ForwardBackward, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.inverse(self.b1.forward(x_))
ans2 = self.b2.inverse(self.b2.forward(x_))
return ans1 + ans2
def test_forward_and_backward_pass():
"""
Test forward and backward pass of Softplus bijector.
"""
net = ForwardBackward()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class ForwardJacobian(nn.Cell):
"""
Test class: Forward log Jacobian.
"""
def __init__(self):
super(ForwardJacobian, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.forward_log_jacobian(x_)
ans2 = self.b2.forward_log_jacobian(x_)
return ans1 + ans2
def test_forward_jacobian():
"""
Test forward log jacobian of Softplus bijector.
"""
net = ForwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class BackwardJacobian(nn.Cell):
"""
Test class: Backward log Jacobian.
"""
def __init__(self):
super(BackwardJacobian, self).__init__()
self.b1 = msb.Softplus(2.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1.inverse_log_jacobian(x_)
ans2 = self.b2.inverse_log_jacobian(x_)
return ans1 + ans2
def test_backward_jacobian():
"""
Test backward log jacobian of Softplus bijector.
"""
net = BackwardJacobian()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)
class Net(nn.Cell):
"""
Test class: function calls going through construct.
"""
def __init__(self):
super(Net, self).__init__()
self.b1 = msb.Softplus(1.0)
self.b2 = msb.Softplus()
def construct(self, x_):
ans1 = self.b1('inverse', self.b1('forward', x_))
ans2 = self.b2('inverse', self.b2('forward', x_))
ans3 = self.b1('forward_log_jacobian', x_)
ans4 = self.b2('forward_log_jacobian', x_)
ans5 = self.b1('inverse_log_jacobian', x_)
ans6 = self.b2('inverse_log_jacobian', x_)
return ans1 - ans2 + ans3 -ans4 + ans5 - ans6
def test_old_api():
"""
Test old api which goes through construct.
"""
net = Net()
x = Tensor([2.0, 3.0, 4.0, 5.0], dtype=dtype.float32)
ans = net(x)
assert isinstance(ans, Tensor)