Add a linalg.transpose op

A linalg.transpose op is a pure metadata operation that takes a view + permutation map and produces
another view of the same underlying data, with a different reindexing. This is a
pure metadata operation that does not touch the underlying data.

Example:

```
  %t = linalg.transpose %v (i, j) -> (j, i) : !linalg.view<?x?xf32>
```

PiperOrigin-RevId: 265139429
This commit is contained in:
Nicolas Vasilache 2019-08-23 14:47:46 -07:00 committed by A. Unique TensorFlower
parent 32052c8417
commit 2c2c9ffd80
4 changed files with 96 additions and 10 deletions

View File

@ -19,12 +19,13 @@
//
//===----------------------------------------------------------------------===//
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
#ifdef LINALG_OPS
#else
#define LINALG_OPS
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
// Base class for Linalg dialect ops that do not correspond to library calls.
class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Linalg_Dialect, mnemonic, traits> {
@ -398,6 +399,37 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
}];
}
def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
Arguments<(ins View:$view, AffineMapAttr:$permutation)>,
Results<(outs View)> {
let summary = "transpose operation produces a new view (metadata-only)";
let description = [{
The "linalg.transpose" op produces a linalg.view whose sizes and strides are
a permutation of the original. This is a pure metadata transformation.
Example:
%1 = linalg.transpose %0 (i, j) -> (j, i) : !linalg.view<?x?xf32>
}];
let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *view, "
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
let verifier = [{
if (!permutation().isPermutation())
return emitOpError("expected a permutation map");
if (permutation().getNumDims() != getViewType().getRank())
return emitOpError("expected a permutation map of same rank as the view");
return success();
}];
let extraClassDeclaration = [{
static StringRef getPermutationAttrName() { return "permutation"; }
ViewType getViewType() { return view()->getType().cast<ViewType>(); }
}];
}
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
Results<(outs View)> {

View File

@ -20,6 +20,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/EDSC/Helpers.h"
#include "mlir/IR/AffineExpr.h"
@ -30,8 +32,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/FoldUtils.h"
@ -599,6 +599,39 @@ static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
parser->addTypeToList(viewType, result->types));
}
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result,
Value *view, AffineMapAttr permutation,
ArrayRef<NamedAttribute> attrs) {
// TODO(ntv): once views have static dimensions, compute the permuted type.
build(b, result, view->getType(), view, attrs);
result->addAttribute(TransposeOp::getPermutationAttrName(), permutation);
}
static void print(OpAsmPrinter *p, TransposeOp op) {
*p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
p->printOptionalAttrDict(op.getAttrs(),
{TransposeOp::getPermutationAttrName()});
*p << " : " << op.view()->getType();
}
static ParseResult parseTransposeOp(OpAsmParser *parser,
OperationState *result) {
OpAsmParser::OperandType view;
AffineMapAttr permutation;
Type type;
return failure(parser->parseOperand(view) ||
parser->parseAttribute(permutation,
TransposeOp::getPermutationAttrName(),
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(view, type, result->operands) ||
parser->addTypeToList(type, result->types));
}
//===----------------------------------------------------------------------===//
// ViewOp
//===----------------------------------------------------------------------===//

View File

@ -85,6 +85,20 @@ func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
// -----
func @transpose_not_permutation(%v : !linalg.view<?x?xf32>) {
// expected-error @+1 {{expected a permutation map}}
linalg.transpose %v (i, j) -> (i, i) : !linalg.view<?x?xf32>
}
// -----
func @transpose_bad_rank(%v : !linalg.view<?x?xf32>) {
// expected-error @+1 {{expected a permutation map of same rank as the view}}
linalg.transpose %v (i) -> (i) : !linalg.view<?x?xf32>
}
// -----
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
// expected-error @+2 {{expected view type}}
%r = linalg.range %min:%max:%step : !linalg.range

View File

@ -1,7 +1,7 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
// CHECK: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
// CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
// CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
@ -121,6 +121,13 @@ func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view<?xf32>, %{{.*}}: f32) {
// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view<?xf32>, f32
func @transpose(%arg0: !linalg.view<?x?x?xf32>) {
%0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : !linalg.view<?x?x?xf32>
return
}
// CHECK-LABEL: func @transpose
// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : !linalg.view<?x?x?xf32>
func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
return