forked from OSchip/llvm-project
Add a linalg.transpose op
A linalg.transpose op is a pure metadata operation that takes a view + permutation map and produces another view of the same underlying data, with a different reindexing. This is a pure metadata operation that does not touch the underlying data. Example: ``` %t = linalg.transpose %v (i, j) -> (j, i) : !linalg.view<?x?xf32> ``` PiperOrigin-RevId: 265139429
This commit is contained in:
parent
32052c8417
commit
2c2c9ffd80
|
@ -19,12 +19,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
|
||||||
|
|
||||||
#ifdef LINALG_OPS
|
#ifdef LINALG_OPS
|
||||||
#else
|
#else
|
||||||
#define LINALG_OPS
|
#define LINALG_OPS
|
||||||
|
|
||||||
|
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
|
||||||
|
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||||
|
|
||||||
// Base class for Linalg dialect ops that do not correspond to library calls.
|
// Base class for Linalg dialect ops that do not correspond to library calls.
|
||||||
class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
Op<Linalg_Dialect, mnemonic, traits> {
|
Op<Linalg_Dialect, mnemonic, traits> {
|
||||||
|
@ -398,6 +399,37 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
||||||
|
Arguments<(ins View:$view, AffineMapAttr:$permutation)>,
|
||||||
|
Results<(outs View)> {
|
||||||
|
let summary = "transpose operation produces a new view (metadata-only)";
|
||||||
|
let description = [{
|
||||||
|
The "linalg.transpose" op produces a linalg.view whose sizes and strides are
|
||||||
|
a permutation of the original. This is a pure metadata transformation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
%1 = linalg.transpose %0 (i, j) -> (j, i) : !linalg.view<?x?xf32>
|
||||||
|
}];
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"Builder *b, OperationState *result, Value *view, "
|
||||||
|
"AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
if (!permutation().isPermutation())
|
||||||
|
return emitOpError("expected a permutation map");
|
||||||
|
if (permutation().getNumDims() != getViewType().getRank())
|
||||||
|
return emitOpError("expected a permutation map of same rank as the view");
|
||||||
|
return success();
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getPermutationAttrName() { return "permutation"; }
|
||||||
|
ViewType getViewType() { return view()->getType().cast<ViewType>(); }
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
|
def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
|
||||||
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
|
Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
|
||||||
Results<(outs View)> {
|
Results<(outs View)> {
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||||
#include "mlir/EDSC/Helpers.h"
|
#include "mlir/EDSC/Helpers.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
@ -30,8 +32,6 @@
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Support/STLExtras.h"
|
#include "mlir/Support/STLExtras.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
@ -599,6 +599,39 @@ static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
|
||||||
parser->addTypeToList(viewType, result->types));
|
parser->addTypeToList(viewType, result->types));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TransposeOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result,
|
||||||
|
Value *view, AffineMapAttr permutation,
|
||||||
|
ArrayRef<NamedAttribute> attrs) {
|
||||||
|
// TODO(ntv): once views have static dimensions, compute the permuted type.
|
||||||
|
build(b, result, view->getType(), view, attrs);
|
||||||
|
result->addAttribute(TransposeOp::getPermutationAttrName(), permutation);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter *p, TransposeOp op) {
|
||||||
|
*p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
|
||||||
|
p->printOptionalAttrDict(op.getAttrs(),
|
||||||
|
{TransposeOp::getPermutationAttrName()});
|
||||||
|
*p << " : " << op.view()->getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseTransposeOp(OpAsmParser *parser,
|
||||||
|
OperationState *result) {
|
||||||
|
OpAsmParser::OperandType view;
|
||||||
|
AffineMapAttr permutation;
|
||||||
|
Type type;
|
||||||
|
return failure(parser->parseOperand(view) ||
|
||||||
|
parser->parseAttribute(permutation,
|
||||||
|
TransposeOp::getPermutationAttrName(),
|
||||||
|
result->attributes) ||
|
||||||
|
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||||
|
parser->parseColonType(type) ||
|
||||||
|
parser->resolveOperand(view, type, result->operands) ||
|
||||||
|
parser->addTypeToList(type, result->types));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ViewOp
|
// ViewOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -85,6 +85,20 @@ func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @transpose_not_permutation(%v : !linalg.view<?x?xf32>) {
|
||||||
|
// expected-error @+1 {{expected a permutation map}}
|
||||||
|
linalg.transpose %v (i, j) -> (i, i) : !linalg.view<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @transpose_bad_rank(%v : !linalg.view<?x?xf32>) {
|
||||||
|
// expected-error @+1 {{expected a permutation map of same rank as the view}}
|
||||||
|
linalg.transpose %v (i) -> (i) : !linalg.view<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
|
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
|
||||||
// expected-error @+2 {{expected view type}}
|
// expected-error @+2 {{expected view type}}
|
||||||
%r = linalg.range %min:%max:%step : !linalg.range
|
%r = linalg.range %min:%max:%step : !linalg.range
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
|
||||||
|
|
||||||
// CHECK: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
|
// CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
|
||||||
// CHECK: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
|
// CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
|
||||||
|
|
||||||
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||||
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||||
|
@ -121,6 +121,13 @@ func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
|
||||||
// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view<?xf32>, %{{.*}}: f32) {
|
// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view<?xf32>, %{{.*}}: f32) {
|
||||||
// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view<?xf32>, f32
|
// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view<?xf32>, f32
|
||||||
|
|
||||||
|
func @transpose(%arg0: !linalg.view<?x?x?xf32>) {
|
||||||
|
%0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : !linalg.view<?x?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK-LABEL: func @transpose
|
||||||
|
// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : !linalg.view<?x?x?xf32>
|
||||||
|
|
||||||
func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
|
func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
|
||||||
linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
|
linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue