!47671 Add hypercomplex operators

Merge pull request !47671 from chenhaozhe/hypercomplex
This commit is contained in:
i-robot 2023-02-06 10:44:46 +00:00 committed by Gitee
commit 9e8e43fcfb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
34 changed files with 9482 additions and 0 deletions

View File

@ -74,6 +74,7 @@
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/_grad_experimental/grad_sparse_ops.py" "unused-variable"
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "not-callable"
"mindspore/mindspore/python/mindspore/hypercomplex" "useless-return"
# MindData
"mindspore/mindspore/python/mindspore/dataset/__init__.py" "redefined-builtin"

View File

@ -0,0 +1,24 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
HyperComplex operators.
Note:
This feature is a beta feature, and we are still improving its functionality.
The interface may be changed or removed in the future.
"""
import mindspore.hypercomplex.complex
import mindspore.hypercomplex.dual
import mindspore.hypercomplex.double

View File

@ -0,0 +1,23 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Complex Operators"""
from mindspore.hypercomplex.complex.complex_relu import ReLU
from mindspore.hypercomplex.complex.complex_operators import Conv1d, Conv2d, Conv3d
from mindspore.hypercomplex.complex.complex_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
from mindspore.hypercomplex.complex.complex_operators import Dense
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d

View File

@ -0,0 +1,139 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""complex BatchNorm implementation"""
from typing import Tuple
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_imag
class _BatchNormImpl(HCBatchNormImpl):
r"""
The implementor class of the Batch Normalization layer for complex numbers.
Implements the functionality specific to complex numbers and needed by the 'BatchNorm' class. This includes:
getting the norm of complex number, applying scaling and shift to a complex tensor, and updating the running
mean and variance, which are used during inference.
Args:
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Applies complex scaling and shift to an input tensor.
This function implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} - \text{Im(inp)} * \text{Im(scale)} + \text{Re(shift)}\\
\text{Im(out)} = \text{Re(inp)} * \text{Im(scale)} + \text{Im(inp)} * \text{Re(scale)} + \text{Im(shift)},
\end{align}
where :math:`inp` is the complex input tensors, :math:`scale` and :math:`shift` are complex parameters
representing the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
Args:
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
u_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the normalized inputs.
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
scale_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the scaling vector.
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
shift_y (Tensor): A tensor of shape (C,), which represents the imaginary part of the bias vector.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the imaginary parts of rescaled and
recentered inputs.
"""
out_x = u_x * scale_x - u_y * scale_y + shift_x
out_y = u_x * scale_y + u_y * scale_x + shift_y
return out_x, out_y
def get_norm(self,
u: Tensor) -> Tensor:
r"""
Calculates norm of complex elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \sqrt{\text{Re(inp)}^2 + \text{Im(inp)}^2 + \delta},
where :math:`inp` is the complex input tensors and :math:`\delta` is a small positive constant, which is needed
to avoid division by zero in case statistical variance is close to zero. :math:`\text{Re(...)}` and
:math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued expression inside
the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the complex domain
and has a real and an imaginary parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
norm2 = self.get_square_norm(u)
eps = 1e-7
out = ops.sqrt(norm2 + eps)
return out
def get_square_norm(self,
u: Tensor) -> Tensor:
r"""
Calculates element-wise squared norm of complex elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \text{Re(inp)}^2 + \text{Im(inp)}^2,
where :math:`inp` is the complex input tensors, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the complex domain
and has a real and an imaginary parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
u_r, u_i = get_real_and_imag(u)
out = u_r ** 2 + u_i ** 2
return out

View File

@ -0,0 +1,201 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""complex convolution implementation"""
import numbers
from typing import Callable, Tuple, Union
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import Initializer
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
from mindspore import ops as P
class _ConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for complex numbers.
Applies complex-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})
- \text{ccor}(\text{Im(kernel)}, \text{Im(inp)})\\
\text{Im(ccor)} = \text{ccor}(\text{Im(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Im(inp)})
\end{align}
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input.
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents the real and the imaginary parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
conv_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
out_rr = conv_op(real, self.weight_x)
out_ii = conv_op(imag, self.weight_y)
out_ri = conv_op(real, self.weight_y)
out_ir = conv_op(imag, self.weight_x)
out_r = out_rr - out_ii
out_i = out_ri + out_ir
return out_r, out_i
class _KaratsubaConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for complex numbers.
Applies complex-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{C1} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
\text{C2} = \text{ccor}(\text{Im(kernel)}, \text{Im(inp)})\\
\text{C3} = \text{ccor}(\text{Re(kernel)} + \text{Im(kernel)}, \text{Re(inp)} + \text{Im(inp)})\\
\text{Re(out)} = C1 - C2\\
\text{Im(out)} = C3 - C1 - C2,
\end{align}
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input.
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents the real and the imaginary parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
conv_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
c1 = conv_op(real, self.weight_x)
c2 = conv_op(imag, self.weight_y)
c3 = conv_op(real + imag, self.weight_x + self.weight_y)
out_r = c1 - c2
out_i = c3 - c1 - c2
return out_r, out_i
class _ReImConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for complex numbers.
Applies complex-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{inp_cat} = \text{cat}(\text{Re(inp)}, \text{Im(inp)}) \\
\text{K1} = \text{cat}(\text{Re(kernel)}, \text{-Im(kernel)}) \\
\text{K2} = \text{cat}(\text{Im(kernel)}, \text{Re(kernel)}) \\
\text{Re(ccor)} = \text{ccor}(\text{K1}, \text{Re(inp_cat)}) \\
\text{Im(ccor)} = \text{ccor}(\text{K2}, \text{Re(inp_cat)})
\end{align}
where :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the complex-valued input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
data type as the :math:`inp` created by the layer, :math:`\text{cat}` is concatenation along the channel axis,
:math:`\text{Re(...)}` and :math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued
expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
factory_kwargs (dict): Additional parameters, which must include data_format.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the complex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input.
- **imag** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents the real and the imaginary parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_ReImConvImpl, self).__init__(weight_init, weight_shape, **factory_kwargs)
data_format = factory_kwargs.get('data_format', 'nchw')
c_idx = data_format.lower().find('c')
if c_idx < 0:
raise ValueError(f"Data format {data_format} is unsupported")
self.concat = P.Concat(c_idx)
self.neg = P.Neg()
def construct(self,
conv_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
inp = self.concat([real, imag])
weight_y_neg = self.neg(self.weight_y)
w1 = self.concat([self.weight_x, weight_y_neg])
w2 = self.concat([self.weight_y, self.weight_x])
out_r = conv_op(inp, w1)
out_i = conv_op(inp, w2)
return out_r, out_i

View File

@ -0,0 +1,125 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""complex dense implementation"""
from typing import Callable, Tuple
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
class _DenseImpl(BaseDenseImpl):
r"""
The implementor class of the dense connected layer for complex numbers.
Applies complex-valued matrix multiplication for dense connected layer. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)} - \text{Im(inp)} * \text{Im(kernel)}\\
\text{Im(out)} = \text{Re(inp)} * \text{Im(kernel)} + \text{Im(inp)} * \text{Re(kernel)},
\end{align}
where :math:`inp` is the complex input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
data type as the :math:`inp` created by the layer, :math:`\text{Re(...)}` and :math:`\text{Im(...)}`
are respectively real and imaginary parts of the complex-valued expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **imag** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
matmul_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
out_rr = matmul_op(real, self.weight_x)
out_ii = matmul_op(imag, self.weight_y)
out_ri = matmul_op(real, self.weight_y)
out_ir = matmul_op(imag, self.weight_x)
out_r = out_rr - out_ii
out_i = out_ri + out_ir
return out_r, out_i
class _KaratsubaDenseImpl(BaseDenseImpl):
r"""
The implementor class of the dense connected layer for complex numbers.
Applies complex-valued matrix multiplication for dense connected layer. This layer implements the operation as:
.. math::
\begin{align}
\text{L1} = \text{Re(inp)} * \text{Re(kernel)}\\
\text{L2} = \text{Im(inp)} * \text{Im(kernel)}\\
\text{L3} = (\text{Re(inp)} + \text{Im(inp)}) * (\text{Re(kernel)} + \text{Im(kernel)})\\
\text{Re(out)} = L1 - L2\\
\text{Im(out)} = L3 - L1 - L2,
\end{align}
where :math:`inp` is the complex input tensors, :math:`\text{kernel}` is a complex weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a complex bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
:math:`\text{Im(...)}` are respectively real and imaginary parts of the complex-valued expression inside
the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **imag** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
matmul_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
l1 = matmul_op(real, self.weight_x)
l2 = matmul_op(imag, self.weight_y)
l3 = matmul_op(real + imag, self.weight_x + self.weight_y)
out_r = l1 - l2
out_i = l3 - l1 - l2
return out_r, out_i

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,62 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""complex ReLU implementation"""
import mindspore
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_imag, to_2channel as to_complex
class ReLU(nn.Cell):
r"""
Rectified Linear Unit activation function for complex-valued input.
Applies ReLU activation layer for the complex-valued input. This layer applies the element-wise
:math:`\max(0, x)` for both real and imaginary parts of the input tensor independently:
.. math::
\begin{align}
\text{Re(out)} = (Re(inp))^+ = \max(0, Re(inp))\\
\text{Im(out)} = (Im(inp))^+ = \max(0, Im(inp)),
\end{align}
Inputs:
- **inp** (Tensor) - The input of ReLU is a Tensor of shape (2, *, ..., *), with float16 or float32 data type,
or (*, ..., *), with complex64 data type.
Outputs:
Tensor, with the same data type and shape as the `inp`.
Raises:
TypeError: If dtype of `inp` is not float16, float32, or complex64.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self):
"""Initialize ReLU."""
super(ReLU, self).__init__()
self.relu = P.ReLU()
def construct(self, u: Tensor) -> Tensor:
if u.dtype == mindspore.complex64:
real, imag = get_real_and_imag(u)
real = self.relu(real)
imag = self.relu(imag)
out = to_complex(real, imag, u.dtype)
else:
out = self.relu(u)
return out

View File

@ -0,0 +1,23 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Double Operators"""
from mindspore.hypercomplex.double.double_operators import Conv1d, Conv2d, Conv3d
from mindspore.hypercomplex.double.double_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
from mindspore.hypercomplex.double.double_operators import Dense
from mindspore.hypercomplex.double.double_operators import ReLU
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d

View File

@ -0,0 +1,256 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Double BatchNorm implementation"""
from typing import Tuple
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
from mindspore.hypercomplex.utils import get_x_and_y
get_real_and_double = get_x_and_y
get_u1_and_u2 = get_x_and_y
class _BatchNormImpl(HCBatchNormImpl):
r"""
The implementor class of the Batch Normalization layer for double numbers in regular representation.
Implements the functionality specific to double numbers and needed by the 'BatchNorm' class. This includes:
getting the norm of double number, applying scaling and shift to a double-valued tensor, and updating the running
mean and variance, which are used during inference.
Args:
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Applies double scaling and shift to an input tensor in regular representation.
This function implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} + \text{Db(inp)} * \text{Db(scale)} + \text{Re(shift)}\\
\text{Db(out)} = \text{Re(inp)} * \text{Db(scale)} + \text{Db(inp)} * \text{Re(scale)} + \text{Db(shift)},
\end{align}
where :math:`inp` is the double input tensors, :math:`scale` and :math:`shift` are double parameters
representing the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
are respectively real and double parts of the double-valued expression inside the parentheses.
Args:
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
u_y (Tensor): A tensor of shape (C,), which represents the double part of the normalized inputs.
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
scale_y (Tensor): A tensor of shape (C,), which represents the double part of the scaling vector.
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
shift_y (Tensor): A tensor of shape (C,), which represents the double part of the bias vector.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the double parts of rescaled and
recentered inputs.
"""
out_x = u_x * scale_x + u_y * scale_y + shift_x
out_y = u_x * scale_y + u_y * scale_x + shift_y
return out_x, out_y
def get_norm(self, u: Tensor) -> Tensor:
r"""
Calculates norm of double elements of an input tensor in regular representation.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = |Re(inp)| + |Db(inp)|,
where :math:`inp` is the double input tensors, :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
are respectively real and double parts of the double-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
and has two components.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
u_r, u_d = get_real_and_double(u)
abs_op = ops.Abs()
out = abs_op(u_r) + abs_op(u_d)
return out
def get_square_norm(self, u: Tensor) -> Tensor:
r"""
Calculates element-wise squared norm of double elements of an input tensor in regular representation.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \left(|Re(inp)| + |Db(inp)|\right)^2,
where :math:`inp` is the double input tensors, :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
are respectively real and double parts of the double-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
and has a real and a double parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
norm = self.get_norm(u)
out = norm ** 2
return out
class _J1J2BatchNormImpl(HCBatchNormImpl):
r"""
The implementor class of the Batch Normalization layer for double numbers in diagonal representation.
Implements the functionality specific to double numbers and needed by the 'BatchNorm' class. This includes:
getting the norm of double number, applying scaling and shift to a double-valued tensor, and updating the running
mean and variance, which are used during inference.
Args:
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Applies double scaling and shift to an input tensor in diagonal representation.
This function implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{X(inp)} * \text{Y(scale)} + \text{X(shift)}\\
\text{Db(out)} = \text{X(inp)} * \text{Y(scale)} + \text{Y(inp)},
\end{align}
where :math:`inp` is the double input tensors in diagonal form, :math:`scale` and :math:`shift` are
double parameters representing the scaling and shift coefficients respectively. :math:`\text{X(...)}`
and :math:`\text{Y(...)}` are respectively the first and the second components of the double-valued
expression inside the parentheses.
Args:
u_x (Tensor): A tensor of shape (C,), which represents the first part of the normalized inputs.
u_y (Tensor): A tensor of shape (C,), which represents the second part of the normalized inputs.
scale_x (Tensor): A tensor of shape (C,), which represents the first part of the scaling vector.
scale_y (Tensor): A tensor of shape (C,), which represents the second part of the scaling vector.
shift_x (Tensor): A tensor of shape (C,), which represents the first part of the bias vector.
shift_y (Tensor): A tensor of shape (C,), which represents the second part of the bias vector.
Returns:
Tuple of two tensors of shape (C,), which contains the first and the second parts of rescaled and
recentered inputs in the diagonal representation.
"""
out_x = u_x * scale_x + shift_x
out_y = u_y * scale_y + shift_y
return out_x, out_y
def get_norm(self, u: Tensor) -> Tensor:
r"""
Calculates norm of double elements of an input tensor in diagonal representation.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \text{max}(|X(inp)|, |Y(inp)|),
where :math:`inp` is the double input tensors in diagonal form, :math:`\text{max}` is the maximum value of its
arguments. :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the second components
of the double-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
and has two components.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
u_1, u_2 = get_u1_and_u2(u)
abs_op = ops.Abs()
max_op = ops.Maximum()
out = max_op(abs_op(u_1), abs_op(u_2))
return out
def get_square_norm(self, u: Tensor) -> Tensor:
r"""
Calculates element-wise squared norm of double elements of an input tensor in diagonal representation.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \left(\text{max}(|X(inp)|, |Y(inp)|)\right)^2,
where :math:`inp` is the double input tensors in diagonal form, :math:`\text{max}` is the maximum value of its
arguments. :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the second components
of the double-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the double domain
and has two components.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
norm = self.get_norm(u)
out = norm ** 2
return out

View File

@ -0,0 +1,126 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Double convolution implementation"""
from typing import Callable, Tuple
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
class _ConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for double numbers in regular representation.
Applies double-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Db(kernel)}, \text{Db(inp)})\\
\text{Du(ccor)} = \text{ccor}(\text{Db(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Db(inp)}),
\end{align}
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
data type as the :math:`inp` created by the layer. :math:`\text{Re(...)}` and :math:`\text{Db(...)}`
are respectively the first and the second parts of the double-valued expression inside the parentheses in the
regular form.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and double parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the double convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the first part of the input.
- **double** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the second part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents both the first and the second parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
conv_op: Callable,
real: Tensor,
double: Tensor) -> Tuple[Tensor, Tensor]:
u1 = real + double
u2 = real - double
out1 = conv_op(u1, self.weight_x)
out2 = conv_op(u2, self.weight_y)
out_r = out1 + out2
out_d = out1 - out2
return out_r, out_d
class _J1J2ConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for double numbers in diagonal representation.
Applies double-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(ccor)} = \text{ccor}(\text{X(kernel)}, \text{X(inp)}) + \text{X(bias)}\\
\text{Du(ccor)} = \text{ccor}(\text{Y(kernel)}, \text{Y(inp)}) + \text{Y(bias)},
\end{align}
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
data type as the :math:`inp` created by the layer. :math:`\text{X(...)}` and :math:`\text{Y(...)}`
are respectively the first and the second parts of the double-valued expression inside the parentheses in the
diagonal form.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and double parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the double convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **u1** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the first part of the input.
- **u2** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the second part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents both the first and the second parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
conv_op: Callable,
u1: Tensor,
u2: Tensor) -> Tuple[Tensor, Tensor]:
out1 = conv_op(u1, self.weight_x)
out2 = conv_op(u2, self.weight_y)
return out1, out2

View File

@ -0,0 +1,123 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Double dense implementation"""
from typing import Callable, Tuple
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
class _DenseImpl(BaseDenseImpl):
r"""
The implementor class of the dense connected layer for double numbers in normal representation.
Applies double-valued matrix multiplication for dense connected layer. This layer implements the operation as:
.. math::
\begin{align}
\text{X(out)} = \text{X(inp)} * \text{X(kernel)} + \text{Y(inp)} * \text{Y(kernel)}\\
\text{Y(out)} = \text{X(inp)} * \text{Y(kernel)} + \text{Y(inp)} * \text{X(kernel)},
\end{align}
where :math:`inp` is the double input tensors, :math:`\text{kernel}` is a double weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a double bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{X(...)}` and
:math:`\text{Y(...)}` are respectively the first and the second parts of the double-valued expression
inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **double** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the
input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
matmul_op: Callable,
real: Tensor,
double: Tensor) -> Tuple[Tensor, Tensor]:
u1 = real + double
u2 = real - double
out1 = matmul_op(u1, self.weight_x)
out2 = matmul_op(u2, self.weight_y)
out_r = out1 + out2
out_d = out1 - out2
return out_r, out_d
class _J1J2DenseImpl(BaseDenseImpl):
r"""
The implementor class of the dense connected layer for double numbers in the diagonal representation.
Applies double-valued matrix multiplication for dense connected layer. This layer implements the operation as:
.. math::
\begin{align}
\text{X(out)} = \text{X(inp)} * \text{X(kernel)}\\
\text{Y(out)} = \text{Y(inp)} * \text{Y(kernel)},
\end{align}
where :math:`inp` is the double input tensors in the diagonal form, :math:`\text{kernel}` is a double weight matrix
in the diagonal form with the same data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is
a double bias vector in the diagonal form with the same data type as the :math:`inp` created by the layer
(only if has_bias is True). :math:`\text{X(...)}` and :math:`\text{Y(...)}` are respectively the first and the
second parts of the double-valued expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and imaginary parts of the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the complex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **u1** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **u2** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the imaginary part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the imaginary
parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
matmul_op: Callable,
u1: Tensor,
u2: Tensor) -> Tuple[Tensor, Tensor]:
out1 = matmul_op(u1, self.weight_x)
out2 = matmul_op(u2, self.weight_y)
return out1, out2

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,77 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Double relu operators"""
from mindspore import nn, Tensor
from mindspore.ops import operations as P
from mindspore.hypercomplex.utils import get_x_and_y as get_u1_and_u2, \
to_2channel as to_double
class J1J2ReLU(nn.Cell):
r"""
Rectified Linear Unit activation function for double-valued input in the diagonal representation.
Applies ReLU activation layer for the double-valued input. This layer first converts the input to the regular form:
.. math::
\begin{align}
\text{Re(inp)} = 0.5 * (\text{X(inp)} + \text{Y(inp)})\\
\text{Db(inp)} = 0.5 * (\text{X(inp)} - \text{Y(inp)}),
\end{align}
then applies the element-wise :math:`\max(0, x)` for both real and double parts of the input tensor independently:
.. math::
\begin{align}
\text{Re(out)} = (Re(inp))^+ = \max(0, Re(inp))\\
\text{Db(out)} = (Db(inp))^+ = \max(0, Db(inp)),
\end{align}
and finally transfers the result back to the diagonal representation:
.. math::
\begin{align}
\text{X(out)} = \text{Re(out)} + \text{Db(out)}\\
\text{Y(out)} = \text{Re(out)} - \text{Db(out)}
\end{align}
Inputs:
- **inp** (Tensor) - The input of ReLU is a Tensor of shape (2, *, ..., *). The data type is
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore.html#mindspore.dtype>`_ .
Outputs:
Tensor, with the same type and shape as the `inp`.
Raises:
TypeError: If dtype of `inp` is not a number.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self):
"""Initialize J1J2ReLU."""
super(J1J2ReLU, self).__init__()
self.relu = P.ReLU()
def construct(self, u: Tensor) -> Tensor:
u = u / 2
u1, u2 = get_u1_and_u2(u)
x = self.relu(u1 + u2)
y = self.relu(u1 - u2)
out1 = x + y
out2 = x - y
out = to_double(out1, out2)
return out

View File

@ -0,0 +1,24 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dual operators"""
from mindspore.nn import ReLU
from mindspore.hypercomplex.dual.dual_operators import Conv1d, Conv2d, Conv3d
from mindspore.hypercomplex.dual.dual_operators import BatchNorm1d, BatchNorm2d, BatchNorm3d
from mindspore.hypercomplex.dual.dual_operators import Dense
from mindspore.hypercomplex.hypercomplex.hc_pool import MaxPool1d, MaxPool2d, \
AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, \
AdaptiveAvgPool3d, AdaptiveMaxPool1d, AdaptiveMaxPool2d

View File

@ -0,0 +1,140 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dual BatchNorm Implementation"""
from typing import Tuple
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BaseBatchNormImpl as HCBatchNormImpl
from mindspore.hypercomplex.utils import get_x_and_y as get_real_and_dual
class _BatchNormImpl(HCBatchNormImpl):
r"""
The implementor class of the Batch Normalization layer for dual numbers.
Implements the functionality specific to dual numbers and needed by the 'BatchNorm' class. This includes:
getting the norm of dual number, applying scaling and shift to a dual tensor, and updating the running
mean and variance, which are used during inference.
Args:
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Applies dual scaling and shift to an input tensor.
This function implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(scale)} + \text{Re(shift)}\\
\text{Du(out)} = \text{Re(inp)} * \text{Du(scale)} + \text{Du(inp)} * \text{Re(scale)} + \text{Du(shift)},
\end{align}
where :math:`inp` is the dual input tensors, :math:`scale` and :math:`shift` are dual parameters representing
the scaling and shift coefficients respectively. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are
respectively real and dual parts of the dual-valued expression inside the parentheses.
Args:
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
u_y (Tensor): A tensor of shape (C,), which represents the dual part of the normalized inputs.
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
scale_y (Tensor): A tensor of shape (C,), which represents the dual part of the scaling vector.
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
shift_y (Tensor): A tensor of shape (C,), which represents the dual part of the bias vector.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the dual parts of rescaled and
recentered inputs.
"""
out_x = u_x * scale_x + shift_x
out_y = u_x * scale_y + u_y * scale_x + shift_y
return out_x, out_y
def get_norm(self, u: Tensor) -> Tensor:
r"""
Calculates norm of dual elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \left|\frac{Du(inp)}{2}\right|+\sqrt{Re(inp)^2+\frac{Du(inp)^2}{4}+\delta},
where :math:`inp` is the dual input tensors and :math:`\delta` is a small positive constant, which is needed
to avoid division by zero in case statistical variance is close to zero. :math:`\text{Re(...)}` and
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the dual domain
and has a real and a dual parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
u_r, u_d = get_real_and_dual(u)
dual_half = u_d.abs() / 2
eps = 1e-7
sqrt = u_r ** 2 + dual_half ** 2 + eps
sqrt = ops.sqrt(sqrt)
out = dual_half + sqrt
return out
def get_square_norm(self, u: Tensor) -> Tensor:
r"""
Calculates element-wise squared norm of dual elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero. The function implements the operation as:
.. math::
\text{out} = \left(\left|\frac{Du(inp)}{2}\right|+\sqrt{Re(inp)^2+\frac{Du(inp)^2}{4}+\delta}\right)^2,
where :math:`inp` is the dual input tensors, :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
are respectively real and dual parts of the dual-valued expression inside the parentheses.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the dual domain
and has a real and a dual parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
norm = self.get_norm(u)
out = norm ** 2
return out

View File

