add epillsis (numpy version) support.

This commit is contained in:
Exusial 2022-08-29 14:09:16 +08:00
parent 6339f62f56
commit 5eaccf538d
1 changed files with 18 additions and 0 deletions

View File

@ -488,6 +488,24 @@ def einsum(string, *args):
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.replace(" ", "").split(',')
s = p[:-1] + p[-1].split('->')
rec_shape = []
ellip_expr = None
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
for idx, expr in enumerate(s[:-1]):
if "..." in expr:
assert "..." in s[-1]
else:
continue
shp = inputs[idx].shape
ellipsis_pos = len(expr.replace("...", ""))
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
if ellip_expr is None:
ellip_expr = nellip_expr
else:
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
s[idx] = expr.replace("...", ellip_expr)
if ellip_expr:
s[-1] = s[-1].replace("...", ellip_expr)
if s[-1]=='':
return ()
else: