From a6e626f02ebff4854adb3b7bdbe831545dd69dfd Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 31 Mar 2020 09:34:09 +0800 Subject: [PATCH] add operator diag and diag_part --- mindspore/ccsrc/transform/convert.cc | 6 +- mindspore/ccsrc/transform/op_declare.cc | 10 +++ mindspore/ccsrc/transform/op_declare.h | 4 ++ mindspore/ops/_grad/grad_array_ops.py | 22 +++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/array_ops.py | 85 +++++++++++++++++++++---- tests/ut/python/ops/test_ops.py | 10 +++ 7 files changed, 125 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 87bfc8f6d86..b8d07594b73 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -179,6 +179,8 @@ const char kNameLARSUpdate[] = "LARSUpdate"; const char kNameRound[] = "Round"; const char kNamePrint[] = "Print"; const char kNameApplyFtrl[] = "ApplyFtrl"; +const char kNameDiag[] = "Diag"; +const char kNameDiagPart[] = "DiagPart"; // -----------------OpAdapter initialization-------------- std::unordered_map &DfGraphConvertor::get_adpt_map() { @@ -359,7 +361,9 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, {string(kNameSign), ADPT_DESC(Sign)}, {string(kNameRound), ADPT_DESC(Round)}, - {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}}; + {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrl)}, + {string(kNameDiag), ADPT_DESC(Diag)}, + {string(kNameDiagPart), ADPT_DESC(DiagPart)}}; #ifdef ENABLE_GE adpt_map[string(kNamePrint)] = ADPT_DESC(Print); #endif diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 028b1297562..24043072fa4 100755 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1160,6 +1160,16 @@ INPUT_MAP(ApplyFtrl) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INP ATTR_MAP(ApplyFtrl) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; OUTPUT_MAP(ApplyFtrl) = {{0, OUTPUT_DESC(var)}}; +// Diag +INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Diag) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}}; + +// DiagPart +INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; +OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; + #ifdef ENABLE_GE // Print INPUT_MAP(Print) = EMPTY_INPUT_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 9e4f407ebb0..65e4cdc9985 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -437,6 +437,10 @@ DECLARE_OP_ADAPTER(ApplyFtrl) DECLARE_OP_USE_OUTPUT(ApplyFtrl) DECLARE_OP_ADAPTER(SparseApplyFtrlD) DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) +DECLARE_OP_ADAPTER(Diag) +DECLARE_OP_USE_OUTPUT(Diag) +DECLARE_OP_ADAPTER(DiagPart) +DECLARE_OP_USE_OUTPUT(DiagPart) #ifdef ENABLE_GE DECLARE_OP_ADAPTER(Print) DECLARE_OP_USE_DYN_INPUT(Print) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index cf6247023e6..79841cf27aa 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -408,3 +408,25 @@ def get_bprop_depth_to_space(self): return (op(dout),) return bprop + + +@bprop_getters.register(P.Diag) +def get_bprop_diag(self): + """Generate bprop for Diag""" + op = P.DiagPart() + + def bprop(x, out, dout): + return (op(dout),) + + return bprop + + +@bprop_getters.register(P.DiagPart) +def get_bprop_diag_part(self): + """Generate bprop for DiagPart""" + op = P.Diag() + + def bprop(x, out, dout): + return (op(dout),) + + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 77bb6d0ff32..92ffdaf199c 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -20,7 +20,7 @@ A collection of operators to build nerual networks or computing functions. """ from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, - Diag, DType, ExpandDims, Eye, + Diag, DiagPart, DType, ExpandDims, Eye, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, @@ -208,6 +208,7 @@ __all__ = [ "Cos", "ACos", "Diag", + "DiagPart", 'Eye', 'Assign', 'AssignAdd', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d0d3d5006c3..6740f172b45 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1615,37 +1615,96 @@ class StridedSlice(PrimitiveWithInfer): class Diag(PrimitiveWithInfer): r""" - Extract or construct a diagonal array. + Construct a diagonal tensor with a given diagonal values. - If input is a 2-D tensor, returns the diagonal of the input with the given offset. If - input is a 1-D tensor, returns the array of diagonals. If you use this function - to extract the diagonal and want to write to the result array, see the more - detailed documentation for "numpy.diagonal", whether you return a copy or a - view depends on the version of numpy you are using. + Assume `input_x` has dimensions :math:`[D_1,... D_k]`, the output is a tensor of + rank 2k with dimensions :math:`[D_1,..., D_k, D_1,..., D_k]` where: + :math:`output[i_1,..., i_k, i_1,..., i_k] = input_x[i_1,..., i_k]` and 0 everywhere else. Inputs: - - **input_x** (Tensor) - 1-D tensor or 2-D tensor. + - **input_x** (Tensor) - The input tensor. Outputs: Tensor. Examples: + >>> input_x = Tensor([1, 2, 3, 4]) >>> diag = P.Diag() >>> diag(x) + [[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]] """ @prim_attr_register def __init__(self): """init Diag""" - def infer_type(self, x): - args = {"x_dtype": x} - validator.check_subclass('input_x', x, mstype.tensor) - validator.check_type_same(args, mstype.number_type) - return x + def infer_dtype(self, x_type): + validator.check_subclass('input_x', x_type, mstype.tensor) + return x_type + + def infer_shape(self, x_shape): + validator.check("x rank", len(x_shape), "", 1, Rel.GE) + ret_shape = copy.deepcopy(x_shape) + ret_shape = ret_shape + ret_shape + return ret_shape def infer_value(self, x): - validator.check("shape_length", len(x.shape()), "length", [1, 2], Rel.IN) + if x is None: + return None + validator.check("input x rank", len(x.shape()), "", 1) + ret = np.diag(x.asnumpy()) + return Tensor(ret) + + +class DiagPart(PrimitiveWithInfer): + r""" + + Extract the diagonal part from given tensor. + + Assume input has dimensions :math:`[D_1,..., D_k, D_1,..., D_k]`, the output is a tensor + of rank k with dimensions :math:`[D_1,..., D_k]` where: + :math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`. + + Inputs: + - **input_x** (Tensor) - The input Tensor. + + Outputs: + Tensor. + + Examples + >>> input_x = Tensor([[1, 0, 0, 0], + >>> [0, 2, 0, 0], + >>> [0, 0, 3, 0], + >>> [0, 0, 0, 4]]) + >>> diag_part = P.DiagPart() + >>> diag_part(x) + [1, 2, 3, 4] + """ + + @prim_attr_register + def __init__(self): + """init DiagPart""" + + def infer_dtype(self, x_type): + validator.check_subclass('input_x', x_type, mstype.tensor) + return x_type + + def infer_shape(self, x_shape): + if len(x_shape)%2 != 0 or \ + not x_shape: + raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, " + f"with shapes {x_shape}") + length = len(x_shape) // 2 + ret_shape = x_shape[0:length] + return ret_shape + + def infer_value(self, x): + if x is None: + return None + validator.check("x rank", len(x.shape()), "", 2) ret = np.diag(x.asnumpy()) return Tensor(ret) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8d7dd950723..7e32ac94546 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -947,6 +947,16 @@ test_case_array_ops = [ Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], 'desc_bprop': [[3,]]}), + ('Diag', { + 'block': P.Diag(), + 'desc_inputs': [[4]], + 'desc_bprop': [[4, 4]], + }), + ('DiagPart', { + 'block': P.DiagPart(), + 'desc_inputs': [[4, 4]], + 'desc_bprop': [[4]], + }), ] test_case_other_ops = [