forked from jittor/jittor
add multinomial operator
This commit is contained in:
parent
5a4ae74a3a
commit
437a720500
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.6.8'
|
||||
__version__ = '1.3.6.9'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -2066,3 +2066,43 @@ The returned var has the same number of dimensions as the original var (x). The
|
|||
'''
|
||||
return x.getitem(((slice(None),)*dim)+(index,))
|
||||
jt.index_select = index_select
|
||||
|
||||
def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> jt.Var:
|
||||
''' Returns a var where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of input weights.
|
||||
|
||||
:param weights: the input probability.
|
||||
:param num_samples: number of samples.
|
||||
:param replacement: whether to draw with replacement or not.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
weights = jt.float32([0, 10, 3, 0])
|
||||
x = jt.multinomial(weights, 2)
|
||||
assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1])
|
||||
x = jt.multinomial(weights, 4, replacement=True)
|
||||
assert x.shape == (4, )
|
||||
|
||||
weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]])
|
||||
x = jt.multinomial(weights, 1)
|
||||
assert jt.all_equal(x, [[2],[1],[0]])
|
||||
|
||||
'''
|
||||
if replacement:
|
||||
cum_probs = jt.cumsum(weights)[..., None, :]
|
||||
cum_probs_l = cum_probs[..., :-1]
|
||||
cum_probs_r = cum_probs[..., 1:]
|
||||
shape = weights.shape[:-1] + (num_samples, 1)
|
||||
rand = jt.rand(shape) * cum_probs[..., :1, -1:]
|
||||
one_hot = jt.logical_and(cum_probs_l < rand, rand <= cum_probs_r)
|
||||
index = one_hot.index(one_hot.ndim - 1) + 1
|
||||
return (one_hot * index).sum(-1)
|
||||
else:
|
||||
# A-Res algorithm
|
||||
# Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir
|
||||
assert num_samples <= weights.shape[-1], "num_samples larger than the input"
|
||||
rand = jt.rand(weights.shape) ** (1/weights)
|
||||
_, indices = jt.topk(rand.safe_clip(), num_samples)
|
||||
return indices
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ struct SafeClipOp : Op {
|
|||
|
||||
*/
|
||||
// @pybind(safe_clip)
|
||||
SafeClipOp(Var* x, float64 left, float64 right);
|
||||
SafeClipOp(Var* x, float64 left=-1e300, float64 right=1e300);
|
||||
|
||||
const char* name() const override { return "safe_clip"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
|
|
|
@ -391,5 +391,16 @@ class TestOther(unittest.TestCase):
|
|||
y = jt.index_select(x, 1, indices)
|
||||
assert jt.all_equal(y, x[:, indices])
|
||||
|
||||
def test_multinorm(self):
|
||||
weights = jt.float32([0, 10, 3, 0])
|
||||
x = jt.multinomial(weights, 2)
|
||||
assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1])
|
||||
x = jt.multinomial(weights, 4, replacement=True)
|
||||
assert x.shape == (4, )
|
||||
|
||||
weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]])
|
||||
x = jt.multinomial(weights, 1)
|
||||
assert jt.all_equal(x, [[2],[1],[0]])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue