forked from mindspore-Ecosystem/mindspore
add hypercomplex operators and corresponding tests
fix comments length and function order fix lints add copyrights fix code styles fix pylints add method docs,fix lints fix code checks move hypercomplex upper directory support the complex data type
This commit is contained in:
parent
dbc4265f43
commit
b3b47143a8
|
@ -74,6 +74,7 @@
|
|||
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py" "unused-variable"
|
||||
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "not-callable"
|
||||
"mindspore/mindspore/python/mindspore/hypercomplex" "useless-return"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
HyperComplex operators.
|
||||
|
||||
Note:
|
||||
This feature is a beta feature, and we are still improving its functionality.
|
||||
The interface may be changed or removed in the future.
|
||||
"""
|
||||
import mindspore.hypercomplex.complex
|
||||
import mindspore.hypercomplex.dual
|
||||
import mindspore.hypercomplex.double
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Complex Operators"""
|
||||
from mindspore.hypercomplex.complex.complex_relu import ReLU
|
||||
from mindspore.hypercomplex.complex.complex_operators import Conv1d, Conv2d, Conv3d
|
||||
from mindspore.hypercomplex.complex.complex_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
|
||||
from mindspore.hypercomplex.complex.complex_operators import Dense
|
||||
|
||||
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
|
||||
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
|
||||
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""complex BatchNorm implementation"""
|
||||
from typing import Tuple
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
|
||||
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_imag
|
||||
|
||||
|
||||
class _BatchNormImpl(HCBatchNormImpl):
|
||||
r"""
|
||||
The implementor class of the Batch Normalization layer for complex numbers.
|
||||
|
||||
Implements the functionality specific to complex numbers and needed by the 'BatchNorm' class. This includes:
|
||||
getting the norm of complex number, applying scaling and shift to a complex tensor, and updating the running
|
||||
mean and variance, which are used during inference.
|
||||
|
||||
Args:
|
||||
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Applies complex scaling and shift to an input tensor.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} - \text{Im(inp)} * \text{Im(scale)} + \text{Re(shift)}\\
|
||||
\text{Im(out)} = \text{Re(inp)} * \text{Im(scale)} + \text{Im(inp)} * \text{Re(scale)} + \text{Im(shift)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the complex input tensors, :math:`scale` and :math:`shift` are complex parameters
|
||||
representing the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
|
||||
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the normalized inputs.
|
||||
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
|
||||
scale_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the scaling vector.
|
||||
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
|
||||
shift_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the bias vector.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the imaginary parts of rescaled and
|
||||
recentered inputs.
|
||||
"""
|
||||
out_x = u_x * scale_x - u_y * scale_y + shift_x
|
||||
out_y = u_x * scale_y + u_y * scale_x + shift_y
|
||||
return out_x, out_y
|
||||
|
||||
def get_norm(self,
|
||||
u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates norm of complex elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \sqrt{\text{Re(inp)}^2 + \text{Im(inp)}^2 + \delta},
|
||||
|
||||
where :math:`inp` is the complex input tensors and :math:`\delta` is a small positive constant, which is needed
|
||||
to avoid division by zero in case statistical variance is close to zero. :math:`\text{Re(...)}` and
|
||||
:math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued expression inside
|
||||
the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the complex domain
|
||||
and has a real and an imaginary parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
norm2 = self.get_square_norm(u)
|
||||
eps = 1e-7
|
||||
out = ops.sqrt(norm2 + eps)
|
||||
return out
|
||||
|
||||
def get_square_norm(self,
|
||||
u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates element-wise squared norm of complex elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{Re(inp)}^2 + \text{Im(inp)}^2,
|
||||
|
||||
where :math:`inp` is the complex input tensors, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
|
||||
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the complex domain
|
||||
and has a real and an imaginary parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
u_r, u_i = get_real_and_imag(u)
|
||||
out = u_r ** 2 + u_i ** 2
|
||||
return out
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""complex convolution implementation"""
|
||||
import numbers
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import Initializer
|
||||
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
|
||||
from mindspore import ops as P
|
||||
|
||||
|
||||
class _ConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for complex numbers.
|
||||
|
||||
Applies complex-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})
|
||||
- \text{ccor}(\text{Im(kernel)}, \text{Im(inp)})\\
|
||||
\text{Im(ccor)} = \text{ccor}(\text{Im(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Im(inp)})
|
||||
\end{align}
|
||||
|
||||
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
|
||||
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input.
|
||||
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents the real and the imaginary parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
out_rr = conv_op(real, self.weight_x)
|
||||
out_ii = conv_op(imag, self.weight_y)
|
||||
out_ri = conv_op(real, self.weight_y)
|
||||
out_ir = conv_op(imag, self.weight_x)
|
||||
|
||||
out_r = out_rr - out_ii
|
||||
out_i = out_ri + out_ir
|
||||
|
||||
return out_r, out_i
|
||||
|
||||
|
||||
class _KaratsubaConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for complex numbers.
|
||||
|
||||
Applies complex-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{C1} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
|
||||
\text{C2} = \text{ccor}(\text{Im(kernel)}, \text{Im(inp)})\\
|
||||
\text{C3} = \text{ccor}(\text{Re(kernel)} + \text{Im(kernel)}, \text{Re(inp)} + \text{Im(inp)})\\
|
||||
\text{Re(out)} = C1 - C2\\
|
||||
\text{Im(out)} = C3 - C1 - C2,
|
||||
\end{align}
|
||||
|
||||
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
|
||||
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input.
|
||||
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents the real and the imaginary parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
c1 = conv_op(real, self.weight_x)
|
||||
c2 = conv_op(imag, self.weight_y)
|
||||
c3 = conv_op(real + imag, self.weight_x + self.weight_y)
|
||||
|
||||
out_r = c1 - c2
|
||||
out_i = c3 - c1 - c2
|
||||
|
||||
return out_r, out_i
|
||||
|
||||
|
||||
class _ReImConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for complex numbers.
|
||||
|
||||
Applies complex-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{inp_cat} = \text{cat}(\text{Re(inp)}, \text{Im(inp)}) \\
|
||||
\text{K1} = \text{cat}(\text{Re(kernel)}, \text{-Im(kernel)}) \\
|
||||
\text{K2} = \text{cat}(\text{Im(kernel)}, \text{Re(kernel)}) \\
|
||||
\text{Re(ccor)} = \text{ccor}(\text{K1}, \text{Re(inp_cat)}) \\
|
||||
\text{Im(ccor)} = \text{ccor}(\text{K2}, \text{Re(inp_cat)})
|
||||
\end{align}
|
||||
|
||||
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, :math:`\text{cat}` is concatenation along the channel axis,
|
||||
:math:`\text{Re(...)}` and :math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued
|
||||
expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
factory_kwargs (dict): Additional parameters, which must include data_format.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input.
|
||||
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents the real and the imaginary parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_ReImConvImpl, self).__init__(weight_init, weight_shape, **factory_kwargs)
|
||||
data_format = factory_kwargs.get('data_format', 'nchw')
|
||||
c_idx = data_format.lower().find('c')
|
||||
if c_idx < 0:
|
||||
raise ValueError(f"Data format {data_format} is unsupported")
|
||||
self.concat = P.Concat(c_idx)
|
||||
self.neg = P.Neg()
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
inp = self.concat([real, imag])
|
||||
weight_y_neg = self.neg(self.weight_y)
|
||||
w1 = self.concat([self.weight_x, weight_y_neg])
|
||||
w2 = self.concat([self.weight_y, self.weight_x])
|
||||
out_r = conv_op(inp, w1)
|
||||
out_i = conv_op(inp, w2)
|
||||
return out_r, out_i
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""complex dense implementation"""
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
|
||||
|
||||
|
||||
class _DenseImpl(BaseDenseImpl):
|
||||
r"""
|
||||
The implementor class of the dense connected layer for complex numbers.
|
||||
|
||||
Applies complex-valued matrix multiplication for dense connected layer. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)} - \text{Im(inp)} * \text{Im(kernel)}\\
|
||||
\text{Im(out)} = \text{Re(inp)} * \text{Im(kernel)} + \text{Im(inp)} * \text{Re(kernel)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the complex input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
|
||||
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **imag** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
|
||||
parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
out_rr = matmul_op(real, self.weight_x)
|
||||
out_ii = matmul_op(imag, self.weight_y)
|
||||
out_ri = matmul_op(real, self.weight_y)
|
||||
out_ir = matmul_op(imag, self.weight_x)
|
||||
|
||||
out_r = out_rr - out_ii
|
||||
out_i = out_ri + out_ir
|
||||
|
||||
return out_r, out_i
|
||||
|
||||
|
||||
class _KaratsubaDenseImpl(BaseDenseImpl):
|
||||
r"""
|
||||
The implementor class of the dense connected layer for complex numbers.
|
||||
|
||||
Applies complex-valued matrix multiplication for dense connected layer. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{L1} = \text{Re(inp)} * \text{Re(kernel)}\\
|
||||
\text{L2} = \text{Im(inp)} * \text{Im(kernel)}\\
|
||||
\text{L3} = (\text{Re(inp)} + \text{Im(inp)}) * (\text{Re(kernel)} + \text{Im(kernel)})\\
|
||||
\text{Re(out)} = L1 - L2\\
|
||||
\text{Im(out)} = L3 - L1 - L2,
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the complex input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a complex bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
|
||||
:math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued expression inside
|
||||
the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **imag** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
|
||||
parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
l1 = matmul_op(real, self.weight_x)
|
||||
l2 = matmul_op(imag, self.weight_y)
|
||||
l3 = matmul_op(real + imag, self.weight_x + self.weight_y)
|
||||
|
||||
out_r = l1 - l2
|
||||
out_i = l3 - l1 - l2
|
||||
|
||||
return out_r, out_i
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,62 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""complex ReLU implementation"""
|
||||
import mindspore
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_imag, to_2channel as to_complex
|
||||
|
||||
|
||||
class ReLU(nn.Cell):
|
||||
r"""
|
||||
Rectified Linear Unit activation function for complex-valued input.
|
||||
|
||||
Applies ReLU activation layer for the complex-valued input. This layer applies the element-wise
|
||||
:math:`\max(0, x)` for both real and imaginary parts of the input tensor independently:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = (Re(inp))^+ = \max(0, Re(inp))\\
|
||||
\text{Im(out)} = (Im(inp))^+ = \max(0, Im(inp)),
|
||||
\end{align}
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - The input of ReLU is a Tensor of shape (2, *, ..., *), with float16 or float32 data type,
|
||||
or (*, ..., *), with complex64 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same data type and shape as the `inp`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `inp` is not float16, float32, or complex64.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize ReLU."""
|
||||
super(ReLU, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
if u.dtype == mindspore.complex64:
|
||||
real, imag = get_real_and_imag(u)
|
||||
real = self.relu(real)
|
||||
imag = self.relu(imag)
|
||||
out = to_complex(real, imag, u.dtype)
|
||||
else:
|
||||
out = self.relu(u)
|
||||
return out
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Double Operators"""
|
||||
from mindspore.hypercomplex.double.double_operators import Conv1d, Conv2d, Conv3d
|
||||
from mindspore.hypercomplex.double.double_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
|
||||
from mindspore.hypercomplex.double.double_operators import Dense
|
||||
from mindspore.hypercomplex.double.double_operators import ReLU
|
||||
|
||||
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
|
||||
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
|
||||
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Double BatchNorm implementation"""
|
||||
from typing import Tuple
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
|
||||
from mindspore.hypercomplex.utils import get_x_and_y
|
||||
|
||||
get_real_and_double = get_x_and_y
|
||||
get_u1_and_u2 = get_x_and_y
|
||||
|
||||
|
||||
class _BatchNormImpl(HCBatchNormImpl):
|
||||
r"""
|
||||
The implementor class of the Batch Normalization layer for double numbers in regular representation.
|
||||
|
||||
Implements the functionality specific to double numbers and needed by the 'BatchNorm' class. This includes:
|
||||
getting the norm of double number, applying scaling and shift to a double-valued tensor, and updating the running
|
||||
mean and variance, which are used during inference.
|
||||
|
||||
Args:
|
||||
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Applies double scaling and shift to an input tensor in regular representation.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} + \text{Db(inp)} * \text{Db(scale)} + \text{Re(shift)}\\
|
||||
\text{Db(out)} = \text{Re(inp)} * \text{Db(scale)} + \text{Db(inp)} * \text{Re(scale)} + \text{Db(shift)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the double input tensors, :math:`scale` and :math:`shift` are double parameters
|
||||
representing the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
|
||||
are respectively real and double parts of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the double part of the normalized inputs.
|
||||
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
|
||||
scale_y (Tensor): A tensor of shape (C,), which represents the double part of the scaling vector.
|
||||
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
|
||||
shift_y (Tensor): A tensor of shape (C,), which represents the double part of the bias vector.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the double parts of rescaled and
|
||||
recentered inputs.
|
||||
"""
|
||||
out_x = u_x * scale_x + u_y * scale_y + shift_x
|
||||
out_y = u_x * scale_y + u_y * scale_x + shift_y
|
||||
return out_x, out_y
|
||||
|
||||
def get_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates norm of double elements of an input tensor in regular representation.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = |Re(inp)| + |Db(inp)|,
|
||||
|
||||
where :math:`inp` is the double input tensors, :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
|
||||
are respectively real and double parts of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
|
||||
and has two components.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
u_r, u_d = get_real_and_double(u)
|
||||
abs_op = ops.Abs()
|
||||
out = abs_op(u_r) + abs_op(u_d)
|
||||
return out
|
||||
|
||||
def get_square_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates element-wise squared norm of double elements of an input tensor in regular representation.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \left(|Re(inp)| + |Db(inp)|\right)^2,
|
||||
|
||||
where :math:`inp` is the double input tensors, :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
|
||||
are respectively real and double parts of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
|
||||
and has a real and a double parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
norm = self.get_norm(u)
|
||||
out = norm ** 2
|
||||
return out
|
||||
|
||||
|
||||
class _J1J2BatchNormImpl(HCBatchNormImpl):
|
||||
r"""
|
||||
The implementor class of the Batch Normalization layer for double numbers in diagonal representation.
|
||||
|
||||
Implements the functionality specific to double numbers and needed by the 'BatchNorm' class. This includes:
|
||||
getting the norm of double number, applying scaling and shift to a double-valued tensor, and updating the running
|
||||
mean and variance, which are used during inference.
|
||||
|
||||
Args:
|
||||
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Applies double scaling and shift to an input tensor in diagonal representation.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{X(inp)} * \text{Y(scale)} + \text{X(shift)}\\
|
||||
\text{Db(out)} = \text{X(inp)} * \text{Y(scale)} + \text{Y(inp)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the double input tensors in diagonal form, :math:`scale` and :math:`shift` are
|
||||
double parameters representing the scaling and shift coefficients respectively. :math:`\text{X(...)}`
|
||||
and :math:`\text{Y(...)}` are respectively the first and the second components of the double-valued
|
||||
expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the first part of the normalized inputs.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the second part of the normalized inputs.
|
||||
scale_x (Tensor): A tensor of shape (C,), which represents the first part of the scaling vector.
|
||||
scale_y (Tensor): A tensor of shape (C,), which represents the second part of the scaling vector.
|
||||
shift_x (Tensor): A tensor of shape (C,), which represents the first part of the bias vector.
|
||||
shift_y (Tensor): A tensor of shape (C,), which represents the second part of the bias vector.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the first and the second parts of rescaled and
|
||||
recentered inputs in the diagonal representation.
|
||||
"""
|
||||
out_x = u_x * scale_x + shift_x
|
||||
out_y = u_y * scale_y + shift_y
|
||||
return out_x, out_y
|
||||
|
||||
def get_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates norm of double elements of an input tensor in diagonal representation.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{max}(|X(inp)|, |Y(inp)|),
|
||||
|
||||
where :math:`inp` is the double input tensors in diagonal form, :math:`\text{max}` is the maximum value of its
|
||||
arguments. :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the second components
|
||||
of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
|
||||
and has two components.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
u_1, u_2 = get_u1_and_u2(u)
|
||||
abs_op = ops.Abs()
|
||||
max_op = ops.Maximum()
|
||||
out = max_op(abs_op(u_1), abs_op(u_2))
|
||||
return out
|
||||
|
||||
def get_square_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates element-wise squared norm of double elements of an input tensor in diagonal representation.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \left(\text{max}(|X(inp)|, |Y(inp)|)\right)^2,
|
||||
|
||||
where :math:`inp` is the double input tensors in diagonal form, :math:`\text{max}` is the maximum value of its
|
||||
arguments. :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the second components
|
||||
of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
|
||||
and has two components.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
norm = self.get_norm(u)
|
||||
out = norm ** 2
|
||||
return out
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Double convolution implementation"""
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
|
||||
|
||||
|
||||
class _ConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for double numbers in regular representation.
|
||||
|
||||
Applies double-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Db(kernel)}, \text{Db(inp)})\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Db(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Db(inp)}),
|
||||
\end{align}
|
||||
|
||||
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer. :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
|
||||
are respectively the first and the second parts of the double-valued expression inside the parentheses in the
|
||||
regular form.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and double parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the double convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the first part of the input.
|
||||
- **double** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the second part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents both the first and the second parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
double: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
u1 = real + double
|
||||
u2 = real - double
|
||||
|
||||
out1 = conv_op(u1, self.weight_x)
|
||||
out2 = conv_op(u2, self.weight_y)
|
||||
|
||||
out_r = out1 + out2
|
||||
out_d = out1 - out2
|
||||
return out_r, out_d
|
||||
|
||||
|
||||
class _J1J2ConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for double numbers in diagonal representation.
|
||||
|
||||
Applies double-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(ccor)} = \text{ccor}(\text{X(kernel)}, \text{X(inp)}) + \text{X(bias)}\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Y(kernel)}, \text{Y(inp)}) + \text{Y(bias)},
|
||||
\end{align}
|
||||
|
||||
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer. :math:`\text{X(...)}` and :math:`\text{Y(...)}`
|
||||
are respectively the first and the second parts of the double-valued expression inside the parentheses in the
|
||||
diagonal form.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and double parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the double convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **u1** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the first part of the input.
|
||||
- **u2** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the second part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents both the first and the second parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
u1: Tensor,
|
||||
u2: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
out1 = conv_op(u1, self.weight_x)
|
||||
out2 = conv_op(u2, self.weight_y)
|
||||
|
||||
return out1, out2
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Double dense implementation"""
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
|
||||
|
||||
|
||||
class _DenseImpl(BaseDenseImpl):
|
||||
r"""
|
||||
The implementor class of the dense connected layer for double numbers in normal representation.
|
||||
|
||||
Applies double-valued matrix multiplication for dense connected layer. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{X(out)} = \text{X(inp)} * \text{X(kernel)} + \text{Y(inp)} * \text{Y(kernel)}\\
|
||||
\text{Y(out)} = \text{X(inp)} * \text{Y(kernel)} + \text{Y(inp)} * \text{X(kernel)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a double bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{X(...)}` and
|
||||
:math:`\text{Y(...)}` are respectively the first and the second parts of the double-valued expression
|
||||
inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **double** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the
|
||||
input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
|
||||
parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
real: Tensor,
|
||||
double: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
u1 = real + double
|
||||
u2 = real - double
|
||||
|
||||
out1 = matmul_op(u1, self.weight_x)
|
||||
out2 = matmul_op(u2, self.weight_y)
|
||||
|
||||
out_r = out1 + out2
|
||||
out_d = out1 - out2
|
||||
|
||||
return out_r, out_d
|
||||
|
||||
|
||||
class _J1J2DenseImpl(BaseDenseImpl):
|
||||
r"""
|
||||
The implementor class of the dense connected layer for double numbers in the diagonal representation.
|
||||
|
||||
Applies double-valued matrix multiplication for dense connected layer. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{X(out)} = \text{X(inp)} * \text{X(kernel)}\\
|
||||
\text{Y(out)} = \text{Y(inp)} * \text{Y(kernel)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the double input tensors in the diagonal form, :math:`\text{kernel}` is a double weight matrix
|
||||
in the diagonal form with the same data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is
|
||||
a double bias vector in the diagonal form with the same data type as the :math:`inp` created by the layer
|
||||
(only if has_bias is True). :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the
|
||||
second parts of the double-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **u1** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **u2** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
|
||||
parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
u1: Tensor,
|
||||
u2: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
out1 = matmul_op(u1, self.weight_x)
|
||||
out2 = matmul_op(u2, self.weight_y)
|
||||
|
||||
return out1, out2
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Double relu operators"""
|
||||
from mindspore import nn, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.hypercomplex.utils import get_x_and_y as get_u1_and_u2, \
|
||||
to_2channel as to_double
|
||||
|
||||
|
||||
class J1J2ReLU(nn.Cell):
|
||||
r"""
|
||||
Rectified Linear Unit activation function for double-valued input in the diagonal representation.
|
||||
|
||||
Applies ReLU activation layer for the double-valued input. This layer first converts the input to the regular form:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(inp)} = 0.5 * (\text{X(inp)} + \text{Y(inp)})\\
|
||||
\text{Db(inp)} = 0.5 * (\text{X(inp)} - \text{Y(inp)}),
|
||||
\end{align}
|
||||
|
||||
then applies the element-wise :math:`\max(0, x)` for both real and double parts of the input tensor independently:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = (Re(inp))^+ = \max(0, Re(inp))\\
|
||||
\text{Db(out)} = (Db(inp))^+ = \max(0, Db(inp)),
|
||||
\end{align}
|
||||
|
||||
and finally transfers the result back to the diagonal representation:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{X(out)} = \text{Re(out)} + \text{Db(out)}\\
|
||||
\text{Y(out)} = \text{Re(out)} - \text{Db(out)}
|
||||
\end{align}
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - The input of ReLU is a Tensor of shape (2, *, ..., *). The data type is
|
||||
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `inp`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `inp` is not a number.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize J1J2ReLU."""
|
||||
super(J1J2ReLU, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
u = u / 2
|
||||
u1, u2 = get_u1_and_u2(u)
|
||||
x = self.relu(u1 + u2)
|
||||
y = self.relu(u1 - u2)
|
||||
out1 = x + y
|
||||
out2 = x - y
|
||||
out = to_double(out1, out2)
|
||||
return out
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dual operators"""
|
||||
from mindspore.nn import ReLU
|
||||
|
||||
from mindspore.hypercomplex.dual.dual_operators import Conv1d, Conv2d, Conv3d
|
||||
from mindspore.hypercomplex.dual.dual_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
|
||||
from mindspore.hypercomplex.dual.dual_operators import Dense
|
||||
|
||||
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
|
||||
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
|
||||
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dual BatchNorm Implementation"""
|
||||
from typing import Tuple
|
||||
|
||||
import mindspore.ops as ops
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
|
||||
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_dual
|
||||
|
||||
|
||||
class _BatchNormImpl(HCBatchNormImpl):
|
||||
r"""
|
||||
The implementor class of the Batch Normalization layer for dual numbers.
|
||||
|
||||
Implements the functionality specific to dual numbers and needed by the 'BatchNorm' class. This includes:
|
||||
getting the norm of dual number, applying scaling and shift to a dual tensor, and updating the running
|
||||
mean and variance, which are used during inference.
|
||||
|
||||
Args:
|
||||
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Applies dual scaling and shift to an input tensor.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} + \text{Re(shift)}\\
|
||||
\text{Du(out)} = \text{Re(inp)} * \text{Du(scale)} + \text{Du(inp)} * \text{Re(scale)} + \text{Du(shift)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`scale` and :math:`shift` are dual parameters representing
|
||||
the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are
|
||||
respectively real and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the dual part of the normalized inputs.
|
||||
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
|
||||
scale_y (Tensor): A tensor of shape (C,), which represents the dual part of the scaling vector.
|
||||
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
|
||||
shift_y (Tensor): A tensor of shape (C,), which represents the dual part of the bias vector.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the dual parts of rescaled and
|
||||
recentered inputs.
|
||||
"""
|
||||
out_x = u_x * scale_x + shift_x
|
||||
out_y = u_x * scale_y + u_y * scale_x + shift_y
|
||||
return out_x, out_y
|
||||
|
||||
def get_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates norm of dual elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \left|\frac{Du(inp)}{2}\right|+\sqrt{Re(inp)^2+\frac{Du(inp)^2}{4}+\delta},
|
||||
|
||||
where :math:`inp` is the dual input tensors and :math:`\delta` is a small positive constant, which is needed
|
||||
to avoid division by zero in case statistical variance is close to zero. :math:`\text{Re(...)}` and
|
||||
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
|
||||
the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the dual domain
|
||||
and has a real and a dual parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
u_r, u_d = get_real_and_dual(u)
|
||||
dual_half = u_d.abs() / 2
|
||||
eps = 1e-7
|
||||
sqrt = u_r ** 2 + dual_half ** 2 + eps
|
||||
sqrt = ops.sqrt(sqrt)
|
||||
out = dual_half + sqrt
|
||||
return out
|
||||
|
||||
def get_square_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates element-wise squared norm of dual elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \left(\left|\frac{Du(inp)}{2}\right|+\sqrt{Re(inp)^2+\frac{Du(inp)^2}{4}+\delta}\right)^2,
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
|
||||
are respectively real and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the dual domain
|
||||
and has a real and a dual parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
norm = self.get_norm(u)
|
||||
out = norm ** 2
|
||||
return out
|
|
@ -0,0 +1,138 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dual Convolution Implementation"""
|
||||
import numbers
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import Initializer
|
||||
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
|
||||
from mindspore import ops as P
|
||||
|
||||
|
||||
class _ConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for dual numbers.
|
||||
|
||||
Applies dual-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}),
|
||||
\end{align}
|
||||
|
||||
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer. :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
|
||||
are respectively real and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the dual convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input.
|
||||
- **dual** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the dual part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents the real and the dual parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
dual: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
out_r = conv_op(real, self.weight_x)
|
||||
out_rd = conv_op(real, self.weight_y)
|
||||
out_dr = conv_op(dual, self.weight_x)
|
||||
|
||||
out_d = out_rd + out_dr
|
||||
return out_r, out_d
|
||||
|
||||
|
||||
class _ReDuConvImpl(BaseConvImpl):
|
||||
r"""
|
||||
The implementor class of the convolution layer for dual numbers.
|
||||
|
||||
Applies dual-valued convolution transformation. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{inp_cat} = \text{cat}(\text{Re(inp)}, \text{Du(inp)}) \\
|
||||
\text{K} = \text{cat}(\text{Du(kernel)}, \text{Re(kernel)}) \\
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{K}, \text{Re(inp_cat)})
|
||||
\end{align}
|
||||
|
||||
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, :math:`\text{cat}` is concatenation along the channel axis.
|
||||
:math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression
|
||||
inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
|
||||
of the dual convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input.
|
||||
- **dual** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the dual part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
|
||||
which represents the real and the dual parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_ReDuConvImpl, self).__init__(weight_init, weight_shape, **factory_kwargs)
|
||||
data_format = factory_kwargs.get('data_format', 'nchw')
|
||||
c_idx = data_format.lower().find('c')
|
||||
if c_idx < 0:
|
||||
raise ValueError(f"Data format {data_format} is unsupported")
|
||||
self.concat = P.Concat(c_idx)
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
real: Tensor,
|
||||
imag: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
out_r = conv_op(real, self.weight_x)
|
||||
inp = self.concat([real, imag])
|
||||
w = self.concat([self.weight_y, self.weight_x])
|
||||
out_d = conv_op(inp, w)
|
||||
return out_r, out_d
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dual Dense Implementation"""
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
|
||||
|
||||
|
||||
class _DenseImpl(BaseDenseImpl):
|
||||
r"""
|
||||
The implementor class of the dense connected layer for dual numbers.
|
||||
|
||||
Applies dual-valued matrix multiplication for dense connected layer. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)}\\
|
||||
\text{Du(out)} = \text{Re(inp)} * \text{Du(kernel)} + \text{Du(inp)} * \text{Re(kernel)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{kernel}` is
|
||||
a hypercomplex weight matrix with the same data type as the :math:`inp` created by the layer,
|
||||
:math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued
|
||||
expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the dual linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **dual** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the dual part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the dual
|
||||
parts of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
real: Tensor,
|
||||
dual: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
out_r = matmul_op(real, self.weight_x)
|
||||
out_rd = matmul_op(real, self.weight_y)
|
||||
out_dr = matmul_op(dual, self.weight_x)
|
||||
|
||||
out_d = out_rd + out_dr
|
||||
return out_r, out_d
|
|
@ -0,0 +1,958 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dual Operators"""
|
||||
import numbers
|
||||
from typing import Union
|
||||
from mindspore.common.initializer import Initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
# Batch Normalization
|
||||
from mindspore.hypercomplex.hypercomplex.hc_bn import BatchNorm1d as HBatchNorm1d, \
|
||||
BatchNorm2d as HBatchNorm2d, BatchNorm3d as HBatchNorm3d
|
||||
from mindspore.hypercomplex.dual._dual_bn_impl import _BatchNormImpl as BatchNormImpl
|
||||
# Convolution
|
||||
from mindspore.hypercomplex.hypercomplex.hc_conv import Conv1d as HConv1d, \
|
||||
Conv2d as HConv2d, Conv3d as HConv3d
|
||||
from mindspore.hypercomplex.dual._dual_conv_impl import _ReDuConvImpl as ConvImpl
|
||||
# Dense
|
||||
from mindspore.hypercomplex.hypercomplex.hc_dense import Dense as HDense
|
||||
from mindspore.hypercomplex.dual._dual_dense_impl import _DenseImpl as DenseImpl
|
||||
from mindspore.hypercomplex.hypercomplex.uniform_operator import _UniformOperator
|
||||
|
||||
from mindspore.hypercomplex.utils import _size_1_t, _size_2_t, _size_3_t
|
||||
|
||||
|
||||
class Conv2d(_UniformOperator):
|
||||
r"""
|
||||
2D convolution layer on the dual-valued input.
|
||||
|
||||
Calculates the 2D convolution on the input tensor which is typically of shape
|
||||
:math:`(2, N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size, :math:`C_{in}` is a number of channels,
|
||||
:math:`H_{in}, W_{in}` are the height and width of the feature layer respectively. The formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
|
||||
\sum_{k = 0}^{C_{in} - 1} \text{hccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
|
||||
|
||||
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
|
||||
the output and :math:`j` is in the range of :math:`[0, C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
|
||||
is a convolution kernel slice with shape :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`,
|
||||
where :math:`\text{kernel_size[0]}` and :math:`\text{kernel_size[1]}` are the height and width of the convolution
|
||||
kernel respectively. :math:`\text{bias}` is the bias parameter and :math:`\text{inp}` is the input tensor.
|
||||
In this case, `data_format` of the input tensor is 'NCHW' and the shape of full convolution kernel is
|
||||
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]})`,
|
||||
where `group` is the number of groups to split the input `inp` in the channel dDuension. If `data_format` of the
|
||||
input tensor is 'NHWC', the shape of full convolution kernel will be
|
||||
:math:`(C_{out}, \text{kernel_size[0]}, \text{kernel_size[1]}), C_{in} / \text{group}`.
|
||||
:math:`hccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
|
||||
This implies the operation as:
|
||||
|
||||
.. math::
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
|
||||
|
||||
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}`
|
||||
and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
|
||||
the parentheses.
|
||||
|
||||
For more details, please refers to the paper `Gradient Based Learning Applied to Document
|
||||
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||
|
||||
Note:
|
||||
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
|
||||
That is, when `group > 1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
|
||||
|
||||
'NCHW' format is supported only with GPU target device as of now.
|
||||
|
||||
Args:
|
||||
in_channels (int): The channel number of the input tensor of the Conv2d layer.
|
||||
out_channels (int): The channel number of the output tensor of the Conv2d layer.
|
||||
kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution kernel.
|
||||
The data type is an integer or a tuple of two integers. An integer represents the height
|
||||
and width of the convolution kernel. A tuple of two integers represents the height
|
||||
and width of the convolution kernel respectively.
|
||||
stride (Union[int, tuple[int]]): The movement stride of the 2D convolution kernel.
|
||||
The data type is an integer or a tuple of two integers. An integer represents the movement step size
|
||||
in both height and width directions. A tuple of two integers represents the movement step size in
|
||||
the height and width directions respectively. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are
|
||||
"same", "valid", "pad". Default: "same".
|
||||
|
||||
- same: The width of the output is the same as the value of the input divided by `stride`.
|
||||
If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
|
||||
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
|
||||
If this mode is set, the value of `padding` must be greater than or equal to 0.
|
||||
|
||||
padding (Union[int, tuple[int]]): The number of padding on the height and width directions of the input.
|
||||
The data type is an integer or a tuple of four integers. If `padding` is an integer,
|
||||
then the top, bottom, left, and right padding are all equal to `padding`.
|
||||
If `padding` is a tuple of 4 integers, then the top, bottom, left, and right padding
|
||||
is equal to `padding[0]`, `padding[1]`, `padding[2]`, and `padding[3]` respectively.
|
||||
The value should be greater than or equal to 0. Default: 0.
|
||||
dilation (Union[int, tuple[int]]): Dilation size of 2D convolution kernel.
|
||||
The data type is an integer or a tuple of two integers. If :math:`k > 1`, the kernel is sampled
|
||||
every `k` elements. The value of `k` on the height and width directions is in range of [1, H]
|
||||
and [1, W] respectively. Default: 1.
|
||||
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
|
||||
divisible by `group`. If the group is equal to `in_channels` and `out_channels`,
|
||||
this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
|
||||
has_bias (bool): Whether the Conv2d layer has a bias parameter. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
|
||||
Available initialization methods are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
|
||||
Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, H_{in}, W_{in})` \
|
||||
or :math:`(2, N, H_{in}, W_{in}, C_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(2, N, C_{out}, H_{out}, W_{out})` or :math:`(2, N, H_{out}, W_{out}, C_{out})`.
|
||||
|
||||
pad_mode is 'same':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[0]}}} \right \rceil \\
|
||||
W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[1]}}} \right \rceil \\
|
||||
\end{array}
|
||||
|
||||
pad_mode is 'valid':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
H_{out} = \left \lceil{\frac{H_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
||||
{\text{stride[0]}}} \right \rceil \\
|
||||
W_{out} = \left \lceil{\frac{W_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
||||
{\text{stride[1]}}} \right \rceil \\
|
||||
\end{array}
|
||||
|
||||
pad_mode is 'pad':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
H_{out} = \left \lfloor{\frac{H_{in} + padding[0] + padding[1] - (\text{kernel_size[0]} - 1) \times
|
||||
\text{dilation[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
||||
W_{out} = \left \lfloor{\frac{W_{in} + padding[2] + padding[3] - (\text{kernel_size[1]} - 1) \times
|
||||
\text{dilation[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
||||
\end{array}
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
|
||||
TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int not a tuple.
|
||||
ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
|
||||
ValueError: If `padding` is less than 0.
|
||||
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
|
||||
ValueError: If `padding` is a tuple whose length is not equal to 4.
|
||||
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0).
|
||||
ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import Conv2d
|
||||
>>> from mindspore import Tensor
|
||||
>>> w = Tensor(np.random.random((2, 128, 3, 7, 7)).astype(np.float32))
|
||||
>>> b = Tensor(np.random.random((2, 128)).astype(np.float32))
|
||||
>>> net = Conv2d(
|
||||
>>> in_channels=3, out_channels=128, kernel_size=7, stride=2, padding=3,
|
||||
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
|
||||
>>> )
|
||||
>>> inp = Tensor(np.random.random((2, 16, 3, 224, 224)).astype(np.float32))
|
||||
>>> out = net(inp)
|
||||
>>> print(out.shape)
|
||||
(2, 16, 128, 112, 112)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_2_t,
|
||||
stride: _size_2_t = 1,
|
||||
pad_mode: str = 'same',
|
||||
padding: _size_2_t = 0,
|
||||
dilation: _size_2_t = 1,
|
||||
group: int = 1,
|
||||
has_bias: bool = False,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
|
||||
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
data_format: str = 'NCHW') -> None:
|
||||
super(Conv2d, self).__init__(HConv2d,
|
||||
ConvImpl,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
group=group,
|
||||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init,
|
||||
data_format=data_format)
|
||||
|
||||
|
||||
class Conv1d(_UniformOperator):
|
||||
r"""
|
||||
1D convolution layer on the dual-valued input.
|
||||
|
||||
Calculates the 1D convolution on the input tensor which is typically of shape :math:`(2, N, C_{in}, L_{in})`,
|
||||
where :math:`N` is batch size, :math:`C_{in}` is a number of channels and :math:`L_{in}` is a length of sequence.
|
||||
The formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
|
||||
\sum_{k = 0}^{C_{in} - 1} \text{ccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
|
||||
|
||||
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
|
||||
the output and :math:`j` is in the range of :math:`[0, C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
|
||||
is a convolution kernel slice with shape :math:`\text{kernel_size}`, where :math:`\text{kernel_size}`
|
||||
is the width of the convolution kernel. :math:`\text{bias}` is the bias parameter,
|
||||
and :math:`\text{inp}` is the input tensor. The shape of full convolution kernel is
|
||||
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size})`,
|
||||
where `group` is the number of groups to split the input `inp` in the channel dimension.
|
||||
:math:`ccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
|
||||
This implies the operation as:
|
||||
|
||||
.. math::
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
|
||||
|
||||
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
|
||||
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
For more details, please refers to the paper `Gradient Based Learning Applied to Document
|
||||
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||
|
||||
Note:
|
||||
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
|
||||
That is, when `group > 1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
|
||||
|
||||
Args:
|
||||
in_channels (int): The channel number of the input tensor of the Conv1d layer.
|
||||
out_channels (int): The channel number of the output tensor of the Conv1d layer.
|
||||
kernel_size (int): Specifies the width of the 1D convolution kernel.
|
||||
stride (int): The movement stride of the 1D convolution kernel. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are
|
||||
"same", "valid", "pad". Default: "same".
|
||||
|
||||
- same: The width of the output is the same as the value of the input divided by `stride`.
|
||||
If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
|
||||
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
|
||||
If this mode is set, the value of `padding` must be greater than or equal to 0.
|
||||
|
||||
padding (int): The number of padding on both sides of input.
|
||||
The value should be greater than or equal to 0. Default: 0.
|
||||
dilation (int): Dilation size of 1D convolution kernel. If :math:`k > 1`, the kernel is sampled
|
||||
every `k` elements. The value of `k` is in range of [1, L]. Default: 1.
|
||||
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
|
||||
divisible by `group`. Default: 1.
|
||||
has_bias (bool): Whether the Conv1d layer has a bias parameter. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
|
||||
Available initialization methods are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, L_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(2, N, C_{out}, L_{out})`.
|
||||
|
||||
pad_mode is 'same':
|
||||
|
||||
.. math::
|
||||
L_{out} = \left \lceil{\frac{L_{in}}{\text{stride}}} \right \rceil
|
||||
|
||||
pad_mode is 'valid':
|
||||
|
||||
.. math::
|
||||
L_{out} = \left \lceil{\frac{L_{in} - \text{dilation} \times (\text{kernel_size} - 1) }
|
||||
{\text{stride}}} \right \rceil
|
||||
|
||||
pad_mode is 'pad':
|
||||
|
||||
.. math::
|
||||
L_{out} = \left \lfloor{\frac{L_{in} + 2 \times padding - (\text{kernel_size} - 1) \times
|
||||
\text{dilation} - 1 }{\text{stride}} + 1} \right \rfloor
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels`, `out_channels`, `kernel_size`, `stride`, `padding` or `dilation` is not an int.
|
||||
ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
|
||||
ValueError: If `padding` is less than 0.
|
||||
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import Conv1d
|
||||
>>> from mindspore import Tensor
|
||||
>>> w = Tensor(np.random.random((2, 16, 1, 6)).astype(np.float32))
|
||||
>>> b = Tensor(np.random.random((2, 16)).astype(np.float32))
|
||||
>>> net = Conv1d(
|
||||
>>> in_channels=1, out_channels=16, kernel_size=6, stride=2, padding=2,
|
||||
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
|
||||
>>> )
|
||||
>>> u = Tensor(np.random.random((2, 8, 1, 4096)).astype(np.float32))
|
||||
>>> out = net(u)
|
||||
>>> print(out.shape)
|
||||
(2, 8, 16, 2048)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_1_t,
|
||||
stride: _size_1_t = 1,
|
||||
pad_mode: str = 'same',
|
||||
padding: _size_1_t = 0,
|
||||
dilation: _size_1_t = 1,
|
||||
group: int = 1,
|
||||
has_bias: bool = False,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
|
||||
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros') -> None:
|
||||
super(Conv1d, self).__init__(HConv1d,
|
||||
ConvImpl,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
group=group,
|
||||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init)
|
||||
|
||||
|
||||
class Conv3d(_UniformOperator):
|
||||
r"""
|
||||
3D convolution layer on the dual-valued input.
|
||||
|
||||
Calculates the 3D convolution on the input tensor which is typically of shape
|
||||
:math:`(2, N, C_{in}, D_{in}, H_{in}, W_{in})`,
|
||||
where :math:`N` is batch size, :math:`C_{in}` is a number of channels,
|
||||
:math:`D_{in}, H_{in}, W_{in}` are the depth, height and width of the feature layer respectively.
|
||||
The formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
|
||||
\sum_{k = 0}^{C_{in} - 1} \text{ccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
|
||||
|
||||
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
|
||||
the output and :math:`j` is in the range of :math:`[0,C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
|
||||
is a convolution kernel slice with shape
|
||||
:math:`(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`,
|
||||
where :math:`\text{kernel_size[0]}`, :math:`\text{kernel_size[1]}` and :math:`\text{kernel_size[2]}` are
|
||||
the depth, height and width of the convolution kernel respectively. :math:`\text{bias}` is the bias parameter
|
||||
and :math:`\text{inp}` is the input tensor.
|
||||
The shape of full convolution kernel is
|
||||
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`,
|
||||
where `group` is the number of groups to split the input `x` in the channel dimension.
|
||||
:math:`hccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
|
||||
This implies the operation as:
|
||||
|
||||
.. math::
|
||||
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
|
||||
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
|
||||
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
|
||||
|
||||
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
|
||||
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}`
|
||||
and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
|
||||
the parentheses.
|
||||
|
||||
For more details, please refers to the paper `Gradient Based Learning Applied to Document
|
||||
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||
|
||||
Note:
|
||||
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
|
||||
That is, when `group>1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
|
||||
|
||||
Args:
|
||||
in_channels (int): The channel number of the input tensor of the Conv3d layer.
|
||||
out_channels (int): The channel number of the output tensor of the Conv3d layer.
|
||||
kernel_size (Union[int, tuple[int]]): Specifies the depth, height and width of the 3D convolution kernel.
|
||||
The data type is an integer or a tuple of three integers. An integer represents the depth, height
|
||||
and width of the convolution kernel. A tuple of three integers represents the depth, height
|
||||
and width of the convolution kernel respectively.
|
||||
stride (Union[int, tuple[int]]): The movement stride of the 3D convolution kernel.
|
||||
The data type is an integer or a tuple of three integers. An integer represents the movement step size
|
||||
in depth, height and width directions. A tuple of three integers represents the movement step size
|
||||
in the depth, height and width directions respectively. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are
|
||||
"same", "valid", "pad". Default: "same".
|
||||
|
||||
- same: The width of the output is the same as the value of the input divided by `stride`.
|
||||
If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
|
||||
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
|
||||
|
||||
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
|
||||
If this mode is set, the value of `padding` must be greater than or equal to 0.
|
||||
|
||||
padding (Union(int, tuple[int])): The number of padding on the depth, height and width directions of the input.
|
||||
The data type is an integer or a tuple of six integers. If `padding` is an integer,
|
||||
then the head, tail, top, bottom, left, and right padding are all equal to `padding`.
|
||||
If `padding` is a tuple of six integers, then the head, tail, top, bottom, left, and right padding
|
||||
is equal to `padding[0]`, `padding[1]`, `padding[2]`, `padding[3]`, `padding[4]` and `padding[5]`
|
||||
respectively. The value should be greater than or equal to 0. Default: 0.
|
||||
dilation (Union[int, tuple[int]]): Dilation size of 3D convolution kernel.
|
||||
The data type is an integer or a tuple of three integers. If :math:`k > 1`, the kernel is sampled
|
||||
every `k` elements. The value of `k` on the depth, height and width directions is in range of
|
||||
[1, D], [1, H] and [1, W] respectively. Default: 1.
|
||||
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
|
||||
divisible by `group`. Default: 1. Only 1 is currently supported.
|
||||
has_bias (bool): Whether the Conv3d layer has a bias parameter. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
|
||||
Available initialization methods are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
data_format (str): The optional value for data format. Currently only support "NCDHW".
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, D_{in}, H_{in}, W_{in})`.
|
||||
Currently input data type only support float16 and float32.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape is :math:`(2, N, C_{out}, D_{out}, H_{out}, W_{out})`.
|
||||
|
||||
pad_mode is 'same':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
D_{out} = \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
|
||||
H_{out} = \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
|
||||
W_{out} = \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
|
||||
\end{array}
|
||||
|
||||
|
||||
pad_mode is 'valid':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
D_{out} = \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
|
||||
{\text{stride[0]}} + 1} \right \rfloor \\
|
||||
H_{out} = \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
|
||||
{\text{stride[1]}} + 1} \right \rfloor \\
|
||||
W_{out} = \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
|
||||
{\text{stride[2]}} + 1} \right \rfloor \\
|
||||
\end{array}
|
||||
|
||||
pad_mode is 'pad':
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
D_{out} = \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
|
||||
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
|
||||
H_{out} = \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
|
||||
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
|
||||
W_{out} = \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
|
||||
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
|
||||
\end{array}
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
|
||||
TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
|
||||
ValueError: If `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
|
||||
ValueError: If `padding` is less than 0.
|
||||
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
|
||||
ValueError: If `padding` is a tuple whose length is not equal to 6.
|
||||
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0, 0, 0).
|
||||
ValueError: If `data_format` is not 'NCDHW'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import Conv3d
|
||||
>>> from mindspore import Tensor
|
||||
>>> w = Tensor(np.random.random((2, 128, 3, 3, 3, 3)).astype(np.float32))
|
||||
>>> b = Tensor(np.random.random((2, 128)).astype(np.float32))
|
||||
>>> net = Conv3d(
|
||||
>>> in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1,
|
||||
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
|
||||
>>> )
|
||||
>>> u = Tensor(np.random.random((2, 64, 3, 32, 32, 32)).astype(np.float32))
|
||||
>>> out = net(u)
|
||||
>>> print(out.shape)
|
||||
(2, 64, 128, 32, 32, 32)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_3_t,
|
||||
stride: _size_3_t = 1,
|
||||
pad_mode: str = 'same',
|
||||
padding: _size_3_t = 0,
|
||||
dilation: _size_3_t = 1,
|
||||
group: int = 1,
|
||||
has_bias: bool = False,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
|
||||
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
data_format: str = 'NCDHW') -> None:
|
||||
super(Conv3d, self).__init__(HConv3d,
|
||||
ConvImpl,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
group=group,
|
||||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init,
|
||||
data_format=data_format)
|
||||
|
||||
|
||||
class BatchNorm1d(_UniformOperator):
|
||||
r"""
|
||||
The dual-valued Batch Normalization layer over a second-order dual input of four dimensions, including
|
||||
one spatial dimension, or three dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
|
||||
a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
|
||||
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
|
||||
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}
|
||||
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
|
||||
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
|
||||
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
|
||||
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
|
||||
and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, W)` or :math:`(2, N, C)`.
|
||||
'2' denotes that the input tensor belongs to the dual domain and has got a real and
|
||||
a dual parts. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same shape as :math:`u`:
|
||||
:math:`(2, N, C, W)` or :math:`(2, N, C)`
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: if 'inp' is not a Tensor of 3 or 4 dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import BatchNorm1d
|
||||
>>> from mindspore import Tensor
|
||||
>>> u = Tensor(np.random.random((2, 8, 64, 32)).astype(np.float32))
|
||||
>>> bn = BatchNorm1d(64)
|
||||
>>> y = bn(u)
|
||||
>>> print(y.shape)
|
||||
(2, 8, 64, 32)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = True) -> None:
|
||||
super(BatchNorm1d, self).__init__(HBatchNorm1d,
|
||||
BatchNormImpl,
|
||||
num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
gamma_init=gamma_init,
|
||||
beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics)
|
||||
|
||||
|
||||
class BatchNorm2d(_UniformOperator):
|
||||
r"""
|
||||
The dual-valued Batch Normalization layer over a second-order dual input of five dimensions, including
|
||||
two spatial dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
|
||||
a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
|
||||
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
|
||||
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}\
|
||||
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
|
||||
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
|
||||
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
|
||||
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
|
||||
and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, H, W)` if data_format is 'NCHW', or
|
||||
:math:`(2, N, H, W, C)` if data_format is 'NHWC'. '2' denotes that the input tensor belongs to the dual
|
||||
domain and has got a real and a dual parts. The `num_features` in `Args` has to be equal to
|
||||
:math:`C` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same shape as :math:`u`:
|
||||
:math:`(2, N, C, H, W)` if data_format is 'NCHW', or :math:`(2, N, H, W, C)` if data_format is 'NHWC'.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
|
||||
ValueError: if 'inp' is not a Tensor of 5 dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import BatchNorm2d
|
||||
>>> from mindspore import Tensor
|
||||
>>> u = Tensor(np.random.random((2, 8, 64, 32, 32)).astype(np.float32))
|
||||
>>> bn = BatchNorm2d(64)
|
||||
>>> y = bn(u)
|
||||
>>> print(y.shape)
|
||||
(2, 8, 64, 32, 32)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = True,
|
||||
data_format='NCHW') -> None:
|
||||
super(BatchNorm2d, self).__init__(HBatchNorm2d,
|
||||
BatchNormImpl,
|
||||
num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
gamma_init=gamma_init,
|
||||
beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics,
|
||||
data_format=data_format)
|
||||
|
||||
|
||||
class BatchNorm3d(_UniformOperator):
|
||||
r"""
|
||||
The dual-valued Batch Normalization layer over a second-order dual input of six dimensions, including
|
||||
three spatial dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
|
||||
a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
|
||||
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
|
||||
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}
|
||||
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
|
||||
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
|
||||
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
|
||||
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
|
||||
and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
data_format (str): The optional value for data format. Only 'NCDHW' format is supported as of now.
|
||||
Default: 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, D, H, W)`. '2' denotes that the input tensor belongs
|
||||
to the dual domain and has got a real and a dual parts. The `num_features` in `Args` has to be equal
|
||||
to :math:`C` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same shape :math:`(2, N, C, D, H, W)` as :math:`u`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: If `data_format` is not 'NCDHW'.
|
||||
ValueError: if 'inp' is not a Tensor of 6 dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import BatchNorm3d
|
||||
>>> from mindspore import Tensor
|
||||
>>> u = Tensor(np.random.random((2, 8, 64, 32, 32, 32)).astype(np.float32))
|
||||
>>> bn = BatchNorm3d(64)
|
||||
>>> y = bn(u)
|
||||
>>> print(y.shape)
|
||||
(2, 8, 64, 32, 32, 32)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = True,
|
||||
data_format='NCDHW') -> None:
|
||||
super(BatchNorm3d, self).__init__(HBatchNorm3d,
|
||||
BatchNormImpl,
|
||||
num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
gamma_init=gamma_init,
|
||||
beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics,
|
||||
data_format=data_format)
|
||||
|
||||
|
||||
class Dense(_UniformOperator):
|
||||
r"""
|
||||
The dual-valued dense connected layer.
|
||||
|
||||
Applies dense connected layer for the dual-valued input. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)} + \text{Re(bias)}\\
|
||||
\text{Du(out)} = \text{Re(inp)} * \text{Du(kernel)} + \text{Du(inp)} * \text{Re(kernel)} + \text{Du(bias)},
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
|
||||
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
|
||||
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside the parentheses.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as `inp`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, *, ..., *, in\_channels)`. '2' denotes that the input tensor
|
||||
belongs to the dual domain and has got a real and a dual parts. The `in_channels` in `Args`
|
||||
has to be equal to :math:`in\_channels` in `Inputs`. The count of mediator dimensions denoted by '*' is
|
||||
arbitrary but must be at least one.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(2, *, ..., *, out\_channels)`. The count of mediator dimensions is the same as one
|
||||
in 'Inputs'.
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels` or `out_channels` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
ValueError: If length of shape of `weight_init` is not equal to 3,
|
||||
or shape[0] of 'weight_init' is not equal to 2,
|
||||
or shape[1] of `weight_init` is not equal to `out_channels`,
|
||||
or shape[2] of `weight_init` is not equal to `in_channels`.
|
||||
ValueError: If length of shape of `bias_init` is not equal to 2,
|
||||
or shape[0] of 'bias_init' is not equal to 2,
|
||||
or shape[1] of `bias_init` is not equal to `out_channels`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.hypercomplex.dual import Dense
|
||||
>>> from mindspore import Tensor
|
||||
>>> w = Tensor(np.random.random((2, 7, 5)).astype(np.float32))
|
||||
>>> b = Tensor(np.random.random((2, 7)).astype(np.float32))
|
||||
>>> net = Dense(in_channels=5, out_channels=7, weight_init=w, bias_init=b, has_bias=True)
|
||||
>>> u = Tensor(np.random.random((2, 34, 1, 5)).astype(np.float32))
|
||||
>>> out = net(u)
|
||||
>>> print(out.shape)
|
||||
(2, 34, 1, 7)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
|
||||
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
has_bias: bool = True) -> None:
|
||||
super(Dense, self).__init__(HDense,
|
||||
DenseImpl,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init,
|
||||
has_bias=has_bias)
|
|
@ -0,0 +1,339 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Hypercomplex BatchNorm Implementation"""
|
||||
import numbers
|
||||
from typing import Union, Tuple
|
||||
from abc import abstractmethod
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
class _BatchNormImpl(nn.Cell):
|
||||
r"""
|
||||
The interface of the implementor part of batch normalization layer on the second-order hypercomplex numbers.
|
||||
|
||||
Defines the API for getting the norm of hypercomplex number, applying scaling and shift to a hypercomplex tensor,
|
||||
and updating the running mean and variance, which are used during inference. The API is used by the 'BatchNorm'
|
||||
class, and it must be implemented separately for every hypercomplex algebra:
|
||||
|
||||
Args:
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
num_features: int) -> None:
|
||||
super(_BatchNormImpl, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def get_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates norm of a hypercomplex elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the hypercomplex
|
||||
domain and has a real and a hypercomplex parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_square_norm(self, u: Tensor) -> Tensor:
|
||||
r"""
|
||||
Calculates element-wise squared norm of hypercomplex elements of an input tensor.
|
||||
|
||||
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
|
||||
is from zero.
|
||||
|
||||
Args:
|
||||
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the hypercomplex
|
||||
domain and has a real and a hypercomplex parts.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
|
||||
the input tensor, but without the very first dimension because the output tensor is real-valued.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Applies hypercomplex scaling and shift to an input tensor.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{mul}(\text{inp}, \text{scale}) + \text{shift},
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{mul}` is the channel-wise scaling operation,
|
||||
which depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is
|
||||
a hypercomplex scaling vector with the same data type as the :math:`inp` created by the layer, and
|
||||
:math:`\text{shift}` is a hypercomplex bias vector with the same data type as the :math:`inp` created by
|
||||
the layer.
|
||||
|
||||
Args:
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the normalized inputs.
|
||||
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
|
||||
scale_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the scaling vector.
|
||||
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
|
||||
shift_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the bias vector.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
|
||||
recentered inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_bn(self,
|
||||
u_centered_x: Tensor,
|
||||
u_centered_y: Tensor,
|
||||
sigma: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Given a hypercomplex centered input tensor and the standard deviation of its elements, computes the
|
||||
corresponding rescaled and recentered tensor with normalized variance.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma}) + \text{shift},
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors centered over spatial and mini-batch dimensions,
|
||||
:math:`\sigma` is standard deviation of the input tensors over the same dimensions, :math:`\text{mul}` is a
|
||||
channel-wise scaling operation, which depends on the type of the number system and provided by subclassess,
|
||||
:math:`\text{scale}` is a hypercomplex scaling vector with the same data type as the :math:`inp` created
|
||||
by the layer, and :math:`\text{shift}` is a hypercomplex bias vector with the same data type as the
|
||||
:math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
u_centered_x (Tensor): A tensor of shape (C,), which represents the real part of the centered inputs.
|
||||
u_centered_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the
|
||||
centered inputs.
|
||||
sigma (Tensor): A tensor of shape (C,), which represents the statistical standard deviation of the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
|
||||
recentered normalized inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_infer_bn(self,
|
||||
moving_mean_x: Tensor,
|
||||
moving_mean_y: Tensor,
|
||||
moving_sigma: Tensor,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Given a hypercomplex input tensor, computes the corresponding rescaled and recentered normalized tensor.
|
||||
|
||||
This function is supposed to be used during inference. The mean and standard deviation are accumulated during
|
||||
the training phase. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma})
|
||||
+ \left(\text{mul}(-\mathrm{E}[inp], \frac{\text{scale}}{\sigma})+\text{shift}\right),
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\sigma` is the accumulated standard deviation of
|
||||
the input tensors over spatial and mini-batch dimensions, :math:`\mathrm{E}[inp]` is the accumulated arithmetic
|
||||
mean of the input tensor over the same dimensions,:math:`\text{mul}` is a channel-wise scaling operation, which
|
||||
depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is a hypercomplex
|
||||
scaling vector with the same data type as the :math:`inp` created by the layer, and :math:`\text{shift}` is a
|
||||
hypercomplex bias vector with the same data type as the :math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
moving_mean_x (Tensor): A tensor of shape (C,), which represents the real part of the accumulated
|
||||
arithmetic mean of inputs.
|
||||
moving_mean_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the accumulated
|
||||
arithmetic mean of inputs.
|
||||
moving_sigma (Tensor): A tensor of shape (C,), which represents the accumulated statistical standard
|
||||
deviation of inputs.
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the input tensor.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of normalized,
|
||||
rescaled and recentered inputs.
|
||||
"""
|
||||
|
||||
|
||||
class _BaseBatchNormImpl(_BatchNormImpl):
|
||||
r"""
|
||||
The base implementor part of the batch normalization layer for all the hypercomplex numbers of the second order.
|
||||
|
||||
Contains initialization and processing logic, which are shared by all specific implementations of the
|
||||
'BatchNormImpl' interface for dual, double, and complex numbers.
|
||||
|
||||
Args:
|
||||
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, the training process will use the mean
|
||||
and variance of current batch data and track the running mean and variance, the evaluation process will use
|
||||
the running mean and variance.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
|
||||
num_features (int): The number of features in the input space.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
affine: bool,
|
||||
use_batch_statistics: bool,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
num_features: int) -> None:
|
||||
super(_BaseBatchNormImpl, self).__init__(gamma_init,
|
||||
beta_init,
|
||||
num_features)
|
||||
self.scale_x = Parameter(initializer(gamma_init, num_features), name="scale_x", requires_grad=affine)
|
||||
self.scale_y = Parameter(initializer(gamma_init, num_features), name="scale_y", requires_grad=affine)
|
||||
self.shift_x = Parameter(initializer(beta_init, num_features), name="shift_x", requires_grad=affine)
|
||||
self.shift_y = Parameter(initializer(beta_init, num_features), name="shift_y", requires_grad=affine)
|
||||
|
||||
def calculate_infer_bn(self,
|
||||
moving_mean_x: Tensor,
|
||||
moving_mean_y: Tensor,
|
||||
moving_sigma: Tensor,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Given a hypercomplex input tensor, computes the corresponding rescaled and recentered normalized tensor.
|
||||
|
||||
This function is supposed to be used during inference. The mean and standard deviation are accumulated during
|
||||
the training phase. The function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma})
|
||||
+ \left(\text{mul}(-\mathrm{E}[inp], \frac{\text{scale}}{\sigma})+\text{shift}\right),
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\sigma` is the accumulated standard deviation of
|
||||
the input tensors over spatial and mini-batch dimensions, :math:`\mathrm{E}[inp]` is the accumulated arithmetic
|
||||
mean of the input tensor over the same dimensions,:math:`\text{mul}` is a channel-wise scaling operation, which
|
||||
depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is a hypercomplex
|
||||
scaling vector with the same data type as the :math:`inp` created by the layer, and :math:`\text{shift}` is a
|
||||
hypercomplex bias vector with the same data type as the :math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
moving_mean_x (Tensor): A tensor of shape (C,), which represents the real part of the accumulated
|
||||
arithmetic mean of inputs.
|
||||
moving_mean_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the accumulated
|
||||
arithmetic mean of inputs.
|
||||
moving_sigma (Tensor): A tensor of shape (C,), which represents the accumulated statistical standard
|
||||
deviation of inputs.
|
||||
u_x (Tensor): A tensor of shape (C,), which represents the real part of the input tensor.
|
||||
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of normalized,
|
||||
rescaled and recentered inputs.
|
||||
"""
|
||||
fused_scale_x = self.scale_x / moving_sigma
|
||||
fused_scale_y = self.scale_y / moving_sigma
|
||||
neg_mean_x = (-1) * moving_mean_x
|
||||
neg_mean_y = (-1) * moving_mean_y
|
||||
fused_shift_x, fused_shift_y = self.scale_and_shift(neg_mean_x,
|
||||
neg_mean_y,
|
||||
fused_scale_x,
|
||||
fused_scale_y,
|
||||
self.shift_x,
|
||||
self.shift_y)
|
||||
out_x, out_y = self.scale_and_shift(u_x,
|
||||
u_y,
|
||||
fused_scale_x,
|
||||
fused_scale_y,
|
||||
fused_shift_x,
|
||||
fused_shift_y)
|
||||
return out_x, out_y
|
||||
|
||||
def calculate_bn(self,
|
||||
u_centered_x: Tensor,
|
||||
u_centered_y: Tensor,
|
||||
sigma: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
r"""
|
||||
Given a hypercomplex centered input tensor and the standard deviation of its elements, computes the
|
||||
corresponding rescaled and recentered tensor with normalized variance.
|
||||
|
||||
This function implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma}) + \text{shift},
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors centered over spatial and mini-batch dimensions,
|
||||
:math:`\sigma` is standard deviation of the input tensors over the same dimensions, :math:`\text{mul}` is a
|
||||
channel-wise scaling operation, which depends on the type of the number system and provided by subclassess,
|
||||
:math:`\text{scale}` is a hypercomplex scaling vector with the same data type as the :math:`inp` created
|
||||
by the layer, and :math:`\text{shift}` is a hypercomplex bias vector with the same data type as the
|
||||
:math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
u_centered_x (Tensor): A tensor of shape (C,), which represents the real part of the centered inputs.
|
||||
u_centered_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the
|
||||
centered inputs.
|
||||
sigma (Tensor): A tensor of shape (C,), which represents the statistical standard deviation of the inputs.
|
||||
|
||||
Returns:
|
||||
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
|
||||
recentered normalized inputs.
|
||||
"""
|
||||
scale_x = self.scale_x / sigma
|
||||
scale_y = self.scale_y / sigma
|
||||
out_x, out_y = self.scale_and_shift(u_centered_x,
|
||||
u_centered_y,
|
||||
scale_x,
|
||||
scale_y,
|
||||
self.shift_x,
|
||||
self.shift_y)
|
||||
return out_x, out_y
|
||||
|
||||
@abstractmethod
|
||||
def get_norm(self, u: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_square_norm(self, u: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def scale_and_shift(self,
|
||||
u_x: Tensor,
|
||||
u_y: Tensor,
|
||||
scale_x: Tensor,
|
||||
scale_y: Tensor,
|
||||
shift_x: Tensor,
|
||||
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Hypercomplex Convolution Implementation"""
|
||||
import numbers
|
||||
from typing import Callable, Union, Tuple
|
||||
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.hypercomplex.utils import get_x_and_y
|
||||
|
||||
|
||||
class _ConvImpl(nn.Cell):
|
||||
r"""
|
||||
The interface of the implementor part of convolution layer on second-order hypercomplex numbers.
|
||||
|
||||
Defines the API for unbiased convolution transformation, which is used by the '_ConvNd' class. The API must
|
||||
be implemented separately for every hypercomplex algebra:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{conv}(\text{inp}, \text{kernel})
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{conv}` is the convolution transformation
|
||||
operation, which is provided by subclasses, :math:`\text{kernel}` is a hypercomplex weight matrix with the same
|
||||
data type as the :math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
|
||||
the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used with decomposition
|
||||
of the hypercomplex convolution transformation. For example, mindspore.ops.operations.Conv2D(...)
|
||||
may be passed for a 2D convolution.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input. The exact shape depends on data format and the number of spatial
|
||||
dimensions.
|
||||
- **y** (Tensor) - Tensor of the same shape as `x`, which defines the real part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`, which
|
||||
represent the real and the hypercomplex parts of the output respectively. Data format and the count of spatial
|
||||
dimensions are the same as in `x` and `y`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_ConvImpl, self).__init__()
|
||||
|
||||
def construct(self,
|
||||
conv_op: Callable,
|
||||
x: Tensor,
|
||||
y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class _BaseConvImpl(_ConvImpl):
|
||||
r"""
|
||||
The base implementor part of the convolution layer for all the hypercomplex numbers of the second order.
|
||||
|
||||
Contains initialization of the kernel tensors, which is shared by all specific implementations of the 'ConvImpl'
|
||||
interface for dual, double, and complex numbers.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
|
||||
the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **conv_op** (Callable) - the function of the real-valued convolution to be used with decomposition
|
||||
of the hypercomplex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be
|
||||
passed for a 2D convolution.
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
|
||||
which defines the real part of the input. The exact shape depends on data format and the number of spatial
|
||||
dimensions.
|
||||
- **y** (Tensor) - Tensor of the same shape as `x`, which defines the real part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`, which
|
||||
represent the real and the hypercomplex parts of the output respectively. Data format and the count of spatial
|
||||
dimensions are the same as in `x` and `y`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_BaseConvImpl, self).__init__(weight_init,
|
||||
weight_shape,
|
||||
**factory_kwargs)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
weight_init_x, weight_init_y = get_x_and_y(weight_init)
|
||||
else:
|
||||
weight_init_x = weight_init_y = weight_init
|
||||
self.weight_x = Parameter(initializer(weight_init_x, shape=weight_shape), name='weight_x')
|
||||
self.weight_y = Parameter(initializer(weight_init_y, shape=weight_shape), name='weight_y')
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hypercomplex dense implementation"""
|
||||
import numbers
|
||||
from typing import Callable, Union, Tuple
|
||||
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.hypercomplex.utils import get_x_and_y
|
||||
|
||||
|
||||
class _DenseImpl(nn.Cell):
|
||||
r"""
|
||||
The interface of the implementor part of dense connected layer on second-order hypercomplex numbers.
|
||||
|
||||
Defines the API for linear transformation, which is used by the 'Dense' class. The API must be implemented
|
||||
seprarately for every hypercomplex algebra:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{linear}(\text{inp}, \text{kernel})
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which is provided by subclasses, :math:`\text{kernel}` is a hypercomplex weight matrix with the same data type as
|
||||
the :math:`inp` created by the layer.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
|
||||
the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the hypercomplex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **y** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the hypercomplex part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the hypercomplex
|
||||
part of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_DenseImpl, self).__init__()
|
||||
|
||||
def construct(self,
|
||||
matmul_op: Callable,
|
||||
x: Tensor,
|
||||
y: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
class _BaseDenseImpl(_DenseImpl):
|
||||
r"""
|
||||
The base implementor part of the dense connected layer for all the hypercomplex numbers of the second order.
|
||||
|
||||
Contains initialization of the kernel tensors, which is shared by all specific implementations of the 'DenseImpl'
|
||||
interface for dual, double, and complex numbers.
|
||||
|
||||
Args:
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
|
||||
the kernel.
|
||||
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
|
||||
|
||||
Inputs:
|
||||
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
|
||||
of the hypercomplex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
|
||||
- **y** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the hypercomplex part of the input.
|
||||
|
||||
Outputs:
|
||||
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the hypercomplex
|
||||
part of the output.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number],
|
||||
weight_shape: tuple,
|
||||
**factory_kwargs) -> None:
|
||||
super(_BaseDenseImpl, self).__init__(weight_init,
|
||||
weight_shape,
|
||||
**factory_kwargs)
|
||||
if isinstance(weight_init, Tensor):
|
||||
weight_init_x, weight_init_y = get_x_and_y(weight_init)
|
||||
else:
|
||||
weight_init_x = weight_init_y = weight_init
|
||||
self.weight_x = Parameter(initializer(weight_init_x, shape=weight_shape), name='weight_x')
|
||||
self.weight_y = Parameter(initializer(weight_init_y, shape=weight_shape), name='weight_y')
|
|
@ -0,0 +1,627 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Hypercomplex batchnorm"""
|
||||
import numbers
|
||||
from typing import TypeVar, Type, Union, Any
|
||||
from abc import abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BatchNormImpl as BatchNormImpl
|
||||
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
|
||||
|
||||
TBatchNormImpl = TypeVar('TBatchNormImpl', bound=BatchNormImpl)
|
||||
|
||||
|
||||
class _BatchNorm(nn.Cell):
|
||||
r"""
|
||||
The base class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
|
||||
of some number of dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
|
||||
a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
out = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which depends on the type of the number system and provided by the implementor part of the batch normalization
|
||||
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
|
||||
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
|
||||
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero.
|
||||
|
||||
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate a child
|
||||
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
|
||||
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
|
||||
|
||||
Args:
|
||||
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
|
||||
class name of this argument defines the algebra that the batch normalization layer will operate on.
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, *, ..., *)` if data_format is 'NCHW', or
|
||||
:math:`(2, N, *, ..., *, C)` if data_format is 'NHWC', with float16 or float32 data type. '2' denotes that
|
||||
the input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts. Or,
|
||||
:math:`(N, C, *, ..., *)` if data_format is 'NCHW', or :math:`(N, *, ..., *, C)` if data_format is 'NHWC',
|
||||
with complex64 data type. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
|
||||
The count of dimensions denoted by '*' must be equal to the number of spatial dimensions.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
|
||||
:math:`(2, N, C, *, ..., *)` if data_format is 'NCHW', or :math:`(2, N, *, ..., *, C)` if data_format is 'NHWC',
|
||||
with float16 or float32 data type. Or, :math:`(N, C, *, ..., *)` if data_format is 'NCHW', or
|
||||
:math:`(N, *, ..., *, C)` if data_format is 'NHWC', with complex64 data type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
TypeError: If dtype of `inp` is not float16, float32 or complex64.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bn_impl: Type[TBatchNormImpl],
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = None,
|
||||
data_format: str = 'NCHW') -> None:
|
||||
"""Initialize _BatchNorm."""
|
||||
super(_BatchNorm, self).__init__()
|
||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||
if num_features < 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
|
||||
|
||||
if momentum < 0 or momentum > 1:
|
||||
raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
|
||||
f"but got {momentum}.")
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
|
||||
f"target {context.get_context('device_target')}.")
|
||||
self.use_batch_statistics = use_batch_statistics
|
||||
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
|
||||
raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None,"
|
||||
f" but got {use_batch_statistics}.")
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.beta_init = beta_init
|
||||
self.gamma_init = gamma_init
|
||||
self.moving_mean_init = moving_mean_init
|
||||
self.moving_var_init = moving_var_init
|
||||
self.affine = affine
|
||||
|
||||
self.bn_impl = bn_impl(affine, use_batch_statistics, gamma_init, beta_init, num_features)
|
||||
|
||||
self.moving_mean_x = Parameter(
|
||||
initializer(moving_mean_init, (num_features)), name="mean_x", requires_grad=False
|
||||
)
|
||||
self.moving_mean_y = Parameter(
|
||||
initializer(moving_mean_init, (num_features)), name="mean_y", requires_grad=False
|
||||
)
|
||||
self.moving_sigma2 = Parameter(
|
||||
initializer(moving_var_init, num_features), name="sigma2", requires_grad=False
|
||||
)
|
||||
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
|
||||
self._target = context.get_context("device_target")
|
||||
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
||||
self.momentum = 1.0 - momentum
|
||||
|
||||
self.reduce_mean_op1 = P.ReduceMean(keep_dims=True)
|
||||
self.reduce_mean_op2 = P.ReduceMean(keep_dims=False)
|
||||
|
||||
self.features_dim = data_format.lower().find('c')
|
||||
self.get_dtype = P.DType()
|
||||
self.get_shape = P.Shape()
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
"""construct"""
|
||||
u_dtype = self.get_dtype(u)
|
||||
u_shape = self.get_shape(u)
|
||||
self._check_input_dim(u_shape, u_dtype)
|
||||
if u_dtype == mindspore.complex64:
|
||||
hc_axis = None
|
||||
feature_axis = self.features_dim
|
||||
else:
|
||||
hc_axis = 0
|
||||
feature_axis = self.features_dim + 1
|
||||
|
||||
if self.training or not self.use_batch_statistics:
|
||||
ndim = u.ndim
|
||||
hc_axis = hc_axis
|
||||
feature_axis = feature_axis
|
||||
sh = np.arange(ndim)
|
||||
sh = sh[sh != hc_axis]
|
||||
sh = sh[sh != feature_axis]
|
||||
if hc_axis is None:
|
||||
u_x, u_y = get_x_and_y(u)
|
||||
mu_x = self.reduce_mean_op1(u_x, sh.tolist())
|
||||
mu_y = self.reduce_mean_op1(u_y, sh.tolist())
|
||||
mu = to_2channel(mu_x, mu_y, mindspore.complex64)
|
||||
else:
|
||||
mu = self.reduce_mean_op1(u, sh.tolist())
|
||||
|
||||
u_centered = u - mu
|
||||
norma2 = self.bn_impl.get_square_norm(u_centered)
|
||||
norma_feature_axis = feature_axis if hc_axis is None or feature_axis < hc_axis else feature_axis - 1
|
||||
ndim = norma2.ndim
|
||||
mean_dims = np.arange(ndim)
|
||||
mean_dims = mean_dims[mean_dims != norma_feature_axis]
|
||||
sigma2 = self.reduce_mean_op2(norma2, mean_dims.tolist()) + self.eps
|
||||
result = self._calculate_bn(u_centered, sigma2, feature_axis)
|
||||
|
||||
if self.use_batch_statistics:
|
||||
momentum = self.momentum
|
||||
mu = mu.squeeze()
|
||||
mu_x, mu_y = get_x_and_y(mu)
|
||||
momentum_suppl = 1 - momentum
|
||||
self.moving_mean_x *= momentum_suppl
|
||||
self.moving_mean_x += mu_x * momentum
|
||||
self.moving_mean_y *= momentum_suppl
|
||||
self.moving_mean_y += mu_y * momentum
|
||||
self.moving_sigma2 *= momentum_suppl
|
||||
self.moving_sigma2 += sigma2 * momentum
|
||||
elif self.affine:
|
||||
result = self._calculate_infer_bn(u, axis=feature_axis)
|
||||
else:
|
||||
broadcast_mu_shape = [1] * u.ndim
|
||||
broadcast_mu_shape[feature_axis] = u_shape[feature_axis]
|
||||
if hc_axis is not None:
|
||||
broadcast_mu_shape[hc_axis] = 2
|
||||
moving_mean = to_2channel(self.moving_mean_x, self.moving_mean_y, u.dtype)
|
||||
moving_mean = moving_mean.reshape(tuple(broadcast_mu_shape))
|
||||
inference_centered = u - moving_mean
|
||||
result = self._calculate_bn(inference_centered, self.moving_sigma2, feature_axis)
|
||||
return result
|
||||
|
||||
def _calculate_bn(self,
|
||||
u_centered: Tensor,
|
||||
sigma2: Tensor,
|
||||
axis: int) -> Tensor:
|
||||
"""_calculate_bn, implement the abstract function"""
|
||||
sigma = P.sqrt(sigma2)
|
||||
ndim = u_centered.ndim
|
||||
u_shape = list(np.arange(ndim))
|
||||
u_shape[ndim - 1] = axis
|
||||
u_shape[axis] = ndim - 1
|
||||
u_shape = tuple(int(i) for i in u_shape)
|
||||
out = P.transpose(u_centered, u_shape)
|
||||
if self.affine:
|
||||
out_x, out_y = get_x_and_y(out)
|
||||
out_x, out_y = self.bn_impl.calculate_bn(out_x, out_y, sigma)
|
||||
out = to_2channel(out_x, out_y, self.get_dtype(u_centered))
|
||||
else:
|
||||
out = out / sigma
|
||||
out = P.transpose(out, u_shape)
|
||||
return out
|
||||
|
||||
def _calculate_infer_bn(self,
|
||||
u: Tensor,
|
||||
axis: int) -> Tensor:
|
||||
"""_calculate_infer_bn, implement the abstract function"""
|
||||
ndim = u.ndim
|
||||
shape = list(np.arange(ndim))
|
||||
shape[ndim-1] = axis
|
||||
shape[axis] = ndim - 1
|
||||
shape = tuple(int(i) for i in shape)
|
||||
|
||||
out = P.transpose(u, shape)
|
||||
out_x, out_y = get_x_and_y(out)
|
||||
out_x, out_y = self.bn_impl.calculate_infer_bn(self.moving_mean_x,
|
||||
self.moving_mean_y,
|
||||
P.sqrt(self.moving_sigma2),
|
||||
out_x,
|
||||
out_y)
|
||||
out = to_2channel(out_x, out_y, dtype=u.dtype)
|
||||
out = P.transpose(out, shape)
|
||||
return out
|
||||
|
||||
@abstractmethod
|
||||
def _check_input_dim(self, shape: tuple, dtype: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BatchNorm1d(_BatchNorm):
|
||||
r"""
|
||||
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
|
||||
of four dimensions including one spatial dimension, or three dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a hypercomplex input of 'NCW' data format in order to reduce
|
||||
internal covariate shift. Batch Normalization is widely used in convolutional networks. It rescales and recenters
|
||||
the feature using a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
out = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which depends on the type of the number system and provided by the implementor part of the batch normalization
|
||||
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
|
||||
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
|
||||
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero.
|
||||
|
||||
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
|
||||
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
|
||||
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
|
||||
|
||||
Args:
|
||||
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
|
||||
class name of this argument defines the algebra that the batch normalization layer will operate on.
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, W)` or :math:`(2, N, C)`, with float16 or float32 data
|
||||
type, or :math:`(N, C, W)` or :math:`(N, C)`, with complex64 data type. In the former case '2' denotes that
|
||||
the input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts.
|
||||
The `num_features` in `Args` has to be equal to :math:`C` in `inp`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
|
||||
:math:`(2, N, C, W)` or :math:`(2, N, C)`, with float16 or float32 data type, or :math:`(N, C, W)` or
|
||||
:math:`(N, C)`, with complex64 data type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
TypeError: If dtype of `inp` is not float16, float32 or complex64.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: if 'inp' is not a Tensor of 3 or 4 dimensions with float16 or float32 data type, and not a Tensor
|
||||
of 2 or 3 dimensions with complex64 data type.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bn_impl: Type[TBatchNormImpl],
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = None) -> None:
|
||||
"""Initialize _BatchNorm."""
|
||||
|
||||
super(BatchNorm1d, self).__init__(bn_impl,
|
||||
num_features,
|
||||
eps,
|
||||
momentum,
|
||||
affine,
|
||||
gamma_init,
|
||||
beta_init,
|
||||
moving_mean_init,
|
||||
moving_var_init,
|
||||
use_batch_statistics)
|
||||
|
||||
def _check_input_dim(self, shape: tuple, dtype: Any):
|
||||
dim = len(shape)
|
||||
if dtype in [mindspore.float16, mindspore.float32]:
|
||||
if dim not in (4, 3):
|
||||
raise ValueError(f"For '{self.cls_name}', the in_shape must have 3-4 dims, but got {dim}.")
|
||||
elif dtype == mindspore.complex64:
|
||||
if dim not in (3, 2):
|
||||
raise ValueError(f"For '{self.cls_name}', the in_shape must have 2-3 dims, but got {dim}.")
|
||||
else:
|
||||
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
|
||||
return None
|
||||
|
||||
|
||||
class BatchNorm2d(_BatchNorm):
|
||||
r"""
|
||||
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
|
||||
of five dimensions, including two spatial dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature
|
||||
using a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
y = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which depends on the type of the number system and provided by the implementor part of the batch normalization
|
||||
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
|
||||
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
|
||||
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero.
|
||||
|
||||
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
|
||||
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
|
||||
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
|
||||
|
||||
Args:
|
||||
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
|
||||
class name of this argument defines the algebra that the batch normalization layer will operate on.
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, H, W)` if data_format is 'NCHW', or
|
||||
:math:`(2, N, H, W, C)` if data_format is 'NHWC', with float16 or float32 data type. '2' denotes that the
|
||||
input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts. Or,
|
||||
:math:`(N, C, H, W)` if data_format is 'NCHW', or :math:`(N, H, W, C)` if data_format is 'NHWC', with
|
||||
complex64 data type. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
|
||||
:math:`(2, N, C, H, W)` if data_format is 'NCHW', or :math:`(2, N, H, W, C)` if data_format is 'NHWC', with
|
||||
float16 or float32 data type. Or, :math:`(N, C, H, W)` if data_format is 'NCHW', or :math:`(N, H, W, C)` if
|
||||
data_format is 'NHWC', with complex64 data type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
TypeError: If dtype of `inp` is not float16, float32 or complex64.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
|
||||
ValueError: if 'inp' is not a Tensor of 5 dimensions with float16 or float32 data type, and not a Tensor of 4
|
||||
dimensions with complex64 data type.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, shape: tuple, dtype: Any):
|
||||
dim = len(shape)
|
||||
if dtype in [mindspore.float16, mindspore.float32]:
|
||||
if dim != 5:
|
||||
raise ValueError(f"For '{self.cls_name}', the in_shape must have 5 dims, but got {dim}.")
|
||||
elif dtype == mindspore.complex64:
|
||||
if dim != 4:
|
||||
raise ValueError(f"For '{self.cls_name}', the in_shape must have 4 dims, but got {dim}.")
|
||||
else:
|
||||
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
|
||||
return None
|
||||
|
||||
|
||||
class BatchNorm3d(nn.Cell):
|
||||
r"""
|
||||
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
|
||||
of six dimensions, including three spatial dimensions.
|
||||
|
||||
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
|
||||
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature
|
||||
using a mini-batch of data and the learned parameters which can be described by the following formula:
|
||||
|
||||
.. math::
|
||||
\begin{align}
|
||||
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
|
||||
y = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
|
||||
\end{align}
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which depends on the type of the number system and provided by the implementor part of the batch normalization
|
||||
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
|
||||
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
|
||||
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
|
||||
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
|
||||
statistical variance is close to zero.
|
||||
|
||||
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
|
||||
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
|
||||
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
|
||||
|
||||
Args:
|
||||
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
|
||||
class name of this argument defines the algebra that the batch normalization layer will operate on.
|
||||
num_features (int): The number of features in the input space.
|
||||
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
|
||||
Default: 0.9.
|
||||
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool):
|
||||
|
||||
- If True, use the mean value and variance value of current batch data and track running mean
|
||||
and running variance.
|
||||
- If False, use the mean value and variance value of specified value, and not track statistical value.
|
||||
- If None, the use_batch_statistics is automatically set to True or False according to the training
|
||||
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
|
||||
parameter is set to False. Default: None.
|
||||
data_format (str): The optional value for data format. Only 'NCDHW' format is supported as of now.
|
||||
Default: 'NCDHW'.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, D, H, W)`, with float16 or float32 data type, or
|
||||
:math:`(N, C, D, H, W)`, with complex64 data type. In the former case '2' denotes that the input tensor
|
||||
belongs to the hypercomplex domain and has got a real and a hypercomplex parts. The `num_features` in `Args`
|
||||
has to be equal to :math:`C` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
|
||||
:math:`(2, N, C, D, H, W)`, with float16 and float32 data type, or :math:`(N, C, D, H, W)`, with
|
||||
complex64 data type.
|
||||
|
||||
Raises:
|
||||
TypeError: If `num_features` is not an int.
|
||||
TypeError: If `eps` is not a float.
|
||||
TypeError: If dtype of `inp` is not float16, float32 or complex64.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
ValueError: If `data_format` is not 'NCDHW'.
|
||||
ValueError: if 'inp' is not a Tensor of 6 dimensions.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
bn_impl: Type[TBatchNormImpl],
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: float = 0.9,
|
||||
affine: bool = True,
|
||||
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
|
||||
use_batch_statistics: bool = None,
|
||||
data_format: str = 'NCDHW') -> None:
|
||||
"""Initialize _BatchNorm."""
|
||||
super(BatchNorm3d, self).__init__()
|
||||
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
||||
self.reshape = P.Reshape()
|
||||
self.bn2d = BatchNorm2d(bn_impl=bn_impl,
|
||||
num_features=num_features,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
gamma_init=gamma_init,
|
||||
beta_init=beta_init,
|
||||
moving_mean_init=moving_mean_init,
|
||||
moving_var_init=moving_var_init,
|
||||
use_batch_statistics=use_batch_statistics,
|
||||
data_format="NCHW")
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
'''construct'''
|
||||
u_shape = F.shape(u)
|
||||
self._check_3d_shape(u_shape, F.dtype(u))
|
||||
reshape = list(u_shape)
|
||||
reshape[-3] *= reshape[-2]
|
||||
reshape = tuple(int(i) for i in reshape[:-2] + reshape[-1:])
|
||||
u = self.reshape(u, tuple(reshape))
|
||||
out = self.bn2d(u)
|
||||
out = self.reshape(out, u_shape)
|
||||
return out
|
||||
|
||||
def _check_3d_shape(self, input_shape, dtype: Any) -> None:
|
||||
'''_check_3d_shape'''
|
||||
dim = len(input_shape)
|
||||
if dtype in [mindspore.float16, mindspore.float32]:
|
||||
if dim != 6:
|
||||
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
|
||||
raise ValueError(f"{msg_prefix} input_shape must be 6-dimensional, but got the length of input_shape: "
|
||||
f"{len(dim)}.")
|
||||
elif dtype == mindspore.complex64:
|
||||
if dim != 5:
|
||||
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
|
||||
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
|
||||
f"{len(dim)}.")
|
||||
else:
|
||||
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
|
||||
return None
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,200 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Hypercomplex Dense"""
|
||||
import numbers
|
||||
from typing import TypeVar, Type, Union
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _DenseImpl as DenseImpl
|
||||
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
|
||||
|
||||
|
||||
TDenseImpl = TypeVar('TDenseImpl', bound=DenseImpl)
|
||||
|
||||
|
||||
class Dense(nn.Cell):
|
||||
r"""
|
||||
The abstract part of dense connected layer.
|
||||
|
||||
Applies dense connected layer for the second-order hypercomplex input. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{out} = \text{linear}(\text{inp}, \text{kernel}) + \text{bias},
|
||||
|
||||
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
|
||||
which is defined and provided by the implementor part of the dense connected layer, :math:`\text{kernel}` is
|
||||
a hypercomplex weight matrix with the same data type as the :math:`inp` created by the layer, and
|
||||
:math:`\text{bias}` is a hypercomplex bias vector with the same data type as the :math:`inp` created by the layer
|
||||
(only if has_bias is True).
|
||||
|
||||
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
|
||||
class and an implementor class, which acts like a strategy pattern and determine the exact set of hypercomplex
|
||||
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
|
||||
|
||||
Args:
|
||||
dense_impl(DenseImpl): The implementor object of the dense connected layer. Essentially, the concrete class
|
||||
name of this argument defines the algebra that the dense layer will operate on.
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as `inp`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - Tensor of shape :math:`(2, *, ..., *, in\_channels)`, with float16 or float32 data type,
|
||||
or :math:`(*, ..., *, in\_channels)`, with complex64 data type. In the former case '2' denotes that the input
|
||||
tensor belongs to the hypercomplex domain and has got a real and an imaginary parts. The `in_channels` in
|
||||
`Args` has to be equal to :math:`in\_channels` in `Inputs`. The count of mediator dimensions denoted by '*'
|
||||
is arbitrary but must be at least one.
|
||||
|
||||
Outputs:
|
||||
Tensor of the same data type as 'inp' and of shape :math:`(2, *, ..., *, out\_channels)`, with float16 or
|
||||
float32 data type, or :math:`(*, ..., *, out\_channels)`, with complex64 data type. The count of mediator
|
||||
dimensions is the same as one in 'inp'.
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels` or `out_channels` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
TypeError: If any two of `inp`, `weight_init` and `bias_init` are Tensors of different data type.
|
||||
ValueError: If length of shape of `weight_init` is not equal to 3,
|
||||
or shape[0] of 'weight_init' is not equal to 2,
|
||||
or shape[1] of `weight_init` is not equal to `out_channels`,
|
||||
or shape[2] of `weight_init` is not equal to `in_channels`.
|
||||
ValueError: If length of shape of `bias_init` is not equal to 2,
|
||||
or shape[0] of 'bias_init' is not equal to 2,
|
||||
or shape[1] of `bias_init` is not equal to `out_channels`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dense_impl: Type[TDenseImpl],
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
|
||||
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
|
||||
has_bias: bool = True) -> None:
|
||||
"""Initialize Dense."""
|
||||
super(Dense, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
|
||||
self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
|
||||
self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
|
||||
self.dtype = None
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_op = P.Shape()
|
||||
|
||||
self.weight_x = None
|
||||
self.weight_y = None
|
||||
if isinstance(weight_init, Tensor):
|
||||
self.dtype = weight_init.dtype
|
||||
if self.dtype in [mindspore.float16, mindspore.float32] and ( \
|
||||
weight_init.ndim != 3
|
||||
or weight_init.shape[0] != 2 \
|
||||
or weight_init.shape[1] != out_channels \
|
||||
or weight_init.shape[2] != in_channels):
|
||||
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
||||
f"be equal to 3, and the first dim must be equal to 2, and the second dim must be "
|
||||
f"equal to 'out_channels', and the third dim must be equal to 'in_channels'. But got "
|
||||
f"'weight_init': {weight_init}, 'out_channels': {out_channels}, 'in_channels': "
|
||||
f"{in_channels}.")
|
||||
if self.dtype == mindspore.complex64 and ( \
|
||||
weight_init.ndim != 2 \
|
||||
or weight_init.shape[0] != out_channels \
|
||||
or weight_init.shape[1] != in_channels):
|
||||
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
||||
f"be equal to 2, and the first dim must be equal to 'out_channels', "
|
||||
f"and the second dim must be equal to 'in_channels'. But got "
|
||||
f"'weight_init': {weight_init}, 'out_channels': {out_channels}, 'in_channels': "
|
||||
f"{in_channels}.")
|
||||
|
||||
self.dense_impl = dense_impl(weight_init, [out_channels, in_channels])
|
||||
|
||||
self.bias_x = None
|
||||
self.bias_y = None
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if self.dtype is None:
|
||||
self.dtype = bias_init.dtype
|
||||
elif self.dtype != bias_init.dtype:
|
||||
raise TypeError("Data type of weight init tensor and the bias init tensor must be equal, "
|
||||
f"but got weight_init.dtype={self.dtype} and bias_init.dtype={bias_init.dtype}")
|
||||
if self.dtype in [mindspore.float16, mindspore.float32] and ( \
|
||||
bias_init.ndim != 2 \
|
||||
or bias_init.shape[0] != 2 \
|
||||
or bias_init.shape[1] != out_channels):
|
||||
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
||||
f"be equal to 2, and the second dim must be equal to 'out_channels'. But got "
|
||||
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||
if self.dtype == mindspore.complex64 and ( \
|
||||
bias_init.ndim != 1 \
|
||||
or bias_init.shape[0] != out_channels):
|
||||
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
||||
f"be equal to 1, and the only dim must be equal to 'out_channels'. But got "
|
||||
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
|
||||
bias_init_x, bias_init_y = get_x_and_y(bias_init)
|
||||
else:
|
||||
bias_init_x = bias_init_y = bias_init
|
||||
self.bias_x = Parameter(initializer(bias_init_x, [out_channels]), name="bias_x")
|
||||
self.bias_y = Parameter(initializer(bias_init_y, [out_channels]), name="bias_y")
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
|
||||
def check_dense_input_shape(self, x: Tensor, x_dtype):
|
||||
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
|
||||
if x_dtype in [mindspore.float32, mindspore.float64] and (len(x) < 3 or x[0] != 2):
|
||||
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 3, and the first dimension "
|
||||
f"should be 2, but got {x}.")
|
||||
if x_dtype == mindspore.complex64 and len(x) < 2:
|
||||
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {x}.")
|
||||
return None
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
"""Construct"""
|
||||
if self.dtype is not None and self.dtype != u.dtype:
|
||||
raise TypeError("dtype must be equal to the data type of the inputs tensor, but got: "
|
||||
f"dtype={self.dtype} and inputs.dtype={u.dtype}")
|
||||
u_shape = self.shape_op(u)
|
||||
self.check_dense_input_shape(u_shape, u.dtype)
|
||||
u_reshape = [-1, u_shape[-1]]
|
||||
if u.dtype in [mindspore.float32, mindspore.float64]:
|
||||
u_reshape = [2] + u_reshape
|
||||
if len(u_reshape) < len(u_shape):
|
||||
u = self.reshape(u, tuple(u_reshape))
|
||||
x, y = get_x_and_y(u)
|
||||
out_x, out_y = self.dense_impl(self.matmul, x, y)
|
||||
if self.has_bias:
|
||||
out_x = self.bias_add(out_x, self.bias_x)
|
||||
out_y = self.bias_add(out_y, self.bias_y)
|
||||
out = to_2channel(out_x, out_y, u.dtype)
|
||||
if len(u_reshape) < len(u_shape):
|
||||
out_shape = u_shape[:-1] + (-1,)
|
||||
out = self.reshape(out, out_shape)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
|
||||
if self.has_bias:
|
||||
s += ', has_bias={}'.format(self.has_bias)
|
||||
return s
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Uniform operators"""
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
||||
class _UniformOperator(Cell):
|
||||
r"""
|
||||
Base class for layers that operate with the second-order hypercomplex numbers, and are designed
|
||||
using the bridge pattern.
|
||||
|
||||
Constructs the object of the 'hc_op' type, passing 'hc_impl' as a parameter.
|
||||
|
||||
Args:
|
||||
hc_op (Type): The abstraction part of the bridge pattern.
|
||||
hc_impl (Type): The implementor part of the bridge pattern.
|
||||
**kwargs (dict): Additional arguments that may be required to construct the specific layer
|
||||
|
||||
Inputs:
|
||||
- **inp** (Tensor) - input tensor. The shape is specific to the subclass.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape, which is specific to the subclass.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hc_op,
|
||||
hc_impl,
|
||||
**kwargs) -> None:
|
||||
super(_UniformOperator, self).__init__()
|
||||
self.op = hc_op(hc_impl, **kwargs)
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Utils"""
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mindspore
|
||||
from mindspore import ops as P
|
||||
|
||||
|
||||
to_complex = P.Complex()
|
||||
get_real = P.Real()
|
||||
get_imag = P.Imag()
|
||||
unstack = P.Unstack(0)
|
||||
cat = P.Concat(0)
|
||||
|
||||
|
||||
def get_x_and_y(tensor):
|
||||
if tensor.dtype == mindspore.complex64:
|
||||
return get_real(tensor), get_imag(tensor)
|
||||
return unstack(tensor)
|
||||
|
||||
|
||||
def to_2channel(real, imag, dtype=None):
|
||||
'''Convert to 2 channel format'''
|
||||
if dtype is not None and dtype == mindspore.complex64:
|
||||
return to_complex(real, imag)
|
||||
if dtype is not None and (dtype != real.dtype or dtype != imag.dtype):
|
||||
raise ValueError("dtype must match with data type of the input tensors, but got: "
|
||||
f"dtype={dtype}, real.dtype={real.dtype}, imag.dtype={imag.dtype}")
|
||||
expand_dims = P.ExpandDims()
|
||||
real = expand_dims(real, 0)
|
||||
imag = expand_dims(imag, 0)
|
||||
return cat((real, imag))
|
||||
|
||||
|
||||
_size_1_t = Union[int, Tuple[int]]
|
||||
_size_2_t = Union[int, Tuple[int, int]]
|
||||
_size_3_t = Union[int, Tuple[int, int, int]]
|
||||
_size_any_t = Union[int, Tuple[int, ...]]
|
|
@ -0,0 +1,83 @@
|
|||
from mindspore import nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.hypercomplex.dual as ops
|
||||
|
||||
|
||||
class DeepConvNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(DeepConvNet, self).__init__()
|
||||
|
||||
self.conv1 = ops.Conv1d(1, 16, kernel_size=6, stride=2, padding=2, pad_mode='pad')
|
||||
self.bn1 = ops.BatchNorm1d(16)
|
||||
self.avg_pool1 = ops.AvgPool1d(kernel_size=2, stride=2)
|
||||
self.pad1 = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 2)), mode='CONSTANT')
|
||||
|
||||
self.conv2 = ops.Conv1d(16, 32, kernel_size=3, stride=2, padding=0)
|
||||
self.bn2 = ops.BatchNorm1d(32)
|
||||
self.avg_pool2 = ops.AvgPool1d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv3 = ops.Conv1d(32, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
|
||||
self.bn3 = ops.BatchNorm1d(64)
|
||||
self.avg_pool3 = ops.AvgPool1d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv4 = ops.Conv1d(64, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
|
||||
self.bn4 = ops.BatchNorm1d(64)
|
||||
self.avg_pool4 = ops.AvgPool1d(kernel_size=2, stride=2)
|
||||
|
||||
self.conv5 = ops.Conv1d(64, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
|
||||
self.conv6 = ops.Conv1d(128, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
|
||||
self.bn6 = ops.BatchNorm1d(128)
|
||||
self.avg_pool6 = ops.AvgPool1d(kernel_size=2, stride=2)
|
||||
|
||||
self.shape_op = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.permute = P.Transpose()
|
||||
self.flatten = P.Flatten()
|
||||
|
||||
self.fc1 = ops.Dense(4096, 1024)
|
||||
self.fc2 = nn.Dense(2048, 84)
|
||||
|
||||
self.relu = ops.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
u = self.conv1(u)
|
||||
u = self.bn1(u)
|
||||
u = self.relu(u)
|
||||
u = self.avg_pool1(u)
|
||||
u = self.pad1(u)
|
||||
|
||||
u = self.conv2(u)
|
||||
u = self.bn2(u)
|
||||
u = self.relu(u)
|
||||
u = self.avg_pool2(u)
|
||||
|
||||
u = self.conv3(u)
|
||||
u = self.bn3(u)
|
||||
u = self.relu(u)
|
||||
u = self.avg_pool3(u)
|
||||
|
||||
u = self.conv4(u)
|
||||
u = self.bn4(u)
|
||||
u = self.relu(u)
|
||||
u = self.avg_pool4(u)
|
||||
|
||||
u = self.conv5(u)
|
||||
u = self.relu(u)
|
||||
|
||||
u = self.conv6(u)
|
||||
u = self.bn6(u)
|
||||
u = self.relu(u)
|
||||
u = self.avg_pool6(u)
|
||||
|
||||
u_shape = self.shape_op(u)
|
||||
u = self.reshape(u, (u_shape[0], u_shape[1], -1))
|
||||
u = self.fc1(u)
|
||||
u = self.relu(u)
|
||||
|
||||
u = self.permute(u, (1, 0, 2))
|
||||
x = self.flatten(u)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
|
@ -0,0 +1,32 @@
|
|||
from mindspore import nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
|
||||
import mindspore.hypercomplex.dual as ops
|
||||
|
||||
|
||||
class HCModel(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(HCModel, self).__init__()
|
||||
self.conv1 = ops.Conv2d(1, 10, kernel_size=3)
|
||||
self.bn1 = ops.BatchNorm2d(10)
|
||||
self.max_pool = ops.MaxPool2d(2)
|
||||
self.relu = ops.ReLU()
|
||||
self.fc1 = ops.Dense(7290, 256)
|
||||
self.fc2 = nn.Dense(512, 10)
|
||||
self.concat = P.Concat(1)
|
||||
|
||||
def construct(self, u: Tensor) -> Tensor:
|
||||
u = to_2channel(u[:, :1], u[:, 1:])
|
||||
u = self.conv1(u)
|
||||
u = self.bn1(u)
|
||||
u = self.relu(u)
|
||||
u = self.max_pool(u)
|
||||
u = u.view(2, u.shape[1], -1)
|
||||
u = self.fc1(u)
|
||||
u = self.relu(u)
|
||||
out_x, out_y = get_x_and_y(u)
|
||||
out = self.concat([out_x, out_y])
|
||||
out = self.fc2(out)
|
||||
return out
|
|
@ -0,0 +1,593 @@
|
|||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
import math
|
||||
import numpy as np
|
||||
from scipy.stats import truncnorm
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.hypercomplex.dual as ops
|
||||
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
|
||||
|
||||
|
||||
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
fan_in = in_channel * kernel_size * kernel_size
|
||||
scale = 1.0
|
||||
scale /= max(1., fan_in)
|
||||
stddev = (scale ** 0.5) / .87962566103423978
|
||||
mu, sigma = 0, stddev
|
||||
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
|
||||
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
|
||||
return Tensor(weight, dtype=mstype.float32)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2: # Linear
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 3, 3)
|
||||
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
|
||||
if res_base:
|
||||
return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=1, pad_mode='pad', weight_init=weight)
|
||||
return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 1, 1)
|
||||
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
|
||||
if res_base:
|
||||
return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='pad', weight_init=weight)
|
||||
return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
||||
padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
|
||||
if use_se:
|
||||
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel, 7, 7)
|
||||
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
|
||||
if res_base:
|
||||
return ops.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
|
||||
return ops.Conv2d(in_channel, out_channel,
|
||||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
|
||||
|
||||
|
||||
def _bn(channel, res_base=False):
|
||||
if res_base:
|
||||
return ops.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
return ops.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
|
||||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
|
||||
|
||||
|
||||
def _fc(in_channel, out_channel, use_se=False):
|
||||
if use_se:
|
||||
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
|
||||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
|
||||
else:
|
||||
weight_shape = (out_channel, in_channel)
|
||||
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
|
||||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlock(3, 256, stride=2)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1,
|
||||
use_se=False, se_block=False):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.use_se = use_se
|
||||
self.se_block = se_block
|
||||
channel = out_channel // self.expansion
|
||||
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
|
||||
self.bn1 = _bn(channel)
|
||||
if self.use_se and self.stride != 1:
|
||||
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
|
||||
ops.ReLU(), ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
|
||||
else:
|
||||
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
|
||||
self.bn2 = _bn(channel)
|
||||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||
self.bn3 = _bn(out_channel)
|
||||
if self.se_block:
|
||||
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
|
||||
self.se_sigmoid = nn.Sigmoid()
|
||||
self.se_mul = P.Mul()
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
if self.use_se:
|
||||
if stride == 1:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
|
||||
stride, use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
|
||||
_conv1x1(in_channel, out_channel, 1,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
else:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
use_se=self.use_se), _bn(out_channel)])
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
if self.use_se and self.stride != 1:
|
||||
out = self.e2(out)
|
||||
else:
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.se_block:
|
||||
out_se = out
|
||||
out = self.se_global_pool(out, (2, 3))
|
||||
out = self.se_dense_0(out)
|
||||
out = self.relu(out)
|
||||
out = self.se_dense_1(out)
|
||||
out = self.se_sigmoid(out)
|
||||
out = F.reshape(out, F.shape(out) + (1, 1))
|
||||
out = self.se_mul(out, out_se)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResidualBlockBase(nn.Cell):
|
||||
"""
|
||||
ResNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer. Default: 1.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
res_base (bool): Enable parameter setting of resnet18. Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResidualBlockBase(3, 256, stride=2)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
stride=1,
|
||||
use_se=False,
|
||||
se_block=False,
|
||||
res_base=True):
|
||||
super(ResidualBlockBase, self).__init__()
|
||||
self.res_base = res_base
|
||||
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
|
||||
self.bn1d = _bn(out_channel)
|
||||
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
|
||||
self.bn2d = _bn(out_channel)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
self.down_sample = False
|
||||
if stride != 1 or in_channel != out_channel:
|
||||
self.down_sample = True
|
||||
|
||||
self.down_sample_layer = None
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
|
||||
use_se=use_se, res_base=self.res_base),
|
||||
_bn(out_channel, res_base)])
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1d(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2d(out)
|
||||
|
||||
if self.down_sample:
|
||||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = out + identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (Cell): Block for network.
|
||||
layer_nums (list): Numbers of block in different layers.
|
||||
in_channels (list): Input channel in each layer.
|
||||
out_channels (list): Output channel in each layer.
|
||||
strides (list): Stride size in each layer.
|
||||
num_classes (int): The number of classes that the training images are belonging to.
|
||||
use_se (bool): Enable SE-ResNet50 net. Default: False.
|
||||
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
|
||||
res_base (bool): Enable parameter setting of resnet18. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResNet(ResidualBlock,
|
||||
>>> [3, 4, 6, 3],
|
||||
>>> [64, 256, 512, 1024],
|
||||
>>> [256, 512, 1024, 2048],
|
||||
>>> [1, 2, 2, 2],
|
||||
>>> 10)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
strides,
|
||||
num_classes,
|
||||
use_se=False,
|
||||
res_base=False):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
self.use_se = use_se
|
||||
self.res_base = res_base
|
||||
self.se_block = False
|
||||
if self.use_se:
|
||||
self.se_block = True
|
||||
|
||||
if self.use_se:
|
||||
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
|
||||
self.bn1_0 = _bn(32)
|
||||
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
|
||||
self.bn1_1 = _bn(32)
|
||||
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
|
||||
else:
|
||||
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
|
||||
self.bn1 = _bn(64, self.res_base)
|
||||
self.relu = ops.ReLU()
|
||||
|
||||
if self.res_base:
|
||||
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
|
||||
self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
|
||||
else:
|
||||
self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=in_channels[0],
|
||||
out_channel=out_channels[0],
|
||||
stride=strides[0],
|
||||
use_se=self.use_se)
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=in_channels[1],
|
||||
out_channel=out_channels[1],
|
||||
stride=strides[1],
|
||||
use_se=self.use_se)
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=in_channels[2],
|
||||
out_channel=out_channels[2],
|
||||
stride=strides[2],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=in_channels[3],
|
||||
out_channel=out_channels[3],
|
||||
stride=strides[3],
|
||||
use_se=self.use_se,
|
||||
se_block=self.se_block)
|
||||
|
||||
self.avgpool = ops.AvgPool2d(4)
|
||||
self.concat = P.Concat(1)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = _fc(16384, num_classes, use_se=self.use_se)
|
||||
|
||||
def construct(self, x):
|
||||
x = to_2channel(x[:, :3], x[:, 3:])
|
||||
if self.use_se:
|
||||
x = self.conv1_0(x)
|
||||
x = self.bn1_0(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv1_2(x)
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
if self.res_base:
|
||||
x_1, x_2 = get_x_and_y(x)
|
||||
x_1 = self.pad(x_1)
|
||||
x_2 = self.pad(x_2)
|
||||
x = to_2channel(x_1, x_2)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
out = self.avgpool(x)
|
||||
out_x, out_y = get_x_and_y(out)
|
||||
out = self.concat([out_x, out_y])
|
||||
out = self.flatten(out)
|
||||
out = self.end_point(out)
|
||||
return out
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
|
||||
"""
|
||||
Make stage network of ResNet.
|
||||
|
||||
Args:
|
||||
block (Cell): Resnet block.
|
||||
layer_num (int): Layer number.
|
||||
in_channel (int): Input channel.
|
||||
out_channel (int): Output channel.
|
||||
stride (int): Stride size for the first convolutional layer.
|
||||
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
|
||||
Returns:
|
||||
SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
|
||||
"""
|
||||
layers = []
|
||||
|
||||
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
if se_block:
|
||||
for _ in range(1, layer_num - 1):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
|
||||
layers.append(resnet_block)
|
||||
else:
|
||||
for _ in range(1, layer_num):
|
||||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
|
||||
layers.append(resnet_block)
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
def resnet18(class_num=10):
|
||||
"""
|
||||
Get ResNet18 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet18 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet18(10)
|
||||
"""
|
||||
return ResNet(ResidualBlockBase,
|
||||
[2, 2, 2, 2],
|
||||
[64, 64, 128, 256],
|
||||
[64, 128, 256, 512],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
res_base=True)
|
||||
|
||||
|
||||
def resnet34(class_num=10):
|
||||
"""
|
||||
Get ResNet34 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet34 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet18(10)
|
||||
"""
|
||||
return ResNet(ResidualBlockBase,
|
||||
[3, 4, 6, 3],
|
||||
[64, 64, 128, 256],
|
||||
[64, 128, 256, 512],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
res_base=True)
|
||||
|
||||
|
||||
def resnet50(class_num=10):
|
||||
"""
|
||||
Get ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet50(10)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
||||
|
||||
|
||||
def se_resnet50(class_num=1001):
|
||||
"""
|
||||
Get SE-ResNet50 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of SE-ResNet50 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = se-resnet50(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 6, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num,
|
||||
use_se=True)
|
||||
|
||||
|
||||
def resnet101(class_num=1001):
|
||||
"""
|
||||
Get ResNet101 neural network.
|
||||
|
||||
Args:
|
||||
class_num (int): Class number.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of ResNet101 neural network.
|
||||
|
||||
Examples:
|
||||
>>> net = resnet101(1001)
|
||||
"""
|
||||
return ResNet(ResidualBlock,
|
||||
[3, 4, 23, 3],
|
||||
[64, 256, 512, 1024],
|
||||
[256, 512, 1024, 2048],
|
||||
[1, 2, 2, 2],
|
||||
class_num)
|
|
@ -0,0 +1,13 @@
|
|||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from deepconvnet import DeepConvNet
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
model = DeepConvNet()
|
||||
model.set_train(False)
|
||||
u = Tensor(np.random.random((2, 32, 1, 4096)).astype(np.float32))
|
||||
y = model(u)
|
||||
print(P.Shape()(y), y)
|
|
@ -0,0 +1,112 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import nn, context, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.dataset import MnistDataset
|
||||
from hcmodel import HCModel
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
||||
class ImageToDualImage:
|
||||
@staticmethod
|
||||
def __call__(img):
|
||||
return np.concatenate((img, img), axis=0)
|
||||
|
||||
|
||||
def create_dataset(dataset_dir, batch_size, usage=None):
|
||||
dataset = MnistDataset(dataset_dir=dataset_dir, usage=usage)
|
||||
type_cast_op = mindspore.dataset.transforms.TypeCast(mstype.int32)
|
||||
|
||||
# define map operations
|
||||
trans = [mindspore.dataset.vision.Rescale(1.0 / 255.0, 0),
|
||||
mindspore.dataset.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
|
||||
mindspore.dataset.vision.HWC2CHW(),
|
||||
ImageToDualImage()]
|
||||
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="label")
|
||||
dataset = dataset.map(operations=trans, input_columns="image")
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def train(model, dataset, loss_fn, optimizer):
|
||||
# Define forward function
|
||||
def forward_fn(data, label):
|
||||
logits = model(data)
|
||||
loss = loss_fn(logits, label)
|
||||
return loss, logits
|
||||
|
||||
# Get gradient function
|
||||
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
||||
|
||||
# Define function of one-step training
|
||||
def train_step(data, label):
|
||||
(loss, _), grads = grad_fn(data, label)
|
||||
loss = ops.depend(loss, optimizer(grads))
|
||||
return loss
|
||||
|
||||
size = dataset.get_dataset_size()
|
||||
model.set_train()
|
||||
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
||||
loss = train_step(data, label)
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.asnumpy(), batch
|
||||
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
|
||||
|
||||
|
||||
def test(model, dataset, loss_fn):
|
||||
num_batches = dataset.get_dataset_size()
|
||||
model.set_train(False)
|
||||
total, test_loss, correct = 0, 0, 0
|
||||
for data, label in dataset.create_tuple_iterator():
|
||||
pred = model(data)
|
||||
total += len(data)
|
||||
test_loss += loss_fn(pred, label).asnumpy()
|
||||
correct += (pred.argmax(1) == label).asnumpy().sum()
|
||||
test_loss /= num_batches
|
||||
correct /= total
|
||||
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Testing')
|
||||
parser.add_argument(
|
||||
'--dataset', default=None, type=str, metavar='DS', required=True,
|
||||
help='Path to the dataset folder'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bs', default=64, type=int, metavar='N', required=False,
|
||||
help='Mini-batch size'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the MNIST dataset.
|
||||
train_dataset = create_dataset(args.dataset, args.bs, "train")
|
||||
test_dataset = create_dataset(args.dataset, args.bs, "test")
|
||||
|
||||
for img, lbl in test_dataset.create_tuple_iterator():
|
||||
print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
|
||||
print(f"Shape of label: {lbl.shape} {lbl.dtype}")
|
||||
break
|
||||
|
||||
# Initialize hypercomplex model
|
||||
net = HCModel()
|
||||
|
||||
# Initialize loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optim = nn.SGD(net.trainable_params(), 1e-2)
|
||||
|
||||
epochs = 10
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t+1}\n-------------------------------")
|
||||
train(net, train_dataset, criterion, optim)
|
||||
test(net, test_dataset, criterion)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,115 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import nn, context, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.dataset import Cifar10Dataset
|
||||
from mindspore.dataset import vision, transforms
|
||||
from resnet import resnet18
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class ImageToDualImage:
|
||||
@staticmethod
|
||||
def __call__(img):
|
||||
return np.concatenate((img, img), axis=0)
|
||||
|
||||
|
||||
def create_dataset(dataset_dir, batch_size, usage=None):
|
||||
dataset = Cifar10Dataset(dataset_dir=dataset_dir, usage=usage)
|
||||
type_cast_op = transforms.TypeCast(mstype.int32)
|
||||
|
||||
# define map operations
|
||||
trans = [vision.ToPIL(),
|
||||
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
|
||||
vision.RandomHorizontalFlip(prob=0.5),
|
||||
vision.Resize((224, 224)),
|
||||
vision.ToTensor(),
|
||||
vision.Rescale(1.0 / 255.0, 0.0),
|
||||
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010], is_hwc=False),
|
||||
ImageToDualImage()]
|
||||
|
||||
dataset = dataset.map(operations=type_cast_op, input_columns="label")
|
||||
dataset = dataset.map(operations=trans, input_columns="image")
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
|
||||
|
||||
def train(model, dataset, loss_fn, optimizer):
|
||||
# Define forward function
|
||||
def forward_fn(data, label):
|
||||
logits = model(data)
|
||||
loss = loss_fn(logits, label)
|
||||
return loss, logits
|
||||
|
||||
# Get gradient function
|
||||
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
||||
|
||||
# Define function of one-step training
|
||||
def train_step(data, label):
|
||||
(loss, _), grads = grad_fn(data, label)
|
||||
loss = ops.depend(loss, optimizer(grads))
|
||||
return loss
|
||||
|
||||
size = dataset.get_dataset_size()
|
||||
model.set_train()
|
||||
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
|
||||
loss = train_step(data, label)
|
||||
|
||||
if batch % 100 == 0:
|
||||
loss, current = loss.asnumpy(), batch
|
||||
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
|
||||
|
||||
|
||||
def test(model, dataset, loss_fn):
|
||||
num_batches = dataset.get_dataset_size()
|
||||
model.set_train(False)
|
||||
total, test_loss, correct = 0, 0, 0
|
||||
for data, label in dataset.create_tuple_iterator():
|
||||
pred = model(data)
|
||||
total += len(data)
|
||||
test_loss += loss_fn(pred, label).asnumpy()
|
||||
correct += (pred.argmax(1) == label).asnumpy().sum()
|
||||
test_loss /= num_batches
|
||||
correct /= total
|
||||
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='MindSpore ResNet Testing')
|
||||
parser.add_argument(
|
||||
'--dataset', default=None, type=str, metavar='DS', required=True,
|
||||
help='Path to the dataset folder'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bs', default=64, type=int, metavar='N', required=False,
|
||||
help='Mini-batch size'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Process the cifar dataset.
|
||||
train_dataset = create_dataset(args.dataset, args.bs, "train")
|
||||
test_dataset = create_dataset(args.dataset, args.bs, "test")
|
||||
|
||||
for img, lbl in test_dataset.create_tuple_iterator():
|
||||
print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
|
||||
print(f"Shape of label: {lbl.shape} {lbl.dtype}")
|
||||
break
|
||||
|
||||
# Initialize hypercomplex model
|
||||
net = resnet18()
|
||||
|
||||
# Initialize loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optim = nn.SGD(net.trainable_params(), 1e-2)
|
||||
|
||||
epochs = 10
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t+1}\n-------------------------------")
|
||||
train(net, train_dataset, criterion, optim)
|
||||
test(net, test_dataset, criterion)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue