llvm-project/mlir/g3doc/Tutorials/Linalg/Ch-2.md

4.1 KiB

Linalg Part 2: Compute Operations

We now describe the main compute operations linalg.dot, linalg.matvec and linalg.matmul. These operations are a subset of a more general tensor contraction class of operations. In this tutorial, we define a tensor contraction as a generic operation which:

  1. Reads a getNumInputs() number of input ssa-values of ViewType.
  2. Writes into a getNumOutputs() number of input ssa-values of ViewType.
  3. Can be written in scalar loop form as a perfect loop nest with getNumParallelDims() outermost loops with parallel semantics and getNumReductionDims() innermost dimensions with reduction semantics.
  4. Has a scalar form that is specific to each particular specialization.

Operation Definition

In this section we do not discuss the specific properties of tensor contractions but only define the linalg.dot, linalg.matvec and linalg.matmul operations as opaque operations with side-effects (reads and writes into input and output views).

These operations take input and output views of the proper rank as operands. For the purpose of illustration, assume all the elemental types are fixed to f32. The invariants checked by the op-specific verify functions are:

  1. linalg.dot reads two one-dimensional view<?xf32> and writes a zero-dimensional view<f32> (i.e. a scalar).
  2. linalg.matvec reads a two-dimensional view<?x?xf32> matrix and a one dimensional view<?xf32> vector and writes a one-dimensional view<?xf32>.
  3. linalg.matmul reads two two-dimensional view<?x?xf32> matrices and writes a two-dimensional view<?x?xf32> matrix.

Other operations on higher-order tensors can be defined and would behave similarly with respect to IR verification and interactions with ViewType operands. The generic form of verification and pretty-printing is defined on the TensorContractionBase class.

Note that in order to give TensorContractionBase access to the mlir::Op in a generic fashion, we use a CRTP pattern where:

template <class ConcreteOp> class TensorContractionBase { ... };

class DotOp : public TensorContractionBase<DotOp>,
              public mlir::Op<DotOp, mlir::OpTrait::VariadicOperands,
                              mlir::OpTrait::ZeroResult> { ... }

In turn, this allows the generic behavior of TensorContractionBase to be implemented once and reused across ops. The generic verify method is:

template <class ConcreteOp>
mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
  auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
  if (getNumInputs() <= 0)
    concreteOp->emitOpError("expected at least one input");
  ...
}

Each specialized operation then calls into the generic verification method before applying its own verification steps.

LogicalResult linalg::MatmulOp::verify() {
  if (failed(TensorContractionBaseType::verify()))
    return failure();
  auto *A = getOperand(0), *B = getOperand(1), *C = getOperand(2);
  unsigned index = 0;
  for (auto *v : {A, B, C}) {
    if (getViewRank(v) != 2)
      return emitOpError("operand " + Twine(index++) + " must be of rank 2");
  }
  return success();
}

Note that in a more future-proof design, it is considered a best practice for operations which share similarity in their behavior to be defined with Tablegen.

All TensorContractionBase ops pretty-print similarly. In the case of linalg.matmul the pretty-printed form is: linalg.matmul(%A, %B, %C) : view<?x?xf32>

Putting it all together

The example demonstrates how to construct some simple IR snippets that pass through the verifier checks. The example demonstrate how to allocate three memref buffers from index function arguments and use those buffers as backing data structures for views that get passed to dot, matvec and matmul operations.