forked from mindspore-Ecosystem/mindspore
add cpu model scripts
This commit is contained in:
parent
74fe67db83
commit
64b77ac3a4
|
@ -55,8 +55,9 @@ The directory structure is as follows:
|
||||||
|
|
||||||
# [Environment Requirements](#contents)
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware(Ascend)
|
- Hardware(Ascend, CPU)
|
||||||
- Prepare hardware environment with Ascend processor.
|
- Prepare hardware environment with Ascend processor. It also supports the use of CPU processor to prepare the
|
||||||
|
hardware environment.
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
|
@ -76,9 +77,12 @@ The entire code structure is as following:
|
||||||
│ ├── run_distribute_train_base.sh // shell script for distributed training on Ascend
|
│ ├── run_distribute_train_base.sh // shell script for distributed training on Ascend
|
||||||
│ ├── run_distribute_train_beta.sh // shell script for distributed training on Ascend
|
│ ├── run_distribute_train_beta.sh // shell script for distributed training on Ascend
|
||||||
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
│ ├── run_eval.sh // shell script for evaluation on Ascend
|
||||||
|
│ ├── run_eval_cpu.sh // shell script for evaluation on CPU
|
||||||
│ ├── run_export.sh // shell script for exporting air model
|
│ ├── run_export.sh // shell script for exporting air model
|
||||||
│ ├── run_standalone_train_base.sh // shell script for standalone training on Ascend
|
│ ├── run_standalone_train_base.sh // shell script for standalone training on Ascend
|
||||||
│ ├── run_standalone_train_beta.sh // shell script for standalone training on Ascend
|
│ ├── run_standalone_train_beta.sh // shell script for standalone training on Ascend
|
||||||
|
│ ├── run_train_base_cpu.sh // shell script for training on CPU
|
||||||
|
│ ├── run_train_btae_cpu.sh // shell script for training on CPU
|
||||||
├── src
|
├── src
|
||||||
│ ├── backbone
|
│ ├── backbone
|
||||||
│ │ ├── head.py // head unit
|
│ │ ├── head.py // head unit
|
||||||
|
@ -100,8 +104,11 @@ The entire code structure is as following:
|
||||||
│ ├── local_adapter.py // local adapter
|
│ ├── local_adapter.py // local adapter
|
||||||
│ ├── moxing_adapter.py // moxing adapter
|
│ ├── moxing_adapter.py // moxing adapter
|
||||||
├─ base_config.yaml // parameter configuration
|
├─ base_config.yaml // parameter configuration
|
||||||
|
├─ base_config_cpu.yaml // parameter configuration
|
||||||
├─ beta_config.yaml // parameter configuration
|
├─ beta_config.yaml // parameter configuration
|
||||||
|
├─ beta_config_cpu.yaml // parameter configuration
|
||||||
├─ inference_config.yaml // parameter configuration
|
├─ inference_config.yaml // parameter configuration
|
||||||
|
├─ inference_config_cpu.yaml // parameter configuration
|
||||||
├─ train.py // training scripts
|
├─ train.py // training scripts
|
||||||
├─ eval.py // evaluation scripts
|
├─ eval.py // evaluation scripts
|
||||||
└─ export.py // export air model
|
└─ export.py // export air model
|
||||||
|
@ -111,7 +118,7 @@ The entire code structure is as following:
|
||||||
|
|
||||||
### Train
|
### Train
|
||||||
|
|
||||||
- Stand alone mode
|
- Stand alone mode(Ascend)
|
||||||
|
|
||||||
- base model
|
- base model
|
||||||
|
|
||||||
|
@ -171,6 +178,36 @@ The entire code structure is as following:
|
||||||
sh run_distribute_train_beta.sh ./rank_table_8p.json
|
sh run_distribute_train_beta.sh ./rank_table_8p.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Stand alone mode(CPU)
|
||||||
|
|
||||||
|
- base model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./scripts
|
||||||
|
sh run_train_base_cpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
for example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./scripts
|
||||||
|
sh run_train_base_cpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
- beta model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./scripts
|
||||||
|
sh run_train_beta_cpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
for example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./scripts
|
||||||
|
sh run_train_beta_cpu.sh
|
||||||
|
```
|
||||||
|
|
||||||
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
|
||||||
|
|
||||||
- base model
|
- base model
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: "CPU"
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
train_stage: "base"
|
||||||
|
is_distributed: 0
|
||||||
|
|
||||||
|
# dataset related
|
||||||
|
data_dir: "/cache/data/face_recognition_dataset/train_dataset/"
|
||||||
|
num_classes: 1
|
||||||
|
per_batch_size: 64
|
||||||
|
need_modelarts_dataset_unzip: True
|
||||||
|
|
||||||
|
# network structure related
|
||||||
|
backbone: "r100"
|
||||||
|
use_se: 1
|
||||||
|
emb_size: 512
|
||||||
|
act_type: "relu"
|
||||||
|
fp16: 1
|
||||||
|
pre_bn: 1
|
||||||
|
inference: 0
|
||||||
|
use_drop: 1
|
||||||
|
nc_16: 1
|
||||||
|
|
||||||
|
# loss related
|
||||||
|
margin_a: 1.0
|
||||||
|
margin_b: 0.2
|
||||||
|
margin_m: 0.3
|
||||||
|
margin_s: 64
|
||||||
|
|
||||||
|
# optimizer related
|
||||||
|
lr: 0.01
|
||||||
|
lr_scale: 1
|
||||||
|
lr_epochs: "8,14,18"
|
||||||
|
weight_decay: 0.0002
|
||||||
|
momentum: 0.9
|
||||||
|
max_epoch: 20
|
||||||
|
pretrained: ""
|
||||||
|
warmup_epochs: 0
|
||||||
|
|
||||||
|
# distributed parameter
|
||||||
|
local_rank: 0
|
||||||
|
world_size: 1
|
||||||
|
model_parallel: 0
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
log_interval: 100
|
||||||
|
ckpt_path: "outputs"
|
||||||
|
max_ckpts: -1
|
||||||
|
dynamic_init_loss_scale: 65536
|
||||||
|
ckpt_steps: 1000
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: "Whether training on modelarts, default: False"
|
||||||
|
data_url: "Url for modelarts"
|
||||||
|
train_url: "Url for modelarts"
|
||||||
|
data_path: "The location of the input data."
|
||||||
|
output_path: "The location of the output file."
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
|
||||||
|
train_stage: "Train stage, base or beta"
|
||||||
|
is_distributed: "If multi device"
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: "CPU"
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
train_stage: "beta"
|
||||||
|
is_distributed: 0
|
||||||
|
|
||||||
|
# dataset related
|
||||||
|
data_dir: "/cache/data/face_recognition_dataset/train_dataset/"
|
||||||
|
num_classes: 1
|
||||||
|
per_batch_size: 64
|
||||||
|
need_modelarts_dataset_unzip: True
|
||||||
|
|
||||||
|
# network structure related
|
||||||
|
backbone: "r100"
|
||||||
|
use_se: 0
|
||||||
|
emb_size: 256
|
||||||
|
act_type: "relu"
|
||||||
|
fp16: 1
|
||||||
|
pre_bn: 0
|
||||||
|
inference: 0
|
||||||
|
use_drop: 1
|
||||||
|
nc_16: 1
|
||||||
|
|
||||||
|
# loss related
|
||||||
|
margin_a: 1.0
|
||||||
|
margin_b: 0.2
|
||||||
|
margin_m: 0.3
|
||||||
|
margin_s: 64
|
||||||
|
|
||||||
|
# optimizer related
|
||||||
|
lr: 0.04
|
||||||
|
lr_scale: 1
|
||||||
|
lr_epochs: "8,14,18"
|
||||||
|
weight_decay: 0.0002
|
||||||
|
momentum: 0.9
|
||||||
|
max_epoch: 20
|
||||||
|
pretrained: "your_pretrained_model"
|
||||||
|
warmup_epochs: 0
|
||||||
|
|
||||||
|
# distributed parameter
|
||||||
|
local_rank: 0
|
||||||
|
world_size: 1
|
||||||
|
model_parallel: 0
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
log_interval: 100
|
||||||
|
ckpt_path: "outputs"
|
||||||
|
max_ckpts: -1
|
||||||
|
dynamic_init_loss_scale: 65536
|
||||||
|
ckpt_steps: 1000
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: "Whether training on modelarts, default: False"
|
||||||
|
data_url: "Url for modelarts"
|
||||||
|
train_url: "Url for modelarts"
|
||||||
|
data_path: "The location of the input data."
|
||||||
|
output_path: "The location of the output file."
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
||||||
|
|
||||||
|
train_stage: "Train stage, base or beta"
|
||||||
|
is_distributed: "If multi device"
|
|
@ -33,7 +33,7 @@ from model_utils.config import config
|
||||||
from model_utils.moxing_adapter import moxing_wrapper
|
from model_utils.moxing_adapter import moxing_wrapper
|
||||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=get_device_id())
|
||||||
|
|
||||||
|
|
||||||
class TxtDataset():
|
class TxtDataset():
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||||
|
enable_modelarts: False
|
||||||
|
# Url for modelarts
|
||||||
|
data_url: ""
|
||||||
|
train_url: ""
|
||||||
|
checkpoint_url: ""
|
||||||
|
# Path for local
|
||||||
|
data_path: "/cache/data"
|
||||||
|
output_path: "/cache/train"
|
||||||
|
load_path: "/cache/checkpoint_path"
|
||||||
|
device_target: "CPU"
|
||||||
|
enable_profiling: False
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Training options
|
||||||
|
|
||||||
|
# distributed parameter
|
||||||
|
is_distributed: 0
|
||||||
|
local_rank: 0
|
||||||
|
world_size: 1
|
||||||
|
|
||||||
|
# test weight
|
||||||
|
weight: 'your_test_model'
|
||||||
|
test_dir: '/cache/data/face_recognition_dataset/'
|
||||||
|
need_modelarts_dataset_unzip: True
|
||||||
|
|
||||||
|
# model define
|
||||||
|
backbone: "r100"
|
||||||
|
use_se: 0
|
||||||
|
emb_size: 256
|
||||||
|
act_type: "relu"
|
||||||
|
fp16: 1
|
||||||
|
pre_bn: 0
|
||||||
|
inference: 1
|
||||||
|
use_drop: 0
|
||||||
|
|
||||||
|
# test and dis batch size
|
||||||
|
test_batch_size: 128
|
||||||
|
dis_batch_size: 512
|
||||||
|
|
||||||
|
# log
|
||||||
|
log_interval: 100
|
||||||
|
ckpt_path: "outputs/models"
|
||||||
|
|
||||||
|
# test and dis image list
|
||||||
|
test_img_predix: ""
|
||||||
|
test_img_list: ""
|
||||||
|
dis_img_predix: ""
|
||||||
|
dis_img_list: ""
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Help description for each configuration
|
||||||
|
enable_modelarts: "Whether training on modelarts, default: False"
|
||||||
|
data_url: "Url for modelarts"
|
||||||
|
train_url: "Url for modelarts"
|
||||||
|
data_path: "The location of the input data."
|
||||||
|
output_path: "The location of the output file."
|
||||||
|
device_target: 'Target device type'
|
||||||
|
enable_profiling: 'Whether enable profiling while training, default: False'
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
USE_DEVICE_ID=0
|
||||||
|
echo 'start device '$USE_DEVICE_ID
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
|
||||||
|
EXECUTE_PATH=$(pwd)
|
||||||
|
echo *******************EXECUTE_PATH= $EXECUTE_PATH
|
||||||
|
if [ -d "${EXECUTE_PATH}/log_inference" ]; then
|
||||||
|
echo "[INFO] Delete old log_inference log files"
|
||||||
|
rm -rf ${EXECUTE_PATH}/log_inference
|
||||||
|
fi
|
||||||
|
mkdir ${EXECUTE_PATH}/log_inference
|
||||||
|
|
||||||
|
cd ${EXECUTE_PATH}/log_inference || exit
|
||||||
|
env > ${EXECUTE_PATH}/log_inference/face_recognition.log
|
||||||
|
python ${EXECUTE_PATH}/../eval.py --config_path=${EXECUTE_PATH}/../inference_config_cpu.yaml &> ${EXECUTE_PATH}/log_inference/face_recognition.log &
|
||||||
|
|
||||||
|
echo "[INFO] Start inference..."
|
|
@ -0,0 +1,45 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
|
||||||
|
USE_DEVICE_ID=0
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
|
||||||
|
EXECUTE_PATH=$(pwd)
|
||||||
|
echo *******************EXECUTE_PATH= $EXECUTE_PATH
|
||||||
|
if [ -d "${EXECUTE_PATH}/log_standalone_graph" ]; then
|
||||||
|
echo "[INFO] Delete old data_standalone log files"
|
||||||
|
rm -rf ${EXECUTE_PATH}/log_standalone_graph
|
||||||
|
fi
|
||||||
|
mkdir ${EXECUTE_PATH}/log_standalone_graph
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
|
||||||
|
mkdir -p ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
|
||||||
|
cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID || exit
|
||||||
|
|
||||||
|
env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log
|
||||||
|
python ${EXECUTE_PATH}/../train.py \
|
||||||
|
--config_path=${EXECUTE_PATH}/../base_config_cpu.yaml \
|
||||||
|
--train_stage=base \
|
||||||
|
--is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log &
|
||||||
|
|
||||||
|
echo "[INFO] Start training..."
|
|
@ -0,0 +1,44 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
dirname_path=$(dirname "$(pwd)")
|
||||||
|
echo ${dirname_path}
|
||||||
|
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||||
|
|
||||||
|
USE_DEVICE_ID=0
|
||||||
|
dev=`expr $USE_DEVICE_ID + 0`
|
||||||
|
export DEVICE_ID=$dev
|
||||||
|
|
||||||
|
EXECUTE_PATH=$(pwd)
|
||||||
|
echo *******************EXECUTE_PATH= $EXECUTE_PATH
|
||||||
|
if [ -d "${EXECUTE_PATH}/log_standalone_graph" ]; then
|
||||||
|
echo "[INFO] Delete old data_stanalone log files"
|
||||||
|
rm -rf ${EXECUTE_PATH}/log_standalone_graph
|
||||||
|
fi
|
||||||
|
mkdir ${EXECUTE_PATH}/log_standalone_graph
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
|
||||||
|
mkdir -p ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID
|
||||||
|
cd ${EXECUTE_PATH}/data_standalone_log_$USE_DEVICE_ID || exit
|
||||||
|
|
||||||
|
env > ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log
|
||||||
|
python ${EXECUTE_PATH}/../train.py \
|
||||||
|
--config_path=${EXECUTE_PATH}/../beta_config_cpu.yaml \
|
||||||
|
--train_stage=beta \
|
||||||
|
--is_distributed=0 &> ${EXECUTE_PATH}/log_standalone_graph/face_recognition_$USE_DEVICE_ID.log &
|
||||||
|
|
||||||
|
echo "[INFO] Start training..."
|
|
@ -21,6 +21,7 @@ from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image, ImageFile
|
from PIL import Image, ImageFile
|
||||||
|
from utils.config import config
|
||||||
from mindspore.communication.management import get_group_size, get_rank
|
from mindspore.communication.management import get_group_size, get_rank
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
@ -56,9 +57,14 @@ class DistributedCustomSampler:
|
||||||
self.epoch_gen = 1
|
self.epoch_gen = 1
|
||||||
|
|
||||||
def _sample_(self, indices):
|
def _sample_(self, indices):
|
||||||
|
"""sample"""
|
||||||
sampled = []
|
sampled = []
|
||||||
|
|
||||||
for indice in indices:
|
for indice in indices:
|
||||||
sampled_id = indice
|
sampled_id = indice
|
||||||
|
if config.device_target == 'CPU':
|
||||||
|
if self.k >= len(sampled_id):
|
||||||
|
continue
|
||||||
sampled.extend(np.random.choice(self.dataset.id2range[sampled_id][:], self.k).tolist())
|
sampled.extend(np.random.choice(self.dataset.id2range[sampled_id][:], self.k).tolist())
|
||||||
|
|
||||||
return sampled
|
return sampled
|
||||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.dataset as de
|
||||||
import mindspore.dataset.vision.py_transforms as F
|
import mindspore.dataset.vision.py_transforms as F
|
||||||
import mindspore.dataset.transforms.py_transforms as F2
|
import mindspore.dataset.transforms.py_transforms as F2
|
||||||
|
|
||||||
|
from utils.config import config
|
||||||
from src.custom_dataset import DistributedCustomSampler, CustomDataset
|
from src.custom_dataset import DistributedCustomSampler, CustomDataset
|
||||||
|
|
||||||
__all__ = ['get_de_dataset']
|
__all__ = ['get_de_dataset']
|
||||||
|
@ -44,9 +45,12 @@ def get_de_dataset(args):
|
||||||
os.makedirs(os.path.dirname(cache_path))
|
os.makedirs(os.path.dirname(cache_path))
|
||||||
dataset = CustomDataset(args.data_dir, cache_path, args.is_distributed)
|
dataset = CustomDataset(args.data_dir, cache_path, args.is_distributed)
|
||||||
args.logger.info("dataset len:{}".format(dataset.__len__()))
|
args.logger.info("dataset len:{}".format(dataset.__len__()))
|
||||||
sampler = DistributedCustomSampler(dataset, num_replicas=args.world_size, rank=args.local_rank,
|
if config.device_target == 'Ascend':
|
||||||
is_distributed=args.is_distributed)
|
sampler = DistributedCustomSampler(dataset, num_replicas=args.world_size, rank=args.local_rank,
|
||||||
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
is_distributed=args.is_distributed)
|
||||||
|
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
|
||||||
|
elif config.device_target == 'CPU':
|
||||||
|
de_dataset = de.GeneratorDataset(dataset, ["image", "label"])
|
||||||
args.logger.info("after sampler de_dataset datasize :{}".format(de_dataset.get_dataset_size()))
|
args.logger.info("after sampler de_dataset datasize :{}".format(de_dataset.get_dataset_size()))
|
||||||
de_dataset = de_dataset.map(input_columns="image", operations=transform)
|
de_dataset = de_dataset.map(input_columns="image", operations=transform)
|
||||||
de_dataset = de_dataset.map(input_columns="label", operations=transform_label)
|
de_dataset = de_dataset.map(input_columns="label", operations=transform_label)
|
||||||
|
|
|
@ -41,7 +41,7 @@ from model_utils.config import config
|
||||||
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
from model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||||
|
|
||||||
mindspore.common.seed.set_seed(1)
|
mindspore.common.seed.set_seed(1)
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False,
|
||||||
device_id=get_device_id(), reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
device_id=get_device_id(), reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
||||||
|
|
||||||
class DistributedHelper(Cell):
|
class DistributedHelper(Cell):
|
||||||
|
|
Loading…
Reference in New Issue