[mlir] Introduce ml_program dialect.

Differential Revision: https://reviews.llvm.org/D120203
This commit is contained in:
Stella Laurenzo 2022-04-13 20:16:04 -07:00
parent 836e610d93
commit 61352a580a
16 changed files with 563 additions and 0 deletions

View File

@ -15,6 +15,7 @@ add_subdirectory(Math)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
add_mlir_dialect(MLProgramOps ml_program)
add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)

View File

@ -0,0 +1,34 @@
//===- MLProgram.h - MLProgram dialect ----------------------------*- C++-*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
// MLProgramDialect
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// MLProgram Dialect Operations
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc"
#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_

View File

@ -0,0 +1,33 @@
//===- MLProgramBase.td - Base defs for ml_program dialect --*- tablegen -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLPROGRAM_BASE
#define MLPROGRAM_BASE
include "mlir/IR/OpBase.td"
def MLProgram_Dialect : Dialect {
let name = "ml_program";
let cppNamespace = "::mlir::ml_program";
let description = [{
The MLProgram dialect contains structural operations and types for
defining a compiled Machine-Learning program, as created from common
ML frameworks, such as TensorFlow, PyTorch, JAX, etc. It does not itself
define computation ops common to such frameworks but establishes a common
programming model for establishing modules, functions, globals and
memory model components appropriate for such an abstract level of detail.
This dialect is under active development, and while stability is an
eventual goal, it is not guaranteed at this juncture. Given the early state,
it is recommended to inquire further prior to using this dialect.
}];
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
}
#endif // MLPROGRAM_BASE

View File

@ -0,0 +1,218 @@
//===- MLProgramOps.td - Structural ML Program Ops ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLPROGRAM_OPS
#define MLPROGRAM_OPS
include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
class MLProgram_Op<string mnemonic, list<Trait> traits = []> :
Op<MLProgram_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
def MLProgram_FuncOp : MLProgram_Op<"func", [
CallableOpInterface, FunctionOpInterface, IsolatedFromAbove,
RegionKindInterface, Symbol
]> {
let summary = "Function containing a single `SSACFG` region";
let description = [{
This simple function container represents callables in an ML program where
the body is an `SSACFG` region. It must be terminated by a `return` op which
yields values with the same arity and types as the `FunctionType` results
of the containing `func`.
This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
it cannot represent nested symbols.
Example:
```mlir
ml_program.func private @some_extern(i32) -> i32
ml_program.func @compute(%arg0 : i32) -> i32 {
ml_program.return %arg0 : i32
}
```
}];
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
/// Returns the region on the current operation that is callable. This may
/// return null in the case of an external callable object, e.g. an external
/// function.
::mlir::Region *getCallableRegion() {
return isExternal() ? nullptr : &getBody();
}
/// Returns the results types that the callable region produces when
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
//===------------------------------------------------------------------===//
// RegionKindInterface Methods
//===------------------------------------------------------------------===//
static ::mlir::RegionKind getRegionKind(unsigned index) {
return ::mlir::RegionKind::SSACFG;
}
//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//
bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
CallableOpInterface, FunctionOpInterface, HasOnlyGraphRegion,
IsolatedFromAbove, RegionKindInterface, SingleBlock, Symbol
]> {
let summary = "An function containing a single `Graph` region";
let description = [{
This simple function container represents callables in an ML program where
the body is a `Graph` region containing a single block. It must be
terminated by an `output` op which yields values with the same arity and
types as the `FunctionType` results of the containing `subgraph`.
This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
it cannot represented nested symbols.
Example:
```mlir
ml_program.subgraph private @some_extern(i32) -> i32
ml_program.subgraph @compute(%arg0 : i32) -> i32 {
ml_program.output %arg0 : i32
}
```
}];
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<FunctionType>:$function_type,
OptionalAttr<StrAttr>:$sym_visibility);
let regions = (region AnyRegion:$body);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// CallableOpInterface
//===------------------------------------------------------------------===//
/// Returns the region on the current operation that is callable. This may
/// return null in the case of an external callable object, e.g. an external
/// function.
::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
/// Returns the results types that the callable region produces when
/// executed.
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
//===------------------------------------------------------------------===//
// FunctionOpInterface Methods
//===------------------------------------------------------------------===//
/// Returns the argument types of this function.
ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
/// Returns the result types of this function.
ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
//===------------------------------------------------------------------===//
// SymbolOpInterface Methods
//===------------------------------------------------------------------===//
bool isDeclaration() { return isExternal(); }
}];
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// OutputOp
//===----------------------------------------------------------------------===//
def MLProgram_OutputOp : MLProgram_Op<"output", [
NoSideEffect, HasParent<"SubgraphOp">, ReturnLike, Terminator
]> {
let summary = "Outputs values from a subgraph function";
let description = [{
The `output` operation terminates a subgraph by yielding values
to the caller.
The operation takes variable number of operands and produces no results.
The operand number and types must match the signature of the function
that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<(ins), [{
build($_builder, $_state, llvm::None);
}]>];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
def MLProgram_ReturnOp : MLProgram_Op<"return", [
NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator
]> {
let summary = "Returns values from a `func` function";
let description = [{
The `return` operation terminates a `func` function by yielding values
to the caller.
The operation takes variable number of operands and produces no results.
The operand number and types must match the signature of the function
that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<(ins), [{
build($_builder, $_state, llvm::None);
}]>];
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasVerifier = 1;
}
#endif // MLPROGRAM_OPS

View File

@ -33,6 +33,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@ -77,6 +78,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
linalg::LinalgDialect,
math::MathDialect,
memref::MemRefDialect,
ml_program::MLProgramDialect,
scf::SCFDialect,
omp::OpenMPDialect,
pdl::PDLDialect,

View File

@ -15,6 +15,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)

View File

@ -0,0 +1 @@
add_subdirectory(IR)

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRMLProgram
MLProgramOps.cpp
MLProgramDialect.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram
DEPENDS
MLIRMLProgramOpsIncGen
LINK_LIBS PUBLIC
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
)

View File

@ -0,0 +1,21 @@
//===- MLProgramDialect.cpp - MLProgram dialect implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
using namespace mlir;
using namespace mlir::ml_program;
#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
void ml_program::MLProgramDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
>();
}

