[MLIR] Fix shape inference in toy tutorial

The implementation of shape inference in the toy tutorial did not conform to the correct algorithmic description.
The result was only correct because all operations appear to be processed in sequence.

Differential Revision: https://reviews.llvm.org/D77382
This commit is contained in:
Frederik Gossen 2020-04-04 04:33:58 +00:00 committed by Mehdi Amini
parent b801577c59
commit 1a2370bfb8
4 changed files with 36 additions and 4 deletions

View File

@ -62,7 +62,7 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@ -88,6 +88,14 @@ public:
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {

View File

@ -62,7 +62,7 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@ -88,6 +88,14 @@ public:
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {

View File

@ -62,7 +62,7 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@ -88,6 +88,14 @@ public:
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {

View File

@ -62,7 +62,7 @@ public:
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
@ -88,6 +88,14 @@ public:
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {