From b126272b39cc1a16e138bec3ae24c4839d346bc9 Mon Sep 17 00:00:00 2001 From: bingyaweng Date: Fri, 21 Aug 2020 10:54:40 +0800 Subject: [PATCH] fix error of conv_variational --- mindspore/nn/probability/bnn_layers/conv_variational.py | 8 ++++---- tests/st/probability/test_bnn_layer.py | 2 +- tests/st/probability/test_transform_bnn_layer.py | 2 +- tests/st/probability/test_transform_bnn_model.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/probability/bnn_layers/conv_variational.py b/mindspore/nn/probability/bnn_layers/conv_variational.py index b3f77e9a7ce..f0087870d69 100644 --- a/mindspore/nn/probability/bnn_layers/conv_variational.py +++ b/mindspore/nn/probability/bnn_layers/conv_variational.py @@ -61,10 +61,10 @@ class _ConvVariational(_Conv): raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') - if not isinstance(stride, (int, tuple)): + if isinstance(stride, bool) or not isinstance(stride, (int, tuple)): raise TypeError('The type of `stride` should be `int` of `tuple`') - if not isinstance(dilation, (int, tuple)): + if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)): raise TypeError('The type of `dilation` should be `int` of `tuple`') # convolution args @@ -136,8 +136,8 @@ class _ConvVariational(_Conv): return outputs def extend_repr(self): - str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \ - 'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\ + str_info = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \ + 'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}'\ .format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std, self.has_bias) diff --git a/tests/st/probability/test_bnn_layer.py b/tests/st/probability/test_bnn_layer.py index b135d0bf085..742b17c2688 100644 --- a/tests/st/probability/test_bnn_layer.py +++ b/tests/st/probability/test_bnn_layer.py @@ -137,7 +137,7 @@ if __name__ == "__main__": epoch = 100 for i in range(epoch): - train_loss, train_acc = train_model(train_bnn_network, network, test_set) + train_loss, train_acc = train_model(train_bnn_network, network, train_set) valid_acc = validate_model(network, test_set) diff --git a/tests/st/probability/test_transform_bnn_layer.py b/tests/st/probability/test_transform_bnn_layer.py index 590fee8e811..3fd4bfd4001 100644 --- a/tests/st/probability/test_transform_bnn_layer.py +++ b/tests/st/probability/test_transform_bnn_layer.py @@ -142,7 +142,7 @@ if __name__ == "__main__": epoch = 100 for i in range(epoch): - train_loss, train_acc = train_model(train_bnn_network, network, test_set) + train_loss, train_acc = train_model(train_bnn_network, network, train_set) valid_acc = validate_model(network, test_set) diff --git a/tests/st/probability/test_transform_bnn_model.py b/tests/st/probability/test_transform_bnn_model.py index 015a1f41d76..5cc7733e891 100644 --- a/tests/st/probability/test_transform_bnn_model.py +++ b/tests/st/probability/test_transform_bnn_model.py @@ -141,7 +141,7 @@ if __name__ == "__main__": epoch = 500 for i in range(epoch): - train_loss, train_acc = train_model(train_bnn_network, network, test_set) + train_loss, train_acc = train_model(train_bnn_network, network, train_set) valid_acc = validate_model(network, test_set)