Add ViewOp to the StandardOps dialect, which casts a 1D/i8 element type memref type to an N-D memref type.

Proposed in RFC: https://groups.google.com/a/tensorflow.org/forum/#!searchin/mlir/std.view%7Csort:date/mlir/-wKHANzDNTg/4K6nUAp8AAAJ

Supports creating the N-D memref type with dynamic sizes and at a dynamic offset within the 1D base memref.
This change contains op definition/parsing/printing and tests. Follow up changes will handle constant shape/layout map folding and llvm lowering.

PiperOrigin-RevId: 278869990
This commit is contained in:
Andy Davis 2019-11-06 08:53:39 -08:00 committed by A. Unique TensorFlower
parent 0d545921ea
commit c38dca7f4b
6 changed files with 228 additions and 4 deletions

View File

@ -1133,6 +1133,68 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
}];
}
def ViewOp : Std_Op<"view"> {
let summary = "memref view operation";
let description = [{
The "view" operation converts a 1-D memref with i8 element type,
to an N-D memref with arbitrary element type. In addition, the ViewOp
supports the following arguments:
*) A dynamic size operand must be specified for each dynamic dimension
in the resulting view memref type.
*) A single dynamic offset operand can be specified which represents a
a dynamic offset within the base 1-D memref at which to create the
resulting memref view.
// Allocate a flat 1D/i8 memref.
%0 = alloc() : memref<2048xi8>
// ViewOp with static sizes and offset.
%1 = view %0[][] : memref<2048xi8> to memref<64x4xf32>
// ViewOp with one dynamic size and a dynamic offset.
%2 = view %0[%size0][%offset_1024]
: memref<2048xi8> to memref<?x4xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
// ViewOp creating 3D shape where two of the dim sizes are dynamic.
// *) The dynamic size for the second dimension induces a dynamic
// stride for the first dimension, which is represented by the
// symbol 's0' in the layout map of the ViewOp result memref type.
// Note that this dynamic stride will be computed from the view
// shape and dynamic sizes.
// *) The dynamic offset specified in the ViewOp is applied to the
// base 1-D memref, and is represented by the symbol 's1' in the
// layout map of the ViewOp result memref type.
%3 = view %0[%size0, %size1][%offset_1024]
: memref<2048xi8> to memref<?x?x4xf32,
(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * 4 + d2 + s1)
}];
let arguments = (ins AnyMemRef:$source, Variadic<Index>:$operands);
let results = (outs AnyMemRef);
let extraClassDeclaration = [{
/// The result of a memref_shape_cast is always a memref.
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
/// Returns the dynamic offset for this shape cast operation if specified.
/// Returns nullptr if no dynamic offset was specified.
Value *getDynamicOffset() {
unsigned offsetPos = 1 + getType().getNumDynamicDims();
if (offsetPos >= getNumOperands())
return nullptr;
return getOperand(offsetPos);
}
/// Returns the dynamic sizes for this shape cast operation.
operand_range getDynamicSizes() {
return {operand_begin() + 1,
operand_begin() + 1 + getType().getNumDynamicDims()};
}
}];
// TODO(andydavis) Add canonicalizer to fold constants into shape and map.
}
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasFolder = 1;

View File

@ -529,7 +529,7 @@ void mlir::linalg::ViewOp::build(Builder *b, OperationState &result,
result.addAttributes(attrs);
}
static void print(OpAsmPrinter &p, ViewOp op) {
static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) {
p << op.getOperationName() << " " << *op.buffer() << "[";
interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; });
p << "] ";

View File

@ -490,14 +490,15 @@ public:
class ViewOpConversion : public LLVMOpLowering {
public:
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
: LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context,
lowering_) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
ViewOpOperandAdaptor adaptor(operands);
mlir::linalg::ViewOpOperandAdaptor adaptor(operands);
auto viewOp = cast<ViewOp>(op);
auto viewOp = cast<mlir::linalg::ViewOp>(op);
BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(),
rewriter, lowering);
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;

View File

@ -2340,6 +2340,83 @@ static LogicalResult verify(TruncateIOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
auto indexType = parser.getBuilder().getIndexType();
Type srcType, dstType;
return failure(
parser.parseOperand(srcInfo) ||
parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
parser.resolveOperands(offsetInfo, indexType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types));
}
static void print(OpAsmPrinter &p, ViewOp op) {
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
p.printOperands(op.getDynamicSizes());
p << "][";
auto *dynamicOffset = op.getDynamicOffset();
if (dynamicOffset != nullptr)
p.printOperand(dynamicOffset);
p << ']';
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
}
static LogicalResult verify(ViewOp op) {
auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>();
auto viewType = op.getResult()->getType().dyn_cast<MemRefType>();
// Operand 0 type and ViewOp result type must be memref.
if (!baseType || !viewType)
return op.emitError("operand type ") << baseType << " and result type "
<< viewType << " are must be memref";
// The base memref should be rank 1 with i8 element type.
if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8))
return op.emitError("unsupported shape for base memref type ") << baseType;
// The base memref should have identity layout map (or none).
if (baseType.getAffineMaps().size() > 1 ||
(baseType.getAffineMaps().size() == 1 &&
!baseType.getAffineMaps()[0].isIdentity()))
return op.emitError("unsupported map for base memref type ") << baseType;
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != viewType.getMemorySpace())
return op.emitError("different memory spaces specified for base memref "
"type ")
<< baseType << " and view memref type " << viewType;
// Verify that the result memref type has a strided layout map. is strided
int64_t offset;
llvm::SmallVector<int64_t, 4> strides;
if (failed(mlir::getStridesAndOffset(viewType, strides, offset)))
return op.emitError("result type ") << viewType << " is not strided";
// Verify that we have the correct number of operands for the result type.
unsigned memrefOperandCount = 1;
unsigned numDynamicDims = viewType.getNumDynamicDims();
unsigned dynamicOffsetCount =
offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0;
if (op.getNumOperands() !=
memrefOperandCount + numDynamicDims + dynamicOffsetCount)
return op.emitError("incorrect number of operands for type ") << viewType;
return success();
}
//===----------------------------------------------------------------------===//
// ZeroExtendIOp
//===----------------------------------------------------------------------===//

View File

@ -10,6 +10,8 @@
// CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0)
// CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
// CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0)
// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1)
// CHECK-LABEL: func @func_with_ops(%arg0: f32) {
func @func_with_ops(f32) {
@ -472,6 +474,36 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>) {
return
}
// CHECK-LABEL: func @memref_view(%arg0
func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8>
// Test two dynamic sizes and dynamic offset.
// CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP0]]>
%1 = view %0[%arg0, %arg1][%arg2]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
// Test two dynamic sizes and static offset.
// CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP1]]>
%2 = view %0[%arg0, %arg1][]
: memref<2048xi8> to memref<?x?xf32, (d0, d1) -> (d0 * 4 + d1)>
// Test one dynamic size and dynamic offset.
// CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]>
%3 = view %0[%arg1][%arg2]
: memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
// Test one dynamic size and static offset.
// CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x16xf32, #[[VIEW_MAP1]]>
%4 = view %0[%arg0][]
: memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
// Test static sizes and static offset.
// CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]>
%5 = view %0[][]
: memref<2048xi8> to memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>
return
}
// CHECK-LABEL: func @test_dimop(%arg0
func @test_dimop(%arg0: tensor<4x4x?xf32>) {
// CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32>

View File

@ -901,3 +901,55 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
// expected-error@-1 {{expects different type than prior uses}}
return
}
// -----
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8>
// expected-error@+1 {{incorrect number of operands for type}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
return
}
// -----
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8>
// expected-error@+1 {{is not strided}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0, d1, s0)>
return
}
// -----
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xf32>
// expected-error@+1 {{unsupported shape for base memref}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xf32> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
return
}
// -----
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)>
// expected-error@+1 {{unsupported map for base memref}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)> to
memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
return
}
// -----
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8, 2>
// expected-error@+1 {{different memory spaces}}
%1 = view %0[%arg0, %arg1][]
: memref<2048xi8, 2> to
memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0), 1>
return
}