add overflow check for make_range and optimize isinstance processing

This commit is contained in:
buxue 2020-09-17 11:34:55 +08:00
parent c95ed54fe1
commit 14f6c95c28
4 changed files with 16 additions and 6 deletions

View File

@ -173,7 +173,7 @@ def check_type_same(x_type, base_type):
"""Check x_type is same as base_type."""
if mstype.issubclass_(x_type, base_type):
return True
raise TypeError(f"The arg 'x' should be a {base_type}, but got {x_type}.")
return False
@constexpr

View File

@ -489,15 +489,25 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr
if (slide.step <= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
}
for (int i = slide.start; i < slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i));
if (i > 0 && INT_MAX - i < slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
}
} else {
if (slide.step >= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
}
for (int i = slide.start; i > slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i));
if (i < 0 && INT_MIN - i > slide.step) {
MS_EXCEPTION(ValueError) << "For make range, the required cycles number is greater than max cycles number, "
"will cause integer overflow.";
}
}
}

View File

@ -268,7 +268,7 @@ def _tensor_index_by_tuple_slice(data, t):
def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types"""
if len(tuple_index) == 1:
return data[tuple_index[0]]
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM)
if index_elements_type == const_utils.NO_TENSOR:

View File

@ -40,17 +40,17 @@ def test_number_not_in_tuple():
if self.number_in not in self.tuple_:
ret += 1
if self.number_not_in not in self.tuple_:
ret += 1
ret += 2
if self.number_in not in self.list_:
ret += 3
if self.number_not_in not in self.list_:
ret += 3
ret += 4
if self.str_in not in self.dict_:
ret += 5
if self.str_not_in not in self.dict_:
ret += 5
ret += 6
return ret
net = Net()
output = net()
assert output == 9
assert output == 12