forked from mindspore-Ecosystem/mindspore
Modelzoo interface change.
This commit is contained in:
parent
9018737e99
commit
4d9d8c3e74
|
@ -30,7 +30,9 @@ from mindspore import Tensor
|
|||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
|
||||
|
|
|
@ -21,11 +21,14 @@ from mindspore import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
from src.md_dataset import create_dataset
|
||||
from src.losses import OhemLoss
|
||||
from src.deeplabv3 import deeplabv3_resnet50
|
||||
from src.config import config
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Deeplabv3 training")
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument('--data_url', required=True, default=None, help='Train data url')
|
||||
|
|
|
@ -87,7 +87,9 @@ Dataset used: [COCO2017](<http://images.cocodataset.org/>)
|
|||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
Note: 1.the first run will generate the mindeocrd file, which will take a long time. 2. pretrained model is a resnet50 checkpoint that trained over ImageNet2012. 3. VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
||||
Note: 1.the first run will generate the mindeocrd file, which will take a long time.
|
||||
2.pretrained model is a resnet50 checkpoint that trained over ImageNet2012.
|
||||
3.VALIDATION_JSON_FILE is label file. CHECKPOINT_PATH is a checkpoint file after trained.
|
||||
|
||||
```
|
||||
# standalone training
|
||||
|
@ -106,7 +108,7 @@ sh run_eval_ascend.sh [VALIDATION_JSON_FILE] [CHECKPOINT_PATH]
|
|||
|
||||
```shell
|
||||
.
|
||||
└─FasterRcnn
|
||||
└─faster_rcnn
|
||||
├─README.md // descriptions about fasterrcnn
|
||||
├─scripts
|
||||
├─run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||
|
@ -148,6 +150,7 @@ sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [PRETRAINED_MODEL]
|
|||
|
||||
> Rank_table.json which is specified by RANK_TABLE_FILE is needed when you are running a distribute task. You can generate it by using the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
> As for PRETRAINED_MODEL,it should be a ResNet50 checkpoint that trained over ImageNet2012. Ready-made pretrained_models are not available now. Stay tuned.
|
||||
> The original dataset path needs to be in the config.py,you can select "coco_root" or "image_dir".
|
||||
|
||||
### Result
|
||||
|
||||
|
@ -205,8 +208,8 @@ Eval result will be stored in the example path, whose folder name is "eval". Und
|
|||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,56cores;Memory,314G |
|
||||
| uploaded Date | 06/01/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.3.0-alpha |
|
||||
| uploaded Date | 08/31/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.7.0-beta |
|
||||
| Dataset | COCO2017 |
|
||||
| Training Parameters | epoch=12, batch_size=2 |
|
||||
| Optimizer | SGD |
|
||||
|
@ -223,12 +226,12 @@ Eval result will be stored in the example path, whose folder name is "eval". Und
|
|||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 06/01/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.3.0-alpha |
|
||||
| Uploaded Date | 08/31/2020 (month/day/year) |
|
||||
| MindSpore Version | 0.7.0-beta |
|
||||
| Dataset | COCO2017 |
|
||||
| batch_size | 2 |
|
||||
| outputs | mAP |
|
||||
| Accuracy | IoU=0.50: 58.6% |
|
||||
| Accuracy | IoU=0.50: 57.6% |
|
||||
| Model for inference | 250M (.ckpt file) |
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
|
|
@ -17,21 +17,18 @@
|
|||
import os
|
||||
import argparse
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
||||
from src.config import config
|
||||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
||||
from src.util import coco_eval, bbox2result_1image, results2json
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn evaluation")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
|
|
|
@ -19,8 +19,6 @@ import os
|
|||
import time
|
||||
import argparse
|
||||
import ast
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
|
@ -30,7 +28,7 @@ from mindspore.train import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import SGD
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||
|
@ -38,9 +36,7 @@ from src.config import config
|
|||
from src.dataset import data_to_mindrecord_byte_image, create_fasterrcnn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.")
|
||||
|
@ -78,18 +74,24 @@ if __name__ == '__main__':
|
|||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
if not os.path.exists(config.coco_root):
|
||||
print("Please make sure config:coco_root is valid.")
|
||||
raise ValueError(config.coco_root)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
data_to_mindrecord_byte_image("coco", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
if not os.path.exists(config.image_dir):
|
||||
print("Please make sure config:image_dir is valid.")
|
||||
raise ValueError(config.image_dir)
|
||||
print("Create Mindrecord. It may take some time.")
|
||||
data_to_mindrecord_byte_image("other", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
print("image_dir or anno_path not exits.")
|
||||
|
||||
while not os.path.exists(mindrecord_file + ".db"):
|
||||
time.sleep(5)
|
||||
|
|
|
@ -23,11 +23,14 @@ from mindspore import context
|
|||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='googlenet')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
|
|
@ -18,7 +18,6 @@ python train.py
|
|||
"""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -31,13 +30,13 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMoni
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
set_seed(1)
|
||||
|
||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
"""Set learning rate."""
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
"""train_imagenet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
@ -27,9 +25,9 @@ from mindspore.nn.optim.rmsprop import RMSProp
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import dataset as de
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common.initializer import XavierUniform, initializer
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config_gpu, config_ascend
|
||||
from src.dataset import create_dataset
|
||||
|
@ -37,9 +35,7 @@ from src.inception_v3 import InceptionV3
|
|||
from src.lr_generator import get_lr
|
||||
from src.loss import CrossEntropy
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -94,7 +90,6 @@ if __name__ == '__main__':
|
|||
if args_opt.platform == "Ascend":
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
np.random.seed(seed=1)
|
||||
param.set_parameter_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
|
|
|
@ -29,7 +29,9 @@ from mindspore import context
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.common import set_seed
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
|
|
|
@ -28,11 +28,14 @@ from mindspore.train import Model
|
|||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import create_dataset
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
from src.loss_monitor import LossMonitor
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU'],
|
||||
|
|
|
@ -17,21 +17,18 @@
|
|||
import os
|
||||
import argparse
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
from pycocotools.coco import COCO
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||
from src.config import config
|
||||
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
||||
from src.util import coco_eval, bbox2result_1image, results2json, get_seg_masks
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="MaskRcnn evaluation")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
|
@ -29,7 +27,7 @@ from mindspore.train import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn import SGD
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||
from src.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet
|
||||
|
@ -37,9 +35,7 @@ from src.config import config
|
|||
from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
|
||||
from src.lr_schedule import dynamic_lr
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="MaskRcnn training")
|
||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, help="If set it true, only create "
|
||||
|
|
|
@ -13,16 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.communication.management import get_rank, init
|
||||
from mindspore.dataset import engine as de
|
||||
|
||||
from src.models import Monitor
|
||||
|
||||
|
@ -84,10 +80,3 @@ def config_ckpoint(config, lr, step_size):
|
|||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
return cb
|
||||
|
||||
|
||||
|
||||
def set_random_seed(seed=1):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
de.config.set_seed(seed)
|
||||
|
|
|
@ -27,16 +27,17 @@ from mindspore.common import dtype as mstype
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import _exec_save_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset, extract_features
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import set_config
|
||||
|
||||
from src.args import train_parse_args
|
||||
from src.utils import set_random_seed, context_device_init, switch_precision, config_ckpoint
|
||||
from src.utils import context_device_init, switch_precision, config_ckpoint
|
||||
from src.models import CrossEntropyWithLabelSmooth, define_net
|
||||
|
||||
set_random_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = train_parse_args()
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
|
@ -30,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint
|
|||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
|
@ -38,9 +36,7 @@ from src.utils import Monitor, CrossEntropyWithLabelSmooth
|
|||
from src.config import config_ascend_quant, config_gpu_quant
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
import time
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
|
@ -33,7 +32,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.dataset import create_dataset
|
||||
|
@ -41,9 +40,7 @@ from src.lr_generator import get_lr
|
|||
from src.config import config_gpu
|
||||
from src.mobilenetV3 import mobilenet_v3_large
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
"""train imagenet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
@ -26,7 +24,7 @@ from mindspore.nn.optim.rmsprop import RMSProp
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
|
@ -34,9 +32,7 @@ from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStep
|
|||
from src.lr_generator import get_lr
|
||||
|
||||
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
de.config.set_seed(cfg.random_seed)
|
||||
set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
@ -33,9 +31,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
|||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
if args_opt.net == "resnet50":
|
||||
from src.resnet import resnet50 as resnet
|
||||
|
|
|
@ -14,13 +14,10 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
import ast
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import dataset as de
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
|
@ -30,6 +27,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.common import set_seed
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||
|
@ -47,9 +45,7 @@ parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained ch
|
|||
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
if args_opt.net == "resnet50":
|
||||
from src.resnet import resnet50 as resnet
|
||||
|
|
|
@ -31,6 +31,7 @@ from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
|||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
|
||||
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
|
||||
|
@ -39,6 +40,8 @@ from src.lr_generator import get_lr
|
|||
from src.config import config_quant
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
@ -32,9 +30,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
|
|||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
target = args_opt.device_target
|
||||
|
|
|
@ -14,13 +14,12 @@
|
|||
# ============================================================================
|
||||
"""train resnet."""
|
||||
import os
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
|
@ -46,9 +45,7 @@ else:
|
|||
from src.thor import THOR_GPU as THOR
|
||||
from src.config import config_gpu as config
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
|
||||
|
|
|
@ -151,7 +151,6 @@ class KaimingUniform(KaimingInit):
|
|||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
@ -179,7 +178,6 @@ class KaimingNormal(KaimingInit):
|
|||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
@ -195,7 +193,6 @@ def default_recurisive_init(custom_cell):
|
|||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype)
|
||||
|
@ -206,7 +203,6 @@ def default_recurisive_init(custom_cell):
|
|||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype)
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.train.callback import CheckpointConfig, Callback
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import classification_dataset
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
@ -38,6 +39,7 @@ from src.utils.optimizers__init__ import get_param_groups
|
|||
from src.image_classification import get_network
|
||||
from src.config import config
|
||||
|
||||
set_seed(1)
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
"""build training network"""
|
||||
|
|
|
@ -16,14 +16,11 @@
|
|||
import argparse
|
||||
import ast
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from network import ShuffleNetV2
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
@ -31,14 +28,13 @@ from mindspore.nn.optim.momentum import Momentum
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr_basic
|
||||
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
de.config.set_seed(cfg.random_seed)
|
||||
set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# ============================================================================
|
||||
"""Parameters utils"""
|
||||
|
||||
import numpy as np
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
|
||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||
|
@ -22,7 +21,6 @@ def init_net_param(network, initialize_mode='TruncatedNormal'):
|
|||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
np.random.seed(seed=1)
|
||||
if initialize_mode == 'TruncatedNormal':
|
||||
p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
||||
else:
|
||||
|
|
|
@ -25,12 +25,14 @@ from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMoni
|
|||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
|
||||
from src.config import config
|
||||
from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SSD training")
|
||||
|
|
|
@ -151,7 +151,6 @@ class KaimingUniform(KaimingInit):
|
|||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
@ -179,7 +178,6 @@ class KaimingNormal(KaimingInit):
|
|||
def _initialize(self, arr):
|
||||
fan = _select_fan(arr, self.mode)
|
||||
std = self.gain / math.sqrt(fan)
|
||||
np.random.seed(0)
|
||||
data = np.random.normal(0, std, arr.shape)
|
||||
|
||||
_assignment(arr, data)
|
||||
|
@ -195,7 +193,6 @@ def default_recurisive_init(custom_cell):
|
|||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype)
|
||||
|
@ -206,7 +203,6 @@ def default_recurisive_init(custom_cell):
|
|||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_in_and_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
np.random.seed(0)
|
||||
cell.bias.default_input = init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype)
|
||||
|
|
|
@ -19,9 +19,6 @@ python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
|
|||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
@ -33,6 +30,7 @@ from mindspore.train.model import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import vgg_create_dataset
|
||||
from src.dataset import classification_dataset
|
||||
|
||||
|
@ -45,8 +43,7 @@ from src.utils.util import get_param_groups
|
|||
from src.vgg import vgg16
|
||||
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args(cloud_args=None):
|
||||
|
|
|
@ -15,11 +15,9 @@
|
|||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import math as m
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
@ -29,9 +27,7 @@ from src.dataset import create_dataset
|
|||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
||||
from src.metric import WarpCTCAccuracy
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
|
|
|
@ -15,12 +15,10 @@
|
|||
"""Warpctc training"""
|
||||
import os
|
||||
import math as m
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap import WithLossCell
|
||||
|
@ -34,9 +32,7 @@ from src.warpctc import StackedRNN, StackedRNNForGPU
|
|||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||
from src.lr_schedule import get_lr
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
|
|
|
@ -21,9 +21,6 @@ from mindspore.common.initializer import Initializer as MeInitializer
|
|||
import mindspore.nn as nn
|
||||
|
||||
|
||||
np.random.seed(5)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
|
|
@ -30,6 +30,7 @@ import mindspore as ms
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore import amp
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||
from src.logger import get_logger
|
||||
|
@ -41,6 +42,7 @@ from src.initializer import default_recurisive_init
|
|||
from src.config import ConfigYOLOV3DarkNet53
|
||||
from src.util import keep_loss_fp32
|
||||
|
||||
set_seed(1)
|
||||
|
||||
class BuildTrainNetwork(nn.Cell):
|
||||
def __init__(self, network, criterion):
|
||||
|
|
|
@ -21,9 +21,6 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
|
||||
|
||||
np.random.seed(5)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
|||
import mindspore as ms
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||
from src.logger import get_logger
|
||||
|
@ -41,6 +42,7 @@ from src.config import ConfigYOLOV3DarkNet53
|
|||
from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single
|
||||
from src.util import ShapeRecord
|
||||
|
||||
set_seed(1)
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
||||
|
|
|
@ -34,11 +34,13 @@ from mindspore.train import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper
|
||||
from src.dataset import create_yolo_dataset, data_to_mindrecord_byte_image
|
||||
from src.config import ConfigYOLOV3ResNet18
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False):
|
||||
"""Set learning rate."""
|
||||
|
@ -54,7 +56,7 @@ def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps
|
|||
|
||||
|
||||
def init_net_param(network, init_value='ones'):
|
||||
"""Init:wq the parameters in network."""
|
||||
"""Init the parameters in network."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
|
|
|
@ -19,12 +19,14 @@ import os
|
|||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import GatConfig
|
||||
from src.dataset import load_and_process
|
||||
from src.gat import GAT
|
||||
from src.utils import LossAccuracyWrapper, TrainGAT
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def train():
|
||||
"""Train GAT model."""
|
||||
|
|
|
@ -26,6 +26,7 @@ from matplotlib import pyplot as plt
|
|||
from matplotlib import animation
|
||||
from sklearn import manifold
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.gcn import GCN
|
||||
from src.metrics import LossAccuracyWrapper, TrainNetWrapper
|
||||
|
@ -55,7 +56,7 @@ def train():
|
|||
parser.add_argument('--save_TSNE', type=ast.literal_eval, default=False, help='Whether to save t-SNE graph')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
np.random.seed(args_opt.seed)
|
||||
set_seed(args_opt.seed)
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend", save_graphs=False)
|
||||
config = ConfigGCN()
|
||||
|
|
|
@ -19,7 +19,6 @@ python run_pretrain.py
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
@ -30,6 +29,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainAccumulateStepsWithLossScaleCell
|
||||
from src.dataset import create_bert_dataset
|
||||
|
@ -196,5 +196,5 @@ def run_pretrain():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
set_seed(0)
|
||||
run_pretrain()
|
||||
|
|
|
@ -19,7 +19,6 @@ python run_pretrain.py
|
|||
|
||||
import argparse
|
||||
import os
|
||||
import numpy
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from src.bert_net_config import bert_net_cfg
|
||||
from src.config import cfg
|
||||
|
@ -36,6 +35,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
@ -197,5 +197,5 @@ def run_pretrain():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
set_seed(0)
|
||||
run_pretrain()
|
||||
|
|
|
@ -30,6 +30,7 @@ from mindspore import context, Parameter
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication import management as MultiAscend
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from config import TransformerConfig
|
||||
from src.dataset import load_dataset
|
||||
|
@ -337,7 +338,7 @@ if __name__ == '__main__':
|
|||
_check_args(args.config)
|
||||
_config = get_config(args.config)
|
||||
|
||||
np.random.seed(_config.random_seed)
|
||||
set_seed(_config.random_seed)
|
||||
context.set_context(save_graphs=_config.save_graphs)
|
||||
|
||||
if _rank_size is not None and int(_rank_size) > 1:
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
import numpy
|
||||
import mindspore.communication.management as D
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
@ -28,6 +27,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import create_tinybert_dataset, DataType
|
||||
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
|
||||
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
|
||||
|
@ -154,5 +154,5 @@ def run_general_distill():
|
|||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
set_seed(0)
|
||||
run_general_distill()
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
|
||||
import time
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -27,10 +25,10 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
|
||||
TransformerTrainOneStepWithLossScaleCell
|
||||
|
@ -38,10 +36,7 @@ from src.config import cfg, transformer_net_cfg
|
|||
from src.dataset import create_transformer_dataset
|
||||
from src.lr_schedule import create_dynamic_lr
|
||||
|
||||
random_seed = 1
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
de.config.set_seed(random_seed)
|
||||
set_seed(1)
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
|
|
|
@ -16,15 +16,13 @@
|
|||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
import mindspore.dataset.engine as de
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
|
@ -46,9 +44,7 @@ args_opt, _ = parser.parse_known_args()
|
|||
args_opt.do_eval = args_opt.do_eval == 'True'
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
|
|
|
@ -17,11 +17,11 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
|
@ -69,7 +69,7 @@ def train_and_eval(config):
|
|||
"""
|
||||
test_train_eval
|
||||
"""
|
||||
np.random.seed(1000)
|
||||
set_seed(1000)
|
||||
data_path = config.data_path
|
||||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
|
|
|
@ -17,11 +17,11 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
|
@ -70,7 +70,7 @@ def train_and_eval(config):
|
|||
"""
|
||||
test_train_eval
|
||||
"""
|
||||
np.random.seed(1000)
|
||||
set_seed(1000)
|
||||
data_path = config.data_path
|
||||
batch_size = config.batch_size
|
||||
epochs = config.epochs
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from mindspore import Model, context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.callback import TimeMonitor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_rank, get_group_size, init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
||||
from src.callbacks import LossCallBack, EvalCallBack
|
||||
|
@ -69,7 +69,7 @@ def train_and_eval(config):
|
|||
"""
|
||||
train_and_eval
|
||||
"""
|
||||
np.random.seed(1000)
|
||||
set_seed(1000)
|
||||
data_path = config.data_path
|
||||
epochs = config.epochs
|
||||
print("epochs is {}".format(epochs))
|
||||
|
|
Loading…
Reference in New Issue