forked from mindspore-Ecosystem/mindspore
!6611 Add random seed in deeplabv3
Merge pull request !6611 from jiangzhenguang/add_set_random_in_deeplabv3
This commit is contained in:
commit
11f0b85eb3
|
@ -15,6 +15,7 @@
|
||||||
- [Model Description](#model-description)
|
- [Model Description](#model-description)
|
||||||
- [Performance](#performance)
|
- [Performance](#performance)
|
||||||
- [Evaluation Performance](#evaluation-performance)
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,7 +142,7 @@ run_eval_s8_multiscale_flip.sh
|
||||||
├── run_standalone_train.sh # launch ascend standalone training(1 pc)
|
├── run_standalone_train.sh # launch ascend standalone training(1 pc)
|
||||||
├── src
|
├── src
|
||||||
├── data
|
├── data
|
||||||
├── data_generator.py # mindrecord data generator
|
├── dataset.py # mindrecord data generator
|
||||||
├── build_seg_data.py # data preprocessing
|
├── build_seg_data.py # data preprocessing
|
||||||
├── loss
|
├── loss
|
||||||
├── loss.py # loss definition for deeplabv3
|
├── loss.py # loss definition for deeplabv3
|
||||||
|
@ -412,10 +413,9 @@ Note: There OS is output stride, and MS is multiscale.
|
||||||
| Checkpoint for Fine tuning | 443M (.ckpt file) |
|
| Checkpoint for Fine tuning | 443M (.ckpt file) |
|
||||||
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) |
|
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) |
|
||||||
|
|
||||||
|
# [Description of Random Situation](#contents)
|
||||||
|
In dataset.py, we set the seed inside "create_dataset" function. We also use random seed in train.py.
|
||||||
|
|
||||||
# [ModelZoo Homepage](#contents)
|
# [ModelZoo Homepage](#contents)
|
||||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,13 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.communication.management import init, get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from src.data import data_generator
|
from mindspore.common import set_seed
|
||||||
|
from src.data import dataset as data_generator
|
||||||
from src.loss import loss
|
from src.loss import loss
|
||||||
from src.nets import net_factory
|
from src.nets import net_factory
|
||||||
from src.utils import learning_rates
|
from src.utils import learning_rates
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||||
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
device_target="Ascend", device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue