add bnn_layers to nn.probability

This commit is contained in:
bingyaweng 2020-08-13 19:31:24 +08:00
parent fb2f888ec8
commit 61dbb1b17c
9 changed files with 685 additions and 5 deletions

View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Bayesian Layer.
The high-level components(Cells) used to construct the bayesian neural network.
"""
from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper
from .conv_variational import ConvReparam
from .dense_variational import DenseReparam
from .layer_distribution import NormalPrior, NormalPosterior
from .bnn_cell_wrapper import WithBNNLossCell
__all__ = []
__all__.extend(conv_variational.__all__)
__all__.extend(dense_variational.__all__)
__all__.extend(layer_distribution.__all__)
__all__.extend(bnn_cell_wrapper.__all__)

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate WithLossCell suitable for BNN."""
from .conv_variational import _ConvVariational
from .dense_variational import _DenseVariational
from ..transforms.bnn_loss.generate_kl_loss import gain_bnn_with_loss
__all__ = ['WithBNNLossCell']
class ClassWrap:
"""Decorator of WithBNNLossCell"""
def __init__(self, cls):
self._cls = cls
self.bnn_loss_file = None
def __call__(self, backbone, loss_fn, backbone_factor, kl_factor):
obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor)
bnn_with_loss = obj()
self.bnn_loss_file = obj.bnn_loss_file
return bnn_with_loss
@ClassWrap
class WithBNNLossCell:
r"""
Generate WithLossCell suitable for BNN.
Args:
backbone (Cell): The target network.
loss_fn (Cell): The loss function used to compute loss.
dnn_factor(int, float): The coefficient of backbone's loss, which is computed by loss functin. Default: 1.
bnn_factor(int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tensor, a scalar tensor with shape :math:`()`.
Examples:
>>> net = Net()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
>>> net_with_criterion_object = WithBNNLossCell(net, loss_fn)
>>> net_with_criterion = net_with_criterion_object()
>>>
>>> batch_size = 2
>>> data = Tensor(np.ones([batch_size, 3, 64, 64]).astype(np.float32) * 0.01)
>>> label = Tensor(np.ones([batch_size, 1, 1, 1]).astype(np.int32))
>>>
>>> net_with_criterion(data, label)
"""
def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
self.backbone = backbone
self.loss_fn = loss_fn
self.dnn_factor = dnn_factor
self.bnn_factor = bnn_factor
self.bnn_loss_file = None
def _generate_loss_cell(self):
"""Generate WithBNNLossCell by ast."""
layer_count = self._kl_loss_count(self.backbone)
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
self.dnn_factor, self.bnn_factor)
return bnn_with_loss
def _kl_loss_count(self, net):
""" Calculate the number of Bayesian layers."""
count = 0
for (_, layer) in net.name_cells().items():
if isinstance(layer, (_DenseVariational, _ConvVariational)):
count += 1
else:
count += self._kl_loss_count(layer)
return count
def __call__(self):
return self._generate_loss_cell()

View File

