forked from mindspore-Ecosystem/mindspore
!5717 Modelzoo network seed interface update.
Merge pull request !5717 from linqingke/fasterrcnn
This commit is contained in:
commit
e17f3b8dd3
|
@ -30,7 +30,9 @@ from mindspore import Tensor
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||||
|
|
|
@ -21,11 +21,14 @@ from mindspore import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src.md_dataset import create_dataset
|
from src.md_dataset import create_dataset
|
||||||
from src.losses import OhemLoss
|
from src.losses import OhemLoss
|
||||||
from src.deeplabv3 import deeplabv3_resnet50
|
from src.deeplabv3 import deeplabv3_resnet50
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
||||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||||
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
||||||
|
|
|
@ -87,7 +87,9 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
|
||||||
|
|
||||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||||
|
|
||||||
Note: 1.the first run will generate the mindeocrd file, which will take a long time. 2. pretrained model is a resnet50 checkpoint that trained over ImageNet2012. 3. VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
Note: 1.the first run will generate the mindeocrd file, which will take a long time.
|
||||||
|
2.pretrained model is a resnet50 checkpoint that trained over ImageNet2012.
|
||||||
|
3.VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
||||||
|
|
||||||
```
|
```
|
||||||
# standalone training
|
# standalone training
|
||||||
|
@ -106,7 +108,7 @@ sh run_eval_ascend.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
.
|
.
|
||||||
└─FasterRcnn
|
└─faster_rcnn
|
||||||
├─README.md // descriptions about fasterrcnn
|
├─README.md // descriptions about fasterrcnn
|
||||||
├─scripts
|
├─scripts
|
||||||
├─run_standalone_train_ascend.sh // shell script for standalone on ascend
|
├─run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||||
|
@ -148,6 +150,7 @@ sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL]
|
||||||
|
|
||||||
> Rank_table.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
> Rank_table.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||||
> As for PRETRAINED_MODEL,it should be a ResNet50 checkpoint that trained over ImageNet2012. Ready-made pretrained_models are not available now. Stay tuned.
|
> As for PRETRAINED_MODEL,it should be a ResNet50 checkpoint that trained over ImageNet2012. Ready-made pretrained_models are not available now. Stay tuned.
|
||||||
|
> The original dataset path needs to be in the config.py,you can select "coco_root" or "image_dir".
|
||||||
|
|
||||||
### Result
|
### Result
|
||||||
|
|
||||||
|
@ -205,10 +208,10 @@ Eval result will be stored in the example path, whose folder name is "eval". Und
|
||||||
| -------------------------- | ----------------------------------------------------------- |
|
| -------------------------- | ----------------------------------------------------------- |
|
||||||
| Model Version | V1 |
|
| Model Version | V1 |
|
||||||
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||||
| uploaded Date | 06/01/2020 (month/day/year) |
|
| uploaded Date | 08/31/2020 (month/day/year) |
|
||||||
| MindSpore Version | 0.3.0-alpha |
|
| MindSpore Version | 0.7.0-beta |
|
||||||
| Dataset | COCO2017 |
|
| Dataset | COCO2017 |
|
||||||
| Training Parameters | epoch=12, batch_size = 2 |
|
| Training Parameters | epoch=12, batch_size=2 |
|
||||||
| Optimizer | SGD |
|
| Optimizer | SGD |
|
||||||
| Loss Function | Softmax Cross Entropy ,Sigmoid Cross Entropy,SmoothL1Loss |
|
| Loss Function | Softmax Cross Entropy ,Sigmoid Cross Entropy,SmoothL1Loss |
|
||||||
| Speed | 1pc: 190 ms/step; 8pcs: 200 ms/step |
|
| Speed | 1pc: 190 ms/step; 8pcs: 200 ms/step |
|
||||||
|
@ -223,12 +226,12 @@ Eval result will be stored in the example path, whose folder name is "eval". Und
|
||||||
| ------------------- | --------------------------- |
|
| ------------------- | --------------------------- |
|
||||||
| Model Version | V1 |
|
| Model Version | V1 |
|
||||||
| Resource | Ascend 910 |
|
| Resource | Ascend 910 |
|
||||||
| Uploaded Date | 06/01/2020 (month/day/year) |
|
| Uploaded Date | 08/31/2020 (month/day/year) |
|
||||||
| MindSpore Version | 0.3.0-alpha |
|
| MindSpore Version | 0.7.0-beta |
|
||||||
| Dataset | COCO2017 |
|
| Dataset | COCO2017 |
|
||||||
| batch_size | 2 |
|
| batch_size | 2 |
|
||||||
| outputs | mAP |
|
| outputs | mAP |
|
||||||
| Accuracy | IoU=0.50: 58.6% |
|
| Accuracy | IoU=0.50: 57.6% |
|
||||||
| Model for inference | 250M (.ckpt file) |
|
| Model for inference | 250M (.ckpt file) |
|
||||||
|
|
||||||
# [ModelZoo Homepage](#contents)
|
# [ModelZoo Homepage](#contents)
|
||||||
|
|
|
@ -17,21 +17,18 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
||||||
from src.util import coco_eval, bbox2result_1image, results2json
|
from src.util import coco_eval, bbox2result_1image, results2json
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="FasterRcnn evaluation")
|
parser = argparse.ArgumentParser(description="FasterRcnn evaluation")
|
||||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||||
|
|
|
@ -19,8 +19,6 @@ import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
|
@ -30,7 +28,7 @@ from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.nn import SGD
|
from mindspore.nn import SGD
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
||||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||||
|
@ -38,9 +36,7 @@ from src.config import config
|
||||||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
||||||
from src.lr_schedule import dynamic_lr
|
from src.lr_schedule import dynamic_lr
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||||
|
@ -78,18 +74,24 @@ if __name__ == '__main__':
|
||||||
os.makedirs(mindrecord_dir)
|
os.makedirs(mindrecord_dir)
|
||||||
if args_opt.dataset == "coco":
|
if args_opt.dataset == "coco":
|
||||||
if os.path.isdir(config.coco_root):
|
if os.path.isdir(config.coco_root):
|
||||||
|
if not os.path.exists(config.coco_root):
|
||||||
|
print("Please make sure config:coco_root is valid.")
|
||||||
|
raise ValueError(config.coco_root)
|
||||||
print("Create Mindrecord. It may take some time.")
|
print("Create Mindrecord. It may take some time.")
|
||||||
data_to_mindrecord_byte_image("coco", True, prefix)
|
data_to_mindrecord_byte_image("coco", True, prefix)
|
||||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||||
else:
|
else:
|
||||||
print("coco_root not exits.")
|
print("coco_root not exits.")
|
||||||
else:
|
else:
|
||||||
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
|
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||||
|
if not os.path.exists(config.image_dir):
|
||||||
|
print("Please make sure config:image_dir is valid.")
|
||||||
|
raise ValueError(config.image_dir)
|
||||||
print("Create Mindrecord. It may take some time.")
|
print("Create Mindrecord. It may take some time.")
|
||||||
data_to_mindrecord_byte_image("other", True, prefix)
|
data_to_mindrecord_byte_image("other", True, prefix)
|
||||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||||
else:
|
else:
|
||||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
print("image_dir or anno_path not exits.")
|
||||||
|
|
||||||
while not os.path.exists(mindrecord_file + ".db"):
|
while not os.path.exists(mindrecord_file + ".db"):
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
|
@ -23,11 +23,14 @@ from mindspore import context
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import cifar_cfg as cfg
|
from src.config import cifar_cfg as cfg
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.googlenet import GoogleNet
|
from src.googlenet import GoogleNet
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='googlenet')
|
parser = argparse.ArgumentParser(description='googlenet')
|
||||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
|
@ -18,7 +18,6 @@ python train.py
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -31,13 +30,13 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import cifar_cfg as cfg
|
from src.config import cifar_cfg as cfg
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.googlenet import GoogleNet
|
from src.googlenet import GoogleNet
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
|
|
||||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||||
"""Set learning rate."""
|
"""Set learning rate."""
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
"""train_imagenet."""
|
"""train_imagenet."""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -27,9 +25,9 @@ from mindspore.nn.optim.rmsprop import RMSProp
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore import dataset as de
|
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.common.initializer import XavierUniform, initializer
|
from mindspore.common.initializer import XavierUniform, initializer
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import config_gpu, config_ascend
|
from src.config import config_gpu, config_ascend
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
|
@ -37,9 +35,7 @@ from src.inception_v3 import InceptionV3
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
from src.loss import CrossEntropy
|
from src.loss import CrossEntropy
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -94,7 +90,6 @@ if __name__ == '__main__':
|
||||||
if args_opt.platform == "Ascend":
|
if args_opt.platform == "Ascend":
|
||||||
for param in net.trainable_params():
|
for param in net.trainable_params():
|
||||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
np.random.seed(seed=1)
|
|
||||||
param.set_parameter_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
|
param.set_parameter_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
|
||||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||||
{'params': no_decayed_params},
|
{'params': no_decayed_params},
|
||||||
|
|
|
@ -29,7 +29,9 @@ from mindspore import context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||||
|
|
|
@ -28,11 +28,14 @@ from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.train.quant import quant
|
from mindspore.train.quant import quant
|
||||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.config import mnist_cfg as cfg
|
from src.config import mnist_cfg as cfg
|
||||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
from src.loss_monitor import LossMonitor
|
from src.loss_monitor import LossMonitor
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
|
|
|
@ -17,21 +17,18 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
||||||
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
|
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
|
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
|
||||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||||
|
|
|
@ -17,9 +17,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import ast
|
import ast
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context, Tensor
|
from mindspore import context, Tensor
|
||||||
|
@ -29,7 +27,7 @@ from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.nn import SGD
|
from mindspore.nn import SGD
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||||
|
@ -37,9 +35,7 @@ from src.config import config
|
||||||
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
||||||
from src.lr_schedule import dynamic_lr
|
from src.lr_schedule import dynamic_lr
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="MaskRcnn training")
|
parser = argparse.ArgumentParser(description="MaskRcnn training")
|
||||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
|
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
|
||||||
|
|
|
@ -13,16 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.train.model import ParallelMode
|
from mindspore.train.model import ParallelMode
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
from mindspore.communication.management import get_rank, init
|
from mindspore.communication.management import get_rank, init
|
||||||
from mindspore.dataset import engine as de
|
|
||||||
|
|
||||||
from src.models import Monitor
|
from src.models import Monitor
|
||||||
|
|
||||||
|
@ -84,10 +80,3 @@ def config_ckpoint(config, lr, step_size):
|
||||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
||||||
cb += [ckpt_cb]
|
cb += [ckpt_cb]
|
||||||
return cb
|
return cb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed=1):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
de.config.set_seed(seed)
|
|
||||||
|
|
|
@ -27,16 +27,17 @@ from mindspore.common import dtype as mstype
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint
|
from mindspore.train.serialization import _exec_save_checkpoint
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.dataset import create_dataset, extract_features
|
from src.dataset import create_dataset, extract_features
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
from src.config import set_config
|
from src.config import set_config
|
||||||
|
|
||||||
from src.args import train_parse_args
|
from src.args import train_parse_args
|
||||||
from src.utils import set_random_seed, context_device_init, switch_precision, config_ckpoint
|
from src.utils import context_device_init, switch_precision, config_ckpoint
|
||||||
from src.models import CrossEntropyWithLabelSmooth, define_net
|
from src.models import CrossEntropyWithLabelSmooth, define_net
|
||||||
|
|
||||||
set_random_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args_opt = train_parse_args()
|
args_opt = train_parse_args()
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -30,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint
|
||||||
from mindspore.communication.management import init, get_group_size, get_rank
|
from mindspore.communication.management import init, get_group_size, get_rank
|
||||||
from mindspore.train.quant import quant
|
from mindspore.train.quant import quant
|
||||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
|
@ -38,9 +36,7 @@ from src.utils import Monitor, CrossEntropyWithLabelSmooth
|
||||||
from src.config import config_ascend_quant, config_gpu_quant
|
from src.config import config_ascend_quant, config_gpu_quant
|
||||||
from src.mobilenetV2 import mobilenetV2
|
from src.mobilenetV2 import mobilenetV2
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -33,7 +32,7 @@ from mindspore.context import ParallelMode
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.communication.management import init, get_group_size, get_rank
|
from mindspore.communication.management import init, get_group_size, get_rank
|
||||||
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
|
@ -41,9 +40,7 @@ from src.lr_generator import get_lr
|
||||||
from src.config import config_gpu
|
from src.config import config_gpu
|
||||||
from src.mobilenetV3 import mobilenet_v3_large
|
from src.mobilenetV3 import mobilenet_v3_large
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
"""train imagenet."""
|
"""train imagenet."""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -26,7 +24,7 @@ from mindspore.nn.optim.rmsprop import RMSProp
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
|
@ -34,9 +32,7 @@ from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStep
|
||||||
from src.lr_generator import get_lr
|
from src.lr_generator import get_lr
|
||||||
|
|
||||||
|
|
||||||
random.seed(cfg.random_seed)
|
set_seed(cfg.random_seed)
|
||||||
np.random.seed(cfg.random_seed)
|
|
||||||
de.config.set_seed(cfg.random_seed)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -14,11 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train resnet."""
|
"""train resnet."""
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
@ -33,9 +31,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
||||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
if args_opt.net == "resnet50":
|
if args_opt.net == "resnet50":
|
||||||
from src.resnet import resnet50 as resnet
|
from src.resnet import resnet50 as resnet
|
||||||
|
|
|
@ -14,13 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train resnet."""
|
"""train resnet."""
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import dataset as de
|
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
from mindspore.nn.optim.momentum import Momentum
|
from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
|
@ -30,6 +27,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.communication.management import init, get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from mindspore.common import set_seed
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.common.initializer as weight_init
|
import mindspore.common.initializer as weight_init
|
||||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||||
|
@ -47,9 +45,7 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch
|
||||||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
if args_opt.net == "resnet50":
|
if args_opt.net == "resnet50":
|
||||||
from src.resnet import resnet50 as resnet
|
from src.resnet import resnet50 as resnet
|
||||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||||
from mindspore.communication.management import init
|
from mindspore.communication.management import init
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.common.initializer as weight_init
|
import mindspore.common.initializer as weight_init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
|
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
|
||||||
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
|
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
|
||||||
|
@ -39,6 +40,8 @@ from src.lr_generator import get_lr
|
||||||
from src.config import config_quant
|
from src.config import config_quant
|
||||||
from src.crossentropy import CrossEntropy
|
from src.crossentropy import CrossEntropy
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||||
|
|
|
@ -14,11 +14,9 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train resnet."""
|
"""train resnet."""
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.crossentropy import CrossEntropy
|
from src.crossentropy import CrossEntropy
|
||||||
|
@ -32,9 +30,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
||||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
target = args_opt.device_target
|
target = args_opt.device_target
|
||||||
|
|
|
@ -14,13 +14,12 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train resnet."""
|
"""train resnet."""
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||||
|
@ -46,9 +45,7 @@ else:
|
||||||
from src.thor import THOR_GPU as THOR
|
from src.thor import THOR_GPU as THOR
|
||||||
from src.config import config_gpu as config
|
from src.config import config_gpu as config
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
|
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
|
||||||
|
|
|
@ -151,7 +151,6 @@ class KaimingUniform(KaimingInit):
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
fan = _select_fan(arr, self.mode)
|
fan = _select_fan(arr, self.mode)
|
||||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||||
np.random.seed(0)
|
|
||||||
data = np.random.uniform(-bound, bound, arr.shape)
|
data = np.random.uniform(-bound, bound, arr.shape)
|
||||||
|
|
||||||
_assignment(arr, data)
|
_assignment(arr, data)
|
||||||
|
@ -179,7 +178,6 @@ class KaimingNormal(KaimingInit):
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
fan = _select_fan(arr, self.mode)
|
fan = _select_fan(arr, self.mode)
|
||||||
std = self.gain / math.sqrt(fan)
|
std = self.gain / math.sqrt(fan)
|
||||||
np.random.seed(0)
|
|
||||||
data = np.random.normal(0, std, arr.shape)
|
data = np.random.normal(0, std, arr.shape)
|
||||||
|
|
||||||
_assignment(arr, data)
|
_assignment(arr, data)
|
||||||
|
@ -195,7 +193,6 @@ def default_recurisive_init(custom_cell):
|
||||||
if cell.bias is not None:
|
if cell.bias is not None:
|
||||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
np.random.seed(0)
|
|
||||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||||
cell.bias.shape,
|
cell.bias.shape,
|
||||||
cell.bias.dtype)
|
cell.bias.dtype)
|
||||||
|
@ -206,7 +203,6 @@ def default_recurisive_init(custom_cell):
|
||||||
if cell.bias is not None:
|
if cell.bias is not None:
|
||||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
np.random.seed(0)
|
|
||||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||||
cell.bias.shape,
|
cell.bias.shape,
|
||||||
cell.bias.dtype)
|
cell.bias.dtype)
|
||||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.train.callback import CheckpointConfig, Callback
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.dataset import classification_dataset
|
from src.dataset import classification_dataset
|
||||||
from src.crossentropy import CrossEntropy
|
from src.crossentropy import CrossEntropy
|
||||||
|
@ -38,6 +39,7 @@ from src.utils.optimizers__init__ import get_param_groups
|
||||||
from src.image_classification import get_network
|
from src.image_classification import get_network
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
class BuildTrainNetwork(nn.Cell):
|
class BuildTrainNetwork(nn.Cell):
|
||||||
"""build training network"""
|
"""build training network"""
|
||||||
|
|
|
@ -16,14 +16,11 @@
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from network import ShuffleNetV2
|
from network import ShuffleNetV2
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dataset as de
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.communication.management import init, get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
@ -31,14 +28,13 @@ from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import config_gpu as cfg
|
from src.config import config_gpu as cfg
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.lr_generator import get_lr_basic
|
from src.lr_generator import get_lr_basic
|
||||||
|
|
||||||
random.seed(cfg.random_seed)
|
set_seed(cfg.random_seed)
|
||||||
np.random.seed(cfg.random_seed)
|
|
||||||
de.config.set_seed(cfg.random_seed)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Parameters utils"""
|
"""Parameters utils"""
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||||
|
|
||||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||||
|
@ -22,7 +21,6 @@ def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||||
params = network.trainable_params()
|
params = network.trainable_params()
|
||||||
for p in params:
|
for p in params:
|
||||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||||
np.random.seed(seed=1)
|
|
||||||
if initialize_mode == 'TruncatedNormal':
|
if initialize_mode == 'TruncatedNormal':
|
||||||
p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -25,12 +25,14 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMoni
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
|
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
||||||
from src.lr_schedule import get_lr
|
from src.lr_schedule import get_lr
|
||||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="SSD training")
|
parser = argparse.ArgumentParser(description="SSD training")
|
||||||
|
|
|
@ -151,7 +151,6 @@ class KaimingUniform(KaimingInit):
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
fan = _select_fan(arr, self.mode)
|
fan = _select_fan(arr, self.mode)
|
||||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||||
np.random.seed(0)
|
|
||||||
data = np.random.uniform(-bound, bound, arr.shape)
|
data = np.random.uniform(-bound, bound, arr.shape)
|
||||||
|
|
||||||
_assignment(arr, data)
|
_assignment(arr, data)
|
||||||
|
@ -179,7 +178,6 @@ class KaimingNormal(KaimingInit):
|
||||||
def _initialize(self, arr):
|
def _initialize(self, arr):
|
||||||
fan = _select_fan(arr, self.mode)
|
fan = _select_fan(arr, self.mode)
|
||||||
std = self.gain / math.sqrt(fan)
|
std = self.gain / math.sqrt(fan)
|
||||||
np.random.seed(0)
|
|
||||||
data = np.random.normal(0, std, arr.shape)
|
data = np.random.normal(0, std, arr.shape)
|
||||||
|
|
||||||
_assignment(arr, data)
|
_assignment(arr, data)
|
||||||
|
@ -195,7 +193,6 @@ def default_recurisive_init(custom_cell):
|
||||||
if cell.bias is not None:
|
if cell.bias is not None:
|
||||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
np.random.seed(0)
|
|
||||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||||
cell.bias.shape,
|
cell.bias.shape,
|
||||||
cell.bias.dtype)
|
cell.bias.dtype)
|
||||||
|
@ -206,7 +203,6 @@ def default_recurisive_init(custom_cell):
|
||||||
if cell.bias is not None:
|
if cell.bias is not None:
|
||||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||||
bound = 1 / math.sqrt(fan_in)
|
bound = 1 / math.sqrt(fan_in)
|
||||||
np.random.seed(0)
|
|
||||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||||
cell.bias.shape,
|
cell.bias.shape,
|
||||||
cell.bias.dtype)
|
cell.bias.dtype)
|
||||||
|
|
|
@ -19,9 +19,6 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
@ -33,6 +30,7 @@ from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src.dataset import vgg_create_dataset
|
from src.dataset import vgg_create_dataset
|
||||||
from src.dataset import classification_dataset
|
from src.dataset import classification_dataset
|
||||||
|
|
||||||
|
@ -45,8 +43,7 @@ from src.utils.util import get_param_groups
|
||||||
from src.vgg import vgg16
|
from src.vgg import vgg16
|
||||||
|
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(cloud_args=None):
|
def parse_args(cloud_args=None):
|
||||||
|
|
|
@ -15,11 +15,9 @@
|
||||||
"""Warpctc evaluation"""
|
"""Warpctc evaluation"""
|
||||||
import os
|
import os
|
||||||
import math as m
|
import math as m
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
@ -29,9 +27,7 @@ from src.dataset import create_dataset
|
||||||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
from src.warpctc import StackedRNN, StackedRNNForGPU
|
||||||
from src.metric import WarpCTCAccuracy
|
from src.metric import WarpCTCAccuracy
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||||
|
|
|
@ -15,12 +15,10 @@
|
||||||
"""Warpctc training"""
|
"""Warpctc training"""
|
||||||
import os
|
import os
|
||||||
import math as m
|
import math as m
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import dataset as de
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.nn.wrap import WithLossCell
|
from mindspore.nn.wrap import WithLossCell
|
||||||
|
@ -34,9 +32,7 @@ from src.warpctc import StackedRNN, StackedRNNForGPU
|
||||||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||||
from src.lr_schedule import get_lr
|
from src.lr_schedule import get_lr
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||||
|
|
|
@ -21,9 +21,6 @@ from mindspore.common.initializer import Initializer as MeInitializer
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
np.random.seed(5)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_gain(nonlinearity, param=None):
|
def calculate_gain(nonlinearity, param=None):
|
||||||
r"""Return the recommended gain value for the given nonlinearity function.
|
r"""Return the recommended gain value for the given nonlinearity function.
|
||||||
The values are as follows:
|
The values are as follows:
|
||||||
|
|
|
@ -30,6 +30,7 @@ import mindspore as ms
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore import amp
|
from mindspore import amp
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||||
from src.logger import get_logger
|
from src.logger import get_logger
|
||||||
|
@ -41,6 +42,7 @@ from src.initializer import default_recurisive_init
|
||||||
from src.config import ConfigYOLOV3DarkNet53
|
from src.config import ConfigYOLOV3DarkNet53
|
||||||
from src.util import keep_loss_fp32
|
from src.util import keep_loss_fp32
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
class BuildTrainNetwork(nn.Cell):
|
class BuildTrainNetwork(nn.Cell):
|
||||||
def __init__(self, network, criterion):
|
def __init__(self, network, criterion):
|
||||||
|
|
|
@ -21,9 +21,6 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
np.random.seed(5)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_gain(nonlinearity, param=None):
|
def calculate_gain(nonlinearity, param=None):
|
||||||
r"""Return the recommended gain value for the given nonlinearity function.
|
r"""Return the recommended gain value for the given nonlinearity function.
|
||||||
The values are as follows:
|
The values are as follows:
|
||||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.quant import quant
|
from mindspore.train.quant import quant
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||||
from src.logger import get_logger
|
from src.logger import get_logger
|
||||||
|
@ -41,6 +42,7 @@ from src.config import ConfigYOLOV3DarkNet53
|
||||||
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
|
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
|
||||||
from src.util import ShapeRecord
|
from src.util import ShapeRecord
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
devid = int(os.getenv('DEVICE_ID'))
|
devid = int(os.getenv('DEVICE_ID'))
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||||
|
|
|
@ -34,11 +34,13 @@ from mindspore.train import Model
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||||
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||||
from src.config import ConfigYOLOV3ResNet18
|
from src.config import ConfigYOLOV3ResNet18
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||||
"""Set learning rate."""
|
"""Set learning rate."""
|
||||||
|
@ -54,7 +56,7 @@ def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps
|
||||||
|
|
||||||
|
|
||||||
def init_net_param(network, init_value='ones'):
|
def init_net_param(network, init_value='ones'):
|
||||||
"""Init:wq the parameters in network."""
|
"""Init the parameters in network."""
|
||||||
params = network.trainable_params()
|
params = network.trainable_params()
|
||||||
for p in params:
|
for p in params:
|
||||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||||
|
|
|
@ -19,12 +19,14 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.config import GatConfig
|
from src.config import GatConfig
|
||||||
from src.dataset import load_and_process
|
from src.dataset import load_and_process
|
||||||
from src.gat import GAT
|
from src.gat import GAT
|
||||||
from src.utils import LossAccuracyWrapper, TrainGAT
|
from src.utils import LossAccuracyWrapper, TrainGAT
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
"""Train GAT model."""
|
"""Train GAT model."""
|
||||||
|
|
|
@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
|
||||||
from matplotlib import animation
|
from matplotlib import animation
|
||||||
from sklearn import manifold
|
from sklearn import manifold
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.gcn import GCN
|
from src.gcn import GCN
|
||||||
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||||
|
@ -55,7 +56,7 @@ def train():
|
||||||
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
|
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
np.random.seed(args_opt.seed)
|
set_seed(args_opt.seed)
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend", save_graphs=False)
|
device_target="Ascend", save_graphs=False)
|
||||||
config = ConfigGCN()
|
config = ConfigGCN()
|
||||||
|
|
|
@ -19,7 +19,6 @@ python run_pretrain.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import numpy
|
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -30,6 +29,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||||
BertTrainAccumulateStepsWithLossScaleCell
|
BertTrainAccumulateStepsWithLossScaleCell
|
||||||
from src.dataset import create_bert_dataset
|
from src.dataset import create_bert_dataset
|
||||||
|
@ -196,5 +196,5 @@ def run_pretrain():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
numpy.random.seed(0)
|
set_seed(0)
|
||||||
run_pretrain()
|
run_pretrain()
|
||||||
|
|
|
@ -19,7 +19,6 @@ python run_pretrain.py
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import numpy
|
|
||||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||||
from src.bert_net_config import bert_net_cfg
|
from src.bert_net_config import bert_net_cfg
|
||||||
from src.config import cfg
|
from src.config import cfg
|
||||||
|
@ -36,6 +35,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
|
@ -197,5 +197,5 @@ def run_pretrain():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
numpy.random.seed(0)
|
set_seed(0)
|
||||||
run_pretrain()
|
run_pretrain()
|
||||||
|
|
|
@ -30,6 +30,7 @@ from mindspore import context, Parameter
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication import management as MultiAscend
|
from mindspore.communication import management as MultiAscend
|
||||||
from mindspore.train.serialization import load_checkpoint
|
from mindspore.train.serialization import load_checkpoint
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from config import TransformerConfig
|
from config import TransformerConfig
|
||||||
from src.dataset import load_dataset
|
from src.dataset import load_dataset
|
||||||
|
@ -337,7 +338,7 @@ if __name__ == '__main__':
|
||||||
_check_args(args.config)
|
_check_args(args.config)
|
||||||
_config = get_config(args.config)
|
_config = get_config(args.config)
|
||||||
|
|
||||||
np.random.seed(_config.random_seed)
|
set_seed(_config.random_seed)
|
||||||
context.set_context(save_graphs=_config.save_graphs)
|
context.set_context(save_graphs=_config.save_graphs)
|
||||||
|
|
||||||
if _rank_size is not None and int(_rank_size) > 1:
|
if _rank_size is not None and int(_rank_size) > 1:
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import numpy
|
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
@ -28,6 +27,7 @@ from mindspore.context import ParallelMode
|
||||||
from mindspore.nn.optim import AdamWeightDecay
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from mindspore.common import set_seed
|
||||||
from src.dataset import create_tinybert_dataset, DataType
|
from src.dataset import create_tinybert_dataset, DataType
|
||||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||||
|
@ -154,5 +154,5 @@ def run_general_distill():
|
||||||
sink_size=args_opt.data_sink_steps)
|
sink_size=args_opt.data_sink_steps)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
numpy.random.seed(0)
|
set_seed(0)
|
||||||
run_general_distill()
|
run_general_distill()
|
||||||
|
|
|
@ -16,8 +16,6 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
@ -27,10 +25,10 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||||
from mindspore.train.callback import Callback, TimeMonitor
|
from mindspore.train.callback import Callback, TimeMonitor
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
import mindspore.dataset.engine as de
|
|
||||||
import mindspore.communication.management as D
|
import mindspore.communication.management as D
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||||
TransformerTrainOneStepWithLossScaleCell
|
TransformerTrainOneStepWithLossScaleCell
|
||||||
|
@ -38,10 +36,7 @@ from src.config import cfg, transformer_net_cfg
|
||||||
from src.dataset import create_transformer_dataset
|
from src.dataset import create_transformer_dataset
|
||||||
from src.lr_schedule import create_dynamic_lr
|
from src.lr_schedule import create_dynamic_lr
|
||||||
|
|
||||||
random_seed = 1
|
set_seed(1)
|
||||||
random.seed(random_seed)
|
|
||||||
np.random.seed(random_seed)
|
|
||||||
de.config.set_seed(random_seed)
|
|
||||||
|
|
||||||
def get_ms_timestamp():
|
def get_ms_timestamp():
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
|
|
@ -16,15 +16,13 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import init, get_rank, get_group_size
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
import mindspore.dataset.engine as de
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.deepfm import ModelBuilder, AUCMetric
|
from src.deepfm import ModelBuilder, AUCMetric
|
||||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||||
|
@ -46,9 +44,7 @@ args_opt, _ = parser.parse_known_args()
|
||||||
args_opt.do_eval = args_opt.do_eval == 'True'
|
args_opt.do_eval = args_opt.do_eval == 'True'
|
||||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
|
||||||
random.seed(1)
|
set_seed(1)
|
||||||
np.random.seed(1)
|
|
||||||
de.config.set_seed(1)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data_config = DataConfig()
|
data_config = DataConfig()
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -69,7 +69,7 @@ def train_and_eval(config):
|
||||||
"""
|
"""
|
||||||
test_train_eval
|
test_train_eval
|
||||||
"""
|
"""
|
||||||
np.random.seed(1000)
|
set_seed(1000)
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -70,7 +70,7 @@ def train_and_eval(config):
|
||||||
"""
|
"""
|
||||||
test_train_eval
|
test_train_eval
|
||||||
"""
|
"""
|
||||||
np.random.seed(1000)
|
set_seed(1000)
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from mindspore import Model, context
|
from mindspore import Model, context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
from mindspore.train.callback import TimeMonitor
|
from mindspore.train.callback import TimeMonitor
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.communication.management import get_rank, get_group_size, init
|
from mindspore.communication.management import get_rank, get_group_size, init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||||
from src.callbacks import LossCallBack, EvalCallBack
|
from src.callbacks import LossCallBack, EvalCallBack
|
||||||
|
@ -69,7 +69,7 @@ def train_and_eval(config):
|
||||||
"""
|
"""
|
||||||
train_and_eval
|
train_and_eval
|
||||||
"""
|
"""
|
||||||
np.random.seed(1000)
|
set_seed(1000)
|
||||||
data_path = config.data_path
|
data_path = config.data_path
|
||||||
epochs = config.epochs
|
epochs = config.epochs
|
||||||
print("epochs is {}".format(epochs))
|
print("epochs is {}".format(epochs))
|
||||||
|
|
Loading…
Reference in New Issue