@ -0,0 +1,138 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dual Convolution Implementation"""
import numbers
from typing import Callable, Tuple, Union
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import Initializer
from mindspore.hypercomplex.hypercomplex._hc_conv_impl import _BaseConvImpl as BaseConvImpl
from mindspore import ops as P
class _ConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for dual numbers.
Applies dual-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}),
\end{align}
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer. :math:`\text{Re(...)}` and :math:`\text{Du(...)}`
are respectively real and dual parts of the dual-valued expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the dual convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input.
- **dual** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the dual part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents the real and the dual parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
conv_op: Callable,
real: Tensor,
dual: Tensor) -> Tuple[Tensor, Tensor]:
out_r = conv_op(real, self.weight_x)
out_rd = conv_op(real, self.weight_y)
out_dr = conv_op(dual, self.weight_x)
out_d = out_rd + out_dr
return out_r, out_d
class _ReDuConvImpl(BaseConvImpl):
r"""
The implementor class of the convolution layer for dual numbers.
Applies dual-valued convolution transformation. This layer implements the operation as:
.. math::
\begin{align}
\text{inp_cat} = \text{cat}(\text{Re(inp)}, \text{Du(inp)}) \\
\text{K} = \text{cat}(\text{Du(kernel)}, \text{Re(kernel)}) \\
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)})\\
\text{Du(ccor)} = \text{ccor}(\text{K}, \text{Re(inp_cat)})
\end{align}
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer, :math:`\text{cat}` is concatenation along the channel axis.
:math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression
inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used for decomposition
of the dual convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be passed
- **real** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input.
- **dual** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the dual part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`,
which represents the real and the dual parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_ReDuConvImpl, self).__init__(weight_init, weight_shape, **factory_kwargs)
data_format = factory_kwargs.get('data_format', 'nchw')
c_idx = data_format.lower().find('c')
if c_idx < 0:
raise ValueError(f"Data format {data_format} is unsupported")
self.concat = P.Concat(c_idx)
def construct(self,
conv_op: Callable,
real: Tensor,
imag: Tensor) -> Tuple[Tensor, Tensor]:
out_r = conv_op(real, self.weight_x)
inp = self.concat([real, imag])
w = self.concat([self.weight_y, self.weight_x])
out_d = conv_op(inp, w)
return out_r, out_d

View File

@ -0,0 +1,69 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dual Dense Implementation"""
from typing import Callable, Tuple
from mindspore.common.tensor import Tensor
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _BaseDenseImpl as BaseDenseImpl
class _DenseImpl(BaseDenseImpl):
r"""
The implementor class of the dense connected layer for dual numbers.
Applies dual-valued matrix multiplication for dense connected layer. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)}\\
\text{Du(out)} = \text{Re(inp)} * \text{Du(kernel)} + \text{Du(inp)} * \text{Re(kernel)},
\end{align}
where :math:`inp` is the hypercomplex input tensors, :math:`\text{kernel}` is
a hypercomplex weight matrix with the same data type as the :math:`inp` created by the layer,
:math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued
expression inside the parentheses.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and dual parts of the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the dual linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **real** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **dual** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the dual part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the dual
parts of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def construct(self,
matmul_op: Callable,
real: Tensor,
dual: Tensor) -> Tuple[Tensor, Tensor]:
out_r = matmul_op(real, self.weight_x)
out_rd = matmul_op(real, self.weight_y)
out_dr = matmul_op(dual, self.weight_x)
out_d = out_rd + out_dr
return out_r, out_d

View File

@ -0,0 +1,958 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dual Operators"""
import numbers
from typing import Union
from mindspore.common.initializer import Initializer
from mindspore.common.tensor import Tensor
# Batch Normalization
from mindspore.hypercomplex.hypercomplex.hc_bn import BatchNorm1d as HBatchNorm1d, \
BatchNorm2d as HBatchNorm2d, BatchNorm3d as HBatchNorm3d
from mindspore.hypercomplex.dual._dual_bn_impl import _BatchNormImpl as BatchNormImpl
# Convolution
from mindspore.hypercomplex.hypercomplex.hc_conv import Conv1d as HConv1d, \
Conv2d as HConv2d, Conv3d as HConv3d
from mindspore.hypercomplex.dual._dual_conv_impl import _ReDuConvImpl as ConvImpl
# Dense
from mindspore.hypercomplex.hypercomplex.hc_dense import Dense as HDense
from mindspore.hypercomplex.dual._dual_dense_impl import _DenseImpl as DenseImpl
from mindspore.hypercomplex.hypercomplex.uniform_operator import _UniformOperator
from mindspore.hypercomplex.utils import _size_1_t, _size_2_t, _size_3_t
class Conv2d(_UniformOperator):
r"""
2D convolution layer on the dual-valued input.
Calculates the 2D convolution on the input tensor which is typically of shape
:math:`(2, N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size, :math:`C_{in}` is a number of channels,
:math:`H_{in}, W_{in}` are the height and width of the feature layer respectively. The formula is defined as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{hccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
the output and :math:`j` is in the range of :math:`[0, C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
is a convolution kernel slice with shape :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`,
where :math:`\text{kernel_size[0]}` and :math:`\text{kernel_size[1]}` are the height and width of the convolution
kernel respectively. :math:`\text{bias}` is the bias parameter and :math:`\text{inp}` is the input tensor.
In this case, `data_format` of the input tensor is 'NCHW' and the shape of full convolution kernel is
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]})`,
where `group` is the number of groups to split the input `inp` in the channel dDuension. If `data_format` of the
input tensor is 'NHWC', the shape of full convolution kernel will be
:math:`(C_{out}, \text{kernel_size[0]}, \text{kernel_size[1]}), C_{in} / \text{group}`.
:math:`hccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
This implies the operation as:
.. math::
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
where and :math:`cccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}`
and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
the parentheses.
For more details, please refers to the paper `Gradient Based Learning Applied to Document
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Note:
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
That is, when `group > 1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
'NCHW' format is supported only with GPU target device as of now.
Args:
in_channels (int): The channel number of the input tensor of the Conv2d layer.
out_channels (int): The channel number of the output tensor of the Conv2d layer.
kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution kernel.
The data type is an integer or a tuple of two integers. An integer represents the height
and width of the convolution kernel. A tuple of two integers represents the height
and width of the convolution kernel respectively.
stride (Union[int, tuple[int]]): The movement stride of the 2D convolution kernel.
The data type is an integer or a tuple of two integers. An integer represents the movement step size
in both height and width directions. A tuple of two integers represents the movement step size in
the height and width directions respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: The width of the output is the same as the value of the input divided by `stride`.
If this mode is set, the value of `padding` must be 0.
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
If this mode is set, the value of `padding` must be greater than or equal to 0.
padding (Union[int, tuple[int]]): The number of padding on the height and width directions of the input.
The data type is an integer or a tuple of four integers. If `padding` is an integer,
then the top, bottom, left, and right padding are all equal to `padding`.
If `padding` is a tuple of 4 integers, then the top, bottom, left, and right padding
is equal to `padding[0]`, `padding[1]`, `padding[2]`, and `padding[3]` respectively.
The value should be greater than or equal to 0. Default: 0.
dilation (Union[int, tuple[int]]): Dilation size of 2D convolution kernel.
The data type is an integer or a tuple of two integers. If :math:`k > 1`, the kernel is sampled
every `k` elements. The value of `k` on the height and width directions is in range of [1, H]
and [1, W] respectively. Default: 1.
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
divisible by `group`. If the group is equal to `in_channels` and `out_channels`,
this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
has_bias (bool): Whether the Conv2d layer has a bias parameter. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
Available initialization methods are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, H_{in}, W_{in})` \
or :math:`(2, N, H_{in}, W_{in}, C_{in})`.
Outputs:
Tensor of shape :math:`(2, N, C_{out}, H_{out}, W_{out})` or :math:`(2, N, H_{out}, W_{out}, C_{out})`.
pad_mode is 'same':
.. math::
\begin{array}{ll} \\
H_{out} \left \lceil{\frac{H_{in}}{\text{stride[0]}}} \right \rceil \\
W_{out} \left \lceil{\frac{W_{in}}{\text{stride[1]}}} \right \rceil \\
\end{array}
pad_mode is 'valid':
.. math::
\begin{array}{ll} \\
H_{out} \left \lceil{\frac{H_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
{\text{stride[0]}}} \right \rceil \\
W_{out} \left \lceil{\frac{W_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
{\text{stride[1]}}} \right \rceil \\
\end{array}
pad_mode is 'pad':
.. math::
\begin{array}{ll} \\
H_{out} \left \lfloor{\frac{H_{in} + padding[0] + padding[1] - (\text{kernel_size[0]} - 1) \times
\text{dilation[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
W_{out} \left \lfloor{\frac{W_{in} + padding[2] + padding[3] - (\text{kernel_size[1]} - 1) \times
\text{dilation[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
\end{array}
Raises:
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int not a tuple.
ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `padding` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
ValueError: If `padding` is a tuple whose length is not equal to 4.
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0).
ValueError: If `data_format` is neither 'NCHW' not 'NHWC'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import Conv2d
>>> from mindspore import Tensor
>>> w = Tensor(np.random.random((2, 128, 3, 7, 7)).astype(np.float32))
>>> b = Tensor(np.random.random((2, 128)).astype(np.float32))
>>> net = Conv2d(
>>> in_channels=3, out_channels=128, kernel_size=7, stride=2, padding=3,
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
>>> )
>>> inp = Tensor(np.random.random((2, 16, 3, 224, 224)).astype(np.float32))
>>> out = net(inp)
>>> print(out.shape)
(2, 16, 128, 112, 112)
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
pad_mode: str = 'same',
padding: _size_2_t = 0,
dilation: _size_2_t = 1,
group: int = 1,
has_bias: bool = False,
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
data_format: str = 'NCHW') -> None:
super(Conv2d, self).__init__(HConv2d,
ConvImpl,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init,
data_format=data_format)
class Conv1d(_UniformOperator):
r"""
1D convolution layer on the dual-valued input.
Calculates the 1D convolution on the input tensor which is typically of shape :math:`(2, N, C_{in}, L_{in})`,
where :math:`N` is batch size, :math:`C_{in}` is a number of channels and :math:`L_{in}` is a length of sequence.
The formula is defined as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{ccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
the output and :math:`j` is in the range of :math:`[0, C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
is a convolution kernel slice with shape :math:`\text{kernel_size}`, where :math:`\text{kernel_size}`
is the width of the convolution kernel. :math:`\text{bias}` is the bias parameter,
and :math:`\text{inp}` is the input tensor. The shape of full convolution kernel is
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size})`,
where `group` is the number of groups to split the input `inp` in the channel dimension.
:math:`ccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
This implies the operation as:
.. math::
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside the parentheses.
For more details, please refers to the paper `Gradient Based Learning Applied to Document
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Note:
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
That is, when `group > 1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
Args:
in_channels (int): The channel number of the input tensor of the Conv1d layer.
out_channels (int): The channel number of the output tensor of the Conv1d layer.
kernel_size (int): Specifies the width of the 1D convolution kernel.
stride (int): The movement stride of the 1D convolution kernel. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: The width of the output is the same as the value of the input divided by `stride`.
If this mode is set, the value of `padding` must be 0.
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
If this mode is set, the value of `padding` must be greater than or equal to 0.
padding (int): The number of padding on both sides of input.
The value should be greater than or equal to 0. Default: 0.
dilation (int): Dilation size of 1D convolution kernel. If :math:`k > 1`, the kernel is sampled
every `k` elements. The value of `k` is in range of [1, L]. Default: 1.
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
divisible by `group`. Default: 1.
has_bias (bool): Whether the Conv1d layer has a bias parameter. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
Available initialization methods are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, L_{in})`.
Outputs:
Tensor of shape :math:`(2, N, C_{out}, L_{out})`.
pad_mode is 'same':
.. math::
L_{out} = \left \lceil{\frac{L_{in}}{\text{stride}}} \right \rceil
pad_mode is 'valid':
.. math::
L_{out} = \left \lceil{\frac{L_{in} - \text{dilation} \times (\text{kernel_size} - 1) }
{\text{stride}}} \right \rceil
pad_mode is 'pad':
.. math::
L_{out} = \left \lfloor{\frac{L_{in} + 2 \times padding - (\text{kernel_size} - 1) \times
\text{dilation} - 1 }{\text{stride}} + 1} \right \rfloor
Raises:
TypeError: If `in_channels`, `out_channels`, `kernel_size`, `stride`, `padding` or `dilation` is not an int.
ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `padding` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import Conv1d
>>> from mindspore import Tensor
>>> w = Tensor(np.random.random((2, 16, 1, 6)).astype(np.float32))
>>> b = Tensor(np.random.random((2, 16)).astype(np.float32))
>>> net = Conv1d(
>>> in_channels=1, out_channels=16, kernel_size=6, stride=2, padding=2,
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
>>> )
>>> u = Tensor(np.random.random((2, 8, 1, 4096)).astype(np.float32))
>>> out = net(u)
>>> print(out.shape)
(2, 8, 16, 2048)
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
pad_mode: str = 'same',
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
group: int = 1,
has_bias: bool = False,
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros') -> None:
super(Conv1d, self).__init__(HConv1d,
ConvImpl,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
class Conv3d(_UniformOperator):
r"""
3D convolution layer on the dual-valued input.
Calculates the 3D convolution on the input tensor which is typically of shape
:math:`(2, N, C_{in}, D_{in}, H_{in}, W_{in})`,
where :math:`N` is batch size, :math:`C_{in}` is a number of channels,
:math:`D_{in}, H_{in}, W_{in}` are the depth, height and width of the feature layer respectively.
The formula is defined as:
.. math::
\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
\sum_{k = 0}^{C_{in} - 1} \text{ccor}({\text{weight}(C_{\text{out}_j}, k), \text{inp}(N_i, k)})
where :math:`C_{in}` is the channel number of the input, :math:`out_{j}` corresponds to the jth channel of
the output and :math:`j` is in the range of :math:`[0C_{out}-1]`. :math:`\text{weight}(C_{\text{out}_j}, k)`
is a convolution kernel slice with shape
:math:`(\text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`,
where :math:`\text{kernel_size[0]}`, :math:`\text{kernel_size[1]}` and :math:`\text{kernel_size[2]}` are
the depth, height and width of the convolution kernel respectively. :math:`\text{bias}` is the bias parameter
and :math:`\text{inp}` is the input tensor.
The shape of full convolution kernel is
:math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]}, \text{kernel_size[2]})`,
where `group` is the number of groups to split the input `x` in the channel dimension.
:math:`hccor` is the dual-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_.
This implies the operation as:
.. math::
\text{Re(ccor)} = \text{ccor}(\text{Re(kernel)}, \text{Re(inp)}) + \text{Re(bias)}\\
\text{Du(ccor)} = \text{ccor}(\text{Du(kernel)}, \text{Re(inp)})
+ \text{ccor}(\text{Re(kernel)}, \text{Du(inp)}) + \text{Du(bias)}
where and :math:`ccor` is the real-valued `cross-correlation <https://en.wikipedia.org/wiki/Cross-correlation>`_,
:math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}`
and :math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside
the parentheses.
For more details, please refers to the paper `Gradient Based Learning Applied to Document
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Note:
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
That is, when `group>1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
Args:
in_channels (int): The channel number of the input tensor of the Conv3d layer.
out_channels (int): The channel number of the output tensor of the Conv3d layer.
kernel_size (Union[int, tuple[int]]): Specifies the depth, height and width of the 3D convolution kernel.
The data type is an integer or a tuple of three integers. An integer represents the depth, height
and width of the convolution kernel. A tuple of three integers represents the depth, height
and width of the convolution kernel respectively.
stride (Union[int, tuple[int]]): The movement stride of the 3D convolution kernel.
The data type is an integer or a tuple of three integers. An integer represents the movement step size
in depth, height and width directions. A tuple of three integers represents the movement step size
in the depth, height and width directions respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: The width of the output is the same as the value of the input divided by `stride`.
If this mode is set, the value of `padding` must be 0.
- valid: Returns a valid calculated output without padding. Excess pixels that do not satisfy the
calculation will be discarded. If this mode is set, the value of `padding` must be 0.
- pad: Pads the input. Padding `padding` size of zero on both sides of the input.
If this mode is set, the value of `padding` must be greater than or equal to 0.
padding (Union(int, tuple[int])): The number of padding on the depth, height and width directions of the input.
The data type is an integer or a tuple of six integers. If `padding` is an integer,
then the head, tail, top, bottom, left, and right padding are all equal to `padding`.
If `padding` is a tuple of six integers, then the head, tail, top, bottom, left, and right padding
is equal to `padding[0]`, `padding[1]`, `padding[2]`, `padding[3]`, `padding[4]` and `padding[5]`
respectively. The value should be greater than or equal to 0. Default: 0.
dilation (Union[int, tuple[int]]): Dilation size of 3D convolution kernel.
The data type is an integer or a tuple of three integers. If :math:`k > 1`, the kernel is sampled
every `k` elements. The value of `k` on the depth, height and width directions is in range of
[1, D], [1, H] and [1, W] respectively. Default: 1.
group (int): Splits filter into groups, `in_channels` and `out_channels` must be
divisible by `group`. Default: 1. Only 1 is currently supported.
has_bias (bool): Whether the Conv3d layer has a bias parameter. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of weight parameter.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initialization method of bias parameter.
Available initialization methods are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
data_format (str): The optional value for data format. Currently only support "NCDHW".
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C_{in}, D_{in}, H_{in}, W_{in})`.
Currently input data type only support float16 and float32.
Outputs:
Tensor of shape is :math:`(2, N, C_{out}, D_{out}, H_{out}, W_{out})`.
pad_mode is 'same':
.. math::
\begin{array}{ll} \\
D_{out} \left \lceil{\frac{D_{in}}{\text{stride[0]}}} \right \rceil \\
H_{out} \left \lceil{\frac{H_{in}}{\text{stride[1]}}} \right \rceil \\
W_{out} \left \lceil{\frac{W_{in}}{\text{stride[2]}}} \right \rceil \\
\end{array}
pad_mode is 'valid':
.. math::
\begin{array}{ll} \\
D_{out} \left \lfloor{\frac{D_{in} - \text{dilation[0]} \times (\text{kernel_size[0]} - 1) }
{\text{stride[0]}} + 1} \right \rfloor \\
H_{out} \left \lfloor{\frac{H_{in} - \text{dilation[1]} \times (\text{kernel_size[1]} - 1) }
{\text{stride[1]}} + 1} \right \rfloor \\
W_{out} \left \lfloor{\frac{W_{in} - \text{dilation[2]} \times (\text{kernel_size[2]} - 1) }
{\text{stride[2]}} + 1} \right \rfloor \\
\end{array}
pad_mode is 'pad':
.. math::
\begin{array}{ll} \\
D_{out} \left \lfloor{\frac{D_{in} + padding[0] + padding[1] - (\text{dilation[0]} - 1) \times
\text{kernel_size[0]} - 1 }{\text{stride[0]}} + 1} \right \rfloor \\
H_{out} \left \lfloor{\frac{H_{in} + padding[2] + padding[3] - (\text{dilation[1]} - 1) \times
\text{kernel_size[1]} - 1 }{\text{stride[1]}} + 1} \right \rfloor \\
W_{out} \left \lfloor{\frac{W_{in} + padding[4] + padding[5] - (\text{dilation[2]} - 1) \times
\text{kernel_size[2]} - 1 }{\text{stride[2]}} + 1} \right \rfloor \\
\end{array}
Raises:
TypeError: If `in_channels`, `out_channels` or `group` is not an int.
TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
ValueError: If `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
ValueError: If `padding` is less than 0.
ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
ValueError: If `padding` is a tuple whose length is not equal to 6.
ValueError: If `pad_mode` is not equal to 'pad' and `padding` is not equal to (0, 0, 0, 0, 0, 0).
ValueError: If `data_format` is not 'NCDHW'.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import Conv3d
>>> from mindspore import Tensor
>>> w = Tensor(np.random.random((2, 128, 3, 3, 3, 3)).astype(np.float32))
>>> b = Tensor(np.random.random((2, 128)).astype(np.float32))
>>> net = Conv3d(
>>> in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=1,
>>> pad_mode='pad', weight_init=w, bias_init=b, has_bias=True
>>> )
>>> u = Tensor(np.random.random((2, 64, 3, 32, 32, 32)).astype(np.float32))
>>> out = net(u)
>>> print(out.shape)
(2, 64, 128, 32, 32, 32)
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_3_t,
stride: _size_3_t = 1,
pad_mode: str = 'same',
padding: _size_3_t = 0,
dilation: _size_3_t = 1,
group: int = 1,
has_bias: bool = False,
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
data_format: str = 'NCDHW') -> None:
super(Conv3d, self).__init__(HConv3d,
ConvImpl,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init,
data_format=data_format)
class BatchNorm1d(_UniformOperator):
r"""
The dual-valued Batch Normalization layer over a second-order dual input of four dimensions, including
one spatial dimension, or three dimensions.
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
\end{align}
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
and dual parts of the dual-valued expression inside the parentheses.
Args:
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, W)` or :math:`(2, N, C)`.
'2' denotes that the input tensor belongs to the dual domain and has got a real and
a dual parts. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same shape as :math:`u`:
:math:`(2, N, C, W)` or :math:`(2, N, C)`
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: if 'inp' is not a Tensor of 3 or 4 dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import BatchNorm1d
>>> from mindspore import Tensor
>>> u = Tensor(np.random.random((2, 8, 64, 32)).astype(np.float32))
>>> bn = BatchNorm1d(64)
>>> y = bn(u)
>>> print(y.shape)
(2, 8, 64, 32)
"""
def __init__(self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = True) -> None:
super(BatchNorm1d, self).__init__(HBatchNorm1d,
BatchNormImpl,
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
gamma_init=gamma_init,
beta_init=beta_init,
moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics)
class BatchNorm2d(_UniformOperator):
r"""
The dual-valued Batch Normalization layer over a second-order dual input of five dimensions, including
two spatial dimensions.
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}\
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
\end{align}
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
and dual parts of the dual-valued expression inside the parentheses.
Args:
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, H, W)` if data_format is 'NCHW', or
:math:`(2, N, H, W, C)` if data_format is 'NHWC'. '2' denotes that the input tensor belongs to the dual
domain and has got a real and a dual parts. The `num_features` in `Args` has to be equal to
:math:`C` in `Inputs`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same shape as :math:`u`:
:math:`(2, N, C, H, W)` if data_format is 'NCHW', or :math:`(2, N, H, W, C)` if data_format is 'NHWC'.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
ValueError: if 'inp' is not a Tensor of 5 dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import BatchNorm2d
>>> from mindspore import Tensor
>>> u = Tensor(np.random.random((2, 8, 64, 32, 32)).astype(np.float32))
>>> bn = BatchNorm2d(64)
>>> y = bn(u)
>>> print(y.shape)
(2, 8, 64, 32, 32)
"""
def __init__(self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = True,
data_format='NCHW') -> None:
super(BatchNorm2d, self).__init__(HBatchNorm2d,
BatchNormImpl,
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
gamma_init=gamma_init,
beta_init=beta_init,
moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics,
data_format=data_format)
class BatchNorm3d(_UniformOperator):
r"""
The dual-valued Batch Normalization layer over a second-order dual input of six dimensions, including
three spatial dimensions.
This layer applies Batch Normalization over a dual input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
out = \frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}\\
\text{Re(\hat{out})} = \text{Re(out)} * \text{Re(\gamma)} + \text{Re(\beta)}\\
\text{Du(\hat{out})} = \text{Re(out)} * \text{Du(\gamma)}
+ \text{Du(out)} * \text{Re(\gamma)} + \text{Du(\beta)},
\end{align}
where :math:`inp` is the dual input tensors, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor
over the batch dimension, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over spatial
dimensions, based on the dual norm :math:`\|x+\epsilon y\|=\left|\frac{y}{2}\right|+\sqrt{x^2+\frac{y^2}{4}}`,
:math:`\gamma` and :math:`\beta` are dual learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero. :math:`\text{Re(...)}` and :math:`\text{Du(...)}` are respectively real
and dual parts of the dual-valued expression inside the parentheses.
Args:
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
data_format (str): The optional value for data format. Only 'NCDHW' format is supported as of now.
Default: 'NCDHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, D, H, W)`. '2' denotes that the input tensor belongs
to the dual domain and has got a real and a dual parts. The `num_features` in `Args` has to be equal
to :math:`C` in `Inputs`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same shape :math:`(2, N, C, D, H, W)` as :math:`u`.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: If `data_format` is not 'NCDHW'.
ValueError: if 'inp' is not a Tensor of 6 dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import BatchNorm3d
>>> from mindspore import Tensor
>>> u = Tensor(np.random.random((2, 8, 64, 32, 32, 32)).astype(np.float32))
>>> bn = BatchNorm3d(64)
>>> y = bn(u)
>>> print(y.shape)
(2, 8, 64, 32, 32, 32)
"""
def __init__(self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = True,
data_format='NCDHW') -> None:
super(BatchNorm3d, self).__init__(HBatchNorm3d,
BatchNormImpl,
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
gamma_init=gamma_init,
beta_init=beta_init,
moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics,
data_format=data_format)
class Dense(_UniformOperator):
r"""
The dual-valued dense connected layer.
Applies dense connected layer for the dual-valued input. This layer implements the operation as:
.. math::
\begin{align}
\text{Re(out)} = \text{Re(inp)} * \text{Re(kernel)} + \text{Re(bias)}\\
\text{Du(out)} = \text{Re(inp)} * \text{Du(kernel)} + \text{Du(inp)} * \text{Re(kernel)} + \text{Du(bias)},
\end{align}
where :math:`inp` is the dual input tensors, :math:`\text{kernel}` is a dual weight matrix with the same
data type as the :math:`inp` created by the layer, and :math:`\text{bias}` is a dual bias vector with the same
data type as the :math:`inp` created by the layer (only if has_bias is True). :math:`\text{Re(...)}` and
:math:`\text{Du(...)}` are respectively real and dual parts of the dual-valued expression inside the parentheses.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as `inp`. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, *, ..., *, in\_channels)`. '2' denotes that the input tensor
belongs to the dual domain and has got a real and a dual parts. The `in_channels` in `Args`
has to be equal to :math:`in\_channels` in `Inputs`. The count of mediator dimensions denoted by '*' is
arbitrary but must be at least one.
Outputs:
Tensor of shape :math:`(2, *, ..., *, out\_channels)`. The count of mediator dimensions is the same as one
in 'Inputs'.
Raises:
TypeError: If `in_channels` or `out_channels` is not an int.
TypeError: If `has_bias` is not a bool.
ValueError: If length of shape of `weight_init` is not equal to 3,
or shape[0] of 'weight_init' is not equal to 2,
or shape[1] of `weight_init` is not equal to `out_channels`,
or shape[2] of `weight_init` is not equal to `in_channels`.
ValueError: If length of shape of `bias_init` is not equal to 2,
or shape[0] of 'bias_init' is not equal to 2,
or shape[1] of `bias_init` is not equal to `out_channels`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import numpy as np
>>> from mindspore.hypercomplex.dual import Dense
>>> from mindspore import Tensor
>>> w = Tensor(np.random.random((2, 7, 5)).astype(np.float32))
>>> b = Tensor(np.random.random((2, 7)).astype(np.float32))
>>> net = Dense(in_channels=5, out_channels=7, weight_init=w, bias_init=b, has_bias=True)
>>> u = Tensor(np.random.random((2, 34, 1, 5)).astype(np.float32))
>>> out = net(u)
>>> print(out.shape)
(2, 34, 1, 7)
"""
def __init__(self,
in_channels: int,
out_channels: int,
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
has_bias: bool = True) -> None:
super(Dense, self).__init__(HDense,
DenseImpl,
in_channels=in_channels,
out_channels=out_channels,
weight_init=weight_init,
bias_init=bias_init,
has_bias=has_bias)

View File

@ -0,0 +1,339 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hypercomplex BatchNorm Implementation"""
import numbers
from typing import Union, Tuple
from abc import abstractmethod
import mindspore.nn as nn
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
class _BatchNormImpl(nn.Cell):
r"""
The interface of the implementor part of batch normalization layer on the second-order hypercomplex numbers.
Defines the API for getting the norm of hypercomplex number, applying scaling and shift to a hypercomplex tensor,
and updating the running mean and variance, which are used during inference. The API is used by the 'BatchNorm'
class, and it must be implemented separately for every hypercomplex algebra:
Args:
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
gamma_init: Union[Tensor, str, Initializer, numbers.Number],
beta_init: Union[Tensor, str, Initializer, numbers.Number],
num_features: int) -> None:
super(_BatchNormImpl, self).__init__()
@abstractmethod
def get_norm(self, u: Tensor) -> Tensor:
r"""
Calculates norm of a hypercomplex elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the hypercomplex
domain and has a real and a hypercomplex parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
@abstractmethod
def get_square_norm(self, u: Tensor) -> Tensor:
r"""
Calculates element-wise squared norm of hypercomplex elements of an input tensor.
Norm is a non-negative real number that is a characteristic of 'magnitude' of that number, i.e. how far away it
is from zero.
Args:
u (Tensor): Tensor of shape (2, *, ..., *). '2' denotes that the input tensor belongs to the hypercomplex
domain and has a real and a hypercomplex parts.
Returns:
Tensor of shape (*, ..., *). The count and size of dimensions of the output tensor are the same ones as in
the input tensor, but without the very first dimension because the output tensor is real-valued.
"""
@abstractmethod
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Applies hypercomplex scaling and shift to an input tensor.
This function implements the operation as:
.. math::
\text{out} = \text{mul}(\text{inp}, \text{scale}) + \text{shift},
where :math:`inp` is the hypercomplex input tensors, :math:`\text{mul}` is the channel-wise scaling operation,
which depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is
a hypercomplex scaling vector with the same data type as the :math:`inp` created by the layer, and
:math:`\text{shift}` is a hypercomplex bias vector with the same data type as the :math:`inp` created by
the layer.
Args:
u_x (Tensor): A tensor of shape (C,), which represents the real part of the normalized inputs.
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the normalized inputs.
scale_x (Tensor): A tensor of shape (C,), which represents the real part of the scaling vector.
scale_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the scaling vector.
shift_x (Tensor): A tensor of shape (C,), which represents the real part of the bias vector.
shift_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the bias vector.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
recentered inputs.
"""
@abstractmethod
def calculate_bn(self,
u_centered_x: Tensor,
u_centered_y: Tensor,
sigma: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Given a hypercomplex centered input tensor and the standard deviation of its elements, computes the
corresponding rescaled and recentered tensor with normalized variance.
This function implements the operation as:
.. math::
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma}) + \text{shift},
where :math:`inp` is the hypercomplex input tensors centered over spatial and mini-batch dimensions,
:math:`\sigma` is standard deviation of the input tensors over the same dimensions, :math:`\text{mul}` is a
channel-wise scaling operation, which depends on the type of the number system and provided by subclassess,
:math:`\text{scale}` is a hypercomplex scaling vector with the same data type as the :math:`inp` created
by the layer, and :math:`\text{shift}` is a hypercomplex bias vector with the same data type as the
:math:`inp` created by the layer.
Args:
u_centered_x (Tensor): A tensor of shape (C,), which represents the real part of the centered inputs.
u_centered_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the
centered inputs.
sigma (Tensor): A tensor of shape (C,), which represents the statistical standard deviation of the inputs.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
recentered normalized inputs.
"""
@abstractmethod
def calculate_infer_bn(self,
moving_mean_x: Tensor,
moving_mean_y: Tensor,
moving_sigma: Tensor,
u_x: Tensor,
u_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Given a hypercomplex input tensor, computes the corresponding rescaled and recentered normalized tensor.
This function is supposed to be used during inference. The mean and standard deviation are accumulated during
the training phase. The function implements the operation as:
.. math::
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma})
+ \left(\text{mul}(-\mathrm{E}[inp], \frac{\text{scale}}{\sigma})+\text{shift}\right),
where :math:`inp` is the hypercomplex input tensors, :math:`\sigma` is the accumulated standard deviation of
the input tensors over spatial and mini-batch dimensions, :math:`\mathrm{E}[inp]` is the accumulated arithmetic
mean of the input tensor over the same dimensions,:math:`\text{mul}` is a channel-wise scaling operation, which
depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is a hypercomplex
scaling vector with the same data type as the :math:`inp` created by the layer, and :math:`\text{shift}` is a
hypercomplex bias vector with the same data type as the :math:`inp` created by the layer.
Args:
moving_mean_x (Tensor): A tensor of shape (C,), which represents the real part of the accumulated
arithmetic mean of inputs.
moving_mean_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the accumulated
arithmetic mean of inputs.
moving_sigma (Tensor): A tensor of shape (C,), which represents the accumulated statistical standard
deviation of inputs.
u_x (Tensor): A tensor of shape (C,), which represents the real part of the input tensor.
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the input tensor.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of normalized,
rescaled and recentered inputs.
"""
class _BaseBatchNormImpl(_BatchNormImpl):
r"""
The base implementor part of the batch normalization layer for all the hypercomplex numbers of the second order.
Contains initialization and processing logic, which are shared by all specific implementations of the
'BatchNormImpl' interface for dual, double, and complex numbers.
Args:
affine (bool) - A bool value. When set to True, gamma and beta can be learned.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc.
num_features (int): The number of features in the input space.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
affine: bool,
use_batch_statistics: bool,
gamma_init: Union[Tensor, str, Initializer, numbers.Number],
beta_init: Union[Tensor, str, Initializer, numbers.Number],
num_features: int) -> None:
super(_BaseBatchNormImpl, self).__init__(gamma_init,
beta_init,
num_features)
self.scale_x = Parameter(initializer(gamma_init, num_features), name="scale_x", requires_grad=affine)
self.scale_y = Parameter(initializer(gamma_init, num_features), name="scale_y", requires_grad=affine)
self.shift_x = Parameter(initializer(beta_init, num_features), name="shift_x", requires_grad=affine)
self.shift_y = Parameter(initializer(beta_init, num_features), name="shift_y", requires_grad=affine)
def calculate_infer_bn(self,
moving_mean_x: Tensor,
moving_mean_y: Tensor,
moving_sigma: Tensor,
u_x: Tensor,
u_y: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Given a hypercomplex input tensor, computes the corresponding rescaled and recentered normalized tensor.
This function is supposed to be used during inference. The mean and standard deviation are accumulated during
the training phase. The function implements the operation as:
.. math::
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma})
+ \left(\text{mul}(-\mathrm{E}[inp], \frac{\text{scale}}{\sigma})+\text{shift}\right),
where :math:`inp` is the hypercomplex input tensors, :math:`\sigma` is the accumulated standard deviation of
the input tensors over spatial and mini-batch dimensions, :math:`\mathrm{E}[inp]` is the accumulated arithmetic
mean of the input tensor over the same dimensions,:math:`\text{mul}` is a channel-wise scaling operation, which
depends on the type of the number system and provided by subclassess, :math:`\text{scale}` is a hypercomplex
scaling vector with the same data type as the :math:`inp` created by the layer, and :math:`\text{shift}` is a
hypercomplex bias vector with the same data type as the :math:`inp` created by the layer.
Args:
moving_mean_x (Tensor): A tensor of shape (C,), which represents the real part of the accumulated
arithmetic mean of inputs.
moving_mean_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the accumulated
arithmetic mean of inputs.
moving_sigma (Tensor): A tensor of shape (C,), which represents the accumulated statistical standard
deviation of inputs.
u_x (Tensor): A tensor of shape (C,), which represents the real part of the input tensor.
u_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the input tensor.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of normalized,
rescaled and recentered inputs.
"""
fused_scale_x = self.scale_x / moving_sigma
fused_scale_y = self.scale_y / moving_sigma
neg_mean_x = (-1) * moving_mean_x
neg_mean_y = (-1) * moving_mean_y
fused_shift_x, fused_shift_y = self.scale_and_shift(neg_mean_x,
neg_mean_y,
fused_scale_x,
fused_scale_y,
self.shift_x,
self.shift_y)
out_x, out_y = self.scale_and_shift(u_x,
u_y,
fused_scale_x,
fused_scale_y,
fused_shift_x,
fused_shift_y)
return out_x, out_y
def calculate_bn(self,
u_centered_x: Tensor,
u_centered_y: Tensor,
sigma: Tensor) -> Tuple[Tensor, Tensor]:
r"""
Given a hypercomplex centered input tensor and the standard deviation of its elements, computes the
corresponding rescaled and recentered tensor with normalized variance.
This function implements the operation as:
.. math::
\text{out} = \text{mul}(\text{inp}, \frac{\text{scale}}{\sigma}) + \text{shift},
where :math:`inp` is the hypercomplex input tensors centered over spatial and mini-batch dimensions,
:math:`\sigma` is standard deviation of the input tensors over the same dimensions, :math:`\text{mul}` is a
channel-wise scaling operation, which depends on the type of the number system and provided by subclassess,
:math:`\text{scale}` is a hypercomplex scaling vector with the same data type as the :math:`inp` created
by the layer, and :math:`\text{shift}` is a hypercomplex bias vector with the same data type as the
:math:`inp` created by the layer.
Args:
u_centered_x (Tensor): A tensor of shape (C,), which represents the real part of the centered inputs.
u_centered_y (Tensor): A tensor of shape (C,), which represents the hypercomplex part of the
centered inputs.
sigma (Tensor): A tensor of shape (C,), which represents the statistical standard deviation of the inputs.
Returns:
Tuple of two tensors of shape (C,), which contains the real and the hypercomplex parts of rescaled and
recentered normalized inputs.
"""
scale_x = self.scale_x / sigma
scale_y = self.scale_y / sigma
out_x, out_y = self.scale_and_shift(u_centered_x,
u_centered_y,
scale_x,
scale_y,
self.shift_x,
self.shift_y)
return out_x, out_y
@abstractmethod
def get_norm(self, u: Tensor) -> Tensor:
pass
@abstractmethod
def get_square_norm(self, u: Tensor) -> Tensor:
pass
@abstractmethod
def scale_and_shift(self,
u_x: Tensor,
u_y: Tensor,
scale_x: Tensor,
scale_y: Tensor,
shift_x: Tensor,
shift_y: Tensor) -> Tuple[Tensor, Tensor]:
pass

View File

@ -0,0 +1,123 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hypercomplex Convolution Implementation"""
import numbers
from typing import Callable, Union, Tuple
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.hypercomplex.utils import get_x_and_y
class _ConvImpl(nn.Cell):
r"""
The interface of the implementor part of convolution layer on second-order hypercomplex numbers.
Defines the API for unbiased convolution transformation, which is used by the '_ConvNd' class. The API must
be implemented separately for every hypercomplex algebra:
.. math::
\text{out} = \text{conv}(\text{inp}, \text{kernel})
where :math:`inp` is the hypercomplex input tensors, :math:`\text{conv}` is the convolution transformation
operation, which is provided by subclasses, :math:`\text{kernel}` is a hypercomplex weight matrix with the same
data type as the :math:`inp` created by the layer.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used with decomposition
of the hypercomplex convolution transformation. For example, mindspore.ops.operations.Conv2D(...)
may be passed for a 2D convolution.
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input. The exact shape depends on data format and the number of spatial
dimensions.
- **y** (Tensor) - Tensor of the same shape as `x`, which defines the real part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`, which
represent the real and the hypercomplex parts of the output respectively. Data format and the count of spatial
dimensions are the same as in `x` and `y`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_ConvImpl, self).__init__()
def construct(self,
conv_op: Callable,
x: Tensor,
y: Tensor) -> Tuple[Tensor, Tensor]:
pass
class _BaseConvImpl(_ConvImpl):
r"""
The base implementor part of the convolution layer for all the hypercomplex numbers of the second order.
Contains initialization of the kernel tensors, which is shared by all specific implementations of the 'ConvImpl'
interface for dual, double, and complex numbers.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **conv_op** (Callable) - the function of the real-valued convolution to be used with decomposition
of the hypercomplex convolution transformation. For example, mindspore.ops.operations.Conv2D(...) may be
passed for a 2D convolution.
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, *, ..., *)` or :math:`(N, *, ..., *, C_{in})`,
which defines the real part of the input. The exact shape depends on data format and the number of spatial
dimensions.
- **y** (Tensor) - Tensor of the same shape as `x`, which defines the real part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(N, C_{out}, *, ..., *)` or :math:`(N, *, ..., *, C_{out})`, which
represent the real and the hypercomplex parts of the output respectively. Data format and the count of spatial
dimensions are the same as in `x` and `y`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_BaseConvImpl, self).__init__(weight_init,
weight_shape,
**factory_kwargs)
if isinstance(weight_init, Tensor):
weight_init_x, weight_init_y = get_x_and_y(weight_init)
else:
weight_init_x = weight_init_y = weight_init
self.weight_x = Parameter(initializer(weight_init_x, shape=weight_shape), name='weight_x')
self.weight_y = Parameter(initializer(weight_init_y, shape=weight_shape), name='weight_y')

View File

@ -0,0 +1,114 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hypercomplex dense implementation"""
import numbers
from typing import Callable, Union, Tuple
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.hypercomplex.utils import get_x_and_y
class _DenseImpl(nn.Cell):
r"""
The interface of the implementor part of dense connected layer on second-order hypercomplex numbers.
Defines the API for linear transformation, which is used by the 'Dense' class. The API must be implemented
seprarately for every hypercomplex algebra:
.. math::
\text{out} = \text{linear}(\text{inp}, \text{kernel})
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which is provided by subclasses, :math:`\text{kernel}` is a hypercomplex weight matrix with the same data type as
the :math:`inp` created by the layer.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the hypercomplex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **y** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the hypercomplex part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the hypercomplex
part of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_DenseImpl, self).__init__()
def construct(self,
matmul_op: Callable,
x: Tensor,
y: Tensor) -> Tuple[Tensor, Tensor]:
pass
class _BaseDenseImpl(_DenseImpl):
r"""
The base implementor part of the dense connected layer for all the hypercomplex numbers of the second order.
Contains initialization of the kernel tensors, which is shared by all specific implementations of the 'DenseImpl'
interface for dual, double, and complex numbers.
Args:
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
weight_shape (tuple): The set of int numbers that defines the shape of real and hypercomplex parts of
the kernel.
**factory_kwargs (dict): Extra parameters which may be needed by specific subclasses.
Inputs:
- **matmul_op** (Callable) - the function of the real-valued matrix multiplication to be used for decomposition
of the hypercomplex linear transformation. Usually, mindspore.ops.operations.MatMul(...) is passed
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the real part of the input.
- **y** (Tensor) - Tensor of shape :math:`(*, in\_channels)`, which defines the hypercomplex part of the input.
Outputs:
Tuple of two tensors, each of shape :math:`(*, out\_channels)`, which represents the real and the hypercomplex
part of the output.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
weight_init: Union[Tensor, str, Initializer, numbers.Number],
weight_shape: tuple,
**factory_kwargs) -> None:
super(_BaseDenseImpl, self).__init__(weight_init,
weight_shape,
**factory_kwargs)
if isinstance(weight_init, Tensor):
weight_init_x, weight_init_y = get_x_and_y(weight_init)
else:
weight_init_x = weight_init_y = weight_init
self.weight_x = Parameter(initializer(weight_init_x, shape=weight_shape), name='weight_x')
self.weight_y = Parameter(initializer(weight_init_y, shape=weight_shape), name='weight_y')

