forked from mindspore-Ecosystem/mindspore
!17687 repair fcn4
From: @huchunmei Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
c469efbdfe
|
@ -14,7 +14,6 @@ enable_profiling: False
|
|||
pre_trained: "/cache/data"
|
||||
coco_root: "/cache/data"
|
||||
ckpt_path: './ckpt_maskrcnn/mask_rcnn-12_7393.ckpt'
|
||||
ckpt_file: '/cache/data/cocodataset/ckpt_maskrcnn/mask_rcnn-12_7393.ckpt'
|
||||
ann_file: "./annotations/instances_val2017.json"
|
||||
# ==============================================================================
|
||||
modelarts_dataset_unzip_name: 'cocodataset'
|
||||
|
@ -172,9 +171,13 @@ dataset: "coco"
|
|||
device_id: 0
|
||||
device_num: 1
|
||||
rank_id: 0
|
||||
# batch_size_export: 1
|
||||
|
||||
# maskrcnn export
|
||||
batch_size_export: 1
|
||||
file_name: "maskrcnn"
|
||||
file_format: "AIR"
|
||||
file_format: "MINDIR"
|
||||
ckpt_file: '/cache/data/cocodataset/ckpt_maskrcnn/mask_rcnn-12_7393.ckpt'
|
||||
ckpt_file_local: './maskrcnn_b3/scripts/train_parallel0/ckpt_0/mask_rcnn-12_7393.ckpt'
|
||||
|
||||
# other
|
||||
learning_rate: 0.002
|
||||
|
|
|
@ -27,6 +27,12 @@ config.feature_shapes = [(lss[2*i], lss[2*i+1]) for i in range(int(len(lss)/2))]
|
|||
config.roi_layer = dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2)
|
||||
config.warmup_ratio = 1/3.0
|
||||
config.mask_shape = (28, 28)
|
||||
train_cls = [i for i in re.findall(r'[a-zA-Z\s]+', config.coco_classes) if i != ' ']
|
||||
config.coco_classes = np.array(train_cls)
|
||||
config.batch_size = config.batch_size_export
|
||||
|
||||
if not config.enable_modelarts:
|
||||
config.ckpt_file = config.ckpt_file_local
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
|
|
|
@ -56,7 +56,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
3. The information file of each clip should contain the label and path. Please refer to the annotations_final.csv in MagnaTagATune Dataset.
|
||||
4. The provided pre-processing script use MagnaTagATune Dataset as an example. Please modify the code accprding to your own need.
|
||||
|
||||
### 2. setup parameters (src/config.py)
|
||||
### 2. setup parameters (src/model_utils/default_config.yaml)
|
||||
|
||||
### 3. Train
|
||||
|
||||
|
@ -80,6 +80,67 @@ Then you can test your model
|
|||
SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
|
||||
```
|
||||
|
||||
- Running on [ModelArts](https://support.huaweicloud.com/modelarts/)
|
||||
|
||||
```bash
|
||||
# Train 8p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "distribute=True" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data'" on default_config.yaml file.
|
||||
# Set "checkpoint_path='/cache/data/musicTagger'" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "distribute=True" on the website UI interface.
|
||||
# Add "data_dir=/cache/data" on the website UI interface.
|
||||
# Add "checkpoint_path='/cache/data/musicTagger'" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Upload the original MusicTagger dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/fcn-4" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Train 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data'" on default_config.yaml file.
|
||||
# Set "checkpoint_path='/cache/data/musicTagger'" on default_config.yaml file.
|
||||
# (optional)Set "checkpoint_url='s3://dir_to_your_pretrained/'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_dir='/cache/data'" on the website UI interface.
|
||||
# Add "checkpoint_path='/cache/data/musicTagger'" on the website UI interface.
|
||||
# (optional)Add "checkpoint_url='s3://dir_to_your_pretrained/'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload or copy your pretrained model to S3 bucket if you want to finetune.
|
||||
# (4) Upload the original MusicTagger dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/fcn-4" on the website UI interface.
|
||||
# (6) Set the startup file to "train.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
#
|
||||
# Eval 1p with Ascend
|
||||
# (1) Perform a or b.
|
||||
# a. Set "enable_modelarts=True" on default_config.yaml file.
|
||||
# Set "data_dir='/cache/data'" on default_config.yaml file.
|
||||
# Set "checkpoint_path='/cache/data/musicTagger'" on default_config.yaml file.
|
||||
# Set other parameters on default_config.yaml file you need.
|
||||
# b. Add "enable_modelarts=True" on the website UI interface.
|
||||
# Add "data_dir='/cache/data'" on the website UI interface.
|
||||
# Add "checkpoint_path='/cache/data/musicTagger'" on the website UI interface.
|
||||
# Add other parameters on the website UI interface.
|
||||
# (3) Upload or copy your trained model to S3 bucket.
|
||||
# (4) Upload the original MusicTagger dataset to S3 bucket.
|
||||
# (5) Set the code directory to "/path/fcn-4" on the website UI interface.
|
||||
# (6) Set the startup file to "eval.py" on the website UI interface.
|
||||
# (7) Set the "Dataset path" and "Output file path" and "Job log path" to your path on the website UI interface.
|
||||
# (8) Create your job.
|
||||
```
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
@ -113,7 +174,7 @@ SLOG_PRINT_TO_STDOUT=1 python eval.py --device_id 0
|
|||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
Parameters for both training and evaluation can be set in default_config.yaml
|
||||
|
||||
- config for FCN-4
|
||||
|
||||
|
|
|
@ -40,6 +40,10 @@ checkpoint_path: "/cache/data/musicTagger"
|
|||
prefix: 'MusicTagger'
|
||||
model_name: 'MusicTagger-10_543.ckpt'
|
||||
|
||||
# export
|
||||
file_name: "/cache/data/musicTagger/fcn-4.air"
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
# ==============================================================================
|
||||
# What is the meaning of separating the two dictionaries in the original config file?
|
||||
pre_trained: False
|
||||
lr: 0.0005
|
||||
batch_size: 32
|
||||
epoch_size: 10
|
||||
loss_scale: 1024.0
|
||||
num_consumer: 4
|
||||
mixed_precision: False
|
||||
train_filename: 'train.mindrecord0'
|
||||
val_filename: 'val.mindrecord0'
|
||||
data_dir: "/cache/data"
|
||||
device_target: 'Ascend'
|
||||
device_id: 0
|
||||
keep_checkpoint_max: 10
|
||||
save_step: 2000
|
||||
checkpoint_path: "/cache/data/musicTagger"
|
||||
prefix: 'MusicTagger'
|
||||
model_name: 'MusicTagger-10_543.ckpt'
|
||||
|
||||
---
|
||||
# Config description for each option
|
||||
enable_modelarts: 'Whether training on modelarts, default: False'
|
||||
data_url: 'Dataset url for obs'
|
||||
train_url: 'Training output url for obs'
|
||||
data_path: 'Dataset path for local'
|
||||
output_path: 'Training output path for local'
|
||||
|
||||
device_target: 'Target device type'
|
||||
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||
file_name: 'output file name.'
|
||||
file_format: 'file format'
|
||||
|
||||
---
|
||||
device_target: ['Ascend', 'GPU', 'CPU']
|
||||
file_format: ['AIR', 'ONNX', 'MINDIR']
|
|
@ -35,8 +35,4 @@ if __name__ == "__main__":
|
|||
param_dict = load_checkpoint(config.checkpoint_path + "/" + config.model_name)
|
||||
load_param_into_net(network, param_dict)
|
||||
input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 96, 1366]).astype(np.float32)
|
||||
export(network,
|
||||
Tensor(input_data),
|
||||
filename="{}/{}.air".format(config.checkpoint_path,
|
||||
config.model_name[:-5]),
|
||||
file_format="AIR")
|
||||
export(network, Tensor(input_data), file_name=config.file_name, file_format=config.file_format)
|
||||
|
|
|
@ -19,5 +19,4 @@ __init__.py
|
|||
from . import musictagger
|
||||
from . import loss
|
||||
from . import dataset
|
||||
from . import config
|
||||
from . import pre_process_data
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
'''python dataset.py'''
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import librosa
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from mindspore import context
|
||||
from model_utils.config import config as cfg
|
||||
from model_utils.device_adapter import get_device_id
|
||||
|
||||
|
||||
def compute_melgram(audio_path, save_path='', filename='', save_npy=True):
|
||||
|
@ -192,12 +192,6 @@ def convert_to_mindrecord(info_name, file_path, store_path, mr_name,
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='get feature')
|
||||
parser.add_argument('--device_id',
|
||||
type=int,
|
||||
help='device ID',
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if cfg.get_npy:
|
||||
GetLabel(cfg.info_path, cfg.info_name)
|
||||
|
@ -211,14 +205,8 @@ if __name__ == "__main__":
|
|||
"{}/{}/".format(cfg.npy_path, d), f)
|
||||
|
||||
if cfg.get_mindrecord:
|
||||
if args.device_id is not None:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=args.device_id)
|
||||
else:
|
||||
context.set_context(device_target='Ascend',
|
||||
mode=context.GRAPH_MODE,
|
||||
device_id=cfg.device_id)
|
||||
context.set_context(device_target='Ascend', mode=context.GRAPH_MODE, device_id=get_device_id())
|
||||
|
||||
for cmn in cfg.mr_nam:
|
||||
if cmn in ['train', 'val']:
|
||||
convert_to_mindrecord('music_tagging_{}_tmp.csv'.format(cmn),
|
||||
|
|
|
@ -34,7 +34,6 @@ from src.loss import BCELoss
|
|||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
# config.ckpt_path = os.path.join(config.output_path, str(get_rank_id()), config.checkpoint_path)
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process)
|
||||
def train(model, dataset_direct, filename, columns_list, num_consumer=4,
|
||||
|
|
Loading…
Reference in New Issue