skip operations which are not supported in the backend in ME

This commit is contained in:
BowenK 2020-06-09 15:32:04 +08:00
parent 5958c4abc6
commit 8c4bcb84b2
3 changed files with 40 additions and 2 deletions

View File

@ -184,6 +184,13 @@ class Cast(PrimitiveWithInfer):
"""init Cast"""
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
def check_elim(self, x, dtype):
if isinstance(x, Tensor):
if x.dtype() == dtype:
return (True, x)
return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
def __infer__(self, x, t):
src_type = x['dtype']
dst_type = t['value']
@ -1310,6 +1317,15 @@ class Tile(PrimitiveWithInfer):
"""init Tile"""
self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
def check_elim(self, base_tensor, multiplier):
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
def is_all_zeros(v_tuple):
return all(v == 1 for v in v_tuple)
if is_all_zeros(multiplier):
return (True, base_tensor)
return (False, None)
def __infer__(self, x, multiples):
multiples_v = multiples['value']
x_shp = x['shape']

View File

@ -705,6 +705,13 @@ class AddN(PrimitiveWithInfer):
def __init__(self):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def check_elim(self, inputs):
if len(inputs) != 1:
return (False, None)
if isinstance(inputs[0], Tensor):
return (True, inputs[0])
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
def infer_shape(self, inputs):
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)

View File

@ -140,9 +140,24 @@ class Primitive(Primitive_):
return self.attrs[item]
raise AttributeError(item)
def check_elim(self, *args):
"""
Check whether or not certain inputs should go into backend. Subclass in need should override this method.
Args:
Same as arguments of current Primitive
Returns:
A tuple of two elements, first element indicates whether or not we should filter out current arguments;
seconde element is the output in case where we should filter out the arguments.
"""
return (False, None)
def __call__(self, *args):
output = _run_op(self, self.name, args)
should_elim, output = self.check_elim(*args)
if should_elim:
return output
return _run_op(self, self.name, args)
def __getstate__(self):
return self.__dict__