forked from mindspore-Ecosystem/mindspore
add transforms to nn.probability
This commit is contained in:
parent
2c2fe9bed9
commit
1b1ad52e7c
|
@ -0,0 +1,24 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Transforms.
|
||||
|
||||
The high-level components used to transform model between DNN and DNN.
|
||||
"""
|
||||
from . import transform_bnn
|
||||
from .transform_bnn import TransformToBNN
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(transform_bnn.__all__)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
bnn loss.
|
||||
"""
|
||||
from . import generate_kl_loss
|
||||
from .generate_kl_loss import gain_bnn_with_loss
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Gain bnn_with_loss by rewrite WithLossCell as WithBNNLossCell to suit for BNN model"""
|
||||
import ast
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import astunparse
|
||||
import mindspore
|
||||
|
||||
|
||||
class _CodeTransformer(ast.NodeTransformer):
|
||||
"""
|
||||
Add kl_loss computation by analyzing the python code structure with the help of the AST module.
|
||||
|
||||
Args:
|
||||
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_count):
|
||||
self.layer_count = layer_count
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
"""visit function and add kl_loss computation."""
|
||||
self.generic_visit(node)
|
||||
if node.name == 'compute_kl_loss':
|
||||
for i in range(self.layer_count):
|
||||
func = ast.Assign(targets=[ast.Name(id='loss', ctx=ast.Store())],
|
||||
value=ast.BinOp(left=ast.Name(id='loss', ctx=ast.Load()), op=ast.Add(),
|
||||
right=ast.Call(func=ast.Name(id='self.kl_loss' + '[' + str(i) + ']',
|
||||
ctx=ast.Load()),
|
||||
args=[], keywords=[])))
|
||||
node.body.insert(-1, func)
|
||||
return node
|
||||
|
||||
|
||||
def _generate_kl_loss_func(layer_count):
|
||||
"""Rewrite WithLossCell as WithBNNLossCell to suit for BNN model."""
|
||||
path = os.path.dirname(mindspore.__file__) + '/nn/probability/transforms/bnn_loss/withLossCell.py'
|
||||
with open(path, 'r') as fp:
|
||||
srclines = fp.readlines()
|
||||
src = ''.join(srclines)
|
||||
if src.startswith((' ', '\t')):
|
||||
src = 'if 1:\n' + src
|
||||
expr_ast = ast.parse(src, mode='exec')
|
||||
transformer = _CodeTransformer(layer_count)
|
||||
modify = transformer.visit(expr_ast)
|
||||
modify = ast.fix_missing_locations(modify)
|
||||
func = astunparse.unparse(modify)
|
||||
return func
|
||||
|
||||
|
||||
def gain_bnn_with_loss(layer_count, backbone, loss_fn, dnn_factor, bnn_factor):
|
||||
"""
|
||||
Gain bnn_with_loss, which wraps bnn network with loss function and kl loss of each bayesian layer.
|
||||
|
||||
Args:
|
||||
layer_count (int): The number of kl loss to be generated, namely the number of Bayesian layers.
|
||||
backbone (Cell): The target network to wrap.
|
||||
loss_fn (Cell): The loss function used to compute loss.
|
||||
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function.
|
||||
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
|
||||
"""
|
||||
bnn_loss_func = _generate_kl_loss_func(layer_count)
|
||||
path = os.path.dirname(mindspore.__file__)
|
||||
bnn_loss_file = tempfile.NamedTemporaryFile(mode='w+t', suffix='.py', delete=True,
|
||||
dir=path + '/nn/probability/transforms/bnn_loss')
|
||||
bnn_loss_file.write(bnn_loss_func)
|
||||
bnn_loss_file.seek(0)
|
||||
|
||||
sys.path.append(path + '/nn/probability/transforms/bnn_loss')
|
||||
|
||||
module_name = os.path.basename(bnn_loss_file.name)[0:-3]
|
||||
bnn_loss_module = importlib.import_module(module_name, __package__)
|
||||
bnn_with_loss = bnn_loss_module.WithBNNLossCell(backbone, loss_fn, dnn_factor, bnn_factor)
|
||||
return bnn_with_loss, bnn_loss_file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Original WithBNNLossCell for ast to rewrite."""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.probability.bnn_layers.conv_variational import _ConvVariational
|
||||
from mindspore.nn.probability.bnn_layers.dense_variational import _DenseVariational
|
||||
|
||||
|
||||
class WithBNNLossCell(nn.Cell):
|
||||
"""
|
||||
Cell with loss function.
|
||||
|
||||
Wraps the network with loss function. This Cell accepts data, label, backbone_factor and kl_factor as inputs and
|
||||
the computed loss will be returned.
|
||||
"""
|
||||
def __init__(self, backbone, loss_fn, backbone_factor=1, kl_factor=1):
|
||||
super(WithBNNLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
self.backbone_factor = backbone_factor
|
||||
self.kl_factor = kl_factor
|
||||
self.kl_loss = []
|
||||
self._add_kl_loss(self._backbone)
|
||||
|
||||
def construct(self, x, label):
|
||||
y_pred = self._backbone(x)
|
||||
backbone_loss = self._loss_fn(y_pred, label)
|
||||
kl_loss = self.cal_kl_loss()
|
||||
loss = backbone_loss*self.backbone_factor + kl_loss*self.kl_factor
|
||||
return loss
|
||||
|
||||
def cal_kl_loss(self):
|
||||
"""Calculate kl loss."""
|
||||
loss = 0.0
|
||||
return loss
|
||||
|
||||
def _add_kl_loss(self, net):
|
||||
"""Collect kl loss of each Bayesian layer."""
|
||||
for (_, layer) in net.name_cells().items():
|
||||
if isinstance(layer, (_DenseVariational, _ConvVariational)):
|
||||
self.kl_loss.append(layer.compute_kl_loss)
|
||||
else:
|
||||
self._add_kl_loss(layer)
|
|
@ -0,0 +1,246 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transform DNN to BNN."""
|
||||
import mindspore.nn as nn
|
||||
from ...wrap.cell_wrapper import TrainOneStepCell
|
||||
from ....nn import optim
|
||||
from ....nn import layer
|
||||
from .bnn_loss.generate_kl_loss import gain_bnn_with_loss
|
||||
from ...probability import bnn_layers
|
||||
from ..bnn_layers.conv_variational import ConvReparam
|
||||
from ..bnn_layers.dense_variational import DenseReparam
|
||||
|
||||
__all__ = ['TransformToBNN']
|
||||
|
||||
|
||||
class TransformToBNN:
|
||||
r"""
|
||||
Transform Deep Neural Network (DNN) model to Bayesian Neural Network (BNN) model.
|
||||
|
||||
Args:
|
||||
trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell.
|
||||
dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function.
|
||||
bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> net_with_loss = WithLossCell(network, criterion)
|
||||
>>> train_network = TrainOneStepCell(net_with_loss, optim)
|
||||
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
||||
"""
|
||||
|
||||
def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
|
||||
net_with_loss = trainable_dnn.network
|
||||
self.optimizer = trainable_dnn.optimizer
|
||||
self.backbone = net_with_loss.backbone_network
|
||||
self.loss_fn = getattr(net_with_loss, "_loss_fn")
|
||||
self.dnn_factor = dnn_factor
|
||||
self.bnn_factor = bnn_factor
|
||||
self.bnn_loss_file = None
|
||||
|
||||
def transform_to_bnn_model(self,
|
||||
get_dense_args=lambda dp: {"in_channels": dp.in_channels, "has_bias": dp.has_bias,
|
||||
"out_channels": dp.out_channels, "activation": dp.activation},
|
||||
get_conv_args=lambda dp: {"in_channels": dp.in_channels, "out_channels": dp.out_channels,
|
||||
"pad_mode": dp.pad_mode, "kernel_size": dp.kernel_size,
|
||||
"stride": dp.stride, "has_bias": dp.has_bias,
|
||||
"padding": dp.padding, "dilation": dp.dilation,
|
||||
"group": dp.group},
|
||||
add_dense_args=None,
|
||||
add_conv_args=None):
|
||||
r"""
|
||||
Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell.
|
||||
|
||||
Args:
|
||||
get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp:
|
||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}.
|
||||
get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp:
|
||||
{"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode,
|
||||
"kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}.
|
||||
add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}.
|
||||
add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}.
|
||||
|
||||
Returns:
|
||||
Cell, a trainable BNN model wrapped by TrainOneStepCell.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> net_with_loss = WithLossCell(network, criterion)
|
||||
>>> train_network = TrainOneStepCell(net_with_loss, optim)
|
||||
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
||||
>>> train_bnn_network = bnn_transformer.transform_to_bnn_model()
|
||||
"""
|
||||
if not add_dense_args:
|
||||
add_dense_args = {}
|
||||
if not add_conv_args:
|
||||
add_conv_args = {}
|
||||
|
||||
layer_count = self._replace_all_bnn_layers(self.backbone, get_dense_args, get_conv_args, add_dense_args,
|
||||
add_conv_args)
|
||||
|
||||
# rename layers of BNN model to prevent duplication of names
|
||||
for value, param in self.backbone.parameters_and_names():
|
||||
param.name = value
|
||||
|
||||
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
|
||||
self.dnn_factor, self.bnn_factor)
|
||||
bnn_optimizer = self._create_optimizer_with_bnn_params()
|
||||
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
|
||||
return train_bnn_network
|
||||
|
||||
def transform_to_bnn_layer(self, dnn_layer_type, bnn_layer_type, get_args=None, add_args=None):
|
||||
r"""
|
||||
Transform a specific type of layers in DNN model to corresponding BNN layer.
|
||||
|
||||
Args:
|
||||
dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are
|
||||
nn.Dense, nn.Conv2d.
|
||||
bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are
|
||||
DenseReparameterization, ConvReparameterization.
|
||||
get_args (dict): The arguments gotten from the DNN layer. Default: None.
|
||||
add_args (dict): The new arguments added to BNN layer. Default: None.
|
||||
|
||||
Returns:
|
||||
Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the
|
||||
corresponding bayesian layer.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> net_with_loss = WithLossCell(network, criterion)
|
||||
>>> train_network = TrainOneStepCell(net_with_loss, optim)
|
||||
>>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1)
|
||||
>>> train_bnn_network = bnn_transformer.transform_to_bnn_layer(Dense, DenseReparam)
|
||||
"""
|
||||
if dnn_layer_type.__name__ not in ["Dense", "Conv2d"]:
|
||||
raise ValueError(' \'dnn_layer\'' + str(dnn_layer_type) +
|
||||
', should be one of values in \'nn.Dense\', \'nn.Conv2d\'.')
|
||||
|
||||
if bnn_layer_type.__name__ not in ["DenseReparam", "ConvReparam"]:
|
||||
raise ValueError(' \'bnn_layer\'' + str(bnn_layer_type) +
|
||||
', should be one of values in \'DenseReparam\', \'ConvReparam\'.')
|
||||
|
||||
dnn_layer_type = getattr(layer, dnn_layer_type.__name__)
|
||||
bnn_layer_type = getattr(bnn_layers, bnn_layer_type.__name__)
|
||||
|
||||
if not get_args:
|
||||
if dnn_layer_type.__name__ == "Dense":
|
||||
get_args = self._get_dense_args
|
||||
else:
|
||||
get_args = self._get_conv_args
|
||||
|
||||
if not add_args:
|
||||
add_args = {}
|
||||
|
||||
layer_count = self._replace_specified_dnn_layers(self.backbone, dnn_layer_type, bnn_layer_type, get_args,
|
||||
add_args)
|
||||
for value, param in self.backbone.parameters_and_names():
|
||||
param.name = value
|
||||
|
||||
bnn_with_loss, self.bnn_loss_file = gain_bnn_with_loss(layer_count, self.backbone, self.loss_fn,
|
||||
self.dnn_factor, self.bnn_factor)
|
||||
bnn_optimizer = self._create_optimizer_with_bnn_params()
|
||||
|
||||
train_bnn_network = TrainOneStepCell(bnn_with_loss, bnn_optimizer)
|
||||
return train_bnn_network
|
||||
|
||||
def _get_dense_args(self, dense_layer):
|
||||
"""Get arguments from dense layer."""
|
||||
dense_args = {"in_channels": dense_layer.in_channels, "has_bias": dense_layer.has_bias,
|
||||
"out_channels": dense_layer.out_channels, "activation": dense_layer.activation}
|
||||
return dense_args
|
||||
|
||||
def _get_conv_args(self, conv_layer):
|
||||
"""Get arguments from conv2d layer."""
|
||||
conv_args = {"in_channels": conv_layer.in_channels, "out_channels": conv_layer.out_channels,
|
||||
"pad_mode": conv_layer.pad_mode, "kernel_size": conv_layer.kernel_size,
|
||||
"stride": conv_layer.stride, "has_bias": conv_layer.has_bias,
|
||||
"padding": conv_layer.padding, "dilation": conv_layer.dilation,
|
||||
"group": conv_layer.group}
|
||||
return conv_args
|
||||
|
||||
def _create_optimizer_with_bnn_params(self):
|
||||
"""Create new optimizer that contains bnn trainable parameters."""
|
||||
name = self.optimizer.__class__.__name__
|
||||
modules = optim.__all__
|
||||
|
||||
if name not in modules:
|
||||
raise TypeError('The optimizer can be {}, but got {}'.format(str(modules), name))
|
||||
|
||||
optimizer = getattr(optim, name)
|
||||
|
||||
args = {'params': self.backbone.trainable_params()}
|
||||
params = optimizer.__init__.__code__.co_varnames
|
||||
_params = self.optimizer.__dict__['_params']
|
||||
for param in params:
|
||||
if param in _params:
|
||||
args[param] = self.optimizer.__getattr__(param).data.asnumpy().tolist()
|
||||
|
||||
new_optimizer = optimizer(**args)
|
||||
return new_optimizer
|
||||
|
||||
def _replace_all_bnn_layers(self, backbone, get_dense_args, get_conv_args, add_dense_args, add_conv_args):
|
||||
"""Replace both dense layer and conv2d layer in DNN model to bayesian layers."""
|
||||
count = 0
|
||||
for name, cell in backbone.name_cells().items():
|
||||
if isinstance(cell, nn.Dense):
|
||||
dense_args = get_dense_args(cell)
|
||||
new_layer = DenseReparam(**dense_args, **add_dense_args)
|
||||
setattr(backbone, name, new_layer)
|
||||
count += 1
|
||||
elif isinstance(cell, nn.Conv2d):
|
||||
conv_args = get_conv_args(cell)
|
||||
new_layer = ConvReparam(**conv_args, **add_conv_args)
|
||||
setattr(backbone, name, new_layer)
|
||||
count += 1
|
||||
else:
|
||||
count += self._replace_all_bnn_layers(cell, get_dense_args, get_conv_args, add_dense_args,
|
||||
add_conv_args)
|
||||
return count
|
||||
|
||||
def _replace_specified_dnn_layers(self, backbone, dnn_layer, bnn_layer, get_args, add_args):
|
||||
"""Convert a specific type of layers in DNN model to corresponding bayesian layers."""
|
||||
count = 0
|
||||
for name, cell in backbone.name_cells().items():
|
||||
if isinstance(cell, dnn_layer):
|
||||
args = get_args(cell)
|
||||
new_layer = bnn_layer(**args, **add_args)
|
||||
setattr(backbone, name, new_layer)
|
||||
count += 1
|
||||
else:
|
||||
count += self._replace_specified_dnn_layers(cell, dnn_layer, bnn_layer, get_args, add_args)
|
||||
return count
|
Loading…
Reference in New Issue