[AIG] Add LowerVariadic pass

This commit adds AIG LowerVariadic pass to lower variadic AndInverter
op to have at most two operands. This makes IR closer to traditinal
AIG representation combined with LowerWordToBits pass
This commit is contained in:
Hideto Ueno 2024-10-26 18:28:19 +09:00
parent c2c00c69e0
commit f2fc8c7214
5 changed files with 131 additions and 1 deletions

View File

@ -11,4 +11,8 @@
include "mlir/Pass/PassBase.td"
def LowerVariadic : Pass<"aig-lower-variadic", "hw::HWModuleOp"> {
let summary = "Lower variadic AndInverter operations to binary AndInverter";
}
#endif // CIRCT_DIALECT_AIG_AIGPASSES_TD

View File

@ -10,6 +10,7 @@ add_circt_dialect_library(CIRCTAIG
CIRCTHW
DEPENDS
CIRCTAIGPassesIncGen
MLIRAIGIncGen
)
add_subdirectory(Transforms)

View File

@ -0,0 +1,11 @@
add_circt_dialect_library(CIRCTAIGTransforms
LowerVariadic.cpp
DEPENDS
CIRCTAIGPassesIncGen
LINK_LIBS PUBLIC
CIRCTAIG
CIRCTComb
CIRCTHW
)

View File

@ -0,0 +1,103 @@
//===- LowerVariadic.cpp - Lowering Variadic to Binary Ops ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass lowers variadic AndInverter operations to binary AndInverter
// operations.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/AIG/AIGOps.h"
#include "circt/Dialect/AIG/AIGPasses.h"
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "aig-lower-variadic"
namespace circt {
namespace aig {
#define GEN_PASS_DEF_LOWERVARIADIC
#include "circt/Dialect/AIG/AIGPasses.h.inc"
} // namespace aig
} // namespace circt
using namespace circt;
using namespace aig;
//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//
namespace {
static Value lowerVariadicAndInverterOp(AndInverterOp op, OperandRange operands,
ArrayRef<bool> inverts,
PatternRewriter &rewriter) {
switch (operands.size()) {
case 0:
assert(0 && "cannot be called with empty operand range");
break;
case 1:
if (inverts[0])
return rewriter.create<AndInverterOp>(op.getLoc(), operands[0], true);
else
return operands[0];
case 2:
return rewriter.create<AndInverterOp>(op.getLoc(), operands[0], operands[1],
inverts[0], inverts[1]);
default:
auto firstHalf = operands.size() / 2;
auto lhs =
lowerVariadicAndInverterOp(op, operands.take_front(firstHalf),
inverts.take_front(firstHalf), rewriter);
auto rhs =
lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf),
inverts.drop_front(firstHalf), rewriter);
return rewriter.create<AndInverterOp>(op.getLoc(), lhs, rhs);
}
}
struct VariadicOpConversion : OpRewritePattern<aig::AndInverterOp> {
using OpRewritePattern<aig::AndInverterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AndInverterOp op,
PatternRewriter &rewriter) const override {
if (op.getInputs().size() <= 2)
return failure();
// TODO: This is a naive implementation that creates a balanced binary tree.
// We can improve by analyzing the dataflow and creating a tree that
// improves the critical path or area.
rewriter.replaceOp(op,
lowerVariadicAndInverterOp(op, op.getOperands(),
op.getInverted(), rewriter));
return success();
}
};
} // namespace
static void populateLowerVariadicPatterns(RewritePatternSet &patterns) {
patterns.add<VariadicOpConversion>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Lower Variadic pass
//===----------------------------------------------------------------------===//
namespace {
struct LowerVariadicPass : public impl::LowerVariadicBase<LowerVariadicPass> {
void runOnOperation() override;
};
} // namespace
void LowerVariadicPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLowerVariadicPatterns(patterns);
mlir::FrozenRewritePatternSet frozen(std::move(patterns));
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), frozen)))
return signalPassFailure();
}

View File

@ -0,0 +1,11 @@
// RUN: circt-opt %s --aig-lower-variadic | FileCheck %s
// CHECK: hw.module @Basic
hw.module @Basic(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, out f: i2) {
// CHECK: %[[RES0:.+]] = aig.and_inv not %a, %b : i2
// CHECK-NEXT: %[[RES1:.+]] = aig.and_inv not %d, %e : i2
// CHECK-NEXT: %[[RES2:.+]] = aig.and_inv %c, %[[RES1]] : i2
// CHECK-NEXT: %[[RES3:.+]] = aig.and_inv %[[RES0]], %[[RES2]] : i2
// CHECK-NEXT: hw.output %[[RES3]] : i2
%0 = aig.and_inv not %a, %b, %c, not %d, %e : i2
hw.output %0 : i2
}