From 32481cdb743ab979e74265b02d036c426efe0723 Mon Sep 17 00:00:00 2001 From: yanzhenxiang2020 Date: Thu, 10 Jun 2021 11:50:17 +0800 Subject: [PATCH] fix DynamicStitch check and StandardNormal docs --- mindspore/core/abstract/prim_arrays.cc | 4 ++-- mindspore/ops/operations/_inner_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 2 -- mindspore/ops/operations/random_ops.py | 12 ++++++++---- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 31e3f174b2a..eb27ca34fc2 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -1187,8 +1187,8 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv for (size_t i = 1; i < data.size(); ++i) { auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape]; auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape]; - if (indicesi_shape.size() >= datai_shape.size()) { - MS_LOG(EXCEPTION) << "The rank of indices[i] must be < rank of data[i]!"; + if (indicesi_shape.size() > datai_shape.size()) { + MS_LOG(EXCEPTION) << "The rank of indices[i] must be <= rank of data[i]!"; } indices_total_size += indicesi_shape.size(); } diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 2c346a819a5..9fa6ae022df 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -1082,7 +1082,7 @@ class DynamicStitch(PrimitiveWithCheck): for i in range(0, indices_num): indices_dim = len(indices_shape[i]) data_dim = len(data_shape[i]) - validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LT, self.name) + validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LE, self.name) if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]: raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}") diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index e115a85339e..0fd11805e12 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6893,8 +6893,6 @@ class Dropout(PrimitiveWithCheck): >>> output, mask = dropout(x) >>> print(output) [0. 32. 0. 0.] - >>> print(mask) - [0. 1. 0. 0.] """ @prim_attr_register diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index be9042e8f2f..f77e1f1bcc1 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -24,6 +24,9 @@ class StandardNormal(PrimitiveWithInfer): r""" Generates random numbers according to the standard Normal (or Gaussian) random number distribution. + Returns the tensor with the given shape, the random numbers in it drawn from normal distributions + whose mean is 0 and standard deviation is 1. + Args: seed (int): Random seed, must be non-negative. Default: 0. seed2 (int): Random seed2, must be non-negative. Default: 0. @@ -43,12 +46,13 @@ class StandardNormal(PrimitiveWithInfer): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> shape = (4, 16) + >>> shape = (3, 4) >>> stdnormal = ops.StandardNormal(seed=2) >>> output = stdnormal(shape) - >>> result = output.shape - >>> print(result) - (4, 16) + >>> print(output) + [[-1.3031056 0.64198005 -0.65207404 -1.767485 ] + [-0.91792876 0.6508565 -0.9098478 -0.14092612] + [ 0.7806437 1.1585592 1.9676613 -0.00440959]] """ @prim_attr_register