forked from mindspore-Ecosystem/mindspore
!10022 fix example and Validator for LogUniformCandidateSampler
From: @yanzhenxiang2020 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
eda6ce12ed
|
@ -4134,7 +4134,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
"""
|
||||
Generates coordinate matrices from given coordinate tensors.
|
||||
|
||||
Given N one-dimensional coordinate tensors, returns a list outputs of N N-D
|
||||
Given N one-dimensional coordinate tensors, returns a tuple outputs of N N-D
|
||||
coordinate tensors for evaluating expressions on an N-D grid.
|
||||
|
||||
|
||||
|
@ -4144,12 +4144,15 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
instructions for the first two dimensions are swapped.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Union[tuple, list]) - A Tuple or list of N 1-D Tensor objects.
|
||||
The length of input_x should be greater than 1
|
||||
- **input** (Union[tuple]) - A Tuple of N 1-D Tensor objects.
|
||||
The length of input should be greater than 1
|
||||
|
||||
Outputs:
|
||||
Tensors, A Tuple of N N-D Tensor objects.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
|
||||
>>> y = Tensor(np.array([5, 6, 7]).astype(np.int32))
|
||||
|
@ -4158,7 +4161,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
>>> meshgrid = ops.Meshgrid(indexing="xy")
|
||||
>>> output = meshgrid(inputs)
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
(Tensor(shape=[3, 4, 5], dtype=Int32, value=
|
||||
[[[1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
|
@ -4171,7 +4174,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
[2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3],
|
||||
[4, 4, 4, 4, 4]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
Tensor(shape=[3, 4, 5], dtype=Int32, value=
|
||||
[[[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
[5, 5, 5, 5, 5],
|
||||
|
@ -4184,7 +4187,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7],
|
||||
[7, 7, 7, 7, 7]]]),
|
||||
Tensor(shape=[3, 4, 6], dtype=Int32, value=
|
||||
Tensor(shape=[3, 4, 5], dtype=Int32, value=
|
||||
[[[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
[8, 9, 0, 1, 2],
|
||||
|
@ -4208,12 +4211,12 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
self.indexing = indexing
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_value_type("shape", x_shape, [tuple, list], self.name)
|
||||
validator.check_int(len(x_shape), 2, Rel.GE, "len of input_x", self.name)
|
||||
validator.check_value_type("shape", x_shape, [tuple], self.name)
|
||||
validator.check_int(len(x_shape), 2, Rel.GE, "len of input", self.name)
|
||||
n = len(x_shape)
|
||||
shape_0 = []
|
||||
for s in x_shape:
|
||||
validator.check_int(len(s), 1, Rel.EQ, 'each_input_rank', self.name)
|
||||
validator.check_int(len(s), 1, Rel.EQ, 'each input rank', self.name)
|
||||
shape_0.append(s[0])
|
||||
if self.indexing == "xy":
|
||||
shape_0[0], shape_0[1] = shape_0[1], shape_0[0]
|
||||
|
@ -4221,7 +4224,7 @@ class Meshgrid(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_subclass("input_x[0]", x_type[0], mstype.tensor, self.name)
|
||||
validator.check_subclass("input[0]", x_type[0], mstype.tensor, self.name)
|
||||
n = len(x_type)
|
||||
for i in range(1, n):
|
||||
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, self.name, TypeError)
|
||||
|
|
|
@ -621,14 +621,14 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
|
|||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> sampler = ops.LogUniformCandidateSampler(1, 5, True, 5)
|
||||
>>> sampler = ops.LogUniformCandidateSampler(2, 5, True, 5)
|
||||
>>> output1, output2, output3 = sampler(Tensor(np.array([[1, 7], [0, 4], [3, 3]])))
|
||||
>>> print(output1, output2, output3)
|
||||
[3, 2, 0, 4, 1],
|
||||
[[9.23129916e-01, 4.93363708e-01],
|
||||
[9.92489874e-01, 6.58063710e-01],
|
||||
[7.35534430e-01, 7.35534430e-01]],
|
||||
[7.35534430e-01, 8.26258004e-01, 9.92489874e-01, 6.58063710e-01, 9.23129916e-01]
|
||||
[3 2 0 4 1]
|
||||
[[0.92312991 0.49336370]
|
||||
[0.99248987 0.65806371]
|
||||
[0.73553443 0.73553443]]
|
||||
[0.73553443 0.82625800 0.99248987 0.65806371 0.92312991]
|
||||
|
||||
"""
|
||||
|
||||
|
@ -645,14 +645,14 @@ class LogUniformCandidateSampler(PrimitiveWithInfer):
|
|||
self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
|
||||
self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name)
|
||||
if unique:
|
||||
Validator.check_number("range_max", range_max, num_sampled, Rel.GE, self.name)
|
||||
Validator.check("range_max", range_max, "num_sampled", num_sampled, Rel.GE, self.name)
|
||||
self.range_max = range_max
|
||||
self.unique = unique
|
||||
self.seed = seed
|
||||
self.seed = Validator.check_number("seed", seed, 0, Rel.GE, self.name)
|
||||
|
||||
def infer_shape(self, true_classes_shape):
|
||||
Validator.check("true_classes shape rank", len(true_classes_shape), "expect", 2, Rel.EQ, self.name)
|
||||
Validator.check_int(true_classes_shape[1], self.num_true, Rel.EQ, 'true_classes_shape', self.name)
|
||||
Validator.check_int(len(true_classes_shape), 2, Rel.EQ, "dim of true_classes", self.name)
|
||||
Validator.check("true_classes_shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
|
||||
return (self.num_sampled,), true_classes_shape, (self.num_sampled,)
|
||||
|
||||
def infer_dtype(self, true_classes_type):
|
||||
|
|
Loading…
Reference in New Issue