Support folding of StandardOps with DenseElementsAttr.

PiperOrigin-RevId: 282270243
This commit is contained in:
Ben Vanik 2019-11-24 18:50:54 -08:00 committed by A. Unique TensorFlower
parent ae821fe626
commit d2284f1f0b
2 changed files with 57 additions and 13 deletions

View File

@ -244,25 +244,41 @@ template <class AttrElementT,
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
return {};
if (operands[0].getType() != operands[1].getType())
return {};
if (auto lhs = operands[0].dyn_cast_or_null<AttrElementT>()) {
auto rhs = operands[1].dyn_cast_or_null<AttrElementT>();
if (!rhs || lhs.getType() != rhs.getType())
return {};
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
auto lhs = operands[0].cast<AttrElementT>();
auto rhs = operands[1].cast<AttrElementT>();
return AttrElementT::get(lhs.getType(),
calculate(lhs.getValue(), rhs.getValue()));
} else if (auto lhs = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>();
if (!rhs || lhs.getType() != rhs.getType())
return {};
auto elementResult = constFoldBinaryOp<AttrElementT>(
{lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
if (!elementResult)
return {};
} else if (operands[0].isa<SplatElementsAttr>() &&
operands[1].isa<SplatElementsAttr>()) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
auto lhs = operands[0].cast<SplatElementsAttr>();
auto rhs = operands[1].cast<SplatElementsAttr>();
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
rhs.getSplatValue<ElementValueT>());
return DenseElementsAttr::get(lhs.getType(), elementResult);
} else if (operands[0].isa<ElementsAttr>() &&
operands[1].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto lhs = operands[0].cast<ElementsAttr>();
auto rhs = operands[1].cast<ElementsAttr>();
auto lhsIt = lhs.getValues<ElementValueT>().begin();
auto rhsIt = rhs.getValues<ElementValueT>().begin();
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
elementResults.push_back(calculate(*lhsIt, *rhsIt));
return DenseElementsAttr::get(lhs.getType(), elementResults);
}
return {};
}

View File

@ -50,6 +50,34 @@ func @addf_splat_tensor() -> tensor<4xf32> {
// -----
// CHECK-LABEL: func @addf_dense_tensor
func @addf_dense_tensor() -> tensor<4xf32> {
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
%1 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 5.{{0*}}e+00, 7.{{0*}}e+00, 9.{{0*}}e+00]> : tensor<4xf32>
%2 = addf %0, %1 : tensor<4xf32>
// CHECK-NEXT: return [[C]]
return %2 : tensor<4xf32>
}
// -----
// CHECK-LABEL: func @addf_dense_and_splat_tensors
func @addf_dense_and_splat_tensors() -> tensor<4xf32> {
%0 = constant dense<[1.5, 2.5, 3.5, 4.5]> : tensor<4xf32>
%1 = constant dense<1.5> : tensor<4xf32>
// CHECK-NEXT: [[C:%.+]] = constant dense<[3.{{0*}}e+00, 4.{{0*}}e+00, 5.{{0*}}e+00, 6.{{0*}}e+00]> : tensor<4xf32>
%2 = addf %0, %1 : tensor<4xf32>
// CHECK-NEXT: return [[C]]
return %2 : tensor<4xf32>
}
// -----
// CHECK-LABEL: func @simple_addi
func @simple_addi() -> i32 {
%0 = constant 1 : i32