From 5f22e08d0522693fec69da5ab0cb764d2925fa12 Mon Sep 17 00:00:00 2001 From: tangdezhi_123 Date: Wed, 22 Feb 2023 17:32:35 +0800 Subject: [PATCH] =?UTF-8?q?mindspore.ops.eye=E5=92=8Cpytorch.eye=E7=9A=84?= =?UTF-8?q?=E5=8F=82=E6=95=B0m=E4=B8=8D=E5=AF=B9=E6=A0=87=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E4=BF=AE=E6=AD=A3=E3=80=82=201.=20=E5=AF=B9=E5=8F=82?= =?UTF-8?q?=E6=95=B0m=E5=A2=9E=E5=8A=A0=E9=BB=98=E8=AE=A4=E5=80=BCNone,?= =?UTF-8?q?=E8=BF=94=E5=9B=9ETensor=E7=9A=84=E5=88=97=E6=95=B0=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E4=B8=8En=E7=9B=B8=E7=AD=89=E3=80=82=202.=20=E5=AF=B9?= =?UTF-8?q?=E5=8F=82=E6=95=B0dtype=E5=A2=9E=E5=8A=A0=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=80=BCNone=EF=BC=8C=E8=BF=94=E5=9B=9ETensor=E7=9A=84?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E9=BB=98=E8=AE=A4=E4=B8=BA?= =?UTF-8?q?mindspore.float32=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_python/ops/mindspore.ops.func_eye.rst | 8 +++--- .../mindspore/ops/function/array_func.py | 28 +++++++++++++++---- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/docs/api/api_python/ops/mindspore.ops.func_eye.rst b/docs/api/api_python/ops/mindspore.ops.func_eye.rst index 9a3b13ab803..680d0afbf60 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_eye.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_eye.rst @@ -1,7 +1,7 @@ mindspore.ops.eye ================== -.. py:function:: mindspore.ops.eye(n, m, t) +.. py:function:: mindspore.ops.eye(n, m=None, dtype=None) 创建一个主对角线上元素为1,其余元素为0的Tensor。 @@ -10,11 +10,11 @@ mindspore.ops.eye 参数: - **n** (int) - 指定返回Tensor的行数。仅支持常量值。 - - **m** (int) - 指定返回Tensor的列数。仅支持常量值。 - - **t** (mindspore.dtype) - 指定返回Tensor的数据类型。数据类型必须是 `bool_ `_ 或 `number `_ 。 + - **m** (int) - 指定返回Tensor的列数。仅支持常量值。默认值为None,返回Tensor的列数默认与n相等。 + - **dtype** (mindspore.dtype) - 指定返回Tensor的数据类型。数据类型必须是 `bool_ `_ 或 `number `_ 。默认值为None,返回Tensor的数据类型默认为mindspore.float32。 返回: - Tensor,主对角线上为1,其余的元素为0。它的shape由 `n` 和 `m` 指定。数据类型由 `t` 指定。 + Tensor,主对角线上为1,其余的元素为0。它的shape由 `n` 和 `m` 指定。数据类型由 `dtype` 指定。 异常: - **TypeError** - `m` 或 `n` 不是int。 diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index d2b8bd40966..3ab4cb2e094 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -20,6 +20,7 @@ import builtins import operator import numpy as np +import mindspore as ms import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr @@ -285,7 +286,7 @@ def cat(tensors, axis=0): return _concat(tensors) -def eye(n, m, t): +def eye(n, m=None, dtype=None): """ Creates a tensor with ones on the diagonal and zeros in the rest. @@ -296,12 +297,13 @@ def eye(n, m, t): Args: n (int): The number of rows of returned tensor. Constant value only. m (int): The number of columns of returned tensor. Constant value only. - t (mindspore.dtype): MindSpore's dtype, the data type of the returned tensor. - The data type can be Number. + Default: if None, the number of columns is as the same as n. + dtype (mindspore.dtype): MindSpore's dtype, the data type of the returned tensor. + The data type can be Number. Default: if None, the data type of the returned tensor is mindspore.float32. Returns: Tensor, a tensor with ones on the diagonal and the rest of elements are zero. The shape of `output` depends on - the user's Inputs `n` and `m`. And the data type depends on Inputs `t`. + the user's Inputs `n` and `m`. And the data type depends on Inputs `dtype`. Raises: TypeError: If `m` or `n` is not an int. @@ -322,8 +324,24 @@ def eye(n, m, t): [[1. 0.]] >>> print(output.dtype) Float64 + >>> output = ops.eye(2, t=mindspore.int32) + >>> print(output) + [[1 0] + [0 1]] + >>> print(output.dtype) + Int32 + >>> output = ops.eye(2) + >>> print(output) + [[1. 0.] + [0. 1.]] + >>> print(output.dtype) + Float32 """ - return eye_(n, m, t) + if m is None: + m = n + if dtype is None: + dtype = ms.float32 + return eye_(n, m, dtype) def hamming_window(window_length, periodic=True, alpha=0.54, beta=0.46, *, dtype=None):