[FIRRTL] Clean up inferReturnTypes implementation

This is a big change to clean up our implementation of
`inferReturnTypes` in an effort to get ourselves ready to enable
properties in the FIRRTL dialect. We had structured our API so that
`inferReturnType` would call out to a simpler version of itself with
some of the useless arguments removed.  This prevented us from using op
adaptors to abstract over whether or not an inherent attribute was
contained inside the attr-dict or the properties.

This change keeps the old structure of a two level API, but with
different boundaries: The large API takes all arguments, creates an
adapter, pulls out the necessary attributes and then calls in to the
simpler interface. FIRParser uses the simpler API when inferring return
types.  The simpler interface is now specific to each operation and not
common with other operations.

This last change caused a problem for `parsePrimExp`, which relied on a
generic interface to create all expressions. This function is now
templated over exactly how many arguments the specific prim op takes,
parses exactly that many, and splats them out when calling
`inferReturnType`.  As an upside to this, we can also call a more
specific builder for each operation, which should speed up building
operations when we move to properties.
This commit is contained in:
Andrew Young 2024-10-29 18:18:48 -07:00
parent c68cbf5361
commit fd56341db4
6 changed files with 487 additions and 416 deletions

View File

@ -38,7 +38,9 @@ class FIRRTLExprOp<string mnemonic, list<Trait> traits = []> :
code inferTypeDecl = [{
/// Infer the return type of this operation.
static FIRRTLType inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc);
}];
@ -60,8 +62,11 @@ class FIRRTLExprOp<string mnemonic, list<Trait> traits = []> :
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
SmallVectorImpl<Type> &results) {
return impl::inferReturnTypes(context, loc, operands, attrs, properties,
regions, results, &inferReturnType);
auto type = inferReturnType(operands, attrs, properties, regions, loc);
if (!type)
return failure();
results.push_back(type);
return success();
}
}];
@ -270,6 +275,22 @@ class BaseSubfieldOp<string name, Type btype, Type rtype> : FIRRTLExprOp<name> {
return build($_builder, $_state, input, *fieldIndex);
}]>
];
let inferTypeDecl = [{
static FIRRTLType inferReturnType(Type inType , uint32_t fieldIndex,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(adaptor.getInput().getType(),
adaptor.getFieldIndex(),
loc);
}
}];
let firrtlExtraClassDeclaration = [{
using InputType = }] # btype # [{;
@ -316,6 +337,21 @@ def SubindexOp : FIRRTLExprOp<"subindex"> {
let assemblyFormat =
"$input `[` $index `]` attr-dict `:` qualified(type($input))";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(Type inType , uint32_t fieldIndex,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(adaptor.getInput().getType(),
adaptor.getIndex(), loc);
}
}];
let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {
@ -343,6 +379,21 @@ def OpenSubindexOp : FIRRTLExprOp<"opensubindex"> {
let assemblyFormat =
"$input `[` $index `]` attr-dict `:` qualified(type($input))";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(Type inType , uint32_t fieldIndex,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(adaptor.getInput().getType(),
adaptor.getIndex(), loc);
}
}];
let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {
@ -371,6 +422,20 @@ def SubaccessOp : FIRRTLExprOp<"subaccess"> {
"$input `[` $index `]` attr-dict `:` qualified(type($input)) `,` qualified(type($index))";
let hasCanonicalizer = true;
let inferTypeDecl = [{
static FIRRTLType inferReturnType(Type inType, Type indexType,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands, DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto inType = adaptor.getInput().getType();
auto indexType = adaptor.getIndex().getType();
return inferReturnType(inType, indexType, loc);
}
}];
}
def IsTagOp : FIRRTLExprOp<"istag"> {
@ -523,24 +588,33 @@ class BinaryPrimOp<string mnemonic, Type lhsType, Type rhsType, Type resultType,
}];
// Give concrete operations a chance to set a type inference callback. If left
// empty, a declaration for `inferBinaryReturnType` will be emitted that the
// empty, a declaration for `inferReturnType` will be emitted that the
// operation is expected to implement.
code inferType = "";
let inferTypeDecl = !if(!empty(inferType), [{
/// Infer the return type of this binary operation.
static FIRRTLType inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc);
}], "") # !subst("$_infer", !if(!empty(inferType), "inferBinaryReturnType",
inferType), [{
static FIRRTLType inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc);
}], !subst("$_infer", inferType, [{
/// Infer the return type of this binary operation.
static FIRRTLType inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
return $_infer(lhs, rhs, loc);
}
}])) # [{
/// Infer the return type of this operation.
static FIRRTLType inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
std::optional<Location> loc) {
return $_infer(firrtl::type_cast<FIRRTLType>(operands[0].getType()),
firrtl::type_cast<FIRRTLType>(operands[1].getType()),
loc);
mlir::DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc
) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(firrtl::type_cast<FIRRTLType>(adaptor.getLhs().getType()),
firrtl::type_cast<FIRRTLType>(adaptor.getRhs().getType()),
loc);
}
}]);
}];
}
// A binary operation on two integer-typed arguments of the same kind.
@ -639,22 +713,32 @@ class UnaryPrimOp<string mnemonic, Type srcType, Type resultType,
"$input attr-dict `:` functional-type($input, $result)";
// Give concrete operations a chance to set a type inference callback. If left
// empty, a declaration for `inferUnaryReturnType` will be emitted that the
// empty, a declaration for `inferReturnType` will be emitted that the
// operation is expected to implement.
code inferType = "";
let inferTypeDecl = !if(!empty(inferType), [{
/// Infer the return type of this unary operation.
static FIRRTLType inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc);
}], "") # !subst("$_infer", !if(!empty(inferType), "inferUnaryReturnType",
inferType), [{
/// Infer the return type of this binary operation.
static FIRRTLType inferReturnType(FIRRTLType input,
std::optional<Location> loc);
}], !subst("$_infer", inferType, [{
/// Infer the return type of this binary operation.
static FIRRTLType inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
return $_infer(input, loc);
}
}])) # [{
/// Infer the return type of this operation.
static FIRRTLType inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
std::optional<Location> loc) {
return $_infer(firrtl::type_cast<FIRRTLType>(operands[0].getType()), loc);
mlir::DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc
) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType()),
loc);
}
}]);
}];
}
@ -716,6 +800,22 @@ def BitsPrimOp : PrimOp<"bits"> {
The result is `hi - lo + 1` bits wide.
}];
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType input, int64_t high,
int64_t low, std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto input = firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType());
auto high = adaptor.getHiAttr().getValue().getSExtValue();
auto low = adaptor.getLoAttr().getValue().getSExtValue();
return inferReturnType(input, high, low, loc);
}
}];
let hasCanonicalizer = true;
}
@ -726,6 +826,22 @@ def HeadPrimOp : PrimOp<"head"> {
let assemblyFormat =
"$input `,` $amount attr-dict `:` functional-type($input, $result)";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto input = firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType());
auto amount = adaptor.getAmountAttr().getValue().getSExtValue();
return inferReturnType(input, amount, loc);
}
}];
let hasCanonicalizeMethod = true;
}
@ -737,6 +853,23 @@ def MuxPrimOp : PrimOp<"mux"> {
let assemblyFormat =
"`(` operands `)` attr-dict `:` functional-type(operands, $result)";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType sel, FIRRTLType high,
FIRRTLType low, std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto sel = firrtl::type_cast<FIRRTLType>(adaptor.getSel().getType());
auto high = firrtl::type_cast<FIRRTLType>(adaptor.getHigh().getType());
auto low = firrtl::type_cast<FIRRTLType>(adaptor.getLow().getType());
return inferReturnType(sel, high, low, loc);
}
}];
let hasCanonicalizer = true;
}
@ -753,6 +886,21 @@ def PadPrimOp : PrimOp<"pad"> {
width of `input`, then input is unmodified.
}];
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto input = firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType());
auto amount = adaptor.getAmountAttr().getValue().getSExtValue();
return inferReturnType(input, amount, loc);
}
}];
}
class ShiftPrimOp<string mnemonic> : PrimOp<mnemonic> {
@ -762,6 +910,21 @@ class ShiftPrimOp<string mnemonic> : PrimOp<mnemonic> {
let assemblyFormat =
"$input `,` $amount attr-dict `:` functional-type($input, $result)";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto input = firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType());
auto amount = adaptor.getAmountAttr().getValue().getSExtValue();
return inferReturnType(input, amount, loc);
}
}];
}
def ShlPrimOp : ShiftPrimOp<"shl"> {
@ -795,6 +958,22 @@ def TailPrimOp : PrimOp<"tail"> {
width of e. The result is `width(input)-amount` bits wide.
}];
let inferTypeDecl = [{
static FIRRTLType inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
auto input = firrtl::type_cast<FIRRTLType>(adaptor.getInput().getType());
auto amount = adaptor.getAmountAttr().getValue().getSExtValue();
return inferReturnType(input, amount, loc);
}
}];
let hasCanonicalizeMethod = true;
}
@ -1007,9 +1186,19 @@ def ObjectSubfieldOp : FIRRTLOp<"object.subfield",
/// Infer the return type of this operation.
/// Note: In contrast to other ops, this function infers a generic Type,
/// in order to support foreign types in ports.
static Type inferReturnType(Type inType , uint32_t fieldIndex,
std::optional<Location> loc);
static Type inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
std::optional<Location> loc);
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(adaptor.getInput().getType(),
adaptor.getIndex(), loc);
}
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {
@ -1349,6 +1538,21 @@ def RefSubOp : FIRRTLExprOp<"ref.sub"> {
let assemblyFormat =
"$input `[` $index `]` attr-dict `:` qualified(type($input))";
let inferTypeDecl = [{
static FIRRTLType inferReturnType(Type inType , uint32_t fieldIndex,
std::optional<Location> loc);
static FIRRTLType inferReturnType(ValueRange operands,
mlir::DictionaryAttr attrs,
mlir::OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Adaptor adaptor(operands, attrs, properties, regions);
return inferReturnType(adaptor.getInput().getType(),
adaptor.getIndex(), loc);
}
}];
let firrtlExtraClassDeclaration = [{
/// Return a `FieldRef` to the accessed field.
FieldRef getAccessedField() {

View File

@ -139,15 +139,6 @@ MatchingConnectOp getSingleConnectUserOf(Value value);
namespace impl {
LogicalResult verifySameOperandsIntTypeKind(Operation *op);
// Type inference adaptor for FIRRTL operations.
LogicalResult inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attrs, mlir::OpaqueProperties properties,
mlir::RegionRange regions, SmallVectorImpl<Type> &results,
llvm::function_ref<FIRRTLType(ValueRange, ArrayRef<NamedAttribute>,
std::optional<Location>)>
callback);
// Common type inference functions.
FIRRTLType inferAddSubResult(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc);

View File

@ -3368,19 +3368,19 @@ LogicalResult NodeOp::inferReturnTypes(
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
if (operands.empty())
return failure();
inferredReturnTypes.push_back(operands[0].getType());
for (auto &attr : attributes)
if (attr.getName() == Forceable::getForceableAttrName()) {
auto forceableType =
firrtl::detail::getForceableResultType(true, operands[0].getType());
if (!forceableType) {
if (location)
::mlir::emitError(*location, "cannot force a node of type ")
<< operands[0].getType();
return failure();
}
inferredReturnTypes.push_back(forceableType);
Adaptor adaptor(operands, attributes, properties, regions);
inferredReturnTypes.push_back(adaptor.getInput().getType());
if (adaptor.getForceable()) {
auto forceableType = firrtl::detail::getForceableResultType(
true, adaptor.getInput().getType());
if (!forceableType) {
if (location)
::mlir::emitError(*location, "cannot force a node of type ")
<< operands[0].getType();
return failure();
}
inferredReturnTypes.push_back(forceableType);
}
return success();
}
@ -3983,50 +3983,6 @@ void MatchOp::build(OpBuilder &builder, OperationState &result, Value input,
// Expressions
//===----------------------------------------------------------------------===//
/// Type inference adaptor that narrows from the very generic MLIR
/// `InferTypeOpInterface` to what we need in the FIRRTL dialect: just operands
/// and attributes, no context or regions. Also, we only ever produce a single
/// result value, so the FIRRTL-specific type inference ops directly return the
/// inferred type rather than pushing into the `results` vector.
LogicalResult impl::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attrs, mlir::OpaqueProperties properties,
RegionRange regions, SmallVectorImpl<Type> &results,
llvm::function_ref<FIRRTLType(ValueRange, ArrayRef<NamedAttribute>,
std::optional<Location>)>
callback) {
auto type = callback(
operands, attrs ? attrs.getValue() : ArrayRef<NamedAttribute>{}, loc);
if (type) {
results.push_back(type);
return success();
}
return failure();
}
/// Get an attribute by name from a list of named attributes. Return null if no
/// attribute is found with that name.
static Attribute maybeGetAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
for (auto attr : attrs)
if (attr.getName() == name)
return attr.getValue();
return {};
}
/// Get an attribute by name from a list of named attributes. Aborts if the
/// attribute does not exist.
static Attribute getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
if (auto attr = maybeGetAttr(attrs, name))
return attr;
llvm::report_fatal_error("attribute '" + name + "' not found");
}
/// Same as above, but casts the attribute to a specific type.
template <typename AttrClass>
AttrClass getAttr(ArrayRef<NamedAttribute> attrs, StringRef name) {
return cast<AttrClass>(getAttr(attrs, name));
}
/// Return true if the specified operation is a firrtl expression.
bool firrtl::isExpression(Operation *op) {
struct IsExprClassifier : public ExprVisitor<IsExprClassifier, bool> {
@ -4507,11 +4463,13 @@ ParseResult IsTagOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
FIRRTLType IsTagOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType IsTagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
return UIntType::get(operands[0].getContext(), 1,
isConst(operands[0].getType()));
Adaptor adaptor(operands, attrs, properties, regions);
return UIntType::get(attrs.getContext(), 1,
isConst(adaptor.getInput().getType()));
}
template <typename OpTy>
@ -4704,12 +4662,9 @@ LogicalResult ConstCastOp::verify() {
return success();
}
FIRRTLType SubfieldOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType SubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
std::optional<Location> loc) {
auto inType = type_cast<BundleType>(operands[0].getType());
auto fieldIndex =
getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
auto inType = type_cast<BundleType>(type);
if (fieldIndex >= inType.getNumElements())
return emitInferRetTypeError(loc,
@ -4721,12 +4676,9 @@ FIRRTLType SubfieldOp::inferReturnType(ValueRange operands,
return inType.getElementTypePreservingConst(fieldIndex);
}
FIRRTLType OpenSubfieldOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType OpenSubfieldOp::inferReturnType(Type type, uint32_t fieldIndex,
std::optional<Location> loc) {
auto inType = type_cast<OpenBundleType>(operands[0].getType());
auto fieldIndex =
getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
auto inType = type_cast<OpenBundleType>(type);
if (fieldIndex >= inType.getNumElements())
return emitInferRetTypeError(loc,
@ -4747,46 +4699,36 @@ bool OpenSubfieldOp::isFieldFlipped() {
return bundle.getElement(getFieldIndex()).isFlip;
}
FIRRTLType SubindexOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType SubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
std::optional<Location> loc) {
Type inType = operands[0].getType();
auto fieldIdx =
getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
if (fieldIdx < vectorType.getNumElements())
if (auto vectorType = type_dyn_cast<FVectorType>(type)) {
if (fieldIndex < vectorType.getNumElements())
return vectorType.getElementTypePreservingConst();
return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
"' in vector type ", inType);
return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
"' in vector type ", type);
}
return emitInferRetTypeError(loc, "subindex requires vector operand");
}
FIRRTLType OpenSubindexOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType OpenSubindexOp::inferReturnType(Type type, uint32_t fieldIndex,
std::optional<Location> loc) {
Type inType = operands[0].getType();
auto fieldIdx =
getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
if (auto vectorType = type_dyn_cast<OpenVectorType>(inType)) {
if (fieldIdx < vectorType.getNumElements())
if (auto vectorType = type_dyn_cast<OpenVectorType>(type)) {
if (fieldIndex < vectorType.getNumElements())
return vectorType.getElementTypePreservingConst();
return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
"' in vector type ", inType);
return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
"' in vector type ", type);
}
return emitInferRetTypeError(loc, "subindex requires vector operand");
}
FIRRTLType SubtagOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType SubtagOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
auto inType = type_cast<FEnumType>(operands[0].getType());
auto fieldIndex =
getAttr<IntegerAttr>(attrs, "fieldIndex").getValue().getZExtValue();
Adaptor adaptor(operands, attrs, properties, regions);
auto inType = type_cast<FEnumType>(adaptor.getInput().getType());
auto fieldIndex = adaptor.getFieldIndex();
if (fieldIndex >= inType.getNumElements())
return emitInferRetTypeError(loc,
@ -4799,12 +4741,8 @@ FIRRTLType SubtagOp::inferReturnType(ValueRange operands,
return elementType.getConstType(elementType.isConst() || inType.isConst());
}
FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType SubaccessOp::inferReturnType(Type inType, Type indexType,
std::optional<Location> loc) {
auto inType = operands[0].getType();
auto indexType = operands[1].getType();
if (!type_isa<UIntType>(indexType))
return emitInferRetTypeError(loc, "subaccess index must be UInt type, not ",
indexType);
@ -4820,9 +4758,12 @@ FIRRTLType SubaccessOp::inferReturnType(ValueRange operands,
}
FIRRTLType TagExtractOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
auto inType = type_cast<FEnumType>(operands[0].getType());
Adaptor adaptor(operands, attrs, properties, regions);
auto inType = type_cast<FEnumType>(adaptor.getInput().getType());
auto i = llvm::Log2_32_Ceil(inType.getNumElements());
return UIntType::get(inType.getContext(), i);
}
@ -4855,7 +4796,9 @@ void MultibitMuxOp::print(OpAsmPrinter &p) {
}
FIRRTLType MultibitMuxOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
if (operands.size() < 2)
return emitInferRetTypeError(loc, "at least one input is required");
@ -4877,26 +4820,24 @@ LogicalResult ObjectSubfieldOp::inferReturnTypes(
MLIRContext *context, std::optional<mlir::Location> location,
ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
RegionRange regions, llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
auto type = inferReturnType(operands, attributes.getValue(), location);
auto type =
inferReturnType(operands, attributes, properties, regions, location);
if (!type)
return failure();
inferredReturnTypes.push_back(type);
return success();
}
Type ObjectSubfieldOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
Type ObjectSubfieldOp::inferReturnType(Type inType, uint32_t fieldIndex,
std::optional<Location> loc) {
auto classType = dyn_cast<ClassType>(operands[0].getType());
auto classType = dyn_cast<ClassType>(inType);
if (!classType)
return emitInferRetTypeError(loc, "base object is not a class");
auto index = getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
if (classType.getNumElements() <= index)
if (classType.getNumElements() <= fieldIndex)
return emitInferRetTypeError(loc, "element index is greater than the "
"number of fields in the object");
return classType.getElement(index).type;
return classType.getElement(fieldIndex).type;
}
void ObjectSubfieldOp::print(OpAsmPrinter &p) {
@ -5005,8 +4946,8 @@ FIRRTLType impl::inferAddSubResult(FIRRTLType lhs, FIRRTLType rhs,
isConstResult);
}
FIRRTLType MulPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType MulPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
int32_t lhsWidth, rhsWidth, resultWidth = -1;
bool isConstResult = false;
if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
@ -5019,8 +4960,8 @@ FIRRTLType MulPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
isConstResult);
}
FIRRTLType DivPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType DivPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
int32_t lhsWidth, rhsWidth;
bool isConstResult = false;
if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
@ -5035,8 +4976,8 @@ FIRRTLType DivPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
return SIntType::get(lhs.getContext(), resultWidth, isConstResult);
}
FIRRTLType RemPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType RemPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
int32_t lhsWidth, rhsWidth, resultWidth = -1;
bool isConstResult = false;
if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
@ -5094,8 +5035,8 @@ FIRRTLType impl::inferComparisonResult(FIRRTLType lhs, FIRRTLType rhs,
return UIntType::get(lhs.getContext(), 1, isConst(lhs) && isConst(rhs));
}
FIRRTLType CatPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType CatPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
int32_t lhsWidth, rhsWidth, resultWidth = -1;
bool isConstResult = false;
if (!isSameIntTypeKind(lhs, rhs, lhsWidth, rhsWidth, isConstResult, loc))
@ -5106,8 +5047,8 @@ FIRRTLType CatPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
return UIntType::get(lhs.getContext(), resultWidth, isConstResult);
}
FIRRTLType DShlPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType DShlPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
auto lhsi = type_dyn_cast<IntType>(lhs);
auto rhsui = type_dyn_cast<UIntType>(rhs);
if (!rhsui || !lhsi)
@ -5136,8 +5077,8 @@ FIRRTLType DShlPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
lhsi.isConst() && rhsui.isConst());
}
FIRRTLType DShlwPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType DShlwPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
auto lhsi = type_dyn_cast<IntType>(lhs);
auto rhsu = type_dyn_cast<UIntType>(rhs);
if (!lhsi || !rhsu)
@ -5146,8 +5087,8 @@ FIRRTLType DShlwPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
return lhsi.getConstType(lhsi.isConst() && rhsu.isConst());
}
FIRRTLType DShrPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
FIRRTLType DShrPrimOp::inferReturnType(FIRRTLType lhs, FIRRTLType rhs,
std::optional<Location> loc) {
auto lhsi = type_dyn_cast<IntType>(lhs);
auto rhsu = type_dyn_cast<UIntType>(rhs);
if (!lhsi || !rhsu)
@ -5160,14 +5101,13 @@ FIRRTLType DShrPrimOp::inferBinaryReturnType(FIRRTLType lhs, FIRRTLType rhs,
// Unary Primitives
//===----------------------------------------------------------------------===//
FIRRTLType
SizeOfIntrinsicOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType SizeOfIntrinsicOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
return UIntType::get(input.getContext(), 32);
}
FIRRTLType AsSIntPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType AsSIntPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
auto base = type_dyn_cast<FIRRTLBaseType>(input);
if (!base)
return emitInferRetTypeError(loc, "operand must be a scalar base type");
@ -5177,8 +5117,8 @@ FIRRTLType AsSIntPrimOp::inferUnaryReturnType(FIRRTLType input,
return SIntType::get(input.getContext(), width, base.isConst());
}
FIRRTLType AsUIntPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType AsUIntPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
auto base = type_dyn_cast<FIRRTLBaseType>(input);
if (!base)
return emitInferRetTypeError(loc, "operand must be a scalar base type");
@ -5188,9 +5128,8 @@ FIRRTLType AsUIntPrimOp::inferUnaryReturnType(FIRRTLType input,
return UIntType::get(input.getContext(), width, base.isConst());
}
FIRRTLType
AsAsyncResetPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType AsAsyncResetPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
auto base = type_dyn_cast<FIRRTLBaseType>(input);
if (!base)
return emitInferRetTypeError(loc,
@ -5201,13 +5140,13 @@ AsAsyncResetPrimOp::inferUnaryReturnType(FIRRTLType input,
return AsyncResetType::get(input.getContext(), base.isConst());
}
FIRRTLType AsClockPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType AsClockPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
return ClockType::get(input.getContext(), isConst(input));
}
FIRRTLType CvtPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType CvtPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
if (auto uiType = type_dyn_cast<UIntType>(input)) {
auto width = uiType.getWidthOrSentinel();
if (width != -1)
@ -5221,8 +5160,8 @@ FIRRTLType CvtPrimOp::inferUnaryReturnType(FIRRTLType input,
return emitInferRetTypeError(loc, "operand must have integer type");
}
FIRRTLType NegPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType NegPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
auto inputi = type_dyn_cast<IntType>(input);
if (!inputi)
return emitInferRetTypeError(loc, "operand must have integer type");
@ -5232,8 +5171,8 @@ FIRRTLType NegPrimOp::inferUnaryReturnType(FIRRTLType input,
return SIntType::get(input.getContext(), width, inputi.isConst());
}
FIRRTLType NotPrimOp::inferUnaryReturnType(FIRRTLType input,
std::optional<Location> loc) {
FIRRTLType NotPrimOp::inferReturnType(FIRRTLType input,
std::optional<Location> loc) {
auto inputi = type_dyn_cast<IntType>(input);
if (!inputi)
return emitInferRetTypeError(loc, "operand must have integer type");
@ -5252,13 +5191,9 @@ FIRRTLType impl::inferReductionResult(FIRRTLType input,
// Other Operations
//===----------------------------------------------------------------------===//
FIRRTLType BitsPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType BitsPrimOp::inferReturnType(FIRRTLType input, int64_t high,
int64_t low,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto high = getAttr<IntegerAttr>(attrs, "hi").getValue().getSExtValue();
auto low = getAttr<IntegerAttr>(attrs, "lo").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (!inputi)
return emitInferRetTypeError(
@ -5285,11 +5220,8 @@ FIRRTLType BitsPrimOp::inferReturnType(ValueRange operands,
return UIntType::get(input.getContext(), high - low + 1, inputi.isConst());
}
FIRRTLType HeadPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType HeadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (amount < 0 || !inputi)
@ -5396,19 +5328,20 @@ static FIRRTLBaseType inferMuxReturnType(FIRRTLBaseType high,
", false value type: ", low);
}
FIRRTLType MuxPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType MuxPrimOp::inferReturnType(FIRRTLType sel, FIRRTLType high,
FIRRTLType low,
std::optional<Location> loc) {
auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
auto highType = type_dyn_cast<FIRRTLBaseType>(high);
auto lowType = type_dyn_cast<FIRRTLBaseType>(low);
if (!highType || !lowType)
return emitInferRetTypeError(loc, "operands must be base type");
return inferMuxReturnType(highType, lowType, isConst(operands[0].getType()),
loc);
return inferMuxReturnType(highType, lowType, isConst(sel), loc);
}
FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
auto highType = type_dyn_cast<FIRRTLBaseType>(operands[1].getType());
auto lowType = type_dyn_cast<FIRRTLBaseType>(operands[2].getType());
@ -5419,7 +5352,9 @@ FIRRTLType Mux2CellIntrinsicOp::inferReturnType(ValueRange operands,
}
FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
SmallVector<FIRRTLBaseType> types;
FIRRTLBaseType result;
@ -5439,12 +5374,8 @@ FIRRTLType Mux4CellIntrinsicOp::inferReturnType(ValueRange operands,
return result;
}
FIRRTLType PadPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType PadPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (amount < 0 || !inputi)
return emitInferRetTypeError(
@ -5459,12 +5390,8 @@ FIRRTLType PadPrimOp::inferReturnType(ValueRange operands,
inputi.isConst());
}
FIRRTLType ShlPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType ShlPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (amount < 0 || !inputi)
return emitInferRetTypeError(
@ -5478,12 +5405,8 @@ FIRRTLType ShlPrimOp::inferReturnType(ValueRange operands,
inputi.isConst());
}
FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType ShrPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (amount < 0 || !inputi)
return emitInferRetTypeError(
@ -5500,11 +5423,8 @@ FIRRTLType ShrPrimOp::inferReturnType(ValueRange operands,
inputi.isConst());
}
FIRRTLType TailPrimOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType TailPrimOp::inferReturnType(FIRRTLType input, int64_t amount,
std::optional<Location> loc) {
auto input = operands[0].getType();
auto amount = getAttr<IntegerAttr>(attrs, "amount").getValue().getSExtValue();
auto inputi = type_dyn_cast<IntType>(input);
if (amount < 0 || !inputi)
@ -6121,7 +6041,9 @@ void RWProbeOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
}
FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Type inType = operands[0].getType();
auto inRefType = type_dyn_cast<RefType>(inType);
@ -6131,8 +6053,9 @@ FIRRTLType RefResolveOp::inferReturnType(ValueRange operands,
return inRefType.getType();
}
FIRRTLType RefSendOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType RefSendOp::inferReturnType(ValueRange operands, DictionaryAttr attrs,
OpaqueProperties properties,
mlir::RegionRange regions,
std::optional<Location> loc) {
Type inType = operands[0].getType();
auto inBaseType = type_dyn_cast<FIRRTLBaseType>(inType);
@ -6142,39 +6065,38 @@ FIRRTLType RefSendOp::inferReturnType(ValueRange operands,
return RefType::get(inBaseType.getPassiveType());
}
FIRRTLType RefSubOp::inferReturnType(ValueRange operands,
ArrayRef<NamedAttribute> attrs,
FIRRTLType RefSubOp::inferReturnType(Type type, uint32_t fieldIndex,
std::optional<Location> loc) {
auto refType = type_dyn_cast<RefType>(operands[0].getType());
auto refType = type_dyn_cast<RefType>(type);
if (!refType)
return emitInferRetTypeError(loc, "input must be of reference type");
auto inType = refType.getType();
auto fieldIdx =
getAttr<IntegerAttr>(attrs, "index").getValue().getZExtValue();
// TODO: Determine ref.sub + rwprobe behavior, test.
// Probably best to demote to non-rw, but that has implications
// for any LowerTypes behavior being relied on.
// Allow for now, as need to LowerTypes things generally.
if (auto vectorType = type_dyn_cast<FVectorType>(inType)) {
if (fieldIdx < vectorType.getNumElements())
if (fieldIndex < vectorType.getNumElements())
return RefType::get(
vectorType.getElementType().getConstType(
vectorType.isConst() || vectorType.getElementType().isConst()),
refType.getForceable(), refType.getLayer());
return emitInferRetTypeError(loc, "out of range index '", fieldIdx,
return emitInferRetTypeError(loc, "out of range index '", fieldIndex,
"' in RefType of vector type ", refType);
}
if (auto bundleType = type_dyn_cast<BundleType>(inType)) {
if (fieldIdx >= bundleType.getNumElements()) {
if (fieldIndex >= bundleType.getNumElements()) {
return emitInferRetTypeError(loc,
"subfield element index is greater than "
"the number of fields in the bundle type");
}
return RefType::get(bundleType.getElement(fieldIdx).type.getConstType(
bundleType.isConst() ||
bundleType.getElement(fieldIdx).type.isConst()),
refType.getForceable(), refType.getLayer());
return RefType::get(
bundleType.getElement(fieldIndex)
.type.getConstType(
bundleType.isConst() ||
bundleType.getElement(fieldIndex).type.isConst()),
refType.getForceable(), refType.getLayer());
}
return emitInferRetTypeError(

View File

@ -42,6 +42,7 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <utility>
using namespace circt;
using namespace firrtl;
@ -1772,11 +1773,67 @@ private:
ParseResult parsePostFixFieldId(Value &result);
ParseResult parsePostFixIntSubscript(Value &result);
ParseResult parsePostFixDynamicSubscript(Value &result);
ParseResult parsePrimExp(Value &result);
ParseResult parseIntegerLiteralExp(Value &result);
ParseResult parseListExp(Value &result);
ParseResult parseListConcatExp(Value &result);
template <typename T, size_t M, size_t N, size_t... Ms, size_t... Ns>
ParseResult parsePrim(std::index_sequence<Ms...>, std::index_sequence<Ns...>,
Value &result) {
auto loc = getToken().getLoc();
locationProcessor.setLoc(loc);
consumeToken();
auto vals = std::array<Value, M>();
auto ints = std::array<int64_t, N>();
// Parse all the values.
bool first = true;
for (size_t i = 0; i < M; ++i) {
if (!first)
if (parseToken(FIRToken::comma, "expected ','"))
return failure();
if (parseExp(vals[i], "expected expression in primitive operand"))
return failure();
first = false;
}
// Parse all the attributes.
for (size_t i = 0; i < N; ++i) {
if (!first)
if (parseToken(FIRToken::comma, "expected ','"))
return failure();
if (parseIntLit(ints[i], "expected integer in primitive operand"))
return failure();
first = false;
}
if (parseToken(FIRToken::r_paren, "expected ')'"))
return failure();
// Infer the type.
auto type = T::inferReturnType(cast<FIRRTLType>(vals[Ms].getType())...,
ints[Ns]..., {});
if (!type) {
// Only call translateLocation on an error case, it is expensive.
T::inferReturnType(cast<FIRRTLType>(vals[Ms].getType())..., ints[Ns]...,
translateLocation(loc));
return failure();
}
// Create the operation.
auto op = builder.create<T>(type, vals[Ms]..., ints[Ns]...);
result = op.getResult();
return success();
}
template <typename T, unsigned M, unsigned N>
ParseResult parsePrimExp(Value &result) {
auto ms = std::make_index_sequence<M>();
auto ns = std::make_index_sequence<N>();
return parsePrim<T, M, N>(ms, ns, result);
}
std::optional<ParseResult> parseExpWithLeadingKeyword(FIRToken keyword);
// Stmt Parsing
@ -1917,15 +1974,28 @@ void FIRStmtParser::emitInvalidate(Value val, Flow flow) {
// NOLINTNEXTLINE(misc-no-recursion)
ParseResult FIRStmtParser::parseExpImpl(Value &result, const Twine &message,
bool isLeadingStmt) {
switch (getToken().getKind()) {
// Handle all primitive's.
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS) \
case FIRToken::lp_##SPELLING:
#include "FIRTokenKinds.def"
if (parsePrimExp(result))
auto token = getToken();
auto kind = token.getKind();
switch (kind) {
case FIRToken::lp_integer_add:
case FIRToken::lp_integer_mul:
case FIRToken::lp_integer_shr:
case FIRToken::lp_integer_shl:
if (requireFeature({4, 0, 0}, "Integer arithmetic expressions"))
return failure();
break;
default:
break;
}
switch (kind) {
// Handle all primitive's.
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS, NUMATTRIBUTES) \
case FIRToken::lp_##SPELLING: \
if (parsePrimExp<CLASS, NUMOPERANDS, NUMATTRIBUTES>(result)) \
return failure(); \
break;
#include "FIRTokenKinds.def"
case FIRToken::l_brace_bar:
if (isLeadingStmt)
@ -2113,6 +2183,18 @@ ParseResult FIRStmtParser::parseExpImpl(Value &result, const Twine &message,
break;
}
}
// Don't add code here, the common cases of these switch statements will be
// merged. This allows for fixing up primops after they have been created.
switch (kind) {
case FIRToken::lp_shr:
// For FIRRTL versions earlier than 4.0.0, insert pad(_, 1) around any
// unsigned shr This ensures the minimum width is 1 (but can be greater)
if (version < FIRVersion(4, 0, 0) && type_isa<UIntType>(result.getType()))
result = builder.create<PadPrimOp>(result, 1);
break;
default:
break;
}
return parseOptionalExpPostscript(result);
}
@ -2162,10 +2244,11 @@ FIRStmtParser::emitCachedSubAccess(Value base, ArrayRef<NamedAttribute> attrs,
unsigned indexNo, SMLoc loc) {
// Make sure the field name matches up with the input value's type and
// compute the result type for the expression.
auto resultType = subop::inferReturnType({base}, attrs, {});
auto baseType = cast<FIRRTLType>(base.getType());
auto resultType = subop::inferReturnType(baseType, indexNo, {});
if (!resultType) {
// Emit the error at the right location. translateLocation is expensive.
(void)subop::inferReturnType({base}, attrs, translateLocation(loc));
(void)subop::inferReturnType(baseType, indexNo, translateLocation(loc));
return failure();
}
@ -2296,10 +2379,11 @@ ParseResult FIRStmtParser::parsePostFixDynamicSubscript(Value &result) {
// Make sure the index expression is valid and compute the result type for the
// expression.
auto resultType = SubaccessOp::inferReturnType({result, index}, {}, {});
auto resultType =
SubaccessOp::inferReturnType(result.getType(), index.getType(), {});
if (!resultType) {
// Emit the error at the right location. translateLocation is expensive.
(void)SubaccessOp::inferReturnType({result, index}, {},
(void)SubaccessOp::inferReturnType(result.getType(), index.getType(),
translateLocation(loc));
return failure();
}
@ -2310,136 +2394,6 @@ ParseResult FIRStmtParser::parsePostFixDynamicSubscript(Value &result) {
return success();
}
/// prim ::= primop exp* intLit* ')'
ParseResult FIRStmtParser::parsePrimExp(Value &result) {
auto kind = getToken().getKind();
auto loc = getToken().getLoc();
consumeToken();
// Parse the operands and constant integer arguments.
SmallVector<Value, 3> operands;
SmallVector<int64_t, 3> integers;
if (parseListUntil(FIRToken::r_paren, [&]() -> ParseResult {
// Handle the integer constant case if present.
if (getToken().isAny(FIRToken::integer, FIRToken::signed_integer,
FIRToken::string)) {
integers.push_back(0);
return parseIntLit(integers.back(), "expected integer");
}
// Otherwise it must be a value operand. These must all come before the
// integers.
if (!integers.empty())
return emitError("expected more integer constants"), failure();
Value operand;
if (parseExp(operand, "expected expression in primitive operand"))
return failure();
locationProcessor.setLoc(loc);
operands.push_back(operand);
return success();
}))
return failure();
locationProcessor.setLoc(loc);
SmallVector<FIRRTLType, 3> opTypes;
for (auto v : operands)
opTypes.push_back(type_cast<FIRRTLType>(v.getType()));
unsigned numOperandsExpected;
SmallVector<StringAttr, 2> attrNames;
// Get information about the primitive in question.
switch (kind) {
default:
emitError(loc, "primitive not supported yet");
return failure();
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS) \
case FIRToken::lp_##SPELLING: \
numOperandsExpected = NUMOPERANDS; \
break;
#include "FIRTokenKinds.def"
}
// Don't add code here, we want these two switch statements to be fused by
// the compiler.
switch (kind) {
default:
break;
case FIRToken::lp_bits:
attrNames.push_back(getConstants().hiIdentifier); // "hi"
attrNames.push_back(getConstants().loIdentifier); // "lo"
break;
case FIRToken::lp_head:
case FIRToken::lp_pad:
case FIRToken::lp_shl:
case FIRToken::lp_shr:
case FIRToken::lp_tail:
attrNames.push_back(getConstants().amountIdentifier);
break;
case FIRToken::lp_integer_add:
case FIRToken::lp_integer_mul:
case FIRToken::lp_integer_shr:
case FIRToken::lp_integer_shl:
if (requireFeature({4, 0, 0}, "Integer arithmetic expressions", loc))
return failure();
break;
}
if (operands.size() != numOperandsExpected) {
assert(numOperandsExpected <= 3);
static const char *numberName[] = {"zero", "one", "two", "three"};
const char *optionalS = &"s"[numOperandsExpected == 1];
return emitError(loc, "operation requires ")
<< numberName[numOperandsExpected] << " operand" << optionalS;
}
if (integers.size() != attrNames.size()) {
emitError(loc, "expected ") << attrNames.size() << " constant arguments";
return failure();
}
NamedAttrList attrs;
for (size_t i = 0, e = attrNames.size(); i != e; ++i)
attrs.append(attrNames[i], builder.getI32IntegerAttr(integers[i]));
switch (kind) {
default:
emitError(loc, "primitive not supported yet");
return failure();
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS) \
case FIRToken::lp_##SPELLING: { \
auto resultTy = CLASS::inferReturnType(operands, attrs, {}); \
if (!resultTy) { \
/* only call translateLocation on an error case, it is expensive. */ \
CLASS::inferReturnType(operands, attrs, translateLocation(loc)); \
return failure(); \
} \
result = builder.create<CLASS>(resultTy, operands, attrs); \
break; \
}
#include "FIRTokenKinds.def"
}
// Don't add code here, the common cases of these switch statements will be
// merged. This allows for fixing up primops after they have been created.
switch (kind) {
default:
break;
case FIRToken::lp_shr:
// For FIRRTL versions earlier than 4.0.0, insert pad(_, 1) around any
// unsigned shr This ensures the minimum width is 1 (but can be greater)
if (version < FIRVersion(4, 0, 0) && type_isa<UIntType>(result.getType()))
result = builder.create<PadPrimOp>(result, 1);
break;
}
return success();
}
/// integer-literal-exp ::= 'UInt' optional-width '(' intLit ')'
/// ::= 'SInt' optional-width '(' intLit ')'
ParseResult FIRStmtParser::parseIntegerLiteralExp(Value &result) {

View File

@ -37,7 +37,7 @@
#define TOK_LPKEYWORD(SPELLING)
#endif
#ifndef TOK_LPKEYWORD_PRIM
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS) TOK_LPKEYWORD(SPELLING)
#define TOK_LPKEYWORD_PRIM(SPELLING, CLASS, NUMOPERANDS, NUMATTRIBUTES) TOK_LPKEYWORD(SPELLING)
#endif
// Markers
@ -185,45 +185,45 @@ TOK_LPKEYWORD(rwprobe)
TOK_LPKEYWORD(intrinsic)
// These are for LPKEYWORD cases that correspond to a primitive operation.
TOK_LPKEYWORD_PRIM(add, AddPrimOp, 2)
TOK_LPKEYWORD_PRIM(and, AndPrimOp, 2)
TOK_LPKEYWORD_PRIM(andr, AndRPrimOp, 1)
TOK_LPKEYWORD_PRIM(asAsyncReset, AsAsyncResetPrimOp, 1)
TOK_LPKEYWORD_PRIM(asClock, AsClockPrimOp, 1)
TOK_LPKEYWORD_PRIM(asSInt, AsSIntPrimOp, 1)
TOK_LPKEYWORD_PRIM(asUInt, AsUIntPrimOp, 1)
TOK_LPKEYWORD_PRIM(bits, BitsPrimOp, 1)
TOK_LPKEYWORD_PRIM(cat, CatPrimOp, 2)
TOK_LPKEYWORD_PRIM(cvt, CvtPrimOp, 1)
TOK_LPKEYWORD_PRIM(div, DivPrimOp, 2)
TOK_LPKEYWORD_PRIM(dshl, DShlPrimOp, 2)
TOK_LPKEYWORD_PRIM(dshlw, DShlwPrimOp, 2)
TOK_LPKEYWORD_PRIM(dshr, DShrPrimOp, 2)
TOK_LPKEYWORD_PRIM(eq, EQPrimOp, 2)
TOK_LPKEYWORD_PRIM(geq, GEQPrimOp, 2)
TOK_LPKEYWORD_PRIM(gt, GTPrimOp, 2)
TOK_LPKEYWORD_PRIM(head, HeadPrimOp, 1)
TOK_LPKEYWORD_PRIM(leq, LEQPrimOp, 2)
TOK_LPKEYWORD_PRIM(lt, LTPrimOp, 2)
TOK_LPKEYWORD_PRIM(mul, MulPrimOp, 2)
TOK_LPKEYWORD_PRIM(mux, MuxPrimOp, 3)
TOK_LPKEYWORD_PRIM(neg, NegPrimOp, 1)
TOK_LPKEYWORD_PRIM(neq, NEQPrimOp, 2)
TOK_LPKEYWORD_PRIM(not, NotPrimOp, 1)
TOK_LPKEYWORD_PRIM(or, OrPrimOp, 2)
TOK_LPKEYWORD_PRIM(orr, OrRPrimOp, 1)
TOK_LPKEYWORD_PRIM(pad, PadPrimOp, 1)
TOK_LPKEYWORD_PRIM(rem, RemPrimOp, 2)
TOK_LPKEYWORD_PRIM(shl, ShlPrimOp, 1)
TOK_LPKEYWORD_PRIM(shr, ShrPrimOp, 1)
TOK_LPKEYWORD_PRIM(sub, SubPrimOp, 2)
TOK_LPKEYWORD_PRIM(tail, TailPrimOp, 1)
TOK_LPKEYWORD_PRIM(xor, XorPrimOp, 2)
TOK_LPKEYWORD_PRIM(xorr, XorRPrimOp, 1)
TOK_LPKEYWORD_PRIM(integer_add, IntegerAddOp, 2)
TOK_LPKEYWORD_PRIM(integer_mul, IntegerMulOp, 2)
TOK_LPKEYWORD_PRIM(integer_shr, IntegerShrOp, 2)
TOK_LPKEYWORD_PRIM(integer_shl, IntegerShlOp, 2)
TOK_LPKEYWORD_PRIM(add, AddPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(and, AndPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(andr, AndRPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(asAsyncReset, AsAsyncResetPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(asClock, AsClockPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(asSInt, AsSIntPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(asUInt, AsUIntPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(bits, BitsPrimOp, 1, 2)
TOK_LPKEYWORD_PRIM(cat, CatPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(cvt, CvtPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(div, DivPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(dshl, DShlPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(dshlw, DShlwPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(dshr, DShrPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(eq, EQPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(geq, GEQPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(gt, GTPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(head, HeadPrimOp, 1, 1)
TOK_LPKEYWORD_PRIM(leq, LEQPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(lt, LTPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(mul, MulPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(mux, MuxPrimOp, 3, 0)
TOK_LPKEYWORD_PRIM(neg, NegPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(neq, NEQPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(not, NotPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(or, OrPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(orr, OrRPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(pad, PadPrimOp, 1, 1)
TOK_LPKEYWORD_PRIM(rem, RemPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(shl, ShlPrimOp, 1, 1)
TOK_LPKEYWORD_PRIM(shr, ShrPrimOp, 1, 1)
TOK_LPKEYWORD_PRIM(sub, SubPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(tail, TailPrimOp, 1, 1)
TOK_LPKEYWORD_PRIM(xor, XorPrimOp, 2, 0)
TOK_LPKEYWORD_PRIM(xorr, XorRPrimOp, 1, 0)
TOK_LPKEYWORD_PRIM(integer_add, IntegerAddOp, 2, 0)
TOK_LPKEYWORD_PRIM(integer_mul, IntegerMulOp, 2, 0)
TOK_LPKEYWORD_PRIM(integer_shr, IntegerShrOp, 2, 0)
TOK_LPKEYWORD_PRIM(integer_shl, IntegerShlOp, 2, 0)
#undef TOK_MARKER
#undef TOK_IDENTIFIER

View File

@ -144,7 +144,7 @@ circuit trailing_comma :
public module trailing_comma :
input in0 : SInt<8>
input in1 : SInt<8>
; expected-error @+1 {{expected expression in primitive operand}}
; expected-error @+1 {{expected ')'}}
node n = add(in0, in1,)
;// -----
@ -164,7 +164,7 @@ circuit invalid_add :
public module invalid_add :
input in : SInt<8>
input c : Clock
; expected-error @+1 {{operation requires two operands}}
; expected-error @+1 {{expected ')'}}
node n = add(in, in, in)
;// -----
@ -266,7 +266,7 @@ circuit Issue418:
input a: UInt<1>
output b: UInt<1>
; expected-error @+1 {{operation requires one operand}}
; expected-error @+1 {{expected ')'}}
connect b, not(a, a)
;// -----