From 0eb72d76f076d5bdeb08093721b9104a40e1ad28 Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Wed, 1 Apr 2020 21:08:40 +0800 Subject: [PATCH] import comment and function of op print --- mindspore/ccsrc/transform/op_adapter.h | 2 +- mindspore/ops/operations/debug_ops.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 3dd299f83de..524cdfb4aa1 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -513,7 +513,7 @@ class OpAdapter : public BaseOpAdapter { return; } } else { - MS_LOG(ERROR) << "Update output desc failed, unknow output shape type"; + MS_LOG(WARNING) << "Update output desc failed, unknow output shape type"; return; } MS_EXCEPTION_IF_NULL(node); diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 6640ef87ca5..a69dcc2df1f 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -14,6 +14,7 @@ # ============================================================================ """debug_ops""" +from ..._checkparam import ParamValidator as validator from ...common import dtype as mstype from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer @@ -157,19 +158,20 @@ class InsertGradientOf(PrimitiveWithInfer): class Print(PrimitiveWithInfer): """ - Output tensor to stdout. + Output tensor or string to stdout. Inputs: - - **input_x** (Tensor) - The graph node to attach to. + - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports + multiple strings and tensors which are separated by ','. Examples: >>> class PrintDemo(nn.Cell): - >>> def __init__(self,): + >>> def __init__(self): >>> super(PrintDemo, self).__init__() >>> self.print = P.Print() >>> - >>> def construct(self, x): - >>> self.print(x) + >>> def construct(self, x, y): + >>> self.print('Print Tensor x and Tensor y:', x, y) >>> return x """ @@ -181,4 +183,6 @@ class Print(PrimitiveWithInfer): return [1] def infer_dtype(self, *inputs): + for dtype in inputs: + validator.check_subclass("input", dtype, (mstype.tensor, mstype.string)) return mstype.int32