!40248 synchronize the code of the reverse_sequence operator in r1.8
Merge pull request !40248 from wtcheng/master
This commit is contained in:
commit
6745857a39
|
@ -399,6 +399,7 @@ Array操作
|
|||
mindspore.ops.rank
|
||||
mindspore.ops.repeat_elements
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.reverse_sequence
|
||||
mindspore.ops.scatter_nd
|
||||
mindspore.ops.select
|
||||
mindspore.ops.sequence_mask
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
mindspore.Tensor.reverse_sequence
|
||||
=======================
|
||||
|
||||
.. py:method:: reverse_sequence(seq_lengths, seq_dim, batch_dim=0)
|
||||
|
||||
对输入序列进行部分反转。
|
||||
|
||||
参数:
|
||||
- **seq_lengths** (Tensor) - 指定反转长度,为一维向量,其数据类型为int32或int64。
|
||||
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
|
||||
- **batch_dim** (int) - 指定切片维度。默认值:0。
|
||||
|
||||
返回:
|
||||
Tensor,shape和数据类型与输入相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `seq_dim` 或 `batch_dim` 不是int。
|
||||
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。
|
|
@ -182,6 +182,7 @@ Array操作
|
|||
mindspore.Tensor.repeat
|
||||
mindspore.Tensor.reshape
|
||||
mindspore.Tensor.resize
|
||||
mindspore.Tensor.reverse_sequence
|
||||
mindspore.Tensor.scatter_add
|
||||
mindspore.Tensor.scatter_div
|
||||
mindspore.Tensor.scatter_max
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
mindspore.ops.reverse_sequence
|
||||
==============================
|
||||
|
||||
.. py:function:: mindspore.ops.reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0)
|
||||
|
||||
对输入序列进行部分反转。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - 输入需反转的数据,其数据类型支持包括bool在内的所有数值型。
|
||||
- **seq_lengths** (Tensor) - 指定反转长度,为一维向量,其数据类型为int32或int64。
|
||||
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
|
||||
- **batch_dim** (int) - 指定切片维度。默认值:0。
|
||||
|
||||
返回:
|
||||
Tensor,shape和数据类型与输入相同。
|
||||
|
||||
异常:
|
||||
- **TypeError** - `seq_dim` 或 `batch_dim` 不是int。
|
||||
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。
|
|
@ -401,6 +401,7 @@ Array Operation
|
|||
mindspore.ops.rank
|
||||
mindspore.ops.repeat_elements
|
||||
mindspore.ops.reshape
|
||||
mindspore.ops.reverse_sequence
|
||||
mindspore.ops.scatter_nd
|
||||
mindspore.ops.select
|
||||
mindspore.ops.sequence_mask
|
||||
|
|
|
@ -192,6 +192,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"transpose", std::string("transpose")}, // P.transpose
|
||||
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
||||
{"reshape", std::string("reshape")}, // P.reshape()
|
||||
{"reverse_sequence", std::string("reverse_sequence")}, // P.ReverseSequence()
|
||||
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
||||
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
||||
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
||||
|
|
|
@ -40,6 +40,71 @@ int64_t ReverseSequence::get_batch_dim() const {
|
|||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameReverseSequence, ReverseSequence);
|
||||
namespace {
|
||||
abstract::ShapePtr ReverseSequenceInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 0);
|
||||
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||
auto seq_lengths_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(seq_lengths_shape_ptr);
|
||||
auto x_shape = x_shape_ptr->shape();
|
||||
auto seq_lengths_shape = seq_lengths_shape_ptr->shape();
|
||||
|
||||
auto seq_dim_ptr = primitive->GetAttr("seq_dim");
|
||||
MS_EXCEPTION_IF_NULL(seq_dim_ptr);
|
||||
auto seq_dim = GetValue<int64_t>(seq_dim_ptr);
|
||||
auto batch_dim_ptr = primitive->GetAttr("batch_dim");
|
||||
MS_EXCEPTION_IF_NULL(batch_dim_ptr);
|
||||
auto batch_dim = GetValue<int64_t>(batch_dim_ptr);
|
||||
|
||||
if (seq_dim >= SizeToLong(x_shape.size())) {
|
||||
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_dim' should be < x rank: " << x_shape.size()
|
||||
<< ", but got " << seq_dim << ".";
|
||||
}
|
||||
if (batch_dim >= SizeToLong(x_shape.size())) {
|
||||
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be < x rank: " << x_shape.size()
|
||||
<< ", but got " << batch_dim << ".";
|
||||
}
|
||||
if (batch_dim == seq_dim) {
|
||||
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be != 'seq_dim': " << seq_dim
|
||||
<< ", but got " << batch_dim << ".";
|
||||
}
|
||||
if (seq_lengths_shape.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_lengths' rank should be = 'expected': 1 , but got "
|
||||
<< seq_lengths_shape.size() << ".";
|
||||
}
|
||||
if (seq_lengths_shape[0] != x_shape[batch_dim]) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For 'ReverseSequence', the 'seq_lengths' vector size should be = input size along batch_dim: "
|
||||
<< x_shape[batch_dim] << ", but got " << seq_lengths_shape[0] << ".";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
TypePtr ReverseSequenceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim->name()
|
||||
<< ", the input args used for infer shape and type is necessary, but missing it.";
|
||||
}
|
||||
const std::set<TypePtr> seq_lengths_valid_types = {kInt32, kInt64};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("seq_lengths", input_args[1]->BuildType(), seq_lengths_valid_types,
|
||||
prim->name());
|
||||
|
||||
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64,
|
||||
kInt8, kInt16, kInt32, kInt64, kComplex64, kComplex128, kBool};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), x_valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = ReverseSequenceInferType(primitive, input_args);
|
||||
auto infer_shape = ReverseSequenceInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReverseSequence, prim::kPrimReverseSequence, ReverseSequenceInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -371,6 +371,65 @@ def reshape(x, *shape):
|
|||
return F.reshape(x, new_shape)
|
||||
|
||||
|
||||
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
|
||||
"""
|
||||
Reverses variable length slices.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input to reverse, supporting all number types including bool.
|
||||
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||
seq_dim (int): The dimension where reversal is performed. Required.
|
||||
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||
|
||||
Returns:
|
||||
Reversed tensor with the same shape and data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 2. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 5. 9.]
|
||||
[4. 2. 6.]
|
||||
[7. 8. 3.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[2. 1. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[3. 2. 1.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[4. 3. 2. 1.]
|
||||
[8. 7. 6. 5.]]
|
||||
"""
|
||||
return F.reverse_sequence(x, seq_lengths, seq_dim, batch_dim)
|
||||
|
||||
|
||||
def ravel(x):
|
||||
"""
|
||||
Return a contiguous flattened tensor.
|
||||
|
|
|
@ -1896,6 +1896,64 @@ class Tensor(Tensor_):
|
|||
new_shape = validator.check_reshape_shp(shape)
|
||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||
|
||||
def reverse_sequence(self, seq_lengths, seq_dim, batch_dim=0):
|
||||
"""
|
||||
Reverses variable length slices.
|
||||
|
||||
Args:
|
||||
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||
seq_dim (int): The dimension where reversal is performed. Required.
|
||||
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||
|
||||
Returns:
|
||||
Reversed tensor with the same shape and data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 2. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 5. 9.]
|
||||
[4. 2. 6.]
|
||||
[7. 8. 3.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[2. 1. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[3. 2. 1.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[4. 3. 2. 1.]
|
||||
[8. 7. 6. 5.]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('reverse_sequence')(seq_dim, batch_dim)(self, seq_lengths)
|
||||
|
||||
def ravel(self):
|
||||
"""
|
||||
Return a contiguous flattened tensor.
|
||||
|
|
|
@ -435,6 +435,65 @@ def get_reshape_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.ReverseSequence)
|
||||
def get_reverse_sequence_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `ReverseSequence` operation."""
|
||||
if isinstance(prim, str):
|
||||
prim = Primitive(prim)
|
||||
reshape = P.Reshape()
|
||||
batch_dim = prim.batch_dim_
|
||||
seq_dim = prim.seq_dim_
|
||||
|
||||
@constexpr
|
||||
def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
|
||||
if seq_dim_ == dim:
|
||||
seq_dim_ += 1
|
||||
if seq_dim_ == batch_dim_:
|
||||
batch_dim_ += 1
|
||||
elif batch_dim_ == dim:
|
||||
batch_dim_ += 1
|
||||
if seq_dim_ == batch_dim_:
|
||||
seq_dim_ += 1
|
||||
return batch_dim_, seq_dim_
|
||||
|
||||
@constexpr
|
||||
def get_seq_dim(dim, batch_dim_, seq_dim_):
|
||||
if seq_dim_ < dim and seq_dim_ < batch_dim_:
|
||||
seq_dim_ = seq_dim_ + 1
|
||||
elif seq_dim_ > dim and seq_dim_ > batch_dim_:
|
||||
seq_dim_ = seq_dim_ - 1
|
||||
else:
|
||||
seq_dim_ = seq_dim_
|
||||
return seq_dim_
|
||||
|
||||
def vmap_rule(x_bdim, seq_lengths_bdim):
|
||||
is_all_none, result = vmap_general_preprocess(prim, x_bdim, seq_lengths_bdim)
|
||||
if is_all_none:
|
||||
return result
|
||||
x, dim = x_bdim
|
||||
seq_lengths, seq_lengths_dim = seq_lengths_bdim
|
||||
seq_lengths = mnp.moveaxis(seq_lengths, seq_lengths_dim, 0)
|
||||
origin_shape = x.shape
|
||||
batch_dim_ = batch_dim
|
||||
seq_dim_ = seq_dim
|
||||
batch_dim_, seq_dim_ = get_batch_seq_dim(dim, batch_dim_, seq_dim_)
|
||||
x = mnp.moveaxis(x, [dim, batch_dim_], [0, 1])
|
||||
shape = x.shape
|
||||
shape = (shape[0] * shape[1],) + tuple(_ for _ in shape[2:])
|
||||
x = reshape(x, shape)
|
||||
seq_dim_ = get_seq_dim(dim, batch_dim_, seq_dim_)
|
||||
seq_lengths = reshape(seq_lengths, (-1,))
|
||||
x = P.ReverseSequence(seq_dim=seq_dim_)(x, seq_lengths)
|
||||
shape = x.shape
|
||||
shape = (origin_shape[dim], origin_shape[batch_dim_],) + tuple(_ for _ in shape[1:])
|
||||
out = reshape(x, shape)
|
||||
if batch_dim_ not in (0, 1):
|
||||
out = mnp.moveaxis(out, 1, batch_dim_)
|
||||
return out, 0
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.Flatten)
|
||||
def get_flatten_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for `Flatten` operation."""
|
||||
|
|
|
@ -45,6 +45,7 @@ from .array_func import (
|
|||
rank,
|
||||
reshape,
|
||||
reshape_,
|
||||
reverse_sequence,
|
||||
flatten,
|
||||
concat,
|
||||
stack,
|
||||
|
|
|
@ -859,6 +859,65 @@ def reshape(input_x, input_shape):
|
|||
return reshape_(input_x, input_shape)
|
||||
|
||||
|
||||
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
|
||||
"""
|
||||
Reverses variable length slices.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input to reverse, supporting all number types including bool.
|
||||
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||
seq_dim (int): The dimension where reversal is performed. Required.
|
||||
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||
|
||||
Returns:
|
||||
Reversed tensor with the same shape and data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 2. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=0, batch_dim=1)
|
||||
>>> print(output)
|
||||
[[1. 5. 9.]
|
||||
[4. 2. 6.]
|
||||
[7. 8. 3.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[2. 1. 3.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[3. 2. 1.]
|
||||
[5. 4. 6.]
|
||||
[9. 8. 7.]]
|
||||
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||
>>> print(output)
|
||||
[[4. 3. 2. 1.]
|
||||
[8. 7. 6. 5.]]
|
||||
"""
|
||||
return P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim)(x, seq_lengths)
|
||||
|
||||
|
||||
def flatten(input_x):
|
||||
r"""
|
||||
Flattens a tensor without changing its batch size on the 0-th axis.
|
||||
|
|
|
@ -354,6 +354,7 @@ tensor_operator_registry.register('mean', P.ReduceMean)
|
|||
tensor_operator_registry.register('prod', prod)
|
||||
tensor_operator_registry.register('round', P.Round)
|
||||
tensor_operator_registry.register('reshape', P.Reshape)
|
||||
tensor_operator_registry.register('reverse_sequence', P.ReverseSequence)
|
||||
tensor_operator_registry.register('xlogy', P.Xlogy)
|
||||
tensor_operator_registry.register('flatten', P.Flatten)
|
||||
tensor_operator_registry.register('transpose', P.Transpose)
|
||||
|
|
|
@ -5791,7 +5791,7 @@ class ReverseSequence(PrimitiveWithInfer):
|
|||
ValueError: If value of `batch_dim` is equal to or greater than length of shape of `x` .
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||
|
@ -5844,20 +5844,6 @@ class ReverseSequence(PrimitiveWithInfer):
|
|||
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
|
||||
self.batch_dim_ = batch_dim
|
||||
|
||||
def infer_shape(self, x, seq_lengths):
|
||||
validator.check_int_range(self.seq_dim_, 0, len(x), Rel.INC_LEFT, "seq_dim", self.name)
|
||||
validator.check_int_range(self.batch_dim_, 0, len(x), Rel.INC_LEFT, "batch_dim", self.name)
|
||||
validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
|
||||
validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("seq_lengths vector size", seq_lengths[0],
|
||||
"input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
|
||||
return x
|
||||
|
||||
def infer_dtype(self, x, seq_lengths):
|
||||
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
|
||||
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
|
||||
return x
|
||||
|
||||
|
||||
class EditDistance(Primitive):
|
||||
r"""
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common import Parameter, ParameterTuple
|
|||
from mindspore.train.callback import Callback
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._checkparam import Validator, Rel
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
class _StartFLJob(nn.Cell):
|
||||
|
@ -163,8 +164,11 @@ class FederatedLearningManager(Callback):
|
|||
self._abs_grads_ema = dict()
|
||||
|
||||
if self._is_adaptive_sync():
|
||||
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
|
||||
if self._as_prefix not in _.name}
|
||||
self._last_param = {
|
||||
_.name: deepcopy(_.asnumpy())
|
||||
for _ in self._model.trainable_params()
|
||||
if self._as_prefix not in _.name
|
||||
}
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix not in param.name:
|
||||
self._model_size += np.product(param.shape)
|
||||
|
@ -201,13 +205,13 @@ class FederatedLearningManager(Callback):
|
|||
try:
|
||||
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
||||
except KeyError:
|
||||
print("{} is not in self._last_param".format(param.name))
|
||||
logger.warning("{} is not in self._last_param".format(param.name))
|
||||
for param in self._model.trainable_params():
|
||||
if self._as_prefix in param.name:
|
||||
try:
|
||||
param.set_data(Parameter(abs_grads[param.name]))
|
||||
except KeyError:
|
||||
print("{} is not in abs_grads".format(param.name))
|
||||
logger.warning("{} is not in abs_grads".format(param.name))
|
||||
|
||||
def _as_analyze_gradient(self):
|
||||
"""
|
||||
|
@ -225,32 +229,40 @@ class FederatedLearningManager(Callback):
|
|||
try:
|
||||
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
|
||||
except KeyError:
|
||||
print("{} is not in self._last_param".format(param.name))
|
||||
logger.warning("{} is not in self._last_param".format(param.name))
|
||||
for last_p in self._last_param:
|
||||
try:
|
||||
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
|
||||
except KeyError:
|
||||
print("{} is not in self._grads_ema".format(last_p))
|
||||
logger.warning("{} is not in self._grads_ema".format(last_p))
|
||||
continue
|
||||
try:
|
||||
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[
|
||||
last_p]
|
||||
except KeyError:
|
||||
print("{} is not in self._abs_grads_ema".format(last_p))
|
||||
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
|
||||
continue
|
||||
try:
|
||||
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
|
||||
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
|
||||
except KeyError:
|
||||
print("{} is not in self._abs_grads_ema".format(last_p))
|
||||
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
|
||||
continue
|
||||
try:
|
||||
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
|
||||
if divide_base != 0:
|
||||
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
|
||||
else:
|
||||
logger.warning("You cannot divide by 0!")
|
||||
return
|
||||
consistent_rate_sum += np.sum(layer_consistent_rate)
|
||||
except KeyError:
|
||||
print("{} is not in self._grads_ema".format(last_p))
|
||||
logger.warning("{} is not in self._grads_ema".format(last_p))
|
||||
|
||||
consistent_rate = float(consistent_rate_sum / self._model_size)
|
||||
if self._model_size != 0:
|
||||
consistent_rate = float(consistent_rate_sum / self._model_size)
|
||||
else:
|
||||
logger.warning("You cannot divide by 0!")
|
||||
return
|
||||
|
||||
if self._min_consistent_rate > consistent_rate:
|
||||
self._min_consistent_rate = consistent_rate
|
||||
|
@ -272,8 +284,11 @@ class FederatedLearningManager(Callback):
|
|||
"""
|
||||
Set the value of last parameters for adaptive synchronization.
|
||||
"""
|
||||
self._last_param = {_.name: deepcopy(_.asnumpy())
|
||||
for _ in self._model.trainable_params() if self._as_prefix not in _.name}
|
||||
self._last_param = {
|
||||
_.name: deepcopy(_.asnumpy())
|
||||
for _ in self._model.trainable_params()
|
||||
if self._as_prefix not in _.name
|
||||
}
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
|
@ -309,4 +324,4 @@ class FederatedLearningManager(Callback):
|
|||
self._round_id += 1
|
||||
self._as_set_last_param()
|
||||
|
||||
print("sync step is: {}".format(self._global_step))
|
||||
logger.info("sync step is: {}".format(self._global_step))
|
||||
|
|
Loading…
Reference in New Issue