From 5eaccf538d3f7e90aa9b96116afc201cb57cdb53 Mon Sep 17 00:00:00 2001 From: Exusial Date: Mon, 29 Aug 2022 14:09:16 +0800 Subject: [PATCH] add epillsis (numpy version) support. --- python/jittor/linalg.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py index 320efcb2..522719bc 100644 --- a/python/jittor/linalg.py +++ b/python/jittor/linalg.py @@ -488,6 +488,24 @@ def einsum(string, *args): shps = np_cpu.concatenate([in_.shape for in_ in inputs]) p = einsum_expr.replace(" ", "").split(',') s = p[:-1] + p[-1].split('->') + rec_shape = [] + ellip_expr = None + const_rep = '1234567890' # assume tensor shape no more than 10 dimensions + for idx, expr in enumerate(s[:-1]): + if "..." in expr: + assert "..." in s[-1] + else: + continue + shp = inputs[idx].shape + ellipsis_pos = len(expr.replace("...", "")) + nellip_expr = const_rep[0 : len(shp) - ellipsis_pos] + if ellip_expr is None: + ellip_expr = nellip_expr + else: + assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis." + s[idx] = expr.replace("...", ellip_expr) + if ellip_expr: + s[-1] = s[-1].replace("...", ellip_expr) if s[-1]=='': return () else: