diff --git a/model_zoo/official/cv/resnext50/README.md b/model_zoo/official/cv/resnext50/README.md index 55b54d84ac..aab2952c96 100644 --- a/model_zoo/official/cv/resnext50/README.md +++ b/model_zoo/official/cv/resnext50/README.md @@ -90,10 +90,15 @@ sh run_standalone_train.sh DEVICE_ID DATA_PATH #### Launch ```bash -# distributed training example(8p) +# distributed training example(8p) for Ascend sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train -# standalone training example +# standalone training example for Ascend sh scripts/run_standalone_train.sh 0 /dataset/train + +# distributed training example(8p) for GPU +sh scripts/run_distribute_train_for_gpu.sh /dataset/train +# standalone training example for GPU +sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train ``` #### Result @@ -106,14 +111,15 @@ You can find checkpoint file together with result in log. ``` # Evaluation -sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH +sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM ``` +PLATFORM is Ascend or GPU, default is Ascend. #### Launch ```bash # Evaluation with checkpoint -sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt +sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt Ascend ``` > checkpoint can be produced in training process. diff --git a/model_zoo/official/cv/resnext50/eval.py b/model_zoo/official/cv/resnext50/eval.py index ff5c83843e..4dc2aa485a 100644 --- a/model_zoo/official/cv/resnext50/eval.py +++ b/model_zoo/official/cv/resnext50/eval.py @@ -29,15 +29,11 @@ from mindspore.ops import functional as F from mindspore.common import dtype as mstype from src.utils.logging import get_logger +from src.utils.auto_mixed_precision import auto_mixed_precision from src.image_classification import get_network from src.dataset import classification_dataset from src.config import config -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target="Ascend", save_graphs=False, device_id=devid) - - class ParameterReduce(nn.Cell): """ParameterReduce""" @@ -56,6 +52,7 @@ class ParameterReduce(nn.Cell): def parse_args(cloud_args=None): """parse_args""" parser = argparse.ArgumentParser('mindspore classification test') + parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') # dataset related parser.add_argument('--data_dir', type=str, default='/opt/npu/datasets/classification/val', help='eval data dir') @@ -108,12 +105,25 @@ def merge_args(args, cloud_args): def test(cloud_args=None): """test""" args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) # init distributed if args.is_distributed: - init() + if args.platform == "Ascend": + init() + elif args.platform == "GPU": + init("nccl") args.rank = get_rank() args.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + args.rank = 0 + args.group_size = 1 args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) @@ -140,7 +150,7 @@ def test(cloud_args=None): max_epoch=1, rank=args.rank, group_size=args.group_size, mode='eval') eval_dataloader = de_dataset.create_tuple_iterator() - network = get_network(args.backbone, args.num_classes) + network = get_network(args.backbone, args.num_classes, platform=args.platform) if network is None: raise NotImplementedError('not implement {}'.format(args.backbone)) @@ -157,12 +167,13 @@ def test(cloud_args=None): load_param_into_net(network, param_dict_new) args.logger.info('load model {} success'.format(model)) - # must add - network.add_flags_recursive(fp16=True) - img_tot = 0 top1_correct = 0 top5_correct = 0 + if args.platform == "Ascend": + network.to_float(mstype.float16) + else: + auto_mixed_precision(network) network.set_train(False) t_end = time.time() it = 0 diff --git a/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 0000000000..6ab980a0fa --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +DATA_DIR=$1 +export RANK_SIZE=8 +PATH_CHECKPOINT="" +if [ $# == 2 ] +then + PATH_CHECKPOINT=$2 +fi + +mpirun --allow-run-as-root -n $RANK_SIZE \ + python train.py \ + --is_distribute=1 \ + --platform="GPU" \ + --pretrained=$PATH_CHECKPOINT \ + --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/cv/resnext50/scripts/run_eval.sh b/model_zoo/official/cv/resnext50/scripts/run_eval.sh index 610faa874e..c884180950 100644 --- a/model_zoo/official/cv/resnext50/scripts/run_eval.sh +++ b/model_zoo/official/cv/resnext50/scripts/run_eval.sh @@ -14,11 +14,16 @@ # limitations under the License. # ============================================================================ -DEVICE_ID=$1 +export DEVICE_ID=$1 DATA_DIR=$2 PATH_CHECKPOINT=$3 +PLATFORM=Ascend +if [ $# == 4 ] +then + PLATFORM=$4 +fi python eval.py \ - --device_id=$DEVICE_ID \ --pretrained=$PATH_CHECKPOINT \ + --platform=$PLATFORM \ --data_dir=$DATA_DIR > log.txt 2>&1 & diff --git a/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh b/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh index ca5d8206f3..f10d7a2f57 100644 --- a/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh +++ b/model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================ -DEVICE_ID=$1 +export DEVICE_ID=$1 DATA_DIR=$2 PATH_CHECKPOINT="" if [ $# == 3 ] diff --git a/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 0000000000..1d1d82fb88 --- /dev/null +++ b/model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +export DEVICE_ID=$1 +DATA_DIR=$2 +PATH_CHECKPOINT="" +if [ $# == 3 ] +then + PATH_CHECKPOINT=$3 +fi + +python train.py \ + --is_distribute=0 \ + --pretrained=$PATH_CHECKPOINT \ + --platform="GPU" \ + --data_dir=$DATA_DIR > log.txt 2>&1 & + diff --git a/model_zoo/official/cv/resnext50/src/backbone/resnet.py b/model_zoo/official/cv/resnext50/src/backbone/resnet.py index 5b69f9e1f5..9c880154ea 100644 --- a/model_zoo/official/cv/resnext50/src/backbone/resnet.py +++ b/model_zoo/official/cv/resnext50/src/backbone/resnet.py @@ -87,7 +87,8 @@ class BasicBlock(nn.Cell): """ expansion = 1 - def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs): + def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, + platform="Ascend", **kwargs): super(BasicBlock, self).__init__() self.conv1 = conv3x3(in_channels, out_channels, stride=stride) self.bn1 = nn.BatchNorm2d(out_channels) @@ -142,7 +143,7 @@ class Bottleneck(nn.Cell): expansion = 4 def __init__(self, in_channels, out_channels, stride=1, down_sample=None, - base_width=64, groups=1, use_se=False, **kwargs): + base_width=64, groups=1, use_se=False, platform="Ascend", **kwargs): super(Bottleneck, self).__init__() width = int(out_channels * (base_width / 64.0)) * groups @@ -153,7 +154,11 @@ class Bottleneck(nn.Cell): self.conv3x3s = nn.CellList() - self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) + if platform == "GPU": + self.conv2 = nn.Conv2d(width, width, 3, stride, pad_mode='pad', padding=1, group=groups) + else: + self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups) + self.op_split = Split(axis=1, output_num=self.groups) self.op_concat = Concat(axis=1) @@ -211,7 +216,7 @@ class ResNet(nn.Cell): Examples: >>>ResNet() """ - def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False): + def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False, platform="Ascend"): super(ResNet, self).__init__() self.in_channels = 64 self.groups = groups @@ -222,10 +227,10 @@ class ResNet(nn.Cell): self.relu = P.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') - self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se) + self.layer1 = self._make_layer(block, 64, layers[0], use_se=use_se, platform=platform) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_se=use_se, platform=platform) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_se=use_se, platform=platform) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_se=use_se, platform=platform) self.out_channels = 512 * block.expansion self.cast = P.Cast() @@ -242,7 +247,7 @@ class ResNet(nn.Cell): return x - def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False): + def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False, platform="Ascend"): """_make_layer""" down_sample = None if stride != 1 or self.in_channels != out_channels * block.expansion: @@ -257,11 +262,12 @@ class ResNet(nn.Cell): down_sample=down_sample, base_width=self.base_width, groups=self.groups, - use_se=use_se)) + use_se=use_se, + platform=platform)) self.in_channels = out_channels * block.expansion for _ in range(1, blocks_num): - layers.append(block(self.in_channels, out_channels, - base_width=self.base_width, groups=self.groups, use_se=use_se)) + layers.append(block(self.in_channels, out_channels, base_width=self.base_width, + groups=self.groups, use_se=use_se, platform=platform)) return nn.SequentialCell(layers) @@ -269,5 +275,5 @@ class ResNet(nn.Cell): return self.out_channels -def resnext50(): - return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32) +def resnext50(platform="Ascend"): + return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32, platform=platform) diff --git a/model_zoo/official/cv/resnext50/src/config.py b/model_zoo/official/cv/resnext50/src/config.py index c1a12aa14e..0acff08342 100644 --- a/model_zoo/official/cv/resnext50/src/config.py +++ b/model_zoo/official/cv/resnext50/src/config.py @@ -36,7 +36,8 @@ config = ed({ "label_smooth": 1, "label_smooth_factor": 0.1, - "ckpt_interval": 1250, + "ckpt_interval": 5, + "ckpt_save_max": 5, "ckpt_path": 'outputs/', "is_save_on_master": 1, diff --git a/model_zoo/official/cv/resnext50/src/dataset.py b/model_zoo/official/cv/resnext50/src/dataset.py index 9608e3c790..66fc653c47 100644 --- a/model_zoo/official/cv/resnext50/src/dataset.py +++ b/model_zoo/official/cv/resnext50/src/dataset.py @@ -143,8 +143,10 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) de_dataset.set_dataset_size(len(sampler)) - de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) - de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) + de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers, + operations=transform_img) + de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers, + operations=transform_label) columns_to_project = ["image", "label"] de_dataset = de_dataset.project(columns=columns_to_project) diff --git a/model_zoo/official/cv/resnext50/src/image_classification.py b/model_zoo/official/cv/resnext50/src/image_classification.py index d8003ad200..de5a4fcd2c 100644 --- a/model_zoo/official/cv/resnext50/src/image_classification.py +++ b/model_zoo/official/cv/resnext50/src/image_classification.py @@ -50,9 +50,9 @@ class Resnet(ImageClassificationNetwork): Returns: Resnet. """ - def __init__(self, backbone_name, num_classes): + def __init__(self, backbone_name, num_classes, platform="Ascend"): self.backbone_name = backbone_name - backbone = backbones.__dict__[self.backbone_name]() + backbone = backbones.__dict__[self.backbone_name](platform=platform) out_channels = backbone.get_out_channels() head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) super(Resnet, self).__init__(backbone, head) @@ -79,7 +79,7 @@ class Resnet(ImageClassificationNetwork): -def get_network(backbone_name, num_classes): +def get_network(backbone_name, num_classes, platform="Ascend"): if backbone_name in ['resnext50']: - return Resnet(backbone_name, num_classes) + return Resnet(backbone_name, num_classes, platform) return None diff --git a/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py new file mode 100644 index 0000000000..f8e27f5b52 --- /dev/null +++ b/model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Auto mixed precision.""" +import mindspore.nn as nn +from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator +from mindspore.common import dtype as mstype + + +class OutputTo(nn.Cell): + "Cast cell output back to float16 or float32" + + def __init__(self, op, to_type=mstype.float16): + super(OutputTo, self).__init__(auto_prefix=False) + self._op = op + validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None) + self.to_type = to_type + + def construct(self, x): + return F.cast(self._op(x), self.to_type) + + +def auto_mixed_precision(network): + """Do keep batchnorm fp32.""" + cells = network.name_cells() + change = False + network.to_float(mstype.float16) + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif name == 'fc': + network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32)) + change = True + elif name == 'conv2': + subcell.to_float(mstype.float32) + change = True + elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): + network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16)) + change = True + else: + auto_mixed_precision(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) diff --git a/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py b/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py index cbe89a1610..f4062821ef 100644 --- a/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py +++ b/model_zoo/official/cv/resnext50/src/utils/cunstom_op.py @@ -29,14 +29,10 @@ class GlobalAvgPooling(nn.Cell): """ def __init__(self): super(GlobalAvgPooling, self).__init__() - self.mean = P.ReduceMean(True) - self.shape = P.Shape() - self.reshape = P.Reshape() + self.mean = P.ReduceMean(False) def construct(self, x): x = self.mean(x, (2, 3)) - b, c, _, _ = self.shape(x) - x = self.reshape(x, (b, c)) return x diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py index ec2e33aba3..6b0eaae03b 100644 --- a/model_zoo/official/cv/resnext50/train.py +++ b/model_zoo/official/cv/resnext50/train.py @@ -36,11 +36,9 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr from src.utils.logging import get_logger from src.utils.optimizers__init__ import get_param_groups from src.image_classification import get_network +from src.utils.auto_mixed_precision import auto_mixed_precision from src.config import config -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target="Ascend", save_graphs=False, device_id=devid) class BuildTrainNetwork(nn.Cell): """build training network""" @@ -109,6 +107,7 @@ class ProgressMonitor(Callback): def parse_args(cloud_args=None): """parameters""" parser = argparse.ArgumentParser('mindspore classification training') + parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') # dataset related parser.add_argument('--data_dir', type=str, default='', help='train data dir') @@ -141,6 +140,7 @@ def parse_args(cloud_args=None): args.label_smooth = config.label_smooth args.label_smooth_factor = config.label_smooth_factor args.ckpt_interval = config.ckpt_interval + args.ckpt_save_max = config.ckpt_save_max args.ckpt_path = config.ckpt_path args.is_save_on_master = config.is_save_on_master args.rank = config.rank @@ -166,12 +166,25 @@ def merge_args(args, cloud_args): def train(cloud_args=None): """training process""" args = parse_args(cloud_args) + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) # init distributed if args.is_distributed: - init() + if args.platform == "Ascend": + init() + else: + init("nccl") args.rank = get_rank() args.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + args.rank = 0 + args.group_size = 1 if args.is_dynamic_loss_scale == 1: args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt @@ -192,7 +205,7 @@ def train(cloud_args=None): # dataloader de_dataset = classification_dataset(args.data_dir, args.image_size, args.per_batch_size, 1, - args.rank, args.group_size) + args.rank, args.group_size, num_parallel_workers=8) de_dataset.map_model = 4 # !!!important args.steps_per_epoch = de_dataset.get_dataset_size() @@ -201,15 +214,9 @@ def train(cloud_args=None): # network args.logger.important_info('start create network') # get network and init - network = get_network(args.backbone, args.num_classes) + network = get_network(args.backbone, args.num_classes, platform=args.platform) if network is None: raise NotImplementedError('not implement {}'.format(args.backbone)) - network.add_flags_recursive(fp16=True) - # loss - if not args.label_smooth: - args.label_smooth_factor = 0.0 - criterion = CrossEntropy(smooth_factor=args.label_smooth_factor, - num_classes=args.num_classes) # load pretrain model if os.path.isfile(args.pretrained): @@ -252,31 +259,29 @@ def train(cloud_args=None): loss_scale=args.loss_scale) - criterion.add_flags_recursive(fp32=True) + # loss + if not args.label_smooth: + args.label_smooth_factor = 0.0 + loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes) - # package training process, adjust lr + forward + backward + optimizer - train_net = BuildTrainNetwork(network, criterion) - if args.is_distributed: - parallel_mode = ParallelMode.DATA_PARALLEL - else: - parallel_mode = ParallelMode.STAND_ALONE if args.is_dynamic_loss_scale == 1: loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) else: loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) - # Model api changed since TR5_branch 2020/03/09 - context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, - parameter_broadcast=True, mirror_mean=True) - model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager) + if args.platform == "Ascend": + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, + metrics={'acc'}, amp_level="O3") + else: + auto_mixed_precision(network) + model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager, metrics={'acc'}) # checkpoint save progress_cb = ProgressMonitor(args) callbacks = [progress_cb,] if args.rank_save_ckpt_flag: - ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, - keep_checkpoint_max=ckpt_max_num) + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch, + keep_checkpoint_max=args.ckpt_save_max) ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.outputs_dir, prefix='{}'.format(args.rank))