View File

@ -0,0 +1,107 @@
//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionImplementation.h"
using namespace mlir;
using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void FuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//
ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false, buildFuncType);
}
void SubgraphOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
}
//===----------------------------------------------------------------------===//
// OutputOp
//===----------------------------------------------------------------------===//
LogicalResult OutputOp::verify() {
auto function = cast<SubgraphOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") outputs " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of output operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
LogicalResult ReturnOp::verify() {
auto function = cast<FuncOp>((*this)->getParentOp());
// The operand number and types must match the function signature.
const auto &results = function.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing function (@"
<< function.getName() << ") returns " << results.size();
for (unsigned i = 0, e = results.size(); i != e; ++i)
if (getOperand(i).getType() != results[i])
return emitError() << "type of return operand " << i << " ("
<< getOperand(i).getType()
<< ") doesn't match function result type ("
<< results[i] << ")"
<< " in function @" << function.getName();
return success();
}

View File

@ -0,0 +1,33 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -verify-diagnostics %s
ml_program.func @ssa_enforced(%arg0 : i32) -> i32 {
// expected-error @+1 {{does not dominate this use}}
%1 = "unregistered.dummy"(%0) : (i32) -> i32
// expected-note @+1 {{operand defined here}}
%0 = "unregistered.dummy"(%arg0) : (i32) -> i32
ml_program.return %0 : i32
}
// -----
ml_program.func @return_arity_match(%arg0 : i32) -> i32 {
// expected-error @+1 {{enclosing function (@return_arity_match) returns 1}}
ml_program.return %arg0, %arg0 : i32, i32
}
// -----
ml_program.func @return_type_match(%arg0 : i64) -> i32 {
// expected-error @+1 {{doesn't match function result}}
ml_program.return %arg0 : i64
}
// -----
ml_program.subgraph @output_arity_match(%arg0 : i32) -> i32 {
// expected-error @+1 {{enclosing function (@output_arity_match) outputs 1}}
ml_program.output %arg0, %arg0 : i32, i32
}
// -----
ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
// expected-error @+1 {{doesn't match function result}}
ml_program.output %arg0 : i64
}

View File

@ -0,0 +1,20 @@
// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s --allow-unregistered-dialect --mlir-print-op-generic | mlir-opt --allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: ml_program.func private @extern_func
ml_program.func private @extern_func(i32) -> i32
// CHECK-LABEL: ml_program.func @defined_func
ml_program.func @defined_func(%arg0 : i32) -> i32 {
ml_program.return %arg0 : i32
}
// CHECK-LABEL: ml_program.subgraph private @extern_subgraph
ml_program.subgraph private @extern_subgraph(i32) -> i32
// CHECK-LABEL: ml_program.subgraph @compute_subgraph
ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
%1 = "unregistered.dummy"(%0) : (i32) -> i32
%0 = "unregistered.dummy"(%arg0) : (i32) -> i32
ml_program.output %0 : i32
}

View File

@ -19,6 +19,7 @@
// CHECK-NEXT: llvm
// CHECK-NEXT: math
// CHECK-NEXT: memref
// CHECK-NEXT: ml_program
// CHECK-NEXT: nvvm
// CHECK-NEXT: omp
// CHECK-NEXT: pdl

View File

@ -5939,6 +5939,7 @@ cc_library(
":LinalgToSPIRV",
":LinalgToStandard",
":LinalgTransforms",
":MLProgramDialect",
":MathDialect",
":MathToLLVM",
":MathToLibm",
@ -8114,6 +8115,77 @@ cc_library(
],
)
##---------------------------------------------------------------------------##
# MLProgram dialect
##---------------------------------------------------------------------------##
td_library(
name = "MLProgramOpsTdFiles",
srcs = [
"include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
"include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
],
includes = ["include"],
deps = [
":CallInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
":FunctionInterfacesTdFiles",
":OpBaseTdFiles",
":RegionKindInterfaceIncGen",
":SideEffectInterfacesTdFiles",
],
)
gentbl_cc_library(
name = "MLProgramOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
"include/mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc",
),
(
["-gen-op-defs"],
"include/mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc",
),
(
["-gen-dialect-decls"],
"include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc",
),
(
["-gen-dialect-defs"],
"include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
deps = [":MLProgramOpsTdFiles"],
)
cc_library(
name = "MLProgramDialect",
srcs = glob([
"lib/Dialect/MLProgram/IR/*.cpp",
"lib/Dialect/MLProgram/IR/*.h",
]),
hdrs = glob([
"include/mlir/Dialect/MLProgram/IR/*.h",
]),
includes = ["include"],
deps = [
":ControlFlowInterfaces",
":IR",
":MLProgramOpsIncGen",
":Pass",
":Support",
"//llvm:Support",
],
)
##---------------------------------------------------------------------------##
# Allocation interfaces
##---------------------------------------------------------------------------##
td_library(
name = "AllocationOpInterfaceTdFiles",
srcs = ["include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"],