[Calyx][OM][Pipeline] Use free variants of isa/cast/dyn_cast

Refer to https://mlir.llvm.org/deprecation/
This commit is contained in:
Martin Erhart 2024-04-28 16:53:58 +02:00
parent ad1b56c01b
commit 594e3fb83f
15 changed files with 99 additions and 109 deletions

View File

@ -60,7 +60,7 @@ public:
return success();
// If the operation has the static attribute, verify it is zero.
APInt staticValue = staticAttribute.cast<IntegerAttr>().getValue();
APInt staticValue = cast<IntegerAttr>(staticAttribute).getValue();
assert(staticValue == 0 && "If combinational, it should take 0 cycles.");
return success();

View File

@ -15,7 +15,7 @@ include "mlir/IR/EnumAttr.td"
// "Forward-declare" these HW attributes rather than including or duplicating
// them here. This lets us to refer to them in ODS, but delegates to HW in C++.
// These are used to represent parameters for the PrimitiveOp.
def ParamDeclAttr : Attr<CPred<"$_self.isa<hw::ParamDeclAttr>()">>;
def ParamDeclAttr : Attr<CPred<"llvm::isa<hw::ParamDeclAttr>($_self)">>;
def ParamDeclArrayAttr : TypedArrayAttrBase<ParamDeclAttr, "parameter array">;
def UndefinedOp : CalyxOp<"undef", [
@ -104,7 +104,7 @@ def ComponentOp : CalyxOp<"component", [
/// an error if the attribute is invalid.
LogicalResult verifyType() {
auto type = getFunctionTypeAttr().getValue();
if (!type.isa<FunctionType>())
if (!llvm::isa<FunctionType>(type))
return emitOpError("requires '" +
getFunctionTypeAttrName().getValue() +
"' attribute of function type");
@ -215,7 +215,7 @@ def CombComponentOp : CalyxOp<"comb_component", [
/// an error if the attribute is invalid.
LogicalResult verifyType() {
auto type = getFunctionTypeAttr().getValue();
if (!type.isa<FunctionType>())
if (!llvm::isa<FunctionType>(type))
return emitOpError("requires '" +
getFunctionTypeAttrName().getValue() +
"' attribute of function type");

View File

@ -296,7 +296,7 @@ def StageOp : Op<Pipeline_Dialect, "stage", [
// Returns the register name for a given register index.
StringAttr getRegisterName(unsigned regIdx) {
if(auto names = getOperation()->getAttrOfType<ArrayAttr>("registerNames")) {
auto name = names[regIdx].cast<StringAttr>();
auto name = llvm::cast<StringAttr>(names[regIdx]);
if(!name.strref().empty())
return name;
}
@ -307,7 +307,7 @@ def StageOp : Op<Pipeline_Dialect, "stage", [
// Returns the passthrough name for a given passthrough index.
StringAttr getPassthroughName(unsigned passthroughIdx) {
if(auto names = getOperation()->getAttrOfType<ArrayAttr>("passthroughNames")) {
auto name = names[passthroughIdx].cast<StringAttr>();
auto name = llvm::cast<StringAttr>(names[passthroughIdx]);
if(!name.strref().empty())
return name;
}

View File

@ -30,7 +30,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OM, om, OMDialect)
//===----------------------------------------------------------------------===//
/// Is the Type a ClassType.
bool omTypeIsAClassType(MlirType type) { return unwrap(type).isa<ClassType>(); }
bool omTypeIsAClassType(MlirType type) { return isa<ClassType>(unwrap(type)); }
/// Get the TypeID for a ClassType.
MlirTypeID omClassTypeGetTypeID() { return wrap(ClassType::getTypeID()); }
@ -62,7 +62,7 @@ MlirTypeID omFrozenPathTypeGetTypeID(void) {
/// Is the Type a StringType.
bool omTypeIsAStringType(MlirType type) {
return unwrap(type).isa<StringType>();
return isa<StringType>(unwrap(type));
}
/// Get a StringType.
@ -72,7 +72,7 @@ MlirType omStringTypeGet(MlirContext ctx) {
/// Return a key type of a map.
MlirType omMapTypeGetKeyType(MlirType type) {
return wrap(unwrap(type).cast<MapType>().getKeyType());
return wrap(cast<MapType>(unwrap(type)).getKeyType());
}
//===----------------------------------------------------------------------===//
@ -117,7 +117,7 @@ OMEvaluatorValue omEvaluatorInstantiate(OMEvaluator evaluator,
Evaluator *cppEvaluator = unwrap(evaluator);
// Unwrap the className, which the client must supply as a StringAttr.
StringAttr cppClassName = unwrap(className).cast<StringAttr>();
StringAttr cppClassName = cast<StringAttr>(unwrap(className));
// Unwrap the actual parameters.
SmallVector<std::shared_ptr<evaluator::EvaluatorValue>> cppActualParams;
@ -189,7 +189,7 @@ OMEvaluatorValue omEvaluatorObjectGetField(OMEvaluatorValue object,
// supply as a StringAttr.
FailureOr<EvaluatorValuePtr> result =
llvm::cast<Object>(unwrap(object).get())
->getField(unwrap(name).cast<StringAttr>());
->getField(cast<StringAttr>(unwrap(name)));
// If getField failed, return a null EvaluatorValue. A Diagnostic will be
// emitted in this case.
@ -355,12 +355,12 @@ omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue) {
//===----------------------------------------------------------------------===//
bool omAttrIsAReferenceAttr(MlirAttribute attr) {
return unwrap(attr).isa<ReferenceAttr>();
return isa<ReferenceAttr>(unwrap(attr));
}
MlirAttribute omReferenceAttrGetInnerRef(MlirAttribute referenceAttr) {
return wrap(
(Attribute)unwrap(referenceAttr).cast<ReferenceAttr>().getInnerRef());
(Attribute)cast<ReferenceAttr>(unwrap(referenceAttr)).getInnerRef());
}
//===----------------------------------------------------------------------===//
@ -368,7 +368,7 @@ MlirAttribute omReferenceAttrGetInnerRef(MlirAttribute referenceAttr) {
//===----------------------------------------------------------------------===//
bool omAttrIsAIntegerAttr(MlirAttribute attr) {
return unwrap(attr).isa<circt::om::IntegerAttr>();
return isa<circt::om::IntegerAttr>(unwrap(attr));
}
MlirAttribute omIntegerAttrGetInt(MlirAttribute attr) {
@ -396,7 +396,7 @@ MlirStringRef omIntegerAttrToString(MlirAttribute attr) {
//===----------------------------------------------------------------------===//
bool omAttrIsAListAttr(MlirAttribute attr) {
return unwrap(attr).isa<ListAttr>();
return isa<ListAttr>(unwrap(attr));
}
intptr_t omListAttrGetNumElements(MlirAttribute attr) {
@ -413,9 +413,7 @@ MlirAttribute omListAttrGetElement(MlirAttribute attr, intptr_t pos) {
// MapAttr API.
//===----------------------------------------------------------------------===//
bool omAttrIsAMapAttr(MlirAttribute attr) {
return unwrap(attr).isa<MapAttr>();
}
bool omAttrIsAMapAttr(MlirAttribute attr) { return isa<MapAttr>(unwrap(attr)); }
intptr_t omMapAttrGetNumElements(MlirAttribute attr) {
auto mapAttr = llvm::cast<MapAttr>(unwrap(attr));

View File

@ -138,7 +138,7 @@ LogicalResult CalyxRemoveGroupsFromFSM::outlineMachine() {
llvm::MapVector<Value, SmallVector<Operation *>> referencedValues;
machineOp.walk([&](Operation *op) {
for (auto &operand : op->getOpOperands()) {
if (auto barg = operand.get().dyn_cast<BlockArgument>()) {
if (auto barg = dyn_cast<BlockArgument>(operand.get())) {
if (barg.getOwner()->getParentOp() == machineOp)
continue;
@ -203,8 +203,8 @@ LogicalResult CalyxRemoveGroupsFromFSM::outlineMachine() {
// First we inspect the groupDoneInputsAttr map and create backedges.
for (auto &namedAttr : groupDoneInputsAttr.getValue()) {
auto name = namedAttr.getName();
auto idx = namedAttr.getValue().cast<IntegerAttr>();
auto inputIdx = idx.cast<IntegerAttr>().getInt();
auto idx = cast<IntegerAttr>(namedAttr.getValue());
auto inputIdx = cast<IntegerAttr>(idx).getInt();
if (fsmInputMap.count(inputIdx))
return emitError(machineOp.getLoc())
<< "MachineOp has duplicate input index " << idx;
@ -266,7 +266,7 @@ LogicalResult CalyxRemoveGroupsFromFSM::outlineMachine() {
// Record the FSM output group go signals.
for (auto namedAttr : groupGoOutputsAttr.getValue()) {
auto name = namedAttr.getName();
auto idx = namedAttr.getValue().cast<IntegerAttr>().getInt();
auto idx = cast<IntegerAttr>(namedAttr.getValue()).getInt();
groupGoSignals[name] = fsmInstance.getResult(idx);
}

View File

@ -774,7 +774,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
FunctionType funcType = funcOp.getFunctionType();
unsigned extMemCounter = 0;
for (auto arg : enumerate(funcOp.getArguments())) {
if (arg.value().getType().isa<MemRefType>()) {
if (isa<MemRefType>(arg.value().getType())) {
/// External memories
auto memName =
"ext_mem" + std::to_string(extMemoryCompPortIndices.size());
@ -835,11 +835,10 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
extMemPorts.readData = compOp.getArgument(inPortsIt++);
extMemPorts.done = compOp.getArgument(inPortsIt);
extMemPorts.writeData = compOp.getArgument(outPortsIt++);
unsigned nAddresses = extMemPortIndices.getFirst()
.getType()
.cast<MemRefType>()
.getShape()
.size();
unsigned nAddresses =
cast<MemRefType>(extMemPortIndices.getFirst().getType())
.getShape()
.size();
for (unsigned j = 0; j < nAddresses; ++j)
extMemPorts.addrPorts.push_back(compOp.getArgument(outPortsIt++));
extMemPorts.writeEn = compOp.getArgument(outPortsIt);
@ -964,7 +963,7 @@ class BuildPipelineRegs : public calyx::FuncOpPartialLoweringPattern {
// Create a register for passing this result to later stages.
Value value = operand.get();
Type resultType = value.getType();
assert(resultType.isa<IntegerType>() &&
assert(isa<IntegerType>(resultType) &&
"unsupported pipeline result type");
auto name = SmallString<20>("stage_");
name += std::to_string(stage.getStageNumber());
@ -1125,7 +1124,7 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
// Get the group and register that is temporarily being written to.
auto doneOp = group.getDoneOp();
auto tempReg =
cast<calyx::RegisterOp>(doneOp.getSrc().cast<OpResult>().getOwner());
cast<calyx::RegisterOp>(cast<OpResult>(doneOp.getSrc()).getOwner());
auto tempIn = tempReg.getIn();
auto tempWriteEn = tempReg.getWriteEn();

View File

@ -1044,7 +1044,7 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
FunctionType funcType = funcOp.getFunctionType();
unsigned extMemCounter = 0;
for (auto arg : enumerate(funcOp.getArguments())) {
if (arg.value().getType().isa<MemRefType>()) {
if (isa<MemRefType>(arg.value().getType())) {
/// External memories
auto memName =
"ext_mem" + std::to_string(extMemoryCompPortIndices.size());
@ -1119,11 +1119,10 @@ struct FuncOpConversion : public calyx::FuncOpPartialLoweringPattern {
extMemPorts.readData = compOp.getArgument(inPortsIt++);
extMemPorts.done = compOp.getArgument(inPortsIt);
extMemPorts.writeData = compOp.getArgument(outPortsIt++);
unsigned nAddresses = extMemPortIndices.getFirst()
.getType()
.cast<MemRefType>()
.getShape()
.size();
unsigned nAddresses =
cast<MemRefType>(extMemPortIndices.getFirst().getType())
.getShape()
.size();
for (unsigned j = 0; j < nAddresses; ++j)
extMemPorts.addrPorts.push_back(compOp.getArgument(outPortsIt++));
extMemPorts.writeEn = compOp.getArgument(outPortsIt);

View File

@ -133,8 +133,8 @@ static std::string valueName(Operation *scopeOp, Value v) {
/// port on a cell interface.
static bool isPort(Value value) {
Operation *definingOp = value.getDefiningOp();
return value.isa<BlockArgument>() ||
(definingOp && isa<CellInterface>(definingOp));
return isa<BlockArgument>(value) ||
isa_and_nonnull<CellInterface>(definingOp);
}
/// Gets the port for a given BlockArgument.
@ -357,7 +357,7 @@ static void eraseControlWithGroupAndConditional(OpTy op,
rewriter.eraseOp(group);
}
// Check the conditional after the Group, since it will be driven within.
if (!cond.isa<BlockArgument>() && cond.getDefiningOp()->use_empty())
if (!isa<BlockArgument>(cond) && cond.getDefiningOp()->use_empty())
rewriter.eraseOp(cond.getDefiningOp());
}
@ -374,7 +374,7 @@ static void eraseControlWithConditional(OpTy op, PatternRewriter &rewriter) {
rewriter.eraseOp(op);
// Check if conditional is still needed, and remove if it isn't
if (!cond.isa<BlockArgument>() && cond.getDefiningOp()->use_empty())
if (!isa<BlockArgument>(cond) && cond.getDefiningOp()->use_empty())
rewriter.eraseOp(cond.getDefiningOp());
}
@ -628,7 +628,7 @@ static Value getBlockArgumentWithName(StringRef name, ComponentOp op) {
ArrayAttr portNames = op.getPortNames();
for (size_t i = 0, e = portNames.size(); i != e; ++i) {
auto portName = portNames[i].cast<StringAttr>();
auto portName = cast<StringAttr>(portNames[i]);
if (portName.getValue() == name)
return op.getBodyBlock()->getArgument(i);
}
@ -666,10 +666,9 @@ SmallVector<PortInfo> ComponentOp::getPortInfo() {
SmallVector<PortInfo> results;
for (size_t i = 0, e = portNamesAttr.size(); i != e; ++i) {
results.push_back(PortInfo{portNamesAttr[i].cast<StringAttr>(),
portTypes[i],
results.push_back(PortInfo{cast<StringAttr>(portNamesAttr[i]), portTypes[i],
direction::get(portDirectionsAttr[i]),
portAttrs[i].cast<DictionaryAttr>()});
cast<DictionaryAttr>(portAttrs[i])});
}
return results;
}
@ -771,7 +770,7 @@ void ComponentOp::getAsmBlockArgumentNames(
auto ports = getPortNames();
auto *block = &getRegion()->front();
for (size_t i = 0, e = block->getNumArguments(); i != e; ++i)
setNameFn(block->getArgument(i), ports[i].cast<StringAttr>().getValue());
setNameFn(block->getArgument(i), cast<StringAttr>(ports[i]).getValue());
}
//===----------------------------------------------------------------------===//
@ -785,10 +784,9 @@ SmallVector<PortInfo> CombComponentOp::getPortInfo() {
SmallVector<PortInfo> results;
for (size_t i = 0, e = portNamesAttr.size(); i != e; ++i) {
results.push_back(PortInfo{portNamesAttr[i].cast<StringAttr>(),
portTypes[i],
results.push_back(PortInfo{cast<StringAttr>(portNamesAttr[i]), portTypes[i],
direction::get(portDirectionsAttr[i]),
portAttrs[i].cast<DictionaryAttr>()});
cast<DictionaryAttr>(portAttrs[i])});
}
return results;
}
@ -884,7 +882,7 @@ void CombComponentOp::getAsmBlockArgumentNames(
auto ports = getPortNames();
auto *block = &getRegion()->front();
for (size_t i = 0, e = block->getNumArguments(); i != e; ++i)
setNameFn(block->getArgument(i), ports[i].cast<StringAttr>().getValue());
setNameFn(block->getArgument(i), cast<StringAttr>(ports[i]).getValue());
}
//===----------------------------------------------------------------------===//
@ -1430,12 +1428,12 @@ static void getCellAsmResultNames(OpAsmSetValueNameFn setNameFn, Operation *op,
static LogicalResult verifyPortDirection(Operation *op, Value value,
bool isDestination) {
Operation *definingOp = value.getDefiningOp();
bool isComponentPort = value.isa<BlockArgument>(),
isCellInterfacePort = definingOp && isa<CellInterface>(definingOp);
bool isComponentPort = isa<BlockArgument>(value),
isCellInterfacePort = isa_and_nonnull<CellInterface>(definingOp);
assert((isComponentPort || isCellInterfacePort) && "Not a port.");
PortInfo port = isComponentPort
? getPortInfo(value.cast<BlockArgument>())
? getPortInfo(cast<BlockArgument>(value))
: cast<CellInterface>(definingOp).portInfo(value);
bool isSource = !isDestination;
@ -1624,7 +1622,7 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
SmallVector<StringRef> InstanceOp::portNames() {
SmallVector<StringRef> portNames;
for (Attribute name : getReferencedComponent().getPortNames())
portNames.push_back(name.cast<StringAttr>().getValue());
portNames.push_back(cast<StringAttr>(name).getValue());
return portNames;
}
@ -1695,8 +1693,8 @@ verifyPrimitiveOpType(PrimitiveOp instance,
<< " but got " << numParams;
for (size_t i = 0; i != numExpected; ++i) {
auto param = parameters[i].cast<circt::hw::ParamDeclAttr>();
auto modParam = modParameters[i].cast<circt::hw::ParamDeclAttr>();
auto param = cast<circt::hw::ParamDeclAttr>(parameters[i]);
auto modParam = cast<circt::hw::ParamDeclAttr>(modParameters[i]);
auto paramName = param.getName();
if (paramName != modParam.getName())
@ -1884,7 +1882,7 @@ static void printParameterList(OpAsmPrinter &p, Operation *op,
p << '<';
llvm::interleaveComma(parameters, p, [&](Attribute param) {
auto paramAttr = param.cast<hw::ParamDeclAttr>();
auto paramAttr = cast<hw::ParamDeclAttr>(param);
p << paramAttr.getName().getValue() << ": " << paramAttr.getType();
if (auto value = paramAttr.getValue()) {
p << " = ";
@ -2066,8 +2064,8 @@ LogicalResult MemoryOp::verify() {
<< numAddrs;
for (size_t i = 0; i < numDims; ++i) {
int64_t size = opSizes[i].cast<IntegerAttr>().getInt();
int64_t addrSize = opAddrSizes[i].cast<IntegerAttr>().getInt();
int64_t size = cast<IntegerAttr>(opSizes[i]).getInt();
int64_t addrSize = cast<IntegerAttr>(opAddrSizes[i]).getInt();
if (llvm::Log2_64_Ceil(size) > addrSize)
return emitOpError("address size (")
<< addrSize << ") for dimension " << i
@ -2169,8 +2167,8 @@ LogicalResult SeqMemoryOp::verify() {
<< numAddrs;
for (size_t i = 0; i < numDims; ++i) {
int64_t size = opSizes[i].cast<IntegerAttr>().getInt();
int64_t addrSize = opAddrSizes[i].cast<IntegerAttr>().getInt();
int64_t size = cast<IntegerAttr>(opSizes[i]).getInt();
int64_t addrSize = cast<IntegerAttr>(opAddrSizes[i]).getInt();
if (llvm::Log2_64_Ceil(size) > addrSize)
return emitOpError("address size (")
<< addrSize << ") for dimension " << i
@ -2794,7 +2792,7 @@ LogicalResult InvokeOp::verify() {
// inputs are required to be destination ports.
if (failed(verifyInvokeOpValue(*this, port, true)))
return emitOpError() << "'@" << callee << "' has input '"
<< portName.cast<StringAttr>().getValue()
<< cast<StringAttr>(portName).getValue()
<< "', which is a source port. The inputs are "
"required to be destination ports.";
// The go port should not appear in the parameter list.
@ -2804,12 +2802,12 @@ LogicalResult InvokeOp::verify() {
// Check the direction of these source ports.
if (failed(verifyInvokeOpValue(*this, input, false)))
return emitOpError() << "'@" << callee << "' has output '"
<< inputName.cast<StringAttr>().getValue()
<< cast<StringAttr>(inputName).getValue()
<< "', which is a destination port. The inputs are "
"required to be source ports.";
if (failed(verifyComplexLogic(*this, input)))
return emitOpError() << "'@" << callee << "' has '"
<< inputName.cast<StringAttr>().getValue()
<< cast<StringAttr>(inputName).getValue()
<< "', which is not a port or constant. Complex "
"logic should be conducted in the guard.";
if (input == doneValue)
@ -2818,8 +2816,8 @@ LogicalResult InvokeOp::verify() {
// Check if the connection uses the callee's port.
if (port.getDefiningOp() != operation && input.getDefiningOp() != operation)
return emitOpError() << "the connection "
<< portName.cast<StringAttr>().getValue() << " = "
<< inputName.cast<StringAttr>().getValue()
<< cast<StringAttr>(portName).getValue() << " = "
<< cast<StringAttr>(inputName).getValue()
<< " is not defined as an input port of '@" << callee
<< "'.";
}

View File

@ -184,7 +184,7 @@ struct Emitter {
for (auto sourceLoc : llvm::enumerate(metadata)) {
// <index>: <source-location>\n
os << std::to_string(sourceLoc.index()) << colon();
os << sourceLoc.value().cast<StringAttr>().getValue() << endl();
os << cast<StringAttr>(sourceLoc.value()).getValue() << endl();
}
os << metadataRBrace();
@ -324,7 +324,7 @@ private:
bool isBooleanAttribute =
llvm::find(booleanAttributes, identifier) != booleanAttributes.end();
if (attr.getValue().isa<UnitAttr>()) {
if (isa<UnitAttr>(attr.getValue())) {
assert(isBooleanAttribute &&
"Non-boolean attributes must provide an integer value.");
if (!atFormat) {
@ -332,7 +332,7 @@ private:
} else {
buffer << addressSymbol() << identifier;
}
} else if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) {
} else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
APInt value = intAttr.getValue();
if (!atFormat) {
buffer << quote() << identifier << quote() << equals() << value;
@ -438,7 +438,7 @@ private:
/// Emits the value of a guard or assignment.
void emitValue(Value value, bool isIndented) {
if (auto blockArg = value.dyn_cast<BlockArgument>()) {
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
// Emit component block argument.
StringAttr portName = getPortInfo(blockArg).name;
(isIndented ? indent() : os) << portName.getValue();
@ -701,7 +701,7 @@ void Emitter::emitPrimitiveExtern(hw::HWModuleExternOp op) {
if (!op.getParameters().empty()) {
os << LSquare();
llvm::interleaveComma(op.getParameters(), os, [&](Attribute param) {
auto paramAttr = param.cast<hw::ParamDeclAttr>();
auto paramAttr = cast<hw::ParamDeclAttr>(param);
os << paramAttr.getName().str();
});
os << RSquare();
@ -730,10 +730,8 @@ void Emitter::emitPrimitivePorts(hw::HWModuleExternOp op) {
// We only care about the bit width in the emitted .futil file.
// Emit parameterized or non-parameterized bit width.
if (hw::isParametricType(port.type)) {
hw::ParamDeclRefAttr bitWidth =
port.type.template cast<hw::IntType>()
.getWidth()
.template dyn_cast<hw::ParamDeclRefAttr>();
hw::ParamDeclRefAttr bitWidth = dyn_cast<hw::ParamDeclRefAttr>(
cast<hw::IntType>(port.type).getWidth());
os << bitWidth.getName().str();
} else {
unsigned int bitWidth = port.type.getIntOrFloatBitWidth();
@ -764,11 +762,11 @@ void Emitter::emitPrimitive(PrimitiveOp op) {
if (op.getParameters().has_value()) {
llvm::interleaveComma(*op.getParameters(), os, [&](Attribute param) {
auto paramAttr = param.cast<hw::ParamDeclAttr>();
auto paramAttr = cast<hw::ParamDeclAttr>(param);
auto value = paramAttr.getValue();
if (auto intAttr = value.dyn_cast<IntegerAttr>()) {
if (auto intAttr = dyn_cast<IntegerAttr>(value)) {
os << intAttr.getInt();
} else if (auto fpAttr = value.dyn_cast<FloatAttr>()) {
} else if (auto fpAttr = dyn_cast<FloatAttr>(value)) {
os << fpAttr.getValue().convertToFloat();
} else {
llvm_unreachable("Primitive parameter type not supported");
@ -805,14 +803,14 @@ void Emitter::emitMemory(MemoryOp memory) {
<< std::to_string(dimension) << LParen() << memory.getWidth()
<< comma();
for (Attribute size : memory.getSizes()) {
APInt memSize = size.cast<IntegerAttr>().getValue();
APInt memSize = cast<IntegerAttr>(size).getValue();
memSize.print(os, /*isSigned=*/false);
os << comma();
}
ArrayAttr addrSizes = memory.getAddrSizes();
for (size_t i = 0, e = addrSizes.size(); i != e; ++i) {
APInt addrSize = addrSizes[i].cast<IntegerAttr>().getValue();
APInt addrSize = cast<IntegerAttr>(addrSizes[i]).getValue();
addrSize.print(os, /*isSigned=*/false);
if (i + 1 == e)
continue;
@ -833,14 +831,14 @@ void Emitter::emitSeqMemory(SeqMemoryOp memory) {
<< std::to_string(dimension) << LParen() << memory.getWidth()
<< comma();
for (Attribute size : memory.getSizes()) {
APInt memSize = size.cast<IntegerAttr>().getValue();
APInt memSize = cast<IntegerAttr>(size).getValue();
memSize.print(os, /*isSigned=*/false);
os << comma();
}
ArrayAttr addrSizes = memory.getAddrSizes();
for (size_t i = 0, e = addrSizes.size(); i != e; ++i) {
APInt addrSize = addrSizes[i].cast<IntegerAttr>().getValue();
APInt addrSize = cast<IntegerAttr>(addrSizes[i]).getValue();
addrSize.print(os, /*isSigned=*/false);
if (i + 1 == e)
continue;

View File

@ -33,7 +33,7 @@ void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName,
Value memref, unsigned memoryID,
SmallVectorImpl<calyx::PortInfo> &inPorts,
SmallVectorImpl<calyx::PortInfo> &outPorts) {
MemRefType memrefType = memref.getType().cast<MemRefType>();
MemRefType memrefType = cast<MemRefType>(memref.getType());
// Ports constituting a memory interface are added a set of attributes under
// a "mem : {...}" dictionary. These attributes allows for deducing which
@ -375,7 +375,7 @@ calyx::RegisterOp ComponentLoweringStateInterface::getReturnReg(unsigned idx) {
void ComponentLoweringStateInterface::registerMemoryInterface(
Value memref, const calyx::MemoryInterface &memoryInterface) {
assert(memref.getType().isa<MemRefType>());
assert(isa<MemRefType>(memref.getType()));
assert(memories.find(memref) == memories.end() &&
"Memory already registered for memref");
memories[memref] = memoryInterface;
@ -383,7 +383,7 @@ void ComponentLoweringStateInterface::registerMemoryInterface(
calyx::MemoryInterface
ComponentLoweringStateInterface::getMemoryInterface(Value memref) {
assert(memref.getType().isa<MemRefType>());
assert(isa<MemRefType>(memref.getType()));
auto it = memories.find(memref);
assert(it != memories.end() && "No memory registered for memref");
return it->second;
@ -655,7 +655,7 @@ void InlineCombGroups::recurseInlineCombGroups(
// return values have at the current point of conversion not yet
// been rewritten to their register outputs, see comment in
// LateSSAReplacement)
if (src.isa<BlockArgument>() ||
if (isa<BlockArgument>(src) ||
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
@ -730,7 +730,7 @@ BuildBasicBlockRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
for (auto arg : enumerate(block->getArguments())) {
Type argType = arg.value().getType();
assert(argType.isa<IntegerType>() && "unsupported block argument type");
assert(isa<IntegerType>(argType) && "unsupported block argument type");
unsigned width = argType.getIntOrFloatBitWidth();
std::string index = std::to_string(arg.index());
std::string name = loweringState().blockName(block) + "_arg" + index;
@ -753,7 +753,7 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
for (auto argType : enumerate(funcOp.getResultTypes())) {
auto convArgType = calyx::convIndexType(rewriter, argType.value());
assert(convArgType.isa<IntegerType>() && "unsupported return type");
assert(isa<IntegerType>(convArgType) && "unsupported return type");
unsigned width = convArgType.getIntOrFloatBitWidth();
std::string name = "ret_arg" + std::to_string(argType.index());
auto reg =

View File

@ -56,7 +56,7 @@ LogicalResult
circt::om::ListAttr::verify(function_ref<InFlightDiagnostic()> emitError,
mlir::Type elementType, mlir::ArrayAttr elements) {
return success(llvm::all_of(elements, [&](mlir::Attribute attr) {
auto typedAttr = attr.dyn_cast<mlir::TypedAttr>();
auto typedAttr = llvm::dyn_cast<mlir::TypedAttr>(attr);
if (!typedAttr) {
emitError()
<< "an element of a list attribute must be a typed attr but got "

View File

@ -365,7 +365,7 @@ circt::om::ObjectFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// If there are more fields, verify the current field is of ClassType, and
// look up the ClassOp for that field.
if (i < e - 1) {
auto classType = fieldDef.getType().dyn_cast<ClassType>();
auto classType = dyn_cast<ClassType>(fieldDef.getType());
if (!classType)
return emitOpError("nested field access into ")
<< field << " requires a ClassType, but found "
@ -463,7 +463,7 @@ LogicalResult TupleGetOp::inferReturnTypes(
if (operands.empty() || !idx)
return failure();
auto tupleTypes = operands[0].getType().cast<TupleType>().getTypes();
auto tupleTypes = cast<TupleType>(operands[0].getType()).getTypes();
if (tupleTypes.size() <= idx.getValue().getLimitedValue()) {
if (location)
mlir::emitError(*location,
@ -485,8 +485,8 @@ void circt::om::MapCreateOp::print(OpAsmPrinter &p) {
p << " ";
p.printOperands(getInputs());
p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << getType().cast<circt::om::MapType>().getKeyType() << ", "
<< getType().cast<circt::om::MapType>().getValueType();
p << " : " << cast<circt::om::MapType>(getType()).getKeyType() << ", "
<< cast<circt::om::MapType>(getType()).getValueType();
}
ParseResult circt::om::MapCreateOp::parse(OpAsmParser &parser,

View File

@ -70,9 +70,9 @@ Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
Block *circt::pipeline::getParentStageInPipeline(ScheduledPipelineOp pipeline,
Value v) {
if (v.isa<BlockArgument>())
if (isa<BlockArgument>(v))
return getParentStageInPipeline(pipeline,
v.cast<BlockArgument>().getOwner());
cast<BlockArgument>(v).getOwner());
return getParentStageInPipeline(pipeline, v.getDefiningOp());
}
@ -108,7 +108,7 @@ static void printOutputList(OpAsmPrinter &p, TypeRange types, ArrayAttr names) {
p << "(";
llvm::interleaveComma(llvm::zip(types, names), p, [&](auto it) {
auto [type, name] = it;
p.printKeywordOrString(name.template cast<StringAttr>().str());
p.printKeywordOrString(cast<StringAttr>(name).str());
p << " : " << type;
});
p << ")";
@ -371,7 +371,7 @@ getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
auto arg = block.getArguments()[regI];
if (regNames) {
auto nameAttr = (*regNames)[regI].dyn_cast<StringAttr>();
auto nameAttr = dyn_cast<StringAttr>((*regNames)[regI]);
if (nameAttr && !nameAttr.strref().empty()) {
setNameFn(arg, nameAttr);
continue;
@ -387,7 +387,7 @@ getPipelineAsmBlockArgumentNames(TPipelineOp op, mlir::Region &region,
if (passthroughNames) {
auto nameAttr =
(*passthroughNames)[passthroughI].dyn_cast<StringAttr>();
dyn_cast<StringAttr>((*passthroughNames)[passthroughI]);
if (nameAttr && !nameAttr.strref().empty()) {
setNameFn(arg, nameAttr);
continue;
@ -543,7 +543,7 @@ LogicalResult ScheduledPipelineOp::verify() {
bool err = true;
if (block.getNumArguments() != 0) {
auto lastArgType =
block.getArguments().back().getType().dyn_cast<IntegerType>();
dyn_cast<IntegerType>(block.getArguments().back().getType());
err = !lastArgType || lastArgType.getWidth() != 1;
}
if (err)
@ -631,7 +631,7 @@ StageKind ScheduledPipelineOp::getStageKind(size_t stageIndex) {
if (stageIndex < stallability->size()) {
bool stageIsStallable =
(*stallability)[stageIndex].cast<BoolAttr>().getValue();
cast<BoolAttr>((*stallability)[stageIndex]).getValue();
if (!stageIsStallable) {
// This is a non-stallable stage.
return StageKind::NonStallable;
@ -786,14 +786,13 @@ void printStageRegisters(OpAsmPrinter &p, Operation *op, ValueRange registers,
size_t idx = it.index();
auto &[reg, type, nClockGatesAttr] = it.value();
if (names) {
if (auto nameAttr = names[idx].dyn_cast<StringAttr>();
if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
nameAttr && !nameAttr.strref().empty())
p << nameAttr << " = ";
}
p << reg << " : " << type;
int64_t nClockGates =
nClockGatesAttr.template cast<IntegerAttr>().getInt();
int64_t nClockGates = cast<IntegerAttr>(nClockGatesAttr).getInt();
if (nClockGates == 0)
return;
p << " gated by [";
@ -818,7 +817,7 @@ void printPassthroughs(OpAsmPrinter &p, Operation *op, ValueRange passthroughs,
size_t idx = it.index();
auto &[reg, type] = it.value();
if (names) {
if (auto nameAttr = names[idx].dyn_cast<StringAttr>();
if (auto nameAttr = dyn_cast<StringAttr>(names[idx]);
nameAttr && !nameAttr.strref().empty())
p << nameAttr << " = ";
}

View File

@ -299,11 +299,10 @@ OperationOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
LinkedOperatorTypeAttr OperationOp::getLinkedOperatorTypeAttr() {
if (ArrayAttr properties = getSspPropertiesAttr()) {
const auto *it = llvm::find_if(properties, [](Attribute a) {
return a.isa<LinkedOperatorTypeAttr>();
});
const auto *it = llvm::find_if(
properties, [](Attribute a) { return isa<LinkedOperatorTypeAttr>(a); });
if (it != properties.end())
return (*it).cast<LinkedOperatorTypeAttr>();
return cast<LinkedOperatorTypeAttr>(*it);
}
return {};
}

View File

@ -454,8 +454,8 @@ std::optional<unsigned> Dependence::getSourceIndex() const {
if (!isDefUse())
return std::nullopt;
assert(defUse->get().isa<OpResult>() && "source is not an operation");
return defUse->get().dyn_cast<OpResult>().getResultNumber();
assert(isa<OpResult>(defUse->get()) && "source is not an operation");
return dyn_cast<OpResult>(defUse->get()).getResultNumber();
}
std::optional<unsigned> Dependence::getDestinationIndex() const {