forked from OSchip/llvm-project
Add a linalg.transpose op
A linalg.transpose op is a pure metadata operation that takes a view + permutation map and produces another view of the same underlying data, with a different reindexing. This is a pure metadata operation that does not touch the underlying data. Example: ``` %t = linalg.transpose %v (i, j) -> (j, i) : !linalg.view<?x?xf32> ``` PiperOrigin-RevId: 265139429
This commit is contained in:
parent
32052c8417
commit
2c2c9ffd80
|
@ -19,12 +19,13 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
|
||||
#ifdef LINALG_OPS
|
||||
#else
|
||||
#define LINALG_OPS
|
||||
|
||||
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
|
||||
// Base class for Linalg dialect ops that do not correspond to library calls.
|
||||
class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Linalg_Dialect, mnemonic, traits> {
|
||||
|
@ -398,6 +399,37 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
||||
Arguments<(ins View:$view, AffineMapAttr:$permutation)>,
|
||||
Results<(outs View)> {
|
||||
let summary = "transpose operation produces a new view (metadata-only)";
|
||||
let description = [{
|
||||
The "linalg.transpose" op produces a linalg.view whose sizes and strides are
|
||||
a permutation of the original. This is a pure metadata transformation.
|
||||
|
||||
Example:
|
||||
|
||||
%1 = linalg.transpose %0 (i, j) -> (j, i) : !linalg.view<?x?xf32>
|
||||
}];
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState *result, Value *view, "
|
||||
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
|
||||
|
||||
let verifier = [{
|
||||
if (!permutation().isPermutation())
|
||||
return emitOpError("expected a permutation map");
|
||||
if (permutation().getNumDims() != getViewType().getRank())
|
||||
return emitOpError("expected a permutation map of same rank as the view");
|
||||
return success();
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPermutationAttrName() { return "permutation"; }
|
||||
ViewType getViewType() { return view()->getType().cast<ViewType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
|
||||
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
|
||||
Results<(outs View)> {
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/EDSC/Helpers.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -30,8 +32,6 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
@ -599,6 +599,39 @@ static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
|
|||
parser->addTypeToList(viewType, result->types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result,
|
||||
Value *view, AffineMapAttr permutation,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
// TODO(ntv): once views have static dimensions, compute the permuted type.
|
||||
build(b, result, view->getType(), view, attrs);
|
||||
result->addAttribute(TransposeOp::getPermutationAttrName(), permutation);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter *p, TransposeOp op) {
|
||||
*p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
|
||||
p->printOptionalAttrDict(op.getAttrs(),
|
||||
{TransposeOp::getPermutationAttrName()});
|
||||
*p << " : " << op.view()->getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransposeOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
OpAsmParser::OperandType view;
|
||||
AffineMapAttr permutation;
|
||||
Type type;
|
||||
return failure(parser->parseOperand(view) ||
|
||||
parser->parseAttribute(permutation,
|
||||
TransposeOp::getPermutationAttrName(),
|
||||
result->attributes) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type) ||
|
||||
parser->resolveOperand(view, type, result->operands) ||
|
||||
parser->addTypeToList(type, result->types));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -85,6 +85,20 @@ func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @transpose_not_permutation(%v : !linalg.view<?x?xf32>) {
|
||||
// expected-error @+1 {{expected a permutation map}}
|
||||
linalg.transpose %v (i, j) -> (i, i) : !linalg.view<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @transpose_bad_rank(%v : !linalg.view<?x?xf32>) {
|
||||
// expected-error @+1 {{expected a permutation map of same rank as the view}}
|
||||
linalg.transpose %v (i) -> (i) : !linalg.view<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
|
||||
// expected-error @+2 {{expected view type}}
|
||||
%r = linalg.range %min:%max:%step : !linalg.range
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
|
||||
// CHECK: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
|
||||
// CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
|
||||
// CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
|
||||
|
||||
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
|
@ -121,6 +121,13 @@ func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
|
|||
// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view<?xf32>, %{{.*}}: f32) {
|
||||
// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view<?xf32>, f32
|
||||
|
||||
func @transpose(%arg0: !linalg.view<?x?x?xf32>) {
|
||||
%0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : !linalg.view<?x?x?xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @transpose
|
||||
// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : !linalg.view<?x?x?xf32>
|
||||
|
||||
func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
|
||||
linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue