forked from OSchip/llvm-project
Add ViewOp to the StandardOps dialect, which casts a 1D/i8 element type memref type to an N-D memref type.
Proposed in RFC: https://groups.google.com/a/tensorflow.org/forum/#!searchin/mlir/std.view%7Csort:date/mlir/-wKHANzDNTg/4K6nUAp8AAAJ Supports creating the N-D memref type with dynamic sizes and at a dynamic offset within the 1D base memref. This change contains op definition/parsing/printing and tests. Follow up changes will handle constant shape/layout map folding and llvm lowering. PiperOrigin-RevId: 278869990
This commit is contained in:
parent
0d545921ea
commit
c38dca7f4b
|
@ -1133,6 +1133,68 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
|
|||
}];
|
||||
}
|
||||
|
||||
def ViewOp : Std_Op<"view"> {
|
||||
let summary = "memref view operation";
|
||||
let description = [{
|
||||
The "view" operation converts a 1-D memref with i8 element type,
|
||||
to an N-D memref with arbitrary element type. In addition, the ViewOp
|
||||
supports the following arguments:
|
||||
*) A dynamic size operand must be specified for each dynamic dimension
|
||||
in the resulting view memref type.
|
||||
*) A single dynamic offset operand can be specified which represents a
|
||||
a dynamic offset within the base 1-D memref at which to create the
|
||||
resulting memref view.
|
||||
|
||||
// Allocate a flat 1D/i8 memref.
|
||||
%0 = alloc() : memref<2048xi8>
|
||||
|
||||
// ViewOp with static sizes and offset.
|
||||
%1 = view %0[][] : memref<2048xi8> to memref<64x4xf32>
|
||||
|
||||
// ViewOp with one dynamic size and a dynamic offset.
|
||||
%2 = view %0[%size0][%offset_1024]
|
||||
: memref<2048xi8> to memref<?x4xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
|
||||
|
||||
// ViewOp creating 3D shape where two of the dim sizes are dynamic.
|
||||
// *) The dynamic size for the second dimension induces a dynamic
|
||||
// stride for the first dimension, which is represented by the
|
||||
// symbol 's0' in the layout map of the ViewOp result memref type.
|
||||
// Note that this dynamic stride will be computed from the view
|
||||
// shape and dynamic sizes.
|
||||
// *) The dynamic offset specified in the ViewOp is applied to the
|
||||
// base 1-D memref, and is represented by the symbol 's1' in the
|
||||
// layout map of the ViewOp result memref type.
|
||||
%3 = view %0[%size0, %size1][%offset_1024]
|
||||
: memref<2048xi8> to memref<?x?x4xf32,
|
||||
(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * 4 + d2 + s1)
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$source, Variadic<Index>:$operands);
|
||||
let results = (outs AnyMemRef);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// The result of a memref_shape_cast is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
|
||||
/// Returns the dynamic offset for this shape cast operation if specified.
|
||||
/// Returns nullptr if no dynamic offset was specified.
|
||||
Value *getDynamicOffset() {
|
||||
unsigned offsetPos = 1 + getType().getNumDynamicDims();
|
||||
if (offsetPos >= getNumOperands())
|
||||
return nullptr;
|
||||
return getOperand(offsetPos);
|
||||
}
|
||||
|
||||
/// Returns the dynamic sizes for this shape cast operation.
|
||||
operand_range getDynamicSizes() {
|
||||
return {operand_begin() + 1,
|
||||
operand_begin() + 1 + getType().getNumDynamicDims()};
|
||||
}
|
||||
}];
|
||||
// TODO(andydavis) Add canonicalizer to fold constants into shape and map.
|
||||
}
|
||||
|
||||
|
||||
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
|
||||
let summary = "integer binary xor";
|
||||
let hasFolder = 1;
|
||||
|
|
|
@ -529,7 +529,7 @@ void mlir::linalg::ViewOp::build(Builder *b, OperationState &result,
|
|||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ViewOp op) {
|
||||
static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) {
|
||||
p << op.getOperationName() << " " << *op.buffer() << "[";
|
||||
interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; });
|
||||
p << "] ";
|
||||
|
|
|
@ -490,14 +490,15 @@ public:
|
|||
class ViewOpConversion : public LLVMOpLowering {
|
||||
public:
|
||||
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||
: LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {}
|
||||
: LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context,
|
||||
lowering_) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ViewOpOperandAdaptor adaptor(operands);
|
||||
mlir::linalg::ViewOpOperandAdaptor adaptor(operands);
|
||||
|
||||
auto viewOp = cast<ViewOp>(op);
|
||||
auto viewOp = cast<mlir::linalg::ViewOp>(op);
|
||||
BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(),
|
||||
rewriter, lowering);
|
||||
LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty;
|
||||
|
|
|
@ -2340,6 +2340,83 @@ static LogicalResult verify(TruncateIOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType srcInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type srcType, dstType;
|
||||
return failure(
|
||||
parser.parseOperand(srcInfo) ||
|
||||
parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(srcType) ||
|
||||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
|
||||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
|
||||
parser.resolveOperands(offsetInfo, indexType, result.operands) ||
|
||||
parser.parseKeywordType("to", dstType) ||
|
||||
parser.addTypeToList(dstType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ViewOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
|
||||
p.printOperands(op.getDynamicSizes());
|
||||
p << "][";
|
||||
auto *dynamicOffset = op.getDynamicOffset();
|
||||
if (dynamicOffset != nullptr)
|
||||
p.printOperand(dynamicOffset);
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ViewOp op) {
|
||||
auto baseType = op.getOperand(0)->getType().dyn_cast<MemRefType>();
|
||||
auto viewType = op.getResult()->getType().dyn_cast<MemRefType>();
|
||||
|
||||
// Operand 0 type and ViewOp result type must be memref.
|
||||
if (!baseType || !viewType)
|
||||
return op.emitError("operand type ") << baseType << " and result type "
|
||||
<< viewType << " are must be memref";
|
||||
|
||||
// The base memref should be rank 1 with i8 element type.
|
||||
if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8))
|
||||
return op.emitError("unsupported shape for base memref type ") << baseType;
|
||||
|
||||
// The base memref should have identity layout map (or none).
|
||||
if (baseType.getAffineMaps().size() > 1 ||
|
||||
(baseType.getAffineMaps().size() == 1 &&
|
||||
!baseType.getAffineMaps()[0].isIdentity()))
|
||||
return op.emitError("unsupported map for base memref type ") << baseType;
|
||||
|
||||
// The base memref and the view memref should be in the same memory space.
|
||||
if (baseType.getMemorySpace() != viewType.getMemorySpace())
|
||||
return op.emitError("different memory spaces specified for base memref "
|
||||
"type ")
|
||||
<< baseType << " and view memref type " << viewType;
|
||||
|
||||
// Verify that the result memref type has a strided layout map. is strided
|
||||
int64_t offset;
|
||||
llvm::SmallVector<int64_t, 4> strides;
|
||||
if (failed(mlir::getStridesAndOffset(viewType, strides, offset)))
|
||||
return op.emitError("result type ") << viewType << " is not strided";
|
||||
|
||||
// Verify that we have the correct number of operands for the result type.
|
||||
unsigned memrefOperandCount = 1;
|
||||
unsigned numDynamicDims = viewType.getNumDynamicDims();
|
||||
unsigned dynamicOffsetCount =
|
||||
offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0;
|
||||
if (op.getNumOperands() !=
|
||||
memrefOperandCount + numDynamicDims + dynamicOffsetCount)
|
||||
return op.emitError("incorrect number of operands for type ") << viewType;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ZeroExtendIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
// CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0)
|
||||
// CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
|
||||
// CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0)
|
||||
// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0)
|
||||
// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1)
|
||||
|
||||
// CHECK-LABEL: func @func_with_ops(%arg0: f32) {
|
||||
func @func_with_ops(f32) {
|
||||
|
@ -472,6 +474,36 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @memref_view(%arg0
|
||||
func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xi8>
|
||||
// Test two dynamic sizes and dynamic offset.
|
||||
// CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP0]]>
|
||||
%1 = view %0[%arg0, %arg1][%arg2]
|
||||
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
|
||||
|
||||
// Test two dynamic sizes and static offset.
|
||||
// CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref<?x?xf32, #[[VIEW_MAP1]]>
|
||||
%2 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xi8> to memref<?x?xf32, (d0, d1) -> (d0 * 4 + d1)>
|
||||
|
||||
// Test one dynamic size and dynamic offset.
|
||||
// CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]>
|
||||
%3 = view %0[%arg1][%arg2]
|
||||
: memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
|
||||
|
||||
// Test one dynamic size and static offset.
|
||||
// CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref<?x16xf32, #[[VIEW_MAP1]]>
|
||||
%4 = view %0[%arg0][]
|
||||
: memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
|
||||
|
||||
// Test static sizes and static offset.
|
||||
// CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]>
|
||||
%5 = view %0[][]
|
||||
: memref<2048xi8> to memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_dimop(%arg0
|
||||
func @test_dimop(%arg0: tensor<4x4x?xf32>) {
|
||||
// CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32>
|
||||
|
|
|
@ -901,3 +901,55 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
|||
// expected-error@-1 {{expects different type than prior uses}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xi8>
|
||||
// expected-error@+1 {{incorrect number of operands for type}}
|
||||
%1 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xi8>
|
||||
// expected-error@+1 {{is not strided}}
|
||||
%1 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0, d1, s0)>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xf32>
|
||||
// expected-error@+1 {{unsupported shape for base memref}}
|
||||
%1 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xf32> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)>
|
||||
// expected-error@+1 {{unsupported map for base memref}}
|
||||
%1 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)> to
|
||||
memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
%0 = alloc() : memref<2048xi8, 2>
|
||||
// expected-error@+1 {{different memory spaces}}
|
||||
%1 = view %0[%arg0, %arg1][]
|
||||
: memref<2048xi8, 2> to
|
||||
memref<?x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0), 1>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue