forked from mindspore-Ecosystem/mindspore
amend assign
This commit is contained in:
parent
943a9a020b
commit
8812c8fe9b
|
@ -26,6 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
||||
AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -42,6 +43,27 @@ AbstractBasePtr InferImplAssign(const abstract::AnalysisEnginePtr &, const Primi
|
|||
return args_spec_list[1]->Broaden();
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("variable", variable_type, check_types, prim_name);
|
||||
|
||||
auto variable_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(args_spec_list[0]->BuildShape())[kShape];
|
||||
auto value_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(args_spec_list[1]->BuildShape())[kShape];
|
||||
if (variable_shape.size() != value_shape.size()) {
|
||||
if (variable_shape.size() == 1 && variable_shape[0] == 1 && value_shape.empty()) {
|
||||
return args_spec_list[0];
|
||||
} else if (value_shape.size() == 1 && value_shape[0] == 1 && variable_shape.empty()) {
|
||||
return args_spec_list[0];
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the rank of value is " << value_shape.size()
|
||||
<< ". It should be same with variable's rank " << variable_shape.size() << ".";
|
||||
}
|
||||
}
|
||||
for (uint64_t i = 0; i < variable_shape.size(); i++) {
|
||||
if (variable_shape[i] != value_shape[i]) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the shape of value is "
|
||||
<< args_spec_list[1]->BuildShape()->ToString()
|
||||
<< ". It should be same with variable's shape "
|
||||
<< args_spec_list[0]->BuildShape()->ToString() << ".";
|
||||
}
|
||||
}
|
||||
return args_spec_list[0];
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign, nullptr, true);
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore._checkparam import Rel
|
|||
from mindspore.nn.optim.optimizer import opt_init_args_register
|
||||
from .optimizer import Optimizer
|
||||
|
||||
|
||||
def _get_lr(step, RMS, learning_rate, relative_step, warmup_init, scale_parameter, eps):
|
||||
"""update optimizer learning rete"""
|
||||
rel_step_sz = learning_rate
|
||||
|
@ -100,8 +101,8 @@ def _run_opt_with_one_number(eps, clip_threshold, decay_rate, beta1,
|
|||
exp_avg_sq_update = F.cast(exp_avg_sq, grad_dtype)
|
||||
|
||||
if scale_lr:
|
||||
RMS = _rms(p_data_fp32)
|
||||
learning_rate_update = _get_lr(step, RMS, learning_rate, relative_step, warmup_init, scale_parameter, eps)
|
||||
rms = _rms(p_data_fp32)
|
||||
learning_rate_update = _get_lr(step, rms, learning_rate, relative_step, warmup_init, scale_parameter, eps)
|
||||
learning_rate_update = F.assign(learning_rate, F.cast(learning_rate_update, F.dtype(learning_rate)))
|
||||
else:
|
||||
learning_rate_update = learning_rate * 1.0
|
||||
|
@ -161,7 +162,8 @@ class AdaFactor(Optimizer):
|
|||
r"""
|
||||
Updates gradients by the Adaptive Learning Rates with Sublinear Memory Cost (Adafactor) algorithm.
|
||||
|
||||
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost <https://arxiv.org/abs/1804.04235>`_.
|
||||
The Adafactor algorithm is proposed in `Adafactor: Adafactor: Adaptive Learning Rates with Sublinear Memory
|
||||
Cost <https://arxiv.org/abs/1804.04235>`_.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
|
|
@ -37,7 +37,7 @@ def get_bprop_masked_select(self):
|
|||
dvalue = sum_op(dvalue)
|
||||
dinput = F.cast(dinput, F.dtype(input_data))
|
||||
if is_instance_op(value, mstype.number) is True:
|
||||
dvalue = zeros_like(value)
|
||||
dvalue = 0
|
||||
else:
|
||||
dvalue = F.cast(dvalue, F.dtype(value))
|
||||
return dinput, zeros_like(mask), dvalue
|
||||
|
|
|
@ -21,7 +21,6 @@ from .. import functional as F
|
|||
from .. import operations as P
|
||||
from .._grad.grad_base import bprop_getters
|
||||
from .._grad.grad_math_ops import binop_grad_common
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
||||
@bprop_getters.register(P.Lerp)
|
||||
def get_bprop_index_lerp(self):
|
||||
|
@ -37,7 +36,7 @@ def get_bprop_index_lerp(self):
|
|||
dweight = mul_op(dout, sub_op(end, start))
|
||||
dstart, dend = binop_grad_common(start, end, dstart, dend)
|
||||
if is_instance_op(weight, mstype.number) is True:
|
||||
dweight = zeros_like(weight)
|
||||
dweight = 0
|
||||
else:
|
||||
_, dweight = binop_grad_common(start, weight, dstart, dweight)
|
||||
dweight = F.cast(dweight, F.dtype(weight))
|
||||
|
|
|
@ -44,7 +44,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
|
|||
gradient, variable, moment):
|
||||
""" tensor_run_opt """
|
||||
success = True
|
||||
new_weight = opt(variable, moment, learning_rate, gradient, momentum)[0]
|
||||
new_weight = opt(variable, moment, learning_rate, gradient, momentum)
|
||||
success = F.depend(success, F.assign(variable, new_weight))
|
||||
return success
|
||||
|
||||
|
|
Loading…
Reference in New Issue