This commit is contained in:
bingyaweng 2020-08-20 16:05:14 +08:00
parent 99bac63475
commit 2037336195
4 changed files with 504 additions and 0 deletions

View File

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds

View File

@ -0,0 +1,145 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test bnn layers"""
import numpy as np
from mindspore import Tensor
from mindspore.common.initializer import TruncatedNormal
import mindspore.nn as nn
from mindspore.nn import TrainOneStepCell
from mindspore.nn.probability import bnn_layers
from mindspore.ops import operations as P
from mindspore import context
from dataset import create_dataset
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class BNNLeNet5(nn.Cell):
"""
bayesian Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> BNNLeNet5(num_class=10)
"""
def __init__(self, num_class=10):
super(BNNLeNet5, self).__init__()
self.num_class = num_class
self.conv1 = bnn_layers.ConvReparam(1, 6, 5, stride=1, padding=0, has_bias=False, pad_mode="valid")
self.conv2 = conv(6, 16, 5)
self.fc1 = bnn_layers.DenseReparam(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def train_model(train_net, net, dataset):
accs = []
loss_sum = 0
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
loss = train_net(train_x, label)
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
loss_sum += loss.asnumpy()
loss_sum = loss_sum / len(accs)
acc_mean = np.mean(accs)
return loss_sum, acc_mean
def validate_model(net, dataset):
accs = []
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
acc_mean = np.mean(accs)
return acc_mean
if __name__ == "__main__":
network = BNNLeNet5()
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)
net_with_loss = bnn_layers.WithBNNLossCell(network, criterion, 60000, 0.000001)
train_bnn_network = TrainOneStepCell(net_with_loss, optimizer)
train_bnn_network.set_train()
train_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/train', 64, 1)
test_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/test', 64, 1)
epoch = 100
for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
valid_acc = validate_model(network, test_set)
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tvalidation Accuracy: {:.4f}'.format(
i, train_loss, train_acc, valid_acc))

View File

@ -0,0 +1,150 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test transform_to_bnn_layer"""
import numpy as np
from mindspore import Tensor
from mindspore.common.initializer import TruncatedNormal
import mindspore.nn as nn
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.probability import transforms, bnn_layers
from mindspore.ops import operations as P
from mindspore import context
from dataset import create_dataset
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet5(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def train_model(train_net, net, dataset):
accs = []
loss_sum = 0
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
loss = train_net(train_x, label)
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
loss_sum += loss.asnumpy()
loss_sum = loss_sum / len(accs)
acc_mean = np.mean(accs)
return loss_sum, acc_mean
def validate_model(net, dataset):
accs = []
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
acc_mean = np.mean(accs)
return acc_mean
if __name__ == "__main__":
network = LeNet5()
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)
net_with_loss = WithLossCell(network, criterion)
train_network = TrainOneStepCell(net_with_loss, optimizer)
bnn_transformer = transforms.TransformToBNN(train_network, 60000, 0.000001)
train_bnn_network = bnn_transformer.transform_to_bnn_layer(nn.Conv2d, bnn_layers.ConvReparam)
# train_bnn_network = bnn_transformer.transform_to_bnn_layer(nn.Dense, bnn_layers.DenseReparam)
train_bnn_network.set_train()
train_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/train', 64, 1)
test_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/test', 64, 1)
epoch = 100
for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
valid_acc = validate_model(network, test_set)
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tvalidation Accuracy: {:.4f}'.format(
i, train_loss, train_acc, valid_acc))

View File

@ -0,0 +1,149 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test transform_to_bnn_model"""
import numpy as np
from mindspore import Tensor
from mindspore.common.initializer import TruncatedNormal
import mindspore.nn as nn
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn.probability import transforms
from mindspore.ops import operations as P
from mindspore import context
from dataset import create_dataset
context.set_context(mode=context.GRAPH_MODE, save_graphs=False, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet5(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def train_model(train_net, net, dataset):
accs = []
loss_sum = 0
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
loss = train_net(train_x, label)
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
loss_sum += loss.asnumpy()
loss_sum = loss_sum / len(accs)
acc_mean = np.mean(accs)
return loss_sum, acc_mean
def validate_model(net, dataset):
accs = []
for _, data in enumerate(dataset.create_dict_iterator()):
train_x = Tensor(data['image'].astype(np.float32))
label = Tensor(data['label'].astype(np.int32))
output = net(train_x)
log_output = P.LogSoftmax(axis=1)(output)
acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy())
accs.append(acc)
acc_mean = np.mean(accs)
return acc_mean
if __name__ == "__main__":
network = LeNet5()
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001)
net_with_loss = WithLossCell(network, criterion)
train_network = TrainOneStepCell(net_with_loss, optimizer)
bnn_transformer = transforms.TransformToBNN(train_network, 60000, 0.000001)
train_bnn_network = bnn_transformer.transform_to_bnn_model()
train_bnn_network.set_train()
train_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/train', 64, 1)
test_set = create_dataset('/home/workspace/mindspore_dataset/mnist_data/test', 64, 1)
epoch = 500
for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
valid_acc = validate_model(network, test_set)
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tvalidation Accuracy: {:.4f}'.format(
i, train_loss, train_acc, valid_acc))