This commit is contained in:
ariaotp 2021-07-02 17:56:32 +08:00 committed by l_emon
parent fd262eb3b8
commit 53e3ce9715
19 changed files with 1470 additions and 0 deletions

View File

@ -0,0 +1,222 @@
# Contents
- [DEM Description](#DEM-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DEM Description](#contents)
Deep Embedding Model (DEM) proposes a new zero shot learning (ZSL) model, which maps the semantic space to the visual feature space. That's to say, DEM maps the low dimensional space to the high dimensional space, thus avoiding the hubness problem. And a multi-modal semantic feature fusion method is proposed, which is used for joint optimization in an end-to-end manner.
[Paper](https://arxiv.org/abs/1611.05088): Li Zhang, Tao Xiang, Shaogang Gong."Learning a Deep Embedding Model for Zero-Shot Learning" *Proceedings of the CVPR*.2017.
# [Model Architecture](#contents)
DEM uses googlenet to extract features, and then uses multimodal fusion method to train in attribute space, wordvector space and fusion space.
# [Dataset](#contents)
Dataset used: AwA, CUB. Download data from [here](https://www.robots.ox.ac.uk/~lz/DEM_cvpr2017/data.zip)
```bash
- NoteData will be processed in dataset.py
```
- The directory structure is as follows:
```bash
└─data
├─AwA_data
│ ├─attribute #attribute data
│ ├─wordvector #wordvector data
│ ├─test_googlenet_bn.mat
│ ├─test_labels.mat
│ ├─testclasses_id.mat
│ └─train_googlenet_bn.mat
└─CUB_data #The directory is similar to AwA_ data
```
# [Environment Requirements](#contents)
- Hardware(Ascend)
- Prepare hardware environment with Ascend.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```bash
# Install necessary package
pip install -r requirements.txt
# Place dataset in '/data/DEM_data', rename and unzip
mv data.zip DEM_data.zip
mv ./DEM_data.zip /data
cd /data
unzip DEM_data.zip
#1p example
# Enter the script dir and start training
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
# Enter the script dir and start evaluating
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
#8p example
sh run_distributed_train_ascend.sh [hccl configuration,.json format] CUB att /data/DEM_data
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../train_parallel/7/auto_parallel-120_11.ckpt
#Note: Word and fusion training in CUB dataset are not supported
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```bash
├── cv
├── DEM
├── README.md // descriptions about DEM
├── README_CN.md // descriptions about DEM in Chinese
├── requirements.txt // package needed
├── scripts
│ ├──run_distributed_train_ascend.sh // train in ascend with 8p
│ ├──run_standalone_train_ascend.sh // train in ascend with 1p
│ └──run_standalone_eval_ascend.sh // evaluate in ascend
├── src
│ ├──dataset.py // load dataset
│ ├──demnet.py // DEM framework
│ ├──config.py // parameter configuration
│ ├──kNN.py // k-Nearest Neighbor
│ ├──kNN_cosine.py // k-Nearest Neighbor cosine
│ ├──accuracy.py // compute accuracy
│ ├──set_parser.py // basic parameters
│ └──utils.py // functions used
├── train.py // training script
├── eval.py // evaluation script
└── export.py // exportation script
```
## [Script Parameters](#contents)
```bash
# Major parameters in train.py and set_parser.py as follows:
--device_target:Device target, default is "Ascend"
--device_id:Device ID
--distribute:Whether train under distributed environment
--device_num:The number of device used
--dataset:Dataset used, choose from "AwA", "CUB"
--train_mode:Train_mode, choose from "att"(attribute), "word"(wordvector), "fusion"
--batch_size:Training batch size.
--interval_step:the interval of printing loss
--epoch_size:The number of training epoch
--data_path:Path where the dataset is saved
--save_ckpt:Path where the ckpt is saved
--file_format:Model transformation format
```
## [Training Process](#contents)
### Training
```bash
python train.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
# or enter script dir, and run the script
sh run_standalone_train_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
# 1p example:
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../train.ckpt
# 8p example:
sh run_distributed_train_ascend.sh [hccl configuration,.json format] CUB att /data/DEM_data ../train.ckpt
```
After training, the loss value will be achieved as follows:
```bash
============== Starting Training ==============
epoch: 1 step: 100, loss is 0.24551314
epoch: 2 step: 61, loss is 0.2861852
epoch: 3 step: 22, loss is 0.2151301
...
epoch: 16 step: 115, loss is 0.13285707
epoch: 17 step: 76, loss is 0.15123637
...
```
The model checkpoint will be saved in [SAVE_CKPT], which has been designated in the script.
## [Evaluation Process](#contents)
### Evaluation
Before running the command below, please check the checkpoint path used for evaluation.
```bash
python eval.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
# or enter script dir, and run the script
sh run_standalone_eval_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
# Example:
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
```
The accuracy of the test dataset is as follows:
```bash
============== Starting Evaluating ==============
accuracy _ CUB _ att = 0.58984
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | DEM_AwA | DEM_CUB |
| ------------------ | -------------------|------------------ |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G;OS CentOS8.2 | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G;OS CentOS8.2 |
| uploaded Date | 06/18/2021 (month/day/year) | 04/26/2021 (month/day/year) |
| MindSpore Version | 1.2.0 | 1.2.0 |
| Dataset | AwA | CUB |
| Training Parameters | epoch = 100, batch_size = 64, lr=1e-5 / 1e-4 / 1e-4 |epoch = 100, batch_size = 100, lr=1e-5 |
| Optimizer | Adam | Adam |
| Loss Function | MSELoss | MSELoss |
| outputs | probability | probability |
| Training mode | attribute, wordvector, fusion | attribute |
| Speed | 24.6ms/step, 7.3ms/step, 42.1ms/step | 51.3ms/step
| Total time | 951s / 286s / 1640s | 551s
| Checkpoint for Fine tuning | 3040k / 4005k / 7426k (.ckpt file) | 3660k (.ckpt file)
| Accuracy calculation method | kNN / kNN_cosine / kNN_cosine | kNN |
## [Description of Random Situation](#contents)
In train.py, we use "dataset.Generator(shuffle=True)" to shuffle dataset.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,229 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [DEM描述](#DEM描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# DEM描述
深度嵌入模型Deep Embedding Model, DEM提出了一种新的零样本学习Zero-Shot Learning, ZSL)模型将语义空间映射到视觉特征空间即将低维空间映射到高维空间很好地避免了枢纽度hubness问题并提出一种多模态语义特征融合方法以端到端方式进行联合优化。
[论文](https://arxiv.org/abs/1611.05088) Li Zhang, Tao Xiang, Shaogang Gong."Learning a Deep Embedding Model for Zero-Shot Learning" *Proceedings of the CVPR*.2017.
# 模型架构
DEM使用GoogLeNet进行特征提取然后使用多模态融合方法分别在特征向量空间、词向量空间和融合空间进行训练。
# 数据集
使用的数据集AwA, CUB, [下载地址](https://www.robots.ox.ac.uk/~lz/DEM_cvpr2017/data.zip)
```bash
- 注数据在dataset.py中加载。
```
- 目录结构如下:
```bash
└─data
├─AwA_data
│ ├─attribute #特征向量
│ ├─wordvector #词向量
│ ├─test_googlenet_bn.mat
│ ├─test_labels.mat
│ ├─testclasses_id.mat
│ └─train_googlenet_bn.mat
└─CUB_data #结构类似AwA_data
```
# 环境要求
- 硬件(Ascend)
- 使用Ascend处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
```bash
# 安装依赖包
pip install -r requirements.txt
# 将数据集放置在'/data/DEM_data'目录下,重命名并解压
mv data.zip DEM_data.zip
mv ./DEM_data.zip /data
cd /data
unzip DEM_data.zip
#1p example
# 进入脚本目录训练DEM
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
# 进入脚本目录评估DEM
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
#8p example
sh run_distributed_train_ascend.sh [hccl配置文件,.json格式] CUB att /data/DEM_data
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../train_parallel/7/auto_parallel-120_11.ckpt
#注暂不支持CUB数据集下词向量模式(word)及混合模式(fusion)的训练
```
# 脚本说明
## 脚本及样例代码
```bash
├── cv
├── DEM
├── README.md // DEM描述
├── README_CN.md // DEM中文描述
├── requirements.txt // 需要的包
├── scripts
│ ├──run_distributed_train_ascend.sh // Ascend 8卡训练
│ ├──run_standalone_train_ascend.sh // Ascend单卡训练
│ └──run_standalone_eval_ascend.sh // Ascend评估
├── src
│ ├──dataset.py // 数据集加载
│ ├──demnet.py // DEM架构
│ ├──config.py // 参数配置
│ ├──kNN.py // k近邻算法
│ ├──kNN_cosine.py // k近邻cosine算法
│ ├──accuracy.py // 计算精度
│ ├──set_parser.py // 基本参数
│ └──utils.py // 常用函数
├── train.py // 训练脚本
├── eval.py // 精度验证脚本
└── export.py // 推理模型导出脚本
```
## 脚本参数
```bash
# train.py和set_parser.py中主要参数如下:
--device_target:运行代码的设备, 默认为"Ascend"
--device_id:运行代码设备的编号
--distribute:是否进行分布式训练
--device_num:训练设备数量
--dataset:使用的数据集, 从"AwA", "CUB"中选择
--train_mode:训练模式, 从"att"(attribute), "word"(wordvector), "fusion"中选择
--batch_size:训练批次大小
--interval_step:输出loss值的间隔
--epoch_size:训练轮数
--data_path:数据集所在路径
--save_ckpt:模型保存路径
--file_format:模型转换格式
```
## 训练过程
### 训练
```bash
python train.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
# 或进入./script目录, 运行脚本
sh run_standalone_train_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
# 单卡示例:
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
# 8卡示例:
sh run_distributed_train_ascend.sh [hccl配置文件,.json格式] CUB att /data/DEM_data
```
训练结束,损失值如下:
```bash
============== Starting Training ==============
epoch: 1 step: 100, loss is 0.24551314
epoch: 2 step: 61, loss is 0.2861852
epoch: 3 step: 22, loss is 0.2151301
...
epoch: 16 step: 115, loss is 0.13285707
epoch: 17 step: 76, loss is 0.15123637
...
```
模型检查点保存在已指定的目录[SAVE_CKPT]下。
## 评估过程
### 评估
在运行以下命令之前,请检查用于评估的检查点路径。
```bash
python eval.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
# 或进入./script目录, 运行脚本
sh run_standalone_eval_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
# 示例:
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
```
测试数据集的准确度如下:
```bash
============== Starting Evaluating ==============
accuracy _ CUB _ att = 0.58984
```
# 模型描述
## 性能
### 评估性能
| 参数 | DEM_AwA | DEM_CUB |
| ------------------ | -------------------|------------------ |
| 资源 | Ascend 910CPU 2.60GHz192核内存 755G系统 CentOS 8.2 | Ascend 910CPU 2.60GHz192核内存 755G系统 CentOS 8.2 |
| 上传日期 | 2021-05-25 | 2021-05-25 |
| MindSpore版本 | 1.2.0 | 1.2.0 |
| 数据集 | AwA | CUB |
| 训练参数 | epoch = 100, batch_size = 64, lr=1e-5 / 1e-4 / 1e-4 | epoch = 100, batch_size = 100, lr=1e-5 |
| 优化器 | Adam | Adam |
| 损失函数 | MSELoss | MSELoss |
| 输出 | 概率 | 概率 |
| 训练模式 | attribute, wordvector, fusion | attribute |
| 速度 | 24.6毫秒/步, 7.3毫秒/步, 42.1毫秒/步 | 51.3毫秒/步
| 总时长 | 951秒 / 286秒 / 1640秒 | 551秒
| 微调检查点 | 3040k / 4005k / 7426k (.ckpt文件) | 3660k (.ckpt文件)
| 精度计算方法 | kNN / kNN_cosine / kNN_cosine | kNN |
# 随机情况说明
在train.py中我们使用了dataset.Generator(shuffle=True)进行随机处理。
# ModelZoo主页
请浏览官网[主页](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo>)。

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## eval DEM ########################
python eval.py --data_path = /YourDataPath \
--dataset = AwA or CUB \
--train_mode = att, word or fusion
"""
import mindspore.nn as nn
from mindspore import context
from mindspore import load_checkpoint
from src.dataset import dataset_AwA, dataset_CUB
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg
from src.demnet import MyTrainOneStepCell, MyWithLossCell
from src.set_parser import set_parser
from src.accuracy import compute_accuracy_att, compute_accuracy_word, compute_accuracy_fusion
if __name__ == "__main__":
# Set graph mode, device id
args = set_parser()
context.set_context(mode=context.PYNATIVE_MODE, \
device_target=args.device_target, \
device_id=args.device_id)
# Loading datasets and iterators
if args.dataset == 'AwA':
train_x, train_att, train_word, \
test_x, test_att, test_word, \
test_label, test_id = dataset_AwA(args.data_path)
elif args.dataset == 'CUB':
train_att, train_x, \
test_x, test_att, \
test_label, test_id = dataset_CUB(args.data_path)
# Initialize parameters
pred_len = acc_cfg(args)
lr, weight_decay, clip_param = param_cfg(args)
save_ckpt = args.save_ckpt
# Build network
net = backbone_cfg(args)
loss_fn = nn.MSELoss(reduction='mean')
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
MyWithLossCell = withlosscell_cfg(args)
loss_net = MyWithLossCell(net, loss_fn)
train_net = MyTrainOneStepCell(loss_net, optim)
# Eval
print("============== Starting Evaluating ==============")
if args.train_mode == 'att':
load_checkpoint(save_ckpt, net)
acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)
elif args.train_mode == 'word':
load_checkpoint(save_ckpt, net)
acc = compute_accuracy_word(net, pred_len, test_word, test_x, test_id, test_label)
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)
elif args.train_mode == 'fusion':
load_checkpoint(save_ckpt, net)
acc = compute_accuracy_fusion(net, pred_len, test_att, test_word, test_x, test_id, test_label)
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train DEMnet example ########################
train DEMnet
python train.py --data_path = /YourDataPath \
--dataset = AwA or CUB \
--train_mode = att, word or fusion
"""
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import export
from mindspore import Tensor
from src.set_parser import set_parser
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg
import numpy as np
if __name__ == "__main__":
# Set graph mode, device id
args = set_parser()
context.set_context(mode=context.PYNATIVE_MODE, \
device_target=args.device_target, \
device_id=args.device_id)
# Loading datasets and iterators
if args.dataset == 'AwA':
train_x, train_att, train_word, \
test_x, test_att, test_word, \
test_label, test_id = dataset_AwA(args.data_path)
elif args.dataset == 'CUB':
train_att, train_x, \
test_x, test_att, \
test_label, test_id = dataset_CUB(args.data_path)
# Initialize parameters
num = acc_cfg(args)
lr, weight_decay, clip_param = param_cfg(args)
save_ckpt = args.save_ckpt
# Build network
net = backbone_cfg(args)
loss_fn = nn.MSELoss(reduction='mean')
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
MyWithLossCell = withlosscell_cfg(args)
loss_net = MyWithLossCell(net, loss_fn)
train_net = MyTrainOneStepCell(loss_net, optim)
print("============== Starting Exporting ==============")
if args.train_mode == 'att':
if args.dataset == 'AwA':
input0 = Tensor(np.zeros([args.batch_size, 85]), mindspore.float32)
elif args.dataset == 'CUB':
input0 = Tensor(np.zeros([args.batch_size, 312]), mindspore.float32)
export(net, input0, file_name=save_ckpt, file_format=args.file_format)
print("Successfully convert to", args.file_format)
elif args.train_mode == 'word':
input0 = Tensor(np.zeros([args.batch_size, 1000]), mindspore.float32)
export(net, input0, file_name=save_ckpt, file_format=args.file_format)
print("Successfully convert to", args.file_format)
elif args.train_mode == 'fusion':
input1 = Tensor(np.zeros([args.batch_size, 85]), mindspore.float32)
input2 = Tensor(np.zeros([args.batch_size, 1000]), mindspore.float32)
export(net, input1, input2, file_name=save_ckpt, file_format=args.file_format)
print("Successfully convert to", args.file_format)

View File

@ -0,0 +1,25 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config"""
from src.demnet import DEMNet1
def demnet(*args, **kwargs):
return DEMNet1(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == "DEM":
return demnet(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,3 @@
easydict==1.9
h5py==2.10.0
scipy==1.6.3

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
export RANK_TABLE_FILE=$1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
# remove old train_parallel files
rm -rf ../train_parallel
mkdir ../train_parallel
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
echo "device num=$DEVICE_NUM"
DATASET=$2
TRAIN_MODE=$3
DATA_PATH=$4
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i))
# mkdirs
mkdir ../train_parallel/$i
mkdir ../train_parallel/$i/src
# move files
cp ../*.py ../train_parallel/$i
cp ../src/*.py ../train_parallel/$i/src
# goto the training dirs of each training
cd ../train_parallel/$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
# input logs to env.log
env > env.log
# start training single task
python train.py --device_id=$i \
--distribute=True \
--dataset=$DATASET \
--train_mode=$TRAIN_MODE \
--data_path=$DATA_PATH &> log_train_8p.txt &
cd ../../scripts
done

View File

@ -0,0 +1,23 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATASET=$1
TRAIN_MODE=$2
DATA_PATH=$3
SAVE_CKPT=$4 # ../output/train.ckpt
python ../eval.py --dataset=$DATASET \
--train_mode=$TRAIN_MODE \
--data_path=$DATA_PATH \
--save_ckpt=$SAVE_CKPT &> log_eval.txt &

View File

@ -0,0 +1,26 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATASET=$1
TRAIN_MODE=$2
DATA_PATH=$3
SAVE_CKPT=$4 # ../output
if [ ! -d $SAVE_CKPT ];then
mkdir $SAVE_CKPT
fi
python ../train.py --dataset=$DATASET \
--train_mode=$TRAIN_MODE \
--data_path=$DATA_PATH \
--save_ckpt=$SAVE_CKPT &> log_train_1p.txt &

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
utils, will be used in train.py
"""
import numpy as np
from src import kNN
from src import kNN_cosine
def compute_accuracy_att(net, pred_len, test_att_0, test_visual_0, test_id_0, test_label_0):
att_pred_0 = net(test_att_0)
outpred = [0] * pred_len
test_label_0 = test_label_0.astype("float32")
for i in range(pred_len):
outputLabel = kNN.kNNClassify(test_visual_0[i, :], att_pred_0.asnumpy(), test_id_0, 1)
outpred[i] = outputLabel
outpred = np.array(outpred)
acc_0 = np.equal(outpred, test_label_0).mean()
return acc_0
def compute_accuracy_word(net, pred_len, test_word_0, test_visual_0, test_id_0, test_label_0):
word_pred_0 = net(test_word_0)
outpred = [0] * pred_len
test_label_0 = test_label_0.astype("float32")
for i in range(pred_len):
outputLabel = kNN_cosine.kNNClassify(test_visual_0[i, :], word_pred_0.asnumpy(), test_id_0, 1)
outpred[i] = outputLabel
outpred = np.array(outpred)
acc_0 = np.equal(outpred, test_label_0).mean()
return acc_0
def compute_accuracy_fusion(net, pred_len, test_att_0, test_word_0, test_visual_0, test_id_0, test_label_0):
fus_pred_0 = net(test_att_0, test_word_0)
outpred = [0] * pred_len
test_label_0 = test_label_0.astype("float32")
for i in range(pred_len):
outputLabel = kNN_cosine.kNNClassify(test_visual_0[i, :], fus_pred_0.asnumpy(), test_id_0, 1)
outpred[i] = outputLabel
outpred = np.array(outpred)
acc_0 = np.equal(outpred, test_label_0).mean()
return acc_0

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
awa_cfg = edict({
'lr_att': 1e-5,
'wd_att': 1e-2,
'clip_att': 0.2,
'lr_word': 1e-4,
'wd_word': 1e-3,
'clip_word': 0.5,
'lr_fusion': 1e-4,
'wd_fusion': 1e-2,
'clip_fusion': 0.5,
'batch_size': 64,
})
cub_cfg = edict({
'lr_att': 1e-5,
'wd_att': 1e-2,
'clip_att': 0.5,
'batch_size': 100,
})

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
import mindspore
from mindspore import Tensor
import h5py
import scipy.io as sio
import numpy as np
def dataset_AwA(data_path):
"""input:*.mat, output:array"""
f = sio.loadmat(data_path+'/AwA_data/train_googlenet_bn.mat')
train_x_0 = np.array(f['train_googlenet_bn'])
f = h5py.File(data_path+'/AwA_data/attribute/Z_s_con.mat', 'r')
train_att_0 = np.array(f['Z_s_con'])
f = sio.loadmat(data_path+'/AwA_data/wordvector/train_word.mat')
train_word_0 = np.array(f['train_word'])
f = sio.loadmat(data_path+'/AwA_data/test_googlenet_bn.mat')
test_x_0 = np.array(f['test_googlenet_bn'])
f = sio.loadmat(data_path+'/AwA_data/attribute/pca_te_con_10x85.mat')
test_att_0 = np.array(f['pca_te_con_10x85'])
test_att_0 = test_att_0.astype("float16")
test_att_0 = Tensor(test_att_0, mindspore.float32)
f = sio.loadmat(data_path+'/AwA_data/wordvector/test_vectors.mat')
test_word_0 = np.array(f['test_vectors'])
test_word_0 = test_word_0.astype("float16")
test_word_0 = Tensor(test_word_0, mindspore.float32)
f = sio.loadmat(data_path+'/AwA_data/test_labels.mat')
test_label_0 = np.squeeze(np.array(f['test_labels']))
f = sio.loadmat(data_path+'/AwA_data/testclasses_id.mat')
test_id_0 = np.squeeze(np.array(f['testclasses_id']))
return train_x_0, train_att_0, train_word_0, test_x_0, \
test_att_0, test_word_0, test_label_0, test_id_0
def dataset_CUB(data_path):
"""input:*.mat, output:array"""
f = sio.loadmat(data_path+'/CUB_data/train_attr.mat')
train_att_0 = np.array(f['train_attr'])
# print('train attr:', train_att.shape)
f = sio.loadmat(data_path+'/CUB_data/train_cub_googlenet_bn.mat')
train_x_0 = np.array(f['train_cub_googlenet_bn'])
# print('train x:', train_x.shape)
f = sio.loadmat(data_path+'/CUB_data/test_cub_googlenet_bn.mat')
test_x_0 = np.array(f['test_cub_googlenet_bn'])
# print('test x:', test_x.shape)
f = sio.loadmat(data_path+'/CUB_data/test_proto.mat')
test_att_0 = np.array(f['test_proto'])
test_att_0 = test_att_0.astype("float16")
test_att_0 = Tensor(test_att_0, mindspore.float32)
# print('test att:', test_att.shape)
f = sio.loadmat(data_path+'/CUB_data/test_labels_cub.mat')
test_label_0 = np.squeeze(np.array(f['test_labels_cub']))
# print('test x2label:', test_x2label)
f = sio.loadmat(data_path+'/CUB_data/testclasses_id.mat')
test_id_0 = np.squeeze(np.array(f['testclasses_id']))
# print('test att2label:', test_att2label)
return train_att_0, train_x_0, test_x_0, test_att_0, test_label_0, test_id_0
class SingleDataIterable:
"""data+label"""
def __init__(self, data, label):
self._data = data
self._label = label
def __getitem__(self, index):
item1 = self._data[index:index + 1]
item2 = self._label[index:index + 1]
return item1.astype(np.float32), item2.astype(np.float32)
def __len__(self):
return len(self._data)
class DoubleDataIterable:
"""data1+data2+label"""
def __init__(self, data1, data2, label):
self._data1 = data1
self._data2 = data2
self._label = label
def __getitem__(self, index):
item1 = self._data1[index:index + 1]
item2 = self._data2[index:index + 1]
item3 = self._label[index:index + 1]
return item1.astype(np.float32), item2.astype(np.float32), item3.astype(np.float32)
def __len__(self):
return len(self._data1)
if __name__ == "__main__":
train_att, train_x, test_x, test_att, test_label, test_id = dataset_CUB('/data/DEM_data')
print('train attr:', train_att.shape)
print('train x:', train_x.shape)
print('test x:', test_x.shape)
print('test att:', test_att.shape)
print('test label:', test_label)
print('test id:', test_id)

View File

@ -0,0 +1,128 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
DEMNet, WithLossCell and TrainOneStepCell
"""
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.context as context
from mindspore.common.initializer import Normal
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_gradients_mean, _get_parallel_mode, _get_device_num
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
class MyTanh(nn.Cell):
def __init__(self):
super(MyTanh, self).__init__()
self.tanh = P.Tanh()
def construct(self, x):
return 1.7159 * self.tanh(2 * x / 3)
class DEMNet1(nn.Cell):
"""cub+att"""
def __init__(self):
super(DEMNet1, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Dense(312, 700, weight_init=Normal(0.0008))
self.fc2 = nn.Dense(700, 1024, weight_init=Normal(0.0012))
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return x
class DEMNet2(nn.Cell):
"""awa+att"""
def __init__(self):
super(DEMNet2, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Dense(85, 700, weight_init=Normal(0.0005))
self.fc2 = nn.Dense(700, 1024, weight_init=Normal(0.0005))
def construct(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return x
class DEMNet3(nn.Cell):
"""awa+word"""
def __init__(self):
super(DEMNet3, self).__init__()
self.relu = nn.ReLU()
self.fc1 = nn.Dense(1000, 1024, weight_init=Normal(0.0005))
def construct(self, x):
x = self.relu(self.fc1(x))
return x
class DEMNet4(nn.Cell):
"""awa+fusion"""
def __init__(self):
super(DEMNet4, self).__init__()
self.relu = nn.ReLU()
self.tanh = MyTanh()
self.fc1 = nn.Dense(1000, 900, weight_init=Normal(0.0008))
self.fc2 = nn.Dense(85, 900, weight_init=Normal(0.0012))
self.fc3 = nn.Dense(900, 1024, weight_init=Normal(0.0012))
def construct(self, att, word):
word = self.tanh(self.fc1(word))
att = self.tanh(self.fc2(att))
fus = word + 3 * att
fus = self.relu(self.fc3(fus))
return fus
class MyWithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(MyWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data1, data2, label):
out = self._backbone(data1, data2)
return self._loss_fn(out, label)
class MyTrainOneStepCell(nn.Cell):
"""custom TrainOneStepCell"""
def __init__(self, network, optimizer, sens=1.0):
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (context.ParallelMode.DATA_PARALLEL, context.ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
grads = ops.clip_by_global_norm(grads, 0.2)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
k-nearest neighbor algorithm, will be used to compute accuracy in train.py
"""
import numpy as np
# create a dataset which contains 4 samples with 2 classes
def createDataSet():
group = np.array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A', 'A', 'B', 'B'] # four samples and two classes
return group, labels
def kNNClassify(newInput, dataSet, labels, k):
"""classify using kNN"""
numSamples = dataSet.shape[0]
diff = np.tile(newInput, (numSamples, 1)) - dataSet
squaredDiff = diff ** 2
squaredDist = np.sum(squaredDiff, axis=1)
distance = squaredDist ** 0.5
sortedDistIndices = np.argsort(distance)
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndices[i]]
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key
return maxIndex
#return sortedDistIndices

View File

@ -0,0 +1,50 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
k-nearest neighbor with cosine algorithm, will be used to compute accuracy in train.py
"""
import math
import numpy as np
# create a dataset which contains 4 samples with 2 classes
def createDataSet():
group = np.array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def cosine_distance(v1, v2):
"compute cosine similarity of v1 to v2: (v1 dot v2)/{||v1||*||v2||)"
v1_sq = np.inner(v1, v1)
v2_sq = np.inner(v2, v2)
dis = 1 - np.inner(v1, v2) / math.sqrt(v1_sq * v2_sq)
return dis
def kNNClassify(newInput, dataSet, labels, k):
"""classify using kNN"""
distance = [0] * dataSet.shape[0]
for i in range(dataSet.shape[0]):
distance[i] = cosine_distance(newInput, dataSet[i])
sortedDistIndices = np.argsort(distance)
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndices[i]]
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key
return maxIndex
#return sortedDistIndices

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Training parameter setting, will be used in train.py
"""
import argparse
def set_parser():
"""parser for train.py and eval.py"""
parser = argparse.ArgumentParser(description='MindSpore DEMnet Training')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='number of device which is chosen')
parser.add_argument('--distribute', type=bool, default=False, help='choice of distribute train')
parser.add_argument('--device_num', type=int, default=1, help='number of device which is used')
parser.add_argument('--dataset', type=str, default="CUB", choices=['AwA', 'CUB'],
help='dataset which is chosen to train (default: AwA)')
parser.add_argument('--train_mode', type=str, default='att', choices=['att', 'word', 'fusion'],
help='mode which is chosen to train (default: attribute)')
parser.add_argument('--batch_size', type=int, default=100, help='batch size of one step training')
parser.add_argument('--interval_step', type=int, default=500, help='the interval of printing loss')
parser.add_argument('--epoch_size', type=int, default=120, help='epoch of training')
parser.add_argument('--data_path', type=str, default='/data/DEM_data', help='path where the dataset is saved')
parser.add_argument('--save_ckpt', type=str, default='../output',
help='if is test, must provide path where the trained ckpt file')
parser.add_argument("--file_format", type=str, default="ONNX", choices=["AIR", "ONNX", "MINDIR"], help="export")
args = parser.parse_args()
return args

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
utils, will be used in train.py
"""
import mindspore.nn as nn
from src.config import awa_cfg, cub_cfg
from src.demnet import DEMNet1, DEMNet2, DEMNet3, DEMNet4, MyWithLossCell
def acc_cfg(args):
if args.dataset == 'CUB':
pred_len = 2933
elif args.dataset == 'AwA':
pred_len = 6180
return pred_len
def backbone_cfg(args):
"""set backbone"""
if args.dataset == 'CUB':
net = DEMNet1()
elif args.dataset == 'AwA':
if args.train_mode == 'att':
net = DEMNet2()
elif args.train_mode == 'word':
net = DEMNet3()
elif args.train_mode == 'fusion':
net = DEMNet4()
return net
def param_cfg(args):
"""set Hyperparameter"""
if args.dataset == 'CUB':
lr = cub_cfg.lr_att
weight_decay = cub_cfg.wd_att
clip_param = cub_cfg.clip_att
elif args.dataset == 'AwA':
if args.train_mode == 'att':
lr = awa_cfg.lr_att
weight_decay = awa_cfg.wd_att
clip_param = awa_cfg.clip_att
elif args.train_mode == 'word':
lr = awa_cfg.lr_word
weight_decay = awa_cfg.wd_word
clip_param = awa_cfg.clip_word
elif args.train_mode == 'fusion':
lr = awa_cfg.lr_fusion
weight_decay = awa_cfg.wd_fusion
clip_param = awa_cfg.clip_fusion
return lr, weight_decay, clip_param
def withlosscell_cfg(args):
if args.train_mode == 'fusion':
return MyWithLossCell
return nn.WithLossCell
def save_min_acc_cfg(args):
if args.train_mode == 'att':
save_min_acc = 0.5
elif args.train_mode == 'word':
save_min_acc = 0.7
elif args.train_mode == 'fusion':
save_min_acc = 0.7
return save_min_acc

View File

@ -0,0 +1,160 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train DEM ########################
train DEM
python train.py --data_path = /YourDataPath \
--dataset = AwA or CUB \
--train_mode = att, word or fusion
"""
import time
import sys
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore import save_checkpoint
from mindspore import dataset as ds
from mindspore import Model
from mindspore import set_seed
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import init, get_rank, get_group_size
from src.dataset import dataset_AwA, dataset_CUB, SingleDataIterable, DoubleDataIterable
from src.demnet import MyTrainOneStepCell
from src.set_parser import set_parser
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg, save_min_acc_cfg
from src.accuracy import compute_accuracy_att, compute_accuracy_word, compute_accuracy_fusion
if __name__ == "__main__":
# Set graph mode, device id
set_seed(1000)
args = set_parser()
context.set_context(mode=context.GRAPH_MODE, \
device_target=args.device_target, \
device_id=args.device_id)
if args.distribute:
init()
args.device_num = get_group_size()
rank_id = get_rank()
context.set_auto_parallel_context(parallel_mode=context.ParallelMode.AUTO_PARALLEL, \
gradients_mean=True, \
device_num=args.device_num)
else:
rank_id = 0
# Initialize parameters
pred_len = acc_cfg(args)
lr, weight_decay, clip_param = param_cfg(args)
if np.equal(args.distribute, True):
lr = lr * 5
batch_size = args.batch_size
# Loading datasets and iterators
if args.dataset == 'AwA':
train_x, train_att, train_word, \
test_x, test_att, test_word, \
test_label, test_id = dataset_AwA(args.data_path)
if args.train_mode == 'att':
custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
['label', 'data'],
num_shards=args.device_num,
shard_id=rank_id,
shuffle=True)
elif args.train_mode == 'word':
custom_data = ds.GeneratorDataset(SingleDataIterable(train_word, train_x),
['label', 'data'],
num_shards=args.device_num,
shard_id=rank_id,
shuffle=True)
elif args.train_mode == 'fusion':
custom_data = ds.GeneratorDataset(DoubleDataIterable(train_att, train_word, train_x),
['label1', 'label2', 'data'],
num_shards=args.device_num,
shard_id=rank_id,
shuffle=True)
elif args.dataset == 'CUB':
train_att, train_x, \
test_x, test_att, \
test_label, test_id = dataset_CUB(args.data_path)
if args.train_mode == 'att':
custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
['label', 'data'],
num_shards=args.device_num,
shard_id=rank_id,
shuffle=True)
elif args.train_mode == 'word':
print("Warning: Do not support word vector mode training in CUB dataset.")
print("Only attribute mode is supported in this dataset.")
sys.exit(0)
elif args.train_mode == 'fusion':
print("Warning: Do not support fusion mode training in CUB dataset.")
print("Only attribute mode is supported in this dataset.")
sys.exit(0)
# Note: Must set "drop_remainder = True" in parallel mode.
custom_data = custom_data.batch(batch_size, drop_remainder=True)
# Build network
net = backbone_cfg(args)
loss_fn = nn.MSELoss(reduction='mean')
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
MyWithLossCell = withlosscell_cfg(args)
loss_net = MyWithLossCell(net, loss_fn)
train_net = MyTrainOneStepCell(loss_net, optim)
model = Model(train_net)
# Train
start = time.time()
acc_max = 0
save_min_acc = save_min_acc_cfg(args)
save_ckpt = args.save_ckpt
ckpt_file_name = save_ckpt + '/train.ckpt'
interval_step = args.interval_step
epoch_size = args.epoch_size
print("============== Starting Training ==============")
if np.equal(args.distribute, True):
now = time.localtime()
nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
print(nowt)
loss_cb = LossMonitor(interval_step)
ckpt_config = CheckpointConfig(save_checkpoint_steps=interval_step)
ckpt_callback = ModelCheckpoint(prefix='auto_parallel', config=ckpt_config)
t1 = time.time()
model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=False)
end = time.time()
t3 = 1000 * (end - t1) / (88 * epoch_size)
print('total time:', end - start)
print('speed_8p = %.3f ms/step'%t3)
now = time.localtime()
nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
print(nowt)
else:
for i in range(epoch_size):
t1 = time.time()
model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=False)
t2 = time.time()
t3 = 1000 * (t2 - t1) / 88
if args.train_mode == 'att':
acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
elif args.train_mode == 'word':
acc = compute_accuracy_word(net, pred_len, test_word, test_x, test_id, test_label)
else:
acc = compute_accuracy_fusion(net, pred_len, test_att, test_word, test_x, test_id, test_label)
if acc > acc_max:
acc_max = acc
if acc_max > save_min_acc:
save_checkpoint(net, ckpt_file_name)
print('epoch:', i + 1, 'accuracy = %.5f'%acc, 'speed = %.3f ms/step'%t3)
end = time.time()
print("total time:", end - start)