forked from mindspore-Ecosystem/mindspore
skip operations which are not supported in the backend in ME
This commit is contained in:
parent
5958c4abc6
commit
8c4bcb84b2
|
@ -184,6 +184,13 @@ class Cast(PrimitiveWithInfer):
|
|||
"""init Cast"""
|
||||
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
||||
|
||||
def check_elim(self, x, dtype):
|
||||
if isinstance(x, Tensor):
|
||||
if x.dtype() == dtype:
|
||||
return (True, x)
|
||||
return (False, None)
|
||||
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
|
||||
|
||||
def __infer__(self, x, t):
|
||||
src_type = x['dtype']
|
||||
dst_type = t['value']
|
||||
|
@ -1310,6 +1317,15 @@ class Tile(PrimitiveWithInfer):
|
|||
"""init Tile"""
|
||||
self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
|
||||
|
||||
def check_elim(self, base_tensor, multiplier):
|
||||
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
|
||||
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
|
||||
def is_all_zeros(v_tuple):
|
||||
return all(v == 1 for v in v_tuple)
|
||||
if is_all_zeros(multiplier):
|
||||
return (True, base_tensor)
|
||||
return (False, None)
|
||||
|
||||
def __infer__(self, x, multiples):
|
||||
multiples_v = multiples['value']
|
||||
x_shp = x['shape']
|
||||
|
|
|
@ -705,6 +705,13 @@ class AddN(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
|
||||
|
||||
def check_elim(self, inputs):
|
||||
if len(inputs) != 1:
|
||||
return (False, None)
|
||||
if isinstance(inputs[0], Tensor):
|
||||
return (True, inputs[0])
|
||||
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))
|
||||
|
||||
def infer_shape(self, inputs):
|
||||
cls_name = self.name
|
||||
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
|
||||
|
|
|
@ -140,9 +140,24 @@ class Primitive(Primitive_):
|
|||
return self.attrs[item]
|
||||
raise AttributeError(item)
|
||||
|
||||
def check_elim(self, *args):
|
||||
"""
|
||||
Check whether or not certain inputs should go into backend. Subclass in need should override this method.
|
||||
|
||||
Args:
|
||||
Same as arguments of current Primitive
|
||||
|
||||
Returns:
|
||||
A tuple of two elements, first element indicates whether or not we should filter out current arguments;
|
||||
seconde element is the output in case where we should filter out the arguments.
|
||||
"""
|
||||
return (False, None)
|
||||
|
||||
def __call__(self, *args):
|
||||
output = _run_op(self, self.name, args)
|
||||
return output
|
||||
should_elim, output = self.check_elim(*args)
|
||||
if should_elim:
|
||||
return output
|
||||
return _run_op(self, self.name, args)
|
||||
|
||||
def __getstate__(self):
|
||||
return self.__dict__
|
||||
|
|
Loading…
Reference in New Issue