[mlir][OpDSL] Refactor function handling.

Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn).

Depends On D119718

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120108
This commit is contained in:
gysit 2022-02-25 15:04:38 +00:00
parent 9f5f08476e
commit 4d4cb17da8
7 changed files with 537 additions and 403 deletions

View File

@ -133,17 +133,26 @@ class TensorUse(TensorExpression):
f"[{', '.join([repr(i) for i in self.indices])}]")
class TensorArithFn(TensorExpression):
"""Application of an arithmetic function."""
class TensorFn(TensorExpression):
"""Application of a tensor function."""
def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
self.arith_fn = arith_fn
self.args = tuple(args)
def __init__(self, kind: "FunctionKind", name: Optional[str],
operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
args: Sequence[TensorExpression]):
if bool(name) + bool(operand_def) != 1:
raise ValueError("One of 'name', 'operand_def' must be specified")
self.name = name
self.kind = kind
self.operand_def = operand_def
self.type_var = type_var
self.args = args
def to_scalar_expression(self) -> ScalarExpression:
return ScalarArithFn(self.arith_fn.fn_name,
*[arg.to_scalar_expression() for arg in self.args
]).expr()
if self.operand_def:
assert self.operand_def.name, "TensorFn not registered with an op"
attr_name = self.operand_def.name if self.operand_def else None
args = [arg.to_scalar_expression() for arg in self.args]
return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
super().visit_tensor_exprs(callback)
@ -151,37 +160,9 @@ class TensorArithFn(TensorExpression):
arg.visit_tensor_exprs(callback)
def __repr__(self):
return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
class TensorTypeFn(TensorExpression):
"""Application of a type conversion function."""
def __init__(self, type_fn: Optional["TypeFn"],
operand_def: Optional["OperandDef"], type_var: TypeVar,
arg: TensorExpression):
if bool(type_fn) + bool(operand_def) != 1:
raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
self.type_fn = type_fn
self.operand_def = operand_def
self.type_var = type_var
self.arg = arg
def to_scalar_expression(self) -> ScalarExpression:
if self.operand_def:
assert self.operand_def.name, "TypeFnAttr not registered with an op"
fn_name = self.type_fn.fn_name if self.type_fn else None
attr_name = self.operand_def.name if self.operand_def else None
return ScalarTypeFn(fn_name, attr_name, self.type_var,
self.arg.to_scalar_expression()).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
super().visit_tensor_exprs(callback)
self.arg.visit_tensor_exprs(callback)
def __repr__(self):
return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
f"({self.type_var}, {self.arg})")
name = self.operand_def.name if self.operand_def else self.name
return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
f"args={', '.join(repr(a) for a in self.args)})")
class TensorReduceFn(TensorExpression):
@ -194,7 +175,7 @@ class TensorReduceFn(TensorExpression):
args: Sequence[TensorExpression]):
self.reduce_use = reduce_use
self.lhs = None # type: Optional[TensorUse]
self.args = tuple(args)
self.args = args
def to_scalar_expression(self) -> ScalarExpression:
if self.lhs is None:
@ -202,7 +183,8 @@ class TensorReduceFn(TensorExpression):
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
None, full_args).expr()
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
for arg in self.args:
@ -259,6 +241,11 @@ class index(TensorExpression):
###############################################################################
class FunctionKind(Enum):
ARITH = 0
TYPE = 1
class TypeFnType:
"""Type conversion function.
@ -269,8 +256,8 @@ class TypeFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
return TensorTypeFn(self, None, type_var, arg)
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
def __repr__(self):
return f"{self.fn_name}"
@ -301,8 +288,8 @@ class ArithFnType:
def __init__(self, fn_name: str):
self.fn_name = fn_name
def __call__(self, *args) -> "TensorArithFn":
return TensorArithFn(self, args)
def __call__(self, *args) -> "TensorFn":
return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
def __repr__(self):
return f"{self.fn_name}"
@ -562,8 +549,8 @@ class TypeFnAttrDef:
self.operand_def = OperandDef(
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
return TensorTypeFn(None, self.operand_def, type_var, arg)
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
###############################################################################

View File

@ -270,19 +270,19 @@ class _BodyBuilder:
dim_attr = IntegerAttr.get(
IntegerType.get_signless(64), expr.scalar_index.dim)
return linalg.IndexOp(dim_attr).result
elif expr.arith_fn:
fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
operand_values = [
self.expression(operand) for operand in expr.arith_fn.operands
self.expression(operand) for operand in expr.scalar_fn.operands
]
return fn(*operand_values)
elif expr.type_fn:
fn_name = expr.type_fn.fn_name
if expr.type_fn.attr_name:
fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
fn_name = expr.scalar_fn.fn_name
if expr.scalar_fn.attr_name:
fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
fn = self._get_function(f"_typefn_{fn_name}")
operand = self.expression(expr.type_fn.operand)
return fn(expr.type_fn.type_var.name, operand)
operand_value = self.expression(expr.scalar_fn.operands[0])
return fn(expr.scalar_fn.type_var.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
def yield_outputs(self, *output_names: str):

View File

@ -15,13 +15,13 @@ can be easily consumed from the C++ side, not necessarily for ergonomics.
from typing import Optional, Sequence
from .yaml_helper import *
from .comprehension import *
from .types import *
from .yaml_helper import *
__all__ = [
"ScalarAssign",
"ScalarArithFn",
"ScalarTypeFn",
"ScalarFn",
"ScalarArg",
"ScalarConst",
"ScalarIndex",
@ -29,36 +29,27 @@ __all__ = [
]
class ScalarArithFn:
"""A type of ScalarExpression that applies an arithmetic function."""
class ScalarFn:
"""A type of ScalarExpression that applies a function."""
def __init__(self, fn_name: str, *operands: "ScalarExpression"):
self.fn_name = fn_name
self.operands = operands
def expr(self) -> "ScalarExpression":
return ScalarExpression(arith_fn=self)
def __repr__(self):
return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
class ScalarTypeFn:
"""A type of ScalarExpression that applies a type conversion function."""
def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
type_var: TypeVar, operand: "ScalarExpression"):
def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
attr_name: Optional[str], type_var: Optional["TypeVar"],
operands: Sequence["ScalarExpression"]):
if bool(fn_name) + bool(attr_name) != 1:
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
self.kind = kind
self.fn_name = fn_name
self.attr_name = attr_name
self.type_var = type_var
self.operand = operand
self.operands = operands
def expr(self) -> "ScalarExpression":
return ScalarExpression(type_fn=self)
return ScalarExpression(scalar_fn=self)
def __repr__(self):
return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
f"({self.type_var}, {self.operand})")
name = self.fn_name if self.fn_name else self.attr_name
return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
f"operands=[{', '.join(self.operands)}])")
class ScalarArg:
@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject):
"""An expression on scalar values.
Can be one of:
- ScalarArithFn
- ScalarTypeFn
- ScalarFn
- ScalarArg
- ScalarConst
- ScalarIndex
- ScalarSymbolicCast
"""
yaml_tag = "!ScalarExpression"
def __init__(self,
arith_fn: Optional[ScalarArithFn] = None,
type_fn: Optional[ScalarTypeFn] = None,
scalar_fn: Optional[ScalarFn] = None,
scalar_arg: Optional[ScalarArg] = None,
scalar_const: Optional[ScalarConst] = None,
scalar_index: Optional[ScalarIndex] = None):
if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
bool(scalar_index)) != 1:
raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
"'scalar_const', 'scalar_index', must be specified")
self.arith_fn = arith_fn
self.type_fn = type_fn
raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
"'scalar_index' must be specified")
self.scalar_fn = scalar_fn
self.scalar_arg = scalar_arg
self.scalar_const = scalar_const
self.scalar_index = scalar_index
def to_yaml_custom_dict(self):
if self.arith_fn:
return dict(
arith_fn=dict(
fn_name=self.arith_fn.fn_name,
operands=list(self.arith_fn.operands),
))
if self.type_fn:
# Note that even though operands must be arity 1, we write it the
# same way as for apply because it allows handling code to be more
# generic vs having a special form.
type_fn_dict = dict(
type_var=self.type_fn.type_var.name,
operands=[self.type_fn.operand],
)
if self.type_fn.fn_name:
type_fn_dict["fn_name"] = self.type_fn.fn_name
if self.type_fn.attr_name:
type_fn_dict["attr_name"] = self.type_fn.attr_name
return dict(type_fn=type_fn_dict)
if self.scalar_fn:
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
if self.scalar_fn.fn_name:
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
if self.scalar_fn.attr_name:
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
if self.scalar_fn.type_var:
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
return dict(scalar_fn=scalar_fn_dict)
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
elif self.scalar_const:

View File

@ -39,23 +39,26 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
arith_fn:
scalar_fn:
kind: arith
fn_name: add
operands:
- !ScalarExpression
type_fn:
scalar_fn:
kind: type
attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_const: '42 : i64'
attr_name: cast
- !ScalarExpression
type_fn:
scalar_fn:
kind: type
attr_name: cast
type_var: T
operands:
- !ScalarExpression
scalar_index: 1
attr_name: cast
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
@ -236,7 +239,8 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarAssign
arg: O
value: !ScalarExpression
type_fn:
scalar_fn:
kind: type
fn_name: cast
type_var: U
operands:

View File

@ -9,22 +9,24 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK: -
# CHECK: arg: C
# CHECK: value:
# CHECK: arith_fn:
# CHECK: scalar_fn:
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: arith_fn:
# CHECK: scalar_fn:
# CHECK: fn_name: mul
# CHECK: operands:
# CHECK: type_fn:
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: attr_name: cast
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: A
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: attr_name: cast
# CHECK: type_fn:
# CHECK: type_var: U
# CHECK: operands:
# CHECK: scalar_arg: B
# CHECK: attr_name: cast
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
@ -39,21 +41,28 @@ def matmul(
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
# CHECK: arith_fn:
# CHECK: scalar_fn:
# CHECK: kind: arith
# CHECK: fn_name: sub
# CHECK: operands:
# CHECK: arith_fn:
# CHECK: scalar_fn:
# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: type_fn:
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '3.1415926535897931 : f64'
# CHECK: type_fn:
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '42 : i64'
# CHECK: type_fn:
# CHECK: scalar_fn:
# CHECK: kind: type
# CHECK: fn_name: cast
# CHECK: type_var: T
# CHECK: operands:
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
@ -70,7 +79,8 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
# CHECK: assignments:
# CHECK: -
# CHECK: arg: O
# CHECK: arith_fn:
# CHECK: scalar_fn:
# CHECK: kind: arith
# CHECK: fn_name: add
# CHECK: operands:
# CHECK: scalar_index: 1

View File

@ -90,28 +90,23 @@ struct LinalgIndexingMapsConfig {
struct ScalarExpression;
struct ScalarArithFn {
std::string fnName;
// NOTE: Must be pure heap allocated container (not SmallVector)
// due to recursive data type.
std::vector<ScalarExpression> operands;
};
enum class ScalarFnKind { Arith, Type };
struct ScalarTypeFn {
std::string typeVar;
struct ScalarFn {
ScalarFnKind kind;
Optional<std::string> fnName;
Optional<std::string> attrName;
Optional<std::string> typeVar;
// NOTE: This must be of arity 1, but to break the self-referential cycle,
// we use a heap allocated vector.
std::vector<ScalarExpression> operands;
Optional<std::string> fnName;
Optional<std::string> attrName;
};
struct ScalarExpression {
Optional<std::string> arg;
Optional<std::string> constant;
Optional<int64_t> index;
Optional<ScalarArithFn> arithFn;
Optional<ScalarTypeFn> typeFn;
Optional<ScalarFn> scalarFn;
};
struct ScalarAssign {
@ -265,16 +260,23 @@ struct MappingTraits<ScalarAssign> {
/// - `scalar_arg`: An operation argument.
/// - `scalar_const`: A constant definition.
/// - `scalar_index`: An iteration index.
/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
/// - `scalar_fn`: A named function (see `ScalarFn`).
template <>
struct MappingTraits<ScalarExpression> {
static void mapping(IO &io, ScalarExpression &info) {
io.mapOptional("scalar_arg", info.arg);
io.mapOptional("scalar_const", info.constant);
io.mapOptional("scalar_index", info.index);
io.mapOptional("arith_fn", info.arithFn);
io.mapOptional("type_fn", info.typeFn);
io.mapOptional("scalar_fn", info.scalarFn);
}
};
/// Scalar function kind enum.
template <>
struct ScalarEnumerationTraits<ScalarFnKind> {
static void enumeration(IO &io, ScalarFnKind &value) {
io.enumCase(value, "arith", ScalarFnKind::Arith);
io.enumCase(value, "type", ScalarFnKind::Type);
}
};
@ -284,20 +286,13 @@ struct MappingTraits<ScalarExpression> {
/// - `add(lhs, rhs)`
/// - `mul(lhs, rhs)`
template <>
struct MappingTraits<ScalarArithFn> {
static void mapping(IO &io, ScalarArithFn &info) {
io.mapRequired("fn_name", info.fnName);
io.mapRequired("operands", info.operands);
}
};
template <>
struct MappingTraits<ScalarTypeFn> {
static void mapping(IO &io, ScalarTypeFn &info) {
io.mapRequired("type_var", info.typeVar);
io.mapRequired("operands", info.operands);
struct MappingTraits<ScalarFn> {
static void mapping(IO &io, ScalarFn &info) {
io.mapRequired("kind", info.kind);
io.mapOptional("fn_name", info.fnName);
io.mapOptional("attr_name", info.attrName);
io.mapOptional("type_var", info.typeVar);
io.mapRequired("operands", info.operands);
}
};
@ -1060,11 +1055,12 @@ if ({0}Iter != attrs.end()) {{
cppIdent, *expression.index));
return cppIdent;
}
if (expression.arithFn) {
if (expression.scalarFn &&
expression.scalarFn->kind == ScalarFnKind::Arith) {
// Apply function.
// Recursively generate operands.
SmallVector<std::string> operandCppValues;
for (ScalarExpression &operand : expression.arithFn->operands) {
for (ScalarExpression &operand : expression.scalarFn->operands) {
auto operandCppValue = generateExpression(operand);
if (!operandCppValue)
return None;
@ -1073,28 +1069,30 @@ if ({0}Iter != attrs.end()) {{
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
expression.arithFn->fnName,
expression.scalarFn->fnName,
interleaveToString(operandCppValues, ", ")));
return cppIdent;
}
if (expression.typeFn) {
if (expression.scalarFn &&
expression.scalarFn->kind == ScalarFnKind::Type) {
// Symbolic cast.
// Operands must be arity 1.
if (expression.typeFn->operands.size() != 1) {
if (expression.scalarFn->operands.size() != 1) {
emitError(genContext.getLoc())
<< "type conversion operand arity must be 1";
return None;
}
Optional<std::string> operandCppValue =
generateExpression(expression.typeFn->operands[0]);
generateExpression(expression.scalarFn->operands[0]);
if (!operandCppValue)
return None;
assert(expression.scalarFn->typeVar.hasValue());
Optional<std::string> typeCppValue =
findTypeValue(expression.typeFn->typeVar, args);
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
if (!typeCppValue) {
emitError(genContext.getLoc())
<< "type variable " << expression.typeFn->typeVar
<< "type variable " << expression.scalarFn->typeVar.getValue()
<< ", used in a type conversion, must map to a predefined or "
<< "an argument type but it does not";
return None;
@ -1102,17 +1100,17 @@ if ({0}Iter != attrs.end()) {{
// Use the function name or the attribute to build the type function.
std::string typeFunc = llvm::formatv(
"TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
if (expression.typeFn->attrName) {
"TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
if (expression.scalarFn->attrName) {
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
arg.name == expression.typeFn->attrName.getValue();
arg.name == expression.scalarFn->attrName.getValue();
})) {
emitError(genContext.getLoc())
<< "missing type function attribute "
<< expression.typeFn->attrName.getValue();
<< expression.scalarFn->attrName.getValue();
}
typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
}
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
stmts.push_back(llvm::formatv(