[mlir] Harden verifiers for DMA ops

DMA operation classes in the Standard dialect (`DmaStartOp` and `DmaWaitOp`)
provide helper functions that make numerous assumptions about the number and
order of operands, and about their types. However, these assumptions were not
checked in the verifier, leading to assertion failures or crashes when helper
functions were used on ill-formed ops. Some of the assuptions were checked in
the custom parser (and thus could not check assumption violations in ops
constructed programmatically, e.g., during rewrites) and others were not
checked at all. Introduce the verifiers for all these assumptions and drop
unnecessary checks in the parser that are now covered by the verifier.

Addresses PR45560.

Differential Revision: https://reviews.llvm.org/D79408
This commit is contained in:
Alex Zinenko 2020-05-05 14:09:35 +02:00
parent 0195b3a909
commit 9d273c0ef0
3 changed files with 214 additions and 37 deletions

View File

@ -286,6 +286,7 @@ public:
void print(OpAsmPrinter &p);
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
LogicalResult verify();
};
/// Prints dimension and symbol list.

View File

@ -1444,49 +1444,82 @@ ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperands(tagIndexInfos, indexType, result.operands))
return failure();
auto memrefType0 = types[0].dyn_cast<MemRefType>();
if (!memrefType0)
return parser.emitError(parser.getNameLoc(),
"expected source to be of memref type");
auto memrefType1 = types[1].dyn_cast<MemRefType>();
if (!memrefType1)
return parser.emitError(parser.getNameLoc(),
"expected destination to be of memref type");
auto memrefType2 = types[2].dyn_cast<MemRefType>();
if (!memrefType2)
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");
if (isStrided) {
if (parser.resolveOperands(strideInfo, indexType, result.operands))
return failure();
}
// Check that source/destination index list size matches associated rank.
if (static_cast<int64_t>(srcIndexInfos.size()) != memrefType0.getRank() ||
static_cast<int64_t>(dstIndexInfos.size()) != memrefType1.getRank())
return parser.emitError(parser.getNameLoc(),
"memref rank not equal to indices count");
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType2.getRank())
return parser.emitError(parser.getNameLoc(),
"tag memref rank not equal to indices count");
return success();
}
LogicalResult DmaStartOp::verify() {
unsigned numOperands = getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements.
if (numOperands < 4)
return emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
if (!getSrcMemRef().getType().isa<MemRefType>())
return emitOpError("expected source to be of memref type");
if (numOperands < getSrcMemRefRank() + 4)
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
<< " operands";
if (!getSrcIndices().empty() &&
!llvm::all_of(getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected source indices to be of index type");
// 2. Destination memref.
if (!getDstMemRef().getType().isa<MemRefType>())
return emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getDstIndices().empty() &&
!llvm::all_of(getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected destination indices to be of index type");
// 3. Number of elements.
if (!getNumElements().getType().isIndex())
return emitOpError("expected num elements to be of index type");
// 4. Tag memref.
if (!getTagMemRef().getType().isa<MemRefType>())
return emitOpError("expected tag to be of memref type");
numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands)
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError("expected tag indices to be of index type");
// DMAs from different memory spaces supported.
if (getSrcMemorySpace() == getDstMemorySpace())
return emitOpError("DMA should be between different memory spaces");
if (getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
getDstMemRefRank() + 3 + 1 &&
getNumOperands() != getTagMemRefRank() + getSrcMemRefRank() +
getDstMemRefRank() + 3 + 1 + 2) {
// Optional stride-related operands must be either both present or both
// absent.
if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2)
return emitOpError("incorrect number of operands");
// 5. Strides.
if (isStrided()) {
if (!getStride().getType().isIndex() ||
!getNumElementsPerStride().getType().isIndex())
return emitOpError(
"expected stride and num elements per stride to be of type index");
}
return success();
}
@ -1536,15 +1569,6 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(numElementsInfo, indexType, result.operands))
return failure();
auto memrefType = type.dyn_cast<MemRefType>();
if (!memrefType)
return parser.emitError(parser.getNameLoc(),
"expected tag to be of memref type");
if (static_cast<int64_t>(tagIndexInfos.size()) != memrefType.getRank())
return parser.emitError(parser.getNameLoc(),
"tag memref rank not equal to indices count");
return success();
}
@ -1554,6 +1578,32 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this);
}
LogicalResult DmaWaitOp::verify() {
// Mandatory non-variadic operands are tag and the number of elements.
if (getNumOperands() < 2)
return emitOpError() << "expected at least 2 operands";
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
if (!getTagMemRef().getType().isa<MemRefType>())
return emitOpError() << "expected tag to be of memref type";
if (getNumOperands() != 2 + getTagMemRefRank())
return emitOpError() << "expected " << 2 + getTagMemRefRank()
<< " operands";
if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return emitOpError() << "expected tag indices to be of index type";
if (!getNumElements().getType().isIndex())
return emitOpError()
<< "expected the number of elements to be of index type";
return success();
}
//===----------------------------------------------------------------------===//
// ExtractElementOp
//===----------------------------------------------------------------------===//

