forked from OSchip/llvm-project
[mlir][OpDSL] Refactor function handling.
Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn). Depends On D119718 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120108
This commit is contained in:
parent
9f5f08476e
commit
4d4cb17da8
File diff suppressed because it is too large
Load Diff
|
@ -133,17 +133,26 @@ class TensorUse(TensorExpression):
|
|||
f"[{', '.join([repr(i) for i in self.indices])}]")
|
||||
|
||||
|
||||
class TensorArithFn(TensorExpression):
|
||||
"""Application of an arithmetic function."""
|
||||
class TensorFn(TensorExpression):
|
||||
"""Application of a tensor function."""
|
||||
|
||||
def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
|
||||
self.arith_fn = arith_fn
|
||||
self.args = tuple(args)
|
||||
def __init__(self, kind: "FunctionKind", name: Optional[str],
|
||||
operand_def: Optional["OperandDef"], type_var: Optional[TypeVar],
|
||||
args: Sequence[TensorExpression]):
|
||||
if bool(name) + bool(operand_def) != 1:
|
||||
raise ValueError("One of 'name', 'operand_def' must be specified")
|
||||
self.name = name
|
||||
self.kind = kind
|
||||
self.operand_def = operand_def
|
||||
self.type_var = type_var
|
||||
self.args = args
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
return ScalarArithFn(self.arith_fn.fn_name,
|
||||
*[arg.to_scalar_expression() for arg in self.args
|
||||
]).expr()
|
||||
if self.operand_def:
|
||||
assert self.operand_def.name, "TensorFn not registered with an op"
|
||||
attr_name = self.operand_def.name if self.operand_def else None
|
||||
args = [arg.to_scalar_expression() for arg in self.args]
|
||||
return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||
super().visit_tensor_exprs(callback)
|
||||
|
@ -151,37 +160,9 @@ class TensorArithFn(TensorExpression):
|
|||
arg.visit_tensor_exprs(callback)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
|
||||
|
||||
|
||||
class TensorTypeFn(TensorExpression):
|
||||
"""Application of a type conversion function."""
|
||||
|
||||
def __init__(self, type_fn: Optional["TypeFn"],
|
||||
operand_def: Optional["OperandDef"], type_var: TypeVar,
|
||||
arg: TensorExpression):
|
||||
if bool(type_fn) + bool(operand_def) != 1:
|
||||
raise ValueError("Either 'type_fn' or 'operand_def' must be specified")
|
||||
self.type_fn = type_fn
|
||||
self.operand_def = operand_def
|
||||
self.type_var = type_var
|
||||
self.arg = arg
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
if self.operand_def:
|
||||
assert self.operand_def.name, "TypeFnAttr not registered with an op"
|
||||
fn_name = self.type_fn.fn_name if self.type_fn else None
|
||||
attr_name = self.operand_def.name if self.operand_def else None
|
||||
return ScalarTypeFn(fn_name, attr_name, self.type_var,
|
||||
self.arg.to_scalar_expression()).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||
super().visit_tensor_exprs(callback)
|
||||
self.arg.visit_tensor_exprs(callback)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]"
|
||||
f"({self.type_var}, {self.arg})")
|
||||
name = self.operand_def.name if self.operand_def else self.name
|
||||
return (f"{self.kind.name}.{name}(type_var={self.type_var}, "
|
||||
f"args={', '.join(repr(a) for a in self.args)})")
|
||||
|
||||
|
||||
class TensorReduceFn(TensorExpression):
|
||||
|
@ -194,7 +175,7 @@ class TensorReduceFn(TensorExpression):
|
|||
args: Sequence[TensorExpression]):
|
||||
self.reduce_use = reduce_use
|
||||
self.lhs = None # type: Optional[TensorUse]
|
||||
self.args = tuple(args)
|
||||
self.args = args
|
||||
|
||||
def to_scalar_expression(self) -> ScalarExpression:
|
||||
if self.lhs is None:
|
||||
|
@ -202,7 +183,8 @@ class TensorReduceFn(TensorExpression):
|
|||
f"bound to its lhs: {self}")
|
||||
full_args = [self.lhs.to_scalar_expression()
|
||||
] + [arg.to_scalar_expression() for arg in self.args]
|
||||
return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
|
||||
return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None,
|
||||
None, full_args).expr()
|
||||
|
||||
def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
|
||||
for arg in self.args:
|
||||
|
@ -259,6 +241,11 @@ class index(TensorExpression):
|
|||
###############################################################################
|
||||
|
||||
|
||||
class FunctionKind(Enum):
|
||||
ARITH = 0
|
||||
TYPE = 1
|
||||
|
||||
|
||||
class TypeFnType:
|
||||
"""Type conversion function.
|
||||
|
||||
|
@ -269,8 +256,8 @@ class TypeFnType:
|
|||
def __init__(self, fn_name: str):
|
||||
self.fn_name = fn_name
|
||||
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
|
||||
return TensorTypeFn(self, None, type_var, arg)
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn":
|
||||
return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg])
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.fn_name}"
|
||||
|
@ -301,8 +288,8 @@ class ArithFnType:
|
|||
def __init__(self, fn_name: str):
|
||||
self.fn_name = fn_name
|
||||
|
||||
def __call__(self, *args) -> "TensorArithFn":
|
||||
return TensorArithFn(self, args)
|
||||
def __call__(self, *args) -> "TensorFn":
|
||||
return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.fn_name}"
|
||||
|
@ -562,8 +549,8 @@ class TypeFnAttrDef:
|
|||
self.operand_def = OperandDef(
|
||||
OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name)
|
||||
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn:
|
||||
return TensorTypeFn(None, self.operand_def, type_var, arg)
|
||||
def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn:
|
||||
return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg])
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
|
|
@ -270,19 +270,19 @@ class _BodyBuilder:
|
|||
dim_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(64), expr.scalar_index.dim)
|
||||
return linalg.IndexOp(dim_attr).result
|
||||
elif expr.arith_fn:
|
||||
fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}")
|
||||
elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH:
|
||||
fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}")
|
||||
operand_values = [
|
||||
self.expression(operand) for operand in expr.arith_fn.operands
|
||||
self.expression(operand) for operand in expr.scalar_fn.operands
|
||||
]
|
||||
return fn(*operand_values)
|
||||
elif expr.type_fn:
|
||||
fn_name = expr.type_fn.fn_name
|
||||
if expr.type_fn.attr_name:
|
||||
fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name]
|
||||
elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE:
|
||||
fn_name = expr.scalar_fn.fn_name
|
||||
if expr.scalar_fn.attr_name:
|
||||
fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name]
|
||||
fn = self._get_function(f"_typefn_{fn_name}")
|
||||
operand = self.expression(expr.type_fn.operand)
|
||||
return fn(expr.type_fn.type_var.name, operand)
|
||||
operand_value = self.expression(expr.scalar_fn.operands[0])
|
||||
return fn(expr.scalar_fn.type_var.name, operand_value)
|
||||
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
|
||||
|
||||
def yield_outputs(self, *output_names: str):
|
||||
|
|
|
@ -15,13 +15,13 @@ can be easily consumed from the C++ side, not necessarily for ergonomics.
|
|||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from .yaml_helper import *
|
||||
from .comprehension import *
|
||||
from .types import *
|
||||
from .yaml_helper import *
|
||||
|
||||
__all__ = [
|
||||
"ScalarAssign",
|
||||
"ScalarArithFn",
|
||||
"ScalarTypeFn",
|
||||
"ScalarFn",
|
||||
"ScalarArg",
|
||||
"ScalarConst",
|
||||
"ScalarIndex",
|
||||
|
@ -29,36 +29,27 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
class ScalarArithFn:
|
||||
"""A type of ScalarExpression that applies an arithmetic function."""
|
||||
class ScalarFn:
|
||||
"""A type of ScalarExpression that applies a function."""
|
||||
|
||||
def __init__(self, fn_name: str, *operands: "ScalarExpression"):
|
||||
self.fn_name = fn_name
|
||||
self.operands = operands
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(arith_fn=self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})"
|
||||
|
||||
|
||||
class ScalarTypeFn:
|
||||
"""A type of ScalarExpression that applies a type conversion function."""
|
||||
|
||||
def __init__(self, fn_name: Optional[str], attr_name: Optional[str],
|
||||
type_var: TypeVar, operand: "ScalarExpression"):
|
||||
def __init__(self, kind: "FunctionKind", fn_name: Optional[str],
|
||||
attr_name: Optional[str], type_var: Optional["TypeVar"],
|
||||
operands: Sequence["ScalarExpression"]):
|
||||
if bool(fn_name) + bool(attr_name) != 1:
|
||||
raise ValueError("One of 'fn_name', 'attr_name' must be specified")
|
||||
self.kind = kind
|
||||
self.fn_name = fn_name
|
||||
self.attr_name = attr_name
|
||||
self.type_var = type_var
|
||||
self.operand = operand
|
||||
self.operands = operands
|
||||
|
||||
def expr(self) -> "ScalarExpression":
|
||||
return ScalarExpression(type_fn=self)
|
||||
return ScalarExpression(scalar_fn=self)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>"
|
||||
f"({self.type_var}, {self.operand})")
|
||||
name = self.fn_name if self.fn_name else self.attr_name
|
||||
return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, "
|
||||
f"operands=[{', '.join(self.operands)}])")
|
||||
|
||||
|
||||
class ScalarArg:
|
||||
|
@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject):
|
|||
"""An expression on scalar values.
|
||||
|
||||
Can be one of:
|
||||
- ScalarArithFn
|
||||
- ScalarTypeFn
|
||||
- ScalarFn
|
||||
- ScalarArg
|
||||
- ScalarConst
|
||||
- ScalarIndex
|
||||
- ScalarSymbolicCast
|
||||
"""
|
||||
yaml_tag = "!ScalarExpression"
|
||||
|
||||
def __init__(self,
|
||||
arith_fn: Optional[ScalarArithFn] = None,
|
||||
type_fn: Optional[ScalarTypeFn] = None,
|
||||
scalar_fn: Optional[ScalarFn] = None,
|
||||
scalar_arg: Optional[ScalarArg] = None,
|
||||
scalar_const: Optional[ScalarConst] = None,
|
||||
scalar_index: Optional[ScalarIndex] = None):
|
||||
if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) +
|
||||
if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) +
|
||||
bool(scalar_index)) != 1:
|
||||
raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', "
|
||||
"'scalar_const', 'scalar_index', must be specified")
|
||||
self.arith_fn = arith_fn
|
||||
self.type_fn = type_fn
|
||||
raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or "
|
||||
"'scalar_index' must be specified")
|
||||
self.scalar_fn = scalar_fn
|
||||
self.scalar_arg = scalar_arg
|
||||
self.scalar_const = scalar_const
|
||||
self.scalar_index = scalar_index
|
||||
|
||||
def to_yaml_custom_dict(self):
|
||||
if self.arith_fn:
|
||||
return dict(
|
||||
arith_fn=dict(
|
||||
fn_name=self.arith_fn.fn_name,
|
||||
operands=list(self.arith_fn.operands),
|
||||
))
|
||||
if self.type_fn:
|
||||
# Note that even though operands must be arity 1, we write it the
|
||||
# same way as for apply because it allows handling code to be more
|
||||
# generic vs having a special form.
|
||||
type_fn_dict = dict(
|
||||
type_var=self.type_fn.type_var.name,
|
||||
operands=[self.type_fn.operand],
|
||||
)
|
||||
if self.type_fn.fn_name:
|
||||
type_fn_dict["fn_name"] = self.type_fn.fn_name
|
||||
if self.type_fn.attr_name:
|
||||
type_fn_dict["attr_name"] = self.type_fn.attr_name
|
||||
return dict(type_fn=type_fn_dict)
|
||||
if self.scalar_fn:
|
||||
scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower())
|
||||
if self.scalar_fn.fn_name:
|
||||
scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name
|
||||
if self.scalar_fn.attr_name:
|
||||
scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name
|
||||
if self.scalar_fn.type_var:
|
||||
scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name
|
||||
scalar_fn_dict["operands"] = list(self.scalar_fn.operands)
|
||||
return dict(scalar_fn=scalar_fn_dict)
|
||||
elif self.scalar_arg:
|
||||
return dict(scalar_arg=self.scalar_arg.arg)
|
||||
elif self.scalar_const:
|
||||
|
|
|
@ -39,23 +39,26 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
arith_fn:
|
||||
scalar_fn:
|
||||
kind: arith
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
type_fn:
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_const: '42 : i64'
|
||||
attr_name: cast
|
||||
- !ScalarExpression
|
||||
type_fn:
|
||||
scalar_fn:
|
||||
kind: type
|
||||
attr_name: cast
|
||||
type_var: T
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_index: 1
|
||||
attr_name: cast
|
||||
|
||||
# ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1"
|
||||
|
||||
|
@ -236,7 +239,8 @@ structured_op: !LinalgStructuredOpConfig
|
|||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
type_fn:
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast
|
||||
type_var: U
|
||||
operands:
|
||||
|
|
|
@ -9,22 +9,24 @@ from mlir.dialects.linalg.opdsl.lang import *
|
|||
# CHECK: -
|
||||
# CHECK: arg: C
|
||||
# CHECK: value:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: fn_name: mul
|
||||
# CHECK: operands:
|
||||
# CHECK: type_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: attr_name: cast
|
||||
# CHECK: type_var: U
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_arg: A
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: attr_name: cast
|
||||
# CHECK: type_fn:
|
||||
# CHECK: type_var: U
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_arg: B
|
||||
# CHECK: attr_name: cast
|
||||
@linalg_structured_op
|
||||
def matmul(
|
||||
A=TensorDef(T, S.M, S.K),
|
||||
|
@ -39,21 +41,28 @@ def matmul(
|
|||
# CHECK: assignments:
|
||||
# CHECK: -
|
||||
# CHECK: arg: O
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: arith
|
||||
# CHECK: fn_name: sub
|
||||
# CHECK: operands:
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: arith
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: type_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '3.1415926535897931 : f64'
|
||||
# CHECK: type_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: fn_name: cast
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '42 : i64'
|
||||
# CHECK: type_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: type
|
||||
# CHECK: fn_name: cast
|
||||
# CHECK: type_var: T
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64'
|
||||
|
@ -70,7 +79,8 @@ def constants(O=TensorDef(T, S.M, S.K, output=True)):
|
|||
# CHECK: assignments:
|
||||
# CHECK: -
|
||||
# CHECK: arg: O
|
||||
# CHECK: arith_fn:
|
||||
# CHECK: scalar_fn:
|
||||
# CHECK: kind: arith
|
||||
# CHECK: fn_name: add
|
||||
# CHECK: operands:
|
||||
# CHECK: scalar_index: 1
|
||||
|
|
|
@ -90,28 +90,23 @@ struct LinalgIndexingMapsConfig {
|
|||
|
||||
struct ScalarExpression;
|
||||
|
||||
struct ScalarArithFn {
|
||||
std::string fnName;
|
||||
// NOTE: Must be pure heap allocated container (not SmallVector)
|
||||
// due to recursive data type.
|
||||
std::vector<ScalarExpression> operands;
|
||||
};
|
||||
enum class ScalarFnKind { Arith, Type };
|
||||
|
||||
struct ScalarTypeFn {
|
||||
std::string typeVar;
|
||||
struct ScalarFn {
|
||||
ScalarFnKind kind;
|
||||
Optional<std::string> fnName;
|
||||
Optional<std::string> attrName;
|
||||
Optional<std::string> typeVar;
|
||||
// NOTE: This must be of arity 1, but to break the self-referential cycle,
|
||||
// we use a heap allocated vector.
|
||||
std::vector<ScalarExpression> operands;
|
||||
Optional<std::string> fnName;
|
||||
Optional<std::string> attrName;
|
||||
};
|
||||
|
||||
struct ScalarExpression {
|
||||
Optional<std::string> arg;
|
||||
Optional<std::string> constant;
|
||||
Optional<int64_t> index;
|
||||
Optional<ScalarArithFn> arithFn;
|
||||
Optional<ScalarTypeFn> typeFn;
|
||||
Optional<ScalarFn> scalarFn;
|
||||
};
|
||||
|
||||
struct ScalarAssign {
|
||||
|
@ -265,16 +260,23 @@ struct MappingTraits<ScalarAssign> {
|
|||
/// - `scalar_arg`: An operation argument.
|
||||
/// - `scalar_const`: A constant definition.
|
||||
/// - `scalar_index`: An iteration index.
|
||||
/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`).
|
||||
/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`).
|
||||
/// - `scalar_fn`: A named function (see `ScalarFn`).
|
||||
template <>
|
||||
struct MappingTraits<ScalarExpression> {
|
||||
static void mapping(IO &io, ScalarExpression &info) {
|
||||
io.mapOptional("scalar_arg", info.arg);
|
||||
io.mapOptional("scalar_const", info.constant);
|
||||
io.mapOptional("scalar_index", info.index);
|
||||
io.mapOptional("arith_fn", info.arithFn);
|
||||
io.mapOptional("type_fn", info.typeFn);
|
||||
io.mapOptional("scalar_fn", info.scalarFn);
|
||||
}
|
||||
};
|
||||
|
||||
/// Scalar function kind enum.
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<ScalarFnKind> {
|
||||
static void enumeration(IO &io, ScalarFnKind &value) {
|
||||
io.enumCase(value, "arith", ScalarFnKind::Arith);
|
||||
io.enumCase(value, "type", ScalarFnKind::Type);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -284,20 +286,13 @@ struct MappingTraits<ScalarExpression> {
|
|||
/// - `add(lhs, rhs)`
|
||||
/// - `mul(lhs, rhs)`
|
||||
template <>
|
||||
struct MappingTraits<ScalarArithFn> {
|
||||
static void mapping(IO &io, ScalarArithFn &info) {
|
||||
io.mapRequired("fn_name", info.fnName);
|
||||
io.mapRequired("operands", info.operands);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MappingTraits<ScalarTypeFn> {
|
||||
static void mapping(IO &io, ScalarTypeFn &info) {
|
||||
io.mapRequired("type_var", info.typeVar);
|
||||
io.mapRequired("operands", info.operands);
|
||||
struct MappingTraits<ScalarFn> {
|
||||
static void mapping(IO &io, ScalarFn &info) {
|
||||
io.mapRequired("kind", info.kind);
|
||||
io.mapOptional("fn_name", info.fnName);
|
||||
io.mapOptional("attr_name", info.attrName);
|
||||
io.mapOptional("type_var", info.typeVar);
|
||||
io.mapRequired("operands", info.operands);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1060,11 +1055,12 @@ if ({0}Iter != attrs.end()) {{
|
|||
cppIdent, *expression.index));
|
||||
return cppIdent;
|
||||
}
|
||||
if (expression.arithFn) {
|
||||
if (expression.scalarFn &&
|
||||
expression.scalarFn->kind == ScalarFnKind::Arith) {
|
||||
// Apply function.
|
||||
// Recursively generate operands.
|
||||
SmallVector<std::string> operandCppValues;
|
||||
for (ScalarExpression &operand : expression.arithFn->operands) {
|
||||
for (ScalarExpression &operand : expression.scalarFn->operands) {
|
||||
auto operandCppValue = generateExpression(operand);
|
||||
if (!operandCppValue)
|
||||
return None;
|
||||
|
@ -1073,28 +1069,30 @@ if ({0}Iter != attrs.end()) {{
|
|||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(
|
||||
llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent,
|
||||
expression.arithFn->fnName,
|
||||
expression.scalarFn->fnName,
|
||||
interleaveToString(operandCppValues, ", ")));
|
||||
return cppIdent;
|
||||
}
|
||||
if (expression.typeFn) {
|
||||
if (expression.scalarFn &&
|
||||
expression.scalarFn->kind == ScalarFnKind::Type) {
|
||||
// Symbolic cast.
|
||||
// Operands must be arity 1.
|
||||
if (expression.typeFn->operands.size() != 1) {
|
||||
if (expression.scalarFn->operands.size() != 1) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "type conversion operand arity must be 1";
|
||||
return None;
|
||||
}
|
||||
Optional<std::string> operandCppValue =
|
||||
generateExpression(expression.typeFn->operands[0]);
|
||||
generateExpression(expression.scalarFn->operands[0]);
|
||||
if (!operandCppValue)
|
||||
return None;
|
||||
|
||||
assert(expression.scalarFn->typeVar.hasValue());
|
||||
Optional<std::string> typeCppValue =
|
||||
findTypeValue(expression.typeFn->typeVar, args);
|
||||
findTypeValue(expression.scalarFn->typeVar.getValue(), args);
|
||||
if (!typeCppValue) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "type variable " << expression.typeFn->typeVar
|
||||
<< "type variable " << expression.scalarFn->typeVar.getValue()
|
||||
<< ", used in a type conversion, must map to a predefined or "
|
||||
<< "an argument type but it does not";
|
||||
return None;
|
||||
|
@ -1102,17 +1100,17 @@ if ({0}Iter != attrs.end()) {{
|
|||
|
||||
// Use the function name or the attribute to build the type function.
|
||||
std::string typeFunc = llvm::formatv(
|
||||
"TypeFn::{0}", expression.typeFn->fnName.getValueOr(""));
|
||||
if (expression.typeFn->attrName) {
|
||||
"TypeFn::{0}", expression.scalarFn->fnName.getValueOr(""));
|
||||
if (expression.scalarFn->attrName) {
|
||||
if (llvm::none_of(args, [&](LinalgOperandDef &arg) {
|
||||
return arg.kind == LinalgOperandDefKind::TypeFnAttr &&
|
||||
arg.name == expression.typeFn->attrName.getValue();
|
||||
arg.name == expression.scalarFn->attrName.getValue();
|
||||
})) {
|
||||
emitError(genContext.getLoc())
|
||||
<< "missing type function attribute "
|
||||
<< expression.typeFn->attrName.getValue();
|
||||
<< expression.scalarFn->attrName.getValue();
|
||||
}
|
||||
typeFunc = llvm::formatv("{0}Val", *expression.typeFn->attrName);
|
||||
typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName);
|
||||
}
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(llvm::formatv(
|
||||
|
|
Loading…
Reference in New Issue