forked from OSchip/llvm-project
[mlir][linalg] Fix FoldConstantTranspose execution inefficiency
* Move SmallVectors outside of inner loops to avoid frequent allocations and deallocations * Calculate linearized index and call flat range getters to avoid internal shape querying behind `getValue`. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D112099
This commit is contained in:
parent
d29ccbecd0
commit
c788cad83b
|
@ -1286,12 +1286,16 @@ private:
|
|||
template <typename ConcreteType>
|
||||
class FoldConstantBase : public OpRewritePattern<GenericOp> {
|
||||
public:
|
||||
struct APIntOrFloat {
|
||||
Optional<APInt> apInt;
|
||||
Optional<APFloat> apFloat;
|
||||
};
|
||||
struct APIntOrFloatArray {
|
||||
SmallVector<APInt> apInts;
|
||||
SmallVector<APFloat> apFloats;
|
||||
};
|
||||
using RegionComputationFn =
|
||||
std::function<APIntOrFloatArray(APIntOrFloatArray)>;
|
||||
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
|
||||
|
||||
FoldConstantBase(MLIRContext *context,
|
||||
const ControlElementwiseOpsFusionFn &controlFn,
|
||||
|
@ -1403,57 +1407,109 @@ public:
|
|||
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
|
||||
auto outputShape = outputType.getShape();
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in advance,
|
||||
// we need to loop over the range [0, element count) and delinearize the
|
||||
// index.
|
||||
for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
|
||||
SmallVector<uint64_t> indices(loopBounds.size(), 0);
|
||||
int totalCount = linearIndex0;
|
||||
// Allocate small vectors for index delinearization. Initial values do not
|
||||
// matter here as they will be overwritten later.
|
||||
SmallVector<uint64_t> indices(loopBounds.size(), 0);
|
||||
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
|
||||
SmallVector<SmallVector<uint64_t>> srcIndices(
|
||||
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
|
||||
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
|
||||
uint64_t dstLinearIndex = 0;
|
||||
|
||||
// Allocate spaces for compute function inputs. Initial values do not matter
|
||||
// here as they will be overwritten later.
|
||||
APIntOrFloatArray computeFnInputs;
|
||||
|
||||
auto inputShapes = llvm::to_vector<4>(
|
||||
llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
|
||||
return operand->get().getType().cast<ShapedType>().getShape();
|
||||
}));
|
||||
|
||||
// Given a `linearIndex`, remap it to a linear index to access linalg op
|
||||
// inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
|
||||
// `srcLinearIndices`, `dstLinearIndex` in place.
|
||||
auto computeRemappedLinearIndex = [&](int linearIndex) {
|
||||
int totalCount = linearIndex;
|
||||
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
||||
indices[dim] = totalCount % loopBounds[dim];
|
||||
totalCount /= loopBounds[dim];
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<uint64_t>> srcIndices;
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
srcIndices.emplace_back(loopBounds.size(), 0);
|
||||
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
|
||||
|
||||
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
srcIndices[i][dim] = indices[inputDims[i][dim]];
|
||||
dstIndices[dim] = indices[outputDims[dim]];
|
||||
}
|
||||
|
||||
uint64_t linearIndex1 = dstIndices.front();
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim)
|
||||
linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
|
||||
dstLinearIndex = dstIndices.front();
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
srcLinearIndices[i] = srcIndices[i].front();
|
||||
|
||||
// Collect constant elements for all inputs at this loop iteration.
|
||||
SmallVector<APInt> intValues;
|
||||
SmallVector<APFloat> fpValues;
|
||||
if (elementType.isa<FloatType>()) {
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim) {
|
||||
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
|
||||
} else {
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
|
||||
srcLinearIndices[i] =
|
||||
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
|
||||
}
|
||||
};
|
||||
|
||||
// Invoke the computation to get the corresponding constant output
|
||||
// element.
|
||||
APIntOrFloatArray inputs = {intValues, fpValues};
|
||||
APIntOrFloatArray outputs = computeFn(inputs);
|
||||
bool isFloat = elementType.isa<FloatType>();
|
||||
if (isFloat) {
|
||||
SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
|
||||
inputFpIterators;
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
|
||||
|
||||
if (elementType.isa<FloatType>()) {
|
||||
fpOutputValues[linearIndex1] = outputs.apFloats.front();
|
||||
} else {
|
||||
intOutputValues[linearIndex1] = outputs.apInts.front();
|
||||
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in
|
||||
// advance, we need to loop over the range [0, element count) and
|
||||
// delinearize the index.
|
||||
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
||||
computeRemappedLinearIndex(linearIndex);
|
||||
|
||||
// Collect constant elements for all inputs at this loop iteration.
|
||||
for (int i = 0; i < numInputs; ++i) {
|
||||
computeFnInputs.apFloats[i] =
|
||||
*(inputFpIterators[i].begin() + srcLinearIndices[i]);
|
||||
}
|
||||
|
||||
// Invoke the computation to get the corresponding constant output
|
||||
// element.
|
||||
APIntOrFloat outputs = computeFn(computeFnInputs);
|
||||
|
||||
fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
|
||||
}
|
||||
} else {
|
||||
SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
|
||||
inputIntIterators;
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
inputIntIterators.push_back(inputValues[i].getValues<APInt>());
|
||||
|
||||
computeFnInputs.apInts.resize(numInputs);
|
||||
|
||||
// Transpose the input constant. Because we don't know its rank in
|
||||
// advance, we need to loop over the range [0, element count) and
|
||||
// delinearize the index.
|
||||
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
||||
computeRemappedLinearIndex(linearIndex);
|
||||
|
||||
// Collect constant elements for all inputs at this loop iteration.
|
||||
for (int i = 0; i < numInputs; ++i) {
|
||||
computeFnInputs.apInts[i] =
|
||||
*(inputIntIterators[i].begin() + srcLinearIndices[i]);
|
||||
}
|
||||
|
||||
// Invoke the computation to get the corresponding constant output
|
||||
// element.
|
||||
APIntOrFloat outputs = computeFn(computeFnInputs);
|
||||
|
||||
intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntOrFPElementsAttr outputAttr;
|
||||
if (elementType.isa<FloatType>()) {
|
||||
if (isFloat) {
|
||||
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
|
||||
} else {
|
||||
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
|
||||
|
@ -1494,7 +1550,11 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
|
|||
}
|
||||
|
||||
// No computation; just return the orginal value.
|
||||
return [](APIntOrFloatArray inputs) { return inputs; };
|
||||
return [](const APIntOrFloatArray &inputs) {
|
||||
if (inputs.apFloats.empty())
|
||||
return APIntOrFloat{inputs.apInts.front(), llvm::None};
|
||||
return APIntOrFloat{llvm::None, inputs.apFloats.front()};
|
||||
};
|
||||
}
|
||||
|
||||
ControlElementwiseOpsFusionFn controlFn;
|
||||
|
|
Loading…
Reference in New Issue