From 437a720500ff95889c1318651d6a3c0298d1c3dc Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Fri, 6 Jan 2023 19:39:58 +0800 Subject: [PATCH] add multinomial operator --- python/jittor/__init__.py | 2 +- python/jittor/misc.py | 40 ++++++++++++++++++++++++++++ python/jittor/src/ops/safe_clip_op.h | 2 +- python/jittor/test/test_misc_op.py | 11 ++++++++ 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 68a4855c..229da570 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.6.8' +__version__ = '1.3.6.9' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 95577897..8e417c3b 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -2066,3 +2066,43 @@ The returned var has the same number of dimensions as the original var (x). The ''' return x.getitem(((slice(None),)*dim)+(index,)) jt.index_select = index_select + +def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> jt.Var: + ''' Returns a var where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of input weights. + + :param weights: the input probability. + :param num_samples: number of samples. + :param replacement: whether to draw with replacement or not. + + + Example:: + + weights = jt.float32([0, 10, 3, 0]) + x = jt.multinomial(weights, 2) + assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1]) + x = jt.multinomial(weights, 4, replacement=True) + assert x.shape == (4, ) + + weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]]) + x = jt.multinomial(weights, 1) + assert jt.all_equal(x, [[2],[1],[0]]) + + ''' + if replacement: + cum_probs = jt.cumsum(weights)[..., None, :] + cum_probs_l = cum_probs[..., :-1] + cum_probs_r = cum_probs[..., 1:] + shape = weights.shape[:-1] + (num_samples, 1) + rand = jt.rand(shape) * cum_probs[..., :1, -1:] + one_hot = jt.logical_and(cum_probs_l < rand, rand <= cum_probs_r) + index = one_hot.index(one_hot.ndim - 1) + 1 + return (one_hot * index).sum(-1) + else: + # A-Res algorithm + # Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir + assert num_samples <= weights.shape[-1], "num_samples larger than the input" + rand = jt.rand(weights.shape) ** (1/weights) + _, indices = jt.topk(rand.safe_clip(), num_samples) + return indices + + diff --git a/python/jittor/src/ops/safe_clip_op.h b/python/jittor/src/ops/safe_clip_op.h index 97f9d44a..ac95ac99 100644 --- a/python/jittor/src/ops/safe_clip_op.h +++ b/python/jittor/src/ops/safe_clip_op.h @@ -22,7 +22,7 @@ struct SafeClipOp : Op { */ // @pybind(safe_clip) - SafeClipOp(Var* x, float64 left, float64 right); + SafeClipOp(Var* x, float64 left=-1e300, float64 right=1e300); const char* name() const override { return "safe_clip"; } VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py index 08360746..fa51a7ec 100644 --- a/python/jittor/test/test_misc_op.py +++ b/python/jittor/test/test_misc_op.py @@ -391,5 +391,16 @@ class TestOther(unittest.TestCase): y = jt.index_select(x, 1, indices) assert jt.all_equal(y, x[:, indices]) + def test_multinorm(self): + weights = jt.float32([0, 10, 3, 0]) + x = jt.multinomial(weights, 2) + assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1]) + x = jt.multinomial(weights, 4, replacement=True) + assert x.shape == (4, ) + + weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]]) + x = jt.multinomial(weights, 1) + assert jt.all_equal(x, [[2],[1],[0]]) + if __name__ == "__main__": unittest.main() \ No newline at end of file