Add EditDistance op for GE.
This commit is contained in:
parent
c3d3750195
commit
4c99f4f649
|
@ -190,6 +190,7 @@ constexpr const char kNameSquareSumAll[] = "SquareSumAll";
|
|||
constexpr const char kNameAscendQuant[] = "Quant";
|
||||
constexpr const char kNameAscendDequant[] = "Dequant";
|
||||
constexpr const char kNameReverseSequence[] = "ReverseSequence";
|
||||
constexpr const char kNameEditDistance[] = "EditDistance";
|
||||
constexpr const char kNameCase[] = "Case";
|
||||
|
||||
class OpAdapterMap {
|
||||
|
|
|
@ -87,4 +87,12 @@ ATTR_MAP(ReverseSequence) = {{"seq_dim", ATTR_DESC(seq_dim, AnyTraits<int>())},
|
|||
{"batch_dim", ATTR_DESC(batch_dim, AnyTraits<int>())}};
|
||||
OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence))
|
||||
|
||||
// EditDistance
|
||||
INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(hypothesis_values)},
|
||||
{3, INPUT_DESC(hypothesis_shape)}, {4, INPUT_DESC(truth_indices)},
|
||||
{5, INPUT_DESC(truth_values)}, {6, INPUT_DESC(truth_shape)}};
|
||||
ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}};
|
||||
REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -54,5 +54,8 @@ DECLARE_OP_ADAPTER(Data)
|
|||
|
||||
DECLARE_OP_ADAPTER(ReverseSequence)
|
||||
DECLARE_OP_USE_OUTPUT(ReverseSequence)
|
||||
|
||||
DECLARE_OP_ADAPTER(EditDistance)
|
||||
DECLARE_OP_USE_OUTPUT(EditDistance)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_
|
||||
|
|
|
@ -23,7 +23,6 @@ from .acos_grad import _acos_grad_tbe
|
|||
from .acosh import _acosh_tbe
|
||||
from .acosh_grad import _acosh_grad_tbe
|
||||
from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe
|
||||
from .add import _add_tbe
|
||||
from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe
|
||||
from .add_n import _add_n_tbe
|
||||
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Add op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
add_op_info = TBERegOp("Add") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("add") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("dynamicFormat") \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "x2", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(add_op_info)
|
||||
def _add_tbe():
|
||||
"""Add TBE register"""
|
||||
return
|
|
@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||
Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding,
|
||||
ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance,
|
||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
|
@ -92,6 +92,7 @@ from .sparse_ops import SparseToDense
|
|||
|
||||
__all__ = [
|
||||
'ReverseSequence',
|
||||
'EditDistance',
|
||||
'CropAndResize',
|
||||
'TensorAdd',
|
||||
'Argmax',
|
||||
|
|
|
@ -3470,6 +3470,93 @@ class ReverseSequence(PrimitiveWithInfer):
|
|||
return x
|
||||
|
||||
|
||||
class EditDistance(PrimitiveWithInfer):
|
||||
"""
|
||||
Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences.
|
||||
|
||||
Args:
|
||||
normalize (bool): If True, edit distances are normalized by length of truth. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type.
|
||||
The shape of tensor is :math:`(N, R)`.
|
||||
- **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor.
|
||||
Must be 1-D vector with length of N.
|
||||
- **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor.
|
||||
Must be R-length vector with int64 data type. Only constant value is allowed.
|
||||
- **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type.
|
||||
The shape of tensor is :math:`(M, R)`.
|
||||
- **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M.
|
||||
- **truth_shape** (Tensor) - The values of the truth list SparseTensor.
|
||||
Must be R-length vector with int64 data type. Only constant value is allowed.
|
||||
|
||||
Outputs:
|
||||
Tensor, a dense tensor with rank `R-1` and float32 data type.
|
||||
|
||||
Examples:
|
||||
>>> class EditDistance(nn.Cell):
|
||||
>>> def __init__(self, hypothesis_shape, truth_shape, normalize=True):
|
||||
>>> super(EditDistance, self).__init__()
|
||||
>>> self.edit_distance = P.EditDistance(normalize)
|
||||
>>> self.hypothesis_shape = hypothesis_shape
|
||||
>>> self.truth_shape = truth_shape
|
||||
>>>
|
||||
>>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
|
||||
>>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
|
||||
>>> truth_indices, truth_values, self.truth_shape)
|
||||
>>>
|
||||
>>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64))
|
||||
>>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32))
|
||||
>>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64))
|
||||
>>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64))
|
||||
>>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32))
|
||||
>>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64))
|
||||
>>> edit_distance = EditDistance(hypothesis_shape, truth_shape)
|
||||
>>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values)
|
||||
>>> [[1.0, 1.0], [1.0, 1.0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, normalize=True):
|
||||
"""init EditDistance"""
|
||||
self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name)
|
||||
|
||||
def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape):
|
||||
validator.check_const_input('hypothesis_shape', h_shape['value'], self.name)
|
||||
validator.check_const_input('truth_shape', truth_shape['value'], self.name)
|
||||
args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'],
|
||||
"truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']}
|
||||
validator.check_tensor_type_same(args_int, [mstype.int64], self.name)
|
||||
args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
|
||||
hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape']
|
||||
validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name)
|
||||
validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name)
|
||||
validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name)
|
||||
validator.check("hypothesis_values shape", h_values['shape'][0],
|
||||
"hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name)
|
||||
validator.check("hypothesis_shape", h_shape['shape'][0],
|
||||
"hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name)
|
||||
validator.check("truth_values shape", truth_values['shape'][0],
|
||||
"truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name)
|
||||
validator.check("hypothesis_shape", h_shape['shape'][0],
|
||||
"truth_shape", truth_shape['shape'][0], Rel.EQ, self.name)
|
||||
hypothesis_shape_v = h_shape['value'].asnumpy()
|
||||
truth_shape_v = truth_shape['value'].asnumpy()
|
||||
out_shape_rank = len(hypothesis_shape_v) - 1
|
||||
out_shape = []
|
||||
for i in range(out_shape_rank):
|
||||
out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i]))
|
||||
|
||||
return {'shape': tuple(out_shape),
|
||||
'dtype': mstype.tensor_type(mstype.float32),
|
||||
'value': None}
|
||||
|
||||
|
||||
class TransShape(PrimitiveWithInfer):
|
||||
"""
|
||||
Transform the shape of input tensor to target shape.
|
||||
|
|
|
@ -684,6 +684,18 @@ class ParallelConcatNet(nn.Cell):
|
|||
return self.parallel_concat((x1, x2))
|
||||
|
||||
|
||||
class EditDistance(nn.Cell):
|
||||
def __init__(self, hypothesis_shape, truth_shape, normalize=True):
|
||||
super(EditDistance, self).__init__()
|
||||
self.edit_distance = P.EditDistance(normalize)
|
||||
self.hypothesis_shape = hypothesis_shape
|
||||
self.truth_shape =truth_shape
|
||||
|
||||
def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values):
|
||||
return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape,
|
||||
truth_indices, truth_values, self.truth_shape)
|
||||
|
||||
|
||||
test_case_math_ops = [
|
||||
('BitwiseAnd', {
|
||||
'block': P.BitwiseAnd(),
|
||||
|
@ -1978,6 +1990,15 @@ test_case_array_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)),
|
||||
Tensor(np.array([1, 2, 3]).astype(np.int32))],
|
||||
'desc_bprop': [[3, 3]]}),
|
||||
('EditDistance', {
|
||||
'block': EditDistance(Tensor(np.array([1, 1, 2]).astype(np.int64)),
|
||||
Tensor(np.array([2, 2, 2]).astype(np.int64))),
|
||||
'desc_inputs': [Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)),
|
||||
Tensor(np.array([1, 2, 3]).astype(np.float32)),
|
||||
Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)),
|
||||
Tensor(np.array([1, 3, 2, 1]).astype(np.float32))],
|
||||
'skip': ['backward'],
|
||||
}),
|
||||
('LinSpace', {
|
||||
'block': inner.LinSpace(),
|
||||
'desc_inputs': [Tensor([5, 5.5], mstype.float32),
|
||||
|
|
Loading…
Reference in New Issue