!4478 Add an example of training NASNet in MindSpore
Merge pull request !4478 from dessyang/master
This commit is contained in:
commit
a6c1fb2c25
|
@ -0,0 +1,111 @@
|
|||
# NASNet Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training NASNet-A-Mobile in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [Mindspore](http://www.mindspore.cn/install/en).
|
||||
- Download the dataset.
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─nasnet
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
|
||||
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─dataset.py # data preprocessing
|
||||
├─loss.py # Customized CrossEntropy loss function
|
||||
├─lr_generator.py # learning rate generator
|
||||
├─nasnet_a_mobile.py # network definition
|
||||
├─eval.py # eval net
|
||||
├─export.py # convert checkpoint
|
||||
└─train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## Parameter Configuration
|
||||
|
||||
Parameters for both training and evaluating can be set in config.py
|
||||
|
||||
```
|
||||
'random_seed': 1, # fix random seed
|
||||
'rank': 0, # local rank of distributed
|
||||
'group_size': 1, # world size of distributed
|
||||
'work_nums': 8, # number of workers to read the data
|
||||
'epoch_size': 250, # total epoch numbers
|
||||
'keep_checkpoint_max': 100, # max numbers to keep checkpoints
|
||||
'ckpt_path': './checkpoint/', # save checkpoint path
|
||||
'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters
|
||||
'batch_size': 32, # input batchsize
|
||||
'num_classes': 1000, # dataset class numbers
|
||||
'label_smooth_factor': 0.1, # label smoothing factor
|
||||
'aux_factor': 0.4, # loss factor of aux logit
|
||||
'lr_init': 0.04, # initiate learning rate
|
||||
'lr_decay_rate': 0.97, # decay rate of learning rate
|
||||
'num_epoch_per_decay': 2.4, # decay epoch number
|
||||
'weight_decay': 0.00004, # weight decay
|
||||
'momentum': 0.9, # momentum
|
||||
'opt_eps': 1.0, # epsilon
|
||||
'rmsprop_decay': 0.9, # rmsprop decay
|
||||
'loss_scale': 1, # loss scale
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_for_gpu.sh DATA_DIR
|
||||
# standalone training
|
||||
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# distributed training example(8p) for GPU
|
||||
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
|
||||
# standalone training example for GPU
|
||||
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
You can find checkpoint file together with result in log.
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# Evaluation
|
||||
sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# Evaluation with checkpoint
|
||||
sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/nasnet-a-mobile-rank0-248_10009.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
#### Result
|
||||
|
||||
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.nasnet_a_mobile import NASNetAMobile
|
||||
from src.loss import CrossEntropy_Val
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(args_opt.dataset_path, cfg, False)
|
||||
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into geir and onnx models#################
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.nasnet_a_mobile import NASNetAMobile
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, cfg.image_size, cfg.image_size]), ms.float32)
|
||||
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
|
||||
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
|
|
@ -0,0 +1,19 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
nasnet_a_mobile_config_gpu = edict({
|
||||
'random_seed': 1,
|
||||
'rank': 0,
|
||||
'group_size': 1,
|
||||
'work_nums': 8,
|
||||
'epoch_size': 312,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './nasnet_a_mobile_checkpoint/',
|
||||
'is_save_on_master': 0,
|
||||
|
||||
### Dataset Config
|
||||
'batch_size': 32,
|
||||
'image_size': 224,
|
||||
'num_classes': 1000,
|
||||
|
||||
### Loss Config
|
||||
'label_smooth_factor': 0.1,
|
||||
'aux_factor': 0.4,
|
||||
|
||||
### Learning Rate Config
|
||||
# 'lr_decay_method': 'exponential',
|
||||
'lr_init': 0.04,
|
||||
'lr_decay_rate': 0.97,
|
||||
'num_epoch_per_decay': 2.4,
|
||||
|
||||
### Optimization Config
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'rmsprop_decay': 0.9,
|
||||
"loss_scale": 1,
|
||||
|
||||
### onnx&air Config
|
||||
'onnx_filename': 'nasnet_a_mobile.onnx',
|
||||
'air_filename': 'nasnet_a_mobile.air'
|
||||
})
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
|
||||
|
||||
def create_dataset(dataset_path, config, do_train, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
config(dict): config of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
rank = config.rank
|
||||
group_size = config.group_size
|
||||
if group_size == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(config.image_size),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, saturation=0.5) # fast mode
|
||||
#C.RandomColorAdjust(brightness=0.4, contrast=0.5, saturation=0.5, hue=0.2)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(int(config.image_size/0.875)),
|
||||
C.CenterCrop(config.image_size)
|
||||
]
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
||||
# apply batch operations
|
||||
ds = ds.batch(config.batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define evaluation loss function for network."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy_Val(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000):
|
||||
super(CrossEntropy_Val, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate exponential decay generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_decay_rate (float):
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
is_stair(bool): If `True` decay the learning rate at discrete intervals
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_steps = steps_per_epoch * num_epoch_per_decay
|
||||
for i in range(total_steps):
|
||||
p = i/decay_steps
|
||||
if is_stair:
|
||||
p = math.floor(p)
|
||||
lr_each_step.append(lr_init * math.pow(lr_decay_rate, p))
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
|
@ -0,0 +1,937 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""NASNet-A-Mobile model definition"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.ops.functional as F
|
||||
import mindspore.ops.composite as C
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
||||
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 10.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
# pylint: disable=consider-using-in
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor]: clipped gradients.
|
||||
"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.factor = factor
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
logit, aux = logits
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logit, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
|
||||
loss_aux = self.ce(aux, one_hot_label_aux)
|
||||
loss_aux = self.mean(loss_aux, 0)
|
||||
return loss_logit + self.factor*loss_aux
|
||||
|
||||
|
||||
class AuxLogits(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, name=None):
|
||||
super(AuxLogits, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.pool = nn.AvgPool2d(5, stride=3, pad_mode='valid')
|
||||
self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
|
||||
self.bn = nn.BatchNorm2d(128)
|
||||
self.conv_1 = nn.Conv2d(128, 768, (4, 4), pad_mode='valid')
|
||||
self.bn_1 = nn.BatchNorm2d(768)
|
||||
self.flatten = nn.Flatten()
|
||||
if name == 'large':
|
||||
self.fc = nn.Dense(6912, out_channels) # large: 6912, mobile:768
|
||||
else:
|
||||
self.fc = nn.Dense(768, out_channels)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv_1(x)
|
||||
x = self.bn_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.depthwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=dw_kernel,
|
||||
stride=dw_stride, pad_mode='pad', padding=dw_padding, group=in_channels,
|
||||
has_bias=bias)
|
||||
self.pointwise_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=bias)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.depthwise_conv2d(x)
|
||||
x = self.pointwise_conv2d(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparables(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
||||
super(BranchSeparables, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(
|
||||
in_channels, in_channels, kernel_size, stride, padding, bias=bias
|
||||
)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(num_features=in_channels, eps=0.001, momentum=0.9, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(
|
||||
in_channels, out_channels, kernel_size, 1, padding, bias=bias
|
||||
)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesStem(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
||||
super(BranchSeparablesStem, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias=bias
|
||||
)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(
|
||||
out_channels, out_channels, kernel_size, 1, padding, bias=bias
|
||||
)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.9, affine=True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesReduction(BranchSeparables):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
|
||||
BranchSeparables.__init__(
|
||||
self, in_channels, out_channels, kernel_size, stride, padding, bias
|
||||
)
|
||||
self.padding = nn.Pad(paddings=((0, 0), (0, 0), (z_padding, 0), (z_padding, 0)), mode="CONSTANT")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.padding(x)
|
||||
x = self.separable_1(x)
|
||||
x = x[:, :, 1:, 1:]
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class CellStem0(nn.Cell):
|
||||
|
||||
def __init__(self, stem_filters, num_filters=42):
|
||||
super(CellStem0, self).__init__()
|
||||
self.num_filters = num_filters
|
||||
self.stem_filters = stem_filters
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True)
|
||||
])
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
self.num_filters, self.num_filters, 5, 2, 2
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparablesStem(
|
||||
self.stem_filters, self.num_filters, 7, 2, 3, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
self.comb_iter_1_right = BranchSeparablesStem(
|
||||
self.stem_filters, self.num_filters, 7, 2, 3, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
self.comb_iter_2_right = BranchSeparablesStem(
|
||||
self.stem_filters, self.num_filters, 5, 2, 2, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
self.num_filters, self.num_filters, 3, 1, 1, bias=False
|
||||
)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x1)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x1)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x1)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x1)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class CellStem1(nn.Cell):
|
||||
|
||||
def __init__(self, stem_filters, num_filters):
|
||||
super(CellStem1, self).__init__()
|
||||
self.num_filters = num_filters
|
||||
self.stem_filters = stem_filters
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=2*self.num_filters, out_channels=self.num_filters, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.path_1 = nn.SequentialCell([
|
||||
nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid'),
|
||||
nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters//2, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False)])
|
||||
|
||||
self.path_2 = nn.CellList([])
|
||||
self.path_2.append(nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)), mode="CONSTANT"))
|
||||
self.path_2.append(
|
||||
nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid')
|
||||
)
|
||||
self.path_2.append(
|
||||
nn.Conv2d(in_channels=self.stem_filters, out_channels=self.num_filters//2, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False)
|
||||
)
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(num_features=self.num_filters, eps=0.001, momentum=0.9, affine=True)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
self.num_filters,
|
||||
self.num_filters,
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
bias=False
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparables(
|
||||
self.num_filters,
|
||||
self.num_filters,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_1_right = BranchSeparables(
|
||||
self.num_filters,
|
||||
self.num_filters,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_2_right = BranchSeparables(
|
||||
self.num_filters,
|
||||
self.num_filters,
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
self.num_filters,
|
||||
self.num_filters,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=False
|
||||
)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x_conv0, x_stem_0):
|
||||
x_left = self.conv_1x1(x_stem_0)
|
||||
x_relu = self.relu(x_conv0)
|
||||
# path 1
|
||||
x_path1 = self.path_1(x_relu)
|
||||
# path 2
|
||||
x_path2 = self.path_2[0](x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2[1](x_path2)
|
||||
x_path2 = self.path_2[2](x_path2)
|
||||
# final path
|
||||
x_right = self.final_path_bn(P.Concat(1)((x_path1, x_path2)))
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_right)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_left)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_left)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class FirstCell(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(FirstCell, self).__init__()
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.path_1 = nn.SequentialCell([
|
||||
nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid'),
|
||||
nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False)])
|
||||
|
||||
self.path_2 = nn.CellList([])
|
||||
self.path_2.append(nn.Pad(paddings=((0, 0), (0, 0), (0, 1), (0, 1)), mode="CONSTANT"))
|
||||
self.path_2.append(
|
||||
nn.AvgPool2d(kernel_size=1, stride=2, pad_mode='valid')
|
||||
)
|
||||
self.path_2.append(
|
||||
nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False)
|
||||
)
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(num_features=out_channels_left*2, eps=0.001, momentum=0.9, affine=True)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
||||
)
|
||||
self.comb_iter_1_right = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
def construct(self, x, x_prev):
|
||||
x_relu = self.relu(x_prev)
|
||||
x_path1 = self.path_1(x_relu)
|
||||
x_path2 = self.path_2[0](x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2[1](x_path2)
|
||||
x_path2 = self.path_2[2](x_path2)
|
||||
# final path
|
||||
x_left = self.final_path_bn(P.Concat(1)((x_path1, x_path2)))
|
||||
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
||||
|
||||
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
||||
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
||||
|
||||
x_out = P.Concat(1)((x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class NormalCell(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(NormalCell, self).__init__()
|
||||
self.conv_prev_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 5, 1, 2, bias=False
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparables(
|
||||
out_channels_left, out_channels_left, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(
|
||||
out_channels_left, out_channels_left, 5, 1, 2, bias=False
|
||||
)
|
||||
self.comb_iter_1_right = BranchSeparables(
|
||||
out_channels_left, out_channels_left, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
||||
)
|
||||
|
||||
def construct(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
||||
|
||||
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
||||
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
||||
|
||||
x_out = P.Concat(1)((x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class ReductionCell0(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(ReductionCell0, self).__init__()
|
||||
self.conv_prev_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.comb_iter_0_left = BranchSeparablesReduction(
|
||||
out_channels_right, out_channels_right, 5, 2, 2, bias=False
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparablesReduction(
|
||||
out_channels_right, out_channels_right, 7, 2, 3, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_1_right = BranchSeparablesReduction(
|
||||
out_channels_right, out_channels_right, 7, 2, 3, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_2_right = BranchSeparablesReduction(
|
||||
out_channels_right, out_channels_right, 5, 2, 2, bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparablesReduction(
|
||||
out_channels_right, out_channels_right, 3, 1, 1, bias=False
|
||||
)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
|
||||
def construct(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class ReductionCell1(nn.Cell):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(ReductionCell1, self).__init__()
|
||||
self.conv_prev_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_left, out_channels=out_channels_left, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_left, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.conv_1x1 = nn.SequentialCell([
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels=in_channels_right, out_channels=out_channels_right, kernel_size=1, stride=1,
|
||||
pad_mode='pad', has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_channels_right, eps=0.001, momentum=0.9, affine=True)])
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(
|
||||
out_channels_right,
|
||||
out_channels_right,
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
bias=False
|
||||
)
|
||||
self.comb_iter_0_right = BranchSeparables(
|
||||
out_channels_right,
|
||||
out_channels_right,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_1_right = BranchSeparables(
|
||||
out_channels_right,
|
||||
out_channels_right,
|
||||
7,
|
||||
2,
|
||||
3,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, pad_mode='same')
|
||||
self.comb_iter_2_right = BranchSeparables(
|
||||
out_channels_right,
|
||||
out_channels_right,
|
||||
5,
|
||||
2,
|
||||
2,
|
||||
bias=False
|
||||
)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(kernel_size=3, stride=1, pad_mode='same')
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(
|
||||
out_channels_right,
|
||||
out_channels_right,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=False
|
||||
)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, pad_mode='same')
|
||||
|
||||
def construct(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = P.Concat(1)((x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4))
|
||||
return x_out
|
||||
|
||||
|
||||
class NASNetAMobile(nn.Cell):
|
||||
"""Neural Architecture Search (NAS).
|
||||
|
||||
Reference:
|
||||
Zoph et al. Learning Transferable Architectures
|
||||
for Scalable Image Recognition. CVPR 2018.
|
||||
- ``nasnetamobile``: NASNet-A Mobile.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, is_training=True,
|
||||
stem_filters=32, penultimate_filters=1056, filters_multiplier=2):
|
||||
super(NASNetAMobile, self).__init__()
|
||||
self.is_training = is_training
|
||||
self.stem_filters = stem_filters
|
||||
self.penultimate_filters = penultimate_filters
|
||||
self.filters_multiplier = filters_multiplier
|
||||
|
||||
filters = self.penultimate_filters//24
|
||||
# 24 is default value for the architecture
|
||||
|
||||
self.conv0 = nn.SequentialCell([
|
||||
nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, stride=2, pad_mode='pad', padding=0,
|
||||
has_bias=False),
|
||||
nn.BatchNorm2d(num_features=self.stem_filters, eps=0.001, momentum=0.9, affine=True)
|
||||
])
|
||||
|
||||
self.cell_stem_0 = CellStem0(
|
||||
self.stem_filters, num_filters=filters//(filters_multiplier**2)
|
||||
)
|
||||
self.cell_stem_1 = CellStem1(
|
||||
self.stem_filters, num_filters=filters//filters_multiplier
|
||||
)
|
||||
|
||||
self.cell_0 = FirstCell(
|
||||
in_channels_left=filters,
|
||||
out_channels_left=filters//2, # 1, 0.5
|
||||
in_channels_right=2*filters,
|
||||
out_channels_right=filters
|
||||
) # 2, 1
|
||||
self.cell_1 = NormalCell(
|
||||
in_channels_left=2*filters,
|
||||
out_channels_left=filters, # 2, 1
|
||||
in_channels_right=6*filters,
|
||||
out_channels_right=filters
|
||||
) # 6, 1
|
||||
self.cell_2 = NormalCell(
|
||||
in_channels_left=6*filters,
|
||||
out_channels_left=filters, # 6, 1
|
||||
in_channels_right=6*filters,
|
||||
out_channels_right=filters
|
||||
) # 6, 1
|
||||
self.cell_3 = NormalCell(
|
||||
in_channels_left=6*filters,
|
||||
out_channels_left=filters, # 6, 1
|
||||
in_channels_right=6*filters,
|
||||
out_channels_right=filters
|
||||
) # 6, 1
|
||||
|
||||
self.reduction_cell_0 = ReductionCell0(
|
||||
in_channels_left=6*filters,
|
||||
out_channels_left=2*filters, # 6, 2
|
||||
in_channels_right=6*filters,
|
||||
out_channels_right=2*filters
|
||||
) # 6, 2
|
||||
|
||||
self.cell_6 = FirstCell(
|
||||
in_channels_left=6*filters,
|
||||
out_channels_left=filters, # 6, 1
|
||||
in_channels_right=8*filters,
|
||||
out_channels_right=2*filters
|
||||
) # 8, 2
|
||||
self.cell_7 = NormalCell(
|
||||
in_channels_left=8*filters,
|
||||
out_channels_left=2*filters, # 8, 2
|
||||
in_channels_right=12*filters,
|
||||
out_channels_right=2*filters
|
||||
) # 12, 2
|
||||
self.cell_8 = NormalCell(
|
||||
in_channels_left=12*filters,
|
||||
out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=12*filters,
|
||||
out_channels_right=2*filters
|
||||
) # 12, 2
|
||||
self.cell_9 = NormalCell(
|
||||
in_channels_left=12*filters,
|
||||
out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=12*filters,
|
||||
out_channels_right=2*filters
|
||||
) # 12, 2
|
||||
|
||||
if is_training:
|
||||
self.aux_logits = AuxLogits(in_channels=12*filters, out_channels=num_classes)
|
||||
|
||||
self.reduction_cell_1 = ReductionCell1(
|
||||
in_channels_left=12*filters,
|
||||
out_channels_left=4*filters, # 12, 4
|
||||
in_channels_right=12*filters,
|
||||
out_channels_right=4*filters
|
||||
) # 12, 4
|
||||
|
||||
self.cell_12 = FirstCell(
|
||||
in_channels_left=12*filters,
|
||||
out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=16*filters,
|
||||
out_channels_right=4*filters
|
||||
) # 16, 4
|
||||
self.cell_13 = NormalCell(
|
||||
in_channels_left=16*filters,
|
||||
out_channels_left=4*filters, # 16, 4
|
||||
in_channels_right=24*filters,
|
||||
out_channels_right=4*filters
|
||||
) # 24, 4
|
||||
self.cell_14 = NormalCell(
|
||||
in_channels_left=24*filters,
|
||||
out_channels_left=4*filters, # 24, 4
|
||||
in_channels_right=24*filters,
|
||||
out_channels_right=4*filters
|
||||
) # 24, 4
|
||||
self.cell_15 = NormalCell(
|
||||
in_channels_left=24*filters,
|
||||
out_channels_left=4*filters, # 24, 4
|
||||
in_channels_right=24*filters,
|
||||
out_channels_right=4*filters
|
||||
) # 24, 4
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(keep_prob=0.5)
|
||||
self.classifier = nn.Dense(in_channels=24*filters, out_channels=num_classes)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
|
||||
m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2./n),
|
||||
m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_parameter_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.gamma.set_parameter_data(
|
||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||
m.beta.set_parameter_data(
|
||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.Dense):
|
||||
m.weight.set_parameter_data(Tensor(np.random.normal(
|
||||
0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_parameter_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
|
||||
def construct(self, x):
|
||||
x_conv0 = self.conv0(x)
|
||||
x_stem_0 = self.cell_stem_0(x_conv0)
|
||||
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
||||
|
||||
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
||||
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
||||
x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
|
||||
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
||||
|
||||
x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2)
|
||||
|
||||
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3)
|
||||
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
||||
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
||||
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
||||
|
||||
if self.is_training:
|
||||
aux_logits = self.aux_logits(x_cell_9)
|
||||
else:
|
||||
aux_logits = None
|
||||
|
||||
x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8)
|
||||
|
||||
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9)
|
||||
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
||||
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
||||
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
||||
|
||||
x_cell_15 = self.relu(x_cell_15)
|
||||
x_cell_15 = nn.AvgPool2d(F.shape(x_cell_15)[2:])(x_cell_15) # global average pool
|
||||
x_cell_15 = self.reshape(x_cell_15, (self.shape(x_cell_15)[0], -1,))
|
||||
x_cell_15 = self.dropout(x_cell_15)
|
||||
logits = self.classifier(x_cell_15)
|
||||
|
||||
if self.is_training:
|
||||
return logits, aux_logits
|
||||
return logits
|
||||
|
||||
|
||||
class NASNetAMobileWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide nasnet-a-mobile training loss through network.
|
||||
|
||||
Args:
|
||||
config (dict): The config of nasnet-a-mobile.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
|
||||
Returns:
|
||||
Tensor: the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training=True):
|
||||
super(NASNetAMobileWithLoss, self).__init__()
|
||||
self.network = NASNetAMobile(config.num_classes, is_training)
|
||||
self.loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
|
||||
num_classes=config.num_classes, factor=config.aux_factor)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, data, label):
|
||||
prediction_scores = self.network(data)
|
||||
total_loss = self.loss(prediction_scores, label)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell):
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(NASNetAMobileTrainOneStepWithClipGradient, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train imagenet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.nn.optim.rmsprop import RMSProp
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import dataset as de
|
||||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
|
||||
from src.lr_generator import get_lr
|
||||
|
||||
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
de.config.set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False,
|
||||
help='distributed training')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if args_opt.platform == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
cfg.rank = get_rank()
|
||||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, mirror_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
||||
# dataloader
|
||||
dataset = create_dataset(args_opt.dataset_path, cfg, True)
|
||||
batches_per_epoch = dataset.get_dataset_size()
|
||||
|
||||
# network
|
||||
net_with_loss = NASNetAMobileWithLoss(cfg)
|
||||
if args_opt.resume:
|
||||
ckpt = load_checkpoint(args_opt.resume)
|
||||
load_param_into_net(net_with_loss, ckpt)
|
||||
|
||||
# learning rate schedule
|
||||
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
|
||||
num_epoch_per_decay=cfg.num_epoch_per_decay, total_epochs=cfg.epoch_size,
|
||||
steps_per_epoch=batches_per_epoch, is_stair=True)
|
||||
lr = Tensor(lr)
|
||||
|
||||
# optimizer
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net_with_loss.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net_with_loss.trainable_params()}]
|
||||
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
||||
|
||||
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
||||
net_with_grads.set_train()
|
||||
model = Model(net_with_grads)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
else:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
print("train success")
|
Loading…
Reference in New Issue