add multinomial operator

This commit is contained in:
Dun Liang 2023-01-06 19:39:58 +08:00
parent 5a4ae74a3a
commit 437a720500
4 changed files with 53 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.6.8'
__version__ = '1.3.6.9'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -2066,3 +2066,43 @@ The returned var has the same number of dimensions as the original var (x). The
'''
return x.getitem(((slice(None),)*dim)+(index,))
jt.index_select = index_select
def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> jt.Var:
''' Returns a var where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of input weights.
:param weights: the input probability.
:param num_samples: number of samples.
:param replacement: whether to draw with replacement or not.
Example::
weights = jt.float32([0, 10, 3, 0])
x = jt.multinomial(weights, 2)
assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1])
x = jt.multinomial(weights, 4, replacement=True)
assert x.shape == (4, )
weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]])
x = jt.multinomial(weights, 1)
assert jt.all_equal(x, [[2],[1],[0]])
'''
if replacement:
cum_probs = jt.cumsum(weights)[..., None, :]
cum_probs_l = cum_probs[..., :-1]
cum_probs_r = cum_probs[..., 1:]
shape = weights.shape[:-1] + (num_samples, 1)
rand = jt.rand(shape) * cum_probs[..., :1, -1:]
one_hot = jt.logical_and(cum_probs_l < rand, rand <= cum_probs_r)
index = one_hot.index(one_hot.ndim - 1) + 1
return (one_hot * index).sum(-1)
else:
# A-Res algorithm
# Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir
assert num_samples <= weights.shape[-1], "num_samples larger than the input"
rand = jt.rand(weights.shape) ** (1/weights)
_, indices = jt.topk(rand.safe_clip(), num_samples)
return indices

View File

@ -22,7 +22,7 @@ struct SafeClipOp : Op {
*/
// @pybind(safe_clip)
SafeClipOp(Var* x, float64 left, float64 right);
SafeClipOp(Var* x, float64 left=-1e300, float64 right=1e300);
const char* name() const override { return "safe_clip"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;

View File

@ -391,5 +391,16 @@ class TestOther(unittest.TestCase):
y = jt.index_select(x, 1, indices)
assert jt.all_equal(y, x[:, indices])
def test_multinorm(self):
weights = jt.float32([0, 10, 3, 0])
x = jt.multinomial(weights, 2)
assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1])
x = jt.multinomial(weights, 4, replacement=True)
assert x.shape == (4, )
weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]])
x = jt.multinomial(weights, 1)
assert jt.all_equal(x, [[2],[1],[0]])
if __name__ == "__main__":
unittest.main()