forked from mindspore-Ecosystem/mindspore
fix some bugs in resnet, ssd and naml
This commit is contained in:
parent
b52f0ced25
commit
26457a8ee3
|
@ -284,6 +284,8 @@ Please follow the instructions in the link [hccn_tools](https://gitee.com/mindsp
|
|||
|
||||
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the following in log.
|
||||
|
||||
If you want to change device_id for standalone training, you can set environment variable `export DEVICE_ID=x` or set `device_id=x` in context.
|
||||
|
||||
#### Running on GPU
|
||||
|
||||
```bash
|
||||
|
|
|
@ -268,6 +268,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH]
|
|||
|
||||
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
|
||||
|
||||
运行单卡用例时如果想更换运行卡号,可以通过设置环境变量 `export DEVICE_ID=x` 或者在context中设置 `device_id=x`指定相应的卡号。
|
||||
|
||||
#### GPU处理器环境运行
|
||||
|
||||
```text
|
||||
|
|
|
@ -73,7 +73,6 @@ fi
|
|||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
|
|
|
@ -391,7 +391,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.
|
|||
|
||||
|
||||
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
|
||||
is_training=True, num_parallel_workers=4, use_multiprocessing=True):
|
||||
is_training=True, num_parallel_workers=6, use_multiprocessing=True):
|
||||
"""Create SSD dataset with MindDataset."""
|
||||
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
|
||||
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)
|
||||
|
|
|
@ -99,7 +99,7 @@ def get_args(phase):
|
|||
args.n_sub_categories = cfg.n_sub_categories
|
||||
args.n_words = cfg.n_words
|
||||
if phase == "train":
|
||||
args.epochs = cfg.epochs if args.epochs is None else args.epochs * math.ceil(args.device_num ** 0.5)
|
||||
args.epochs = cfg.epochs * math.ceil(args.device_num ** 0.5) if args.epochs is None else args.epochs
|
||||
args.lr = cfg.lr if args.lr is None else args.lr
|
||||
args.print_times = cfg.print_times if args.print_times is None else args.print_times
|
||||
args.embedding_file = cfg.embedding_file.format(args.dataset_path)
|
||||
|
|
Loading…
Reference in New Issue