diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 333214efcc05..66474348ca22 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -171,11 +171,14 @@ struct DialectConversionRewriter final : public PatternRewriter { void replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead) override { assert(newValues.size() == op->getNumResults()); - // Create mappings for any type changes. - for (unsigned i = 0, e = newValues.size(); i < e; ++i) - if (newValues[i] && - op->getResult(i)->getType() != newValues[i]->getType()) + + // Create mappings for each of the new result values. + for (unsigned i = 0, e = newValues.size(); i < e; ++i) { + assert((newValues[i] || op->getResult(i)->use_empty()) && + "result value has remaining uses that must be replaced"); + if (newValues[i]) mapping.map(op->getResult(i), newValues[i]); + } // Record the requested operation replacement. replacements.emplace_back(op, newValues);