From 9409f832451e4fbd354947e7d3f553e3cee999eb Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Fri, 19 Jun 2020 16:23:37 +0800 Subject: [PATCH] fix params KeyError in group params --- mindspore/nn/optim/optimizer.py | 27 ++++++++++++++++++++++++--- mindspore/ops/operations/debug_ops.py | 6 ------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 9bfc3a284b..999572e264 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -219,8 +219,32 @@ class Optimizer(Cell): raise TypeError("Learning rate should be float, Tensor or Iterable.") return lr + def _check_group_params(self, parameters): + """Check group params.""" + parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] + for group_param in parameters: + invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) + if invalid_key: + raise KeyError(f'The key "{invalid_key}" cannot be recognized in group params.') + + if 'order_params' in group_param.keys(): + if len(group_param.keys()) > 1: + raise ValueError("The order params dict in group parameters should " + "only include the 'order_params' key.") + if not isinstance(group_param['order_params'], Iterable): + raise TypeError("The value of 'order_params' should be an Iterable type.") + continue + + if not group_param['params']: + raise ValueError("Optimizer got an empty group parameter list.") + + for param in group_param['params']: + if not isinstance(param, Parameter): + raise TypeError("The group param should be an iterator of Parameter type.") + def _parse_group_params(self, parameters, learning_rate): """Parse group params.""" + self._check_group_params(parameters) if self.dynamic_lr: dynamic_lr_length = learning_rate.size() else: @@ -250,9 +274,6 @@ class Optimizer(Cell): if dynamic_lr_length not in (lr_length, 0): raise ValueError("The dynamic learning rate in group should be the same size.") - if not group_param['params']: - raise ValueError("Optimizer got an empty group parameter list.") - dynamic_lr_length = lr_length self.dynamic_lr_length = dynamic_lr_length diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 91f56e0e19..ec6abd369f 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -309,12 +309,6 @@ class Print(PrimitiveWithInfer): Output tensor or string to stdout. Note: - The print operation cannot support the following cases currently. - - 1. The type of tensor is float64 or bool. - - 2. The data of tensor is a scalar type. - In pynative mode, please use python print function. Inputs: