diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index fa9328f869eb..d8cc4ed19c17 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -156,7 +156,7 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { dest->setAttrs(newAttrs.takeVector()); // Clone the body. - body.cloneInto(&dest->body, mapper, dest->getContext()); + body.cloneInto(&dest->body, mapper, getContext()); } /// Create a deep copy of this function and all of its blocks, remapping diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index d5f430fa6299..b029cb58ed08 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -159,7 +159,6 @@ Type TypeConverter::convertIntegerType(IntegerType type) { } Type TypeConverter::convertFloatType(FloatType type) { - MLIRContext *context = type.getContext(); switch (type.getKind()) { case mlir::StandardTypes::F32: return wrap(builder.getFloatTy()); @@ -168,8 +167,8 @@ Type TypeConverter::convertFloatType(FloatType type) { case mlir::StandardTypes::F16: return wrap(builder.getHalfTy()); case mlir::StandardTypes::BF16: - return context->emitError(UnknownLoc::get(context), - "unsupported type: BF16"), + return mlirContext->emitError(UnknownLoc::get(mlirContext), + "unsupported type: BF16"), Type(); default: llvm_unreachable("non-float type in convertFloatType"); @@ -236,11 +235,11 @@ FunctionType TypeConverter::convertFunctionSignatureType(FunctionType type) { // If function does not return anything, return immediately. if (type.getNumResults() == 0) - return FunctionType::get(argTypes, {}, type.getContext()); + return FunctionType::get(argTypes, {}, mlirContext); // Otherwise pack the result types into a struct. if (auto result = getPackedResultType(type.getResults())) - return FunctionType::get(argTypes, {result}, type.getContext()); + return FunctionType::get(argTypes, {result}, mlirContext); return {}; } @@ -271,9 +270,8 @@ Type TypeConverter::convertMemRefType(MemRefType type) { // Convert a 1D vector type to an LLVM vector type. Type TypeConverter::convertVectorType(VectorType type) { if (type.getRank() != 1) { - MLIRContext *context = type.getContext(); - context->emitError(UnknownLoc::get(context), - "only 1D vectors are supported"); + mlirContext->emitError(UnknownLoc::get(mlirContext), + "only 1D vectors are supported"); return {}; } @@ -300,12 +298,11 @@ Type TypeConverter::convertType(Type type) { if (auto llvmType = type.dyn_cast()) return llvmType; - MLIRContext *context = type.getContext(); std::string message; llvm::raw_string_ostream os(message); os << "unsupported type: "; type.print(os); - context->emitError(UnknownLoc::get(context), os.str()); + mlirContext->emitError(UnknownLoc::get(mlirContext), os.str()); return {}; } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 545797590583..8a2002ce3687 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -45,7 +45,7 @@ void Canonicalizer::runOnFunction() { // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when // building the worklist. For now, we just grab everything. - auto *context = func.getContext(); + auto *context = &getContext(); for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index c15108530fb7..2a877c456805 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -745,7 +745,7 @@ void MaterializeVectorsPass::runOnFunction() { // Get the hardware vector type. // TODO(ntv): get elemental type from super-vector type rather than force f32. auto subVectorType = - VectorType::get(state.hwVectorSize, FloatType::getF32(f->getContext())); + VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext())); // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index 7a9cfc7f5abd..4ff5367abbb2 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -71,7 +71,7 @@ struct SimplifyAffineStructures FlatAffineConstraints fac(set); if (fac.isEmpty()) return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), - set.getContext()); + &getContext()); return set; } diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index bc2b4b930230..9d6b7a0ba272 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -30,7 +30,7 @@ struct StripDebugInfo : public FunctionPass { void StripDebugInfo::runOnFunction() { Function &func = getFunction(); - UnknownLoc unknownLoc = UnknownLoc::get(func.getContext()); + UnknownLoc unknownLoc = UnknownLoc::get(&getContext()); // Strip the debug info from the function and its instructions. func.setLoc(unknownLoc);