forked from OSchip/llvm-project
Support folding of StandardOps with DenseElementsAttr.
PiperOrigin-RevId: 282270243
This commit is contained in:
parent
ae821fe626
commit
d2284f1f0b
|
@ -244,25 +244,41 @@ template <class AttrElementT,
|
|||
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
||||
const CalculationT &calculate) {
|
||||
assert(operands.size() == 2 && "binary op takes two operands");
|
||||
if (!operands[0] || !operands[1])
|
||||
return {};
|
||||
if (operands[0].getType() != operands[1].getType())
|
||||
return {};
|
||||
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
|
||||
auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
|
||||
if (!rhs || lhs.getType() != rhs.getType())
|
||||
return {};
|
||||
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
|
||||
auto lhs = operands[0].cast<AttrElementT>();
|
||||
auto rhs = operands[1].cast<AttrElementT>();
|
||||
|
||||
return AttrElementT::get(lhs.getType(),
|
||||
calculate(lhs.getValue(), rhs.getValue()));
|
||||
} else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
|
||||
if (!rhs || lhs.getType() != rhs.getType())
|
||||
return {};
|
||||
|
||||
auto elementResult = constFoldBinaryOp<AttrElementT>(
|
||||
{lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
|
||||
if (!elementResult)
|
||||
return {};
|
||||
} else if (operands[0].isa<SplatElementsAttr>() &&
|
||||
operands[1].isa<SplatElementsAttr>()) {
|
||||
// Both operands are splats so we can avoid expanding the values out and
|
||||
// just fold based on the splat value.
|
||||
auto lhs = operands[0].cast<SplatElementsAttr>();
|
||||
auto rhs = operands[1].cast<SplatElementsAttr>();
|
||||
|
||||
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
|
||||
rhs.getSplatValue<ElementValueT>());
|
||||
return DenseElementsAttr::get(lhs.getType(), elementResult);
|
||||
} else if (operands[0].isa<ElementsAttr>() &&
|
||||
operands[1].isa<ElementsAttr>()) {
|
||||
// Operands are ElementsAttr-derived; perform an element-wise fold by
|
||||
// expanding the values.
|
||||
auto lhs = operands[0].cast<ElementsAttr>();
|
||||
auto rhs = operands[1].cast<ElementsAttr>();
|
||||
|
||||
auto lhsIt = lhs.getValues<ElementValueT>().begin();
|
||||
auto rhsIt = rhs.getValues<ElementValueT>().begin();
|
||||
SmallVector<ElementValueT, 4> elementResults;
|
||||
elementResults.reserve(lhs.getNumElements());
|
||||
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
|
||||
elementResults.push_back(calculate(*lhsIt, *rhsIt));
|
||||
return DenseElementsAttr::get(lhs.getType(), elementResults);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -50,6 +50,34 @@ func @addf_splat_tensor() -> tensor<4xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @addf_dense_tensor
|
||||
func @addf_dense_tensor() -> tensor<4xf32> {
|
||||
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
|
||||
%1 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 5.{{0*}}e+00, 7.{{0*}}e+00, 9.{{0*}}e+00]> : tensor<4xf32>
|
||||
%2 = addf %0, %1 : tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: return [[C]]
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @addf_dense_and_splat_tensors
|
||||
func @addf_dense_and_splat_tensors() -> tensor<4xf32> {
|
||||
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
|
||||
%1 = constant dense<1.5> : tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 4.{{0*}}e+00, 5.{{0*}}e+00, 6.{{0*}}e+00]> : tensor<4xf32>
|
||||
%2 = addf %0, %1 : tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: return [[C]]
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @simple_addi
|
||||
func @simple_addi() -> i32 {
|
||||
%0 = constant 1 : i32
|
||||
|
|
Loading…
Reference in New Issue