forked from mindspore-Ecosystem/mindspore
change cumsum op python to allow for float64 input
This commit is contained in:
parent
1965ecb9a1
commit
c0ebc36c78
|
@ -972,7 +972,7 @@ class CumSum(PrimitiveWithInfer):
|
|||
if axis['value'] is None:
|
||||
raise ValueError(f"For {self.name}, axis must be const.")
|
||||
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
||||
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
||||
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64]
|
||||
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
|
||||
return {'shape': x_shp,
|
||||
'dtype': x['dtype'],
|
||||
|
|
Loading…
Reference in New Issue