4.1 KiB
Linalg Part 2: Compute Operations
We now describe the main compute operations linalg.dot
, linalg.matvec
and
linalg.matmul
. These operations are a subset of a more general tensor
contraction
class
of operations. In this tutorial, we define a tensor contraction as a generic
operation which:
- Reads a
getNumInputs()
number of input ssa-values of ViewType. - Writes into a
getNumOutputs()
number of input ssa-values of ViewType. - Can be written in scalar loop form as a perfect loop nest with
getNumParallelDims()
outermost loops with parallel semantics andgetNumReductionDims()
innermost dimensions with reduction semantics. - Has a scalar form that is specific to each particular specialization.
Operation Definition
In this section we do not discuss the specific properties of tensor contractions
but only define the linalg.dot
, linalg.matvec
and linalg.matmul
operations
as opaque operations with side-effects (reads and writes into input and output
views).
These operations take input and output views of the proper rank as operands. For
the purpose of illustration, assume all the elemental types are fixed to f32
.
The invariants checked by the op-specific
verify
functions are:
linalg.dot
reads two one-dimensionalview<?xf32>
and writes a zero-dimensionalview<f32>
(i.e. a scalar).linalg.matvec
reads a two-dimensionalview<?x?xf32>
matrix and a one dimensionalview<?xf32>
vector and writes a one-dimensionalview<?xf32>
.linalg.matmul
reads two two-dimensionalview<?x?xf32>
matrices and writes a two-dimensionalview<?x?xf32>
matrix.
Other operations on higher-order tensors can be defined and would behave
similarly with respect to IR verification and interactions with ViewType
operands. The generic form of verification and pretty-printing is defined on the
TensorContractionBase
class.
Note that in order to give TensorContractionBase access to the mlir::Op in a generic fashion, we use a CRTP pattern where:
template <class ConcreteOp> class TensorContractionBase { ... };
class DotOp : public TensorContractionBase<DotOp>,
public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult> { ... }
In turn, this allows the generic behavior of TensorContractionBase to be implemented once and reused across ops. The generic verify method is:
template <class ConcreteOp>
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
if (getNumInputs() <= 0)
concreteOp->emitOpError("expected at least one input");
...
}
Each specialized operation then calls into the generic verification method before applying its own verification steps.
LogicalResult linalg::MatmulOp::verify() {
if (failed(TensorContractionBaseType::verify()))
return failure();
auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
unsigned index = 0;
for (auto *v : {A, B, C}) {
if (getViewRank(v) != 2)
return emitOpError("operand " + Twine(index++) + " must be of rank 2");
}
return success();
}
Note that in a more future-proof design, it is considered a best practice for operations which share similarity in their behavior to be defined with Tablegen.
All TensorContractionBase ops pretty-print similarly. In the case of
linalg.matmul
the pretty-printed form is: linalg.matmul(%A, %B, %C) : view<?x?xf32>
Putting it all together
The
example
demonstrates how to construct some simple IR snippets that pass through the
verifier checks. The example demonstrate how to allocate three memref buffers
from index
function arguments and use those buffers as backing data structures
for views that get passed to dot
, matvec
and matmul
operations.