forked from OSchip/llvm-project
[MLIR] Fix shape inference in toy tutorial
The implementation of shape inference in the toy tutorial did not conform to the correct algorithmic description. The result was only correct because all operations appear to be processed in sequence. Differential Revision: https://reviews.llvm.org/D77382
This commit is contained in:
parent
b801577c59
commit
1a2370bfb8
|
@ -62,7 +62,7 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
|
||||
if (nextop == opWorklist.end())
|
||||
break;
|
||||
|
||||
|
@ -88,6 +88,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has all of its
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
|
|
|
@ -62,7 +62,7 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
|
||||
if (nextop == opWorklist.end())
|
||||
break;
|
||||
|
||||
|
@ -88,6 +88,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has all of its
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
|
|
|
@ -62,7 +62,7 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
|
||||
if (nextop == opWorklist.end())
|
||||
break;
|
||||
|
||||
|
@ -88,6 +88,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has all of its
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
|
|
|
@ -62,7 +62,7 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, returnsDynamicShape);
|
||||
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
|
||||
if (nextop == opWorklist.end())
|
||||
break;
|
||||
|
||||
|
@ -88,6 +88,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has all of its
|
||||
/// operands inferred.
|
||||
static bool allOperandsInferred(Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
|
||||
return operandType.isa<RankedTensorType>();
|
||||
});
|
||||
}
|
||||
|
||||
/// A utility method that returns if the given operation has a dynamically
|
||||
/// shaped result.
|
||||
static bool returnsDynamicShape(Operation *op) {
|
||||
|
|
Loading…
Reference in New Issue