!8420 mode_train_file

From: @bai-yangfan
Reviewed-by: @chenfei52,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2020-11-11 13:09:52 +08:00 committed by Gitee
commit 623df51e06
25 changed files with 28 additions and 55 deletions

View File

@ -20,5 +20,11 @@ Helper functions in train piplines.
from .model import Model from .model import Model
from .dataset_helper import DatasetHelper, connect_network_with_dataset from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
build_searched_strategy, merge_sliced_parameter
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"] __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"]

View File

@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode from ..context import ParallelMode
from .. import context from .. import context
__all__ = ["build_train_network"]
class OutputTo16(nn.Cell): class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16" "Wrap cell for amp. Cast network output back to float16"

View File

@ -17,8 +17,6 @@
from .._checkparam import Validator as validator from .._checkparam import Validator as validator
from .. import nn from .. import nn
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
class LossScaleManager: class LossScaleManager:
"""Loss scale manager abstract class.""" """Loss scale manager abstract class."""

View File

@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export from mindspore.compression.export import quant_export
import mindspore.context as context import mindspore.context as context
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,

View File

@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.alexnet import AlexNet from src.alexnet import AlexNet

View File

@ -16,9 +16,8 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, export
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, export
from src.config import Config_CNNCTC from src.config import Config_CNNCTC
from src.cnn_ctc import CNNCTC_Model from src.cnn_ctc import CNNCTC_Model

View File

@ -16,10 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import export
from src.nets import net_factory from src.nets import net_factory

View File

@ -17,8 +17,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config from src.config import config

View File

@ -20,8 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import cifar_cfg, imagenet_cfg from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet from src.googlenet import GoogleNet

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_gpu as cfg from src.config import config_gpu as cfg
from src.inception_v3 import InceptionV3 from src.inception_v3 import InceptionV3

View File

@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet import LeNet5 from src.lenet import LeNet5

View File

@ -20,10 +20,8 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion from src.lenet_fusion import LeNet5 as LeNet5Fusion

View File

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config from src.config import config

View File

@ -16,8 +16,7 @@
mobilenetv2 export mindir. mobilenetv2 export mindir.
""" """
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, export
from mindspore.train.serialization import export
from src.config import set_config from src.config import set_config
from src.args import export_parse_args from src.args import export_parse_args
from src.models import define_net, load_ckpt from src.models import define_net, load_ckpt

View File

@ -18,9 +18,7 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2

View File

@ -17,8 +17,7 @@ mobilenetv3 export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_gpu from src.config import config_gpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import nasnet_a_mobile_config_gpu as cfg from src.config import nasnet_a_mobile_config_gpu as cfg
from src.nasnet_a_mobile import NASNetAMobile from src.nasnet_a_mobile import NASNetAMobile

View File

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config from src.config import config
from src.ETSNET.etsnet import ETSNet from src.ETSNET.etsnet import ETSNet

View File

@ -19,8 +19,7 @@ python export.py
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='resnet export') parser = argparse.ArgumentParser(description='resnet export')

View File

@ -16,9 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.resnet_thor import resnet50 as resnet from src.resnet_thor import resnet50 as resnet
from src.config import config from src.config import config

View File

@ -17,8 +17,7 @@ resnext export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config from src.config import config
from src.image_classification import get_network from src.image_classification import get_network

View File

@ -17,8 +17,7 @@ ssd export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, ssd_mobilenet_v2 from src.ssd import SSD300, ssd_mobilenet_v2
from src.config import config from src.config import config

View File

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, export, load_checkpoint, load_param_into_net
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.unet.unet_model import UNet from src.unet.unet_model import UNet

View File

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.warpctc import StackedRNN from src.warpctc import StackedRNN
from src.config import config from src.config import config

View File

@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore import load_checkpoint, load_param_into_net, export
from mindspore.train import Model from mindspore.train import Model
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net