!40248 synchronize the code of the reverse_sequence operator in r1.8

Merge pull request !40248 from wtcheng/master
This commit is contained in:
i-robot 2022-08-11 03:31:23 +00:00 committed by Gitee
commit 6745857a39
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 374 additions and 30 deletions

View File

@ -399,6 +399,7 @@ Array操作
mindspore.ops.rank
mindspore.ops.repeat_elements
mindspore.ops.reshape
mindspore.ops.reverse_sequence
mindspore.ops.scatter_nd
mindspore.ops.select
mindspore.ops.sequence_mask

View File

@ -0,0 +1,18 @@
mindspore.Tensor.reverse_sequence
=======================
.. py:method:: reverse_sequence(seq_lengths, seq_dim, batch_dim=0)
对输入序列进行部分反转。
参数:
- **seq_lengths** (Tensor) - 指定反转长度为一维向量其数据类型为int32或int64。
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
- **batch_dim** (int) - 指定切片维度。默认值0。
返回:
Tensorshape和数据类型与输入相同。
异常:
- **TypeError** - `seq_dim``batch_dim` 不是int。
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。

View File

@ -182,6 +182,7 @@ Array操作
mindspore.Tensor.repeat
mindspore.Tensor.reshape
mindspore.Tensor.resize
mindspore.Tensor.reverse_sequence
mindspore.Tensor.scatter_add
mindspore.Tensor.scatter_div
mindspore.Tensor.scatter_max

View File

@ -0,0 +1,19 @@
mindspore.ops.reverse_sequence
==============================
.. py:function:: mindspore.ops.reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0)
对输入序列进行部分反转。
参数:
- **x** (Tensor) - 输入需反转的数据其数据类型支持包括bool在内的所有数值型。
- **seq_lengths** (Tensor) - 指定反转长度为一维向量其数据类型为int32或int64。
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
- **batch_dim** (int) - 指定切片维度。默认值0。
返回:
Tensorshape和数据类型与输入相同。
异常:
- **TypeError** - `seq_dim``batch_dim` 不是int。
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。

View File

@ -401,6 +401,7 @@ Array Operation
mindspore.ops.rank
mindspore.ops.repeat_elements
mindspore.ops.reshape
mindspore.ops.reverse_sequence
mindspore.ops.scatter_nd
mindspore.ops.select
mindspore.ops.sequence_mask

View File

