From 4c99f4f6499078ad58435494d329309ee2bc4c8a Mon Sep 17 00:00:00 2001 From: liuxiao93 Date: Wed, 19 Aug 2020 15:03:10 +0800 Subject: [PATCH] Add EditDistance op for GE. --- .../ccsrc/transform/graph_ir/op_adapter_map.h | 1 + .../graph_ir/op_declare/array_ops_declare.cc | 8 ++ .../graph_ir/op_declare/array_ops_declare.h | 3 + mindspore/ops/_op_impl/tbe/__init__.py | 1 - mindspore/ops/_op_impl/tbe/add.py | 37 -------- mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 87 +++++++++++++++++++ tests/ut/python/ops/test_ops.py | 21 +++++ 8 files changed, 122 insertions(+), 39 deletions(-) delete mode 100644 mindspore/ops/_op_impl/tbe/add.py diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h index 6b806b331d1..99f67f6dc14 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_map.h @@ -190,6 +190,7 @@ constexpr const char kNameSquareSumAll[] = "SquareSumAll"; constexpr const char kNameAscendQuant[] = "Quant"; constexpr const char kNameAscendDequant[] = "Dequant"; constexpr const char kNameReverseSequence[] = "ReverseSequence"; +constexpr const char kNameEditDistance[] = "EditDistance"; constexpr const char kNameCase[] = "Case"; class OpAdapterMap { diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc index 63f55c218c3..9197e9929fa 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.cc @@ -87,4 +87,12 @@ ATTR_MAP(ReverseSequence) = {{"seq_dim", ATTR_DESC(seq_dim, AnyTraits())}, {"batch_dim", ATTR_DESC(batch_dim, AnyTraits())}}; OUTPUT_MAP(ReverseSequence) = {{0, OUTPUT_DESC(y)}}; REG_ADPT_DESC(ReverseSequence, kNameReverseSequence, ADPT_DESC(ReverseSequence)) + +// EditDistance +INPUT_MAP(EditDistance) = {{1, INPUT_DESC(hypothesis_indices)}, {2, INPUT_DESC(hypothesis_values)}, + {3, INPUT_DESC(hypothesis_shape)}, {4, INPUT_DESC(truth_indices)}, + {5, INPUT_DESC(truth_values)}, {6, INPUT_DESC(truth_shape)}}; +ATTR_MAP(EditDistance) = {{"normalize", ATTR_DESC(normalize, AnyTraits())}}; +OUTPUT_MAP(EditDistance) = {{0, OUTPUT_DESC(output)}}; +REG_ADPT_DESC(EditDistance, kNameEditDistance, ADPT_DESC(EditDistance)) } // namespace mindspore::transform diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.h index c00f33c67ab..2738fe9f4c6 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.h +++ b/mindspore/ccsrc/transform/graph_ir/op_declare/array_ops_declare.h @@ -54,5 +54,8 @@ DECLARE_OP_ADAPTER(Data) DECLARE_OP_ADAPTER(ReverseSequence) DECLARE_OP_USE_OUTPUT(ReverseSequence) + +DECLARE_OP_ADAPTER(EditDistance) +DECLARE_OP_USE_OUTPUT(EditDistance) } // namespace mindspore::transform #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_ARRAY_OPS_DECLARE_H_ diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 4f272326e3c..a869a183fea 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -23,7 +23,6 @@ from .acos_grad import _acos_grad_tbe from .acosh import _acosh_tbe from .acosh_grad import _acosh_grad_tbe from .adam_apply_one_with_decay import _adam_apply_one_with_decay_tbe -from .add import _add_tbe from .apply_centered_rms_prop import _apply_centered_rms_prop_tbe from .add_n import _add_n_tbe from .accumulate_n_v2 import _accumulate_n_v2_tbe diff --git a/mindspore/ops/_op_impl/tbe/add.py b/mindspore/ops/_op_impl/tbe/add.py deleted file mode 100644 index 01eb6e881e2..00000000000 --- a/mindspore/ops/_op_impl/tbe/add.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Add op""" -from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType - -add_op_info = TBERegOp("Add") \ - .fusion_type("ELEMWISE") \ - .async_flag(False) \ - .binfile_name("add.so") \ - .compute_cost(10) \ - .kernel_name("add") \ - .partial_flag(True) \ - .op_pattern("dynamicFormat") \ - .input(0, "x1", False, "required", "all") \ - .input(1, "x2", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ - .get_op_info() - - -@op_info_register(add_op_info) -def _add_tbe(): - """Add TBE register""" - return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index ca03ad2edf7..7bf414d4ae2 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -29,7 +29,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, TransShape, ParallelConcat, Padding, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint, - Squeeze, StridedSlice, Tile, TensorScatterUpdate, + Squeeze, StridedSlice, Tile, TensorScatterUpdate, EditDistance, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, @@ -92,6 +92,7 @@ from .sparse_ops import SparseToDense __all__ = [ 'ReverseSequence', + 'EditDistance', 'CropAndResize', 'TensorAdd', 'Argmax', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 02402095bb6..eb066c44b75 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -3470,6 +3470,93 @@ class ReverseSequence(PrimitiveWithInfer): return x +class EditDistance(PrimitiveWithInfer): + """ + Computes the Levebshtein Edit Distance. It is used to measure the similarity of two sequences. + + Args: + normalize (bool): If True, edit distances are normalized by length of truth. Default: True. + + Inputs: + - **hypothesis_indices** (Tensor) - The indices of the hypothesis list SparseTensor. With int64 data type. + The shape of tensor is :math:`(N, R)`. + - **hypothesis_values** (Tensor) - The values of the hypothesis list SparseTensor. + Must be 1-D vector with length of N. + - **hypothesis_shape** (Tensor) - The values of the hypothesis list SparseTensor. + Must be R-length vector with int64 data type. Only constant value is allowed. + - **truth_indices** (Tensor) - The indices of the truth list SparseTensor. With int64 data type. + The shape of tensor is :math:`(M, R)`. + - **truth_values** (Tensor) - The values of the truth list SparseTensor. Must be 1-D vector with length of M. + - **truth_shape** (Tensor) - The values of the truth list SparseTensor. + Must be R-length vector with int64 data type. Only constant value is allowed. + + Outputs: + Tensor, a dense tensor with rank `R-1` and float32 data type. + + Examples: + >>> class EditDistance(nn.Cell): + >>> def __init__(self, hypothesis_shape, truth_shape, normalize=True): + >>> super(EditDistance, self).__init__() + >>> self.edit_distance = P.EditDistance(normalize) + >>> self.hypothesis_shape = hypothesis_shape + >>> self.truth_shape = truth_shape + >>> + >>> def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): + >>> return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, + >>> truth_indices, truth_values, self.truth_shape) + >>> + >>> hypothesis_indices = Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)) + >>> hypothesis_values = Tensor(np.array([1, 2, 3]).astype(np.float32)) + >>> hypothesis_shape = Tensor(np.array([1, 1, 2]).astype(np.int64)) + >>> truth_indices = Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)) + >>> truth_values = Tensor(np.array([1, 3, 2, 1]).astype(np.float32)) + >>> truth_shape = Tensor(np.array([2, 2, 2]).astype(np.int64)) + >>> edit_distance = EditDistance(hypothesis_shape, truth_shape) + >>> out = edit_distance(hypothesis_indices, hypothesis_values, truth_indices, truth_values) + >>> [[1.0, 1.0], [1.0, 1.0]] + """ + + @prim_attr_register + def __init__(self, normalize=True): + """init EditDistance""" + self.normalize = validator.check_value_type("normalize", normalize, [bool], self.name) + + def __infer__(self, h_indices, h_values, h_shape, truth_indices, truth_values, truth_shape): + validator.check_const_input('hypothesis_shape', h_shape['value'], self.name) + validator.check_const_input('truth_shape', truth_shape['value'], self.name) + args_int = {"hypothesis_indices": h_indices['dtype'], "hypothesis_shape": h_shape['dtype'], + "truth_indices": truth_indices['dtype'], "truth_shape": truth_shape['dtype']} + validator.check_tensor_type_same(args_int, [mstype.int64], self.name) + args = {"hypothesis_values": h_values['dtype'], "truth_values": truth_values['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + hypothesis_indices_shp, truth_indices_shp = h_indices['shape'], truth_indices['shape'] + validator.check("hypothesis_indices rank", len(hypothesis_indices_shp), "expected", 2, Rel.EQ, self.name) + validator.check("truth_indices rank", len(truth_indices_shp), "expected", 2, Rel.EQ, self.name) + validator.check("hypothesis_values rank", len(h_values['shape']), "expected", 1, Rel.EQ, self.name) + validator.check("hypothesis_shape rank", len(h_shape['shape']), "expected", 1, Rel.EQ, self.name) + validator.check("truth_values rank", len(truth_values['shape']), "expected", 1, Rel.EQ, self.name) + validator.check("truth_shape rank", len(truth_shape['shape']), "expected", 1, Rel.EQ, self.name) + validator.check("hypothesis_values shape", h_values['shape'][0], + "hypothesis_indices shape[0]", hypothesis_indices_shp[0], Rel.EQ, self.name) + validator.check("hypothesis_shape", h_shape['shape'][0], + "hypothesis_indices shape[1]", hypothesis_indices_shp[1], Rel.EQ, self.name) + validator.check("truth_values shape", truth_values['shape'][0], + "truth_indices shape[0]", truth_indices_shp[0], Rel.EQ, self.name) + validator.check("hypothesis_shape", h_shape['shape'][0], + "truth_shape", truth_shape['shape'][0], Rel.EQ, self.name) + hypothesis_shape_v = h_shape['value'].asnumpy() + truth_shape_v = truth_shape['value'].asnumpy() + out_shape_rank = len(hypothesis_shape_v) - 1 + out_shape = [] + for i in range(out_shape_rank): + out_shape.append(max(hypothesis_shape_v[i], truth_shape_v[i])) + + return {'shape': tuple(out_shape), + 'dtype': mstype.tensor_type(mstype.float32), + 'value': None} + + class TransShape(PrimitiveWithInfer): """ Transform the shape of input tensor to target shape. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 4f3c3302bb6..313c5f58772 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -684,6 +684,18 @@ class ParallelConcatNet(nn.Cell): return self.parallel_concat((x1, x2)) +class EditDistance(nn.Cell): + def __init__(self, hypothesis_shape, truth_shape, normalize=True): + super(EditDistance, self).__init__() + self.edit_distance = P.EditDistance(normalize) + self.hypothesis_shape = hypothesis_shape + self.truth_shape =truth_shape + + def construct(self, hypothesis_indices, hypothesis_values, truth_indices, truth_values): + return self.edit_distance(hypothesis_indices, hypothesis_values, self.hypothesis_shape, + truth_indices, truth_values, self.truth_shape) + + test_case_math_ops = [ ('BitwiseAnd', { 'block': P.BitwiseAnd(), @@ -1978,6 +1990,15 @@ test_case_array_ops = [ 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), Tensor(np.array([1, 2, 3]).astype(np.int32))], 'desc_bprop': [[3, 3]]}), + ('EditDistance', { + 'block': EditDistance(Tensor(np.array([1, 1, 2]).astype(np.int64)), + Tensor(np.array([2, 2, 2]).astype(np.int64))), + 'desc_inputs': [Tensor(np.array([[0, 0, 0], [1, 0, 1], [1, 1, 1]]).astype(np.int64)), + Tensor(np.array([1, 2, 3]).astype(np.float32)), + Tensor(np.array([[0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1]]).astype(np.int64)), + Tensor(np.array([1, 3, 2, 1]).astype(np.float32))], + 'skip': ['backward'], + }), ('LinSpace', { 'block': inner.LinSpace(), 'desc_inputs': [Tensor([5, 5.5], mstype.float32),