forked from OSchip/llvm-project
[mlir] Harden verifiers for DMA ops
DMA operation classes in the Standard dialect (`DmaStartOp` and `DmaWaitOp`) provide helper functions that make numerous assumptions about the number and order of operands, and about their types. However, these assumptions were not checked in the verifier, leading to assertion failures or crashes when helper functions were used on ill-formed ops. Some of the assuptions were checked in the custom parser (and thus could not check assumption violations in ops constructed programmatically, e.g., during rewrites) and others were not checked at all. Introduce the verifiers for all these assumptions and drop unnecessary checks in the parser that are now covered by the verifier. Addresses PR45560. Differential Revision: https://reviews.llvm.org/D79408
This commit is contained in:
parent
0195b3a909
commit
9d273c0ef0
|
@ -286,6 +286,7 @@ public:
|
|||
void print(OpAsmPrinter &p);
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
LogicalResult verify();
|
||||
};
|
||||
|
||||
/// Prints dimension and symbol list.
|
||||
|
|
|
@ -1444,49 +1444,82 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
parser.resolveOperands(tagIndexInfos, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
auto memrefType0 = types[0].dyn_cast<MemRefType>();
|
||||
if (!memrefType0)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected source to be of memref type");
|
||||
|
||||
auto memrefType1 = types[1].dyn_cast<MemRefType>();
|
||||
if (!memrefType1)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected destination to be of memref type");
|
||||
|
||||
auto memrefType2 = types[2].dyn_cast<MemRefType>();
|
||||
if (!memrefType2)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected tag to be of memref type");
|
||||
|
||||
if (isStrided) {
|
||||
if (parser.resolveOperands(strideInfo, indexType, result.operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that source/destination index list size matches associated rank.
|
||||
if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
|
||||
static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"memref rank not equal to indices count");
|
||||
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"tag memref rank not equal to indices count");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DmaStartOp::verify() {
|
||||
unsigned numOperands = getNumOperands();
|
||||
|
||||
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
|
||||
// the number of elements.
|
||||
if (numOperands < 4)
|
||||
return emitOpError("expected at least 4 operands");
|
||||
|
||||
// Check types of operands. The order of these calls is important: the later
|
||||
// calls rely on some type properties to compute the operand position.
|
||||
// 1. Source memref.
|
||||
if (!getSrcMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected source to be of memref type");
|
||||
if (numOperands < getSrcMemRefRank() + 4)
|
||||
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
|
||||
<< " operands";
|
||||
if (!getSrcIndices().empty() &&
|
||||
!llvm::all_of(getSrcIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return emitOpError("expected source indices to be of index type");
|
||||
|
||||
// 2. Destination memref.
|
||||
if (!getDstMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected destination to be of memref type");
|
||||
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
|
||||
if (numOperands < numExpectedOperands)
|
||||
return emitOpError() << "expected at least " << numExpectedOperands
|
||||
<< " operands";
|
||||
if (!getDstIndices().empty() &&
|
||||
!llvm::all_of(getDstIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return emitOpError("expected destination indices to be of index type");
|
||||
|
||||
// 3. Number of elements.
|
||||
if (!getNumElements().getType().isIndex())
|
||||
return emitOpError("expected num elements to be of index type");
|
||||
|
||||
// 4. Tag memref.
|
||||
if (!getTagMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected tag to be of memref type");
|
||||
numExpectedOperands += getTagMemRefRank();
|
||||
if (numOperands < numExpectedOperands)
|
||||
return emitOpError() << "expected at least " << numExpectedOperands
|
||||
<< " operands";
|
||||
if (!getTagIndices().empty() &&
|
||||
!llvm::all_of(getTagIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return emitOpError("expected tag indices to be of index type");
|
||||
|
||||
// DMAs from different memory spaces supported.
|
||||
if (getSrcMemorySpace() == getDstMemorySpace())
|
||||
return emitOpError("DMA should be between different memory spaces");
|
||||
|
||||
if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
|
||||
getDstMemRefRank() + 3 + 1 &&
|
||||
getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
|
||||
getDstMemRefRank() + 3 + 1 + 2) {
|
||||
// Optional stride-related operands must be either both present or both
|
||||
// absent.
|
||||
if (numOperands != numExpectedOperands &&
|
||||
numOperands != numExpectedOperands + 2)
|
||||
return emitOpError("incorrect number of operands");
|
||||
|
||||
// 5. Strides.
|
||||
if (isStrided()) {
|
||||
if (!getStride().getType().isIndex() ||
|
||||
!getNumElementsPerStride().getType().isIndex())
|
||||
return emitOpError(
|
||||
"expected stride and num elements per stride to be of type index");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1536,15 +1569,6 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
parser.resolveOperand(numElementsInfo, indexType, result.operands))
|
||||
return failure();
|
||||
|
||||
auto memrefType = type.dyn_cast<MemRefType>();
|
||||
if (!memrefType)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected tag to be of memref type");
|
||||
|
||||
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"tag memref rank not equal to indices count");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1554,6 +1578,32 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
LogicalResult DmaWaitOp::verify() {
|
||||
// Mandatory non-variadic operands are tag and the number of elements.
|
||||
if (getNumOperands() < 2)
|
||||
return emitOpError() << "expected at least 2 operands";
|
||||
|
||||
// Check types of operands. The order of these calls is important: the later
|
||||
// calls rely on some type properties to compute the operand position.
|
||||
if (!getTagMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError() << "expected tag to be of memref type";
|
||||
|
||||
if (getNumOperands() != 2 + getTagMemRefRank())
|
||||
return emitOpError() << "expected " << 2 + getTagMemRefRank()
|
||||
<< " operands";
|
||||
|
||||
if (!getTagIndices().empty() &&
|
||||
!llvm::all_of(getTagIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return emitOpError() << "expected tag indices to be of index type";
|
||||
|
||||
if (!getNumElements().getType().isIndex())
|
||||
return emitOpError()
|
||||
<< "expected the number of elements to be of index type";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -303,6 +303,13 @@ func @invalid_cmp_shape(%idx : () -> ()) {
|
|||
|
||||
// -----
|
||||
|
||||
func @dma_start_not_enough_operands() {
|
||||
// expected-error@+1 {{expected at least 4 operands}}
|
||||
"std.dma_start"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
|
||||
// expected-error@+1 {{expected source to be of memref type}}
|
||||
dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
|
||||
|
@ -310,6 +317,24 @@ func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
|
|||
|
||||
// -----
|
||||
|
||||
func @dma_start_not_enough_operands_for_src(
|
||||
%src: memref<2x2x2xf32>, %idx: index) {
|
||||
// expected-error@+1 {{expected at least 7 operands}}
|
||||
"std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_src_index_wrong_type(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>, %flt: f32) {
|
||||
// expected-error@+1 {{expected source indices to be of index type}}
|
||||
"std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
|
||||
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
|
||||
%mref = alloc() : memref<8 x f32>
|
||||
// expected-error@+1 {{expected destination to be of memref type}}
|
||||
|
@ -318,6 +343,36 @@ func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
|
|||
|
||||
// -----
|
||||
|
||||
func @dma_start_not_enough_operands_for_dst(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>) {
|
||||
// expected-error@+1 {{expected at least 7 operands}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_dst_index_wrong_type(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>, %flt: f32) {
|
||||
// expected-error@+1 {{expected destination indices to be of index type}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_dst_index_wrong_type(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>, %flt: f32) {
|
||||
// expected-error@+1 {{expected num elements to be of index type}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
|
||||
%mref = alloc() : memref<8 x f32>
|
||||
// expected-error@+1 {{expected tag to be of memref type}}
|
||||
|
@ -326,9 +381,80 @@ func @dma_no_tag_memref(%tag : f32, %c0 : index) {
|
|||
|
||||
// -----
|
||||
|
||||
func @dma_start_not_enough_operands_for_tag(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<2xi32,2>) {
|
||||
// expected-error@+1 {{expected at least 8 operands}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_dst_index_wrong_type(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<2xi32,2>, %flt: f32) {
|
||||
// expected-error@+1 {{expected tag indices to be of index type}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_same_space(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>,
|
||||
%tag: memref<i32,2>) {
|
||||
// expected-error@+1 {{DMA should be between different memory spaces}}
|
||||
dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref<i32,2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_too_many_operands(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>) {
|
||||
// expected-error@+1 {{incorrect number of operands}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_start_wrong_stride_type(
|
||||
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
|
||||
%tag: memref<i32,2>, %flt: f32) {
|
||||
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
|
||||
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
|
||||
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_wait_not_enough_operands() {
|
||||
// expected-error@+1 {{expected at least 2 operands}}
|
||||
"std.dma_wait"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
|
||||
// expected-error@+1 {{expected tag to be of memref type}}
|
||||
dma_wait %tag[%c0], %arg0 : f32
|
||||
"std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
|
||||
// expected-error@+1 {{expected tag indices to be of index type}}
|
||||
"std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
|
||||
// expected-error@+1 {{expected the number of elements to be of index type}}
|
||||
"std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue