diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 6b1794c2439..2e5163ac1a4 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -184,6 +184,13 @@ class Cast(PrimitiveWithInfer): """init Cast""" self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) + def check_elim(self, x, dtype): + if isinstance(x, Tensor): + if x.dtype() == dtype: + return (True, x) + return (False, None) + raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) + def __infer__(self, x, t): src_type = x['dtype'] dst_type = t['value'] @@ -1310,6 +1317,15 @@ class Tile(PrimitiveWithInfer): """init Tile""" self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) + def check_elim(self, base_tensor, multiplier): + if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): + raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) + def is_all_zeros(v_tuple): + return all(v == 1 for v in v_tuple) + if is_all_zeros(multiplier): + return (True, base_tensor) + return (False, None) + def __infer__(self, x, multiples): multiples_v = multiples['value'] x_shp = x['shape'] diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index af3789201f9..0bba0485cc1 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -705,6 +705,13 @@ class AddN(PrimitiveWithInfer): def __init__(self): self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) + def check_elim(self, inputs): + if len(inputs) != 1: + return (False, None) + if isinstance(inputs[0], Tensor): + return (True, inputs[0]) + raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0]))) + def infer_shape(self, inputs): cls_name = self.name validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index f456421f704..a56f69ceeb9 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -140,9 +140,24 @@ class Primitive(Primitive_): return self.attrs[item] raise AttributeError(item) + def check_elim(self, *args): + """ + Check whether or not certain inputs should go into backend. Subclass in need should override this method. + + Args: + Same as arguments of current Primitive + + Returns: + A tuple of two elements, first element indicates whether or not we should filter out current arguments; + seconde element is the output in case where we should filter out the arguments. + """ + return (False, None) + def __call__(self, *args): - output = _run_op(self, self.name, args) - return output + should_elim, output = self.check_elim(*args) + if should_elim: + return output + return _run_op(self, self.name, args) def __getstate__(self): return self.__dict__