!14747 ADD ReNAS & ManiDP

From: @yttdebaba2
Reviewed-by: @c_34,@wuxuejian
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-04-08 12:55:07 +08:00 committed by Gitee
commit 5d7cd47b0b
13 changed files with 1373 additions and 0 deletions

View File

@ -0,0 +1,120 @@
# Contents
- [Manifold Dynamic Pruning Description](#manifold-dynamic-pruning-description)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision(Ascend))
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [Manifold Dynamic Pruning Description](#contents)
Neural network pruning is an essential approach for reducing the computational complexity of deep models so that they can be well deployed on resource-limited devices. Compared with conventional methods, the recently developed dynamic pruning methods determine redundant filters variant to each input instance which achieves higher acceleration. Most of the existing methods discover effective sub-networks for each instance independently and do not utilize the relationship between different inputs. To maximally excavate redundancy in the given network architecture, this paper proposes a new paradigm that dynamically removes redundant filters by embedding the manifold information of all instances into the space of pruned networks (dubbed as ManiDP). We first investigate the recognition complexity and feature similarity between images in the training set. Then, the manifold relationship between instances and the pruned sub-networks will be aligned in the training procedure. The effectiveness of the proposed method is verified on several benchmarks, which shows better performance in terms of both accuracy and computational cost compared to the state-of-the-art methods. For example, our method can reduce 55.3% FLOPs of ResNet-34 with only 0.57% top-1 accuracy degradation on ImageNet.
[Paper](https://arxiv.org/pdf/2103.05861.pdf): Yehui Tang, Yunhe Wang, Yixing Xu, Yiping Deng, Chao Xu, Dacheng Tao, Chang Xu. Manifold Regularized Dynamic Network Pruning. Submitted to CVPR 2021.
# [Dataset](#contents)
Dataset used: [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)
- Dataset size: 60000 colorful images in 10 classes
- Train: 50000 images
- Test: 10000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# [Features](#contents)
## [Mixed Precision(Ascend)](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend/GPU/CPU
- Prepare hardware environment with Ascend、GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```python
├── ManiDP
├── Readme.md # descriptions about adversarial-pruning # shell script for evaluation with CPU, GPU or Ascend
├── src
│ ├──loss.py # parameter configuration
│ ├──dataset.py # creating dataset
│ ├──resnet.py # Pruned ResNet architecture
├── eval.py # evaluation script
```
## [Training process](#contents)
To Be Done
## [Eval process](#contents)
### Usage
After installing MindSpore via the official website, you can start evaluation as follows:
### Launch
```bash
# infer example
Ascend: python eval.py --dataset_path path/to/cifar10 --platform Ascend --checkpoint_path [CHECKPOINT_PATH]
GPU: python eval.py --dataset_path path/to/cifar10 --platform GPU --checkpoint_path [CHECKPOINT_PATH]
CPU: python eval.py --dataset_path path/to/cifar10 --platform CPU --checkpoint_path [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.
### Result
```bash
result: {'acc': 0.9204727564102564}
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
#### ResNet20 on CIFAR-10
| Parameters | |
| -------------------------- | -------------------------------------- |
| Model Version | ResNet20 |
| uploaded Date | 03/27/2021 (month/day/year) |
| MindSpore Version | 0.6.0-alpha |
| Dataset | CIFAR-10 |
| Parameters (M) | 0.27 |
| FLOPs (M) | 18.74 |
| Accuracy (Top1) | 92.05 |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,111 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference Interface"""
import sys
import logging
import argparse
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore import context
from mindspore import Tensor
from src.dataset import create_dataset_cifar10
from src.loss import LabelSmoothingCrossEntropy
from src.resnet import resnet20
from easydict import EasyDict as edict
import numpy as np
root = logging.getLogger()
root.setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--data_path', type=str, default='/home/workspace/mindspore_dataset/',
metavar='DIR', help='path to dataset')
parser.add_argument('--model', default='hournas_f_c10', type=str, metavar='MODEL',
help='Name of model to train (default: "hournas_f_c10")')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 10)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--ckpt', type=str, default='./resnet20.ckpt',
help='model checkpoint to load')
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
parser.add_argument('--dataset_sink', action='store_true', default=False,
help='Data sink (default: False)')
parser.add_argument('--device_id', type=int, default=0,
help='Device ID (default: 0)')
parser.add_argument('--image-size', type=int, default=32, metavar='N',
help='input image size (default: 32)')
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
context.set_context(mode=context.GRAPH_MODE)
# context.set_context(mode=context.PYNATIVE_MODE)
if args.GPU:
context.set_context(device_target='GPU', device_id=args.device_id)
# parse model argument
assert args.model.startswith(
"hournas"), "Only Tinynet models are supported."
#_, sub_name = args.model.split("_")
thres = np.load('thres.npy')
thres = Tensor(thres.astype(np.float32))
net = resnet20(thres=thres)
cfg = edict({
'image_height': args.image_size,
'image_width': args.image_size,
})
#cfg.rank = 0
#cfg.group_size = 1
cfg.batch_size = args.batch_size
#input_size = net.default_cfg['input_size'][1]
val_data_url = args.data_path #os.path.join(args.data_path, 'val')
val_dataset = create_dataset_cifar10(val_data_url, repeat_num=1, training=False, cifar_cfg=cfg)
loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing,
num_classes=args.num_classes)
loss.add_flags_recursive(fp32=True, fp16=False)
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, metrics=eval_metrics)
metrics = model.eval(val_dataset, dataset_sink_mode=False)
print(metrics)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.resnet import resnet20
from mindspore import Tensor
import numpy as np
def create_network(name, thres_filename, *args, **kwargs):
if name == 'resnet20':
thres = np.load(thres_filename)
thres = Tensor(thres.astype(np.float32))
return resnet20(thres=thres)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,200 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import math
import os
import numpy as np
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore.communication.management import get_rank, get_group_size
from mindspore.dataset.vision import Inter
import mindspore.dataset.vision.c_transforms as vision
# values that should remain constant
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# data preprocess configs
SCALE = (0.08, 1.0)
RATIO = (3./4., 4./3.)
ds.config.set_seed(1)
def split_imgs_and_labels(imgs, labels, batchInfo):
"""split data into labels and images"""
ret_imgs = []
ret_labels = []
for i, image in enumerate(imgs):
ret_imgs.append(image)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
input_size=224, color_jitter=0.4):
"""Create ImageNet training dataset"""
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
decode_op = py_vision.Decode()
type_cast_op = c_transforms.TypeCast(mstype.int32)
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
scale=SCALE, ratio=RATIO,
interpolation=Inter.BICUBIC)
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
# batch dealing
ds_train = dataset_train.batch(batch_size,
per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
input_size=224):
"""Create ImageNet validation dataset"""
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id)
scale_size = None
if isinstance(input_size, tuple):
assert len(input_size) == 2
if input_size[-1] == input_size[-2]:
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
else:
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
else:
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
type_cast_op = c_transforms.TypeCast(mstype.int32)
decode_op = py_vision.Decode()
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, normalize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=image_ops,
num_parallel_workers=workers)
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
dataset = dataset.repeat(1)
return dataset
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None):
"""Data operations."""
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
rank_size, rank_id = _get_rank_info()
if training:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
resize_height = cifar_cfg.image_height
resize_width = cifar_cfg.image_width
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))
changeswap_op = vision.HWC2CHW()
type_cast_op = c_transforms.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
data_set = data_set.map(operations=type_cast_op, input_columns="label")
data_set = data_set.map(operations=c_trans, input_columns="image")
# apply batch operations
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logits, label):
label = self.cast(label, mstype.int32)
one_hot_label = self.onehot(label, F.shape(
logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

View File

@ -0,0 +1,231 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import ops
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel)
def _bn_last(channel):
return nn.BatchNorm2d(channel)
def _fc(in_channel, out_channel, bias=True):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=bias, weight_init=weight, bias_init=0)
class MaskBlock(nn.Cell):
"""
ResNet basic mask block definition.
Args:
in_channels (int): number of input channels.
out_channels (int): number of output channels.
num (int): layer number.
thres (list): threshold of layers.
Returns:
Tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, num, thres=None):
super(MaskBlock, self).__init__()
#self.target_pruning_rate = gate_factor
self.clamp_min = Tensor(0, mstype.float32)
self.clamp_max = Tensor(1000, mstype.float32)
if out_channels < 80:
squeeze_rate = 1
else:
squeeze_rate = 2
self.avg_pool = P.ReduceMean(keep_dims=False)
self.fc1 = _fc(in_channels, out_channels // squeeze_rate, bias=False)
self.fc2 = _fc(out_channels // squeeze_rate, out_channels, bias=True)
self.relu = P.ReLU()
self.thre = thres[num]
self.print = P.Print()
def construct(self, x):
"""construct"""
x_averaged = self.avg_pool(x, (2, 3))
y = self.fc1(x_averaged)
y = self.relu(y)
y = self.fc2(y)
mask_before = self.relu(y)
mask_before = ops.clip_by_value(mask_before, self.clamp_min, self.clamp_max)
tmp = ops.Greater()(mask_before, self.thre)
mask = mask_before * tmp
return mask
class MaskedBasicblock(nn.Cell):
"""
ResNet basic mask block definition.
Args:
inplanes (int): number of input channels.
planes (int): number of output channels.
stride (int): convolution kernel stride.
downsample (Cell): downsample layer.
num (int): layer number.
thres (list): threshold of layers.
Returns:
Tensor, output tensor.
"""
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, num=0, thres=None):
super(MaskedBasicblock, self).__init__()
self.conv_a = _conv3x3(inplanes, planes, stride=stride)
self.bn_a = _bn(planes)
self.conv_b = _conv3x3(planes, planes, stride=1)
self.bn_b = _bn(planes)
self.downsample = downsample
self.mb1 = MaskBlock(inplanes, planes, num*2, thres)
self.mb2 = MaskBlock(planes, planes, num*2+1, thres)
self.relu = P.ReLU()
self.expand_dims = ops.ExpandDims()
def construct(self, x):
"""construct"""
residual = x
mask1 = self.mb1(x)
basicblock = self.conv_a(x)
basicblock = self.bn_a(basicblock)
basicblock = self.relu(basicblock)
basicblock = basicblock * self.expand_dims(self.expand_dims(mask1, -1), -1)
mask2 = self.mb2(basicblock)
basicblock = self.conv_b(basicblock)
basicblock = self.bn_b(basicblock)
basicblock = basicblock* self.expand_dims(self.expand_dims(mask2, -1), -1)
if self.downsample is not None:
residual = self.downsample(x)
return self.relu(residual + basicblock)
class CifarResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): block for network.
depth (int): network depth.
num_classes (int): The number of classes that the training images are belonging to.
thres (list): threshold of layers.
Returns:
Tensor, output tensor.
"""
def __init__(self, block, depth, num_classes, thres):
super(CifarResNet, self).__init__()
layer_blocks = (depth - 2) // 6
self.num_classes = num_classes
self.conv_1_3x3 = _conv3x3(3, 16, stride=1)
self.bn_1 = _bn(16)
self.relu = P.ReLU()
self.inplanes = 16
self.stage_1 = self._make_layer(block, 16, layer_blocks, 1, s_num=0, thres=thres)
self.stage_2 = self._make_layer(block, 32, layer_blocks, 2, s_num=1, thres=thres)
self.stage_3 = self._make_layer(block, 64, layer_blocks, 2, s_num=2, thres=thres)
self.avgpool = nn.AvgPool2d(8)
self.classifier = _fc(64 * block.expansion, num_classes)
self.flatten = nn.Flatten()
def _make_layer(self, block, planes, blocks, stride=1, s_num=0, thres=None):
"""make layer"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([_conv1x1(self.inplanes, planes * block.expansion, stride=stride)])
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, num=s_num*3+0, thres=thres))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, num=s_num*3+i, thres=thres))
return nn.SequentialCell(layers)
def construct(self, x):
"""construct"""
x = self.conv_1_3x3(x)
x = self.relu(self.bn_1(x))
x = self.stage_1(x)
x = self.stage_2(x)
x = self.stage_3(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.classifier(x)
return x
def resnet20(num_classes=10, thres=None):
model = CifarResNet(MaskedBasicblock, 20, num_classes, thres)
return model
def resnet56(num_classes=10):
model = CifarResNet(MaskedBasicblock, 56, num_classes)
return model

Binary file not shown.

View File

@ -0,0 +1,120 @@
# Contents
- [ReNAS Description](#renas-description)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [ReNAS Description](#contents)
An effective and efficient architecture performance evaluation scheme is essential for the success of Neural Architecture Search (NAS). To save computational cost, most of existing NAS algorithms often train and evaluate intermediate neural architectures on a small proxy dataset with limited training epochs. But it is difficult to expect an accurate performance estimation of an architecture in such a coarse evaluation way. This paper advocates a new neural architecture evaluation scheme, which aims to determine which architecture would perform better instead of accurately predict the absolute architecture performance. Therefore, we propose a \textbf{relativistic} architecture performance predictor in NAS (ReNAS). We encode neural architectures into feature tensors, and further refining the representations with the predictor. The proposed relativistic performance predictor can be deployed in discrete searching methods to search for the desired architectures without additional evaluation. Experimental results on NAS-Bench-101 dataset suggests that, sampling 424 ($0.1\%$ of the entire search space) neural architectures and their corresponding validation performance is already enough for learning an accurate architecture performance predictor. The accuracies of our searched neural architectures on NAS-Bench-101 and NAS-Bench-201 datasets are higher than that of the state-of-the-art methods and show the priority of the proposed method.
[Paper](https://arxiv.org/pdf/1910.01523.pdf): Yixing Xu, Yunhe Wang, Kai Han, Yehui Tang, Shangling Jui, Chunjing Xu, Chang Xu. ReNAS: Relativistic Evaluation of Neural Architecture Search. Submitted to CVPR 2021.
# [Dataset](#contents)
- - Dataset used: [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)
- Dataset size: 60000 colorful images in 10 classes
- Train: 50000 images
- Test: 10000 images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py
# [Features](#contents)
## [Mixed Precision(Ascend)](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend/GPU/CPU
- Prepare hardware environment with Ascend、GPU or CPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```python
├── ReNAS
├── Readme.md # descriptions about adversarial-pruning # shell script for evaluation with CPU, GPU or Ascend
├── src
│ ├──loss.py # parameter configuration
│ ├──dataset.py # creating dataset
│ ├──nasnet.py # Pruned ResNet architecture
├── eval.py # evaluation script
```
## [Training process](#contents)
To Be Done
## [Eval process](#contents)
### Usage
After installing MindSpore via the official website, you can start evaluation as follows:
### Launch
```bash
# infer example
Ascend: python eval.py --dataset_path path/to/cifar10 --platform Ascend --checkpoint_path [CHECKPOINT_PATH]
GPU: python eval.py --dataset_path path/to/cifar10 --platform GPU --checkpoint_path [CHECKPOINT_PATH]
CPU: python eval.py --dataset_path path/to/cifar10 --platform CPU --checkpoint_path [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.
### Result
```bash
result: {'acc': 0.9411057692307693} ckpt= ./resnet50-imgnet-0.65x-80.24.ckpt
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
#### NASBench101-Net on CIFAR-10
| Parameters | |
| -------------------------- | -------------------------------------- |
| Model Version | NASBench101-Net |
| uploaded Date | 03/27/2021 (month/day/year) |
| MindSpore Version | 0.6.0-alpha |
| Dataset | CIFAR-10 |
| Parameters (M) | 4.44 |
| FLOPs (G) | 1.9 |
| Accuracy (Top1) | 94.11 |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference Interface"""
import sys
import argparse
import logging
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore import context
from src.dataset import create_dataset_cifar10
from src.loss import LabelSmoothingCrossEntropy
from src.nasnet import nasbenchnet
from easydict import EasyDict as edict
root = logging.getLogger()
root.setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--data_path', type=str, default='/home/workspace/mindspore_dataset/',
metavar='DIR', help='path to dataset')
parser.add_argument('--model', default='hournas_f_c10', type=str, metavar='MODEL',
help='Name of model to train (default: "hournas_f_c10")')
parser.add_argument('--num-classes', type=int, default=10, metavar='N',
help='number of label classes (default: 10)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--ckpt', type=str, default='./nasmodel.ckpt',
help='model checkpoint to load')
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
parser.add_argument('--dataset_sink', action='store_true', default=False,
help='Data sink (default: False)')
parser.add_argument('--device_id', type=int, default=0,
help='Device ID (default: 0)')
parser.add_argument('--image-size', type=int, default=32, metavar='N',
help='input image size (default: 32)')
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
#context.set_context(mode=context.GRAPH_MODE)
context.set_context(mode=context.PYNATIVE_MODE)
if args.GPU:
context.set_context(device_target='GPU', device_id=args.device_id)
# parse model argument
assert args.model.startswith(
"hournas"), "Only Tinynet models are supported."
net = nasbenchnet()
cfg = edict({
'image_height': args.image_size,
'image_width': args.image_size,
})
cfg.batch_size = args.batch_size
val_data_url = args.data_path
val_dataset = create_dataset_cifar10(val_data_url, repeat_num=1, training=False, cifar_cfg=cfg)
loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing,
num_classes=args.num_classes)
loss.add_flags_recursive(fp32=True, fp16=False)
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, metrics=eval_metrics)
metrics = model.eval(val_dataset, dataset_sink_mode=False)
print(metrics)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,23 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.nasnet import nasbenchnet
def create_network(name, *args, **kwargs):
if name == 'nasbenchnet':
return nasbenchnet(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,200 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import math
import os
import numpy as np
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore.communication.management import get_rank, get_group_size
from mindspore.dataset.vision import Inter
import mindspore.dataset.vision.c_transforms as vision
# values that should remain constant
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# data preprocess configs
SCALE = (0.08, 1.0)
RATIO = (3./4., 4./3.)
ds.config.set_seed(1)
def split_imgs_and_labels(imgs, labels, batchInfo):
"""split data into labels and images"""
ret_imgs = []
ret_labels = []
for i, image in enumerate(imgs):
ret_imgs.append(image)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
input_size=224, color_jitter=0.4):
"""Create ImageNet training dataset"""
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
decode_op = py_vision.Decode()
type_cast_op = c_transforms.TypeCast(mstype.int32)
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
scale=SCALE, ratio=RATIO,
interpolation=Inter.BICUBIC)
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, normalize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
# batch dealing
ds_train = dataset_train.batch(batch_size,
per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
input_size=224):
"""Create ImageNet validation dataset"""
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id)
scale_size = None
if isinstance(input_size, tuple):
assert len(input_size) == 2
if input_size[-1] == input_size[-2]:
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
else:
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
else:
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
type_cast_op = c_transforms.TypeCast(mstype.int32)
decode_op = py_vision.Decode()
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
normalize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, normalize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=image_ops,
num_parallel_workers=workers)
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
dataset = dataset.repeat(1)
return dataset
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
def create_dataset_cifar10(data_home, repeat_num=1, training=True, cifar_cfg=None):
"""Data operations."""
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin")
rank_size, rank_id = _get_rank_info()
if training:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=True)
else:
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
resize_height = cifar_cfg.image_height
resize_width = cifar_cfg.image_width
# define map operations
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = vision.RandomHorizontalFlip()
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = vision.HWC2CHW()
type_cast_op = c_transforms.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
data_set = data_set.map(operations=type_cast_op, input_columns="label")
data_set = data_set.map(operations=c_trans, input_columns="image")
# apply batch operations
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logits, label):
label = self.cast(label, mstype.int32)
one_hot_label = self.onehot(label, F.shape(
logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

View File

@ -0,0 +1,159 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NASNet."""
import numpy as np
import mindspore.nn as nn
from mindspore import ops
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel)
def _bn_last(channel):
return nn.BatchNorm2d(channel)
def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class BasicCell(nn.Cell):
"""
NASNet basic cell definition.
Args:
None.
Returns:
Tensor, output tensor.
"""
expansion = 4
def __init__(self):
super(BasicCell, self).__init__()
self.conv3x3_1 = _conv3x3(128, 128)
self.bn3x3_1 = _bn(128)
self.conv3x3_2 = _conv3x3(128, 128)
self.bn3x3_2 = _bn(128)
self.conv3x3_3 = _conv3x3(128, 128)
self.bn3x3_3 = _bn(128)
self.mp = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same")
self.proj1 = _conv1x1(128, 64)
self.bn1 = _bn(64)
self.proj2 = _conv1x1(128, 64)
self.bn2 = _bn(64)
self.proj3 = _conv1x1(128, 64)
self.bn3 = _bn(64)
self.proj4 = _conv1x1(128, 64)
self.bn4 = _bn(64)
self.proj5 = _conv1x1(128, 64)
self.bn5 = _bn(64)
self.proj6 = _conv1x1(128, 64)
self.bn6 = _bn(64)
self.relu = P.ReLU()
self.concat = ops.Concat(axis=1)
def construct(self, x):
o1 = self.mp(x)
o1 = self.concat((self.relu(self.bn1(self.proj1(o1))), self.relu(self.bn2(self.proj2(x)))))
o2 = self.relu(self.bn3x3_1(self.conv3x3_1(o1)))
o2 = self.concat((self.relu(self.bn3(self.proj3(o2))), self.relu(self.bn4(self.proj4(x)))))
o3 = self.relu(self.bn3x3_2(self.conv3x3_2(o2)))
o4 = self.relu(self.bn3x3_3(self.conv3x3_3(x)))
out = self.concat((self.relu(self.bn5(self.proj5(o3))), self.relu(self.bn6(self.proj6(o4)))))
return out
class NasBenchNet(nn.Cell):
"""
NASNet architecture.
Args:
cell (Cell): Cell for network.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
"""
def __init__(self,
cell,
num_classes=10):
super(NasBenchNet, self).__init__()
self.conv1 = _conv3x3(3, 128)
self.bn1 = _bn(128)
self.mp = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid")
self.block1 = self._make_block(cell)
self.block2 = self._make_block(cell)
self.block3 = self._make_block(cell)
self.linear = _fc(128, num_classes)
self.ap = nn.AvgPool2d(kernel_size=8, pad_mode='valid')
self.relu = P.ReLU()
self.flatten = nn.Flatten()
def _make_block(self, cell):
layers = []
for _ in range(3):
layers.append(cell())
return nn.SequentialCell(layers)
def construct(self, x):
"""construct"""
out = self.relu(self.bn1(self.conv1(x)))
out = self.block1(out)
out = self.mp(out)
out = self.block2(out)
out = self.mp(out)
out = self.block3(out)
out = self.ap(out)
out = self.flatten(out)
out = self.linear(out)
return out
def nasbenchnet():
return NasBenchNet(BasicCell)