[AIG] Add LowerWordToBits pass

This commit ads LowerWordToBits pass that perform bit-blasting for
AndInverterOp.
This commit is contained in:
Hideto Ueno 2024-10-26 18:29:28 +09:00
parent f2fc8c7214
commit 4f0edf4cfe
5 changed files with 126 additions and 1 deletions

View File

@ -15,4 +15,9 @@ def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
}
def LowerWordToBits : Pass<"aig-lower-word-to-bits", "hw::HWModuleOp"> {
let summary = "Lower multi-bit AndInverter to single-bit ones";
let dependentDialects = ["comb::CombDialect"];
}
#endif // CIRCT_DIALECT_AIG_AIGPASSES_TD

View File

@ -13,4 +13,4 @@ add_circt_dialect_library(CIRCTAIG
MLIRAIGIncGen
)
add_subdirectory(Transforms)
add_subdirectory(Transforms)

View File

@ -1,5 +1,6 @@
add_circt_dialect_library(CIRCTAIGTransforms
LowerVariadic.cpp
LowerWordToBits.cpp
DEPENDS
CIRCTAIGPassesIncGen

View File

@ -0,0 +1,103 @@
//===- LowerWordToBits.cpp - Bit-Blasting Words to Bits ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass lowers multi-bit AIG operations to single-bit ones.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/AIG/AIGOps.h"
#include "circt/Dialect/AIG/AIGPasses.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "aig-lower-word-to-bits"
namespace circt {
namespace aig {
#define GEN_PASS_DEF_LOWERWORDTOBITS
#include "circt/Dialect/AIG/AIGPasses.h.inc"
} // namespace aig
} // namespace circt
using namespace circt;
using namespace aig;
//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//
namespace {
struct WordRewritePattern : public OpRewritePattern<AndInverterOp> {
using OpRewritePattern<AndInverterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AndInverterOp op,
PatternRewriter &rewriter) const override {
auto width = op.getType().getIntOrFloatBitWidth();
if (width <= 1)
return failure();
SmallVector<Value> results;
// We iterate over the width in reverse order to match the endianness of
// `comb.concat`.
for (int64_t i = width - 1; i >= 0; --i) {
SmallVector<Value> operands;
for (auto operand : op.getOperands()) {
// Reuse bits if we can extract from `comb.concat` operands.
if (auto concat = operand.getDefiningOp<comb::ConcatOp>()) {
// For the simplicity, we only handle the case where all the
// `comb.concat` operands are single-bit.
if (concat.getNumOperands() == width &&
llvm::all_of(concat.getOperandTypes(), [](Type type) {
return type.getIntOrFloatBitWidth() == 1;
})) {
// Be careful with the endianness here.
operands.push_back(concat.getOperand(width - i - 1));
continue;
}
}
// Otherwise, we need to extract the bit.
operands.push_back(
rewriter.create<comb::ExtractOp>(op.getLoc(), operand, i, 1));
}
results.push_back(rewriter.create<AndInverterOp>(op.getLoc(), operands,
op.getInvertedAttr()));
}
rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Lower Word to Bits pass
//===----------------------------------------------------------------------===//
namespace {
struct LowerWordToBitsPass
: public impl::LowerWordToBitsBase<LowerWordToBitsPass> {
void runOnOperation() override;
};
} // namespace
void LowerWordToBitsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<WordRewritePattern>(&getContext());
mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns));
mlir::GreedyRewriteConfig config;
// Use top-down traversal to reuse bits from `comb.concat`.
config.useTopDownTraversal = true;
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns,
config)))
return signalPassFailure();
}

View File

@ -0,0 +1,16 @@
// RUN: circt-opt %s --aig-lower-word-to-bits | FileCheck %s
// CHECK: hw.module @Basic
hw.module @Basic(in %a: i2, in %b: i2, out f: i2) {
%0 = aig.and_inv not %a, %b : i2
%1 = aig.and_inv not %0, not %0 : i2
// CHECK-NEXT: %[[EXTRACT_A_1:.+]] = comb.extract %a from 1
// CHECK-NEXT: %[[EXTRACT_B_1:.+]] = comb.extract %b from 1
// CHECK-NEXT: %[[AND_INV_0:.+]] = aig.and_inv not %[[EXTRACT_A_1]], %[[EXTRACT_B_1]]
// CHECK-NEXT: %[[EXTRACT_A_0:.+]] = comb.extract %a from 0
// CHECK-NEXT: %[[EXTRACT_B_0:.+]] = comb.extract %b from 0
// CHECK-NEXT: %[[AND_INV_1:.+]] = aig.and_inv not %[[EXTRACT_A_0]], %[[EXTRACT_B_0]]
// CHECK-NEXT: %[[AND_INV_2:.+]] = aig.and_inv not %[[AND_INV_0]], not %[[AND_INV_0]]
// CHECK-NEXT: %[[AND_INV_3:.+]] = aig.and_inv not %[[AND_INV_1]], not %[[AND_INV_1]]
// CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[AND_INV_2]], %[[AND_INV_3]]
hw.output %1 : i2
}