forked from mindspore-Ecosystem/mindspore
fix param check for check_elim
This commit is contained in:
parent
7f891f62e5
commit
62aed2ff30
|
@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_
|
|||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._c_expression import typing
|
||||
|
||||
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
|
||||
|
@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer):
|
|||
data = x.default_input
|
||||
if data.dtype == dtype:
|
||||
return (True, x)
|
||||
return (False, None)
|
||||
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
|
||||
return (False, None)
|
||||
|
||||
def __infer__(self, x, t):
|
||||
src_type = x['dtype']
|
||||
|
@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer):
|
|||
|
||||
def check_elim(self, base_tensor, multiplier):
|
||||
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
|
||||
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
|
||||
def is_all_zeros(v_tuple):
|
||||
return all(v == 1 for v in v_tuple)
|
||||
if is_all_zeros(multiplier):
|
||||
raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
|
||||
if all(v == 1 for v in multiplier):
|
||||
return (True, base_tensor)
|
||||
return (False, None)
|
||||
|
||||
|
@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer):
|
|||
validator.check_value_type("shape", multiples_v, [tuple], self.name)
|
||||
for i, multiple in enumerate(multiples_v):
|
||||
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
|
||||
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32]
|
||||
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
|
||||
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
|
||||
len_sub = len(multiples_v) - len(x_shp)
|
||||
multiples_w = None
|
||||
if len_sub == 0:
|
||||
|
|
Loading…
Reference in New Issue