diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 070cad4fa1f..33c065100c6 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================ """Auto mixed precision.""" -from easydict import EasyDict as edict - from .. import nn from .._checkparam import Validator as validator from .._checkparam import Rel @@ -162,23 +160,22 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): _check_kwargs(kwargs) config = dict(_config_level[level], **kwargs) - config = edict(config) - if config.cast_model_type == mstype.float16: + if config["cast_model_type"] == mstype.float16: network.to_float(mstype.float16) - if config.keep_batchnorm_fp32: + if config["keep_batchnorm_fp32"]: _do_keep_batchnorm_fp32(network) if loss_fn: - network = _add_loss_network(network, loss_fn, config.cast_model_type) + network = _add_loss_network(network, loss_fn, config["cast_model_type"]) if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network = _VirtualDatasetCell(network) loss_scale = 1.0 - if config.loss_scale_manager is not None: - loss_scale_manager = config.loss_scale_manager + if config["loss_scale_manager"] is not None: + loss_scale_manager = config["loss_scale_manager"] loss_scale = loss_scale_manager.get_loss_scale() update_cell = loss_scale_manager.get_update_cell() if update_cell is not None: diff --git a/requirements.txt b/requirements.txt index acded899f76..967b6fcaad7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ protobuf >= 3.8.0 asttokens >= 1.1.13 pillow >= 6.2.0 scipy >= 1.5.2 -easydict >= 1.9 sympy >= 1.4 cffi >= 1.12.3 wheel >= 0.32.0 @@ -17,4 +16,5 @@ astunparse >= 1.6.3 packaging >= 20.0 pycocotools >= 2.0.2 # for st test tables >= 3.6.1 # for st test -psutil >= 5.7.0 +easydict >= 1.9 # for st test +psutil >= 5.7.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 56c6d72eac2..9cd543c4545 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,6 @@ required_package = [ 'asttokens >= 1.1.13', 'pillow >= 6.2.0', 'scipy >= 1.5.2', - 'easydict >= 1.9', 'sympy >= 1.4', 'cffi >= 1.12.3', 'wheel >= 0.32.0',