!5734 fix bug of adam optimizer

Merge pull request !5734 from fary86/fix_adam_opt_bug
This commit is contained in:
mindspore-ci-bot 2020-09-04 10:14:42 +08:00 committed by Gitee
commit 8c249c1805
2 changed files with 6 additions and 5 deletions

View File

@ -219,7 +219,7 @@ class _Context:
self.set_param(ms_ctx_param.profiling_options, option)
def set_variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size):
if not _check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
@ -230,7 +230,7 @@ class _Context:
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
def set_max_device_memory(self, max_device_memory):
if not check_input_format(max_device_memory):
if not _check_input_format(max_device_memory):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2])
if max_device_memory_value == 0:
@ -289,7 +289,7 @@ class _Context:
thread_info.debug_runtime = enable
def check_input_format(x):
def _check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)

View File

@ -51,6 +51,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
Returns:
Tensor, the new value of v after updating.
"""
success = True
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
@ -80,8 +81,8 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
success = F.depend(success, next_param)
return success
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",