@ -192,6 +192,7 @@ BuiltInTypeMap &GetMethodMap() {
{"transpose", std::string("transpose")}, // P.transpose
{"flatten", std::string("flatten")}, // P.reshape(,-1)
{"reshape", std::string("reshape")}, // P.reshape()
{"reverse_sequence", std::string("reverse_sequence")}, // P.ReverseSequence()
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()

View File

@ -40,6 +40,71 @@ int64_t ReverseSequence::get_batch_dim() const {
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameReverseSequence, ReverseSequence);
namespace {
abstract::ShapePtr ReverseSequenceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 0);
MS_EXCEPTION_IF_NULL(x_shape_ptr);
auto seq_lengths_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 1);
MS_EXCEPTION_IF_NULL(seq_lengths_shape_ptr);
auto x_shape = x_shape_ptr->shape();
auto seq_lengths_shape = seq_lengths_shape_ptr->shape();
auto seq_dim_ptr = primitive->GetAttr("seq_dim");
MS_EXCEPTION_IF_NULL(seq_dim_ptr);
auto seq_dim = GetValue<int64_t>(seq_dim_ptr);
auto batch_dim_ptr = primitive->GetAttr("batch_dim");
MS_EXCEPTION_IF_NULL(batch_dim_ptr);
auto batch_dim = GetValue<int64_t>(batch_dim_ptr);
if (seq_dim >= SizeToLong(x_shape.size())) {
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_dim' should be < x rank: " << x_shape.size()
<< ", but got " << seq_dim << ".";
}
if (batch_dim >= SizeToLong(x_shape.size())) {
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be < x rank: " << x_shape.size()
<< ", but got " << batch_dim << ".";
}
if (batch_dim == seq_dim) {
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be != 'seq_dim': " << seq_dim
<< ", but got " << batch_dim << ".";
}
if (seq_lengths_shape.size() != 1) {
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_lengths' rank should be = 'expected': 1 , but got "
<< seq_lengths_shape.size() << ".";
}
if (seq_lengths_shape[0] != x_shape[batch_dim]) {
MS_EXCEPTION(ValueError)
<< "For 'ReverseSequence', the 'seq_lengths' vector size should be = input size along batch_dim: "
<< x_shape[batch_dim] << ", but got " << seq_lengths_shape[0] << ".";
}
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr ReverseSequenceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
MS_LOG(EXCEPTION) << "For '" << prim->name()
<< ", the input args used for infer shape and type is necessary, but missing it.";
}
const std::set<TypePtr> seq_lengths_valid_types = {kInt32, kInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("seq_lengths", input_args[1]->BuildType(), seq_lengths_valid_types,
prim->name());
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64,
kInt8, kInt16, kInt32, kInt64, kComplex64, kComplex128, kBool};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), x_valid_types, prim->name());
}
} // namespace
AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = ReverseSequenceInferType(primitive, input_args);
auto infer_shape = ReverseSequenceInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ReverseSequence, prim::kPrimReverseSequence, ReverseSequenceInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -371,6 +371,65 @@ def reshape(x, *shape):
return F.reshape(x, new_shape)
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
"""
Reverses variable length slices.
Args:
x (Tensor): The input to reverse, supporting all number types including bool.
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
seq_dim (int): The dimension where reversal is performed. Required.
batch_dim (int): The input is sliced in this dimension. Default: 0.
Returns:
Reversed tensor with the same shape and data type as input.
Raises:
TypeError: If `seq_dim` or `batch_dim` is not an int.
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[1. 2. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
>>> print(output)
[[1. 5. 9.]
[4. 2. 6.]
[7. 8. 3.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[2. 1. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[3. 2. 1.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([4, 4]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[4. 3. 2. 1.]
[8. 7. 6. 5.]]
"""
return F.reverse_sequence(x, seq_lengths, seq_dim, batch_dim)
def ravel(x):
"""
Return a contiguous flattened tensor.

View File

@ -1896,6 +1896,64 @@ class Tensor(Tensor_):
new_shape = validator.check_reshape_shp(shape)
return tensor_operator_registry.get('reshape')()(self, new_shape)
def reverse_sequence(self, seq_lengths, seq_dim, batch_dim=0):
"""
Reverses variable length slices.
Args:
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
seq_dim (int): The dimension where reversal is performed. Required.
batch_dim (int): The input is sliced in this dimension. Default: 0.
Returns:
Reversed tensor with the same shape and data type as input.
Raises:
TypeError: If `seq_dim` or `batch_dim` is not an int.
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[1. 2. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
>>> print(output)
[[1. 5. 9.]
[4. 2. 6.]
[7. 8. 3.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[2. 1. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[3. 2. 1.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([4, 4]))
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
>>> print(output)
[[4. 3. 2. 1.]
[8. 7. 6. 5.]]
"""
self._init_check()
return tensor_operator_registry.get('reverse_sequence')(seq_dim, batch_dim)(self, seq_lengths)
def ravel(self):
"""
Return a contiguous flattened tensor.

View File

@ -435,6 +435,65 @@ def get_reshape_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(P.ReverseSequence)
def get_reverse_sequence_vmap_rule(prim, axis_size):
"""VmapRule for `ReverseSequence` operation."""
if isinstance(prim, str):
prim = Primitive(prim)
reshape = P.Reshape()
batch_dim = prim.batch_dim_
seq_dim = prim.seq_dim_
@constexpr
def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
if seq_dim_ == dim:
seq_dim_ += 1
if seq_dim_ == batch_dim_:
batch_dim_ += 1
elif batch_dim_ == dim:
batch_dim_ += 1
if seq_dim_ == batch_dim_:
seq_dim_ += 1
return batch_dim_, seq_dim_
@constexpr
def get_seq_dim(dim, batch_dim_, seq_dim_):
if seq_dim_ < dim and seq_dim_ < batch_dim_:
seq_dim_ = seq_dim_ + 1
elif seq_dim_ > dim and seq_dim_ > batch_dim_:
seq_dim_ = seq_dim_ - 1
else:
seq_dim_ = seq_dim_
return seq_dim_
def vmap_rule(x_bdim, seq_lengths_bdim):
is_all_none, result = vmap_general_preprocess(prim, x_bdim, seq_lengths_bdim)
if is_all_none:
return result
x, dim = x_bdim
seq_lengths, seq_lengths_dim = seq_lengths_bdim
seq_lengths = mnp.moveaxis(seq_lengths, seq_lengths_dim, 0)
origin_shape = x.shape
batch_dim_ = batch_dim
seq_dim_ = seq_dim
batch_dim_, seq_dim_ = get_batch_seq_dim(dim, batch_dim_, seq_dim_)
x = mnp.moveaxis(x, [dim, batch_dim_], [0, 1])
shape = x.shape
shape = (shape[0] * shape[1],) + tuple(_ for _ in shape[2:])
x = reshape(x, shape)
seq_dim_ = get_seq_dim(dim, batch_dim_, seq_dim_)
seq_lengths = reshape(seq_lengths, (-1,))
x = P.ReverseSequence(seq_dim=seq_dim_)(x, seq_lengths)
shape = x.shape
shape = (origin_shape[dim], origin_shape[batch_dim_],) + tuple(_ for _ in shape[1:])
out = reshape(x, shape)
if batch_dim_ not in (0, 1):
out = mnp.moveaxis(out, 1, batch_dim_)
return out, 0
return vmap_rule
@vmap_rules_getters.register(P.Flatten)
def get_flatten_vmap_rule(prim, axis_size):
"""VmapRule for `Flatten` operation."""

View File

@ -45,6 +45,7 @@ from .array_func import (
rank,
reshape,
reshape_,
reverse_sequence,
flatten,
concat,
stack,

View File

@ -859,6 +859,65 @@ def reshape(input_x, input_shape):
return reshape_(input_x, input_shape)
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
"""
Reverses variable length slices.
Args:
x (Tensor): The input to reverse, supporting all number types including bool.
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
seq_dim (int): The dimension where reversal is performed. Required.
batch_dim (int): The input is sliced in this dimension. Default: 0.
Returns:
Reversed tensor with the same shape and data type as input.
Raises:
TypeError: If `seq_dim` or `batch_dim` is not an int.
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
>>> print(output)
[[1. 2. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=0, batch_dim=1)
>>> print(output)
[[1. 5. 9.]
[4. 2. 6.]
[7. 8. 3.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
>>> print(output)
[[2. 1. 3.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
>>> print(output)
[[3. 2. 1.]
[5. 4. 6.]
[9. 8. 7.]]
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
>>> seq_lengths = Tensor(np.array([4, 4]))
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
>>> print(output)
[[4. 3. 2. 1.]
[8. 7. 6. 5.]]
"""
return P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim)(x, seq_lengths)
def flatten(input_x):
r"""
Flattens a tensor without changing its batch size on the 0-th axis.

View File

@ -354,6 +354,7 @@ tensor_operator_registry.register('mean', P.ReduceMean)
tensor_operator_registry.register('prod', prod)
tensor_operator_registry.register('round', P.Round)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('reverse_sequence', P.ReverseSequence)
tensor_operator_registry.register('xlogy', P.Xlogy)
tensor_operator_registry.register('flatten', P.Flatten)
tensor_operator_registry.register('transpose', P.Transpose)

View File

@ -5791,7 +5791,7 @@ class ReverseSequence(PrimitiveWithInfer):
ValueError: If value of `batch_dim` is equal to or greater than length of shape of `x` .
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
@ -5844,20 +5844,6 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
self.batch_dim_ = batch_dim
def infer_shape(self, x, seq_lengths):
validator.check_int_range(self.seq_dim_, 0, len(x), Rel.INC_LEFT, "seq_dim", self.name)
validator.check_int_range(self.batch_dim_, 0, len(x), Rel.INC_LEFT, "batch_dim", self.name)
validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
validator.check("seq_lengths vector size", seq_lengths[0],
"input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
return x
def infer_dtype(self, x, seq_lengths):
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
return x
class EditDistance(Primitive):
r"""

View File

@ -23,6 +23,7 @@ from mindspore.common import Parameter, ParameterTuple
from mindspore.train.callback import Callback
from mindspore.ops import operations as P
from mindspore._checkparam import Validator, Rel
from mindspore import log as logger
class _StartFLJob(nn.Cell):
@ -163,8 +164,11 @@ class FederatedLearningManager(Callback):
self._abs_grads_ema = dict()
if self._is_adaptive_sync():
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
if self._as_prefix not in _.name}
self._last_param = {
_.name: deepcopy(_.asnumpy())
for _ in self._model.trainable_params()
if self._as_prefix not in _.name
}
for param in self._model.trainable_params():
if self._as_prefix not in param.name:
self._model_size += np.product(param.shape)
@ -201,13 +205,13 @@ class FederatedLearningManager(Callback):
try:
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
except KeyError:
print("{} is not in self._last_param".format(param.name))
logger.warning("{} is not in self._last_param".format(param.name))
for param in self._model.trainable_params():
if self._as_prefix in param.name:
try:
param.set_data(Parameter(abs_grads[param.name]))
except KeyError:
print("{} is not in abs_grads".format(param.name))
logger.warning("{} is not in abs_grads".format(param.name))
def _as_analyze_gradient(self):
"""
@ -225,32 +229,40 @@ class FederatedLearningManager(Callback):
try:
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
except KeyError:
print("{} is not in self._last_param".format(param.name))
logger.warning("{} is not in self._last_param".format(param.name))
for last_p in self._last_param:
try:
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
except KeyError:
print("{} is not in self._grads_ema".format(last_p))
logger.warning("{} is not in self._grads_ema".format(last_p))
continue
try:
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[
last_p]
except KeyError:
print("{} is not in self._abs_grads_ema".format(last_p))
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
continue
try:
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
except KeyError:
print("{} is not in self._abs_grads_ema".format(last_p))
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
continue
try:
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
if divide_base != 0:
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
else:
logger.warning("You cannot divide by 0!")
return
consistent_rate_sum += np.sum(layer_consistent_rate)
except KeyError:
print("{} is not in self._grads_ema".format(last_p))
logger.warning("{} is not in self._grads_ema".format(last_p))
consistent_rate = float(consistent_rate_sum / self._model_size)
if self._model_size != 0:
consistent_rate = float(consistent_rate_sum / self._model_size)
else:
logger.warning("You cannot divide by 0!")
return
if self._min_consistent_rate > consistent_rate:
self._min_consistent_rate = consistent_rate
@ -272,8 +284,11 @@ class FederatedLearningManager(Callback):
"""
Set the value of last parameters for adaptive synchronization.
"""
self._last_param = {_.name: deepcopy(_.asnumpy())
for _ in self._model.trainable_params() if self._as_prefix not in _.name}
self._last_param = {
_.name: deepcopy(_.asnumpy())
for _ in self._model.trainable_params()
if self._as_prefix not in _.name
}
def step_end(self, run_context):
"""
@ -309,4 +324,4 @@ class FederatedLearningManager(Callback):
self._round_id += 1
self._as_set_last_param()
print("sync step is: {}".format(self._global_step))
logger.info("sync step is: {}".format(self._global_step))