!6611 Add random seed in deeplabv3

Merge pull request !6611 from jiangzhenguang/add_set_random_in_deeplabv3
This commit is contained in:
mindspore-ci-bot 2020-09-22 14:46:57 +08:00 committed by Gitee
commit 11f0b85eb3
3 changed files with 10 additions and 7 deletions

View File

@ -15,6 +15,7 @@
- [Model Description](#model-description) - [Model Description](#model-description)
- [Performance](#performance) - [Performance](#performance)
- [Evaluation Performance](#evaluation-performance) - [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)
@ -141,7 +142,7 @@ run_eval_s8_multiscale_flip.sh
├── run_standalone_train.sh # launch ascend standalone training(1 pc) ├── run_standalone_train.sh # launch ascend standalone training(1 pc)
├── src ├── src
├── data ├── data
├── data_generator.py # mindrecord data generator ├── dataset.py # mindrecord data generator
├── build_seg_data.py # data preprocessing ├── build_seg_data.py # data preprocessing
├── loss ├── loss
├── loss.py # loss definition for deeplabv3 ├── loss.py # loss definition for deeplabv3
@ -412,10 +413,9 @@ Note: There OS is output stride, and MS is multiscale.
| Checkpoint for Fine tuning | 443M (.ckpt file) | | Checkpoint for Fine tuning | 443M (.ckpt file) |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) | | Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -24,10 +24,13 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import LossMonitor, TimeMonitor from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.data import data_generator from mindspore.common import set_seed
from src.data import dataset as data_generator
from src.loss import loss from src.loss import loss
from src.nets import net_factory from src.nets import net_factory
from src.utils import learning_rates from src.utils import learning_rates
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID'))) device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))