forked from mindspore-Ecosystem/mindspore
add overflow check for make_range and optimize isinstance processing
This commit is contained in:
parent
c95ed54fe1
commit
14f6c95c28
|
@ -173,7 +173,7 @@ def check_type_same(x_type, base_type):
|
|||
"""Check x_type is same as base_type."""
|
||||
if mstype.issubclass_(x_type, base_type):
|
||||
return True
|
||||
raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.")
|
||||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -489,15 +489,25 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
if (slide.step <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||
}
|
||||
|
||||
for (int i = slide.start; i < slide.stop; i += slide.step) {
|
||||
args.push_back(abstract::FromValue(i));
|
||||
if (i > 0 && INT_MAX - i < slide.step) {
|
||||
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
|
||||
"will cause integer overflow.";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (slide.step >= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||
}
|
||||
|
||||
for (int i = slide.start; i > slide.stop; i += slide.step) {
|
||||
args.push_back(abstract::FromValue(i));
|
||||
if (i < 0 && INT_MIN - i > slide.step) {
|
||||
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
|
||||
"will cause integer overflow.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -268,7 +268,7 @@ def _tensor_index_by_tuple_slice(data, t):
|
|||
def tensor_index_by_tuple(data, tuple_index):
|
||||
"""Tensor getitem by tuple of various types"""
|
||||
if len(tuple_index) == 1:
|
||||
return data[tuple_index[0]]
|
||||
return data[tuple_index[0]]
|
||||
indexes_types = hyper_map(F.typeof, tuple_index)
|
||||
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
|
||||
if index_elements_type == const_utils.NO_TENSOR:
|
||||
|
|
|
@ -40,17 +40,17 @@ def test_number_not_in_tuple():
|
|||
if self.number_in not in self.tuple_:
|
||||
ret += 1
|
||||
if self.number_not_in not in self.tuple_:
|
||||
ret += 1
|
||||
ret += 2
|
||||
if self.number_in not in self.list_:
|
||||
ret += 3
|
||||
if self.number_not_in not in self.list_:
|
||||
ret += 3
|
||||
ret += 4
|
||||
if self.str_in not in self.dict_:
|
||||
ret += 5
|
||||
if self.str_not_in not in self.dict_:
|
||||
ret += 5
|
||||
ret += 6
|
||||
return ret
|
||||
|
||||
net = Net()
|
||||
output = net()
|
||||
assert output == 9
|
||||
assert output == 12
|
||||
|
|
Loading…
Reference in New Issue