amend assign

This commit is contained in:
jiangzhenguang 2021-08-25 20:31:37 +08:00
parent 943a9a020b
commit 8812c8fe9b
5 changed files with 30 additions and 7 deletions

View File

@ -26,6 +26,7 @@
namespace mindspore {
namespace ops {
AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
@ -42,6 +43,27 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
return args_spec_list[1]->Broaden();
}
(void)CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
auto variable_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(args_spec_list[0]->BuildShape())[kShape];
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(args_spec_list[1]->BuildShape())[kShape];
if (variable_shape.size() != value_shape.size()) {
if (variable_shape.size() == 1 && variable_shape[0] == 1 && value_shape.empty()) {
return args_spec_list[0];
} else if (value_shape.size() == 1 && value_shape[0] == 1 && variable_shape.empty()) {
return args_spec_list[0];
} else {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the rank of value is " << value_shape.size()
<< ". It should be same with variable's rank " << variable_shape.size() << ".";
}
}
for (uint64_t i = 0; i < variable_shape.size(); i++) {
if (variable_shape[i] != value_shape[i]) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the shape of value is "
<< args_spec_list[1]->BuildShape()->ToString()
<< ". It should be same with variable's shape "
<< args_spec_list[0]->BuildShape()->ToString() << ".";
}
}
return args_spec_list[0];
}
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true);

View File

@ -26,6 +26,7 @@ from mindspore._checkparam import Rel
from mindspore.nn.optim.optimizer import opt_init_args_register
from .optimizer import Optimizer
def _get_lr(step, RMS, learning_rate, relative_step, warmup_init, scale_parameter, eps):
"""update optimizer learning rete"""
rel_step_sz = learning_rate
@ -100,8 +101,8 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
if scale_lr:
RMS = _rms(p_data_fp32)
learning_rate_update = _get_lr(step, RMS, learning_rate, relative_step, warmup_init, scale_parameter, eps)
rms = _rms(p_data_fp32)
learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps)
learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate)))
else:
learning_rate_update = learning_rate * 1.0
@ -161,7 +162,8 @@ class AdaFactor(Optimizer):
r"""
Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm.
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost <https://arxiv.org/abs/1804.04235>`_.
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory
Cost <https://arxiv.org/abs/1804.04235>`_.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.

View File

@ -37,7 +37,7 @@ def get_bprop_masked_select(self):
dvalue = sum_op(dvalue)
dinput = F.cast(dinput, F.dtype(input_data))
if is_instance_op(value, mstype.number) is True:
dvalue = zeros_like(value)
dvalue = 0
else:
dvalue = F.cast(dvalue, F.dtype(value))
return dinput, zeros_like(mask), dvalue

View File

@ -21,7 +21,6 @@ from .. import functional as F
from .. import operations as P
from .._grad.grad_base import bprop_getters
from .._grad.grad_math_ops import binop_grad_common
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@bprop_getters.register(P.Lerp)
def get_bprop_index_lerp(self):
@ -37,7 +36,7 @@ def get_bprop_index_lerp(self):
dweight = mul_op(dout, sub_op(end, start))
dstart, dend = binop_grad_common(start, end, dstart, dend)
if is_instance_op(weight, mstype.number) is True:
dweight = zeros_like(weight)
dweight = 0
else:
_, dweight = binop_grad_common(start, weight, dstart, dweight)
dweight = F.cast(dweight, F.dtype(weight))

View File

@ -44,7 +44,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
gradient, variable, moment):
""" tensor_run_opt """
success = True
new_weight = opt(variable, moment, learning_rate, gradient, momentum)[0]
new_weight = opt(variable, moment, learning_rate, gradient, momentum)
success = F.depend(success, F.assign(variable, new_weight))
return success