!40248 synchronize the code of the reverse_sequence operator in r1.8
Merge pull request !40248 from wtcheng/master
This commit is contained in:
commit
6745857a39
|
@ -399,6 +399,7 @@ Array操作
|
||||||
mindspore.ops.rank
|
mindspore.ops.rank
|
||||||
mindspore.ops.repeat_elements
|
mindspore.ops.repeat_elements
|
||||||
mindspore.ops.reshape
|
mindspore.ops.reshape
|
||||||
|
mindspore.ops.reverse_sequence
|
||||||
mindspore.ops.scatter_nd
|
mindspore.ops.scatter_nd
|
||||||
mindspore.ops.select
|
mindspore.ops.select
|
||||||
mindspore.ops.sequence_mask
|
mindspore.ops.sequence_mask
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
mindspore.Tensor.reverse_sequence
|
||||||
|
=======================
|
||||||
|
|
||||||
|
.. py:method:: reverse_sequence(seq_lengths, seq_dim, batch_dim=0)
|
||||||
|
|
||||||
|
对输入序列进行部分反转。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **seq_lengths** (Tensor) - 指定反转长度,为一维向量,其数据类型为int32或int64。
|
||||||
|
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
|
||||||
|
- **batch_dim** (int) - 指定切片维度。默认值:0。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,shape和数据类型与输入相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `seq_dim` 或 `batch_dim` 不是int。
|
||||||
|
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。
|
|
@ -182,6 +182,7 @@ Array操作
|
||||||
mindspore.Tensor.repeat
|
mindspore.Tensor.repeat
|
||||||
mindspore.Tensor.reshape
|
mindspore.Tensor.reshape
|
||||||
mindspore.Tensor.resize
|
mindspore.Tensor.resize
|
||||||
|
mindspore.Tensor.reverse_sequence
|
||||||
mindspore.Tensor.scatter_add
|
mindspore.Tensor.scatter_add
|
||||||
mindspore.Tensor.scatter_div
|
mindspore.Tensor.scatter_div
|
||||||
mindspore.Tensor.scatter_max
|
mindspore.Tensor.scatter_max
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
mindspore.ops.reverse_sequence
|
||||||
|
==============================
|
||||||
|
|
||||||
|
.. py:function:: mindspore.ops.reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0)
|
||||||
|
|
||||||
|
对输入序列进行部分反转。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- **x** (Tensor) - 输入需反转的数据,其数据类型支持包括bool在内的所有数值型。
|
||||||
|
- **seq_lengths** (Tensor) - 指定反转长度,为一维向量,其数据类型为int32或int64。
|
||||||
|
- **seq_dim** (int) - 指定反转的维度,此值为必填参数。
|
||||||
|
- **batch_dim** (int) - 指定切片维度。默认值:0。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
Tensor,shape和数据类型与输入相同。
|
||||||
|
|
||||||
|
异常:
|
||||||
|
- **TypeError** - `seq_dim` 或 `batch_dim` 不是int。
|
||||||
|
- **ValueError** - `batch_dim` 大于或等于 `x` 的shape长度。
|
|
@ -401,6 +401,7 @@ Array Operation
|
||||||
mindspore.ops.rank
|
mindspore.ops.rank
|
||||||
mindspore.ops.repeat_elements
|
mindspore.ops.repeat_elements
|
||||||
mindspore.ops.reshape
|
mindspore.ops.reshape
|
||||||
|
mindspore.ops.reverse_sequence
|
||||||
mindspore.ops.scatter_nd
|
mindspore.ops.scatter_nd
|
||||||
mindspore.ops.select
|
mindspore.ops.select
|
||||||
mindspore.ops.sequence_mask
|
mindspore.ops.sequence_mask
|
||||||
|
|
|
@ -192,6 +192,7 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
{"transpose", std::string("transpose")}, // P.transpose
|
{"transpose", std::string("transpose")}, // P.transpose
|
||||||
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
{"flatten", std::string("flatten")}, // P.reshape(,-1)
|
||||||
{"reshape", std::string("reshape")}, // P.reshape()
|
{"reshape", std::string("reshape")}, // P.reshape()
|
||||||
|
{"reverse_sequence", std::string("reverse_sequence")}, // P.ReverseSequence()
|
||||||
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
{"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
|
||||||
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
{"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
|
||||||
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
{"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
|
||||||
|
|
|
@ -40,6 +40,71 @@ int64_t ReverseSequence::get_batch_dim() const {
|
||||||
return GetValue<int64_t>(value_ptr);
|
return GetValue<int64_t>(value_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_PRIMITIVE_C(kNameReverseSequence, ReverseSequence);
|
namespace {
|
||||||
|
abstract::ShapePtr ReverseSequenceInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto x_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(x_shape_ptr);
|
||||||
|
auto seq_lengths_shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReverseSequence", input_args, 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(seq_lengths_shape_ptr);
|
||||||
|
auto x_shape = x_shape_ptr->shape();
|
||||||
|
auto seq_lengths_shape = seq_lengths_shape_ptr->shape();
|
||||||
|
|
||||||
|
auto seq_dim_ptr = primitive->GetAttr("seq_dim");
|
||||||
|
MS_EXCEPTION_IF_NULL(seq_dim_ptr);
|
||||||
|
auto seq_dim = GetValue<int64_t>(seq_dim_ptr);
|
||||||
|
auto batch_dim_ptr = primitive->GetAttr("batch_dim");
|
||||||
|
MS_EXCEPTION_IF_NULL(batch_dim_ptr);
|
||||||
|
auto batch_dim = GetValue<int64_t>(batch_dim_ptr);
|
||||||
|
|
||||||
|
if (seq_dim >= SizeToLong(x_shape.size())) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_dim' should be < x rank: " << x_shape.size()
|
||||||
|
<< ", but got " << seq_dim << ".";
|
||||||
|
}
|
||||||
|
if (batch_dim >= SizeToLong(x_shape.size())) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be < x rank: " << x_shape.size()
|
||||||
|
<< ", but got " << batch_dim << ".";
|
||||||
|
}
|
||||||
|
if (batch_dim == seq_dim) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'batch_dim' should be != 'seq_dim': " << seq_dim
|
||||||
|
<< ", but got " << batch_dim << ".";
|
||||||
|
}
|
||||||
|
if (seq_lengths_shape.size() != 1) {
|
||||||
|
MS_EXCEPTION(ValueError) << "For 'ReverseSequence', the 'seq_lengths' rank should be = 'expected': 1 , but got "
|
||||||
|
<< seq_lengths_shape.size() << ".";
|
||||||
|
}
|
||||||
|
if (seq_lengths_shape[0] != x_shape[batch_dim]) {
|
||||||
|
MS_EXCEPTION(ValueError)
|
||||||
|
<< "For 'ReverseSequence', the 'seq_lengths' vector size should be = input size along batch_dim: "
|
||||||
|
<< x_shape[batch_dim] << ", but got " << seq_lengths_shape[0] << ".";
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(x_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr ReverseSequenceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||||
|
MS_LOG(EXCEPTION) << "For '" << prim->name()
|
||||||
|
<< ", the input args used for infer shape and type is necessary, but missing it.";
|
||||||
|
}
|
||||||
|
const std::set<TypePtr> seq_lengths_valid_types = {kInt32, kInt64};
|
||||||
|
(void)CheckAndConvertUtils::CheckTensorTypeValid("seq_lengths", input_args[1]->BuildType(), seq_lengths_valid_types,
|
||||||
|
prim->name());
|
||||||
|
|
||||||
|
const std::set<TypePtr> x_valid_types = {kFloat16, kFloat32, kFloat64, kUInt8, kUInt16, kUInt32, kUInt64,
|
||||||
|
kInt8, kInt16, kInt32, kInt64, kComplex64, kComplex128, kBool};
|
||||||
|
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), x_valid_types, prim->name());
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
AbstractBasePtr ReverseSequenceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||||
|
auto infer_type = ReverseSequenceInferType(primitive, input_args);
|
||||||
|
auto infer_shape = ReverseSequenceInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(ReverseSequence, prim::kPrimReverseSequence, ReverseSequenceInfer, nullptr, true);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -371,6 +371,65 @@ def reshape(x, *shape):
|
||||||
return F.reshape(x, new_shape)
|
return F.reshape(x, new_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
|
||||||
|
"""
|
||||||
|
Reverses variable length slices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input to reverse, supporting all number types including bool.
|
||||||
|
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||||
|
seq_dim (int): The dimension where reversal is performed. Required.
|
||||||
|
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reversed tensor with the same shape and data type as input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||||
|
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 2. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 5. 9.]
|
||||||
|
[4. 2. 6.]
|
||||||
|
[7. 8. 3.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[2. 1. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[3. 2. 1.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[4. 3. 2. 1.]
|
||||||
|
[8. 7. 6. 5.]]
|
||||||
|
"""
|
||||||
|
return F.reverse_sequence(x, seq_lengths, seq_dim, batch_dim)
|
||||||
|
|
||||||
|
|
||||||
def ravel(x):
|
def ravel(x):
|
||||||
"""
|
"""
|
||||||
Return a contiguous flattened tensor.
|
Return a contiguous flattened tensor.
|
||||||
|
|
|
@ -1896,6 +1896,64 @@ class Tensor(Tensor_):
|
||||||
new_shape = validator.check_reshape_shp(shape)
|
new_shape = validator.check_reshape_shp(shape)
|
||||||
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
return tensor_operator_registry.get('reshape')()(self, new_shape)
|
||||||
|
|
||||||
|
def reverse_sequence(self, seq_lengths, seq_dim, batch_dim=0):
|
||||||
|
"""
|
||||||
|
Reverses variable length slices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||||
|
seq_dim (int): The dimension where reversal is performed. Required.
|
||||||
|
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reversed tensor with the same shape and data type as input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||||
|
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 2. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=0, batch_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 5. 9.]
|
||||||
|
[4. 2. 6.]
|
||||||
|
[7. 8. 3.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[2. 1. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[3. 2. 1.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||||
|
>>> output = x.reverse_sequence(seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[4. 3. 2. 1.]
|
||||||
|
[8. 7. 6. 5.]]
|
||||||
|
"""
|
||||||
|
self._init_check()
|
||||||
|
return tensor_operator_registry.get('reverse_sequence')(seq_dim, batch_dim)(self, seq_lengths)
|
||||||
|
|
||||||
def ravel(self):
|
def ravel(self):
|
||||||
"""
|
"""
|
||||||
Return a contiguous flattened tensor.
|
Return a contiguous flattened tensor.
|
||||||
|
|
|
@ -435,6 +435,65 @@ def get_reshape_vmap_rule(prim, axis_size):
|
||||||
return vmap_rule
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
|
@vmap_rules_getters.register(P.ReverseSequence)
|
||||||
|
def get_reverse_sequence_vmap_rule(prim, axis_size):
|
||||||
|
"""VmapRule for `ReverseSequence` operation."""
|
||||||
|
if isinstance(prim, str):
|
||||||
|
prim = Primitive(prim)
|
||||||
|
reshape = P.Reshape()
|
||||||
|
batch_dim = prim.batch_dim_
|
||||||
|
seq_dim = prim.seq_dim_
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_batch_seq_dim(dim, batch_dim_, seq_dim_):
|
||||||
|
if seq_dim_ == dim:
|
||||||
|
seq_dim_ += 1
|
||||||
|
if seq_dim_ == batch_dim_:
|
||||||
|
batch_dim_ += 1
|
||||||
|
elif batch_dim_ == dim:
|
||||||
|
batch_dim_ += 1
|
||||||
|
if seq_dim_ == batch_dim_:
|
||||||
|
seq_dim_ += 1
|
||||||
|
return batch_dim_, seq_dim_
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def get_seq_dim(dim, batch_dim_, seq_dim_):
|
||||||
|
if seq_dim_ < dim and seq_dim_ < batch_dim_:
|
||||||
|
seq_dim_ = seq_dim_ + 1
|
||||||
|
elif seq_dim_ > dim and seq_dim_ > batch_dim_:
|
||||||
|
seq_dim_ = seq_dim_ - 1
|
||||||
|
else:
|
||||||
|
seq_dim_ = seq_dim_
|
||||||
|
return seq_dim_
|
||||||
|
|
||||||
|
def vmap_rule(x_bdim, seq_lengths_bdim):
|
||||||
|
is_all_none, result = vmap_general_preprocess(prim, x_bdim, seq_lengths_bdim)
|
||||||
|
if is_all_none:
|
||||||
|
return result
|
||||||
|
x, dim = x_bdim
|
||||||
|
seq_lengths, seq_lengths_dim = seq_lengths_bdim
|
||||||
|
seq_lengths = mnp.moveaxis(seq_lengths, seq_lengths_dim, 0)
|
||||||
|
origin_shape = x.shape
|
||||||
|
batch_dim_ = batch_dim
|
||||||
|
seq_dim_ = seq_dim
|
||||||
|
batch_dim_, seq_dim_ = get_batch_seq_dim(dim, batch_dim_, seq_dim_)
|
||||||
|
x = mnp.moveaxis(x, [dim, batch_dim_], [0, 1])
|
||||||
|
shape = x.shape
|
||||||
|
shape = (shape[0] * shape[1],) + tuple(_ for _ in shape[2:])
|
||||||
|
x = reshape(x, shape)
|
||||||
|
seq_dim_ = get_seq_dim(dim, batch_dim_, seq_dim_)
|
||||||
|
seq_lengths = reshape(seq_lengths, (-1,))
|
||||||
|
x = P.ReverseSequence(seq_dim=seq_dim_)(x, seq_lengths)
|
||||||
|
shape = x.shape
|
||||||
|
shape = (origin_shape[dim], origin_shape[batch_dim_],) + tuple(_ for _ in shape[1:])
|
||||||
|
out = reshape(x, shape)
|
||||||
|
if batch_dim_ not in (0, 1):
|
||||||
|
out = mnp.moveaxis(out, 1, batch_dim_)
|
||||||
|
return out, 0
|
||||||
|
|
||||||
|
return vmap_rule
|
||||||
|
|
||||||
|
|
||||||
@vmap_rules_getters.register(P.Flatten)
|
@vmap_rules_getters.register(P.Flatten)
|
||||||
def get_flatten_vmap_rule(prim, axis_size):
|
def get_flatten_vmap_rule(prim, axis_size):
|
||||||
"""VmapRule for `Flatten` operation."""
|
"""VmapRule for `Flatten` operation."""
|
||||||
|
|
|
@ -45,6 +45,7 @@ from .array_func import (
|
||||||
rank,
|
rank,
|
||||||
reshape,
|
reshape,
|
||||||
reshape_,
|
reshape_,
|
||||||
|
reverse_sequence,
|
||||||
flatten,
|
flatten,
|
||||||
concat,
|
concat,
|
||||||
stack,
|
stack,
|
||||||
|
|
|
@ -859,6 +859,65 @@ def reshape(input_x, input_shape):
|
||||||
return reshape_(input_x, input_shape)
|
return reshape_(input_x, input_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_sequence(x, seq_lengths, seq_dim, batch_dim=0):
|
||||||
|
"""
|
||||||
|
Reverses variable length slices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): The input to reverse, supporting all number types including bool.
|
||||||
|
seq_lengths (Tensor): Must be a 1-D vector with int32 or int64 types.
|
||||||
|
seq_dim (int): The dimension where reversal is performed. Required.
|
||||||
|
batch_dim (int): The input is sliced in this dimension. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reversed tensor with the same shape and data type as input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `seq_dim` or `batch_dim` is not an int.
|
||||||
|
ValueError: If value of `batch_dim` is equal to or greater than length of shape of input.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 2. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([1, 2, 3]))
|
||||||
|
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=0, batch_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[1. 5. 9.]
|
||||||
|
[4. 2. 6.]
|
||||||
|
[7. 8. 3.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([2, 2, 3]))
|
||||||
|
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[2. 1. 3.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([3, 2, 3]))
|
||||||
|
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[3. 2. 1.]
|
||||||
|
[5. 4. 6.]
|
||||||
|
[9. 8. 7.]]
|
||||||
|
>>> x = Tensor(np.array([[1, 2, 3, 4], [5, 6, 7, 8]]), mindspore.float32)
|
||||||
|
>>> seq_lengths = Tensor(np.array([4, 4]))
|
||||||
|
>>> output = ops.reverse_sequence(x, seq_lengths, seq_dim=1)
|
||||||
|
>>> print(output)
|
||||||
|
[[4. 3. 2. 1.]
|
||||||
|
[8. 7. 6. 5.]]
|
||||||
|
"""
|
||||||
|
return P.ReverseSequence(seq_dim=seq_dim, batch_dim=batch_dim)(x, seq_lengths)
|
||||||
|
|
||||||
|
|
||||||
def flatten(input_x):
|
def flatten(input_x):
|
||||||
r"""
|
r"""
|
||||||
Flattens a tensor without changing its batch size on the 0-th axis.
|
Flattens a tensor without changing its batch size on the 0-th axis.
|
||||||
|
|
|
@ -354,6 +354,7 @@ tensor_operator_registry.register('mean', P.ReduceMean)
|
||||||
tensor_operator_registry.register('prod', prod)
|
tensor_operator_registry.register('prod', prod)
|
||||||
tensor_operator_registry.register('round', P.Round)
|
tensor_operator_registry.register('round', P.Round)
|
||||||
tensor_operator_registry.register('reshape', P.Reshape)
|
tensor_operator_registry.register('reshape', P.Reshape)
|
||||||
|
tensor_operator_registry.register('reverse_sequence', P.ReverseSequence)
|
||||||
tensor_operator_registry.register('xlogy', P.Xlogy)
|
tensor_operator_registry.register('xlogy', P.Xlogy)
|
||||||
tensor_operator_registry.register('flatten', P.Flatten)
|
tensor_operator_registry.register('flatten', P.Flatten)
|
||||||
tensor_operator_registry.register('transpose', P.Transpose)
|
tensor_operator_registry.register('transpose', P.Transpose)
|
||||||
|
|
|
@ -5791,7 +5791,7 @@ class ReverseSequence(PrimitiveWithInfer):
|
||||||
ValueError: If value of `batch_dim` is equal to or greater than length of shape of `x` .
|
ValueError: If value of `batch_dim` is equal to or greater than length of shape of `x` .
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), mindspore.float32)
|
||||||
|
@ -5844,20 +5844,6 @@ class ReverseSequence(PrimitiveWithInfer):
|
||||||
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
|
validator.check_value_type("batch_dim", batch_dim, [int], self.name)
|
||||||
self.batch_dim_ = batch_dim
|
self.batch_dim_ = batch_dim
|
||||||
|
|
||||||
def infer_shape(self, x, seq_lengths):
|
|
||||||
validator.check_int_range(self.seq_dim_, 0, len(x), Rel.INC_LEFT, "seq_dim", self.name)
|
|
||||||
validator.check_int_range(self.batch_dim_, 0, len(x), Rel.INC_LEFT, "batch_dim", self.name)
|
|
||||||
validator.check("batch_dim", self.batch_dim_, "seq_dim", self.seq_dim_, Rel.NE, self.name)
|
|
||||||
validator.check("seq_lengths rank", len(seq_lengths), "expected", 1, Rel.EQ, self.name)
|
|
||||||
validator.check("seq_lengths vector size", seq_lengths[0],
|
|
||||||
"input size along batch_dim", x[self.batch_dim_], Rel.EQ, self.name)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def infer_dtype(self, x, seq_lengths):
|
|
||||||
validator.check_tensor_dtype_valid("x_dtype", x, mstype.number_type + (mstype.bool_,), self.name)
|
|
||||||
validator.check_tensor_dtype_valid("seq_lengths_dtype", seq_lengths, [mstype.int32, mstype.int64], self.name)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class EditDistance(Primitive):
|
class EditDistance(Primitive):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -23,6 +23,7 @@ from mindspore.common import Parameter, ParameterTuple
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._checkparam import Validator, Rel
|
from mindspore._checkparam import Validator, Rel
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
|
||||||
class _StartFLJob(nn.Cell):
|
class _StartFLJob(nn.Cell):
|
||||||
|
@ -163,8 +164,11 @@ class FederatedLearningManager(Callback):
|
||||||
self._abs_grads_ema = dict()
|
self._abs_grads_ema = dict()
|
||||||
|
|
||||||
if self._is_adaptive_sync():
|
if self._is_adaptive_sync():
|
||||||
self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
|
self._last_param = {
|
||||||
if self._as_prefix not in _.name}
|
_.name: deepcopy(_.asnumpy())
|
||||||
|
for _ in self._model.trainable_params()
|
||||||
|
if self._as_prefix not in _.name
|
||||||
|
}
|
||||||
for param in self._model.trainable_params():
|
for param in self._model.trainable_params():
|
||||||
if self._as_prefix not in param.name:
|
if self._as_prefix not in param.name:
|
||||||
self._model_size += np.product(param.shape)
|
self._model_size += np.product(param.shape)
|
||||||
|
@ -201,13 +205,13 @@ class FederatedLearningManager(Callback):
|
||||||
try:
|
try:
|
||||||
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._last_param".format(param.name))
|
logger.warning("{} is not in self._last_param".format(param.name))
|
||||||
for param in self._model.trainable_params():
|
for param in self._model.trainable_params():
|
||||||
if self._as_prefix in param.name:
|
if self._as_prefix in param.name:
|
||||||
try:
|
try:
|
||||||
param.set_data(Parameter(abs_grads[param.name]))
|
param.set_data(Parameter(abs_grads[param.name]))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in abs_grads".format(param.name))
|
logger.warning("{} is not in abs_grads".format(param.name))
|
||||||
|
|
||||||
def _as_analyze_gradient(self):
|
def _as_analyze_gradient(self):
|
||||||
"""
|
"""
|
||||||
|
@ -225,32 +229,40 @@ class FederatedLearningManager(Callback):
|
||||||
try:
|
try:
|
||||||
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
|
grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._last_param".format(param.name))
|
logger.warning("{} is not in self._last_param".format(param.name))
|
||||||
for last_p in self._last_param:
|
for last_p in self._last_param:
|
||||||
try:
|
try:
|
||||||
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
|
self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._grads_ema".format(last_p))
|
logger.warning("{} is not in self._grads_ema".format(last_p))
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[
|
self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[
|
||||||
last_p]
|
last_p]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._abs_grads_ema".format(last_p))
|
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
|
divide_base = np.where(self._abs_grads_ema[last_p] == 0,
|
||||||
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
|
np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._abs_grads_ema".format(last_p))
|
logger.warning("{} is not in self._abs_grads_ema".format(last_p))
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
|
if divide_base != 0:
|
||||||
|
layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
|
||||||
|
else:
|
||||||
|
logger.warning("You cannot divide by 0!")
|
||||||
|
return
|
||||||
consistent_rate_sum += np.sum(layer_consistent_rate)
|
consistent_rate_sum += np.sum(layer_consistent_rate)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("{} is not in self._grads_ema".format(last_p))
|
logger.warning("{} is not in self._grads_ema".format(last_p))
|
||||||
|
|
||||||
consistent_rate = float(consistent_rate_sum / self._model_size)
|
if self._model_size != 0:
|
||||||
|
consistent_rate = float(consistent_rate_sum / self._model_size)
|
||||||
|
else:
|
||||||
|
logger.warning("You cannot divide by 0!")
|
||||||
|
return
|
||||||
|
|
||||||
if self._min_consistent_rate > consistent_rate:
|
if self._min_consistent_rate > consistent_rate:
|
||||||
self._min_consistent_rate = consistent_rate
|
self._min_consistent_rate = consistent_rate
|
||||||
|
@ -272,8 +284,11 @@ class FederatedLearningManager(Callback):
|
||||||
"""
|
"""
|
||||||
Set the value of last parameters for adaptive synchronization.
|
Set the value of last parameters for adaptive synchronization.
|
||||||
"""
|
"""
|
||||||
self._last_param = {_.name: deepcopy(_.asnumpy())
|
self._last_param = {
|
||||||
for _ in self._model.trainable_params() if self._as_prefix not in _.name}
|
_.name: deepcopy(_.asnumpy())
|
||||||
|
for _ in self._model.trainable_params()
|
||||||
|
if self._as_prefix not in _.name
|
||||||
|
}
|
||||||
|
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -309,4 +324,4 @@ class FederatedLearningManager(Callback):
|
||||||
self._round_id += 1
|
self._round_id += 1
|
||||||
self._as_set_last_param()
|
self._as_set_last_param()
|
||||||
|
|
||||||
print("sync step is: {}".format(self._global_step))
|
logger.info("sync step is: {}".format(self._global_step))
|
||||||
|
|
Loading…
Reference in New Issue