[mlir][linalg] Fix FoldConstantTranspose execution inefficiency

* Move SmallVectors outside of inner loops to avoid frequent
  allocations and deallocations
* Calculate linearized index and call flat range getters to
  avoid internal shape querying behind `getValue`.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D112099
This commit is contained in:
Lei Zhang 2021-10-28 09:45:07 -04:00
parent d29ccbecd0
commit c788cad83b
1 changed files with 93 additions and 33 deletions

View File

@ -1286,12 +1286,16 @@ private:
template <typename ConcreteType>
class FoldConstantBase : public OpRewritePattern<GenericOp> {
public:
struct APIntOrFloat {
Optional<APInt> apInt;
Optional<APFloat> apFloat;
};
struct APIntOrFloatArray {
SmallVector<APInt> apInts;
SmallVector<APFloat> apFloats;
};
using RegionComputationFn =
std::function<APIntOrFloatArray(APIntOrFloatArray)>;
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
FoldConstantBase(MLIRContext *context,
const ControlElementwiseOpsFusionFn &controlFn,
@ -1403,57 +1407,109 @@ public:
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
auto outputShape = outputType.getShape();
// Transpose the input constant. Because we don't know its rank in advance,
// we need to loop over the range [0, element count) and delinearize the
// index.
for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
SmallVector<uint64_t> indices(loopBounds.size(), 0);
int totalCount = linearIndex0;
// Allocate small vectors for index delinearization. Initial values do not
// matter here as they will be overwritten later.
SmallVector<uint64_t> indices(loopBounds.size(), 0);
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
SmallVector<SmallVector<uint64_t>> srcIndices(
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
uint64_t dstLinearIndex = 0;
// Allocate spaces for compute function inputs. Initial values do not matter
// here as they will be overwritten later.
APIntOrFloatArray computeFnInputs;
auto inputShapes = llvm::to_vector<4>(
llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
return operand->get().getType().cast<ShapedType>().getShape();
}));
// Given a `linearIndex`, remap it to a linear index to access linalg op
// inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
// `srcLinearIndices`, `dstLinearIndex` in place.
auto computeRemappedLinearIndex = [&](int linearIndex) {
int totalCount = linearIndex;
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
indices[dim] = totalCount % loopBounds[dim];
totalCount /= loopBounds[dim];
}
SmallVector<SmallVector<uint64_t>> srcIndices;
for (int i = 0; i < numInputs; ++i)
srcIndices.emplace_back(loopBounds.size(), 0);
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
for (int i = 0; i < numInputs; ++i)
srcIndices[i][dim] = indices[inputDims[i][dim]];
dstIndices[dim] = indices[outputDims[dim]];
}
uint64_t linearIndex1 = dstIndices.front();
for (int dim = 1; dim < outputType.getRank(); ++dim)
linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
dstLinearIndex = dstIndices.front();
for (int i = 0; i < numInputs; ++i)
srcLinearIndices[i] = srcIndices[i].front();
// Collect constant elements for all inputs at this loop iteration.
SmallVector<APInt> intValues;
SmallVector<APFloat> fpValues;
if (elementType.isa<FloatType>()) {
for (int dim = 1; dim < outputType.getRank(); ++dim) {
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
for (int i = 0; i < numInputs; ++i)
fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
} else {
for (int i = 0; i < numInputs; ++i)
intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
srcLinearIndices[i] =
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
}
};
// Invoke the computation to get the corresponding constant output
// element.
APIntOrFloatArray inputs = {intValues, fpValues};
APIntOrFloatArray outputs = computeFn(inputs);
bool isFloat = elementType.isa<FloatType>();
if (isFloat) {
SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
inputFpIterators;
for (int i = 0; i < numInputs; ++i)
inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
if (elementType.isa<FloatType>()) {
fpOutputValues[linearIndex1] = outputs.apFloats.front();
} else {
intOutputValues[linearIndex1] = outputs.apInts.front();
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
// Transpose the input constant. Because we don't know its rank in
// advance, we need to loop over the range [0, element count) and
// delinearize the index.
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
computeRemappedLinearIndex(linearIndex);
// Collect constant elements for all inputs at this loop iteration.
for (int i = 0; i < numInputs; ++i) {
computeFnInputs.apFloats[i] =
*(inputFpIterators[i].begin() + srcLinearIndices[i]);
}
// Invoke the computation to get the corresponding constant output
// element.
APIntOrFloat outputs = computeFn(computeFnInputs);
fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
}
} else {
SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
inputIntIterators;
for (int i = 0; i < numInputs; ++i)
inputIntIterators.push_back(inputValues[i].getValues<APInt>());
computeFnInputs.apInts.resize(numInputs);
// Transpose the input constant. Because we don't know its rank in
// advance, we need to loop over the range [0, element count) and
// delinearize the index.
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
computeRemappedLinearIndex(linearIndex);
// Collect constant elements for all inputs at this loop iteration.
for (int i = 0; i < numInputs; ++i) {
computeFnInputs.apInts[i] =
*(inputIntIterators[i].begin() + srcLinearIndices[i]);
}
// Invoke the computation to get the corresponding constant output
// element.
APIntOrFloat outputs = computeFn(computeFnInputs);
intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
}
}
DenseIntOrFPElementsAttr outputAttr;
if (elementType.isa<FloatType>()) {
if (isFloat) {
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
} else {
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
@ -1494,7 +1550,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
}
// No computation; just return the orginal value.
return [](APIntOrFloatArray inputs) { return inputs; };
return [](const APIntOrFloatArray &inputs) {
if (inputs.apFloats.empty())
return APIntOrFloat{inputs.apInts.front(), llvm::None};
return APIntOrFloat{llvm::None, inputs.apFloats.front()};
};
}
ControlElementwiseOpsFusionFn controlFn;