add_dem
This commit is contained in:
parent
fd262eb3b8
commit
53e3ce9715
|
@ -0,0 +1,222 @@
|
|||
# Contents
|
||||
|
||||
- [DEM Description](#DEM-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [DEM Description](#contents)
|
||||
|
||||
Deep Embedding Model (DEM) proposes a new zero shot learning (ZSL) model, which maps the semantic space to the visual feature space. That's to say, DEM maps the low dimensional space to the high dimensional space, thus avoiding the hubness problem. And a multi-modal semantic feature fusion method is proposed, which is used for joint optimization in an end-to-end manner.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1611.05088): Li Zhang, Tao Xiang, Shaogang Gong."Learning a Deep Embedding Model for Zero-Shot Learning" *Proceedings of the CVPR*.2017.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
DEM uses googlenet to extract features, and then uses multimodal fusion method to train in attribute space, wordvector space and fusion space.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: AwA, CUB. Download data from [here](https://www.robots.ox.ac.uk/~lz/DEM_cvpr2017/data.zip)
|
||||
|
||||
```bash
|
||||
- Note:Data will be processed in dataset.py
|
||||
```
|
||||
|
||||
- The directory structure is as follows:
|
||||
|
||||
```bash
|
||||
└─data
|
||||
├─AwA_data
|
||||
│ ├─attribute #attribute data
|
||||
│ ├─wordvector #wordvector data
|
||||
│ ├─test_googlenet_bn.mat
|
||||
│ ├─test_labels.mat
|
||||
│ ├─testclasses_id.mat
|
||||
│ └─train_googlenet_bn.mat
|
||||
└─CUB_data #The directory is similar to AwA_ data
|
||||
```
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```bash
|
||||
# Install necessary package
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Place dataset in '/data/DEM_data', rename and unzip
|
||||
mv data.zip DEM_data.zip
|
||||
mv ./DEM_data.zip /data
|
||||
cd /data
|
||||
unzip DEM_data.zip
|
||||
|
||||
#1p example
|
||||
# Enter the script dir and start training
|
||||
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
|
||||
|
||||
# Enter the script dir and start evaluating
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
|
||||
|
||||
#8p example
|
||||
sh run_distributed_train_ascend.sh [hccl configuration,.json format] CUB att /data/DEM_data
|
||||
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../train_parallel/7/auto_parallel-120_11.ckpt
|
||||
|
||||
#Note: Word and fusion training in CUB dataset are not supported
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```bash
|
||||
├── cv
|
||||
├── DEM
|
||||
├── README.md // descriptions about DEM
|
||||
├── README_CN.md // descriptions about DEM in Chinese
|
||||
├── requirements.txt // package needed
|
||||
├── scripts
|
||||
│ ├──run_distributed_train_ascend.sh // train in ascend with 8p
|
||||
│ ├──run_standalone_train_ascend.sh // train in ascend with 1p
|
||||
│ └──run_standalone_eval_ascend.sh // evaluate in ascend
|
||||
├── src
|
||||
│ ├──dataset.py // load dataset
|
||||
│ ├──demnet.py // DEM framework
|
||||
│ ├──config.py // parameter configuration
|
||||
│ ├──kNN.py // k-Nearest Neighbor
|
||||
│ ├──kNN_cosine.py // k-Nearest Neighbor cosine
|
||||
│ ├──accuracy.py // compute accuracy
|
||||
│ ├──set_parser.py // basic parameters
|
||||
│ └──utils.py // functions used
|
||||
├── train.py // training script
|
||||
├── eval.py // evaluation script
|
||||
└── export.py // exportation script
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
```bash
|
||||
# Major parameters in train.py and set_parser.py as follows:
|
||||
|
||||
--device_target:Device target, default is "Ascend"
|
||||
--device_id:Device ID
|
||||
--distribute:Whether train under distributed environment
|
||||
--device_num:The number of device used
|
||||
--dataset:Dataset used, choose from "AwA", "CUB"
|
||||
--train_mode:Train_mode, choose from "att"(attribute), "word"(wordvector), "fusion"
|
||||
--batch_size:Training batch size.
|
||||
--interval_step:the interval of printing loss
|
||||
--epoch_size:The number of training epoch
|
||||
--data_path:Path where the dataset is saved
|
||||
--save_ckpt:Path where the ckpt is saved
|
||||
--file_format:Model transformation format
|
||||
|
||||
```
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
python train.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
|
||||
# or enter script dir, and run the script
|
||||
sh run_standalone_train_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
|
||||
# 1p example:
|
||||
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../train.ckpt
|
||||
# 8p example:
|
||||
sh run_distributed_train_ascend.sh [hccl configuration,.json format] CUB att /data/DEM_data ../train.ckpt
|
||||
```
|
||||
|
||||
After training, the loss value will be achieved as follows:
|
||||
|
||||
```bash
|
||||
============== Starting Training ==============
|
||||
epoch: 1 step: 100, loss is 0.24551314
|
||||
epoch: 2 step: 61, loss is 0.2861852
|
||||
epoch: 3 step: 22, loss is 0.2151301
|
||||
|
||||
|
||||
...
|
||||
|
||||
epoch: 16 step: 115, loss is 0.13285707
|
||||
epoch: 17 step: 76, loss is 0.15123637
|
||||
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
The model checkpoint will be saved in [SAVE_CKPT], which has been designated in the script.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation.
|
||||
|
||||
```bash
|
||||
python eval.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
|
||||
# or enter script dir, and run the script
|
||||
sh run_standalone_eval_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
|
||||
# Example:
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
|
||||
```
|
||||
|
||||
The accuracy of the test dataset is as follows:
|
||||
|
||||
```bash
|
||||
============== Starting Evaluating ==============
|
||||
accuracy _ CUB _ att = 0.58984
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | DEM_AwA | DEM_CUB |
|
||||
| ------------------ | -------------------|------------------ |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G;OS CentOS8.2 | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G;OS CentOS8.2 |
|
||||
| uploaded Date | 06/18/2021 (month/day/year) | 04/26/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 | 1.2.0 |
|
||||
| Dataset | AwA | CUB |
|
||||
| Training Parameters | epoch = 100, batch_size = 64, lr=1e-5 / 1e-4 / 1e-4 |epoch = 100, batch_size = 100, lr=1e-5 |
|
||||
| Optimizer | Adam | Adam |
|
||||
| Loss Function | MSELoss | MSELoss |
|
||||
| outputs | probability | probability |
|
||||
| Training mode | attribute, wordvector, fusion | attribute |
|
||||
| Speed | 24.6ms/step, 7.3ms/step, 42.1ms/step | 51.3ms/step
|
||||
| Total time | 951s / 286s / 1640s | 551s
|
||||
| Checkpoint for Fine tuning | 3040k / 4005k / 7426k (.ckpt file) | 3660k (.ckpt file)
|
||||
| Accuracy calculation method | kNN / kNN_cosine / kNN_cosine | kNN |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In train.py, we use "dataset.Generator(shuffle=True)" to shuffle dataset.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,229 @@
|
|||
# 目录
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [目录](#目录)
|
||||
- [DEM描述](#DEM描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# DEM描述
|
||||
|
||||
深度嵌入模型(Deep Embedding Model, DEM)提出了一种新的零样本学习(Zero-Shot Learning, ZSL)模型,将语义空间映射到视觉特征空间,即将低维空间映射到高维空间,很好地避免了枢纽度(hubness)问题;并提出一种多模态语义特征融合方法,以端到端方式进行联合优化。
|
||||
|
||||
[论文](https://arxiv.org/abs/1611.05088) :Li Zhang, Tao Xiang, Shaogang Gong."Learning a Deep Embedding Model for Zero-Shot Learning" *Proceedings of the CVPR*.2017.
|
||||
|
||||
# 模型架构
|
||||
|
||||
DEM使用GoogLeNet进行特征提取,然后使用多模态融合方法,分别在特征向量空间、词向量空间和融合空间进行训练。
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:AwA, CUB, [下载地址](https://www.robots.ox.ac.uk/~lz/DEM_cvpr2017/data.zip)
|
||||
|
||||
```bash
|
||||
- 注:数据在dataset.py中加载。
|
||||
```
|
||||
|
||||
- 目录结构如下:
|
||||
|
||||
```bash
|
||||
└─data
|
||||
├─AwA_data
|
||||
│ ├─attribute #特征向量
|
||||
│ ├─wordvector #词向量
|
||||
│ ├─test_googlenet_bn.mat
|
||||
│ ├─test_labels.mat
|
||||
│ ├─testclasses_id.mat
|
||||
│ └─train_googlenet_bn.mat
|
||||
└─CUB_data #结构类似AwA_data
|
||||
```
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```bash
|
||||
# 安装依赖包
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 将数据集放置在'/data/DEM_data'目录下,重命名并解压
|
||||
mv data.zip DEM_data.zip
|
||||
mv ./DEM_data.zip /data
|
||||
cd /data
|
||||
unzip DEM_data.zip
|
||||
|
||||
#1p example
|
||||
# 进入脚本目录,训练DEM
|
||||
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
|
||||
|
||||
# 进入脚本目录,评估DEM
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
|
||||
|
||||
#8p example
|
||||
sh run_distributed_train_ascend.sh [hccl配置文件,.json格式] CUB att /data/DEM_data
|
||||
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../train_parallel/7/auto_parallel-120_11.ckpt
|
||||
|
||||
#注:暂不支持CUB数据集下词向量模式(word)及混合模式(fusion)的训练
|
||||
```
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本及样例代码
|
||||
|
||||
```bash
|
||||
├── cv
|
||||
├── DEM
|
||||
├── README.md // DEM描述
|
||||
├── README_CN.md // DEM中文描述
|
||||
├── requirements.txt // 需要的包
|
||||
├── scripts
|
||||
│ ├──run_distributed_train_ascend.sh // Ascend 8卡训练
|
||||
│ ├──run_standalone_train_ascend.sh // Ascend单卡训练
|
||||
│ └──run_standalone_eval_ascend.sh // Ascend评估
|
||||
├── src
|
||||
│ ├──dataset.py // 数据集加载
|
||||
│ ├──demnet.py // DEM架构
|
||||
│ ├──config.py // 参数配置
|
||||
│ ├──kNN.py // k近邻算法
|
||||
│ ├──kNN_cosine.py // k近邻cosine算法
|
||||
│ ├──accuracy.py // 计算精度
|
||||
│ ├──set_parser.py // 基本参数
|
||||
│ └──utils.py // 常用函数
|
||||
├── train.py // 训练脚本
|
||||
├── eval.py // 精度验证脚本
|
||||
└── export.py // 推理模型导出脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
```bash
|
||||
|
||||
# train.py和set_parser.py中主要参数如下:
|
||||
|
||||
--device_target:运行代码的设备, 默认为"Ascend"
|
||||
--device_id:运行代码设备的编号
|
||||
--distribute:是否进行分布式训练
|
||||
--device_num:训练设备数量
|
||||
--dataset:使用的数据集, 从"AwA", "CUB"中选择
|
||||
--train_mode:训练模式, 从"att"(attribute), "word"(wordvector), "fusion"中选择
|
||||
--batch_size:训练批次大小
|
||||
--interval_step:输出loss值的间隔
|
||||
--epoch_size:训练轮数
|
||||
--data_path:数据集所在路径
|
||||
--save_ckpt:模型保存路径
|
||||
--file_format:模型转换格式
|
||||
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
```bash
|
||||
python train.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
|
||||
# 或进入./script目录, 运行脚本
|
||||
sh run_standalone_train_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
|
||||
# 单卡示例:
|
||||
sh run_standalone_train_ascend.sh CUB att /data/DEM_data ../output
|
||||
|
||||
# 8卡示例:
|
||||
sh run_distributed_train_ascend.sh [hccl配置文件,.json格式] CUB att /data/DEM_data
|
||||
```
|
||||
|
||||
训练结束,损失值如下:
|
||||
|
||||
```bash
|
||||
============== Starting Training ==============
|
||||
epoch: 1 step: 100, loss is 0.24551314
|
||||
epoch: 2 step: 61, loss is 0.2861852
|
||||
epoch: 3 step: 22, loss is 0.2151301
|
||||
|
||||
|
||||
...
|
||||
|
||||
epoch: 16 step: 115, loss is 0.13285707
|
||||
epoch: 17 step: 76, loss is 0.15123637
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在已指定的目录[SAVE_CKPT]下。
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。
|
||||
|
||||
```bash
|
||||
python eval.py --data_path=/YourDataPath --save_ckpt=/YourCkptPath --dataset=[AwA|CUB] --train_mode=[att|word|fusion]
|
||||
# 或进入./script目录, 运行脚本
|
||||
sh run_standalone_eval_ascend.sh [AwA|CUB] [att|word|fusion] [DATA_PATH] [SAVE_CKPT]
|
||||
# 示例:
|
||||
sh run_standalone_eval_ascend.sh CUB att /data/DEM_data ../output/train.ckpt
|
||||
```
|
||||
|
||||
测试数据集的准确度如下:
|
||||
|
||||
```bash
|
||||
============== Starting Evaluating ==============
|
||||
accuracy _ CUB _ att = 0.58984
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | DEM_AwA | DEM_CUB |
|
||||
| ------------------ | -------------------|------------------ |
|
||||
| 资源 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 CentOS 8.2 | Ascend 910;CPU 2.60GHz,192核;内存 755G;系统 CentOS 8.2 |
|
||||
| 上传日期 | 2021-05-25 | 2021-05-25 |
|
||||
| MindSpore版本 | 1.2.0 | 1.2.0 |
|
||||
| 数据集 | AwA | CUB |
|
||||
| 训练参数 | epoch = 100, batch_size = 64, lr=1e-5 / 1e-4 / 1e-4 | epoch = 100, batch_size = 100, lr=1e-5 |
|
||||
| 优化器 | Adam | Adam |
|
||||
| 损失函数 | MSELoss | MSELoss |
|
||||
| 输出 | 概率 | 概率 |
|
||||
| 训练模式 | attribute, wordvector, fusion | attribute |
|
||||
| 速度 | 24.6毫秒/步, 7.3毫秒/步, 42.1毫秒/步 | 51.3毫秒/步
|
||||
| 总时长 | 951秒 / 286秒 / 1640秒 | 551秒
|
||||
| 微调检查点 | 3040k / 4005k / 7426k (.ckpt文件) | 3660k (.ckpt文件)
|
||||
| 精度计算方法 | kNN / kNN_cosine / kNN_cosine | kNN |
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在train.py中,我们使用了dataset.Generator(shuffle=True)进行随机处理。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo>)。
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## eval DEM ########################
|
||||
python eval.py --data_path = /YourDataPath \
|
||||
--dataset = AwA or CUB \
|
||||
--train_mode = att, word or fusion
|
||||
"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import load_checkpoint
|
||||
|
||||
from src.dataset import dataset_AwA, dataset_CUB
|
||||
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg
|
||||
from src.demnet import MyTrainOneStepCell, MyWithLossCell
|
||||
from src.set_parser import set_parser
|
||||
from src.accuracy import compute_accuracy_att, compute_accuracy_word, compute_accuracy_fusion
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set graph mode, device id
|
||||
args = set_parser()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, \
|
||||
device_target=args.device_target, \
|
||||
device_id=args.device_id)
|
||||
|
||||
# Loading datasets and iterators
|
||||
if args.dataset == 'AwA':
|
||||
train_x, train_att, train_word, \
|
||||
test_x, test_att, test_word, \
|
||||
test_label, test_id = dataset_AwA(args.data_path)
|
||||
elif args.dataset == 'CUB':
|
||||
train_att, train_x, \
|
||||
test_x, test_att, \
|
||||
test_label, test_id = dataset_CUB(args.data_path)
|
||||
|
||||
# Initialize parameters
|
||||
pred_len = acc_cfg(args)
|
||||
lr, weight_decay, clip_param = param_cfg(args)
|
||||
save_ckpt = args.save_ckpt
|
||||
|
||||
# Build network
|
||||
net = backbone_cfg(args)
|
||||
loss_fn = nn.MSELoss(reduction='mean')
|
||||
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
|
||||
MyWithLossCell = withlosscell_cfg(args)
|
||||
loss_net = MyWithLossCell(net, loss_fn)
|
||||
train_net = MyTrainOneStepCell(loss_net, optim)
|
||||
|
||||
# Eval
|
||||
print("============== Starting Evaluating ==============")
|
||||
if args.train_mode == 'att':
|
||||
load_checkpoint(save_ckpt, net)
|
||||
acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
|
||||
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)
|
||||
elif args.train_mode == 'word':
|
||||
load_checkpoint(save_ckpt, net)
|
||||
acc = compute_accuracy_word(net, pred_len, test_word, test_x, test_id, test_label)
|
||||
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)
|
||||
elif args.train_mode == 'fusion':
|
||||
load_checkpoint(save_ckpt, net)
|
||||
acc = compute_accuracy_fusion(net, pred_len, test_att, test_word, test_x, test_id, test_label)
|
||||
print('accuracy _', args.dataset, '_', args.train_mode, "=", acc)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train DEMnet example ########################
|
||||
train DEMnet
|
||||
python train.py --data_path = /YourDataPath \
|
||||
--dataset = AwA or CUB \
|
||||
--train_mode = att, word or fusion
|
||||
"""
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import export
|
||||
from mindspore import Tensor
|
||||
|
||||
from src.set_parser import set_parser
|
||||
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg
|
||||
|
||||
import numpy as np
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set graph mode, device id
|
||||
args = set_parser()
|
||||
context.set_context(mode=context.PYNATIVE_MODE, \
|
||||
device_target=args.device_target, \
|
||||
device_id=args.device_id)
|
||||
|
||||
# Loading datasets and iterators
|
||||
if args.dataset == 'AwA':
|
||||
train_x, train_att, train_word, \
|
||||
test_x, test_att, test_word, \
|
||||
test_label, test_id = dataset_AwA(args.data_path)
|
||||
elif args.dataset == 'CUB':
|
||||
train_att, train_x, \
|
||||
test_x, test_att, \
|
||||
test_label, test_id = dataset_CUB(args.data_path)
|
||||
|
||||
# Initialize parameters
|
||||
num = acc_cfg(args)
|
||||
lr, weight_decay, clip_param = param_cfg(args)
|
||||
save_ckpt = args.save_ckpt
|
||||
|
||||
# Build network
|
||||
net = backbone_cfg(args)
|
||||
loss_fn = nn.MSELoss(reduction='mean')
|
||||
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
|
||||
MyWithLossCell = withlosscell_cfg(args)
|
||||
loss_net = MyWithLossCell(net, loss_fn)
|
||||
train_net = MyTrainOneStepCell(loss_net, optim)
|
||||
|
||||
print("============== Starting Exporting ==============")
|
||||
if args.train_mode == 'att':
|
||||
if args.dataset == 'AwA':
|
||||
input0 = Tensor(np.zeros([args.batch_size, 85]), mindspore.float32)
|
||||
elif args.dataset == 'CUB':
|
||||
input0 = Tensor(np.zeros([args.batch_size, 312]), mindspore.float32)
|
||||
export(net, input0, file_name=save_ckpt, file_format=args.file_format)
|
||||
print("Successfully convert to", args.file_format)
|
||||
|
||||
elif args.train_mode == 'word':
|
||||
input0 = Tensor(np.zeros([args.batch_size, 1000]), mindspore.float32)
|
||||
export(net, input0, file_name=save_ckpt, file_format=args.file_format)
|
||||
print("Successfully convert to", args.file_format)
|
||||
|
||||
elif args.train_mode == 'fusion':
|
||||
input1 = Tensor(np.zeros([args.batch_size, 85]), mindspore.float32)
|
||||
input2 = Tensor(np.zeros([args.batch_size, 1000]), mindspore.float32)
|
||||
export(net, input1, input2, file_name=save_ckpt, file_format=args.file_format)
|
||||
print("Successfully convert to", args.file_format)
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config"""
|
||||
from src.demnet import DEMNet1
|
||||
|
||||
def demnet(*args, **kwargs):
|
||||
return DEMNet1(*args, **kwargs)
|
||||
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "DEM":
|
||||
return demnet(*args, **kwargs)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
|
@ -0,0 +1,3 @@
|
|||
easydict==1.9
|
||||
h5py==2.10.0
|
||||
scipy==1.6.3
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export HCCL_CONNECT_TIMEOUT=600
|
||||
export RANK_TABLE_FILE=$1
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
|
||||
# remove old train_parallel files
|
||||
rm -rf ../train_parallel
|
||||
mkdir ../train_parallel
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
echo "device num=$DEVICE_NUM"
|
||||
|
||||
DATASET=$2
|
||||
TRAIN_MODE=$3
|
||||
DATA_PATH=$4
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=${i}
|
||||
export RANK_ID=$((rank_start + i))
|
||||
# mkdirs
|
||||
mkdir ../train_parallel/$i
|
||||
mkdir ../train_parallel/$i/src
|
||||
# move files
|
||||
cp ../*.py ../train_parallel/$i
|
||||
cp ../src/*.py ../train_parallel/$i/src
|
||||
|
||||
# goto the training dirs of each training
|
||||
cd ../train_parallel/$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
# input logs to env.log
|
||||
env > env.log
|
||||
# start training single task
|
||||
python train.py --device_id=$i \
|
||||
--distribute=True \
|
||||
--dataset=$DATASET \
|
||||
--train_mode=$TRAIN_MODE \
|
||||
--data_path=$DATA_PATH &> log_train_8p.txt &
|
||||
cd ../../scripts
|
||||
done
|
|
@ -0,0 +1,23 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATASET=$1
|
||||
TRAIN_MODE=$2
|
||||
DATA_PATH=$3
|
||||
SAVE_CKPT=$4 # ../output/train.ckpt
|
||||
python ../eval.py --dataset=$DATASET \
|
||||
--train_mode=$TRAIN_MODE \
|
||||
--data_path=$DATA_PATH \
|
||||
--save_ckpt=$SAVE_CKPT &> log_eval.txt &
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATASET=$1
|
||||
TRAIN_MODE=$2
|
||||
DATA_PATH=$3
|
||||
SAVE_CKPT=$4 # ../output
|
||||
if [ ! -d $SAVE_CKPT ];then
|
||||
mkdir $SAVE_CKPT
|
||||
fi
|
||||
python ../train.py --dataset=$DATASET \
|
||||
--train_mode=$TRAIN_MODE \
|
||||
--data_path=$DATA_PATH \
|
||||
--save_ckpt=$SAVE_CKPT &> log_train_1p.txt &
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
utils, will be used in train.py
|
||||
"""
|
||||
import numpy as np
|
||||
from src import kNN
|
||||
from src import kNN_cosine
|
||||
|
||||
def compute_accuracy_att(net, pred_len, test_att_0, test_visual_0, test_id_0, test_label_0):
|
||||
att_pred_0 = net(test_att_0)
|
||||
outpred = [0] * pred_len
|
||||
test_label_0 = test_label_0.astype("float32")
|
||||
for i in range(pred_len):
|
||||
outputLabel = kNN.kNNClassify(test_visual_0[i, :], att_pred_0.asnumpy(), test_id_0, 1)
|
||||
outpred[i] = outputLabel
|
||||
outpred = np.array(outpred)
|
||||
acc_0 = np.equal(outpred, test_label_0).mean()
|
||||
return acc_0
|
||||
|
||||
def compute_accuracy_word(net, pred_len, test_word_0, test_visual_0, test_id_0, test_label_0):
|
||||
word_pred_0 = net(test_word_0)
|
||||
outpred = [0] * pred_len
|
||||
test_label_0 = test_label_0.astype("float32")
|
||||
for i in range(pred_len):
|
||||
outputLabel = kNN_cosine.kNNClassify(test_visual_0[i, :], word_pred_0.asnumpy(), test_id_0, 1)
|
||||
outpred[i] = outputLabel
|
||||
outpred = np.array(outpred)
|
||||
acc_0 = np.equal(outpred, test_label_0).mean()
|
||||
return acc_0
|
||||
|
||||
def compute_accuracy_fusion(net, pred_len, test_att_0, test_word_0, test_visual_0, test_id_0, test_label_0):
|
||||
fus_pred_0 = net(test_att_0, test_word_0)
|
||||
outpred = [0] * pred_len
|
||||
test_label_0 = test_label_0.astype("float32")
|
||||
for i in range(pred_len):
|
||||
outputLabel = kNN_cosine.kNNClassify(test_visual_0[i, :], fus_pred_0.asnumpy(), test_id_0, 1)
|
||||
outpred[i] = outputLabel
|
||||
outpred = np.array(outpred)
|
||||
acc_0 = np.equal(outpred, test_label_0).mean()
|
||||
return acc_0
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
awa_cfg = edict({
|
||||
'lr_att': 1e-5,
|
||||
'wd_att': 1e-2,
|
||||
'clip_att': 0.2,
|
||||
|
||||
'lr_word': 1e-4,
|
||||
'wd_word': 1e-3,
|
||||
'clip_word': 0.5,
|
||||
|
||||
'lr_fusion': 1e-4,
|
||||
'wd_fusion': 1e-2,
|
||||
'clip_fusion': 0.5,
|
||||
|
||||
'batch_size': 64,
|
||||
})
|
||||
|
||||
cub_cfg = edict({
|
||||
'lr_att': 1e-5,
|
||||
'wd_att': 1e-2,
|
||||
'clip_att': 0.5,
|
||||
|
||||
'batch_size': 100,
|
||||
})
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Produce the dataset
|
||||
"""
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
import h5py
|
||||
import scipy.io as sio
|
||||
import numpy as np
|
||||
|
||||
def dataset_AwA(data_path):
|
||||
"""input:*.mat, output:array"""
|
||||
f = sio.loadmat(data_path+'/AwA_data/train_googlenet_bn.mat')
|
||||
train_x_0 = np.array(f['train_googlenet_bn'])
|
||||
|
||||
f = h5py.File(data_path+'/AwA_data/attribute/Z_s_con.mat', 'r')
|
||||
train_att_0 = np.array(f['Z_s_con'])
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/wordvector/train_word.mat')
|
||||
train_word_0 = np.array(f['train_word'])
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/test_googlenet_bn.mat')
|
||||
test_x_0 = np.array(f['test_googlenet_bn'])
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/attribute/pca_te_con_10x85.mat')
|
||||
test_att_0 = np.array(f['pca_te_con_10x85'])
|
||||
test_att_0 = test_att_0.astype("float16")
|
||||
test_att_0 = Tensor(test_att_0, mindspore.float32)
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/wordvector/test_vectors.mat')
|
||||
test_word_0 = np.array(f['test_vectors'])
|
||||
test_word_0 = test_word_0.astype("float16")
|
||||
test_word_0 = Tensor(test_word_0, mindspore.float32)
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/test_labels.mat')
|
||||
test_label_0 = np.squeeze(np.array(f['test_labels']))
|
||||
|
||||
f = sio.loadmat(data_path+'/AwA_data/testclasses_id.mat')
|
||||
test_id_0 = np.squeeze(np.array(f['testclasses_id']))
|
||||
|
||||
return train_x_0, train_att_0, train_word_0, test_x_0, \
|
||||
test_att_0, test_word_0, test_label_0, test_id_0
|
||||
|
||||
def dataset_CUB(data_path):
|
||||
"""input:*.mat, output:array"""
|
||||
f = sio.loadmat(data_path+'/CUB_data/train_attr.mat')
|
||||
train_att_0 = np.array(f['train_attr'])
|
||||
# print('train attr:', train_att.shape)
|
||||
|
||||
f = sio.loadmat(data_path+'/CUB_data/train_cub_googlenet_bn.mat')
|
||||
train_x_0 = np.array(f['train_cub_googlenet_bn'])
|
||||
# print('train x:', train_x.shape)
|
||||
|
||||
f = sio.loadmat(data_path+'/CUB_data/test_cub_googlenet_bn.mat')
|
||||
test_x_0 = np.array(f['test_cub_googlenet_bn'])
|
||||
# print('test x:', test_x.shape)
|
||||
|
||||
f = sio.loadmat(data_path+'/CUB_data/test_proto.mat')
|
||||
test_att_0 = np.array(f['test_proto'])
|
||||
test_att_0 = test_att_0.astype("float16")
|
||||
test_att_0 = Tensor(test_att_0, mindspore.float32)
|
||||
# print('test att:', test_att.shape)
|
||||
|
||||
f = sio.loadmat(data_path+'/CUB_data/test_labels_cub.mat')
|
||||
test_label_0 = np.squeeze(np.array(f['test_labels_cub']))
|
||||
# print('test x2label:', test_x2label)
|
||||
|
||||
f = sio.loadmat(data_path+'/CUB_data/testclasses_id.mat')
|
||||
test_id_0 = np.squeeze(np.array(f['testclasses_id']))
|
||||
# print('test att2label:', test_att2label)
|
||||
|
||||
return train_att_0, train_x_0, test_x_0, test_att_0, test_label_0, test_id_0
|
||||
|
||||
class SingleDataIterable:
|
||||
"""data+label"""
|
||||
def __init__(self, data, label):
|
||||
self._data = data
|
||||
self._label = label
|
||||
|
||||
def __getitem__(self, index):
|
||||
item1 = self._data[index:index + 1]
|
||||
item2 = self._label[index:index + 1]
|
||||
return item1.astype(np.float32), item2.astype(np.float32)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
class DoubleDataIterable:
|
||||
"""data1+data2+label"""
|
||||
def __init__(self, data1, data2, label):
|
||||
self._data1 = data1
|
||||
self._data2 = data2
|
||||
self._label = label
|
||||
|
||||
def __getitem__(self, index):
|
||||
item1 = self._data1[index:index + 1]
|
||||
item2 = self._data2[index:index + 1]
|
||||
item3 = self._label[index:index + 1]
|
||||
return item1.astype(np.float32), item2.astype(np.float32), item3.astype(np.float32)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_att, train_x, test_x, test_att, test_label, test_id = dataset_CUB('/data/DEM_data')
|
||||
print('train attr:', train_att.shape)
|
||||
print('train x:', train_x.shape)
|
||||
print('test x:', test_x.shape)
|
||||
print('test att:', test_att.shape)
|
||||
print('test label:', test_label)
|
||||
print('test id:', test_id)
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
DEMNet, WithLossCell and TrainOneStepCell
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
import mindspore.context as context
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.parallel._utils import _get_gradients_mean, _get_parallel_mode, _get_device_num
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
class MyTanh(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyTanh, self).__init__()
|
||||
self.tanh = P.Tanh()
|
||||
|
||||
def construct(self, x):
|
||||
return 1.7159 * self.tanh(2 * x / 3)
|
||||
|
||||
class DEMNet1(nn.Cell):
|
||||
"""cub+att"""
|
||||
def __init__(self):
|
||||
super(DEMNet1, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = nn.Dense(312, 700, weight_init=Normal(0.0008))
|
||||
self.fc2 = nn.Dense(700, 1024, weight_init=Normal(0.0012))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
return x
|
||||
|
||||
class DEMNet2(nn.Cell):
|
||||
"""awa+att"""
|
||||
def __init__(self):
|
||||
super(DEMNet2, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = nn.Dense(85, 700, weight_init=Normal(0.0005))
|
||||
self.fc2 = nn.Dense(700, 1024, weight_init=Normal(0.0005))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
return x
|
||||
|
||||
class DEMNet3(nn.Cell):
|
||||
"""awa+word"""
|
||||
def __init__(self):
|
||||
super(DEMNet3, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.fc1 = nn.Dense(1000, 1024, weight_init=Normal(0.0005))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
return x
|
||||
|
||||
class DEMNet4(nn.Cell):
|
||||
"""awa+fusion"""
|
||||
def __init__(self):
|
||||
super(DEMNet4, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.tanh = MyTanh()
|
||||
self.fc1 = nn.Dense(1000, 900, weight_init=Normal(0.0008))
|
||||
self.fc2 = nn.Dense(85, 900, weight_init=Normal(0.0012))
|
||||
self.fc3 = nn.Dense(900, 1024, weight_init=Normal(0.0012))
|
||||
|
||||
def construct(self, att, word):
|
||||
word = self.tanh(self.fc1(word))
|
||||
att = self.tanh(self.fc2(att))
|
||||
fus = word + 3 * att
|
||||
fus = self.relu(self.fc3(fus))
|
||||
return fus
|
||||
|
||||
class MyWithLossCell(nn.Cell):
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(MyWithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, data1, data2, label):
|
||||
out = self._backbone(data1, data2)
|
||||
return self._loss_fn(out, label)
|
||||
|
||||
class MyTrainOneStepCell(nn.Cell):
|
||||
"""custom TrainOneStepCell"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(MyTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (context.ParallelMode.DATA_PARALLEL, context.ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, mean, degree)
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = ops.clip_by_global_norm(grads, 0.2)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
k-nearest neighbor algorithm, will be used to compute accuracy in train.py
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# create a dataset which contains 4 samples with 2 classes
|
||||
def createDataSet():
|
||||
group = np.array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
|
||||
labels = ['A', 'A', 'B', 'B'] # four samples and two classes
|
||||
return group, labels
|
||||
|
||||
def kNNClassify(newInput, dataSet, labels, k):
|
||||
"""classify using kNN"""
|
||||
numSamples = dataSet.shape[0]
|
||||
diff = np.tile(newInput, (numSamples, 1)) - dataSet
|
||||
squaredDiff = diff ** 2
|
||||
squaredDist = np.sum(squaredDiff, axis=1)
|
||||
distance = squaredDist ** 0.5
|
||||
sortedDistIndices = np.argsort(distance)
|
||||
classCount = {}
|
||||
for i in range(k):
|
||||
voteLabel = labels[sortedDistIndices[i]]
|
||||
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
|
||||
maxCount = 0
|
||||
for key, value in classCount.items():
|
||||
if value > maxCount:
|
||||
maxCount = value
|
||||
maxIndex = key
|
||||
return maxIndex
|
||||
#return sortedDistIndices
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
k-nearest neighbor with cosine algorithm, will be used to compute accuracy in train.py
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
# create a dataset which contains 4 samples with 2 classes
|
||||
def createDataSet():
|
||||
group = np.array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
|
||||
labels = ['A', 'A', 'B', 'B']
|
||||
return group, labels
|
||||
|
||||
def cosine_distance(v1, v2):
|
||||
"compute cosine similarity of v1 to v2: (v1 dot v2)/{||v1||*||v2||)"
|
||||
v1_sq = np.inner(v1, v1)
|
||||
v2_sq = np.inner(v2, v2)
|
||||
dis = 1 - np.inner(v1, v2) / math.sqrt(v1_sq * v2_sq)
|
||||
return dis
|
||||
|
||||
def kNNClassify(newInput, dataSet, labels, k):
|
||||
"""classify using kNN"""
|
||||
distance = [0] * dataSet.shape[0]
|
||||
for i in range(dataSet.shape[0]):
|
||||
distance[i] = cosine_distance(newInput, dataSet[i])
|
||||
sortedDistIndices = np.argsort(distance)
|
||||
classCount = {}
|
||||
for i in range(k):
|
||||
voteLabel = labels[sortedDistIndices[i]]
|
||||
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
|
||||
maxCount = 0
|
||||
for key, value in classCount.items():
|
||||
if value > maxCount:
|
||||
maxCount = value
|
||||
maxIndex = key
|
||||
return maxIndex
|
||||
#return sortedDistIndices
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Training parameter setting, will be used in train.py
|
||||
"""
|
||||
import argparse
|
||||
|
||||
def set_parser():
|
||||
"""parser for train.py and eval.py"""
|
||||
parser = argparse.ArgumentParser(description='MindSpore DEMnet Training')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='number of device which is chosen')
|
||||
parser.add_argument('--distribute', type=bool, default=False, help='choice of distribute train')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='number of device which is used')
|
||||
parser.add_argument('--dataset', type=str, default="CUB", choices=['AwA', 'CUB'],
|
||||
help='dataset which is chosen to train (default: AwA)')
|
||||
parser.add_argument('--train_mode', type=str, default='att', choices=['att', 'word', 'fusion'],
|
||||
help='mode which is chosen to train (default: attribute)')
|
||||
parser.add_argument('--batch_size', type=int, default=100, help='batch size of one step training')
|
||||
parser.add_argument('--interval_step', type=int, default=500, help='the interval of printing loss')
|
||||
parser.add_argument('--epoch_size', type=int, default=120, help='epoch of training')
|
||||
parser.add_argument('--data_path', type=str, default='/data/DEM_data', help='path where the dataset is saved')
|
||||
parser.add_argument('--save_ckpt', type=str, default='../output',
|
||||
help='if is test, must provide path where the trained ckpt file')
|
||||
|
||||
parser.add_argument("--file_format", type=str, default="ONNX", choices=["AIR", "ONNX", "MINDIR"], help="export")
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
utils, will be used in train.py
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
from src.config import awa_cfg, cub_cfg
|
||||
from src.demnet import DEMNet1, DEMNet2, DEMNet3, DEMNet4, MyWithLossCell
|
||||
|
||||
def acc_cfg(args):
|
||||
if args.dataset == 'CUB':
|
||||
pred_len = 2933
|
||||
elif args.dataset == 'AwA':
|
||||
pred_len = 6180
|
||||
return pred_len
|
||||
|
||||
def backbone_cfg(args):
|
||||
"""set backbone"""
|
||||
if args.dataset == 'CUB':
|
||||
net = DEMNet1()
|
||||
elif args.dataset == 'AwA':
|
||||
if args.train_mode == 'att':
|
||||
net = DEMNet2()
|
||||
elif args.train_mode == 'word':
|
||||
net = DEMNet3()
|
||||
elif args.train_mode == 'fusion':
|
||||
net = DEMNet4()
|
||||
return net
|
||||
|
||||
def param_cfg(args):
|
||||
"""set Hyperparameter"""
|
||||
if args.dataset == 'CUB':
|
||||
lr = cub_cfg.lr_att
|
||||
weight_decay = cub_cfg.wd_att
|
||||
clip_param = cub_cfg.clip_att
|
||||
elif args.dataset == 'AwA':
|
||||
if args.train_mode == 'att':
|
||||
lr = awa_cfg.lr_att
|
||||
weight_decay = awa_cfg.wd_att
|
||||
clip_param = awa_cfg.clip_att
|
||||
elif args.train_mode == 'word':
|
||||
lr = awa_cfg.lr_word
|
||||
weight_decay = awa_cfg.wd_word
|
||||
clip_param = awa_cfg.clip_word
|
||||
elif args.train_mode == 'fusion':
|
||||
lr = awa_cfg.lr_fusion
|
||||
weight_decay = awa_cfg.wd_fusion
|
||||
clip_param = awa_cfg.clip_fusion
|
||||
return lr, weight_decay, clip_param
|
||||
|
||||
def withlosscell_cfg(args):
|
||||
if args.train_mode == 'fusion':
|
||||
return MyWithLossCell
|
||||
return nn.WithLossCell
|
||||
|
||||
def save_min_acc_cfg(args):
|
||||
if args.train_mode == 'att':
|
||||
save_min_acc = 0.5
|
||||
elif args.train_mode == 'word':
|
||||
save_min_acc = 0.7
|
||||
elif args.train_mode == 'fusion':
|
||||
save_min_acc = 0.7
|
||||
return save_min_acc
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train DEM ########################
|
||||
train DEM
|
||||
python train.py --data_path = /YourDataPath \
|
||||
--dataset = AwA or CUB \
|
||||
--train_mode = att, word or fusion
|
||||
"""
|
||||
import time
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import save_checkpoint
|
||||
from mindspore import dataset as ds
|
||||
from mindspore import Model
|
||||
from mindspore import set_seed
|
||||
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
from src.dataset import dataset_AwA, dataset_CUB, SingleDataIterable, DoubleDataIterable
|
||||
from src.demnet import MyTrainOneStepCell
|
||||
from src.set_parser import set_parser
|
||||
from src.utils import acc_cfg, backbone_cfg, param_cfg, withlosscell_cfg, save_min_acc_cfg
|
||||
from src.accuracy import compute_accuracy_att, compute_accuracy_word, compute_accuracy_fusion
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set graph mode, device id
|
||||
set_seed(1000)
|
||||
args = set_parser()
|
||||
context.set_context(mode=context.GRAPH_MODE, \
|
||||
device_target=args.device_target, \
|
||||
device_id=args.device_id)
|
||||
if args.distribute:
|
||||
init()
|
||||
args.device_num = get_group_size()
|
||||
rank_id = get_rank()
|
||||
context.set_auto_parallel_context(parallel_mode=context.ParallelMode.AUTO_PARALLEL, \
|
||||
gradients_mean=True, \
|
||||
device_num=args.device_num)
|
||||
else:
|
||||
rank_id = 0
|
||||
# Initialize parameters
|
||||
pred_len = acc_cfg(args)
|
||||
lr, weight_decay, clip_param = param_cfg(args)
|
||||
if np.equal(args.distribute, True):
|
||||
lr = lr * 5
|
||||
batch_size = args.batch_size
|
||||
|
||||
# Loading datasets and iterators
|
||||
if args.dataset == 'AwA':
|
||||
train_x, train_att, train_word, \
|
||||
test_x, test_att, test_word, \
|
||||
test_label, test_id = dataset_AwA(args.data_path)
|
||||
if args.train_mode == 'att':
|
||||
custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
|
||||
['label', 'data'],
|
||||
num_shards=args.device_num,
|
||||
shard_id=rank_id,
|
||||
shuffle=True)
|
||||
elif args.train_mode == 'word':
|
||||
custom_data = ds.GeneratorDataset(SingleDataIterable(train_word, train_x),
|
||||
['label', 'data'],
|
||||
num_shards=args.device_num,
|
||||
shard_id=rank_id,
|
||||
shuffle=True)
|
||||
elif args.train_mode == 'fusion':
|
||||
custom_data = ds.GeneratorDataset(DoubleDataIterable(train_att, train_word, train_x),
|
||||
['label1', 'label2', 'data'],
|
||||
num_shards=args.device_num,
|
||||
shard_id=rank_id,
|
||||
shuffle=True)
|
||||
elif args.dataset == 'CUB':
|
||||
train_att, train_x, \
|
||||
test_x, test_att, \
|
||||
test_label, test_id = dataset_CUB(args.data_path)
|
||||
if args.train_mode == 'att':
|
||||
custom_data = ds.GeneratorDataset(SingleDataIterable(train_att, train_x),
|
||||
['label', 'data'],
|
||||
num_shards=args.device_num,
|
||||
shard_id=rank_id,
|
||||
shuffle=True)
|
||||
elif args.train_mode == 'word':
|
||||
print("Warning: Do not support word vector mode training in CUB dataset.")
|
||||
print("Only attribute mode is supported in this dataset.")
|
||||
sys.exit(0)
|
||||
elif args.train_mode == 'fusion':
|
||||
print("Warning: Do not support fusion mode training in CUB dataset.")
|
||||
print("Only attribute mode is supported in this dataset.")
|
||||
sys.exit(0)
|
||||
# Note: Must set "drop_remainder = True" in parallel mode.
|
||||
custom_data = custom_data.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# Build network
|
||||
net = backbone_cfg(args)
|
||||
loss_fn = nn.MSELoss(reduction='mean')
|
||||
optim = nn.Adam(net.trainable_params(), lr, weight_decay)
|
||||
MyWithLossCell = withlosscell_cfg(args)
|
||||
loss_net = MyWithLossCell(net, loss_fn)
|
||||
train_net = MyTrainOneStepCell(loss_net, optim)
|
||||
model = Model(train_net)
|
||||
|
||||
# Train
|
||||
start = time.time()
|
||||
acc_max = 0
|
||||
save_min_acc = save_min_acc_cfg(args)
|
||||
save_ckpt = args.save_ckpt
|
||||
ckpt_file_name = save_ckpt + '/train.ckpt'
|
||||
interval_step = args.interval_step
|
||||
epoch_size = args.epoch_size
|
||||
print("============== Starting Training ==============")
|
||||
if np.equal(args.distribute, True):
|
||||
now = time.localtime()
|
||||
nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
|
||||
print(nowt)
|
||||
loss_cb = LossMonitor(interval_step)
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=interval_step)
|
||||
ckpt_callback = ModelCheckpoint(prefix='auto_parallel', config=ckpt_config)
|
||||
t1 = time.time()
|
||||
model.train(epoch_size, train_dataset=custom_data, callbacks=[loss_cb, ckpt_callback], dataset_sink_mode=False)
|
||||
end = time.time()
|
||||
t3 = 1000 * (end - t1) / (88 * epoch_size)
|
||||
print('total time:', end - start)
|
||||
print('speed_8p = %.3f ms/step'%t3)
|
||||
now = time.localtime()
|
||||
nowt = time.strftime("%Y-%m-%d-%H:%M:%S", now)
|
||||
print(nowt)
|
||||
else:
|
||||
for i in range(epoch_size):
|
||||
t1 = time.time()
|
||||
model.train(1, train_dataset=custom_data, callbacks=LossMonitor(interval_step), dataset_sink_mode=False)
|
||||
t2 = time.time()
|
||||
t3 = 1000 * (t2 - t1) / 88
|
||||
if args.train_mode == 'att':
|
||||
acc = compute_accuracy_att(net, pred_len, test_att, test_x, test_id, test_label)
|
||||
elif args.train_mode == 'word':
|
||||
acc = compute_accuracy_word(net, pred_len, test_word, test_x, test_id, test_label)
|
||||
else:
|
||||
acc = compute_accuracy_fusion(net, pred_len, test_att, test_word, test_x, test_id, test_label)
|
||||
if acc > acc_max:
|
||||
acc_max = acc
|
||||
if acc_max > save_min_acc:
|
||||
save_checkpoint(net, ckpt_file_name)
|
||||
print('epoch:', i + 1, 'accuracy = %.5f'%acc, 'speed = %.3f ms/step'%t3)
|
||||
end = time.time()
|
||||
print("total time:", end - start)
|
Loading…
Reference in New Issue