From f227d6ffb0b9b5fd6db34a1de5b92b3cac4bf0d0 Mon Sep 17 00:00:00 2001 From: wangshuide2020 Date: Tue, 29 Jun 2021 21:18:13 +0800 Subject: [PATCH] add optional axis attr for flatten and update example of PSNR. --- mindspore/nn/layer/image.py | 6 +++--- mindspore/ops/_op_impl/tbe/flatten.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index fae79c8b2d2..6e3fd1994f5 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -444,11 +444,11 @@ class PSNR(Cell): Examples: >>> net = nn.PSNR() - >>> img1 = Tensor(np.random.random((1, 3, 16, 16))) - >>> img2 = Tensor(np.random.random((1, 3, 16, 16))) + >>> img1 = Tensor([[[[1, 2, 3, 4], [1, 2, 3, 4]]]]) + >>> img2 = Tensor([[[[3, 4, 5, 6], [3, 4, 5, 6]]]]) >>> output = net(img1, img2) >>> print(output) - [7.915369] + [-6.0206] """ def __init__(self, max_val=1.0): super(PSNR, self).__init__() diff --git a/mindspore/ops/_op_impl/tbe/flatten.py b/mindspore/ops/_op_impl/tbe/flatten.py index 0413422fb04..f165f2fce20 100644 --- a/mindspore/ops/_op_impl/tbe/flatten.py +++ b/mindspore/ops/_op_impl/tbe/flatten.py @@ -23,6 +23,7 @@ flatten_op_info = TBERegOp("Flatten") \ .compute_cost(10) \ .kernel_name("flatten") \ .partial_flag(True) \ + .attr("axis", "optional", "int", "all", "1") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ .dtype_format(DataType.I8_Default, DataType.I8_Default) \