forked from mindspore-Ecosystem/mindspore
!5734 fix bug of adam optimizer
Merge pull request !5734 from fary86/fix_adam_opt_bug
This commit is contained in:
commit
8c249c1805
|
@ -219,7 +219,7 @@ class _Context:
|
|||
self.set_param(ms_ctx_param.profiling_options, option)
|
||||
|
||||
def set_variable_memory_max_size(self, variable_memory_max_size):
|
||||
if not check_input_format(variable_memory_max_size):
|
||||
if not _check_input_format(variable_memory_max_size):
|
||||
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
|
||||
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
|
||||
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
|
||||
|
@ -230,7 +230,7 @@ class _Context:
|
|||
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
|
||||
|
||||
def set_max_device_memory(self, max_device_memory):
|
||||
if not check_input_format(max_device_memory):
|
||||
if not _check_input_format(max_device_memory):
|
||||
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
|
||||
max_device_memory_value = float(max_device_memory[:-2])
|
||||
if max_device_memory_value == 0:
|
||||
|
@ -289,7 +289,7 @@ class _Context:
|
|||
thread_info.debug_runtime = enable
|
||||
|
||||
|
||||
def check_input_format(x):
|
||||
def _check_input_format(x):
|
||||
import re
|
||||
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
|
||||
result = re.match(pattern, x)
|
||||
|
|
|
@ -51,6 +51,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
Returns:
|
||||
Tensor, the new value of v after updating.
|
||||
"""
|
||||
success = True
|
||||
if optim_filter:
|
||||
op_mul = P.Mul()
|
||||
op_square = P.Square()
|
||||
|
@ -80,8 +81,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
|
|||
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
|
||||
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
|
||||
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
|
||||
return next_param
|
||||
return gradient
|
||||
success = F.depend(success, next_param)
|
||||
return success
|
||||
|
||||
|
||||
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
|
||||
|
|
Loading…
Reference in New Issue