forked from mindspore-Ecosystem/mindspore
133 lines
5.5 KiB
Python
133 lines
5.5 KiB
Python
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""start train """
|
|
import sys
|
|
import os
|
|
import pickle
|
|
import argparse
|
|
import lmdb
|
|
from mindspore.common import set_seed
|
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|
from mindspore import context
|
|
from mindspore.context import ParallelMode
|
|
import mindspore.dataset as ds
|
|
from mindspore import nn
|
|
from mindspore.train import Model
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
import mindspore.dataset.transforms.py_transforms as py_transforms
|
|
from src.config import config
|
|
from src.alexnet import SiameseAlexNet
|
|
from src.dataset import ImagnetVIDDataset
|
|
from src.custom_transforms import ToTensor, RandomStretch, RandomCrop, CenterCrop
|
|
sys.path.append(os.getcwd())
|
|
|
|
|
|
|
|
def train(data_dir):
|
|
"""set train """
|
|
# loading meta data
|
|
meta_data_path = os.path.join(data_dir, "meta_data.pkl")
|
|
meta_data = pickle.load(open(meta_data_path, 'rb'))
|
|
all_videos = [x[0] for x in meta_data]
|
|
|
|
set_seed(1234)
|
|
random_crop_size = config.instance_size - 2 * config.total_stride
|
|
train_z_transforms = py_transforms.Compose([
|
|
RandomStretch(),
|
|
CenterCrop((config.exemplar_size, config.exemplar_size)),
|
|
ToTensor()
|
|
])
|
|
train_x_transforms = py_transforms.Compose([
|
|
RandomStretch(),
|
|
RandomCrop((random_crop_size, random_crop_size),
|
|
config.max_translate),
|
|
ToTensor()
|
|
])
|
|
db_open = lmdb.open(data_dir + '.lmdb', readonly=True, map_size=int(50e12))
|
|
# create dataset
|
|
train_dataset = ImagnetVIDDataset(db_open, all_videos, data_dir,
|
|
train_z_transforms, train_x_transforms)
|
|
dataset = ds.GeneratorDataset(train_dataset, ["exemplar_img", "instance_img"], shuffle=True,
|
|
num_parallel_workers=8)
|
|
dataset = dataset.batch(batch_size=8, drop_remainder=True)
|
|
#set network
|
|
network = SiameseAlexNet(train=True)
|
|
decay_lr = nn.polynomial_decay_lr(config.lr,
|
|
config.end_lr,
|
|
total_step=config.epoch * config.num_per_epoch,
|
|
step_per_epoch=config.num_per_epoch,
|
|
decay_epoch=config.epoch,
|
|
power=1.0)
|
|
optim = nn.SGD(params=network.trainable_params(),
|
|
learning_rate=decay_lr,
|
|
momentum=config.momentum,
|
|
weight_decay=config.weight_decay)
|
|
|
|
|
|
loss_scale_manager = DynamicLossScaleManager()
|
|
model = Model(network,
|
|
optimizer=optim,
|
|
loss_scale_manager=loss_scale_manager,
|
|
metrics=None,
|
|
amp_level='O3')
|
|
config_ck_train = CheckpointConfig(save_checkpoint_steps=6650, keep_checkpoint_max=20)
|
|
ckpoint_cb_train = ModelCheckpoint(prefix='SiamFC',
|
|
directory='./models/siamfc_{}.ckpt',
|
|
config=config_ck_train)
|
|
time_cb_train = TimeMonitor(data_size=config.num_per_epoch)
|
|
loss_cb_train = LossMonitor()
|
|
|
|
model.train(epoch=config.epoch,
|
|
train_dataset=dataset,
|
|
callbacks=[time_cb_train, ckpoint_cb_train, loss_cb_train],
|
|
dataset_sink_mode=True
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
ARGPARSER = argparse.ArgumentParser(description=" SiamFC Train")
|
|
ARGPARSER.add_argument('--device_target',
|
|
type=str,
|
|
default="Ascend",
|
|
choices=['GPU', 'CPU', 'Ascend'],
|
|
help='the target device to run, support "GPU", "CPU"')
|
|
ARGPARSER.add_argument('--data_path',
|
|
default="/data/VID/ILSVRC_VID_CURATION_train",
|
|
type=str,
|
|
help=" the path of data")
|
|
ARGPARSER.add_argument('--sink_size',
|
|
type=int, default=-1,
|
|
help='control the amount of data in each sink')
|
|
ARGPARSER.add_argument('--device_id',
|
|
type=int, default=7,
|
|
help='device id of GPU or Ascend')
|
|
ARGS = ARGPARSER.parse_args()
|
|
|
|
DEVICENUM = int(os.environ.get("DEVICE_NUM", 1))
|
|
DEVICETARGET = ARGS.device_target
|
|
if DEVICETARGET == "Ascend":
|
|
context.set_context(
|
|
mode=context.GRAPH_MODE,
|
|
device_id=ARGS.device_id,
|
|
save_graphs=False,
|
|
device_target=ARGS.device_target)
|
|
if DEVICENUM > 1:
|
|
context.reset_auto_parallel_context()
|
|
context.set_auto_parallel_context(device_num=DEVICENUM,
|
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
gradients_mean=True)
|
|
# train
|
|
train(ARGS.data_path)
|