forked from mindspore-Ecosystem/mindspore
add bnn_layers to nn.probability
This commit is contained in:
parent
fb2f888ec8
commit
61dbb1b17c
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Bayesian Layer.
|
||||||
|
|
||||||
|
The high-level components(Cells) used to construct the bayesian neural network.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper
|
||||||
|
from .conv_variational import ConvReparam
|
||||||
|
from .dense_variational import DenseReparam
|
||||||
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
from .bnn_cell_wrapper import WithBNNLossCell
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
__all__.extend(conv_variational.__all__)
|
||||||
|
__all__.extend(dense_variational.__all__)
|
||||||
|
__all__.extend(layer_distribution.__all__)
|
||||||
|
__all__.extend(bnn_cell_wrapper.__all__)
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate WithLossCell suitable for BNN."""
|
||||||
|
from .conv_variational import _ConvVariational
|
||||||
|
from .dense_variational import _DenseVariational
|
||||||
|
from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss
|
||||||
|
|
||||||
|
__all__ = ['WithBNNLossCell']
|
||||||
|
|
||||||
|
|
||||||
|
class ClassWrap:
|
||||||
|
"""Decorator of WithBNNLossCell"""
|
||||||
|
def __init__(self, cls):
|
||||||
|
self._cls = cls
|
||||||
|
self.bnn_loss_file = None
|
||||||
|
|
||||||
|
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor):
|
||||||
|
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor)
|
||||||
|
bnn_with_loss = obj()
|
||||||
|
self.bnn_loss_file = obj.bnn_loss_file
|
||||||
|
return bnn_with_loss
|
||||||
|
|
||||||
|
|
||||||
|
@ClassWrap
|
||||||
|
class WithBNNLossCell:
|
||||||
|
r"""
|
||||||
|
Generate WithLossCell suitable for BNN.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone (Cell): The target network.
|
||||||
|
loss_fn (Cell): The loss function used to compute loss.
|
||||||
|
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1.
|
||||||
|
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||||
|
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, a scalar tensor with shape :math:`()`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = Net()
|
||||||
|
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
|
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
|
||||||
|
>>> net_with_criterion = net_with_criterion_object()
|
||||||
|
>>>
|
||||||
|
>>> batch_size = 2
|
||||||
|
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
|
||||||
|
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
|
||||||
|
>>>
|
||||||
|
>>> net_with_criterion(data, label)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
|
||||||
|
self.backbone = backbone
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.dnn_factor = dnn_factor
|
||||||
|
self.bnn_factor = bnn_factor
|
||||||
|
self.bnn_loss_file = None
|
||||||
|
|
||||||
|
def _generate_loss_cell(self):
|
||||||
|
"""Generate WithBNNLossCell by ast."""
|
||||||
|
layer_count = self._kl_loss_count(self.backbone)
|
||||||
|
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
|
||||||
|
self.dnn_factor, self.bnn_factor)
|
||||||
|
return bnn_with_loss
|
||||||
|
|
||||||
|
def _kl_loss_count(self, net):
|
||||||
|
""" Calculate the number of Bayesian layers."""
|
||||||
|
count = 0
|
||||||
|
for (_, layer) in net.name_cells().items():
|
||||||
|
if isinstance(layer, (_DenseVariational, _ConvVariational)):
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
count += self._kl_loss_count(layer)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
return self._generate_loss_cell()
|
|
@ -0,0 +1,270 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Convolutional variational layers."""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore._checkparam import twice
|
||||||
|
from ...layer.conv import _Conv
|
||||||
|
from ...cell import Cell
|
||||||
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
|
||||||
|
__all__ = ['ConvReparam']
|
||||||
|
|
||||||
|
|
||||||
|
class _ConvVariational(_Conv):
|
||||||
|
"""
|
||||||
|
Base class for all convolutional variational layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
kernel_size = twice(kernel_size)
|
||||||
|
stride = twice(stride)
|
||||||
|
dilation = twice(dilation)
|
||||||
|
super(_ConvVariational, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
pad_mode,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
group,
|
||||||
|
has_bias,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros'
|
||||||
|
)
|
||||||
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
|
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
|
||||||
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||||
|
|
||||||
|
# convolution args
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.group = group
|
||||||
|
self.has_bias = has_bias
|
||||||
|
|
||||||
|
# distribution trainable parameters
|
||||||
|
self.shape = [self.out_channels,
|
||||||
|
self.in_channels // self.group, *self.kernel_size]
|
||||||
|
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
|
||||||
|
if isinstance(weight_prior_fn, Cell):
|
||||||
|
self.weight_prior = weight_prior_fn
|
||||||
|
else:
|
||||||
|
self.weight_prior = weight_prior_fn()
|
||||||
|
|
||||||
|
self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
self.bias.requires_grad = False
|
||||||
|
|
||||||
|
if isinstance(bias_prior_fn, Cell):
|
||||||
|
self.bias_prior = bias_prior_fn
|
||||||
|
else:
|
||||||
|
self.bias_prior = bias_prior_fn()
|
||||||
|
|
||||||
|
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
|
||||||
|
|
||||||
|
# mindspore operations
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.conv2d = P.Conv2D(out_channel=self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
mode=1,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
pad=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation,
|
||||||
|
group=self.group)
|
||||||
|
|
||||||
|
self.log = P.Log()
|
||||||
|
self.sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, inputs):
|
||||||
|
outputs = self._apply_variational_weight(inputs)
|
||||||
|
if self.has_bias:
|
||||||
|
outputs = self._apply_variational_bias(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \
|
||||||
|
'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\
|
||||||
|
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
|
||||||
|
self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std,
|
||||||
|
self.has_bias)
|
||||||
|
if self.has_bias:
|
||||||
|
str_info = str_info + ', bias_mean={}, bias_std={}'\
|
||||||
|
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
|
||||||
|
return str_info
|
||||||
|
|
||||||
|
def _apply_variational_bias(self, inputs):
|
||||||
|
bias_posterior_tensor = self.bias_posterior("sample")
|
||||||
|
return self.bias_add(inputs, bias_posterior_tensor)
|
||||||
|
|
||||||
|
def compute_kl_loss(self):
|
||||||
|
"""Compute kl loss"""
|
||||||
|
weight_post_mean = self.weight_posterior("mean")
|
||||||
|
weight_post_sd = self.weight_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.weight_prior("kl_loss", "Normal",
|
||||||
|
weight_post_mean, weight_post_sd)
|
||||||
|
kl_loss = self.sum(kl)
|
||||||
|
if self.has_bias:
|
||||||
|
bias_post_mean = self.bias_posterior("mean")
|
||||||
|
bias_post_sd = self.bias_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.bias_prior("kl_loss", "Normal",
|
||||||
|
bias_post_mean, bias_post_sd)
|
||||||
|
kl = self.sum(kl)
|
||||||
|
kl_loss += kl
|
||||||
|
return kl_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ConvReparam(_ConvVariational):
|
||||||
|
r"""
|
||||||
|
Convolutional variational layers with Reparameterization.
|
||||||
|
|
||||||
|
See more details in paper `Auto-Encoding Variational Bayes
|
||||||
|
<https://arxiv.org/abs/1312.6114>`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (Union[int, tuple[int]]): The data type is int or
|
||||||
|
tuple with 2 integers. Specifies the height and width of the 2D
|
||||||
|
convolution window. Single int means the value if for both
|
||||||
|
height and width of the kernel. A tuple of 2 ints means the
|
||||||
|
first value is for the height and the other is for the width of
|
||||||
|
the kernel.
|
||||||
|
stride(Union[int, tuple[int]]): The distance of kernel moving,
|
||||||
|
an int number that represents the height and width of movement
|
||||||
|
are both strides, or a tuple of two int numbers that represent
|
||||||
|
height and width of movement respectively. Default: 1.
|
||||||
|
pad_mode (str): Specifies padding mode. The optional values are
|
||||||
|
"same", "valid", "pad". Default: "same".
|
||||||
|
|
||||||
|
- same: Adopts the way of completion. Output height and width
|
||||||
|
will be the same as the input.
|
||||||
|
Total number of padding will be calculated for horizontal and
|
||||||
|
vertical direction and evenly distributed to top and bottom,
|
||||||
|
left and right if possible. Otherwise, the last extra padding
|
||||||
|
will be done from the bottom and the right side. If this mode
|
||||||
|
is set, `padding` must be 0.
|
||||||
|
|
||||||
|
- valid: Adopts the way of discarding. The possibly largest
|
||||||
|
height and width of output will be return without padding.
|
||||||
|
Extra pixels will be discarded. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- pad: Implicit paddings on both sides of the input. The number
|
||||||
|
of `padding` will be padded to the input Tensor borders.
|
||||||
|
`padding` should be greater than or equal to 0.
|
||||||
|
|
||||||
|
padding (Union[int, tuple[int]]): Implicit paddings on both sides of
|
||||||
|
the input. Default: 0.
|
||||||
|
dilation (Union[int, tuple[int]]): The data type is int or tuple
|
||||||
|
with 2 integers. Specifies the dilation rate to use for dilated
|
||||||
|
convolution. If set to be :math:`k > 1`,
|
||||||
|
there will be :math:`k - 1` pixels skipped for each sampling
|
||||||
|
location. Its value should be greater or equal to 1 and bounded
|
||||||
|
by the height and width of the input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_ channels` and
|
||||||
|
`out_channels` should be divisible by the number of groups.
|
||||||
|
Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector.
|
||||||
|
Default: False.
|
||||||
|
weight_prior_fn: prior distribution for convolution kernel.
|
||||||
|
It should return a mindspore distribution instance.
|
||||||
|
Default: NormalPrior. (which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
weight_posterior_fn: posterior distribution for sampling convolution
|
||||||
|
kernel. It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
bias_prior_fn: prior distribution for bias vector. It should return
|
||||||
|
a mindspore distribution.
|
||||||
|
Default: NormalPrior(which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
bias_posterior_fn: posterior distribution for sampling bias vector.
|
||||||
|
It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Examples:
|
||||||
|
>>> net = ConvReparam(120, 240, 4, has_bias=False)
|
||||||
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
|
>>> net(input).shape
|
||||||
|
(1, 240, 1024, 640)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
super(ConvReparam, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
group=group,
|
||||||
|
has_bias=has_bias,
|
||||||
|
weight_prior_fn=weight_prior_fn,
|
||||||
|
weight_posterior_fn=weight_posterior_fn,
|
||||||
|
bias_prior_fn=bias_prior_fn,
|
||||||
|
bias_posterior_fn=bias_posterior_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_variational_weight(self, inputs):
|
||||||
|
weight_posterior_tensor = self.weight_posterior("sample")
|
||||||
|
outputs = self.conv2d(inputs, weight_posterior_tensor)
|
||||||
|
return outputs
|
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""dense_variational"""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore._checkparam import check_int_positive, check_bool
|
||||||
|
from ...cell import Cell
|
||||||
|
from ...layer.activation import get_activation
|
||||||
|
from .layer_distribution import NormalPrior, NormalPosterior
|
||||||
|
|
||||||
|
__all__ = ['DenseReparam']
|
||||||
|
|
||||||
|
|
||||||
|
class _DenseVariational(Cell):
|
||||||
|
"""
|
||||||
|
Base class for all dense variational layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=None,
|
||||||
|
has_bias=True,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
super(_DenseVariational, self).__init__()
|
||||||
|
self.in_channels = check_int_positive(in_channels)
|
||||||
|
self.out_channels = check_int_positive(out_channels)
|
||||||
|
self.has_bias = check_bool(has_bias)
|
||||||
|
|
||||||
|
if isinstance(weight_prior_fn, Cell):
|
||||||
|
self.weight_prior = weight_prior_fn
|
||||||
|
else:
|
||||||
|
self.weight_prior = weight_prior_fn()
|
||||||
|
|
||||||
|
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
if isinstance(bias_prior_fn, Cell):
|
||||||
|
self.bias_prior = bias_prior_fn
|
||||||
|
else:
|
||||||
|
self.bias_prior = bias_prior_fn()
|
||||||
|
|
||||||
|
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
if isinstance(self.activation, str):
|
||||||
|
self.activation = get_activation(activation)
|
||||||
|
self.activation_flag = self.activation is not None
|
||||||
|
|
||||||
|
self.matmul = P.MatMul(transpose_b=True)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
outputs = self._apply_variational_weight(x)
|
||||||
|
if self.has_bias:
|
||||||
|
outputs = self._apply_variational_bias(outputs)
|
||||||
|
if self.activation_flag:
|
||||||
|
outputs = self.activation(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
|
||||||
|
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
|
||||||
|
self.weight_posterior.untransformed_std, self.has_bias)
|
||||||
|
if self.has_bias:
|
||||||
|
str_info = str_info + ', bias_mean={}, bias_std={}' \
|
||||||
|
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
|
||||||
|
|
||||||
|
if self.activation_flag:
|
||||||
|
str_info = str_info + ', activation={}'.format(self.activation)
|
||||||
|
return str_info
|
||||||
|
|
||||||
|
def _apply_variational_bias(self, inputs):
|
||||||
|
bias_posterior_tensor = self.bias_posterior("sample")
|
||||||
|
return self.bias_add(inputs, bias_posterior_tensor)
|
||||||
|
|
||||||
|
def compute_kl_loss(self):
|
||||||
|
"""Compute kl loss."""
|
||||||
|
weight_post_mean = self.weight_posterior("mean")
|
||||||
|
weight_post_sd = self.weight_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
|
||||||
|
kl_loss = self.sum(kl)
|
||||||
|
if self.has_bias:
|
||||||
|
bias_post_mean = self.bias_posterior("mean")
|
||||||
|
bias_post_sd = self.bias_posterior("sd")
|
||||||
|
|
||||||
|
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
|
||||||
|
kl = self.sum(kl)
|
||||||
|
kl_loss += kl
|
||||||
|
return kl_loss
|
||||||
|
|
||||||
|
|
||||||
|
class DenseReparam(_DenseVariational):
|
||||||
|
r"""
|
||||||
|
Dense variational layers with Reparameterization.
|
||||||
|
|
||||||
|
See more details in paper `Auto-Encoding Variational Bayes
|
||||||
|
<https://arxiv.org/abs/1312.6114>`
|
||||||
|
|
||||||
|
Applies dense-connected layer for the input. This layer implements the operation as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
|
||||||
|
|
||||||
|
where :math:`\text{activation}` is the activation function passed as the activation
|
||||||
|
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
|
||||||
|
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
|
||||||
|
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
|
||||||
|
bias vector with the same data type as the inputs created by the layer (only if
|
||||||
|
has_bias is True). The bias vector is sampling from posterior distribution of
|
||||||
|
:math:`\text{bias}`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel.
|
||||||
|
out_channels (int): The number of output channel .
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
|
weight_prior_fn: prior distribution for weight.
|
||||||
|
It should return a mindspore distribution instance.
|
||||||
|
Default: NormalPrior. (which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
weight_posterior_fn: posterior distribution for sampling weight.
|
||||||
|
It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
bias_prior_fn: prior distribution for bias vector. It should return
|
||||||
|
a mindspore distribution.
|
||||||
|
Default: NormalPrior(which creates an instance of standard
|
||||||
|
normal distribution).
|
||||||
|
bias_posterior_fn: posterior distribution for sampling bias vector.
|
||||||
|
It should be a function handle which returns a mindspore
|
||||||
|
distribution instance.
|
||||||
|
Default: NormalPosterior.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, out\_channels)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = DenseReparam(3, 4)
|
||||||
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
|
>>> net(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=None,
|
||||||
|
has_bias=True,
|
||||||
|
weight_prior_fn=NormalPrior,
|
||||||
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
|
||||||
|
bias_prior_fn=NormalPrior,
|
||||||
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
|
||||||
|
super(DenseReparam, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
activation=activation,
|
||||||
|
has_bias=has_bias,
|
||||||
|
weight_prior_fn=weight_prior_fn,
|
||||||
|
weight_posterior_fn=weight_posterior_fn,
|
||||||
|
bias_prior_fn=bias_prior_fn,
|
||||||
|
bias_posterior_fn=bias_posterior_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_variational_weight(self, inputs):
|
||||||
|
weight_posterior_tensor = self.weight_posterior("sample")
|
||||||
|
outputs = self.matmul(inputs, weight_posterior_tensor)
|
||||||
|
return outputs
|
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Initialize normal distributions"""
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.common.parameter import Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from ...cell import Cell
|
||||||
|
from ..distribution.normal import Normal
|
||||||
|
|
||||||
|
__all__ = ['NormalPrior', 'NormalPosterior']
|
||||||
|
|
||||||
|
|
||||||
|
class NormalPrior(Cell):
|
||||||
|
r"""
|
||||||
|
To initialize a normal distribution of mean 0 and standard deviation 0.1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
|
Default: mindspore.float32.
|
||||||
|
mean (int, float): Mean of normal distribution.
|
||||||
|
std (int, float): Standard deviation of normal distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, a normal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self, dtype=mstype.float32, mean=0, std=0.1):
|
||||||
|
super(NormalPrior, self).__init__()
|
||||||
|
self.normal = Normal(mean, std, dtype=dtype)
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return self.normal(*inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class NormalPosterior(Cell):
|
||||||
|
r"""
|
||||||
|
Build Normal distributions with trainable parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name prepended to trainable parameter.
|
||||||
|
shape (list): Shape of the mean and standard deviation.
|
||||||
|
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
|
||||||
|
Default: mindspore.float32.
|
||||||
|
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
|
||||||
|
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
|
||||||
|
Default: 0.1.
|
||||||
|
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
|
||||||
|
parameters. Default: -5.
|
||||||
|
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
|
||||||
|
trainable parameters. Default: 0.1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, a normal distribution.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
name,
|
||||||
|
shape,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
loc_mean=0,
|
||||||
|
loc_std=0.1,
|
||||||
|
untransformed_scale_mean=-5,
|
||||||
|
untransformed_scale_std=0.1):
|
||||||
|
super(NormalPosterior, self).__init__()
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise ValueError('The type of `name` should be `str`')
|
||||||
|
|
||||||
|
self.mean = Parameter(
|
||||||
|
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
|
||||||
|
|
||||||
|
self.untransformed_std = Parameter(
|
||||||
|
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
|
||||||
|
name=name + '_untransformed_std')
|
||||||
|
|
||||||
|
self.normal = Normal()
|
||||||
|
|
||||||
|
def std_trans(self, std_pre):
|
||||||
|
"""Transform std_pre to prevent its value being zero."""
|
||||||
|
std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1)
|
||||||
|
return std
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
std = self.std_trans(self.untransformed_std)
|
||||||
|
return self.normal(*inputs, mean=self.mean, sd=std)
|
|
@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
import mindspore.nn.probability as msp
|
||||||
|
|
||||||
def cast_to_tensor(t, hint_dtype=mstype.float32):
|
def cast_to_tensor(t, hint_dtype=mstype.float32):
|
||||||
"""
|
"""
|
||||||
|
@ -84,7 +85,7 @@ def check_scalar_from_param(params):
|
||||||
Notes: String parameters are excluded.
|
Notes: String parameters are excluded.
|
||||||
"""
|
"""
|
||||||
for value in params.values():
|
for value in params.values():
|
||||||
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
|
||||||
return params['distribution'].is_scalar_batch
|
return params['distribution'].is_scalar_batch
|
||||||
if isinstance(value, Parameter):
|
if isinstance(value, Parameter):
|
||||||
return False
|
return False
|
||||||
|
@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params):
|
||||||
"""
|
"""
|
||||||
broadcast_shape = []
|
broadcast_shape = []
|
||||||
for value in params.values():
|
for value in params.values():
|
||||||
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)):
|
if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
|
||||||
return params['distribution'].broadcast_shape
|
return params['distribution'].broadcast_shape
|
||||||
if isinstance(value, (str, type(params['dtype']))):
|
if isinstance(value, (str, type(params['dtype']))):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer):
|
||||||
def visit_FunctionDef(self, node):
|
def visit_FunctionDef(self, node):
|
||||||
"""visit function and add kl_loss computation."""
|
"""visit function and add kl_loss computation."""
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
if node.name == 'compute_kl_loss':
|
if node.name == 'cal_kl_loss':
|
||||||
for i in range(self.layer_count):
|
for i in range(self.layer_count):
|
||||||
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
|
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
|
||||||
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
|
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
|
||||||
|
@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
|
||||||
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
|
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
|
||||||
backbone (Cell): The target network to wrap.
|
backbone (Cell): The target network to wrap.
|
||||||
loss_fn (Cell): The loss function used to compute loss.
|
loss_fn (Cell): The loss function used to compute loss.
|
||||||
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function.
|
dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function.
|
||||||
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
|
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
|
||||||
"""
|
"""
|
||||||
bnn_loss_func = _generate_kl_loss_func(layer_count)
|
bnn_loss_func = _generate_kl_loss_func(layer_count)
|
||||||
|
|
|
@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test
|
||||||
sklearn >= 0.0 # for st test
|
sklearn >= 0.0 # for st test
|
||||||
pandas >= 1.0.2 # for ut test
|
pandas >= 1.0.2 # for ut test
|
||||||
bs4
|
bs4
|
||||||
|
astunparse
|
||||||
|
|
Loading…
Reference in New Issue