!8420 mode_train_file

From: @bai-yangfan
Reviewed-by: @chenfei52,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2020-11-11 13:09:52 +08:00 committed by Gitee
commit 623df51e06
25 changed files with 28 additions and 55 deletions

View File

@ -20,5 +20,11 @@ Helper functions in train piplines.
from .model import Model
from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
build_searched_strategy, merge_sliced_parameter
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"]
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"]

View File

@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode
from .. import context
__all__ = ["build_train_network"]
class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16"

View File

@ -17,8 +17,6 @@
from .._checkparam import Validator as validator
from .. import nn
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
class LossScaleManager:
"""Loss scale manager abstract class."""

View File

@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
import mindspore.context as context
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,

View File

@ -20,9 +20,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.alexnet import AlexNet

View File

@ -16,9 +16,8 @@
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore import Tensor, context, load_checkpoint, export
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, export
from src.config import Config_CNNCTC
from src.cnn_ctc import CNNCTC_Model

View File

@ -16,10 +16,7 @@
import argparse
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.nets import net_factory

View File

@ -17,8 +17,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config

View File

@ -20,8 +20,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from src.config import config_gpu as cfg
from src.inception_v3 import InceptionV3

View File

@ -20,9 +20,7 @@ import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5

View File

@ -20,10 +20,8 @@ import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion

View File

@ -16,8 +16,7 @@
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config

View File

@ -16,8 +16,7 @@
mobilenetv2 export mindir.
"""
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export
from mindspore import Tensor, export
from src.config import set_config
from src.args import export_parse_args
from src.models import define_net, load_ckpt

View File

@ -18,9 +18,7 @@ import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining
from src.mobilenetV2 import mobilenetV2

View File

@ -17,8 +17,7 @@ mobilenetv3 export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config_gpu
from src.mobilenetV3 import mobilenet_v3_large

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from src.config import nasnet_a_mobile_config_gpu as cfg
from src.nasnet_a_mobile import NASNetAMobile

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from src.config import config
from src.ETSNET.etsnet import ETSNet

View File

@ -19,8 +19,7 @@ python export.py
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='resnet export')

View File

@ -16,9 +16,7 @@
import argparse
import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.resnet_thor import resnet50 as resnet
from src.config import config

View File

@ -17,8 +17,7 @@ resnext export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config
from src.image_classification import get_network

View File

@ -17,8 +17,7 @@ ssd export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, ssd_mobilenet_v2
from src.config import config

View File

@ -16,8 +16,7 @@
import argparse
import numpy as np
from mindspore import Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from mindspore import Tensor, export, load_checkpoint, load_param_into_net
from src.unet.unet_model import UNet

View File

@ -16,8 +16,7 @@
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.warpctc import StackedRNN
from src.config import config

View File

@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore import load_checkpoint, load_param_into_net, export
from mindspore.train import Model
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net