View File

@ -303,6 +303,13 @@ func @invalid_cmp_shape(%idx : () -> ()) {
// -----
func @dma_start_not_enough_operands() {
// expected-error@+1 {{expected at least 4 operands}}
"std.dma_start"() : () -> ()
}
// -----
func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
// expected-error@+1 {{expected source to be of memref type}}
dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32
@ -310,6 +317,24 @@ func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) {
// -----
func @dma_start_not_enough_operands_for_src(
%src: memref<2x2x2xf32>, %idx: index) {
// expected-error@+1 {{expected at least 7 operands}}
"std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> ()
}
// -----
func @dma_start_src_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected source indices to be of index type}}
"std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx)
: (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref<i32,2>, index) -> ()
}
// -----
func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
%mref = alloc() : memref<8 x f32>
// expected-error@+1 {{expected destination to be of memref type}}
@ -318,6 +343,36 @@ func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) {
// -----
func @dma_start_not_enough_operands_for_dst(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{expected at least 7 operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected destination indices to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref<i32,2>, index) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected num elements to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref<i32,2>) -> ()
}
// -----
func @dma_no_tag_memref(%tag : f32, %c0 : index) {
%mref = alloc() : memref<8 x f32>
// expected-error@+1 {{expected tag to be of memref type}}
@ -326,9 +381,80 @@ func @dma_no_tag_memref(%tag : f32, %c0 : index) {
// -----
func @dma_start_not_enough_operands_for_tag(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>) {
// expected-error@+1 {{expected at least 8 operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> ()
}
// -----
func @dma_start_dst_index_wrong_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<2xi32,2>, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> ()
}
// -----
func @dma_start_same_space(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>,
%tag: memref<i32,2>) {
// expected-error@+1 {{DMA should be between different memory spaces}}
dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref<i32,2>
}
// -----
func @dma_start_too_many_operands(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>) {
// expected-error@+1 {{incorrect number of operands}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, index, index) -> ()
}
// -----
func @dma_start_wrong_stride_type(
%src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>,
%tag: memref<i32,2>, %flt: f32) {
// expected-error@+1 {{expected stride and num elements per stride to be of type index}}
"std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt)
: (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<i32,2>, index, f32) -> ()
}
// -----
func @dma_wait_not_enough_operands() {
// expected-error@+1 {{expected at least 2 operands}}
"std.dma_wait"() : () -> ()
}
// -----
func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) {
// expected-error@+1 {{expected tag to be of memref type}}
dma_wait %tag[%c0], %arg0 : f32
"std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> ()
}
// -----
func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected tag indices to be of index type}}
"std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> ()
}
// -----
func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) {
// expected-error@+1 {{expected the number of elements to be of index type}}
"std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> ()
}
// -----