@ -0,0 +1,270 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convolutional variational layers."""
from mindspore.ops import operations as P
from mindspore._checkparam import twice
from ...layer.conv import _Conv
from ...cell import Cell
from .layer_distribution import NormalPrior, NormalPosterior
__all__ = ['ConvReparam']
class _ConvVariational(_Conv):
"""
Base class for all convolutional variational layers.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
kernel_size = twice(kernel_size)
stride = twice(stride)
dilation = twice(dilation)
super(_ConvVariational, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init='normal',
bias_init='zeros'
)
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
# convolution args
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.has_bias = has_bias
# distribution trainable parameters
self.shape = [self.out_channels,
self.in_channels // self.group, *self.kernel_size]
self.weight.requires_grad = False
if isinstance(weight_prior_fn, Cell):
self.weight_prior = weight_prior_fn
else:
self.weight_prior = weight_prior_fn()
self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight')
if self.has_bias:
self.bias.requires_grad = False
if isinstance(bias_prior_fn, Cell):
self.bias_prior = bias_prior_fn
else:
self.bias_prior = bias_prior_fn()
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
# mindspore operations
self.bias_add = P.BiasAdd()
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.log = P.Log()
self.sum = P.ReduceSum()
def construct(self, inputs):
outputs = self._apply_variational_weight(inputs)
if self.has_bias:
outputs = self._apply_variational_bias(outputs)
return outputs
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \
'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std,
self.has_bias)
if self.has_bias:
str_info = str_info + ', bias_mean={}, bias_std={}'\
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
return str_info
def _apply_variational_bias(self, inputs):
bias_posterior_tensor = self.bias_posterior("sample")
return self.bias_add(inputs, bias_posterior_tensor)
def compute_kl_loss(self):
"""Compute kl loss"""
weight_post_mean = self.weight_posterior("mean")
weight_post_sd = self.weight_posterior("sd")
kl = self.weight_prior("kl_loss", "Normal",
weight_post_mean, weight_post_sd)
kl_loss = self.sum(kl)
if self.has_bias:
bias_post_mean = self.bias_posterior("mean")
bias_post_sd = self.bias_posterior("sd")
kl = self.bias_prior("kl_loss", "Normal",
bias_post_mean, bias_post_sd)
kl = self.sum(kl)
kl_loss += kl
return kl_loss
class ConvReparam(_ConvVariational):
r"""
Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or
tuple with 2 integers. Specifies the height and width of the 2D
convolution window. Single int means the value if for both
height and width of the kernel. A tuple of 2 ints means the
first value is for the height and the other is for the width of
the kernel.
stride(Union[int, tuple[int]]): The distance of kernel moving,
an int number that represents the height and width of movement
are both strides, or a tuple of two int numbers that represent
height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width
will be the same as the input.
Total number of padding will be calculated for horizontal and
vertical direction and evenly distributed to top and bottom,
left and right if possible. Otherwise, the last extra padding
will be done from the bottom and the right side. If this mode
is set, `padding` must be 0.
- valid: Adopts the way of discarding. The possibly largest
height and width of output will be return without padding.
Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number
of `padding` will be padded to the input Tensor borders.
`padding` should be greater than or equal to 0.
padding (Union[int, tuple[int]]): Implicit paddings on both sides of
the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple
with 2 integers. Specifies the dilation rate to use for dilated
convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling
location. Its value should be greater or equal to 1 and bounded
by the height and width of the input. Default: 1.
group (int): Split filter into groups, `in_ channels` and
`out_channels` should be divisible by the number of groups.
Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector.
Default: False.
weight_prior_fn: prior distribution for convolution kernel.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling convolution
kernel. It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
Examples:
>>> net = ConvReparam(120, 240, 4, has_bias=False)
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(ConvReparam, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_prior_fn=weight_prior_fn,
weight_posterior_fn=weight_posterior_fn,
bias_prior_fn=bias_prior_fn,
bias_posterior_fn=bias_posterior_fn
)
def _apply_variational_weight(self, inputs):
weight_posterior_tensor = self.weight_posterior("sample")
outputs = self.conv2d(inputs, weight_posterior_tensor)
return outputs

View File

@ -0,0 +1,188 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""dense_variational"""
from mindspore.ops import operations as P
from mindspore._checkparam import check_int_positive, check_bool
from ...cell import Cell
from ...layer.activation import get_activation
from .layer_distribution import NormalPrior, NormalPosterior
__all__ = ['DenseReparam']
class _DenseVariational(Cell):
"""
Base class for all dense variational layers.
"""
def __init__(
self,
in_channels,
out_channels,
activation=None,
has_bias=True,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(_DenseVariational, self).__init__()
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias)
if isinstance(weight_prior_fn, Cell):
self.weight_prior = weight_prior_fn
else:
self.weight_prior = weight_prior_fn()
self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight')
if self.has_bias:
if isinstance(bias_prior_fn, Cell):
self.bias_prior = bias_prior_fn
else:
self.bias_prior = bias_prior_fn()
self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias')
self.activation = activation
if isinstance(self.activation, str):
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()
self.sum = P.ReduceSum()
def construct(self, x):
outputs = self._apply_variational_weight(x)
if self.has_bias:
outputs = self._apply_variational_bias(outputs)
if self.activation_flag:
outputs = self.activation(outputs)
return outputs
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, weight_mean={}, weight_std={}, has_bias={}' \
.format(self.in_channels, self.out_channels, self.weight_posterior.mean,
self.weight_posterior.untransformed_std, self.has_bias)
if self.has_bias:
str_info = str_info + ', bias_mean={}, bias_std={}' \
.format(self.bias_posterior.mean, self.bias_posterior.untransformed_std)
if self.activation_flag:
str_info = str_info + ', activation={}'.format(self.activation)
return str_info
def _apply_variational_bias(self, inputs):
bias_posterior_tensor = self.bias_posterior("sample")
return self.bias_add(inputs, bias_posterior_tensor)
def compute_kl_loss(self):
"""Compute kl loss."""
weight_post_mean = self.weight_posterior("mean")
weight_post_sd = self.weight_posterior("sd")
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd)
kl_loss = self.sum(kl)
if self.has_bias:
bias_post_mean = self.bias_posterior("mean")
bias_post_sd = self.bias_posterior("sd")
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd)
kl = self.sum(kl)
kl_loss += kl
return kl_loss
class DenseReparam(_DenseVariational):
r"""
Dense variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
Applies dense-connected layer for the input. This layer implements the operation as:
.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),
where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same
data type as the inputs created by the layer, :math:`\text{weight}` is a weight
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a
bias vector with the same data type as the inputs created by the layer (only if
has_bias is True). The bias vector is sampling from posterior distribution of
:math:`\text{bias}`.
Args:
in_channels (int): The number of input channel.
out_channels (int): The number of output channel .
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
weight_prior_fn: prior distribution for weight.
It should return a mindspore distribution instance.
Default: NormalPrior. (which creates an instance of standard
normal distribution).
weight_posterior_fn: posterior distribution for sampling weight.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
bias_prior_fn: prior distribution for bias vector. It should return
a mindspore distribution.
Default: NormalPrior(which creates an instance of standard
normal distribution).
bias_posterior_fn: posterior distribution for sampling bias vector.
It should be a function handle which returns a mindspore
distribution instance.
Default: NormalPosterior.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = DenseReparam(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(
self,
in_channels,
out_channels,
activation=None,
has_bias=True,
weight_prior_fn=NormalPrior,
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape),
bias_prior_fn=NormalPrior,
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)):
super(DenseReparam, self).__init__(
in_channels,
out_channels,
activation=activation,
has_bias=has_bias,
weight_prior_fn=weight_prior_fn,
weight_posterior_fn=weight_posterior_fn,
bias_prior_fn=bias_prior_fn,
bias_posterior_fn=bias_posterior_fn
)
def _apply_variational_weight(self, inputs):
weight_posterior_tensor = self.weight_posterior("sample")
outputs = self.matmul(inputs, weight_posterior_tensor)
return outputs

View File

@ -0,0 +1,96 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Initialize normal distributions"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from ...cell import Cell
from ..distribution.normal import Normal
__all__ = ['NormalPrior', 'NormalPosterior']
class NormalPrior(Cell):
r"""
To initialize a normal distribution of mean 0 and standard deviation 0.1.
Args:
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
mean (int, float): Mean of normal distribution.
std (int, float): Standard deviation of normal distribution.
Returns:
Cell, a normal distribution.
"""
def __init__(self, dtype=mstype.float32, mean=0, std=0.1):
super(NormalPrior, self).__init__()
self.normal = Normal(mean, std, dtype=dtype)
def construct(self, *inputs):
return self.normal(*inputs)
class NormalPosterior(Cell):
r"""
Build Normal distributions with trainable parameters.
Args:
name (str): Name prepended to trainable parameter.
shape (list): Shape of the mean and standard deviation.
dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor.
Default: mindspore.float32.
loc_mean ( float, array_like of floats): Mean of distribution to initialize trainable parameters. Default: 0.
loc_std ( float, array_like of floats): Standard deviation of distribution to initialize trainable parameters.
Default: 0.1.
untransformed_scale_mean ( float, array_like of floats): Mean of distribution to initialize trainable
parameters. Default: -5.
untransformed_scale_std ( float, array_like of floats): Standard deviation of distribution to initialize
trainable parameters. Default: 0.1.
Returns:
Cell, a normal distribution.
"""
def __init__(self,
name,
shape,
dtype=mstype.float32,
loc_mean=0,
loc_std=0.1,
untransformed_scale_mean=-5,
untransformed_scale_std=0.1):
super(NormalPosterior, self).__init__()
if not isinstance(name, str):
raise ValueError('The type of `name` should be `str`')
self.mean = Parameter(
Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean')
self.untransformed_std = Parameter(
Tensor(np.random.normal(untransformed_scale_mean, untransformed_scale_std, shape), dtype=dtype),
name=name + '_untransformed_std')
self.normal = Normal()
def std_trans(self, std_pre):
"""Transform std_pre to prevent its value being zero."""
std = 1e-6 + P.Log()(P.Exp()(std_pre) + 1)
return std
def construct(self, *inputs):
std = self.std_trans(self.untransformed_std)
return self.normal(*inputs, mean=self.mean, sd=std)

View File

@ -21,6 +21,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.nn.probability as msp
def cast_to_tensor(t, hint_dtype=mstype.float32): def cast_to_tensor(t, hint_dtype=mstype.float32):
""" """
@ -84,7 +85,7 @@ def check_scalar_from_param(params):
Notes: String parameters are excluded. Notes: String parameters are excluded.
""" """
for value in params.values(): for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)): if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].is_scalar_batch return params['distribution'].is_scalar_batch
if isinstance(value, Parameter): if isinstance(value, Parameter):
return False return False
@ -109,7 +110,7 @@ def calc_broadcast_shape_from_param(params):
""" """
broadcast_shape = [] broadcast_shape = []
for value in params.values(): for value in params.values():
if isinstance(value, (nn.probability.bijector.Bijector, nn.probability.distribution.Distribution)): if isinstance(value, (msp.bijector.Bijector, msp.distribution.Distribution)):
return params['distribution'].broadcast_shape return params['distribution'].broadcast_shape
if isinstance(value, (str, type(params['dtype']))): if isinstance(value, (str, type(params['dtype']))):
continue continue

View File

@ -36,7 +36,7 @@ class _CodeTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
"""visit function and add kl_loss computation.""" """visit function and add kl_loss computation."""
self.generic_visit(node) self.generic_visit(node)
if node.name == 'compute_kl_loss': if node.name == 'cal_kl_loss':
for i in range(self.layer_count): for i in range(self.layer_count):
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())], func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(), value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
@ -71,7 +71,7 @@ def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers. layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
backbone (Cell): The target network to wrap. backbone (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss. loss_fn (Cell): The loss function used to compute loss.
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. dnn_factor (int, float): The coefficient of backbone's loss, which is computed by loss function.
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
""" """
bnn_loss_func = _generate_kl_loss_func(layer_count) bnn_loss_func = _generate_kl_loss_func(layer_count)

View File

@ -14,3 +14,4 @@ opencv-python >= 4.1.2.30 # for ut test
sklearn >= 0.0 # for st test sklearn >= 0.0 # for st test
pandas >= 1.0.2 # for ut test pandas >= 1.0.2 # for ut test
bs4 bs4
astunparse

View File

@ -92,7 +92,8 @@ required_package = [
'easydict >= 1.9', 'easydict >= 1.9',
'sympy >= 1.4', 'sympy >= 1.4',
'cffi >= 1.13.2', 'cffi >= 1.13.2',
'decorator >= 4.4.0' 'decorator >= 4.4.0',
'astunparse >= 1.6.3'
] ]
package_data = { package_data = {