fix param check for check_elim

This commit is contained in:
BowenK 2020-06-20 16:07:56 +08:00
parent 7f891f62e5
commit 62aed2ff30
1 changed files with 5 additions and 8 deletions

View File

@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import typing
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer):
data = x.default_input
if data.dtype == dtype:
return (True, x)
return (False, None)
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
return (False, None)
def __infer__(self, x, t):
src_type = x['dtype']
@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer):
def check_elim(self, base_tensor, multiplier):
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
def is_all_zeros(v_tuple):
return all(v == 1 for v in v_tuple)
if is_all_zeros(multiplier):
raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
if all(v == 1 for v in multiplier):
return (True, base_tensor)
return (False, None)
@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer):
validator.check_value_type("shape", multiples_v, [tuple], self.name)
for i, multiple in enumerate(multiples_v):
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
len_sub = len(multiples_v) - len(x_shp)
multiples_w = None
if len_sub == 0: