[mlir] Prevent segfault in Tensor canonicalization

This segfault could occur from out of bounds accesses when simplifying
tensor.extract with a constant index and a tensor created by
tensor.from_elements.

This IR is not necesarilly invalid as it might conditionally be
never executed.

Differential Revision: https://reviews.llvm.org/D95535
This commit is contained in:
Tres Popp 2021-01-27 17:28:14 +01:00
parent 1c762a81d2
commit 0c5e4a25ee
2 changed files with 50 additions and 0 deletions

View File

@ -248,6 +248,11 @@ struct ExtractElementFromTensorFromElements
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
return failure();
// Prevent out of bounds accesses. This can happen in invalid code that will
// never execute.
if (tensorFromElements->getNumOperands() <= index.getZExtValue() ||
index.getSExtValue() < 0)
return failure();
rewriter.replaceOp(extract,
tensorFromElements.getOperand(index.getZExtValue()));
return success();

View File

@ -122,6 +122,51 @@ func @extract_from_tensor.from_elements(%element : index) -> index {
// -----
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_negative_from_tensor.from_elements
func @extract_negative_from_tensor.from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
%c-1 = constant -1 : index
%tensor = tensor.from_elements %element : tensor<1xindex>
%extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
// CHECK: tensor.from_elements
// CHECK: %[[RESULT:.*]] = tensor.extract
// CHECK: return %[[RESULT]]
return %extracted_element : index
}
// -----
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
func @extract_oob_from_tensor.from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
%c1 = constant 1 : index
%tensor = tensor.from_elements %element : tensor<1xindex>
%extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
// CHECK: tensor.from_elements
// CHECK: %[[RESULT:.*]] = tensor.extract
// CHECK: return %[[RESULT]]
return %extracted_element : index
}
// -----
// Ensure the optimization doesn't segfault from bad constants
// CHECK-LABEL: func @extract_oob_from_tensor.from_elements
func @extract_oob_from_tensor.from_elements(%element : index) -> index {
// CHECK-SAME: ([[ARG:%.*]]: index)
%c2 = constant 2 : index
%tensor = tensor.from_elements %element : tensor<1xindex>
%extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
// CHECK: tensor.from_elements
// CHECK: %[[RESULT:.*]] = tensor.extract
// CHECK: return %[[RESULT]]
return %extracted_element : index
}
// -----
// CHECK-LABEL: func @extract_from_tensor.generate
// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {