!40159 clean code for master

Merge pull request !40159 from changzherui/clean_code_python2_ma
This commit is contained in:
i-robot 2022-08-10 02:42:32 +00:00 committed by Gitee
commit 16588a9eed
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
77 changed files with 323 additions and 308 deletions

View File

@ -14,20 +14,20 @@
# ============================================================================
"""Top-level reference to dtype of common module."""
from __future__ import absolute_import
from . import dtype
from .api import ms_function, ms_memory_recycle, ms_class, _convert_python_data
from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
from mindspore.common import dtype
from mindspore.common.api import ms_function, ms_memory_recycle, ms_class, _convert_python_data
from mindspore.common.dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \
uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \
float32, single, float64, double, bool_, float_, list_, tuple_, int_, \
uint, number, tensor, string, type_none, tensor_type, Int, \
complex64, complex128, dtype_to_nptype, issubclass_, \
dtype_to_pytype, pytype_to_dtype, get_py_obj_dtype
from .dump import set_dump
from .parameter import Parameter, ParameterTuple
from .seed import set_seed, get_seed
from .tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
from .mutable import mutable
from .jit_config import JitConfig
from mindspore.common.dump import set_dump
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.seed import set_seed, get_seed
from mindspore.common.tensor import Tensor, RowTensor, SparseTensor, COOTensor, CSRTensor
from mindspore.common.mutable import mutable
from mindspore.common.jit_config import JitConfig
# symbols from dtype
__all__ = [

View File

@ -14,7 +14,7 @@
# ============================================================================
"""Define Monad default value."""
from __future__ import absolute_import
from .._c_expression import IOMonad, UMonad
from mindspore._c_expression import IOMonad, UMonad
# Universe monad default value.
U = UMonad()

View File

@ -17,7 +17,7 @@
from __future__ import absolute_import
from collections import UserDict
from .. import context
from mindspore import context
class Registry(UserDict):

View File

@ -31,21 +31,21 @@ import mindspore as ms
from mindspore import context
from mindspore import log as logger
from mindspore._extends.remote import kernel_build_server
from .tensor import Tensor as PythonTensor
from .tensor import CSRTensor as PythonCSRTensor
from .tensor import COOTensor as PythonCOOTensor
from .tensor import RowTensor as PythonRowTensor
from .initializer import initializer
from .._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor, PynativeExecutor_
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
_ms_memory_recycle
from ..parallel._tensor import _load_tensor_by_layout
from ..parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
_get_parameter_broadcast, _get_pipeline_stages
from .._checkparam import Validator
from ._utils import is_shape_unknown
from ..common.mutable import mutable
from mindspore.common.tensor import Tensor as PythonTensor
from mindspore.common.tensor import CSRTensor as PythonCSRTensor
from mindspore.common.tensor import COOTensor as PythonCOOTensor
from mindspore.common.tensor import RowTensor as PythonRowTensor
from mindspore.common.initializer import initializer
from mindspore._c_expression import GraphExecutor_, Tensor, MetaTensor, CSRTensor, RowTensor, COOTensor,\
PynativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
from mindspore.parallel._tensor import _load_tensor_by_layout
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched, _enable_distributed_mindrt
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, \
_to_full_tensor, _get_parameter_broadcast, _get_pipeline_stages
from mindspore._checkparam import Validator
from mindspore.common._utils import is_shape_unknown
from mindspore.common.mutable import mutable
# store ms_function class compiled pipeline cache
ms_compile_cache = set()

View File

@ -20,8 +20,8 @@ from __future__ import absolute_import
from inspect import isfunction
import numpy as np
from mindspore import log as logger
from .._c_expression import typing
from .._c_expression.typing import Type
from mindspore._c_expression import typing
from mindspore._c_expression.typing import Type
__dtype__ = [
"int8", "byte",

View File

@ -21,10 +21,10 @@ import math
from functools import reduce
import numpy as np
from scipy.stats import truncnorm
from .seed import get_seed, _get_graph_seed
from . import dtype as mstype
from .tensor import Tensor
from .._c_expression import random_normal
from mindspore.common.seed import get_seed, _get_graph_seed
from mindspore.common import dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore._c_expression import random_normal
_INITIALIZER_ALIAS = dict()

View File

@ -15,7 +15,7 @@
"""mutable function for setting constants mutable."""
from __future__ import absolute_import
from ..common.tensor import Tensor
from mindspore.common.tensor import Tensor
class _Tuple(tuple):

View File

@ -22,21 +22,21 @@ import numbers
import numpy as np
from mindspore import log as logger
from mindspore.log import _LogActionOnce
from .._c_expression import ParamInfo
from . import dtype as mstype
from .. import context
from ..parallel._utils import _get_parallel_mode
from .initializer import initializer
from .tensor import Tensor
from .._checkparam import Validator
from .._c_expression import Tensor as Tensor_
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table,\
_is_fl_mode, _enable_distributed_mindrt
from ..parallel._ps_context import _reinsert_hash_table_size
from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
from .seed import _get_global_and_op_seed
from mindspore._c_expression import ParamInfo
from mindspore.common import dtype as mstype
from mindspore import context
from mindspore.parallel._utils import _get_parallel_mode
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator
from mindspore._c_expression import Tensor as Tensor_
from mindspore.parallel._tensor import _get_slice_index
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table, \
_is_fl_mode, _enable_distributed_mindrt
from mindspore.parallel._ps_context import _reinsert_hash_table_size
from mindspore.parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info
from mindspore.common.seed import _get_global_and_op_seed
__all__ = ['Parameter', 'ParameterTuple']

View File

@ -23,14 +23,14 @@ from mindspore.common._utils import is_shape_unknown
from mindspore import context
from mindspore import log as logger
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
from .._c_expression import COOTensor as COOTensor_
from .._c_expression import CSRTensor as CSRTensor_
from .._c_expression import RowTensor as RowTensor_
from .._c_expression import Tensor as Tensor_
from .._checkparam import Rel
from .._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore._c_expression import COOTensor as COOTensor_
from mindspore._c_expression import CSRTensor as CSRTensor_
from mindspore._c_expression import RowTensor as RowTensor_
from mindspore._c_expression import Tensor as Tensor_
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
__all__ = ['Tensor', 'RowTensor', 'SparseTensor', 'COOTensor', 'CSRTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,

View File

@ -19,22 +19,21 @@ Pre-defined building blocks or computing units to construct neural networks.
"""
from __future__ import absolute_import
from . import layer, loss, optim, metrics, wrap, grad, probability, sparse, dynamic_lr,\
reinforcement
from .learning_rate_schedule import *
from .dynamic_lr import *
from .cell import Cell, GraphCell
from .layer import *
from .loss import *
from .optim import *
from .metrics import *
from .wrap import *
from .grad import Jvp, Vjp
from .sparse import *
from .reinforcement import *
from .transformer import AttentionMask, VocabEmbedding, MultiHeadAttention, FeedForward, TransformerEncoder, \
TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer, Transformer, TransformerOpParallelConfig, \
EmbeddingOpParallelConfig, TransformerRecomputeConfig, MoEConfig, OpParallelConfig
from mindspore.nn import layer, loss, optim, metrics, wrap, grad, probability, sparse, dynamic_lr, reinforcement
from mindspore.nn.learning_rate_schedule import *
from mindspore.nn.dynamic_lr import *
from mindspore.nn.cell import Cell, GraphCell
from mindspore.nn.layer import *
from mindspore.nn.loss import *
from mindspore.nn.optim import *
from mindspore.nn.metrics import *
from mindspore.nn.wrap import *
from mindspore.nn.grad import Jvp, Vjp
from mindspore.nn.sparse import *
from mindspore.nn.reinforcement import *
from mindspore.nn.transformer import AttentionMask, VocabEmbedding, MultiHeadAttention, FeedForward, \
TransformerEncoder, TransformerDecoder, TransformerEncoderLayer, TransformerDecoderLayer, Transformer,\
TransformerOpParallelConfig, EmbeddingOpParallelConfig, TransformerRecomputeConfig, MoEConfig, OpParallelConfig
__all__ = ["Cell", "GraphCell"]
__all__.extend(layer.__all__)

View File

@ -30,17 +30,17 @@ from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
from mindspore.common.hook_handle import HookHandle
from mindspore.context import ParallelMode
from mindspore.ops.composite import Shard
from .. import context
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from ..common.parameter import Parameter, ParameterTuple
from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..ops.operations import Cast
from ..ops.primitive import Primitive
from ..ops.operations import _inner_ops as inner
from ..parallel._tensor import _load_tensor_by_layout
from mindspore import context
from mindspore._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
from mindspore._checkparam import Validator
from mindspore.common import dtype as mstype
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor, CSRTensor, COOTensor
from mindspore.ops.operations import Cast
from mindspore.ops.primitive import Primitive
from mindspore.ops.operations import _inner_ops as inner
from mindspore.parallel._tensor import _load_tensor_by_layout
class Cell(Cell_):

View File

@ -15,7 +15,7 @@
"""Cells of grad function. Calculate the gradient of input network or function."""
from __future__ import absolute_import
from .cell_grad import Jvp, Vjp
from mindspore.nn.grad.cell_grad import Jvp, Vjp
__all__ = ['Jvp', 'Vjp']

View File

@ -15,12 +15,12 @@
"""cell grad"""
from __future__ import absolute_import
from ..cell import Cell
from ...ops import composite as C
from ...ops import operations as P
from ...ops.primitive import Primitive
from ...common import dtype as mstype
from ...common.api import ms_function
from mindspore.nn.cell import Cell
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.primitive import Primitive
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
class _FirstGrad(Cell):
@ -115,6 +115,7 @@ class Jvp(Cell):
@ms_function
def construct(self, *args):
"""construct for jvp."""
jvp_input = args[0:-1]
v = args[-1]
output = self.fn(*jvp_input)
@ -155,9 +156,7 @@ class _JvpInner(Cell):
self.tuple_len = Primitive("tuple_len")
def compute_jvp(self, fn, v, jvp_input, output):
"""
Compute the jacobian-vector-product of the given fn, vector, inputs and outputs.
"""
"""Compute the jacobian-vector-product of the given fn, vector, inputs and outputs."""
if self.issubclass_(self.typeof(output), mstype.tuple_):
u = self.make_tuple()
for i in range(self.tuple_len(output)):

View File

@ -18,10 +18,10 @@ from __future__ import division
import math
from ..common import dtype as mstype
from ..ops import operations as P
from .cell import Cell
from .._checkparam import Validator as validator
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore._checkparam import Validator as validator
class LearningRateSchedule(Cell):

View File

@ -20,25 +20,25 @@ on the evaluation dataset. It's used to choose the best model.
"""
from __future__ import absolute_import
from .accuracy import Accuracy
from .hausdorff_distance import HausdorffDistance
from .error import MAE, MSE
from .metric import Metric, rearrange_inputs
from .precision import Precision
from .recall import Recall
from .fbeta import Fbeta, F1
from .dice import Dice
from .roc import ROC
from .auc import auc
from .topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from .loss import Loss
from .mean_surface_distance import MeanSurfaceDistance
from .root_mean_square_surface_distance import RootMeanSquareDistance
from .bleu_score import BleuScore
from .cosine_similarity import CosineSimilarity
from .occlusion_sensitivity import OcclusionSensitivity
from .perplexity import Perplexity
from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrix
from mindspore.nn.metrics.accuracy import Accuracy
from mindspore.nn.metrics.hausdorff_distance import HausdorffDistance
from mindspore.nn.metrics.error import MAE, MSE
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
from mindspore.nn.metrics.precision import Precision
from mindspore.nn.metrics.recall import Recall
from mindspore.nn.metrics.fbeta import Fbeta, F1
from mindspore.nn.metrics.dice import Dice
from mindspore.nn.metrics.roc import ROC
from mindspore.nn.metrics.auc import auc
from mindspore.nn.metrics.topk import TopKCategoricalAccuracy, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore.nn.metrics.loss import Loss
from mindspore.nn.metrics.mean_surface_distance import MeanSurfaceDistance
from mindspore.nn.metrics.root_mean_square_surface_distance import RootMeanSquareDistance
from mindspore.nn.metrics.bleu_score import BleuScore
from mindspore.nn.metrics.cosine_similarity import CosineSimilarity
from mindspore.nn.metrics.occlusion_sensitivity import OcclusionSensitivity
from mindspore.nn.metrics.perplexity import Perplexity
from mindspore.nn.metrics.confusion_matrix import ConfusionMatrixMetric, ConfusionMatrix
__all__ = [
"names",

View File

@ -16,7 +16,7 @@
from __future__ import absolute_import
import numpy as np
from .metric import EvaluationBase, rearrange_inputs, _check_onehot_data
from mindspore.nn.metrics.metric import EvaluationBase, rearrange_inputs, _check_onehot_data
class Accuracy(EvaluationBase):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from collections import Counter
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class BleuScore(Metric):

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class ConfusionMatrix(Metric):

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class CosineSimilarity(Metric):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class Dice(Metric):

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
import numpy as np
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class MAE(Metric):

View File

@ -19,7 +19,7 @@ import sys
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs, _check_onehot_data
from mindspore.nn.metrics.metric import Metric, rearrange_inputs, _check_onehot_data
class Fbeta(Metric):

View File

@ -22,7 +22,7 @@ import numpy as np
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class _ROISpatialData(metaclass=ABCMeta):
@ -114,8 +114,8 @@ class HausdorffDistance(Metric):
string_list = ["euclidean", "chessboard", "taxicab"]
distance_metric = validator.check_value_type("distance_metric", distance_metric, [str])
self.distance_metric = validator.check_string(distance_metric, string_list, "distance_metric")
self.percentile = percentile if percentile is None else validator.check_value_type("percentile",
percentile, [float])
self.percentile = percentile if percentile is None else \
validator.check_value_type("percentile", percentile, [float])
self.directed = directed if directed is None else validator.check_value_type("directed", directed, [bool])
self.crop = crop if crop is None else validator.check_value_type("crop", crop, [bool])
self.clear()

View File

@ -15,7 +15,7 @@
"""Loss for evaluation"""
from __future__ import absolute_import
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class Loss(Metric):

View File

@ -19,7 +19,7 @@ from scipy.ndimage import morphology
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class MeanSurfaceDistance(Metric):

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore import nn
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
try:
from tqdm import trange

View File

@ -19,7 +19,7 @@ import math
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class Perplexity(Metric):

View File

@ -19,7 +19,7 @@ import sys
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import EvaluationBase, rearrange_inputs, _check_onehot_data
from mindspore.nn.metrics.metric import EvaluationBase, rearrange_inputs, _check_onehot_data
class Precision(EvaluationBase):

View File

@ -19,7 +19,7 @@ import sys
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import EvaluationBase, rearrange_inputs, _check_onehot_data
from mindspore.nn.metrics.metric import EvaluationBase, rearrange_inputs, _check_onehot_data
class Recall(EvaluationBase):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs, _binary_clf_curve
from mindspore.nn.metrics.metric import Metric, rearrange_inputs, _binary_clf_curve
class ROC(Metric):

View File

@ -19,7 +19,7 @@ from scipy.ndimage import morphology
import numpy as np
from mindspore._checkparam import Validator as validator
from .metric import Metric, rearrange_inputs
from mindspore.nn.metrics.metric import Metric, rearrange_inputs
class RootMeanSquareDistance(Metric):

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
import numpy as np
from .metric import Metric, rearrange_inputs, _check_onehot_data
from mindspore.nn.metrics.metric import Metric, rearrange_inputs, _check_onehot_data
class TopKCategoricalAccuracy(Metric):

View File

@ -20,24 +20,24 @@ The optimizer is used to calculate and update the gradients.
"""
from __future__ import absolute_import
from .optimizer import Optimizer
from .momentum import Momentum
from .adam import Adam, AdamWeightDecay, AdamOffload
from .lamb import Lamb
from .sgd import SGD
from .asgd import ASGD
from .rprop import Rprop
from .lars import LARS
from .ftrl import FTRL
from .rmsprop import RMSProp
from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam
from .ada_grad import Adagrad
from .thor import thor
from .adafactor import AdaFactor
from .adasum import AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
from .adamax import AdaMax
from .adadelta import Adadelta
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.optim.adam import Adam, AdamWeightDecay, AdamOffload
from mindspore.nn.optim.lamb import Lamb
from mindspore.nn.optim.sgd import SGD
from mindspore.nn.optim.asgd import ASGD
from mindspore.nn.optim.rprop import Rprop
from mindspore.nn.optim.lars import LARS
from mindspore.nn.optim.ftrl import FTRL
from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.nn.optim.proximal_ada_grad import ProximalAdagrad
from mindspore.nn.optim.lazyadam import LazyAdam
from mindspore.nn.optim.ada_grad import Adagrad
from mindspore.nn.optim.thor import thor
from mindspore.nn.optim.adafactor import AdaFactor
from mindspore.nn.optim.adasum import AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
from mindspore.nn.optim.adamax import AdaMax
from mindspore.nn.optim.adadelta import Adadelta
__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor',

View File

@ -18,8 +18,9 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common.api import ms_function
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_ada_grad_opt = C.MultitypeFuncGraph("ada_grad_opt")

View File

@ -19,8 +19,8 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_adadelta_opt = C.MultitypeFuncGraph("adadelta_opt")

View File

@ -28,7 +28,7 @@ from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import opt_init_args_register
from .optimizer import Optimizer
from mindspore.nn.optim.optimizer import Optimizer
def _rms(update_tensor):

View File

@ -28,9 +28,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from ._dist_optimizer_registry import _register_dist_optimizer
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
_fused_adam_weight_decay = C.MultitypeFuncGraph("fused_adam_weight_decay")

View File

@ -22,8 +22,8 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
class ASGD(Optimizer):

View File

@ -19,9 +19,9 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.api import ms_function
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from ._dist_optimizer_registry import _register_dist_optimizer
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
_ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")

View File

@ -25,8 +25,8 @@ from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")

View File

@ -21,8 +21,8 @@ from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore.common import Tensor, Parameter, dtype as mstype
from mindspore.common.api import ms_function
from .optimizer import _grad_scale, Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import _grad_scale, Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_lars_opt = C.MultitypeFuncGraph("lars_opt")

View File

@ -25,9 +25,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from ._dist_optimizer_registry import _register_dist_optimizer
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
_lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")

View File

@ -21,9 +21,9 @@ from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from ._dist_optimizer_registry import _register_dist_optimizer
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
from mindspore.nn.optim._dist_optimizer_registry import _register_dist_optimizer
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")

View File

@ -35,7 +35,7 @@ from mindspore.parallel._ps_context import _is_ps_mode, _enable_distributed_mind
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
from ._dist_optimizer_registry import generate_dist_optimizer_list
from mindspore.nn.optim._dist_optimizer_registry import generate_dist_optimizer_list
__all__ = ['Optimizer', 'opt_init_args_register']
@ -185,8 +185,8 @@ class Optimizer(Cell):
"""initialize optimizer attributions"""
weight_decay = self._preprocess_weight_decay(weight_decay)
if self.is_group_lr:
self.learning_rate = CellList(self.group_lr, auto_prefix=False) if self.dynamic_lr \
else ParameterTuple(self.group_lr)
self.learning_rate = CellList(self.group_lr, auto_prefix=False) \
if self.dynamic_lr else ParameterTuple(self.group_lr)
else:
self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate')

View File

@ -20,8 +20,8 @@ from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.api import ms_function
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")

View File

@ -18,8 +18,8 @@ from __future__ import absolute_import
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore.common.api import ms_function
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
_centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")

View File

@ -21,8 +21,8 @@ import mindspore.common.dtype as mstype
from mindspore.common.api import ms_function
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
class Rprop(Optimizer):

View File

@ -21,8 +21,8 @@ from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.optim.optimizer import opt_init_args_register
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")

View File

@ -15,7 +15,7 @@
"""Sparse related transformation."""
from __future__ import absolute_import
from .sparse import (SparseToDense, SparseTensorDenseMatmul)
from mindspore.nn.sparse.sparse import (SparseToDense, SparseTensorDenseMatmul)
__all__ = [
"SparseToDense",

View File

@ -16,7 +16,7 @@
from __future__ import absolute_import
from mindspore.ops import operations as P
from ..cell import Cell
from mindspore.nn.cell import Cell
class SparseToDense(Cell):

View File

@ -19,11 +19,12 @@ Use the Wrapper to combine the loss or build the training steps.
"""
from __future__ import absolute_import
from .cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, WithEvalCell, \
ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell
from .loss_scale import TrainOneStepWithLossScaleCell, DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from .grad_reducer import DistributedGradReducer
from ..layer.timedistributed import TimeDistributed
from mindspore.nn.wrap.cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, \
WithEvalCell, ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell
from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell,\
DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.nn.layer.timedistributed import TimeDistributed
__all__ = [

View File

@ -24,15 +24,15 @@ from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
from mindspore.context import ParallelMode
from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops.primitive import constexpr
from ...ops import composite as C
from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from .grad_reducer import DistributedGradReducer
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops.primitive import constexpr
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.cell import Cell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
@ -478,6 +478,7 @@ class _MicroBatch(Cell):
self.strided_slice = P.StridedSlice()
def construct(self, i, *inputs):
"""construct for _MicroBatch."""
micro_inputs = ()
for each_input in inputs:
input_shape = self.shape(each_input)

View File

@ -18,14 +18,14 @@ from __future__ import absolute_import
import mindspore.context as context
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_enable_parallel_optimizer
from .cell_wrapper import TrainOneStepCell
from ..cell import Cell
from ...common import Tensor, RowTensor
from ...common.parameter import Parameter
from ...ops import functional as F
from ...ops import composite as C
from ...ops import operations as P
from ...common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
from mindspore.nn.cell import Cell
from mindspore.common import Tensor, RowTensor
from mindspore.common.parameter import Parameter
from mindspore.ops import functional as F
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()

View File

@ -14,17 +14,17 @@
# ============================================================================
"""Auto mixed precision."""
from __future__ import absolute_import
from .. import nn
from .._checkparam import Validator as validator
from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from ..ops import functional as F
from ..parallel._utils import _get_pipeline_stages
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .. import boost
from .. import context
from mindspore import nn
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.cell_wrapper import _TrainPipelineAccuStepCell
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_pipeline_stages
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from mindspore import boost, context
AMP_WHITE_LIST = (
@ -46,7 +46,7 @@ AMP_BLACK_LIST = (
class _OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16"
"""Wrap cell for amp. Cast network output back to float16."""
def __init__(self, op):
super(_OutputTo16, self).__init__(auto_prefix=False)

View File

@ -13,27 +13,27 @@
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""
from __future__ import absolute_import
from ._callback import Callback
from ._callback import CallbackManager as _CallbackManager
from ._callback import InternalCallbackParam as _InternalCallbackParam
from ._callback import RunContext
from ._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op
from ._callback import set_cur_net as _set_cur_net
from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint
from ._loss_monitor import LossMonitor
from ._time_monitor import TimeMonitor
from ._summary_collector import SummaryCollector
from ._lr_scheduler_callback import LearningRateScheduler
from ._landscape import SummaryLandscape
from ._fl_manager import FederatedLearningManager
from ._history import History
from ._lambda_callback import LambdaCallback
from ._early_stop import EarlyStopping
from ._reduce_lr_on_plateau import ReduceLROnPlateau
from mindspore.train.callback._callback import Callback
from mindspore.train.callback._callback import CallbackManager as _CallbackManager
from mindspore.train.callback._callback import InternalCallbackParam as _InternalCallbackParam
from mindspore.train.callback._callback import RunContext
from mindspore.train.callback._callback import checkpoint_cb_for_save_op as _checkpoint_cb_for_save_op
from mindspore.train.callback._callback import set_cur_net as _set_cur_net
from mindspore.train.callback._checkpoint import CheckpointConfig
from mindspore.train.callback._checkpoint import CheckpointManager as _CheckpointManager
from mindspore.train.callback._checkpoint import ModelCheckpoint
from mindspore.train.callback._loss_monitor import LossMonitor
from mindspore.train.callback._time_monitor import TimeMonitor
from mindspore.train.callback._summary_collector import SummaryCollector
from mindspore.train.callback._lr_scheduler_callback import LearningRateScheduler
from mindspore.train.callback._landscape import SummaryLandscape
from mindspore.train.callback._fl_manager import FederatedLearningManager
from mindspore.train.callback._history import History
from mindspore.train.callback._lambda_callback import LambdaCallback
from mindspore.train.callback._early_stop import EarlyStopping
from mindspore.train.callback._reduce_lr_on_plateau import ReduceLROnPlateau
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",

View File

@ -29,9 +29,9 @@ from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank, _enable_distributed_mindrt
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
from ._callback import Callback, set_cur_net
from ...common.tensor import Tensor
from ...common.parameter import Parameter
from mindspore.train.callback._callback import Callback, set_cur_net
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
_cur_dir = os.getcwd()
_save_dir = _cur_dir

View File

@ -29,7 +29,7 @@ from mindspore.ops import ReduceOp
from mindspore.communication import get_group_size
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from ._callback import Callback, _handle_loss
from mindspore.train.callback._callback import Callback, _handle_loss
_smaller_better_metrics = ['hausdorff_distance', 'mae', 'mse', 'loss', 'perplexity',
@ -133,7 +133,7 @@ class EarlyStopping(Callback):
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.mode == 'min' or \
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
self.best_weights_param_dict = None
def on_train_epoch_end(self, run_context):
@ -152,8 +152,8 @@ class EarlyStopping(Callback):
parallel_mode = auto_parallel_context().get_parallel_mode()
rank_size = 1 if parallel_mode == ParallelMode.STAND_ALONE else get_group_size()
current = current_value if rank_size == 1 else \
self._reduce(Tensor(current_value.astype(np.float32))) / rank_size
current = current_value if \
rank_size == 1 else self._reduce(Tensor(current_value.astype(np.float32))) / rank_size
if current is None:
return

View File

@ -272,8 +272,8 @@ class FederatedLearningManager(Callback):
"""
Set the value of last parameters for adaptive synchronization.
"""
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
if self._as_prefix not in _.name}
self._last_param = {_.name: deepcopy(_.asnumpy())
for _ in self._model.trainable_params() if self._as_prefix not in _.name}
def step_end(self, run_context):
"""

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore.common.tensor import Tensor
from ._callback import Callback
from mindspore.train.callback._callback import Callback
class History(Callback):

View File

@ -15,7 +15,7 @@
"""Lambda Callback class."""
from __future__ import absolute_import
from ._callback import Callback
from mindspore.train.callback._callback import Callback
class LambdaCallback(Callback):

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
import numpy as np
from mindspore._checkparam import Validator
from ._callback import Callback, _handle_loss
from mindspore.train.callback._callback import Callback, _handle_loss
class LossMonitor(Callback):

View File

@ -27,7 +27,7 @@ from mindspore import nn, ops
from mindspore.communication import get_group_size
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from ._callback import Callback, _handle_loss
from mindspore.train.callback._callback import Callback, _handle_loss
_smaller_better_metrics = ['hausdorff_distance', 'mae', 'mse', 'loss', 'perplexity',
@ -129,7 +129,7 @@ class ReduceLROnPlateau(Callback):
self.cooldown_counter = 0
self.wait = 0
self.best = np.Inf if self.mode == 'min' or \
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
(self.mode == 'auto' and self.monitor in _smaller_better_metrics) else -np.Inf
def on_train_epoch_end(self, run_context):
"""

View File

@ -40,7 +40,7 @@ from mindspore.train.callback._dataset_graph import DatasetGraph
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type, _make_directory
from ..._c_expression import security
from mindspore._c_expression import security
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
HYPER_CONFIG_LEN_LIMIT = 100000

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
import time
from mindspore._checkparam import Validator
from ._callback import Callback
from mindspore.train.callback._callback import Callback
class TimeMonitor(Callback):

View File

@ -23,12 +23,13 @@ from mindspore.common.api import _cell_graph_executor
from mindspore.common._utils import is_shape_unknown
from mindspore.dataset.engine import offload
import mindspore.dataset as ds
from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, \
from mindspore import context, nn
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
_to_full_shapes, _get_pipeline_stages
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, \
_enable_distributed_mindrt
from ..ops import operations as P
from mindspore.ops import operations as P
def _send_data(dataset, epoch_num):

View File

@ -15,8 +15,8 @@
"""Loss scale manager abstract class."""
from __future__ import absolute_import
from .._checkparam import Validator as validator
from .. import nn
from mindspore._checkparam import Validator as validator
from mindspore import nn
class LossScaleManager:

View File

@ -24,29 +24,29 @@ import copy
import numpy as np
from mindspore import log as logger
from .serialization import save_checkpoint, load_checkpoint
from .callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics, get_metric_fn
from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor
from .callback import __all__ as internal_cb_names
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
from mindspore.common.tensor import Tensor
from mindspore.nn.metrics import get_metrics, get_metric_fn
from mindspore._checkparam import check_input_data, check_output_data, Validator
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor
from mindspore.train.callback import __all__ as internal_cb_names
from mindspore import context
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
_reset_op_id_with_offset
from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, _cache_enable, \
_enable_distributed_mindrt
from ..nn.metrics import Loss
from .. import nn
from ..boost import AutoBoost
from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
from ..parallel._recovery_context import _set_recovery_context, _get_recovery_context
from .dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _is_ps_mode, \
_cache_enable, _enable_distributed_mindrt
from mindspore.nn.metrics import Loss
from mindspore import nn
from mindspore.boost import AutoBoost
from mindspore.context import ParallelMode
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
from mindspore.common.api import _pynative_executor
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
from . import amp
from ..common.api import _pynative_executor
from ..dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
def _transfer_tensor_to_tuple(inputs):

View File

@ -55,7 +55,7 @@ from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from mindspore.train._utils import read_proto
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,

View File

@ -16,7 +16,8 @@
Summary related classes and functions. User can use SummaryRecord to dump the summary data, the summary is a series of
operations to collect data for analysis and visualization.
"""
from __future__ import absolute_import
from .summary_record import SummaryRecord
from mindspore.train.summary.summary_record import SummaryRecord
__all__ = ["SummaryRecord"]

View File

@ -13,9 +13,11 @@
# limitations under the License.
# ============================================================================
"""Generate the lineage event which conform to proto format."""
from __future__ import absolute_import
import time
from ..lineage_pb2 import LineageEvent
from mindspore.train.lineage_pb2 import LineageEvent
def serialize_to_lineage_event(name, value):

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
from __future__ import absolute_import
import io
import platform
import time
@ -25,9 +27,9 @@ from mindspore import context
from mindspore.communication.management import get_rank
from mindspore.communication.management import GlobalComm
from ..._checkparam import Validator
from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event
from mindspore._checkparam import Validator
from mindspore.train.anf_ir_pb2 import DataType, ModelProto
from mindspore.train.summary_pb2 import Event
# define the MindSpore image format
MS_IMAGE_TENSOR_FORMAT = 'NCHW'

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Write events to disk in a base directory."""
from __future__ import absolute_import
import os
import time
import signal
@ -24,9 +26,9 @@ import psutil
import mindspore.log as logger
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
from ._lineage_adapter import serialize_to_lineage_event
from ._summary_adapter import package_graph_event, package_summary_event
from .writer import LineageWriter, SummaryWriter, ExportWriter
from mindspore.train.summary._lineage_adapter import serialize_to_lineage_event
from mindspore.train.summary._summary_adapter import package_graph_event, package_summary_event
from mindspore.train.summary.writer import LineageWriter, SummaryWriter, ExportWriter
try:
from multiprocessing import get_context

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Summary's enumeration file."""
from __future__ import absolute_import
from enum import Enum

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Record the summary event."""
from __future__ import absolute_import
import atexit
import os
import re
@ -23,13 +25,13 @@ from collections import defaultdict
from mindspore import log as logger
from mindspore.nn import Cell
from ..._c_expression import Tensor, security
from ..._checkparam import Validator
from ...common.api import _cell_graph_executor
from .._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
from ._summary_adapter import get_event_file_name, package_graph_event
from ._writer_pool import WriterPool
from .enums import PluginEnum
from mindspore._c_expression import Tensor, security
from mindspore._checkparam import Validator
from mindspore.common.api import _cell_graph_executor
from mindspore.train._utils import _check_lineage_value, _check_to_numpy, _make_directory, check_value_type
from mindspore.train.summary._summary_adapter import get_event_file_name, package_graph_event
from mindspore.train.summary._writer_pool import WriterPool
from mindspore.train.summary.enums import PluginEnum
# for the moment, this lock is for caution's sake,
# there are actually no any concurrences happening.

View File

@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Writes events to disk in a logdir."""
from __future__ import absolute_import
import os
import stat
from urllib.parse import quote
@ -23,11 +25,11 @@ import numpy as np
from mindspore.train.summary.enums import PluginEnum, WriterPluginEnum
from mindspore import log as logger
from .._utils import _make_directory
from ._summary_adapter import package_init_event
from ..._c_expression import security
from mindspore.train._utils import _make_directory
from mindspore.train.summary._summary_adapter import package_init_event
from mindspore._c_expression import security
if not security.enable_security():
from ..._c_expression import EventWriter_
from mindspore._c_expression import EventWriter_
FREE_DISK_SPACE_TIMES = 32

View File

@ -15,6 +15,6 @@
"""convert to second order related classes and functions."""
from __future__ import absolute_import
from .convert_utils import ConvertNetUtils, ConvertModelUtils
from mindspore.train.train_thor.convert_utils import ConvertNetUtils, ConvertModelUtils
__all__ = ["ConvertNetUtils", "ConvertModelUtils"]

View File

@ -32,7 +32,7 @@ from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset
from .dataset_helper import DatasetHelper
from mindspore.train.train_thor.dataset_helper import DatasetHelper
def _convert_to_ms_type(types):