[mlir] Allow constructing a ValueRange from an ArrayRef<BlockArgument>

Summary: This was a missed case when ValueRange was originally added, and allows for constructing a ValueRange from the arguments of a block.

Differential Revision: https://reviews.llvm.org/D74363
This commit is contained in:
River Riddle 2020-02-12 09:46:21 -08:00 committed by River Riddle
parent 26edb21c29
commit c832145960
4 changed files with 10 additions and 12 deletions

View File

@ -220,13 +220,11 @@ public:
return getBlocks().front().getArgument(idx);
}
// Supports non-const operand iteration.
/// Support argument iteration.
using args_iterator = Block::args_iterator;
args_iterator args_begin() { return front().args_begin(); }
args_iterator args_end() { return front().args_end(); }
iterator_range<args_iterator> getArguments() {
return {args_begin(), args_end()};
}
Block::BlockArgListType getArguments() { return front().getArguments(); }
//===--------------------------------------------------------------------===//
// Argument Attributes

View File

@ -658,6 +658,8 @@ public:
: ValueRange(OperandRange(values)) {}
ValueRange(iterator_range<ResultRange::iterator> values)
: ValueRange(ResultRange(values)) {}
ValueRange(ArrayRef<BlockArgument> values)
: ValueRange(ArrayRef<Value>(values.data(), values.size())) {}
ValueRange(ArrayRef<Value> values = llvm::None);
ValueRange(OperandRange values);
ValueRange(ResultRange values);

View File

@ -332,10 +332,9 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
}
static void print(OpAsmPrinter &p, ParallelOp op) {
p << op.getOperationName() << " (";
p.printOperands(op.getBody()->getArguments());
p << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
<< op.step() << ")";
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
<< op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
<< ")";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op.getAttrs());
if (!op.results().empty())

View File

@ -897,10 +897,9 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
linalg_dilated_conv_nhwc(
makeValueHandles(llvm::to_vector<3>(f.getArguments())),
/*depth_multiplier=*/7,
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
linalg_dilated_conv_nhwc(makeValueHandles(f.getArguments()),
/*depth_multiplier=*/7,
/*strides=*/{3, 4}, /*dilations=*/{5, 6});
f.print(llvm::outs());
f.erase();