From 55b1d6feef2b32624d812078a49d0c3c5999087b Mon Sep 17 00:00:00 2001 From: caojian05 Date: Tue, 16 Jun 2020 11:22:52 +0800 Subject: [PATCH] refactoring code directory for vgg16 and lstm --- .../lstm_aclImdb => model_zoo/lstm}/README.md | 0 .../lstm_aclImdb => model_zoo/lstm}/eval.py | 6 +- model_zoo/lstm/src/__init__.py | 14 +++ .../lstm/src}/config.py | 0 .../lstm/src}/dataset.py | 4 +- .../lstm/src}/imdb.py | 0 model_zoo/lstm/src/lstm.py | 93 ++++++++++++++++ .../lstm_aclImdb => model_zoo/lstm}/train.py | 8 +- .../vgg16}/README.md | 2 +- .../vgg16_cifar10 => model_zoo/vgg16}/eval.py | 11 +- .../vgg16/scripts}/run_distribute_train.sh | 17 ++- model_zoo/vgg16/src/__init__.py | 14 +++ .../vgg16/src}/config.py | 0 .../vgg16/src}/dataset.py | 8 +- model_zoo/vgg16/src/vgg.py | 104 ++++++++++++++++++ .../vgg16}/train.py | 19 ++-- 16 files changed, 266 insertions(+), 34 deletions(-) rename {example/lstm_aclImdb => model_zoo/lstm}/README.md (100%) rename {example/lstm_aclImdb => model_zoo/lstm}/eval.py (94%) create mode 100644 model_zoo/lstm/src/__init__.py rename {example/lstm_aclImdb => model_zoo/lstm/src}/config.py (100%) rename {example/lstm_aclImdb => model_zoo/lstm/src}/dataset.py (96%) rename {example/lstm_aclImdb => model_zoo/lstm/src}/imdb.py (100%) create mode 100644 model_zoo/lstm/src/lstm.py rename {example/lstm_aclImdb => model_zoo/lstm}/train.py (94%) rename {example/vgg16_cifar10 => model_zoo/vgg16}/README.md (97%) rename {example/vgg16_cifar10 => model_zoo/vgg16}/eval.py (93%) rename {example/vgg16_cifar10 => model_zoo/vgg16/scripts}/run_distribute_train.sh (92%) create mode 100644 model_zoo/vgg16/src/__init__.py rename {example/vgg16_cifar10 => model_zoo/vgg16/src}/config.py (100%) rename {example/vgg16_cifar10 => model_zoo/vgg16/src}/dataset.py (96%) create mode 100644 model_zoo/vgg16/src/vgg.py rename {example/vgg16_cifar10 => model_zoo/vgg16}/train.py (93%) diff --git a/example/lstm_aclImdb/README.md b/model_zoo/lstm/README.md similarity index 100% rename from example/lstm_aclImdb/README.md rename to model_zoo/lstm/README.md diff --git a/example/lstm_aclImdb/eval.py b/model_zoo/lstm/eval.py similarity index 94% rename from example/lstm_aclImdb/eval.py rename to model_zoo/lstm/eval.py index e76d40ac67f..04e60d3a074 100644 --- a/example/lstm_aclImdb/eval.py +++ b/model_zoo/lstm/eval.py @@ -21,8 +21,8 @@ import os import numpy as np -from config import lstm_cfg as cfg -from dataset import create_dataset, convert_to_mindrecord +from src.config import lstm_cfg as cfg +from src.dataset import lstm_create_dataset, convert_to_mindrecord from mindspore import Tensor, nn, Model, context from mindspore.model_zoo.lstm import SentimentNet from mindspore.nn import Accuracy @@ -71,7 +71,7 @@ if __name__ == '__main__': model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Testing ==============") - ds_eval = create_dataset(args.preprocess_path, cfg.batch_size, training=False) + ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) if args.device_target == "CPU": diff --git a/model_zoo/lstm/src/__init__.py b/model_zoo/lstm/src/__init__.py new file mode 100644 index 00000000000..301ef9dcb71 --- /dev/null +++ b/model_zoo/lstm/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/example/lstm_aclImdb/config.py b/model_zoo/lstm/src/config.py similarity index 100% rename from example/lstm_aclImdb/config.py rename to model_zoo/lstm/src/config.py diff --git a/example/lstm_aclImdb/dataset.py b/model_zoo/lstm/src/dataset.py similarity index 96% rename from example/lstm_aclImdb/dataset.py rename to model_zoo/lstm/src/dataset.py index 24797198e0c..03d4276dfd0 100644 --- a/example/lstm_aclImdb/dataset.py +++ b/model_zoo/lstm/src/dataset.py @@ -19,12 +19,12 @@ import os import numpy as np -from imdb import ImdbParser import mindspore.dataset as ds from mindspore.mindrecord import FileWriter +from .imdb import ImdbParser -def create_dataset(data_home, batch_size, repeat_num=1, training=True): +def lstm_create_dataset(data_home, batch_size, repeat_num=1, training=True): """Data operations.""" ds.config.set_seed(1) data_dir = os.path.join(data_home, "aclImdb_train.mindrecord0") diff --git a/example/lstm_aclImdb/imdb.py b/model_zoo/lstm/src/imdb.py similarity index 100% rename from example/lstm_aclImdb/imdb.py rename to model_zoo/lstm/src/imdb.py diff --git a/model_zoo/lstm/src/lstm.py b/model_zoo/lstm/src/lstm.py new file mode 100644 index 00000000000..f014eef8df0 --- /dev/null +++ b/model_zoo/lstm/src/lstm.py @@ -0,0 +1,93 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LSTM.""" + +import numpy as np + +from mindspore import Tensor, nn, context +from mindspore.ops import operations as P + +# Initialize short-term memory (h) and long-term memory (c) to 0 +def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): + """init default input.""" + num_directions = 1 + if bidirectional: + num_directions = 2 + + if context.get_context("device_target") == "CPU": + h_list = [] + c_list = [] + i = 0 + while i < num_layers: + hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) + h_list.append(hi) + ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) + c_list.append(ci) + i = i + 1 + h = tuple(h_list) + c = tuple(c_list) + return h, c + + h = Tensor( + np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) + c = Tensor( + np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) + return h, c + + +class SentimentNet(nn.Cell): + """Sentiment network structure.""" + + def __init__(self, + vocab_size, + embed_size, + num_hiddens, + num_layers, + bidirectional, + num_classes, + weight, + batch_size): + super(SentimentNet, self).__init__() + # Mapp words to vectors + self.embedding = nn.Embedding(vocab_size, + embed_size, + embedding_table=weight) + self.embedding.embedding_table.requires_grad = False + self.trans = P.Transpose() + self.perm = (1, 0, 2) + self.encoder = nn.LSTM(input_size=embed_size, + hidden_size=num_hiddens, + num_layers=num_layers, + has_bias=True, + bidirectional=bidirectional, + dropout=0.0) + + self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) + + self.concat = P.Concat(1) + if bidirectional: + self.decoder = nn.Dense(num_hiddens * 4, num_classes) + else: + self.decoder = nn.Dense(num_hiddens * 2, num_classes) + + def construct(self, inputs): + # input:(64,500,300) + embeddings = self.embedding(inputs) + embeddings = self.trans(embeddings, self.perm) + output, _ = self.encoder(embeddings, (self.h, self.c)) + # states[i] size(64,200) -> encoding.size(64,400) + encoding = self.concat((output[0], output[-1])) + outputs = self.decoder(encoding) + return outputs diff --git a/example/lstm_aclImdb/train.py b/model_zoo/lstm/train.py similarity index 94% rename from example/lstm_aclImdb/train.py rename to model_zoo/lstm/train.py index 08bea7c63d0..fd0e7fdd15b 100644 --- a/example/lstm_aclImdb/train.py +++ b/model_zoo/lstm/train.py @@ -21,9 +21,9 @@ import os import numpy as np -from config import lstm_cfg as cfg -from dataset import convert_to_mindrecord -from dataset import create_dataset +from src.config import lstm_cfg as cfg +from src.dataset import convert_to_mindrecord +from src.dataset import lstm_create_dataset from mindspore import Tensor, nn, Model, context from mindspore.model_zoo.lstm import SentimentNet from mindspore.nn import Accuracy @@ -71,7 +71,7 @@ if __name__ == '__main__': model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Training ==============") - ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) + ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck) diff --git a/example/vgg16_cifar10/README.md b/model_zoo/vgg16/README.md similarity index 97% rename from example/vgg16_cifar10/README.md rename to model_zoo/vgg16/README.md index 2c3de2eed96..75f36d28884 100644 --- a/example/vgg16_cifar10/README.md +++ b/model_zoo/vgg16/README.md @@ -98,7 +98,7 @@ parameters/options: ### Distribute Training ``` -Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] +Usage: sh script/run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH] parameters/options: MINDSPORE_HCCL_CONFIG_PATH HCCL configuration file path. diff --git a/example/vgg16_cifar10/eval.py b/model_zoo/vgg16/eval.py similarity index 93% rename from example/vgg16_cifar10/eval.py rename to model_zoo/vgg16/eval.py index ec9fc607c2d..8cdcc86031b 100644 --- a/example/vgg16_cifar10/eval.py +++ b/model_zoo/vgg16/eval.py @@ -17,14 +17,15 @@ python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID """ import argparse + import mindspore.nn as nn +from mindspore import context from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model -from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.model_zoo.vgg import vgg16 -from config import cifar_cfg as cfg -import dataset +from src.config import cifar_cfg as cfg +from src.dataset import vgg_create_dataset +from src.vgg import vgg16 if __name__ == '__main__': parser = argparse.ArgumentParser(description='Cifar10 classification') @@ -47,6 +48,6 @@ if __name__ == '__main__': param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(net, param_dict) net.set_train(False) - dataset = dataset.create_dataset(args_opt.data_path, 1, False) + dataset = vgg_create_dataset(args_opt.data_path, 1, False) res = model.eval(dataset) print("result: ", res) diff --git a/example/vgg16_cifar10/run_distribute_train.sh b/model_zoo/vgg16/scripts/run_distribute_train.sh similarity index 92% rename from example/vgg16_cifar10/run_distribute_train.sh rename to model_zoo/vgg16/scripts/run_distribute_train.sh index c9b8dfc48f9..ca4c993deda 100755 --- a/example/vgg16_cifar10/run_distribute_train.sh +++ b/model_zoo/vgg16/scripts/run_distribute_train.sh @@ -15,39 +15,38 @@ # ============================================================================ if [ $# != 2 ] -then +then echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATA_PATH]" exit 1 fi if [ ! -f $1 ] -then +then echo "error: MINDSPORE_HCCL_CONFIG_PATH=$1 is not a file" exit 1 -fi +fi if [ ! -d $2 ] -then +then echo "error: DATA_PATH=$2 is not a directory" exit 1 -fi +fi -ulimit -u unlimited export DEVICE_NUM=8 export RANK_SIZE=8 export MINDSPORE_HCCL_CONFIG_PATH=$1 -for((i=0; i<${DEVICE_NUM}; i++)) +for((i=0;i env.log python train.py --data_path=$2 --device_id=$i &> log & cd .. -done +done \ No newline at end of file diff --git a/model_zoo/vgg16/src/__init__.py b/model_zoo/vgg16/src/__init__.py new file mode 100644 index 00000000000..301ef9dcb71 --- /dev/null +++ b/model_zoo/vgg16/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/example/vgg16_cifar10/config.py b/model_zoo/vgg16/src/config.py similarity index 100% rename from example/vgg16_cifar10/config.py rename to model_zoo/vgg16/src/config.py diff --git a/example/vgg16_cifar10/dataset.py b/model_zoo/vgg16/src/dataset.py similarity index 96% rename from example/vgg16_cifar10/dataset.py rename to model_zoo/vgg16/src/dataset.py index e8dfd777e6b..b08659fb5ea 100644 --- a/example/vgg16_cifar10/dataset.py +++ b/model_zoo/vgg16/src/dataset.py @@ -16,13 +16,15 @@ Data operations, will be used in train.py and eval.py """ import os + +import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as vision -import mindspore.common.dtype as mstype -from config import cifar_cfg as cfg +from .config import cifar_cfg as cfg -def create_dataset(data_home, repeat_num=1, training=True): + +def vgg_create_dataset(data_home, repeat_num=1, training=True): """Data operations.""" ds.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") diff --git a/model_zoo/vgg16/src/vgg.py b/model_zoo/vgg16/src/vgg.py new file mode 100644 index 00000000000..55130871cc9 --- /dev/null +++ b/model_zoo/vgg16/src/vgg.py @@ -0,0 +1,104 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""VGG.""" +import mindspore.nn as nn +from mindspore.common.initializer import initializer +import mindspore.common.dtype as mstype + +def _make_layer(base, batch_norm): + """Make stage network of VGG.""" + layers = [] + in_channels = 3 + for v in base: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + weight_shape = (v, in_channels, 3, 3) + weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor() + conv2d = nn.Conv2d(in_channels=in_channels, + out_channels=v, + kernel_size=3, + padding=0, + pad_mode='same', + weight_init=weight) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] + else: + layers += [conv2d, nn.ReLU()] + in_channels = v + return nn.SequentialCell(layers) + + +class Vgg(nn.Cell): + """ + VGG network definition. + + Args: + base (list): Configuration for different layers, mainly the channel number of Conv layer. + num_classes (int): Class numbers. Default: 1000. + batch_norm (bool): Whether to do the batchnorm. Default: False. + batch_size (int): Batch size. Default: 1. + + Returns: + Tensor, infer output tensor. + + Examples: + >>> Vgg([64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + >>> num_classes=1000, batch_norm=False, batch_size=1) + """ + + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): + super(Vgg, self).__init__() + _ = batch_size + self.layers = _make_layer(base, batch_norm=batch_norm) + self.flatten = nn.Flatten() + self.classifier = nn.SequentialCell([ + nn.Dense(512 * 7 * 7, 4096), + nn.ReLU(), + nn.Dense(4096, 4096), + nn.ReLU(), + nn.Dense(4096, num_classes)]) + + def construct(self, x): + x = self.layers(x) + x = self.flatten(x) + x = self.classifier(x) + return x + + +cfg = { + '11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + '16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + '19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def vgg16(num_classes=1000): + """ + Get Vgg16 neural network with batch normalization. + + Args: + num_classes (int): Class numbers. Default: 1000. + + Returns: + Cell, cell instance of Vgg16 neural network with batch normalization. + + Examples: + >>> vgg16(num_classes=1000) + """ + + net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) + return net diff --git a/example/vgg16_cifar10/train.py b/model_zoo/vgg16/train.py similarity index 93% rename from example/vgg16_cifar10/train.py rename to model_zoo/vgg16/train.py index 9993db706a4..496aedb25a3 100644 --- a/example/vgg16_cifar10/train.py +++ b/model_zoo/vgg16/train.py @@ -19,20 +19,24 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID import argparse import os import random + import numpy as np + import mindspore.nn as nn from mindspore import Tensor +from mindspore import context from mindspore.communication.management import init from mindspore.nn.optim.momentum import Momentum -from mindspore.train.model import Model, ParallelMode -from mindspore import context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor -from mindspore.model_zoo.vgg import vgg16 -from dataset import create_dataset -from config import cifar_cfg as cfg +from mindspore.train.model import Model, ParallelMode +from src.config import cifar_cfg as cfg +from src.dataset import vgg_create_dataset +from src.vgg import vgg16 + random.seed(1) np.random.seed(1) + def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): """Set learning rate.""" lr_each_step = [] @@ -72,12 +76,13 @@ if __name__ == '__main__': mirror_mean=True) init() - dataset = create_dataset(args_opt.data_path, cfg.epoch_size) + dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size) batch_num = dataset.get_dataset_size() net = vgg16(num_classes=cfg.num_classes) lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, + weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)