forked from jittor/jittor
add epillsis (numpy version) support.
This commit is contained in:
parent
6339f62f56
commit
5eaccf538d
|
@ -488,6 +488,24 @@ def einsum(string, *args):
|
|||
shps = np_cpu.concatenate([in_.shape for in_ in inputs])
|
||||
p = einsum_expr.replace(" ", "").split(',')
|
||||
s = p[:-1] + p[-1].split('->')
|
||||
rec_shape = []
|
||||
ellip_expr = None
|
||||
const_rep = '1234567890' # assume tensor shape no more than 10 dimensions
|
||||
for idx, expr in enumerate(s[:-1]):
|
||||
if "..." in expr:
|
||||
assert "..." in s[-1]
|
||||
else:
|
||||
continue
|
||||
shp = inputs[idx].shape
|
||||
ellipsis_pos = len(expr.replace("...", ""))
|
||||
nellip_expr = const_rep[0 : len(shp) - ellipsis_pos]
|
||||
if ellip_expr is None:
|
||||
ellip_expr = nellip_expr
|
||||
else:
|
||||
assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis."
|
||||
s[idx] = expr.replace("...", ellip_expr)
|
||||
if ellip_expr:
|
||||
s[-1] = s[-1].replace("...", ellip_expr)
|
||||
if s[-1]=='':
|
||||
return ()
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue