forked from mindspore-Ecosystem/mindspore
!18125 fix DynamicStitch check and StandardNormal docs
Merge pull request !18125 from yanzhenxiang2020/fix_dynamicstitch_standard
This commit is contained in:
commit
f87f2ea121
|
@ -1187,8 +1187,8 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
for (size_t i = 1; i < data.size(); ++i) {
|
||||
auto indicesi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices[i]->BuildShape())[kShape];
|
||||
auto datai_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(data[i]->BuildShape())[kShape];
|
||||
if (indicesi_shape.size() >= datai_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The rank of indices[i] must be < rank of data[i]!";
|
||||
if (indicesi_shape.size() > datai_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The rank of indices[i] must be <= rank of data[i]!";
|
||||
}
|
||||
indices_total_size += indicesi_shape.size();
|
||||
}
|
||||
|
|
|
@ -1082,7 +1082,7 @@ class DynamicStitch(PrimitiveWithCheck):
|
|||
for i in range(0, indices_num):
|
||||
indices_dim = len(indices_shape[i])
|
||||
data_dim = len(data_shape[i])
|
||||
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LT, self.name)
|
||||
validator.check(f"dim of indices[{i}]", indices_dim, f"dim of data[{i}]", data_dim, Rel.LE, self.name)
|
||||
if data_shape[i][:indices_dim] != data_shape[i][:indices_dim]:
|
||||
raise ValueError(f"data[{i}].shape: {data_shape} does not start with indices[{i}].shape: {data_shape}")
|
||||
|
||||
|
|
|
@ -6893,8 +6893,6 @@ class Dropout(PrimitiveWithCheck):
|
|||
>>> output, mask = dropout(x)
|
||||
>>> print(output)
|
||||
[0. 32. 0. 0.]
|
||||
>>> print(mask)
|
||||
[0. 1. 0. 0.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -24,6 +24,9 @@ class StandardNormal(PrimitiveWithInfer):
|
|||
r"""
|
||||
Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
|
||||
|
||||
Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
|
||||
whose mean is 0 and standard deviation is 1.
|
||||
|
||||
Args:
|
||||
seed (int): Random seed, must be non-negative. Default: 0.
|
||||
seed2 (int): Random seed2, must be non-negative. Default: 0.
|
||||
|
@ -43,12 +46,13 @@ class StandardNormal(PrimitiveWithInfer):
|
|||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> shape = (4, 16)
|
||||
>>> shape = (3, 4)
|
||||
>>> stdnormal = ops.StandardNormal(seed=2)
|
||||
>>> output = stdnormal(shape)
|
||||
>>> result = output.shape
|
||||
>>> print(result)
|
||||
(4, 16)
|
||||
>>> print(output)
|
||||
[[-1.3031056 0.64198005 -0.65207404 -1.767485 ]
|
||||
[-0.91792876 0.6508565 -0.9098478 -0.14092612]
|
||||
[ 0.7806437 1.1585592 1.9676613 -0.00440959]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue