forked from mindspore-Ecosystem/mindspore
!8420 mode_train_file
From: @bai-yangfan Reviewed-by: @chenfei52,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
623df51e06
|
@ -20,5 +20,11 @@ Helper functions in train piplines.
|
|||
from .model import Model
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
from .amp import build_train_network
|
||||
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
||||
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
|
||||
build_searched_strategy, merge_sliced_parameter
|
||||
|
||||
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"]
|
||||
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
|
||||
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
||||
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"]
|
||||
|
|
|
@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
|
|||
from ..context import ParallelMode
|
||||
from .. import context
|
||||
|
||||
__all__ = ["build_train_network"]
|
||||
|
||||
|
||||
class OutputTo16(nn.Cell):
|
||||
"Wrap cell for amp. Cast network output back to float16"
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
from .._checkparam import Validator as validator
|
||||
from .. import nn
|
||||
|
||||
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
|
||||
|
||||
|
||||
class LossScaleManager:
|
||||
"""Loss scale manager abstract class."""
|
||||
|
|
|
@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator
|
|||
from mindspore.compression.export import quant_export
|
||||
import mindspore.context as context
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
|
||||
"build_searched_strategy", "merge_sliced_parameter"]
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
||||
|
|
|
@ -20,9 +20,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
|
||||
from src.alexnet import AlexNet
|
||||
|
|
|
@ -16,9 +16,8 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import Tensor, context, load_checkpoint, export
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, export
|
||||
|
||||
from src.config import Config_CNNCTC
|
||||
from src.cnn_ctc import CNNCTC_Model
|
||||
|
|
|
@ -16,10 +16,7 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.nets import net_factory
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
|
||||
from src.config import config
|
||||
|
|
|
@ -20,8 +20,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.googlenet import GoogleNet
|
||||
|
|
|
@ -19,8 +19,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.inception_v3 import InceptionV3
|
||||
|
|
|
@ -20,9 +20,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet import LeNet5
|
||||
|
|
|
@ -20,10 +20,8 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
|
||||
from src.config import config
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
mobilenetv2 export mindir.
|
||||
"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore import Tensor, export
|
||||
from src.config import set_config
|
||||
from src.args import export_parse_args
|
||||
from src.models import define_net, load_ckpt
|
||||
|
|
|
@ -18,9 +18,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
|
|
|
@ -17,8 +17,7 @@ mobilenetv3 export mindir.
|
|||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.config import config_gpu
|
||||
from src.mobilenetV3 import mobilenet_v3_large
|
||||
|
||||
|
|
|
@ -19,8 +19,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||
from src.nasnet_a_mobile import NASNetAMobile
|
||||
|
|
|
@ -19,8 +19,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import config
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
|
|
|
@ -19,8 +19,7 @@ python export.py
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='resnet export')
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.resnet_thor import resnet50 as resnet
|
||||
from src.config import config
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ resnext export mindir.
|
|||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.config import config
|
||||
from src.image_classification import get_network
|
||||
|
||||
|
|
|
@ -17,8 +17,7 @@ ssd export mindir.
|
|||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.ssd import SSD300, ssd_mobilenet_v2
|
||||
from src.config import config
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
from mindspore import Tensor, export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.unet.unet_model import UNet
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.warpctc import StackedRNN
|
||||
from src.config import config
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
|
|||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train import Model
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
|
|
Loading…
Reference in New Issue