View File

@ -0,0 +1,627 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hypercomplex batchnorm"""
import numbers
from typing import TypeVar, Type, Union, Any
from abc import abstractmethod
import numpy as np
import mindspore
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as P
from mindspore._checkparam import Validator as validator
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F
from mindspore.hypercomplex.hypercomplex._hc_bn_impl import _BatchNormImpl as BatchNormImpl
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
TBatchNormImpl = TypeVar('TBatchNormImpl', bound=BatchNormImpl)
class _BatchNorm(nn.Cell):
r"""
The base class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
of some number of dimensions.
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature using
a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
out = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
\end{align}
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which depends on the type of the number system and provided by the implementor part of the batch normalization
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero.
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate a child
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
Args:
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
class name of this argument defines the algebra that the batch normalization layer will operate on.
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, *, ..., *)` if data_format is 'NCHW', or
:math:`(2, N, *, ..., *, C)` if data_format is 'NHWC', with float16 or float32 data type. '2' denotes that
the input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts. Or,
:math:`(N, C, *, ..., *)` if data_format is 'NCHW', or :math:`(N, *, ..., *, C)` if data_format is 'NHWC',
with complex64 data type. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
The count of dimensions denoted by '*' must be equal to the number of spatial dimensions.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
:math:`(2, N, C, *, ..., *)` if data_format is 'NCHW', or :math:`(2, N, *, ..., *, C)` if data_format is 'NHWC',
with float16 or float32 data type. Or, :math:`(N, C, *, ..., *)` if data_format is 'NCHW', or
:math:`(N, *, ..., *, C)` if data_format is 'NHWC', with complex64 data type.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If dtype of `inp` is not float16, float32 or complex64.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
bn_impl: Type[TBatchNormImpl],
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = None,
data_format: str = 'NCHW') -> None:
"""Initialize _BatchNorm."""
super(_BatchNorm, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
if num_features < 1:
raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
if momentum < 0 or momentum > 1:
raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
f"but got {momentum}.")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
f"target {context.get_context('device_target')}.")
self.use_batch_statistics = use_batch_statistics
if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None,"
f" but got {use_batch_statistics}.")
self.num_features = num_features
self.eps = eps
self.beta_init = beta_init
self.gamma_init = gamma_init
self.moving_mean_init = moving_mean_init
self.moving_var_init = moving_var_init
self.affine = affine
self.bn_impl = bn_impl(affine, use_batch_statistics, gamma_init, beta_init, num_features)
self.moving_mean_x = Parameter(
initializer(moving_mean_init, (num_features)), name="mean_x", requires_grad=False
)
self.moving_mean_y = Parameter(
initializer(moving_mean_init, (num_features)), name="mean_y", requires_grad=False
)
self.moving_sigma2 = Parameter(
initializer(moving_var_init, num_features), name="sigma2", requires_grad=False
)
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self._target = context.get_context("device_target")
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
self.reduce_mean_op1 = P.ReduceMean(keep_dims=True)
self.reduce_mean_op2 = P.ReduceMean(keep_dims=False)
self.features_dim = data_format.lower().find('c')
self.get_dtype = P.DType()
self.get_shape = P.Shape()
def construct(self, u: Tensor) -> Tensor:
"""construct"""
u_dtype = self.get_dtype(u)
u_shape = self.get_shape(u)
self._check_input_dim(u_shape, u_dtype)
if u_dtype == mindspore.complex64:
hc_axis = None
feature_axis = self.features_dim
else:
hc_axis = 0
feature_axis = self.features_dim + 1
if self.training or not self.use_batch_statistics:
ndim = u.ndim
hc_axis = hc_axis
feature_axis = feature_axis
sh = np.arange(ndim)
sh = sh[sh != hc_axis]
sh = sh[sh != feature_axis]
if hc_axis is None:
u_x, u_y = get_x_and_y(u)
mu_x = self.reduce_mean_op1(u_x, sh.tolist())
mu_y = self.reduce_mean_op1(u_y, sh.tolist())
mu = to_2channel(mu_x, mu_y, mindspore.complex64)
else:
mu = self.reduce_mean_op1(u, sh.tolist())
u_centered = u - mu
norma2 = self.bn_impl.get_square_norm(u_centered)
norma_feature_axis = feature_axis if hc_axis is None or feature_axis < hc_axis else feature_axis - 1
ndim = norma2.ndim
mean_dims = np.arange(ndim)
mean_dims = mean_dims[mean_dims != norma_feature_axis]
sigma2 = self.reduce_mean_op2(norma2, mean_dims.tolist()) + self.eps
result = self._calculate_bn(u_centered, sigma2, feature_axis)
if self.use_batch_statistics:
momentum = self.momentum
mu = mu.squeeze()
mu_x, mu_y = get_x_and_y(mu)
momentum_suppl = 1 - momentum
self.moving_mean_x *= momentum_suppl
self.moving_mean_x += mu_x * momentum
self.moving_mean_y *= momentum_suppl
self.moving_mean_y += mu_y * momentum
self.moving_sigma2 *= momentum_suppl
self.moving_sigma2 += sigma2 * momentum
elif self.affine:
result = self._calculate_infer_bn(u, axis=feature_axis)
else:
broadcast_mu_shape = [1] * u.ndim
broadcast_mu_shape[feature_axis] = u_shape[feature_axis]
if hc_axis is not None:
broadcast_mu_shape[hc_axis] = 2
moving_mean = to_2channel(self.moving_mean_x, self.moving_mean_y, u.dtype)
moving_mean = moving_mean.reshape(tuple(broadcast_mu_shape))
inference_centered = u - moving_mean
result = self._calculate_bn(inference_centered, self.moving_sigma2, feature_axis)
return result
def _calculate_bn(self,
u_centered: Tensor,
sigma2: Tensor,
axis: int) -> Tensor:
"""_calculate_bn, implement the abstract function"""
sigma = P.sqrt(sigma2)
ndim = u_centered.ndim
u_shape = list(np.arange(ndim))
u_shape[ndim - 1] = axis
u_shape[axis] = ndim - 1
u_shape = tuple(int(i) for i in u_shape)
out = P.transpose(u_centered, u_shape)
if self.affine:
out_x, out_y = get_x_and_y(out)
out_x, out_y = self.bn_impl.calculate_bn(out_x, out_y, sigma)
out = to_2channel(out_x, out_y, self.get_dtype(u_centered))
else:
out = out / sigma
out = P.transpose(out, u_shape)
return out
def _calculate_infer_bn(self,
u: Tensor,
axis: int) -> Tensor:
"""_calculate_infer_bn, implement the abstract function"""
ndim = u.ndim
shape = list(np.arange(ndim))
shape[ndim-1] = axis
shape[axis] = ndim - 1
shape = tuple(int(i) for i in shape)
out = P.transpose(u, shape)
out_x, out_y = get_x_and_y(out)
out_x, out_y = self.bn_impl.calculate_infer_bn(self.moving_mean_x,
self.moving_mean_y,
P.sqrt(self.moving_sigma2),
out_x,
out_y)
out = to_2channel(out_x, out_y, dtype=u.dtype)
out = P.transpose(out, shape)
return out
@abstractmethod
def _check_input_dim(self, shape: tuple, dtype: Any):
raise NotImplementedError
class BatchNorm1d(_BatchNorm):
r"""
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
of four dimensions including one spatial dimension, or three dimensions.
This layer applies Batch Normalization over a hypercomplex input of 'NCW' data format in order to reduce
internal covariate shift. Batch Normalization is widely used in convolutional networks. It rescales and recenters
the feature using a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
out = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
\end{align}
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which depends on the type of the number system and provided by the implementor part of the batch normalization
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero.
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
Args:
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
class name of this argument defines the algebra that the batch normalization layer will operate on.
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, W)` or :math:`(2, N, C)`, with float16 or float32 data
type, or :math:`(N, C, W)` or :math:`(N, C)`, with complex64 data type. In the former case '2' denotes that
the input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts.
The `num_features` in `Args` has to be equal to :math:`C` in `inp`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
:math:`(2, N, C, W)` or :math:`(2, N, C)`, with float16 or float32 data type, or :math:`(N, C, W)` or
:math:`(N, C)`, with complex64 data type.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If dtype of `inp` is not float16, float32 or complex64.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: if 'inp' is not a Tensor of 3 or 4 dimensions with float16 or float32 data type, and not a Tensor
of 2 or 3 dimensions with complex64 data type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
bn_impl: Type[TBatchNormImpl],
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = None) -> None:
"""Initialize _BatchNorm."""
super(BatchNorm1d, self).__init__(bn_impl,
num_features,
eps,
momentum,
affine,
gamma_init,
beta_init,
moving_mean_init,
moving_var_init,
use_batch_statistics)
def _check_input_dim(self, shape: tuple, dtype: Any):
dim = len(shape)
if dtype in [mindspore.float16, mindspore.float32]:
if dim not in (4, 3):
raise ValueError(f"For '{self.cls_name}', the in_shape must have 3-4 dims, but got {dim}.")
elif dtype == mindspore.complex64:
if dim not in (3, 2):
raise ValueError(f"For '{self.cls_name}', the in_shape must have 2-3 dims, but got {dim}.")
else:
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
return None
class BatchNorm2d(_BatchNorm):
r"""
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
of five dimensions, including two spatial dimensions.
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature
using a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
y = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
\end{align}
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which depends on the type of the number system and provided by the implementor part of the batch normalization
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero.
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
Args:
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
class name of this argument defines the algebra that the batch normalization layer will operate on.
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: 'NCHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, H, W)` if data_format is 'NCHW', or
:math:`(2, N, H, W, C)` if data_format is 'NHWC', with float16 or float32 data type. '2' denotes that the
input tensor belongs to the hypercomplex domain and has got a real and a hypercomplex parts. Or,
:math:`(N, C, H, W)` if data_format is 'NCHW', or :math:`(N, H, W, C)` if data_format is 'NHWC', with
complex64 data type. The `num_features` in `Args` has to be equal to :math:`C` in `Inputs`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
:math:`(2, N, C, H, W)` if data_format is 'NCHW', or :math:`(2, N, H, W, C)` if data_format is 'NHWC', with
float16 or float32 data type. Or, :math:`(N, C, H, W)` if data_format is 'NCHW', or :math:`(N, H, W, C)` if
data_format is 'NHWC', with complex64 data type.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If dtype of `inp` is not float16, float32 or complex64.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
ValueError: if 'inp' is not a Tensor of 5 dimensions with float16 or float32 data type, and not a Tensor of 4
dimensions with complex64 data type.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def _check_input_dim(self, shape: tuple, dtype: Any):
dim = len(shape)
if dtype in [mindspore.float16, mindspore.float32]:
if dim != 5:
raise ValueError(f"For '{self.cls_name}', the in_shape must have 5 dims, but got {dim}.")
elif dtype == mindspore.complex64:
if dim != 4:
raise ValueError(f"For '{self.cls_name}', the in_shape must have 4 dims, but got {dim}.")
else:
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
return None
class BatchNorm3d(nn.Cell):
r"""
The class of the abstract part of Batch Normalization layer over a second-order hypercomplex input
of six dimensions, including three spatial dimensions.
This layer applies Batch Normalization over a hypercomplex input to reduce internal covariate shift.
Batch Normalization is widely used in convolutional networks. It rescales and recenters the feature
using a mini-batch of data and the learned parameters which can be described by the following formula:
.. math::
\begin{align}
\mathrm{Var}[inp] = \mathrm{E}[\| inp_i - \mathrm{E}[inp] \|^2]\\
y = \text{linear}(\frac{inp - \mathrm{E}[inp]}{\sqrt{\mathrm{Var}[inp] + \delta}}, \gamma) + \beta,
\end{align}
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which depends on the type of the number system and provided by the implementor part of the batch normalization
layer, :math:`\mathrm{E}[inp]` is the arithmetic mean of the input tensor over the spatial and mini-batch
dimensions, :math:`\mathrm{Var}[inp]` is the statistical variance of the input tensor over the same dimensions,
:math:`\gamma` and :math:`\beta` are hypercomplex learnable parameters representing the scale and shift coefficients
respectively, and :math:`\delta` is a small positive constant, which is needed to avoid division by zero in case
statistical variance is close to zero.
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
class and an implementor class, which acts like a bridge pattern and determines the exact set of hypercomplex
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
Args:
bn_impl(BatchNormImpl): The implementor object of the batch normalization layer. Essentially, the concrete
class name of this argument defines the algebra that the batch normalization layer will operate on.
num_features (int): The number of features in the input space.
eps (float): A small positive threshold, which is needed to avoid division by zero. Default: :math:`10^{-5}`
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
momentum (float): A floating hyperparameter of the momentum for the running_mean and running_var computation.
Default: 0.9.
affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool):
- If True, use the mean value and variance value of current batch data and track running mean
and running variance.
- If False, use the mean value and variance value of specified value, and not track statistical value.
- If None, the use_batch_statistics is automatically set to True or False according to the training
and evaluation mode. During training, the parameter is set to True, and during evaluation, the
parameter is set to False. Default: None.
data_format (str): The optional value for data format. Only 'NCDHW' format is supported as of now.
Default: 'NCDHW'.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, N, C, D, H, W)`, with float16 or float32 data type, or
:math:`(N, C, D, H, W)`, with complex64 data type. In the former case '2' denotes that the input tensor
belongs to the hypercomplex domain and has got a real and a hypercomplex parts. The `num_features` in `Args`
has to be equal to :math:`C` in `Inputs`.
Outputs:
Tensor, the normalized, scaled, offset tensor of the same data type and shape as :math:`inp`:
:math:`(2, N, C, D, H, W)`, with float16 and float32 data type, or :math:`(N, C, D, H, W)`, with
complex64 data type.
Raises:
TypeError: If `num_features` is not an int.
TypeError: If `eps` is not a float.
TypeError: If dtype of `inp` is not float16, float32 or complex64.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
ValueError: If `data_format` is not 'NCDHW'.
ValueError: if 'inp' is not a Tensor of 6 dimensions.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
bn_impl: Type[TBatchNormImpl],
num_features: int,
eps: float = 1e-5,
momentum: float = 0.9,
affine: bool = True,
gamma_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
beta_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_mean_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
moving_var_init: Union[Tensor, str, Initializer, numbers.Number] = 'ones',
use_batch_statistics: bool = None,
data_format: str = 'NCDHW') -> None:
"""Initialize _BatchNorm."""
super(BatchNorm3d, self).__init__()
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
self.reshape = P.Reshape()
self.bn2d = BatchNorm2d(bn_impl=bn_impl,
num_features=num_features,
eps=eps,
momentum=momentum,
affine=affine,
gamma_init=gamma_init,
beta_init=beta_init,
moving_mean_init=moving_mean_init,
moving_var_init=moving_var_init,
use_batch_statistics=use_batch_statistics,
data_format="NCHW")
def construct(self, u: Tensor) -> Tensor:
'''construct'''
u_shape = F.shape(u)
self._check_3d_shape(u_shape, F.dtype(u))
reshape = list(u_shape)
reshape[-3] *= reshape[-2]
reshape = tuple(int(i) for i in reshape[:-2] + reshape[-1:])
u = self.reshape(u, tuple(reshape))
out = self.bn2d(u)
out = self.reshape(out, u_shape)
return out
def _check_3d_shape(self, input_shape, dtype: Any) -> None:
'''_check_3d_shape'''
dim = len(input_shape)
if dtype in [mindspore.float16, mindspore.float32]:
if dim != 6:
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
raise ValueError(f"{msg_prefix} input_shape must be 6-dimensional, but got the length of input_shape: "
f"{len(dim)}.")
elif dtype == mindspore.complex64:
if dim != 5:
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
f"{len(dim)}.")
else:
raise TypeError(f"Only float16, float32 and complex64 data types are supported, but got {dtype}.")
return None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,200 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hypercomplex Dense"""
import numbers
from typing import TypeVar, Type, Union
import mindspore
import mindspore.nn as nn
from mindspore._checkparam import Validator
from mindspore.common.initializer import initializer, Initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.hypercomplex.hypercomplex._hc_dense_impl import _DenseImpl as DenseImpl
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
TDenseImpl = TypeVar('TDenseImpl', bound=DenseImpl)
class Dense(nn.Cell):
r"""
The abstract part of dense connected layer.
Applies dense connected layer for the second-order hypercomplex input. This layer implements the operation as:
.. math::
\text{out} = \text{linear}(\text{inp}, \text{kernel}) + \text{bias},
where :math:`inp` is the hypercomplex input tensors, :math:`\text{linear}` is the linear transformation operation,
which is defined and provided by the implementor part of the dense connected layer, :math:`\text{kernel}` is
a hypercomplex weight matrix with the same data type as the :math:`inp` created by the layer, and
:math:`\text{bias}` is a hypercomplex bias vector with the same data type as the :math:`inp` created by the layer
(only if has_bias is True).
This is not a self-sufficient class. In order to construct a fully connected layer, one should instantiate this
class and an implementor class, which acts like a strategy pattern and determine the exact set of hypercomplex
numbers. That implies the rules of multiplication and therefore affects how a linear transformation works.
Args:
dense_impl(DenseImpl): The implementor object of the dense connected layer. Essentially, the concrete class
name of this argument defines the algebra that the dense layer will operate on.
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `inp`. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as `inp`. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
Inputs:
- **inp** (Tensor) - Tensor of shape :math:`(2, *, ..., *, in\_channels)`, with float16 or float32 data type,
or :math:`(*, ..., *, in\_channels)`, with complex64 data type. In the former case '2' denotes that the input
tensor belongs to the hypercomplex domain and has got a real and an imaginary parts. The `in_channels` in
`Args` has to be equal to :math:`in\_channels` in `Inputs`. The count of mediator dimensions denoted by '*'
is arbitrary but must be at least one.
Outputs:
Tensor of the same data type as 'inp' and of shape :math:`(2, *, ..., *, out\_channels)`, with float16 or
float32 data type, or :math:`(*, ..., *, out\_channels)`, with complex64 data type. The count of mediator
dimensions is the same as one in 'inp'.
Raises:
TypeError: If `in_channels` or `out_channels` is not an int.
TypeError: If `has_bias` is not a bool.
TypeError: If any two of `inp`, `weight_init` and `bias_init` are Tensors of different data type.
ValueError: If length of shape of `weight_init` is not equal to 3,
or shape[0] of 'weight_init' is not equal to 2,
or shape[1] of `weight_init` is not equal to `out_channels`,
or shape[2] of `weight_init` is not equal to `in_channels`.
ValueError: If length of shape of `bias_init` is not equal to 2,
or shape[0] of 'bias_init' is not equal to 2,
or shape[1] of `bias_init` is not equal to `out_channels`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
dense_impl: Type[TDenseImpl],
in_channels: int,
out_channels: int,
weight_init: Union[Tensor, str, Initializer, numbers.Number] = 'normal',
bias_init: Union[Tensor, str, Initializer, numbers.Number] = 'zeros',
has_bias: bool = True) -> None:
"""Initialize Dense."""
super(Dense, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
self.dtype = None
self.reshape = P.Reshape()
self.shape_op = P.Shape()
self.weight_x = None
self.weight_y = None
if isinstance(weight_init, Tensor):
self.dtype = weight_init.dtype
if self.dtype in [mindspore.float16, mindspore.float32] and ( \
weight_init.ndim != 3
or weight_init.shape[0] != 2 \
or weight_init.shape[1] != out_channels \
or weight_init.shape[2] != in_channels):
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
f"be equal to 3, and the first dim must be equal to 2, and the second dim must be "
f"equal to 'out_channels', and the third dim must be equal to 'in_channels'. But got "
f"'weight_init': {weight_init}, 'out_channels': {out_channels}, 'in_channels': "
f"{in_channels}.")
if self.dtype == mindspore.complex64 and ( \
weight_init.ndim != 2 \
or weight_init.shape[0] != out_channels \
or weight_init.shape[1] != in_channels):
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
f"be equal to 2, and the first dim must be equal to 'out_channels', "
f"and the second dim must be equal to 'in_channels'. But got "
f"'weight_init': {weight_init}, 'out_channels': {out_channels}, 'in_channels': "
f"{in_channels}.")
self.dense_impl = dense_impl(weight_init, [out_channels, in_channels])
self.bias_x = None
self.bias_y = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if self.dtype is None:
self.dtype = bias_init.dtype
elif self.dtype != bias_init.dtype:
raise TypeError("Data type of weight init tensor and the bias init tensor must be equal, "
f"but got weight_init.dtype={self.dtype} and bias_init.dtype={bias_init.dtype}")
if self.dtype in [mindspore.float16, mindspore.float32] and ( \
bias_init.ndim != 2 \
or bias_init.shape[0] != 2 \
or bias_init.shape[1] != out_channels):
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
f"be equal to 2, and the second dim must be equal to 'out_channels'. But got "
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
if self.dtype == mindspore.complex64 and ( \
bias_init.ndim != 1 \
or bias_init.shape[0] != out_channels):
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
f"be equal to 1, and the only dim must be equal to 'out_channels'. But got "
f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
bias_init_x, bias_init_y = get_x_and_y(bias_init)
else:
bias_init_x = bias_init_y = bias_init
self.bias_x = Parameter(initializer(bias_init_x, [out_channels]), name="bias_x")
self.bias_y = Parameter(initializer(bias_init_y, [out_channels]), name="bias_y")
self.bias_add = P.BiasAdd()
self.matmul = P.MatMul(transpose_b=True)
def check_dense_input_shape(self, x: Tensor, x_dtype):
msg_prefix = f"For '{self.cls_name}', the" if self.cls_name else "The"
if x_dtype in [mindspore.float32, mindspore.float64] and (len(x) < 3 or x[0] != 2):
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 3, and the first dimension "
f"should be 2, but got {x}.")
if x_dtype == mindspore.complex64 and len(x) < 2:
raise ValueError(f"{msg_prefix} dimension of 'x' should not be less than 2, but got {x}.")
return None
def construct(self, u: Tensor) -> Tensor:
"""Construct"""
if self.dtype is not None and self.dtype != u.dtype:
raise TypeError("dtype must be equal to the data type of the inputs tensor, but got: "
f"dtype={self.dtype} and inputs.dtype={u.dtype}")
u_shape = self.shape_op(u)
self.check_dense_input_shape(u_shape, u.dtype)
u_reshape = [-1, u_shape[-1]]
if u.dtype in [mindspore.float32, mindspore.float64]:
u_reshape = [2] + u_reshape
if len(u_reshape) < len(u_shape):
u = self.reshape(u, tuple(u_reshape))
x, y = get_x_and_y(u)
out_x, out_y = self.dense_impl(self.matmul, x, y)
if self.has_bias:
out_x = self.bias_add(out_x, self.bias_x)
out_y = self.bias_add(out_y, self.bias_y)
out = to_2channel(out_x, out_y, u.dtype)
if len(u_reshape) < len(u_shape):
out_shape = u_shape[:-1] + (-1,)
out = self.reshape(out, out_shape)
return out
def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
return s

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,49 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Uniform operators"""
from mindspore.nn import Cell
class _UniformOperator(Cell):
r"""
Base class for layers that operate with the second-order hypercomplex numbers, and are designed
using the bridge pattern.
Constructs the object of the 'hc_op' type, passing 'hc_impl' as a parameter.
Args:
hc_op (Type): The abstraction part of the bridge pattern.
hc_impl (Type): The implementor part of the bridge pattern.
**kwargs (dict): Additional arguments that may be required to construct the specific layer
Inputs:
- **inp** (Tensor) - input tensor. The shape is specific to the subclass.
Outputs:
Tensor of shape, which is specific to the subclass.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
"""
def __init__(self,
hc_op,
hc_impl,
**kwargs) -> None:
super(_UniformOperator, self).__init__()
self.op = hc_op(hc_impl, **kwargs)
def construct(self, x):
return self.op(x)

View File

@ -0,0 +1,51 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utils"""
from typing import Tuple, Union
import mindspore
from mindspore import ops as P
to_complex = P.Complex()
get_real = P.Real()
get_imag = P.Imag()
unstack = P.Unstack(0)
cat = P.Concat(0)
def get_x_and_y(tensor):
if tensor.dtype == mindspore.complex64:
return get_real(tensor), get_imag(tensor)
return unstack(tensor)
def to_2channel(real, imag, dtype=None):
'''Convert to 2 channel format'''
if dtype is not None and dtype == mindspore.complex64:
return to_complex(real, imag)
if dtype is not None and (dtype != real.dtype or dtype != imag.dtype):
raise ValueError("dtype must match with data type of the input tensors, but got: "
f"dtype={dtype}, real.dtype={real.dtype}, imag.dtype={imag.dtype}")
expand_dims = P.ExpandDims()
real = expand_dims(real, 0)
imag = expand_dims(imag, 0)
return cat((real, imag))
_size_1_t = Union[int, Tuple[int]]
_size_2_t = Union[int, Tuple[int, int]]
_size_3_t = Union[int, Tuple[int, int, int]]
_size_any_t = Union[int, Tuple[int, ...]]

View File

@ -0,0 +1,83 @@
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
import mindspore.hypercomplex.dual as ops
class DeepConvNet(nn.Cell):
def __init__(self):
super(DeepConvNet, self).__init__()
self.conv1 = ops.Conv1d(1, 16, kernel_size=6, stride=2, padding=2, pad_mode='pad')
self.bn1 = ops.BatchNorm1d(16)
self.avg_pool1 = ops.AvgPool1d(kernel_size=2, stride=2)
self.pad1 = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 2)), mode='CONSTANT')
self.conv2 = ops.Conv1d(16, 32, kernel_size=3, stride=2, padding=0)
self.bn2 = ops.BatchNorm1d(32)
self.avg_pool2 = ops.AvgPool1d(kernel_size=2, stride=2)
self.conv3 = ops.Conv1d(32, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
self.bn3 = ops.BatchNorm1d(64)
self.avg_pool3 = ops.AvgPool1d(kernel_size=2, stride=2)
self.conv4 = ops.Conv1d(64, 64, kernel_size=3, stride=1, padding=1, pad_mode='pad')
self.bn4 = ops.BatchNorm1d(64)
self.avg_pool4 = ops.AvgPool1d(kernel_size=2, stride=2)
self.conv5 = ops.Conv1d(64, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
self.conv6 = ops.Conv1d(128, 128, kernel_size=3, stride=1, padding=1, pad_mode='pad')
self.bn6 = ops.BatchNorm1d(128)
self.avg_pool6 = ops.AvgPool1d(kernel_size=2, stride=2)
self.shape_op = P.Shape()
self.reshape = P.Reshape()
self.permute = P.Transpose()
self.flatten = P.Flatten()
self.fc1 = ops.Dense(4096, 1024)
self.fc2 = nn.Dense(2048, 84)
self.relu = ops.ReLU()
self.sigmoid = nn.Sigmoid()
def construct(self, u: Tensor) -> Tensor:
u = self.conv1(u)
u = self.bn1(u)
u = self.relu(u)
u = self.avg_pool1(u)
u = self.pad1(u)
u = self.conv2(u)
u = self.bn2(u)
u = self.relu(u)
u = self.avg_pool2(u)
u = self.conv3(u)
u = self.bn3(u)
u = self.relu(u)
u = self.avg_pool3(u)
u = self.conv4(u)
u = self.bn4(u)
u = self.relu(u)
u = self.avg_pool4(u)
u = self.conv5(u)
u = self.relu(u)
u = self.conv6(u)
u = self.bn6(u)
u = self.relu(u)
u = self.avg_pool6(u)
u_shape = self.shape_op(u)
u = self.reshape(u, (u_shape[0], u_shape[1], -1))
u = self.fc1(u)
u = self.relu(u)
u = self.permute(u, (1, 0, 2))
x = self.flatten(u)
x = self.fc2(x)
x = self.sigmoid(x)
return x

View File

@ -0,0 +1,32 @@
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
import mindspore.hypercomplex.dual as ops
class HCModel(nn.Cell):
def __init__(self):
super(HCModel, self).__init__()
self.conv1 = ops.Conv2d(1, 10, kernel_size=3)
self.bn1 = ops.BatchNorm2d(10)
self.max_pool = ops.MaxPool2d(2)
self.relu = ops.ReLU()
self.fc1 = ops.Dense(7290, 256)
self.fc2 = nn.Dense(512, 10)
self.concat = P.Concat(1)
def construct(self, u: Tensor) -> Tensor:
u = to_2channel(u[:, :1], u[:, 1:])
u = self.conv1(u)
u = self.bn1(u)
u = self.relu(u)
u = self.max_pool(u)
u = u.view(2, u.shape[1], -1)
u = self.fc1(u)
u = self.relu(u)
out_x, out_y = get_x_and_y(u)
out = self.concat([out_x, out_y])
out = self.fc2(out)
return out

View File

@ -0,0 +1,593 @@
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
from scipy.stats import truncnorm
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
import mindspore.hypercomplex.dual as ops
from mindspore.hypercomplex.utils import get_x_and_y, to_2channel
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
if res_base:
return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', weight_init=weight)
return ops.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
if res_base:
return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='pad', weight_init=weight)
return ops.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal((2, *weight_shape), mode="fan_out", nonlinearity='relu'))
if res_base:
return ops.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
return ops.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel, res_base=False):
if res_base:
return ops.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
return ops.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
ops.ReLU(), ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = ops.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([ops.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResidualBlockBase(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: True.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlockBase(3, 256, stride=2)
"""
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False,
se_block=False,
res_base=True):
super(ResidualBlockBase, self).__init__()
self.res_base = res_base
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
self.bn1d = _bn(out_channel)
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
self.bn2d = _bn(out_channel)
self.relu = ops.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=use_se, res_base=self.res_base),
_bn(out_channel, res_base)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1d(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2d(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
use_se=False,
res_base=False):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.res_base = res_base
self.se_block = False
if self.use_se:
self.se_block = True
if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
self.bn1 = _bn(64, self.res_base)
self.relu = ops.ReLU()
if self.res_base:
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
else:
self.maxpool = ops.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)
self.avgpool = ops.AvgPool2d(4)
self.concat = P.Concat(1)
self.flatten = nn.Flatten()
self.end_point = _fc(16384, num_classes, use_se=self.use_se)
def construct(self, x):
x = to_2channel(x[:, :3], x[:, 3:])
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.res_base:
x_1, x_2 = get_x_and_y(x)
x_1 = self.pad(x_1)
x_2 = self.pad(x_2)
x = to_2channel(x_1, x_2)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
out = self.avgpool(x)
out_x, out_y = get_x_and_y(out)
out = self.concat([out_x, out_y])
out = self.flatten(out)
out = self.end_point(out)
return out
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def resnet18(class_num=10):
"""
Get ResNet18 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet18 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[2, 2, 2, 2],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet34(class_num=10):
"""
Get ResNet34 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet34 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[3, 4, 6, 3],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of SE-ResNet50 neural network.
Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

View File

@ -0,0 +1,13 @@
import numpy as np
from mindspore import context, Tensor
from mindspore.ops import operations as P
from deepconvnet import DeepConvNet
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
model = DeepConvNet()
model.set_train(False)
u = Tensor(np.random.random((2, 32, 1, 4096)).astype(np.float32))
y = model(u)
print(P.Shape()(y), y)

View File

@ -0,0 +1,112 @@
import argparse
import numpy as np
import mindspore
from mindspore import nn, context, ops
from mindspore.common import dtype as mstype
from mindspore.dataset import MnistDataset
from hcmodel import HCModel
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class ImageToDualImage:
@staticmethod
def __call__(img):
return np.concatenate((img, img), axis=0)
def create_dataset(dataset_dir, batch_size, usage=None):
dataset = MnistDataset(dataset_dir=dataset_dir, usage=usage)
type_cast_op = mindspore.dataset.transforms.TypeCast(mstype.int32)
# define map operations
trans = [mindspore.dataset.vision.Rescale(1.0 / 255.0, 0),
mindspore.dataset.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
mindspore.dataset.vision.HWC2CHW(),
ImageToDualImage()]
dataset = dataset.map(operations=type_cast_op, input_columns="label")
dataset = dataset.map(operations=trans, input_columns="image")
dataset = dataset.batch(batch_size)
return dataset
def train(model, dataset, loss_fn, optimizer):
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
size = dataset.get_dataset_size()
model.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)
if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
def test(model, dataset, loss_fn):
num_batches = dataset.get_dataset_size()
model.set_train(False)
total, test_loss, correct = 0, 0, 0
for data, label in dataset.create_tuple_iterator():
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= total
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
def main():
parser = argparse.ArgumentParser(description='MindSpore MNIST Testing')
parser.add_argument(
'--dataset', default=None, type=str, metavar='DS', required=True,
help='Path to the dataset folder'
)
parser.add_argument(
'--bs', default=64, type=int, metavar='N', required=False,
help='Mini-batch size'
)
args = parser.parse_args()
# Process the MNIST dataset.
train_dataset = create_dataset(args.dataset, args.bs, "train")
test_dataset = create_dataset(args.dataset, args.bs, "test")
for img, lbl in test_dataset.create_tuple_iterator():
print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
print(f"Shape of label: {lbl.shape} {lbl.dtype}")
break
# Initialize hypercomplex model
net = HCModel()
# Initialize loss function and optimizer
criterion = nn.CrossEntropyLoss()
optim = nn.SGD(net.trainable_params(), 1e-2)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(net, train_dataset, criterion, optim)
test(net, test_dataset, criterion)
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,115 @@
import argparse
import numpy as np
from mindspore import nn, context, ops
from mindspore.common import dtype as mstype
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms
from resnet import resnet18
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class ImageToDualImage:
@staticmethod
def __call__(img):
return np.concatenate((img, img), axis=0)
def create_dataset(dataset_dir, batch_size, usage=None):
dataset = Cifar10Dataset(dataset_dir=dataset_dir, usage=usage)
type_cast_op = transforms.TypeCast(mstype.int32)
# define map operations
trans = [vision.ToPIL(),
vision.RandomCrop((32, 32), (4, 4, 4, 4)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Resize((224, 224)),
vision.ToTensor(),
vision.Rescale(1.0 / 255.0, 0.0),
vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010], is_hwc=False),
ImageToDualImage()]
dataset = dataset.map(operations=type_cast_op, input_columns="label")
dataset = dataset.map(operations=trans, input_columns="image")
dataset = dataset.batch(batch_size)
return dataset
def train(model, dataset, loss_fn, optimizer):
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
size = dataset.get_dataset_size()
model.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)
if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
def test(model, dataset, loss_fn):
num_batches = dataset.get_dataset_size()
model.set_train(False)
total, test_loss, correct = 0, 0, 0
for data, label in dataset.create_tuple_iterator():
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= total
print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
def main():
parser = argparse.ArgumentParser(description='MindSpore ResNet Testing')
parser.add_argument(
'--dataset', default=None, type=str, metavar='DS', required=True,
help='Path to the dataset folder'
)
parser.add_argument(
'--bs', default=64, type=int, metavar='N', required=False,
help='Mini-batch size'
)
args = parser.parse_args()
# Process the cifar dataset.
train_dataset = create_dataset(args.dataset, args.bs, "train")
test_dataset = create_dataset(args.dataset, args.bs, "test")
for img, lbl in test_dataset.create_tuple_iterator():
print(f"Shape of image [N, C, H, W]: {img.shape} {img.dtype}")
print(f"Shape of label: {lbl.shape} {lbl.dtype}")
break
# Initialize hypercomplex model
net = resnet18()
# Initialize loss function and optimizer
criterion = nn.CrossEntropyLoss()
optim = nn.SGD(net.trainable_params(), 1e-2)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(net, train_dataset, criterion, optim)
test(net, test_dataset, criterion)
print("Done!")
if __name__ == "__main__":
main()