!18125 fix DynamicStitch check and StandardNormal docs

Merge pull request !18125 from yanzhenxiang2020/fix_dynamicstitch_standard
This commit is contained in:
i-robot 2021-06-11 10:19:45 +08:00 committed by Gitee
commit f87f2ea121
4 changed files with 11 additions and 9 deletions

View File

@ -1187,8 +1187,8 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
for (size_t i = 1; i < data.size(); ++i) {
auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape];
auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape];
if (indicesi_shape.size() >= datai_shape.size()) {
MS_LOG(EXCEPTION) << "The rank of indices[i] must be < rank of data[i]!";
if (indicesi_shape.size() > datai_shape.size()) {
MS_LOG(EXCEPTION) << "The rank of indices[i] must be <= rank of data[i]!";
}
indices_total_size += indicesi_shape.size();
}

View File

@ -1082,7 +1082,7 @@ class DynamicStitch(PrimitiveWithCheck):
for i in range(0, indices_num):
indices_dim = len(indices_shape[i])
data_dim = len(data_shape[i])
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LT, self.name)
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LE, self.name)
if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")

View File

@ -6893,8 +6893,6 @@ class Dropout(PrimitiveWithCheck):
>>> output, mask = dropout(x)
>>> print(output)
[0. 32. 0. 0.]
>>> print(mask)
[0. 1. 0. 0.]
"""
@prim_attr_register

View File

@ -24,6 +24,9 @@ class StandardNormal(PrimitiveWithInfer):
r"""
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
whose mean is 0 and standard deviation is 1.
Args:
seed (int): Random seed, must be non-negative. Default: 0.
seed2 (int): Random seed2, must be non-negative. Default: 0.
@ -43,12 +46,13 @@ class StandardNormal(PrimitiveWithInfer):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> shape = (4, 16)
>>> shape = (3, 4)
>>> stdnormal = ops.StandardNormal(seed=2)
>>> output = stdnormal(shape)
>>> result = output.shape
>>> print(result)
(4, 16)
>>> print(output)
[[-1.3031056 0.64198005 -0.65207404 -1.767485 ]
[-0.91792876 0.6508565 -0.9098478 -0.14092612]
[ 0.7806437 1.1585592 1.9676613 -0.00440959]]
"""
@prim_attr_register