forked from OSchip/llvm-project
1300 lines
46 KiB
C++
1300 lines
46 KiB
C++
//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements an ODS (and C++) generator from a YAML form
|
|
// derived from the mathematical expression of linalg named ops. Typically a
|
|
// math oriented DSL will be used to export the essential representation to
|
|
// this form, and maintaining the SOT at the math level (versus recreating it
|
|
// in MLIR) is deemed to have systemic value.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "llvm/ADT/Optional.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
#include "llvm/Support/YAMLTraits.h"
|
|
|
|
using namespace mlir;
|
|
|
|
using llvm::yaml::Input;
|
|
using llvm::yaml::MappingTraits;
|
|
using llvm::yaml::ScalarEnumerationTraits;
|
|
using llvm::yaml::ScalarTraits;
|
|
|
|
#define DEBUG_TYPE "linalg-ods-gen"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Mapping structs (correspond to data types in the YAML description).
|
|
// TODO: Since this is a schema/part of the contract, it should be moved to
|
|
// a real header.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
struct LinalgYAMLContext {
|
|
MLIRContext *mlirContext;
|
|
};
|
|
|
|
struct LinalgOpMetadata {
|
|
std::string name;
|
|
std::string cppClassName;
|
|
Optional<std::string> doc;
|
|
SmallVector<std::string> implements;
|
|
SmallVector<std::string> defines;
|
|
};
|
|
|
|
struct SerializedAffineMap {
|
|
AffineMapAttr affineMapAttr;
|
|
|
|
AffineMap affineMap() { return affineMapAttr.getValue(); }
|
|
};
|
|
|
|
enum class LinalgOperandDefKind {
|
|
InputTensor,
|
|
Scalar,
|
|
OutputTensor,
|
|
IndexAttr,
|
|
UnaryFnAttr,
|
|
BinaryFnAttr,
|
|
TypeFnAttr
|
|
};
|
|
|
|
struct LinalgOperandDef {
|
|
std::string name;
|
|
LinalgOperandDefKind kind;
|
|
Optional<std::string> typeVar;
|
|
Optional<SerializedAffineMap> shapeMap;
|
|
Optional<SerializedAffineMap> indexAttrMap;
|
|
Optional<SmallVector<int64_t>> defaultIndices;
|
|
Optional<std::string> defaultFn;
|
|
};
|
|
|
|
enum class LinalgIteratorTypeDef {
|
|
parallel,
|
|
reduction,
|
|
};
|
|
|
|
struct LinalgIndexingMapsConfig {
|
|
Optional<SmallVector<SerializedAffineMap>> staticIndexingMaps;
|
|
};
|
|
|
|
struct ScalarExpression;
|
|
|
|
enum class ScalarFnKind { Unary, Binary, Type };
|
|
|
|
struct ScalarFn {
|
|
ScalarFnKind kind;
|
|
Optional<std::string> fnName;
|
|
Optional<std::string> attrName;
|
|
Optional<std::string> typeVar;
|
|
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
|
// we use a heap allocated vector.
|
|
std::vector<ScalarExpression> operands;
|
|
};
|
|
|
|
struct ScalarExpression {
|
|
Optional<std::string> arg;
|
|
Optional<std::string> constant;
|
|
Optional<int64_t> index;
|
|
Optional<ScalarFn> scalarFn;
|
|
};
|
|
|
|
struct ScalarAssign {
|
|
std::string arg;
|
|
ScalarExpression value;
|
|
};
|
|
|
|
struct LinalgStructuredOpConfig {
|
|
SmallVector<LinalgOperandDef> args;
|
|
LinalgIndexingMapsConfig indexingMaps;
|
|
SmallVector<LinalgIteratorTypeDef> iteratorTypes;
|
|
std::vector<ScalarAssign> assignments;
|
|
};
|
|
|
|
struct LinalgOpConfig {
|
|
Optional<LinalgOpMetadata> metadata;
|
|
Optional<LinalgStructuredOpConfig> structuredOp;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Mapping traits.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef)
|
|
LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap)
|
|
LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef)
|
|
LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign)
|
|
LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression)
|
|
LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig)
|
|
|
|
namespace llvm {
|
|
namespace yaml {
|
|
|
|
/// Top-level type containing op metadata and one of a concrete op type.
|
|
/// Currently, the only defined op type is `structured_op` (maps to
|
|
/// `LinalgStructuredOpConfig`).
|
|
template <>
|
|
struct MappingTraits<LinalgOpConfig> {
|
|
static void mapping(IO &io, LinalgOpConfig &info) {
|
|
io.mapOptional("metadata", info.metadata);
|
|
io.mapOptional("structured_op", info.structuredOp);
|
|
}
|
|
};
|
|
|
|
/// A structured op models (at most) a single contraction by modeling
|
|
/// - A list of named arguments (`LinalgOperandDef`), which can be inputs,
|
|
/// outputs, or index attributes.
|
|
/// - List of indexing maps (see `LinalgIndexingMaps`).
|
|
/// - Iterator types (see `LinalgIteratorTypeDef`).
|
|
/// - List of scalar level assignment (see `ScalarAssign`).
|
|
template <>
|
|
struct MappingTraits<LinalgStructuredOpConfig> {
|
|
static void mapping(IO &io, LinalgStructuredOpConfig &info) {
|
|
io.mapRequired("args", info.args);
|
|
io.mapRequired("indexing_maps", info.indexingMaps);
|
|
io.mapRequired("iterator_types", info.iteratorTypes);
|
|
io.mapRequired("assignments", info.assignments);
|
|
}
|
|
};
|
|
|
|
/// Maps a named tensor, scalar or attribute argument to an operation,
|
|
/// consisting of:
|
|
/// - `name`: Must be unique within the operation.
|
|
/// - `usage`: How the argument is used (input, output, attribute, etc).
|
|
/// - `type_var`: The symbolic type variable that binds to the element or self
|
|
/// type of the tensor or scalar argument, respectively.
|
|
/// - `shape_map`: An optional AffineMap from all op symbols to the shape of
|
|
/// the argument. Only tensor arguments have a `shape_map`. Each shape must
|
|
/// be normalized over the same list of symbols and have no dimension
|
|
/// inputs.
|
|
/// - `index_attr_map`: An optional AffineMap from all op symbols to the
|
|
/// index attribute symbols. During op creation these symbols are replaced
|
|
/// by the corresponding `name` index attribue values. Only index attribute
|
|
/// arguments have an `index_attr_map`.
|
|
/// - `default_indices`: An optional default initialization for index
|
|
/// attribute arguments.
|
|
/// - `default_fn`: An optional default initialization for function attribute
|
|
/// arguments.
|
|
template <>
|
|
struct MappingTraits<LinalgOperandDef> {
|
|
static void mapping(IO &io, LinalgOperandDef &info) {
|
|
io.mapRequired("name", info.name);
|
|
io.mapRequired("kind", info.kind);
|
|
io.mapOptional("type_var", info.typeVar);
|
|
io.mapOptional("shape_map", info.shapeMap);
|
|
io.mapOptional("index_attr_map", info.indexAttrMap);
|
|
io.mapOptional("default_indices", info.defaultIndices);
|
|
io.mapOptional("default_fn", info.defaultFn);
|
|
}
|
|
};
|
|
|
|
/// Usage enum for a named argument.
|
|
template <>
|
|
struct ScalarEnumerationTraits<LinalgOperandDefKind> {
|
|
static void enumeration(IO &io, LinalgOperandDefKind &value) {
|
|
io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor);
|
|
io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar);
|
|
io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor);
|
|
io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr);
|
|
io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr);
|
|
io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr);
|
|
io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr);
|
|
}
|
|
};
|
|
|
|
/// Iterator type enum.
|
|
template <>
|
|
struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
|
static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
|
|
io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
|
|
io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
|
|
}
|
|
};
|
|
|
|
/// Metadata about the op (name, C++ name, and documentation).
|
|
template <>
|
|
struct MappingTraits<LinalgOpMetadata> {
|
|
static void mapping(IO &io, LinalgOpMetadata &info) {
|
|
io.mapRequired("name", info.name);
|
|
io.mapRequired("cpp_class_name", info.cppClassName);
|
|
io.mapOptional("doc", info.doc);
|
|
io.mapOptional("implements", info.implements);
|
|
io.mapOptional("defines", info.defines);
|
|
}
|
|
};
|
|
|
|
/// How the ops indexing maps are produced. Must be one of:
|
|
/// - static_indexing_maps: A static list of AffineMaps, possibly with
|
|
/// some symbols that bind to attributes of the op. Each indexing map must
|
|
/// be normalized over the same list of dimensions, and its symbols must
|
|
/// match the symbols for argument shapes.
|
|
template <>
|
|
struct MappingTraits<LinalgIndexingMapsConfig> {
|
|
static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
|
|
io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
|
|
}
|
|
};
|
|
|
|
/// Models an assignment to a named output.
|
|
/// - The `arg` name must match a named output.
|
|
/// - The `value` is a scalar expression for computing the value to
|
|
/// assign (see `ScalarExpression`).
|
|
template <>
|
|
struct MappingTraits<ScalarAssign> {
|
|
static void mapping(IO &io, ScalarAssign &info) {
|
|
io.mapRequired("arg", info.arg);
|
|
io.mapRequired("value", info.value);
|
|
}
|
|
};
|
|
|
|
/// A scalar expression (RHS of an assignment). Must be one of:
|
|
/// - `scalar_arg`: An operation argument.
|
|
/// - `scalar_const`: A constant definition.
|
|
/// - `scalar_index`: An iteration index.
|
|
/// - `scalar_fn`: A named function (see `ScalarFn`).
|
|
template <>
|
|
struct MappingTraits<ScalarExpression> {
|
|
static void mapping(IO &io, ScalarExpression &info) {
|
|
io.mapOptional("scalar_arg", info.arg);
|
|
io.mapOptional("scalar_const", info.constant);
|
|
io.mapOptional("scalar_index", info.index);
|
|
io.mapOptional("scalar_fn", info.scalarFn);
|
|
}
|
|
};
|
|
|
|
/// Scalar function kind enum.
|
|
template <>
|
|
struct ScalarEnumerationTraits<ScalarFnKind> {
|
|
static void enumeration(IO &io, ScalarFnKind &value) {
|
|
io.enumCase(value, "unary", ScalarFnKind::Unary);
|
|
io.enumCase(value, "binary", ScalarFnKind::Binary);
|
|
io.enumCase(value, "type", ScalarFnKind::Type);
|
|
}
|
|
};
|
|
|
|
/// A scalar expression that evaluates a named function.
|
|
/// Functions are generally "math" level and type polymorphic. Builtin
|
|
/// functions include:
|
|
/// - `add(lhs, rhs)`
|
|
/// - `mul(lhs, rhs)`
|
|
template <>
|
|
struct MappingTraits<ScalarFn> {
|
|
static void mapping(IO &io, ScalarFn &info) {
|
|
io.mapRequired("kind", info.kind);
|
|
io.mapOptional("fn_name", info.fnName);
|
|
io.mapOptional("attr_name", info.attrName);
|
|
io.mapOptional("type_var", info.typeVar);
|
|
io.mapRequired("operands", info.operands);
|
|
}
|
|
};
|
|
|
|
/// Helper mapping which accesses an AffineMapAttr as a serialized string of
|
|
/// the same.
|
|
template <>
|
|
struct ScalarTraits<SerializedAffineMap> {
|
|
static void output(const SerializedAffineMap &value, void *rawYamlContext,
|
|
raw_ostream &out) {
|
|
assert(value.affineMapAttr);
|
|
value.affineMapAttr.print(out);
|
|
}
|
|
static StringRef input(StringRef scalar, void *rawYamlContext,
|
|
SerializedAffineMap &value) {
|
|
assert(rawYamlContext);
|
|
auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
|
|
if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
|
|
.dyn_cast_or_null<AffineMapAttr>())
|
|
value.affineMapAttr = attr;
|
|
else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
|
|
return "could not parse as an affine map attribute";
|
|
return StringRef();
|
|
}
|
|
static QuotingType mustQuote(StringRef) { return QuotingType::None; }
|
|
};
|
|
|
|
} // namespace yaml
|
|
} // namespace llvm
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Generation utilities
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class GenerationContext {
|
|
public:
|
|
GenerationContext(MLIRContext *context, raw_ostream *odsOut,
|
|
raw_ostream *defnOut)
|
|
: context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
|
|
defnOut(defnOut) {}
|
|
|
|
MLIRContext *getContext() { return context; }
|
|
|
|
void setLoc(Location loc) { this->loc = loc; }
|
|
Location getLoc() { return loc; }
|
|
|
|
bool shouldGenerateOds() { return odsOut; }
|
|
bool shouldGenerateDefns() { return defnOut; }
|
|
|
|
raw_ostream &odss() {
|
|
assert(odsOut && "ODS stream not defined");
|
|
return *odsOut;
|
|
}
|
|
|
|
raw_ostream &defns() {
|
|
assert(defnOut && "Definition stream not defined");
|
|
return *defnOut;
|
|
}
|
|
|
|
private:
|
|
MLIRContext *context;
|
|
Location loc;
|
|
raw_ostream *odsOut;
|
|
raw_ostream *defnOut;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static std::string generateCppExpression(SerializedAffineMap self,
|
|
StringRef contextName) {
|
|
std::string printedStr;
|
|
llvm::raw_string_ostream printedSs(printedStr);
|
|
self.affineMapAttr.print(printedSs);
|
|
printedSs.flush();
|
|
|
|
static const char exprFormat[] =
|
|
R"FMT(mlir::parseAttribute("{0}", {1}).cast<AffineMapAttr>().getValue())FMT";
|
|
return llvm::formatv(exprFormat, printedStr, contextName);
|
|
}
|
|
|
|
template <typename Container>
|
|
static std::string interleaveToString(Container &container,
|
|
StringRef separator) {
|
|
std::string result;
|
|
llvm::raw_string_ostream ss(result);
|
|
llvm::interleave(container, ss, separator);
|
|
ss.flush();
|
|
return result;
|
|
}
|
|
|
|
static Optional<int>
|
|
findTensorDefArgIndex(StringRef name, SmallVectorImpl<LinalgOperandDef> &args) {
|
|
for (const auto &it : llvm::enumerate(args)) {
|
|
if (it.value().name == name)
|
|
return it.index();
|
|
}
|
|
return None;
|
|
}
|
|
|
|
// Try to map the TypeVar to a predefined or an argument type.
|
|
static Optional<std::string>
|
|
findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
|
|
// Handle all predefined types.
|
|
if (typeVar == "I32")
|
|
return std::string("helper.getIntegerType(32)");
|
|
if (typeVar == "I64")
|
|
return std::string("helper.getIntegerType(64)");
|
|
if (typeVar == "F32")
|
|
return std::string("helper.getFloat32Type()");
|
|
if (typeVar == "F64")
|
|
return std::string("helper.getFloat64Type()");
|
|
|
|
// Search all argument types.
|
|
for (const auto &it : llvm::enumerate(args)) {
|
|
if (it.value().kind != LinalgOperandDefKind::InputTensor &&
|
|
it.value().kind != LinalgOperandDefKind::Scalar &&
|
|
it.value().kind != LinalgOperandDefKind::OutputTensor)
|
|
continue;
|
|
if (it.value().typeVar.getValue() == typeVar)
|
|
return llvm::formatv("block.getArgument({0}).getType()", it.index())
|
|
.str();
|
|
}
|
|
|
|
return None;
|
|
}
|
|
|
|
static ScalarAssign *findAssignment(StringRef name,
|
|
std::vector<ScalarAssign> &assignments) {
|
|
for (auto &assign : assignments) {
|
|
if (assign.arg == name)
|
|
return &assign;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
// Return true if the operand is a function attribute.
|
|
static bool isFunctionAttribute(LinalgOperandDefKind kind) {
|
|
return kind == LinalgOperandDefKind::UnaryFnAttr ||
|
|
kind == LinalgOperandDefKind::BinaryFnAttr ||
|
|
kind == LinalgOperandDefKind::TypeFnAttr;
|
|
}
|
|
|
|
// Return true if the operand is an attribute.
|
|
static bool isAttribute(LinalgOperandDefKind kind) {
|
|
return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind);
|
|
}
|
|
|
|
// Get the enum name for the given operand kind.
|
|
std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) {
|
|
switch (kind) {
|
|
case LinalgOperandDefKind::UnaryFnAttr:
|
|
return std::string("UnaryFn");
|
|
case LinalgOperandDefKind::BinaryFnAttr:
|
|
return std::string("BinaryFn");
|
|
case LinalgOperandDefKind::TypeFnAttr:
|
|
return std::string("TypeFn");
|
|
default:
|
|
break;
|
|
}
|
|
llvm_unreachable("unsupported function attribute kind");
|
|
}
|
|
|
|
// Get the enum name for the given function kind.
|
|
std::string convertFunctionKindToEnumName(ScalarFnKind kind) {
|
|
switch (kind) {
|
|
case ScalarFnKind::Unary:
|
|
return std::string("UnaryFn");
|
|
case ScalarFnKind::Binary:
|
|
return std::string("BinaryFn");
|
|
case ScalarFnKind::Type:
|
|
return std::string("TypeFn");
|
|
}
|
|
llvm_unreachable("unsupported function kind");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Templates
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// A single line banner format. Parameters:
|
|
// {0}: Single line comment
|
|
static const char bannerFormat[] = R"FMT(
|
|
//===----------------------------------------------------------------------===//
|
|
// {0}
|
|
//===----------------------------------------------------------------------===//
|
|
)FMT";
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Named generic op generation.
|
|
// These ops map at most a single contraction that complies with the limitations
|
|
// of a linalg.generic.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Template for Linalg named ops' ODS definitions. Parameters:
|
|
// {0}: ODS/C++ op name
|
|
// {1}: assembly op mnemonic
|
|
// {2}: op interface list
|
|
// {3}: documentation (summary + description)
|
|
// {4}: op attribute list
|
|
// {5}: builder methods taking standalone attribute parameters
|
|
// {6}: additional method defintions
|
|
// {7}: additional methods for attributes used by indexing maps
|
|
static const char structuredOpOdsHeaderFormat[] = R"FMT(
|
|
//===----------------------------------------------------------------------===//
|
|
// Op definition for {0}
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
|
/*extraInterfaces=*/[{2}])> {
|
|
{3}
|
|
let arguments = (ins
|
|
Variadic<AnyType>:$inputs,
|
|
Variadic<AnyShaped>:$outputs{4}
|
|
);
|
|
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
|
let regions = (region AnyRegion:$region);
|
|
|
|
let skipDefaultBuilders = 1;
|
|
let builders = [
|
|
OpBuilder<
|
|
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
SmallVector<Type> resultTensorTypes;
|
|
copy_if(outputs.getTypes(),
|
|
std::back_inserter(resultTensorTypes),
|
|
[](Type type) {{ return type.isa<RankedTensorType>(); });
|
|
$_state.addTypes(resultTensorTypes);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
$_state.addAttributes(attributes);
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
}]>,
|
|
OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
|
"ValueRange":$outputs,
|
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
$_state.addTypes(resultTensorTypes);
|
|
$_state.addAttributes(attributes);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
}]>,
|
|
OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
|
[{{
|
|
$_state.addOperands(operands);
|
|
$_state.addAttributes(attributes);
|
|
$_state.addTypes(resultTensorTypes);
|
|
(void)$_state.addRegion();
|
|
}]>
|
|
{5}
|
|
];
|
|
let hasCustomAssemblyFormat = 1;
|
|
let hasFolder = 1;
|
|
{6}
|
|
|
|
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
|
// Auto-generated.
|
|
ArrayAttr iterator_types();
|
|
ArrayAttr indexing_maps();
|
|
static void regionBuilder(ImplicitLocOpBuilder &b,
|
|
Block &block, ArrayRef<NamedAttribute> attrs);
|
|
static std::function<void(ImplicitLocOpBuilder &,
|
|
Block &, ArrayRef<NamedAttribute>)>
|
|
getRegionBuilder() {{
|
|
return regionBuilder;
|
|
}
|
|
|
|
// Generic methods.
|
|
static unsigned getNumRegionArgs();
|
|
std::string getLibraryCallName();
|
|
{7}
|
|
}];
|
|
}
|
|
)FMT";
|
|
|
|
// Builder method taking attribute parameters. Parameters:
|
|
// {0}: Class name
|
|
// {1}: Comma interleaved attribute parameters
|
|
// {2}: Attribute initialization
|
|
static const char structuredOpBuilderFormat[] = R"FMT(
|
|
, OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
|
"ValueRange":$outputs, {1},
|
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
$_state.addTypes(resultTensorTypes);
|
|
{2}
|
|
$_state.addAttributes(attributes);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
}]>
|
|
)FMT";
|
|
|
|
// The iterator_types() method for structured ops. Parameters:
|
|
// {0}: Class name
|
|
// {1}: Comma interleaved iterator type names.
|
|
static const char structuredOpIteratorTypesFormat[] =
|
|
R"FMT(
|
|
ArrayAttr {0}::iterator_types() {{
|
|
return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef>{{ {1} });
|
|
}
|
|
)FMT";
|
|
|
|
// The iterator_types() method for rank polymorphic structured ops. Parameters:
|
|
// {0}: Class name
|
|
static const char rankPolyStructuredOpIteratorTypesFormat[] =
|
|
R"FMT(
|
|
ArrayAttr {0}::iterator_types() {{
|
|
int64_t rank = getRank(getOutputOperand(0));
|
|
return Builder(getContext()).getStrArrayAttr(
|
|
SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
|
|
}
|
|
)FMT";
|
|
|
|
// The indexing_maps() method for structured ops. Parameters:
|
|
// {0}: Class name
|
|
// {1}: Comma-separated list of dimension variable names.
|
|
// {2}: Statements
|
|
static const char structuredOpIndexingMapsFormat[] = R"FMT(
|
|
ArrayAttr {0}::indexing_maps() {{
|
|
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
|
|
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
|
|
if (cached)
|
|
return cached;
|
|
|
|
MLIRContext *context = getContext();
|
|
auto symbolBindings = getSymbolBindings(*this);
|
|
SmallVector<AffineMap> maps;
|
|
{2}
|
|
cached = Builder(context).getAffineMapArrayAttr(maps);
|
|
getOperation()->setAttr(memoizeAttr, cached);
|
|
return cached;
|
|
}
|
|
)FMT";
|
|
|
|
// The indexing_maps() method for rank polymorphic structured ops. Parameters:
|
|
// {0}: Class name
|
|
static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT(
|
|
ArrayAttr {0}::indexing_maps() {{
|
|
MLIRContext *context = getContext();
|
|
AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
|
|
AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
|
|
getNumParallelLoops(), context);
|
|
SmallVector<AffineMap> indexingMaps;
|
|
for (OpOperand *opOperand : getInputAndOutputOperands())
|
|
indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
|
|
return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
|
|
}
|
|
)FMT";
|
|
|
|
// Implementations of fold and getEffects.
|
|
// Parameters:
|
|
// {0}: Class name
|
|
const char structuredOpFoldersFormat[] = R"FMT(
|
|
LogicalResult {0}::fold(ArrayRef<Attribute>,
|
|
SmallVectorImpl<OpFoldResult> &) {{
|
|
return foldMemRefCast(*this);
|
|
}
|
|
void {0}::getEffects(SmallVectorImpl<
|
|
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
|
|
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
|
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
|
getGenericEffectsImpl(effects,
|
|
getOperation()->getResults(), inputBuffers, outputBuffers);
|
|
}
|
|
)FMT";
|
|
|
|
// Implementation of parse/print.
|
|
// Parameters:
|
|
// {0}: Class name
|
|
static const char structuredOpParserFormat[] = R"FMT(
|
|
ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{
|
|
return ::parseNamedStructuredOp<{0}>(parser, result);
|
|
}
|
|
void {0}::print(OpAsmPrinter &p) {{
|
|
::printNamedStructuredOp(p, *this);
|
|
}
|
|
)FMT";
|
|
|
|
static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
|
GenerationContext &genContext) {
|
|
if (!genContext.shouldGenerateOds())
|
|
return success();
|
|
|
|
raw_ostream &os = genContext.odss();
|
|
|
|
std::string interfaceNameList;
|
|
std::string attrList;
|
|
std::string attrMethods;
|
|
std::string attrBuilder;
|
|
|
|
std::string doc;
|
|
if (opConfig.metadata->doc) {
|
|
static const char structuredOpDocFmt[] = R"FMT(
|
|
let summary = [{ {0} }];
|
|
let description = [{
|
|
{1}
|
|
}];
|
|
)FMT";
|
|
StringRef summary, description;
|
|
std::tie(summary, description) =
|
|
StringRef(*opConfig.metadata->doc).trim().split('\n');
|
|
doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim());
|
|
}
|
|
|
|
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
|
|
|
std::string definitionList;
|
|
for (const std::string &definition : opConfig.metadata->defines) {
|
|
static const char definitionFmt[] = "let {0} = 1;\n";
|
|
definitionList.append(llvm::formatv(definitionFmt, definition));
|
|
}
|
|
|
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
|
return isAttribute(arg.kind);
|
|
})) {
|
|
SmallVector<std::string> attrDefs;
|
|
SmallVector<std::string> attrParams;
|
|
SmallVector<std::string> attrStmts;
|
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
|
static const char paramFmt[] = "\"Attribute\":${0}";
|
|
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
|
|
// Add the type conversion attributes to the op definition and builders.
|
|
if (isFunctionAttribute(arg.kind)) {
|
|
assert(arg.defaultFn.hasValue());
|
|
std::string enumName = convertOperandKindToEnumName(arg.kind);
|
|
static const char typeFmt[] = "{0}::{1}";
|
|
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
|
|
attrDefs.push_back(llvm::formatv(
|
|
defFmt, llvm::formatv("{0}Attr", enumName),
|
|
llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name));
|
|
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
|
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
|
}
|
|
// Add the index attributes to the op definition and builders.
|
|
if (arg.kind == LinalgOperandDefKind::IndexAttr) {
|
|
assert(arg.indexAttrMap.hasValue());
|
|
assert(arg.defaultIndices.hasValue());
|
|
size_t size = arg.indexAttrMap->affineMap().getNumResults();
|
|
assert(arg.defaultIndices.getValue().size() == size);
|
|
static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
|
|
static const char defFmt[] = "DefaultValuedAttr<{0}, \"{ {1} }\">:${2}";
|
|
std::string defaultVals;
|
|
llvm::raw_string_ostream ss(defaultVals);
|
|
llvm::interleave(
|
|
arg.defaultIndices.getValue(), ss,
|
|
[&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
|
|
", ");
|
|
attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
|
|
ss.str(), arg.name));
|
|
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
|
|
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
|
|
}
|
|
}
|
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
|
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
|
})) {
|
|
attrMethods = R"(
|
|
bool hasDynamicIndexingMaps();
|
|
LogicalResult verifyIndexingMapRequiredAttributes();
|
|
)";
|
|
}
|
|
attrList = ",\n" + llvm::join(attrDefs, ",\n");
|
|
attrBuilder = llvm::formatv(
|
|
structuredOpBuilderFormat, opConfig.metadata->cppClassName,
|
|
llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n"));
|
|
}
|
|
|
|
os << llvm::formatv(structuredOpOdsHeaderFormat,
|
|
opConfig.metadata->cppClassName, opConfig.metadata->name,
|
|
interfaceNameList, doc, attrList, attrBuilder,
|
|
definitionList, attrMethods);
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
|
|
GenerationContext &genContext) {
|
|
if (!genContext.shouldGenerateDefns())
|
|
return success();
|
|
|
|
raw_ostream &os = genContext.defns();
|
|
StringRef className = opConfig.metadata->cppClassName;
|
|
|
|
// Implementation banner.
|
|
std::string bannerComment = llvm::formatv("Implementation of {0}", className);
|
|
os << llvm::formatv(bannerFormat, bannerComment);
|
|
|
|
// Compute the number of scalar and tensor arguments.
|
|
int64_t numOfArgs =
|
|
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
|
return arg.kind == LinalgOperandDefKind::InputTensor ||
|
|
arg.kind == LinalgOperandDefKind::Scalar ||
|
|
arg.kind == LinalgOperandDefKind::OutputTensor;
|
|
});
|
|
|
|
// An operation that accesses only scalars and scalar/rank zero tensors is
|
|
// rank polymorhpic. We implement rank polymorphism by generating different
|
|
// indexing maps and iterators that match the rank of the first output tensor.
|
|
// An operation is rank polymorphic if the iteration domain has rank zero.
|
|
bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty();
|
|
|
|
// Generate the iterator_types() method.
|
|
if (!isRankPolymorphic) {
|
|
std::string iteratorsStr;
|
|
llvm::raw_string_ostream ss(iteratorsStr);
|
|
llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
|
|
[&](LinalgIteratorTypeDef it) {
|
|
switch (it) {
|
|
case LinalgIteratorTypeDef::parallel:
|
|
ss << "getParallelIteratorTypeName()";
|
|
break;
|
|
case LinalgIteratorTypeDef::reduction:
|
|
ss << "getReductionIteratorTypeName()";
|
|
break;
|
|
}
|
|
});
|
|
ss.flush();
|
|
os << llvm::formatv(structuredOpIteratorTypesFormat, className,
|
|
iteratorsStr);
|
|
} else {
|
|
os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className);
|
|
}
|
|
|
|
// Generating the indexing_maps() method.
|
|
if (auto &staticMaps =
|
|
opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
|
|
if (staticMaps->empty())
|
|
return emitError(genContext.getLoc()) << "op has no indexing maps";
|
|
if (!isRankPolymorphic) {
|
|
AffineMap firstMap = staticMaps->front().affineMap();
|
|
|
|
// Symbol bindings.
|
|
{
|
|
// For each symbol, generate a declaration for it, either with an
|
|
// AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
|
|
// an attribute).
|
|
// TODO: Possibly lift into a top-level method.
|
|
static const char structuredOpSymbolBindingsFormat[] = R"FMT(
|
|
static SmallVector<AffineExpr> getSymbolBindings({0} self) {
|
|
MLIRContext *context = self.getContext();
|
|
SmallVector<AffineExpr> exprs;
|
|
{1}
|
|
return exprs;
|
|
}
|
|
)FMT";
|
|
|
|
unsigned symbolCount = firstMap.getNumSymbols();
|
|
SmallVector<std::string> symbolBindings;
|
|
for (unsigned i = 0; i < symbolCount; ++i) {
|
|
symbolBindings.push_back(llvm::formatv(
|
|
" exprs.push_back(getAffineSymbolExpr({0}, context));", i));
|
|
}
|
|
|
|
// Access an index attribute. Parameters:
|
|
// {0}: Attribute name
|
|
// {1}: Symbol position
|
|
// {2}: Attribute index
|
|
static const char structuredOpAccessAttrFormat[] = R"FMT(
|
|
int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}];
|
|
exprs.push_back(getAffineConstantExpr(cst{1}, context));
|
|
)FMT";
|
|
// Update all symbol bindings mapped to an attribute.
|
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
|
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
|
continue;
|
|
assert(arg.indexAttrMap.hasValue());
|
|
for (auto &en :
|
|
llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
|
|
if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
|
|
symbolBindings[symbol.getPosition()] =
|
|
llvm::formatv(structuredOpAccessAttrFormat, arg.name,
|
|
symbol.getPosition(), en.index());
|
|
}
|
|
}
|
|
}
|
|
|
|
std::string symbolBindingsStr;
|
|
llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
|
|
llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
|
|
symbolBindingsSs.flush();
|
|
|
|
os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
|
|
symbolBindingsStr);
|
|
}
|
|
|
|
// Indexing maps.
|
|
{
|
|
unsigned dimCount = firstMap.getNumDims();
|
|
|
|
// Generate a comma-separated list of dim identifiers to be passed to
|
|
// bindDims, ensuring tht AffineExpr identifiers are bound in the right
|
|
// order to the proper AffineDimExpr.
|
|
// This results in vars in scope like: d0, d1, d2...
|
|
SmallVector<unsigned> dimIndices;
|
|
for (unsigned i = 0; i < dimCount; ++i)
|
|
dimIndices.push_back(i);
|
|
std::string dimIdentsStr;
|
|
llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
|
|
llvm::interleaveComma(dimIndices, dimIdentsSs,
|
|
[&](unsigned i) { dimIdentsSs << "d" << i; });
|
|
dimIdentsSs.flush();
|
|
|
|
// Statements to add and simplify each affine map.
|
|
SmallVector<std::string> stmts;
|
|
for (auto &indexingMap : *staticMaps) {
|
|
// TODO: Assert that dim and symbol count match the first.
|
|
stmts.push_back(
|
|
llvm::formatv("maps.push_back({0});",
|
|
generateCppExpression(indexingMap, "context")));
|
|
stmts.push_back(llvm::formatv(
|
|
"maps.back() = "
|
|
"simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
|
|
"symbolBindings, {0}, 0));",
|
|
dimCount));
|
|
}
|
|
|
|
// TODO: This needs to be memoized and/or converted to non-parser based
|
|
// C++ codegen prior to real use.
|
|
os << llvm::formatv(structuredOpIndexingMapsFormat, className,
|
|
dimIdentsStr, interleaveToString(stmts, "\n "));
|
|
}
|
|
} else {
|
|
os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className);
|
|
}
|
|
} else {
|
|
return emitError(genContext.getLoc())
|
|
<< "generating code for non static indexing maps not currently "
|
|
"supported";
|
|
}
|
|
|
|
// getNumRegionArgs()
|
|
{
|
|
// Generates a getNumRegionArgs() method. Parameters:
|
|
// {0}: Class name
|
|
// {1}: Number of region args
|
|
static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
|
|
unsigned {0}::getNumRegionArgs() {{ return {1}; }
|
|
)FMT";
|
|
os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
|
|
numOfArgs);
|
|
}
|
|
|
|
// getLibraryCallName()
|
|
{
|
|
// Generates a getLibraryCallName method. Parameters:
|
|
// {0}: Class name
|
|
static const char structuredOpGetLibraryCallFormat[] = R"FMT(
|
|
std::string {0}::getLibraryCallName() {{
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
)FMT";
|
|
os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
|
|
}
|
|
|
|
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
|
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
|
return arg.kind == LinalgOperandDefKind::IndexAttr;
|
|
})) {
|
|
std::vector<std::string> attrVerifications;
|
|
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
|
|
if (arg.kind != LinalgOperandDefKind::IndexAttr)
|
|
continue;
|
|
assert(arg.indexAttrMap.hasValue());
|
|
// Verify index attribute. Paramters:
|
|
// {0}: Attribute name
|
|
// {1}: Attribute size
|
|
static const char attrFmt[] = R"FMT(
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
|
|
if (!attr.getType().getElementType().isInteger(64))
|
|
return op->emitError("incorrect element type for index attribute '{0}'");
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
|
|
return op->emitError("incorrect shape for index attribute '{0}'");
|
|
}
|
|
)FMT";
|
|
attrVerifications.push_back(llvm::formatv(
|
|
attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
|
|
}
|
|
|
|
// Generates the verifyIndexingMapRequiredAttributes method. Parameters:
|
|
// {0}: Class name
|
|
// {1}: Attribute verification
|
|
static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT(
|
|
bool {0}::hasDynamicIndexingMaps() {{ return true; }
|
|
LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
|
|
Operation *op = getOperation();
|
|
{1}
|
|
return success();
|
|
}
|
|
)FMT";
|
|
os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes,
|
|
className, llvm::join(attrVerifications, "\n"));
|
|
}
|
|
|
|
// regionBuilder()
|
|
{
|
|
// Generates a regionBuilder method. Parameters.
|
|
// {0}: Class name
|
|
// {1}: Number of args
|
|
// {2}: Attributes
|
|
// {3}: Statements
|
|
static const char structuredOpRegionBuilderFormat[] = R"FMT(
|
|
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
|
Block &block, ArrayRef<NamedAttribute> attrs) {{
|
|
assert({1} > 0 && block.getNumArguments() == {1} &&
|
|
"{0} regionBuilder expects {1} (>=0) args");
|
|
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
|
SmallVector<Value> yields;
|
|
{2}
|
|
{3}
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
)FMT";
|
|
auto &args = opConfig.structuredOp->args;
|
|
auto &assignments = opConfig.structuredOp->assignments;
|
|
size_t generatedAssignmentCount = 0;
|
|
int localCounter = 0;
|
|
SmallVector<std::string> attrs;
|
|
SmallVector<std::string> stmts;
|
|
for (LinalgOperandDef &arg : args) {
|
|
if (!isFunctionAttribute(arg.kind))
|
|
continue;
|
|
// Obtain the type function attribute values. Parameters.
|
|
// {0}: enum name
|
|
// {1}: attribute name
|
|
// {2}: default type function name
|
|
static const char attrDef[] = R"FMT(
|
|
{0} {1}Val = {0}::{2};
|
|
auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{
|
|
return attr.getName() == "{1}"; });
|
|
if ({1}Iter != attrs.end()) {{
|
|
if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>())
|
|
{1}Val = attr.getValue();
|
|
}
|
|
)FMT";
|
|
std::string enumName = convertOperandKindToEnumName(arg.kind);
|
|
attrs.push_back(
|
|
llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn));
|
|
}
|
|
for (LinalgOperandDef &arg : args) {
|
|
if (arg.kind != LinalgOperandDefKind::OutputTensor)
|
|
continue;
|
|
|
|
// Find the assignment that correlates with the argument.
|
|
ScalarAssign *assignment = findAssignment(arg.name, assignments);
|
|
if (!assignment)
|
|
return emitError(genContext.getLoc())
|
|
<< "no assignment found for output argument " << arg.name;
|
|
++generatedAssignmentCount;
|
|
|
|
// Recursively generate the expression.
|
|
std::function<Optional<std::string>(ScalarExpression &)>
|
|
generateExpression =
|
|
[&](ScalarExpression &expression) -> Optional<std::string> {
|
|
if (expression.arg) {
|
|
// Argument reference.
|
|
Optional<int> argIndex = findTensorDefArgIndex(*expression.arg, args);
|
|
if (!argIndex) {
|
|
emitError(genContext.getLoc())
|
|
<< "scalar argument not defined on the op: " << *expression.arg;
|
|
return None;
|
|
}
|
|
return std::string(
|
|
llvm::formatv("block.getArgument({0})", *argIndex));
|
|
}
|
|
if (expression.constant) {
|
|
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
|
stmts.push_back(
|
|
llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT",
|
|
cppIdent, expression.constant));
|
|
return cppIdent;
|
|
}
|
|
if (expression.index) {
|
|
// Access an iteration index.
|
|
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
|
stmts.push_back(llvm::formatv("Value {0} = helper.index({1});",
|
|
cppIdent, *expression.index));
|
|
return cppIdent;
|
|
}
|
|
if (expression.scalarFn) {
|
|
std::string enumName =
|
|
convertFunctionKindToEnumName(expression.scalarFn->kind);
|
|
|
|
// Get the function or attribute name.
|
|
assert(expression.scalarFn->fnName || expression.scalarFn->attrName);
|
|
std::string funcType;
|
|
if (expression.scalarFn->fnName) {
|
|
funcType = llvm::formatv("{0}::{1}", enumName,
|
|
*expression.scalarFn->fnName);
|
|
}
|
|
if (expression.scalarFn->attrName) {
|
|
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
|
return isFunctionAttribute(arg.kind) &&
|
|
arg.name == expression.scalarFn->attrName.getValue();
|
|
})) {
|
|
emitError(genContext.getLoc())
|
|
<< "missing function attribute "
|
|
<< expression.scalarFn->attrName.getValue();
|
|
}
|
|
funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
|
|
}
|
|
assert(!funcType.empty());
|
|
|
|
// Add the optional type parameter to the operands.
|
|
SmallVector<std::string> operandCppValues;
|
|
if (expression.scalarFn->kind == ScalarFnKind::Type) {
|
|
assert(expression.scalarFn->typeVar.hasValue());
|
|
Optional<std::string> typeCppValue =
|
|
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
|
|
if (!typeCppValue) {
|
|
emitError(genContext.getLoc())
|
|
<< "type variable " << expression.scalarFn->typeVar.getValue()
|
|
<< ", used in a type conversion, must map to a predefined or "
|
|
<< "an argument type but it does not";
|
|
return None;
|
|
}
|
|
operandCppValues.push_back(typeCppValue.getValue());
|
|
}
|
|
|
|
// Collect the scalar operands.
|
|
for (ScalarExpression &operand : expression.scalarFn->operands) {
|
|
auto operandCppValue = generateExpression(operand);
|
|
if (!operandCppValue)
|
|
return None;
|
|
operandCppValues.push_back(*operandCppValue);
|
|
}
|
|
|
|
// Call the function builder.
|
|
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
|
stmts.push_back(llvm::formatv(
|
|
"Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
|
|
funcType, interleaveToString(operandCppValues, ", ")));
|
|
return cppIdent;
|
|
}
|
|
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
|
return None;
|
|
};
|
|
Optional<std::string> cppValue = generateExpression(assignment->value);
|
|
if (!cppValue)
|
|
return failure();
|
|
stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue));
|
|
}
|
|
|
|
if (generatedAssignmentCount != assignments.size())
|
|
return emitError(genContext.getLoc())
|
|
<< "mismatched number of assignments vs output arguments";
|
|
|
|
os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs,
|
|
interleaveToString(attrs, "\n "),
|
|
interleaveToString(stmts, "\n "));
|
|
}
|
|
|
|
// Parser and printer.
|
|
os << llvm::formatv(structuredOpParserFormat, className);
|
|
|
|
// Canonicalizers and folders.
|
|
os << llvm::formatv(structuredOpFoldersFormat, className);
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult generateOp(LinalgOpConfig &opConfig,
|
|
GenerationContext &genContext) {
|
|
// Switch on op type being generated.
|
|
if (opConfig.structuredOp) {
|
|
return success(
|
|
succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
|
|
succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
|
|
}
|
|
return emitError(genContext.getLoc()) << "unsupported operation type";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Command line options and main
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static llvm::cl::opt<std::string>
|
|
inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
|
llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
|
|
|
|
static llvm::cl::opt<std::string>
|
|
outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
|
|
llvm::cl::value_desc("filename"), llvm::cl::init(""));
|
|
|
|
static llvm::cl::opt<std::string>
|
|
outputCppImplFilename("o-impl",
|
|
llvm::cl::desc("C++ implementation file name"),
|
|
llvm::cl::value_desc("filename"), llvm::cl::init(""));
|
|
|
|
int main(int argc, char **argv) {
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
|
|
|
|
// Set up the input file.
|
|
std::string errorMessage;
|
|
std::unique_ptr<llvm::MemoryBuffer> file =
|
|
mlir::openInputFile(inputFilename, &errorMessage);
|
|
if (!file) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return 1;
|
|
}
|
|
|
|
MLIRContext mlirContext;
|
|
LinalgYAMLContext yamlContext{&mlirContext};
|
|
|
|
std::vector<LinalgOpConfig> opConfigs;
|
|
|
|
// Parse input.
|
|
Input yin(file->getBuffer(), &yamlContext);
|
|
yin >> opConfigs;
|
|
|
|
if (yin.error())
|
|
return 1;
|
|
|
|
// Open output files.
|
|
std::unique_ptr<llvm::ToolOutputFile> outputOdsDecl;
|
|
if (!outputOdsDeclFilename.empty()) {
|
|
outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
|
|
if (!outputOdsDecl) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
std::unique_ptr<llvm::ToolOutputFile> outputCppImpl;
|
|
if (!outputCppImplFilename.empty()) {
|
|
outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
|
|
if (!outputCppImpl) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
if (!outputOdsDecl && !outputCppImpl) {
|
|
llvm::errs() << "error: No output files specified\n";
|
|
return 1;
|
|
}
|
|
|
|
// Generate.
|
|
GenerationContext genContext(&mlirContext,
|
|
outputOdsDecl ? &outputOdsDecl->os() : nullptr,
|
|
outputCppImpl ? &outputCppImpl->os() : nullptr);
|
|
|
|
for (auto &opConfig : opConfigs) {
|
|
if (!opConfig.metadata) {
|
|
emitError(genContext.getLoc())
|
|
<< "missing operation metadata on subsequent op";
|
|
return 1;
|
|
}
|
|
|
|
genContext.setLoc(NameLoc::get(
|
|
StringAttr::get(&mlirContext, opConfig.metadata->cppClassName)));
|
|
if (failed(generateOp(opConfig, genContext))) {
|
|
return 1;
|
|
}
|
|
}
|
|
|
|
if (outputOdsDecl)
|
|
outputOdsDecl->keep();
|
|
if (outputCppImpl)
|
|
outputCppImpl->keep();
|
|
|
|
return 0;
|
|
}
|