Fix clang-tidy issues in mlir/ (NFC)

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D115956
This commit is contained in:
Mehdi Amini 2021-12-20 19:45:05 +00:00
parent 3e5b1b77d5
commit 02b6fb218e
115 changed files with 594 additions and 569 deletions

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -27,7 +27,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
cl::value_desc("filename")); cl::value_desc("filename"));
namespace { namespace {
enum Action { None, DumpAST }; enum Action { None, DumpAST };
} } // namespace
static cl::opt<enum Action> static cl::opt<enum Action>
emitAction("emit", cl::desc("Select the kind of output desired"), emitAction("emit", cl::desc("Select the kind of output desired"),

View File

@ -58,8 +58,8 @@ public:
// add them to the module. // add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) { for (FunctionAST &f : moduleAST) {
auto func = mlirGen(F); auto func = mlirGen(f);
if (!func) if (!func)
return nullptr; return nullptr;
theModule.push_back(func); theModule.push_back(func);
@ -113,16 +113,16 @@ private:
// This is a generic function, the return type will be inferred later. // This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors. // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{})); getType(VarType{}));
auto func_type = builder.getFunctionType(arg_types, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -371,7 +371,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -398,7 +398,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -38,7 +38,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -47,7 +47,7 @@ static cl::opt<enum InputType> inputType(
namespace { namespace {
enum Action { None, DumpAST, DumpMLIR }; enum Action { None, DumpAST, DumpMLIR };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -89,8 +89,8 @@ int dumpMLIR() {
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -58,8 +58,8 @@ public:
// add them to the module. // add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) { for (FunctionAST &f : moduleAST) {
auto func = mlirGen(F); auto func = mlirGen(f);
if (!func) if (!func)
return nullptr; return nullptr;
theModule.push_back(func); theModule.push_back(func);
@ -113,16 +113,16 @@ private:
// This is a generic function, the return type will be inferred later. // This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors. // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{})); getType(VarType{}));
auto func_type = builder.getFunctionType(arg_types, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -371,7 +371,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -398,7 +398,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -40,7 +40,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -49,7 +49,7 @@ static cl::opt<enum InputType> inputType(
namespace { namespace {
enum Action { None, DumpAST, DumpMLIR }; enum Action { None, DumpAST, DumpMLIR };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -86,8 +86,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -58,8 +58,8 @@ public:
// add them to the module. // add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) { for (FunctionAST &f : moduleAST) {
auto func = mlirGen(F); auto func = mlirGen(f);
if (!func) if (!func)
return nullptr; return nullptr;
theModule.push_back(func); theModule.push_back(func);
@ -113,16 +113,16 @@ private:
// This is a generic function, the return type will be inferred later. // This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors. // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{})); getType(VarType{}));
auto func_type = builder.getFunctionType(arg_types, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -375,7 +375,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -402,7 +402,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -41,7 +41,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -50,7 +50,7 @@ static cl::opt<enum InputType> inputType(
namespace { namespace {
enum Action { None, DumpAST, DumpMLIR }; enum Action { None, DumpAST, DumpMLIR };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -87,8 +87,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -58,8 +58,8 @@ public:
// add them to the module. // add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) { for (FunctionAST &f : moduleAST) {
auto func = mlirGen(F); auto func = mlirGen(f);
if (!func) if (!func)
return nullptr; return nullptr;
theModule.push_back(func); theModule.push_back(func);
@ -113,16 +113,16 @@ private:
// This is a generic function, the return type will be inferred later. // This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors. // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{})); getType(VarType{}));
auto func_type = builder.getFunctionType(arg_types, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -375,7 +375,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -402,7 +402,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -43,7 +43,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -52,7 +52,7 @@ static cl::opt<enum InputType> inputType(
namespace { namespace {
enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine }; enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -91,8 +91,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -58,8 +58,8 @@ public:
// add them to the module. // add them to the module.
theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) { for (FunctionAST &f : moduleAST) {
auto func = mlirGen(F); auto func = mlirGen(f);
if (!func) if (!func)
return nullptr; return nullptr;
theModule.push_back(func); theModule.push_back(func);
@ -113,16 +113,16 @@ private:
// This is a generic function, the return type will be inferred later. // This is a generic function, the return type will be inferred later.
// Arguments type are uniformly unranked tensors. // Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
getType(VarType{})); getType(VarType{}));
auto func_type = builder.getFunctionType(arg_types, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -375,7 +375,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -402,7 +402,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable); ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -49,7 +49,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -66,7 +66,7 @@ enum Action {
DumpLLVMIR, DumpLLVMIR,
RunJIT RunJIT
}; };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -110,8 +110,8 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -169,14 +169,14 @@ private:
return nullptr; return nullptr;
argTypes.push_back(type); argTypes.push_back(type);
} }
auto func_type = builder.getFunctionType(argTypes, llvm::None); auto funcType = builder.getFunctionType(argTypes, llvm::None);
return mlir::FuncOp::create(location, proto.getName(), func_type); return mlir::FuncOp::create(location, proto.getName(), funcType);
} }
/// Emit a new function and add it to the MLIR module. /// Emit a new function and add it to the MLIR module.
mlir::FuncOp mlirGen(FunctionAST &funcAST) { mlir::FuncOp mlirGen(FunctionAST &funcAST) {
// Create a scope in the symbol table to hold variable declarations. // Create a scope in the symbol table to hold variable declarations.
SymbolTableScopeT var_scope(symbolTable); SymbolTableScopeT varScope(symbolTable);
// Create an MLIR function for the given prototype. // Create an MLIR function for the given prototype.
mlir::FuncOp function(mlirGen(*funcAST.getProto())); mlir::FuncOp function(mlirGen(*funcAST.getProto()));
@ -286,7 +286,7 @@ private:
return llvm::None; return llvm::None;
auto structVars = structAST->getVariables(); auto structVars = structAST->getVariables();
auto it = llvm::find_if(structVars, [&](auto &var) { const auto *it = llvm::find_if(structVars, [&](auto &var) {
return var->getName() == name->getName(); return var->getName() == name->getName();
}); });
if (it == structVars.end()) if (it == structVars.end())
@ -569,7 +569,7 @@ private:
/// Future expressions will be able to reference this variable through symbol /// Future expressions will be able to reference this variable through symbol
/// table lookup. /// table lookup.
mlir::Value mlirGen(VarDeclExprAST &vardecl) { mlir::Value mlirGen(VarDeclExprAST &vardecl) {
auto init = vardecl.getInitVal(); auto *init = vardecl.getInitVal();
if (!init) { if (!init) {
emitError(loc(vardecl.loc()), emitError(loc(vardecl.loc()),
"missing initializer in variable declaration"); "missing initializer in variable declaration");
@ -612,7 +612,7 @@ private:
/// Codegen a list of expression, return failure if one of them hit an error. /// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) { mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
SymbolTableScopeT var_scope(symbolTable); SymbolTableScopeT varScope(symbolTable);
for (auto &expr : blockAST) { for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and // Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested // print. These can only appear in block list and not in nested

View File

@ -121,7 +121,7 @@ void ASTDumper::dump(NumberExprAST *num) {
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
void printLitHelper(ExprAST *litOrNum) { void printLitHelper(ExprAST *litOrNum) {
// Inside a literal expression we can have either a number or another literal // Inside a literal expression we can have either a number or another literal
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) { if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
llvm::errs() << num->getValue(); llvm::errs() << num->getValue();
return; return;
} }

View File

@ -49,7 +49,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
namespace { namespace {
enum InputType { Toy, MLIR }; enum InputType { Toy, MLIR };
} } // namespace
static cl::opt<enum InputType> inputType( static cl::opt<enum InputType> inputType(
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"), "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
@ -66,7 +66,7 @@ enum Action {
DumpLLVMIR, DumpLLVMIR,
RunJIT RunJIT
}; };
} } // namespace
static cl::opt<enum Action> emitAction( static cl::opt<enum Action> emitAction(
"emit", cl::desc("Select the kind of output desired"), "emit", cl::desc("Select the kind of output desired"),
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
@ -110,8 +110,8 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
// Otherwise, the input is '.mlir'. // Otherwise, the input is '.mlir'.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code EC = fileOrErr.getError()) { if (std::error_code ec = fileOrErr.getError()) {
llvm::errs() << "Could not open input file: " << EC.message() << "\n"; llvm::errs() << "Could not open input file: " << ec.message() << "\n";
return -1; return -1;
} }

View File

@ -168,7 +168,7 @@ struct DFSState {
}; };
} // namespace } // namespace
static void DFSPostorder(Operation *root, DFSState *state) { static void dfsPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root); SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops; std::vector<Operation *> ops;
while (!queue.empty()) { while (!queue.empty()) {
@ -200,7 +200,7 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
DFSState state(toSort); DFSState state(toSort);
for (auto *s : toSort) { for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
DFSPostorder(s, &state); dfsPostorder(s, &state);
} }
// Reorder and return. // Reorder and return.

View File

@ -1278,10 +1278,10 @@ bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
/// where each lists loops from outer-most to inner-most in loop nest. /// where each lists loops from outer-most to inner-most in loop nest.
unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) { unsigned mlir::getNumCommonSurroundingLoops(Operation &a, Operation &b) {
SmallVector<AffineForOp, 4> loopsA, loopsB; SmallVector<AffineForOp, 4> loopsA, loopsB;
getLoopIVs(A, &loopsA); getLoopIVs(a, &loopsA);
getLoopIVs(B, &loopsB); getLoopIVs(b, &loopsB);
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
unsigned numCommonLoops = 0; unsigned numCommonLoops = 0;

View File

@ -17,7 +17,6 @@ namespace py = pybind11;
using namespace mlir; using namespace mlir;
using namespace mlir::python; using namespace mlir::python;
using llvm::None;
using llvm::Optional; using llvm::Optional;
using llvm::SmallVector; using llvm::SmallVector;
using llvm::Twine; using llvm::Twine;
@ -510,7 +509,8 @@ public:
if (mlirTypeIsAF32(elementType)) { if (mlirTypeIsAF32(elementType)) {
// f32 // f32
return bufferInfo<float>(shapedType); return bufferInfo<float>(shapedType);
} else if (mlirTypeIsAF64(elementType)) { }
if (mlirTypeIsAF64(elementType)) {
// f64 // f64
return bufferInfo<double>(shapedType); return bufferInfo<double>(shapedType);
} else if (mlirTypeIsAF16(elementType)) { } else if (mlirTypeIsAF16(elementType)) {
@ -712,12 +712,12 @@ public:
SmallVector<MlirNamedAttribute> mlirNamedAttributes; SmallVector<MlirNamedAttribute> mlirNamedAttributes;
mlirNamedAttributes.reserve(attributes.size()); mlirNamedAttributes.reserve(attributes.size());
for (auto &it : attributes) { for (auto &it : attributes) {
auto &mlir_attr = it.second.cast<PyAttribute &>(); auto &mlirAttr = it.second.cast<PyAttribute &>();
auto name = it.first.cast<std::string>(); auto name = it.first.cast<std::string>();
mlirNamedAttributes.push_back(mlirNamedAttributeGet( mlirNamedAttributes.push_back(mlirNamedAttributeGet(
mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
toMlirStringRef(name)), toMlirStringRef(name)),
mlir_attr)); mlirAttr));
} }
MlirAttribute attr = MlirAttribute attr =
mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),

View File

@ -1267,7 +1267,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
if (segmentSpec == 1 || segmentSpec == 0) { if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element. // Unpack unary element.
try { try {
auto operandValue = py::cast<PyValue *>(std::get<0>(it.value())); auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
if (operandValue) { if (operandValue) {
operands.push_back(operandValue); operands.push_back(operandValue);
operandSegmentLengths.push_back(1); operandSegmentLengths.push_back(1);
@ -2286,10 +2286,10 @@ void mlir::python::populateIRCore(py::module &m) {
.def_property_readonly( .def_property_readonly(
"body", "body",
[](PyModule &self) { [](PyModule &self) {
PyOperationRef module_op = PyOperation::forOperation( PyOperationRef moduleOp = PyOperation::forOperation(
self.getContext(), mlirModuleGetOperation(self.get()), self.getContext(), mlirModuleGetOperation(self.get()),
self.getRef().releaseObject()); self.getRef().releaseObject());
PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
return returnBlock; return returnBlock;
}, },
"Return the block for this module") "Return the block for this module")

View File

@ -51,9 +51,8 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
} catch (py::error_already_set &e) { } catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) { if (e.matches(PyExc_ModuleNotFoundError)) {
continue; continue;
} else {
throw;
} }
throw;
} }
break; break;
} }
@ -136,11 +135,10 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
// Positive cache. // Positive cache.
rawOpViewClassMapCache[operationName] = foundIt->second; rawOpViewClassMapCache[operationName] = foundIt->second;
return foundIt->second; return foundIt->second;
} else {
// Negative cache.
rawOpViewClassMap[operationName] = py::none();
return llvm::None;
} }
// Negative cache.
rawOpViewClassMap[operationName] = py::none();
return llvm::None;
} }
} }

View File

@ -8,8 +8,6 @@
#include "PybindUtils.h" #include "PybindUtils.h"
namespace py = pybind11;
pybind11::error_already_set pybind11::error_already_set
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
auto messageStr = message.str(); auto messageStr = message.str();

View File

@ -10,8 +10,6 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
namespace py = pybind11;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Module initialization. // Module initialization.
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------

View File

@ -818,7 +818,7 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable,
MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
MlirStringRef newSymbol, MlirStringRef newSymbol,
MlirOperation from) { MlirOperation from) {
auto cppFrom = unwrap(from); auto *cppFrom = unwrap(from);
auto *context = cppFrom->getContext(); auto *context = cppFrom->getContext();
auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context); auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context); auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);

View File

@ -468,10 +468,10 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
Value structPtr = Value structPtr =
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr); builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
Type int32_type = typeConverter.convertType(builder.getI32Type()); Type int32Type = typeConverter.convertType(builder.getI32Type());
Value zero = Value zero =
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type, Value three = builder.create<LLVM::ConstantOp>(loc, int32Type,
builder.getI32IntegerAttr(3)); builder.getI32IntegerAttr(3));
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy), return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
structPtr, ValueRange({zero, three})); structPtr, ValueRange({zero, three}));

View File

@ -90,8 +90,8 @@ static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
DenseMap<Value, RootOrderingCost> &costs = outer->second; DenseMap<Value, RootOrderingCost> &costs = outer->second;
Value bestSource; Value bestSource;
std::pair<unsigned, unsigned> bestCost; std::pair<unsigned, unsigned> bestCost;
auto inner = costs.begin(), inner_e = costs.end(); auto inner = costs.begin(), innerE = costs.end();
while (inner != inner_e) { while (inner != innerE) {
Value source = inner->first; Value source = inner->first;
if (cycleSet.contains(source)) { if (cycleSet.contains(source)) {
// Going-away edge => get its cost and erase it. // Going-away edge => get its cost and erase it.

View File

@ -259,8 +259,8 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
// from 0 to N with step 1. Therefore, loop induction variables are replaced // from 0 to N with step 1. Therefore, loop induction variables are replaced
// with (gpu-thread/block-id * S) + LB. // with (gpu-thread/block-id * S) + LB.
builder.setInsertionPointToStart(&launchOp.body().front()); builder.setInsertionPointToStart(&launchOp.body().front());
auto lbArgumentIt = lbs.begin(); auto *lbArgumentIt = lbs.begin();
auto stepArgumentIt = steps.begin(); auto *stepArgumentIt = steps.begin();
for (auto en : llvm::enumerate(ivs)) { for (auto en : llvm::enumerate(ivs)) {
Value id = Value id =
en.index() < numBlockDims en.index() < numBlockDims
@ -640,7 +640,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
} else if (op == launchOp.getOperation()) { } else if (op == launchOp.getOperation()) {
// Found our sentinel value. We have finished the operations from one // Found our sentinel value. We have finished the operations from one
// nesting level, pop one level back up. // nesting level, pop one level back up.
auto parent = rewriter.getInsertionPoint()->getParentOp(); auto *parent = rewriter.getInsertionPoint()->getParentOp();
rewriter.setInsertionPointAfter(parent); rewriter.setInsertionPointAfter(parent);
leftNestingScope = true; leftNestingScope = true;
seenSideeffects = false; seenSideeffects = false;

View File

@ -455,11 +455,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
ivs.reserve(parallelOp.getNumLoops()); ivs.reserve(parallelOp.getNumLoops());
bool first = true; bool first = true;
SmallVector<Value, 4> loopResults(iterArgs); SmallVector<Value, 4> loopResults(iterArgs);
for (auto loop_operands : for (auto loopOperands :
llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
parallelOp.getUpperBound(), parallelOp.getStep())) { parallelOp.getUpperBound(), parallelOp.getStep())) {
Value iv, lower, upper, step; Value iv, lower, upper, step;
std::tie(iv, lower, upper, step) = loop_operands; std::tie(iv, lower, upper, step) = loopOperands;
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
ivs.push_back(forOp.getInductionVar()); ivs.push_back(forOp.getInductionVar());
auto iterRange = forOp.getRegionIterArgs(); auto iterRange = forOp.getRegionIterArgs();

View File

@ -1390,7 +1390,7 @@ public:
auto dstType = typeConverter.convertType(op.getType()); auto dstType = typeConverter.convertType(op.getType());
auto scalarType = dstType.cast<VectorType>().getElementType(); auto scalarType = dstType.cast<VectorType>().getElementType();
auto componentsArray = components.getValue(); auto componentsArray = components.getValue();
auto context = rewriter.getContext(); auto *context = rewriter.getContext();
auto llvmI32Type = IntegerType::get(context, 32); auto llvmI32Type = IntegerType::get(context, 32);
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) { for (unsigned i = 0; i < componentsArray.size(); i++) {

View File

@ -2173,16 +2173,16 @@ public:
rewriter.create<linalg::YieldOp>(loc, result); rewriter.create<linalg::YieldOp>(loc, result);
return success(); return success();
} else { }
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0); y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1); y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0); y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1); y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
if (resultElementTy.getIntOrFloatBitWidth() > 32) { if (resultElementTy.getIntOrFloatBitWidth() > 32) {
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx); dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy); dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
} }
auto unitVal = rewriter.create<arith::ConstantOp>( auto unitVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift)); loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
@ -2206,7 +2206,6 @@ public:
rewriter.create<linalg::YieldOp>(loc, result); rewriter.create<linalg::YieldOp>(loc, result);
return success(); return success();
}
} }
return failure(); return failure();

View File

@ -28,9 +28,9 @@ namespace {
struct GpuAllReduceRewriter { struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>; using AccumulatorFactory = std::function<Value(Value, Value)>;
GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
PatternRewriter &rewriter_) PatternRewriter &rewriter)
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
indexType(IndexType::get(reduceOp.getContext())), indexType(IndexType::get(reduceOp.getContext())),
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}

View File

@ -313,7 +313,7 @@ private:
// a SymbolTable by the caller. SymbolTable needs to be refactored to // a SymbolTable by the caller. SymbolTable needs to be refactored to
// prevent manual building of Ops with symbols in code using SymbolTables // prevent manual building of Ops with symbols in code using SymbolTables
// and then this needs to use the OpBuilder. // and then this needs to use the OpBuilder.
auto context = getOperation().getContext(); auto *context = getOperation().getContext();
OpBuilder builder(context); OpBuilder builder(context);
auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(), auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
kernelFunc.getName()); kernelFunc.getName());

View File

@ -266,13 +266,14 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
unsigned size = kDefaultPointerSizeBits; unsigned size = kDefaultPointerSizeBits;
unsigned abi = kDefaultPointerAlignment; unsigned abi = kDefaultPointerAlignment;
auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>(); auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
auto it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { const auto *it =
if (auto type = entry.getKey().dyn_cast<Type>()) { llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
return type.cast<LLVMPointerType>().getAddressSpace() == if (auto type = entry.getKey().dyn_cast<Type>()) {
newType.getAddressSpace(); return type.cast<LLVMPointerType>().getAddressSpace() ==
} newType.getAddressSpace();
return false; }
}); return false;
});
if (it == oldLayout.end()) { if (it == oldLayout.end()) {
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) { llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
if (auto type = entry.getKey().dyn_cast<Type>()) { if (auto type = entry.getKey().dyn_cast<Type>()) {
@ -440,14 +441,15 @@ LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
namespace { namespace {
enum class StructDLEntryPos { Abi = 0, Preferred = 1 }; enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
} } // namespace
static Optional<unsigned> static Optional<unsigned>
getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type, getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
StructDLEntryPos pos) { StructDLEntryPos pos) {
auto currentEntry = llvm::find_if(params, [](DataLayoutEntryInterface entry) { const auto *currentEntry =
return entry.isTypeEntry(); llvm::find_if(params, [](DataLayoutEntryInterface entry) {
}); return entry.isTypeEntry();
});
if (currentEntry == params.end()) if (currentEntry == params.end())
return llvm::None; return llvm::None;
@ -509,7 +511,7 @@ bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
if (!newEntry.isTypeEntry()) if (!newEntry.isTypeEntry())
continue; continue;
auto previousEntry = const auto *previousEntry =
llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) { llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
return entry.isTypeEntry(); return entry.isTypeEntry();
}); });

View File

@ -228,6 +228,7 @@ public:
return operand; return operand;
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__add(Value lhs, Value rhs) { Value applyfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -237,6 +238,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__exp(Value x) { Value applyfn__exp(Value x) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(x)) if (isFloatingPoint(x))
@ -244,6 +246,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__log(Value x) { Value applyfn__log(Value x) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(x)) if (isFloatingPoint(x))
@ -251,6 +254,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__sub(Value lhs, Value rhs) { Value applyfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -260,6 +264,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__mul(Value lhs, Value rhs) { Value applyfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -269,6 +274,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__max(Value lhs, Value rhs) { Value applyfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -278,6 +284,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__max_unsigned(Value lhs, Value rhs) { Value applyfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -287,6 +294,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__min(Value lhs, Value rhs) { Value applyfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -296,6 +304,7 @@ public:
llvm_unreachable("unsupported non numeric type"); llvm_unreachable("unsupported non numeric type");
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
Value applyfn__min_unsigned(Value lhs, Value rhs) { Value applyfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder(); OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs)) if (isFloatingPoint(lhs))
@ -1829,12 +1838,12 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
return failure(); return failure();
// Parse input tensors. // Parse input tensors.
SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args; SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs;
SmallVector<Type, 4> inputTypes; SmallVector<Type, 4> inputTypes;
if (succeeded(parser.parseOptionalKeyword("ins"))) { if (succeeded(parser.parseOptionalKeyword("ins"))) {
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseAssignmentListWithTypes(input_region_args, inputs, if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs,
inputTypes)) inputTypes))
return failure(); return failure();
@ -1844,12 +1853,12 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
} }
// Parse output tensors. // Parse output tensors.
SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args; SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs;
SmallVector<Type, 4> outputTypes; SmallVector<Type, 4> outputTypes;
if (succeeded(parser.parseOptionalKeyword("outs"))) { if (succeeded(parser.parseOptionalKeyword("outs"))) {
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseAssignmentListWithTypes(output_region_args, outputs, if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs,
outputTypes)) outputTypes))
return failure(); return failure();
@ -1905,15 +1914,15 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
// Parse the body. // Parse the body.
Region *body = result.addRegion(); Region *body = result.addRegion();
SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType()); SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
region_types.append(inputTypes); regionTypes.append(inputTypes);
region_types.append(outputTypes); regionTypes.append(outputTypes);
SmallVector<OpAsmParser::OperandType, 4> region_args(ivs); SmallVector<OpAsmParser::OperandType, 4> regionArgs(ivs);
region_args.append(input_region_args); regionArgs.append(inputRegionArgs);
region_args.append(output_region_args); regionArgs.append(outputRegionArgs);
if (parser.parseRegion(*body, region_args, region_types)) if (parser.parseRegion(*body, regionArgs, regionTypes))
return failure(); return failure();
// Parse optional attributes. // Parse optional attributes.

View File

@ -127,7 +127,7 @@ class ConvertElementwiseToLinalgPass
: public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> { : public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
void runOnOperation() final { void runOnOperation() final {
auto func = getOperation(); auto *func = getOperation();
auto *context = &getContext(); auto *context = &getContext();
ConversionTarget target(*context); ConversionTarget target(*context);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);

View File

@ -1426,9 +1426,9 @@ namespace {
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
/// ``` /// ```
/// kw is unrolled, w is unrolled iff dilationW > 1. /// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> { struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW, Conv1DNwcGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
int dilationW) int dilationW)
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false), : StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
strideW(strideW), dilationW(dilationW) { strideW(strideW), dilationW(dilationW) {
// Determine whether `linalgOp` can be generated with this generator // Determine whether `linalgOp` can be generated with this generator
@ -1594,7 +1594,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
/// ``` /// ```
/// kw is always unrolled. /// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1. /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
FailureOr<Operation *> dilated_conv() { FailureOr<Operation *> dilatedConv() {
if (!valid) if (!valid)
return failure(); return failure();
@ -1730,7 +1730,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c}, /*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}})) /*resIndex*/ {n, w, c}}))
return dilated_conv(); return dilatedConv();
return failure(); return failure();
} }
@ -1752,7 +1752,7 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1; auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1; auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation()); LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
Conv1D_NWC_Generator e(b, linalgOp, stride, dilation); Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
auto res = e.generateConv(); auto res = e.generateConv();
if (succeeded(res)) if (succeeded(res))
return res; return res;

View File

@ -195,7 +195,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
// Decomposes given floating point value `arg` into a normalized fraction and // Decomposes given floating point value `arg` into a normalized fraction and
// an integral power of two (see std::frexp). Returned values have float type. // an integral power of two (see std::frexp). Returned values have float type.
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg, static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
bool is_positive = false) { bool isPositive = false) {
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type"); assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
ArrayRef<int64_t> shape = vectorShape(arg); ArrayRef<int64_t> shape = vectorShape(arg);
@ -222,7 +222,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1); Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
// Compute exponent. // Compute exponent.
Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg); Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg);
Value biasedExponentBits = builder.create<arith::ShRUIOp>( Value biasedExponentBits = builder.create<arith::ShRUIOp>(
builder.create<arith::BitcastOp>(i32Vec, arg0), builder.create<arith::BitcastOp>(i32Vec, arg0),
bcast(i32Cst(builder, 23))); bcast(i32Cst(builder, 23)));

View File

@ -375,13 +375,13 @@ parseReductionVarList(OpAsmParser &parser,
/// Print Reduction clause /// Print Reduction clause
static void printReductionVarList(OpAsmPrinter &p, static void printReductionVarList(OpAsmPrinter &p,
Optional<ArrayAttr> reductions, Optional<ArrayAttr> reductions,
OperandRange reduction_vars) { OperandRange reductionVars) {
p << "reduction("; p << "reduction(";
for (unsigned i = 0, e = reductions->size(); i < e; ++i) { for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
if (i != 0) if (i != 0)
p << ", "; p << ", ";
p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
<< reduction_vars[i].getType(); << reductionVars[i].getType();
} }
p << ") "; p << ") ";
} }
@ -389,9 +389,9 @@ static void printReductionVarList(OpAsmPrinter &p,
/// Verifies Reduction Clause /// Verifies Reduction Clause
static LogicalResult verifyReductionVarList(Operation *op, static LogicalResult verifyReductionVarList(Operation *op,
Optional<ArrayAttr> reductions, Optional<ArrayAttr> reductions,
OperandRange reduction_vars) { OperandRange reductionVars) {
if (reduction_vars.size() != 0) { if (reductionVars.size() != 0) {
if (!reductions || reductions->size() != reduction_vars.size()) if (!reductions || reductions->size() != reductionVars.size())
return op->emitOpError() return op->emitOpError()
<< "expected as many reduction symbol references " << "expected as many reduction symbol references "
"as reduction variables"; "as reduction variables";
@ -402,7 +402,7 @@ static LogicalResult verifyReductionVarList(Operation *op,
} }
DenseSet<Value> accumulators; DenseSet<Value> accumulators;
for (auto args : llvm::zip(reduction_vars, *reductions)) { for (auto args : llvm::zip(reductionVars, *reductions)) {
Value accum = std::get<0>(args); Value accum = std::get<0>(args);
if (!accumulators.insert(accum).second) if (!accumulators.insert(accum).second)

View File

@ -271,8 +271,8 @@ bool OperationOp::hasTypeInference() {
static LogicalResult verify(PatternOp pattern) { static LogicalResult verify(PatternOp pattern) {
Region &body = pattern.body(); Region &body = pattern.body();
Operation *term = body.front().getTerminator(); Operation *term = body.front().getTerminator();
auto rewrite_op = dyn_cast<RewriteOp>(term); auto rewriteOp = dyn_cast<RewriteOp>(term);
if (!rewrite_op) { if (!rewriteOp) {
return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
.attachNote(term->getLoc()) .attachNote(term->getLoc())
.append("see terminator defined here"); .append("see terminator defined here");

View File

@ -74,9 +74,9 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
build(builder, state, range, successor); build(builder, state, range, successor);
if (initLoop) { if (initLoop) {
// Create the block and the loop variable. // Create the block and the loop variable.
auto range_type = range.getType().cast<pdl::RangeType>(); auto rangeType = range.getType().cast<pdl::RangeType>();
state.regions.front()->emplaceBlock(); state.regions.front()->emplaceBlock();
state.regions.front()->addArgument(range_type.getElementType()); state.regions.front()->addArgument(rangeType.getElementType());
} }
} }

View File

@ -104,11 +104,13 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
if (candidateType == getStorageType()) { if (candidateType == getStorageType()) {
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}"> // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
return *this; return *this;
} else if (candidateType.isa<RankedTensorType>()) { }
if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get( return RankedTensorType::get(
candidateType.cast<RankedTensorType>().getShape(), getStorageType()); candidateType.cast<RankedTensorType>().getShape(), getStorageType());
} else if (candidateType.isa<UnrankedTensorType>()) { }
if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(getStorageType()); return UnrankedTensorType::get(getStorageType());
} else if (candidateType.isa<VectorType>()) { } else if (candidateType.isa<VectorType>()) {
@ -124,7 +126,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) { if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
return quantizedType.cast<QuantizedType>().getStorageType(); return quantizedType.cast<QuantizedType>().getStorageType();
} else if (quantizedType.isa<ShapedType>()) { }
if (quantizedType.isa<ShapedType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
ShapedType sType = quantizedType.cast<ShapedType>(); ShapedType sType = quantizedType.cast<ShapedType>();
if (!sType.getElementType().isa<QuantizedType>()) { if (!sType.getElementType().isa<QuantizedType>()) {
@ -134,7 +137,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
sType.getElementType().cast<QuantizedType>().getStorageType(); sType.getElementType().cast<QuantizedType>().getStorageType();
if (quantizedType.isa<RankedTensorType>()) { if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(sType.getShape(), storageType); return RankedTensorType::get(sType.getShape(), storageType);
} else if (quantizedType.isa<UnrankedTensorType>()) { }
if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(storageType); return UnrankedTensorType::get(storageType);
} else if (quantizedType.isa<VectorType>()) { } else if (quantizedType.isa<VectorType>()) {
return VectorType::get(sType.getShape(), storageType); return VectorType::get(sType.getShape(), storageType);
@ -148,7 +152,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
if (candidateType == getExpressedType()) { if (candidateType == getExpressedType()) {
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}"> // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
return *this; return *this;
} else if (candidateType.isa<ShapedType>()) { }
if (candidateType.isa<ShapedType>()) {
ShapedType candidateShapedType = candidateType.cast<ShapedType>(); ShapedType candidateShapedType = candidateType.cast<ShapedType>();
if (candidateShapedType.getElementType() != getExpressedType()) { if (candidateShapedType.getElementType() != getExpressedType()) {
return nullptr; return nullptr;
@ -157,7 +162,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
if (candidateType.isa<RankedTensorType>()) { if (candidateType.isa<RankedTensorType>()) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return RankedTensorType::get(candidateShapedType.getShape(), *this); return RankedTensorType::get(candidateShapedType.getShape(), *this);
} else if (candidateType.isa<UnrankedTensorType>()) { }
if (candidateType.isa<UnrankedTensorType>()) {
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
return UnrankedTensorType::get(*this); return UnrankedTensorType::get(*this);
} else if (candidateType.isa<VectorType>()) { } else if (candidateType.isa<VectorType>()) {
@ -173,7 +179,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
if (quantizedType.isa<QuantizedType>()) { if (quantizedType.isa<QuantizedType>()) {
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
return quantizedType.cast<QuantizedType>().getExpressedType(); return quantizedType.cast<QuantizedType>().getExpressedType();
} else if (quantizedType.isa<ShapedType>()) { }
if (quantizedType.isa<ShapedType>()) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
ShapedType sType = quantizedType.cast<ShapedType>(); ShapedType sType = quantizedType.cast<ShapedType>();
if (!sType.getElementType().isa<QuantizedType>()) { if (!sType.getElementType().isa<QuantizedType>()) {
@ -183,7 +190,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
sType.getElementType().cast<QuantizedType>().getExpressedType(); sType.getElementType().cast<QuantizedType>().getExpressedType();
if (quantizedType.isa<RankedTensorType>()) { if (quantizedType.isa<RankedTensorType>()) {
return RankedTensorType::get(sType.getShape(), expressedType); return RankedTensorType::get(sType.getShape(), expressedType);
} else if (quantizedType.isa<UnrankedTensorType>()) { }
if (quantizedType.isa<UnrankedTensorType>()) {
return UnrankedTensorType::get(expressedType); return UnrankedTensorType::get(expressedType);
} else if (quantizedType.isa<VectorType>()) { } else if (quantizedType.isa<VectorType>()) {
return VectorType::get(sType.getShape(), expressedType); return VectorType::get(sType.getShape(), expressedType);

View File

@ -126,7 +126,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false; bool hadFailure = false;
auto func = getFunction(); auto func = getFunction();
RewritePatternSet patterns(func.getContext()); RewritePatternSet patterns(func.getContext());
auto ctx = func.getContext(); auto *ctx = func.getContext();
patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>( patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure); ctx, &hadFailure);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns)); (void)applyPatternsAndFoldGreedily(func, std::move(patterns));

View File

@ -140,10 +140,10 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
Location loc, unsigned numBits, int32_t quantizedDimension, Location loc, unsigned numBits, int32_t quantizedDimension,
ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange, ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
Type expressedType, bool isSigned) { Type expressedType, bool isSigned) {
size_t axis_size = rmins.size(); size_t axisSize = rmins.size();
if (axis_size != rmaxs.size()) { if (axisSize != rmaxs.size()) {
return (emitError(loc, "mismatched per-axis min and max size: ") return (emitError(loc, "mismatched per-axis min and max size: ")
<< axis_size << " vs. " << rmaxs.size(), << axisSize << " vs. " << rmaxs.size(),
nullptr); nullptr);
} }
@ -159,9 +159,9 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
SmallVector<double, 4> scales; SmallVector<double, 4> scales;
SmallVector<int64_t, 4> zeroPoints; SmallVector<int64_t, 4> zeroPoints;
scales.reserve(axis_size); scales.reserve(axisSize);
zeroPoints.reserve(axis_size); zeroPoints.reserve(axisSize);
for (size_t axis = 0; axis != axis_size; ++axis) { for (size_t axis = 0; axis != axisSize; ++axis) {
double rmin = rmins[axis]; double rmin = rmins[axis];
double rmax = rmaxs[axis]; double rmax = rmaxs[axis];
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) { if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {

View File

@ -106,17 +106,17 @@ Attribute mlir::quant::quantizeAttrUniform(
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter); realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType(); outConvertedType = converted.getType();
return converted; return converted;
} else if (realValue.isa<SparseElementsAttr>()) { }
if (realValue.isa<SparseElementsAttr>()) {
// Sparse tensor or vector constant. // Sparse tensor or vector constant.
auto converted = convertSparseElementsAttr( auto converted = convertSparseElementsAttr(
realValue.cast<SparseElementsAttr>(), quantizedElementType, converter); realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType(); outConvertedType = converted.getType();
return converted; return converted;
} else {
// Nothing else matched: try to convert a primitive.
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
outConvertedType);
} }
// Nothing else matched: try to convert a primitive.
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
outConvertedType);
} }
/// Convert an attribute from a type based on /// Convert an attribute from a type based on
@ -132,9 +132,9 @@ Attribute mlir::quant::quantizeAttr(Attribute realValue,
UniformQuantizedValueConverter converter(uniformQuantized); UniformQuantizedValueConverter converter(uniformQuantized);
return quantizeAttrUniform(realValue, uniformQuantized, converter, return quantizeAttrUniform(realValue, uniformQuantized, converter,
outConvertedType); outConvertedType);
}
} else if (auto uniformQuantizedPerAxis = if (auto uniformQuantizedPerAxis =
quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) { quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
auto converted = converter.convert(realValue); auto converted = converter.convert(realValue);
// TODO: why we need this outConvertedType? remove it? // TODO: why we need this outConvertedType? remove it?
@ -142,7 +142,6 @@ Attribute mlir::quant::quantizeAttr(Attribute realValue,
outConvertedType = converted.getType(); outConvertedType = converted.getType();
} }
return converted; return converted;
} else {
return nullptr;
} }
return nullptr;
} }

View File

@ -74,7 +74,7 @@ static Attribute extractCompositeElement(Attribute composite,
namespace { namespace {
#include "SPIRVCanonicalization.inc" #include "SPIRVCanonicalization.inc"
} } // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// spv.AccessChainOp // spv.AccessChainOp

View File

@ -3250,13 +3250,13 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
return success(); return success();
} }
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) { static void print(spirv::CooperativeMatrixLoadNVOp m, OpAsmPrinter &printer) {
printer << " " << M.pointer() << ", " << M.stride() << ", " printer << " " << m.pointer() << ", " << m.stride() << ", "
<< M.columnmajor(); << m.columnmajor();
// Print optional memory access attribute. // Print optional memory access attribute.
if (auto memAccess = M.memory_access()) if (auto memAccess = m.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
printer << " : " << M.pointer().getType() << " as " << M.getType(); printer << " : " << m.pointer().getType() << " as " << m.getType();
} }
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,

View File

@ -31,7 +31,7 @@ using namespace mlir::shape;
namespace { namespace {
#include "ShapeCanonicalization.inc" #include "ShapeCanonicalization.inc"
} } // namespace
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
return RankedTensorType::get({rank}, IndexType::get(ctx)); return RankedTensorType::get({rank}, IndexType::get(ctx));
@ -50,7 +50,8 @@ LogicalResult shape::getShapeVec(Value input,
return failure(); return failure();
shapeValues = llvm::to_vector<6>(type.getShape()); shapeValues = llvm::to_vector<6>(type.getShape());
return success(); return success();
} else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) { }
if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>()); shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
return success(); return success();
} else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) { } else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {

View File

@ -540,7 +540,8 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(), rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
condbr.getTrueOperands()); condbr.getTrueOperands());
return success(); return success();
} else if (matchPattern(condbr.getCondition(), m_Zero())) { }
if (matchPattern(condbr.getCondition(), m_Zero())) {
// False branch taken. // False branch taken.
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(), rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
condbr.getFalseOperands()); condbr.getFalseOperands());

View File

@ -152,11 +152,11 @@ struct ReifyExpandOrCollapseShapeOp
reifyResultShapes(Operation *op, OpBuilder &b, reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto reshape_op = cast<OpTy>(op); auto reshapeOp = cast<OpTy>(op);
auto result_shape = getReshapeOutputShapeFromInputShape( auto resultShape = getReshapeOutputShapeFromInputShape(
b, loc, reshape_op.src(), reshape_op.getResultType().getShape(), b, loc, reshapeOp.src(), reshapeOp.getResultType().getShape(),
reshape_op.getReassociationMaps()); reshapeOp.getReassociationMaps());
reifiedReturnShapes.push_back(getAsValues(b, loc, result_shape)); reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape));
return success(); return success();
} }
}; };

View File

@ -634,7 +634,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static int64_t GetNumElements(ShapedType type) { static int64_t getNumElements(ShapedType type) {
int64_t numElements = 1; int64_t numElements = 1;
for (auto dim : type.getShape()) for (auto dim : type.getShape())
numElements *= dim; numElements *= dim;
@ -657,7 +657,7 @@ static LogicalResult verify(ReshapeOp op) {
if (resultRankedType) { if (resultRankedType) {
if (operandRankedType && resultRankedType.hasStaticShape() && if (operandRankedType && resultRankedType.hasStaticShape() &&
operandRankedType.hasStaticShape()) { operandRankedType.hasStaticShape()) {
if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType)) if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
return op.emitOpError("source and destination tensor should have the " return op.emitOpError("source and destination tensor should have the "
"same number of elements"); "same number of elements");
} }

View File

@ -97,9 +97,9 @@ public:
// Traverse all `elements` and create `memref.store` ops. // Traverse all `elements` and create `memref.store` ops.
ImplicitLocOpBuilder b(loc, rewriter); ImplicitLocOpBuilder b(loc, rewriter);
auto element_it = adaptor.elements().begin(); auto elementIt = adaptor.elements().begin();
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]); SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
CreateStores(/*dim=*/0, buffer, shape, constants, element_it, indices, b); createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
rewriter.replaceOp(op, {buffer}); rewriter.replaceOp(op, {buffer});
return success(); return success();
@ -108,21 +108,21 @@ public:
private: private:
// Implements backtracking to traverse indices of the output buffer while // Implements backtracking to traverse indices of the output buffer while
// iterating over op.elements(). // iterating over op.elements().
void CreateStores(int dim, Value buffer, ArrayRef<int64_t> shape, void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
ArrayRef<Value> constants, ValueRange::iterator &element_it, ArrayRef<Value> constants, ValueRange::iterator &elementIt,
SmallVectorImpl<Value> &indices, SmallVectorImpl<Value> &indices,
ImplicitLocOpBuilder b) const { ImplicitLocOpBuilder b) const {
if (dim == static_cast<int>(shape.size()) - 1) { if (dim == static_cast<int>(shape.size()) - 1) {
for (int i = 0; i < shape.back(); ++i) { for (int i = 0; i < shape.back(); ++i) {
indices.back() = constants[i]; indices.back() = constants[i];
b.create<memref::StoreOp>(*element_it, buffer, indices); b.create<memref::StoreOp>(*elementIt, buffer, indices);
++element_it; ++elementIt;
} }
return; return;
} }
for (int i = 0; i < shape[dim]; ++i) { for (int i = 0; i < shape[dim]; ++i) {
indices[dim] = constants[i]; indices[dim] = constants[i];
CreateStores(dim + 1, buffer, shape, constants, element_it, indices, b); createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
} }
} }
}; };

View File

@ -771,8 +771,8 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
OperationState &result, OperationState &result,
Type outputType, Value input, Type outputType, Value input,
Value paddings, Value paddings,
Value pad_const) { Value padConst) {
result.addOperands({input, paddings, pad_const}); result.addOperands({input, paddings, padConst});
auto quantAttr = buildPadOpQuantizationAttr(builder, input); auto quantAttr = buildPadOpQuantizationAttr(builder, input);
if (quantAttr) if (quantAttr)
result.addAttribute("quantization_info", quantAttr); result.addAttribute("quantization_info", quantAttr);

View File

@ -33,9 +33,9 @@ static void getValuesFromIntArrayAttribute(ArrayAttr attr,
} }
template <typename TosaOp, typename... Args> template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
Args &&...args) { Args &&...args) {
auto op = rewriter.create<TosaOp>(loc, result_ty, args...); auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
InferShapedTypeOpInterface shapeInterface = InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
@ -57,12 +57,12 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
auto result = op->getResult(0); auto result = op->getResult(0);
auto predictedShape = returnedShapes[0]; auto predictedShape = returnedShapes[0];
auto currentKnowledge = auto currentKnowledge =
mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty); mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type. // Compute the knowledge based on the inferred type.
auto inferredKnowledge = auto inferredKnowledge =
mlir::tosa::ValueKnowledge::getPessimisticValueState(); mlir::tosa::ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType(); inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank(); inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) { if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) { for (auto dim : predictedShape.getDims()) {
@ -73,8 +73,8 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
// Compute the new type based on the joined version. // Compute the new type based on the joined version.
auto newKnowledge = auto newKnowledge =
mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge); mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto new_ty = newKnowledge.getType(); auto newTy = newKnowledge.getType();
result.setType(new_ty); result.setType(newTy);
return op; return op;
} }
@ -205,19 +205,19 @@ public:
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = CreateOpAndInfer<tosa::ConstOp>( Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
if (op.quantization_info().hasValue()) { if (op.quantization_info().hasValue()) {
auto quantInfo = op.quantization_info().getValue(); auto quantInfo = op.quantization_info().getValue();
weight = CreateOpAndInfer<tosa::PadOp>( weight = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr, weightPaddingVal, nullptr,
PadOpQuantizationAttr::get(quantInfo.weight_zp(), PadOpQuantizationAttr::get(quantInfo.weight_zp(),
rewriter.getContext())); rewriter.getContext()));
} else { } else {
weight = CreateOpAndInfer<tosa::PadOp>(rewriter, loc, weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(weightETy), UnrankedTensorType::get(weightETy),
weight, weightPaddingVal); weight, weightPaddingVal);
} }
@ -231,7 +231,7 @@ public:
outputChannels, weightHeight / stride[0], outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1], stride[0], weightWidth / stride[1],
stride[1], inputChannels}; stride[1], inputChannels};
weight = CreateOpAndInfer<tosa::ReshapeOp>( weight = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getI64ArrayAttr(weightReshapeDims0)); rewriter.getI64ArrayAttr(weightReshapeDims0));
@ -240,7 +240,7 @@ public:
loc, RankedTensorType::get({6}, rewriter.getI32Type()), loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
weight = CreateOpAndInfer<tosa::TransposeOp>( weight = createOpAndInfer<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal); transposeWeightVal);
@ -248,15 +248,15 @@ public:
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0], outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels}; weightWidth / stride[1], inputChannels};
weight = CreateOpAndInfer<tosa::ReshapeOp>( weight = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getI64ArrayAttr(weightReshapeDims1)); rewriter.getI64ArrayAttr(weightReshapeDims1));
ShapedType restridedWeightTy = weight.getType().cast<ShapedType>(); ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
weight = CreateOpAndInfer<tosa::ReverseOp>( weight = createOpAndInfer<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getI64IntegerAttr(1)); rewriter.getI64IntegerAttr(1));
weight = CreateOpAndInfer<tosa::ReverseOp>( weight = createOpAndInfer<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight, rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getI64IntegerAttr(2)); rewriter.getI64IntegerAttr(2));
@ -270,18 +270,18 @@ public:
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
Value inputPaddingVal = CreateOpAndInfer<tosa::ConstOp>( Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
if (op.quantization_info().hasValue()) { if (op.quantization_info().hasValue()) {
auto quantInfo = op.quantization_info().getValue(); auto quantInfo = op.quantization_info().getValue();
input = CreateOpAndInfer<tosa::PadOp>( input = createOpAndInfer<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input, rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr, inputPaddingVal, nullptr,
PadOpQuantizationAttr::get(quantInfo.input_zp(), PadOpQuantizationAttr::get(quantInfo.input_zp(),
rewriter.getContext())); rewriter.getContext()));
} else { } else {
input = CreateOpAndInfer<tosa::PadOp>(rewriter, loc, input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(inputETy), UnrankedTensorType::get(inputETy),
input, inputPaddingVal); input, inputPaddingVal);
} }
@ -299,7 +299,7 @@ public:
// Perform the convolution using the zero bias. // Perform the convolution using the zero bias.
Value conv2d; Value conv2d;
if (op.quantization_info().hasValue()) { if (op.quantization_info().hasValue()) {
conv2d = CreateOpAndInfer<tosa::Conv2DOp>( conv2d = createOpAndInfer<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input, rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias, weight, zeroBias,
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
@ -308,7 +308,7 @@ public:
op.quantization_info().getValue()) op.quantization_info().getValue())
.getResult(); .getResult();
} else { } else {
conv2d = CreateOpAndInfer<tosa::Conv2DOp>( conv2d = createOpAndInfer<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input, rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias, weight, zeroBias,
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
@ -327,7 +327,7 @@ public:
// Factor striding out of the convolution result. // Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = { llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
conv2d = CreateOpAndInfer<tosa::ReshapeOp>( conv2d = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getI64ArrayAttr(convReshapeDims0)); rewriter.getI64ArrayAttr(convReshapeDims0));
@ -336,14 +336,14 @@ public:
loc, RankedTensorType::get({6}, rewriter.getI32Type()), loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
conv2d = CreateOpAndInfer<tosa::TransposeOp>( conv2d = createOpAndInfer<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d, rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal); transposeConvVal);
// Fuse striding behavior back into width / height. // Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = { llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
conv2d = CreateOpAndInfer<tosa::ReshapeOp>( conv2d = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getI64ArrayAttr(convReshapeDims1)); rewriter.getI64ArrayAttr(convReshapeDims1));
@ -354,14 +354,14 @@ public:
sliceBegin[1] = pad[0]; sliceBegin[1] = pad[0];
sliceBegin[2] = pad[1]; sliceBegin[2] = pad[1];
auto slice = CreateOpAndInfer<tosa::SliceOp>( auto slice = createOpAndInfer<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getI64ArrayAttr(sliceBegin), rewriter.getI64ArrayAttr(sliceBegin),
rewriter.getI64ArrayAttr(resultTy.getShape())) rewriter.getI64ArrayAttr(resultTy.getShape()))
.getResult(); .getResult();
auto addBias = auto addBias =
CreateOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias); createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
rewriter.replaceOp(op, addBias.getResult()); rewriter.replaceOp(op, addBias.getResult());

View File

@ -223,7 +223,7 @@ void propagateShapesInRegion(Region &region) {
// Check whether this use case is replaceable. We define an op as // Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a ReturnOp or a TosaOp. // being replaceable if it is used by a ReturnOp or a TosaOp.
bool replaceable = true; bool replaceable = true;
for (auto user : result.getUsers()) { for (auto *user : result.getUsers()) {
if (isa<ReturnOp>(user)) if (isa<ReturnOp>(user))
continue; continue;
if (user->getDialect()->getNamespace() == if (user->getDialect()->getNamespace() ==

View File

@ -1179,7 +1179,7 @@ struct UnrolledOuterProductGenerator
return builder.create<vector::TransposeOp>(loc, v, perm); return builder.create<vector::TransposeOp>(loc, v, perm);
} }
Value outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0); assert(reductionSize > 0);
for (int64_t k = 0; k < reductionSize; ++k) { for (int64_t k = 0; k < reductionSize; ++k) {
Value a = builder.create<vector::ExtractOp>(loc, lhs, k); Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
@ -1199,31 +1199,31 @@ struct UnrolledOuterProductGenerator
bindDims(builder.getContext(), m, n, k); bindDims(builder.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs. // Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}})) if (layout({{m, k}, {k, n}, {m, n}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction. // TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) { if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs); Value tlhs = t(lhs);
return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1)); return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
} }
// No need to permute anything. // No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}})) if (layout({{k, m}, {k, n}, {m, n}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
// Just permute the rhs. // Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}})) if (layout({{k, m}, {n, k}, {m, n}}))
return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0)); return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
// Transposed output: swap RHS and LHS. // Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs. // Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}})) if (layout({{m, k}, {k, n}, {n, m}}))
return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1)); return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction. // TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) { if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs); Value trhs = t(rhs);
return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1)); return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
} }
if (layout({{k, m}, {k, n}, {n, m}})) if (layout({{k, m}, {k, n}, {n, m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
if (layout({{k, m}, {n, k}, {n, m}})) if (layout({{k, m}, {n, k}, {n, m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
return failure(); return failure();
} }
@ -1236,16 +1236,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose. // Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}})) if (layout({{m, k}, {k}, {m}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go. // Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}})) if (layout({{k, m}, {k}, {m}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose. // Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}})) if (layout({{k}, {m, k}, {m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go. // Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}})) if (layout({{k}, {k, m}, {m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
return failure(); return failure();
} }
@ -1260,16 +1260,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose. // Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}})) if (layout({{m, k}, {k}, {m}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go. // Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}})) if (layout({{k, m}, {k}, {m}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose. // Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}})) if (layout({{k}, {m, k}, {m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go. // Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}})) if (layout({{k}, {k, m}, {m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
return failure(); return failure();
} }

View File

@ -31,8 +31,9 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
auto asmDialectAttr = auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel); LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
auto asmTp = "vblendps $0, $1, $2, {0}"; const auto *asmTp = "vblendps $0, $1, $2, {0}";
auto asmCstr = "=x,x,x"; // Careful: constraint parser is very brittle: no ws! const auto *asmCstr =
"=x,x,x"; // Careful: constraint parser is very brittle: no ws!
SmallVector<Value> asmVals{v1, v2}; SmallVector<Value> asmVals{v1, v2};
auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str(); auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
auto asmOp = b.create<LLVM::InlineAsmOp>( auto asmOp = b.create<LLVM::InlineAsmOp>(
@ -116,18 +117,18 @@ void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
"expects all types to be vector<8xf32>"); "expects all types to be vector<8xf32>");
#endif #endif
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>()); vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>()); vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>()); vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>()); vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
} }
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
@ -140,46 +141,46 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
[&](Type t) { return t == vt; }) && [&](Type t) { return t == vt; }) &&
"expects all types to be vector<8xf32>"); "expects all types to be vector<8xf32>");
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]); Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]); Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]); Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]); Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
using inline_asm::mm256BlendPsAsm; using inline_asm::mm256BlendPsAsm;
Value sh0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 3, 2>()); Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 3, 2>()); Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 3, 2>()); Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 3, 2>()); Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
Value S0 = Value s0 =
mm256BlendPsAsm(ib, T0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value S1 = Value s1 =
mm256BlendPsAsm(ib, T2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value S2 = Value s2 =
mm256BlendPsAsm(ib, T1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value S3 = Value s3 =
mm256BlendPsAsm(ib, T3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value S4 = Value s4 =
mm256BlendPsAsm(ib, T4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value S5 = Value s5 =
mm256BlendPsAsm(ib, T6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value S6 = Value s6 =
mm256BlendPsAsm(ib, T5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>()); mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value S7 = Value s7 =
mm256BlendPsAsm(ib, T7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>()); mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>()); vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>()); vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>()); vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>()); vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>()); vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>()); vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>()); vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>()); vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
} }
/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and

View File

@ -463,8 +463,10 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
// https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
// The bug is fixed in VS2019 16.1. Separating the declaration and definition is // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
// a work around for older versions of Visual Studio. // a work around for older versions of Visual Studio.
// NOLINTNEXTLINE(*-identifier-naming): externally called.
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols); extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
// NOLINTNEXTLINE(*-identifier-naming): externally called.
void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) { void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
auto exportSymbol = [&](llvm::StringRef name, auto ptr) { auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
assert(exportSymbols.count(name) == 0 && "symbol already exists"); assert(exportSymbols.count(name) == 0 && "symbol already exists");
@ -517,6 +519,7 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId); &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
} }
// NOLINTNEXTLINE(*-identifier-naming): externally called.
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); } extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
} // namespace runtime } // namespace runtime

View File

@ -58,27 +58,27 @@ using llvm::orc::ThreadSafeModule;
using llvm::orc::TMOwningSimpleCompiler; using llvm::orc::TMOwningSimpleCompiler;
/// Wrap a string into an llvm::StringError. /// Wrap a string into an llvm::StringError.
static Error make_string_error(const Twine &message) { static Error makeStringError(const Twine &message) {
return llvm::make_error<StringError>(message.str(), return llvm::make_error<StringError>(message.str(),
llvm::inconvertibleErrorCode()); llvm::inconvertibleErrorCode());
} }
void SimpleObjectCache::notifyObjectCompiled(const Module *M, void SimpleObjectCache::notifyObjectCompiled(const Module *m,
MemoryBufferRef ObjBuffer) { MemoryBufferRef objBuffer) {
cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier()); objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
} }
std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
auto I = cachedObjects.find(M->getModuleIdentifier()); auto i = cachedObjects.find(m->getModuleIdentifier());
if (I == cachedObjects.end()) { if (i == cachedObjects.end()) {
LLVM_DEBUG(dbgs() << "No object for " << M->getModuleIdentifier() LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
<< " in cache. Compiling.\n"); << " in cache. Compiling.\n");
return nullptr; return nullptr;
} }
LLVM_DEBUG(dbgs() << "Object for " << M->getModuleIdentifier() LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
<< " loaded from cache.\n"); << " loaded from cache.\n");
return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
} }
void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) { void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
@ -114,7 +114,8 @@ bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
// Setup the machine properties from the current architecture. // Setup the machine properties from the current architecture.
auto targetTriple = llvm::sys::getDefaultTargetTriple(); auto targetTriple = llvm::sys::getDefaultTargetTriple();
std::string errorMessage; std::string errorMessage;
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); const auto *target =
llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
if (!target) { if (!target) {
errs() << "NO target: " << errorMessage << "\n"; errs() << "NO target: " << errorMessage << "\n";
return true; return true;
@ -160,7 +161,7 @@ static void packFunctionArguments(Module *module) {
// Given a function `foo(<...>)`, define the interface function // Given a function `foo(<...>)`, define the interface function
// `mlir_foo(i8**)`. // `mlir_foo(i8**)`.
auto newType = llvm::FunctionType::get( auto *newType = llvm::FunctionType::get(
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
/*isVarArg=*/false); /*isVarArg=*/false);
auto newName = makePackedFunctionName(func.getName()); auto newName = makePackedFunctionName(func.getName());
@ -170,7 +171,7 @@ static void packFunctionArguments(Module *module) {
// Extract the arguments from the type-erased argument list and cast them to // Extract the arguments from the type-erased argument list and cast them to
// the proper types. // the proper types.
auto bb = llvm::BasicBlock::Create(ctx); auto *bb = llvm::BasicBlock::Create(ctx);
bb->insertInto(interfaceFunc); bb->insertInto(interfaceFunc);
builder.SetInsertPoint(bb); builder.SetInsertPoint(bb);
llvm::Value *argList = interfaceFunc->arg_begin(); llvm::Value *argList = interfaceFunc->arg_begin();
@ -237,7 +238,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx) auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx)
: translateModuleToLLVMIR(m, *ctx); : translateModuleToLLVMIR(m, *ctx);
if (!llvmModule) if (!llvmModule)
return make_string_error("could not convert to LLVM IR"); return makeStringError("could not convert to LLVM IR");
// FIXME: the triple should be passed to the translation or dialect conversion // FIXME: the triple should be passed to the translation or dialect conversion
// instead of this. Currently, the LLVM module created above has no triple // instead of this. Currently, the LLVM module created above has no triple
// associated with it. // associated with it.
@ -249,7 +250,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
// Callback to create the object layer with symbol resolution to current // Callback to create the object layer with symbol resolution to current
// process and dynamically linked libraries. // process and dynamically linked libraries.
auto objectLinkingLayerCreator = [&](ExecutionSession &session, auto objectLinkingLayerCreator = [&](ExecutionSession &session,
const Triple &TT) { const Triple &tt) {
auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
session, []() { return std::make_unique<SectionMemoryManager>(); }); session, []() { return std::make_unique<SectionMemoryManager>(); });
@ -276,7 +277,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
<< "\nError: " << mb.getError().message() << "\n"; << "\nError: " << mb.getError().message() << "\n";
continue; continue;
} }
auto &JD = session.createBareJITDylib(std::string(libPath)); auto &jd = session.createBareJITDylib(std::string(libPath));
auto loaded = DynamicLibrarySearchGenerator::Load( auto loaded = DynamicLibrarySearchGenerator::Load(
libPath.data(), dataLayout.getGlobalPrefix()); libPath.data(), dataLayout.getGlobalPrefix());
if (!loaded) { if (!loaded) {
@ -284,8 +285,8 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
<< "\n"; << "\n";
continue; continue;
} }
JD.addGenerator(std::move(*loaded)); jd.addGenerator(std::move(*loaded));
cantFail(objectLayer->add(JD, std::move(mb.get()))); cantFail(objectLayer->add(jd, std::move(mb.get())));
} }
return objectLayer; return objectLayer;
@ -293,14 +294,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
// Callback to inspect the cache and recompile on demand. This follows Lang's // Callback to inspect the cache and recompile on demand. This follows Lang's
// LLJITWithObjectCache example. // LLJITWithObjectCache example.
auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB) auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
-> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> { -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
if (jitCodeGenOptLevel) if (jitCodeGenOptLevel)
JTMB.setCodeGenOptLevel(jitCodeGenOptLevel.getValue()); jtmb.setCodeGenOptLevel(jitCodeGenOptLevel.getValue());
auto TM = JTMB.createTargetMachine(); auto tm = jtmb.createTargetMachine();
if (!TM) if (!tm)
return TM.takeError(); return tm.takeError();
return std::make_unique<TMOwningSimpleCompiler>(std::move(*TM), return std::make_unique<TMOwningSimpleCompiler>(std::move(*tm),
engine->cache.get()); engine->cache.get());
}; };
@ -350,13 +351,13 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
llvm::raw_string_ostream os(errorMessage); llvm::raw_string_ostream os(errorMessage);
llvm::handleAllErrors(expectedSymbol.takeError(), llvm::handleAllErrors(expectedSymbol.takeError(),
[&os](llvm::ErrorInfoBase &ei) { ei.log(os); }); [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
return make_string_error(os.str()); return makeStringError(os.str());
} }
auto rawFPtr = expectedSymbol->getAddress(); auto rawFPtr = expectedSymbol->getAddress();
auto fptr = reinterpret_cast<void *>(rawFPtr); auto *fptr = reinterpret_cast<void *>(rawFPtr);
if (!fptr) if (!fptr)
return make_string_error("looked up function is null"); return makeStringError("looked up function is null");
return fptr; return fptr;
} }

View File

@ -125,7 +125,7 @@ static OwningModuleRef parseMLIRInput(StringRef inputFilename,
return OwningModuleRef(parseSourceFile(sourceMgr, context)); return OwningModuleRef(parseSourceFile(sourceMgr, context));
} }
static inline Error make_string_error(const Twine &message) { static inline Error makeStringError(const Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(), return llvm::make_error<llvm::StringError>(message.str(),
llvm::inconvertibleErrorCode()); llvm::inconvertibleErrorCode());
} }
@ -239,7 +239,7 @@ static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
CompileAndExecuteConfig config) { CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.empty()) if (!mainFunction || mainFunction.empty())
return make_string_error("entry point not found"); return makeStringError("entry point not found");
void *empty = nullptr; void *empty = nullptr;
return compileAndExecute(options, module, entryPoint, config, &empty); return compileAndExecute(options, module, entryPoint, config, &empty);
} }
@ -253,7 +253,7 @@ Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
.getReturnType() .getReturnType()
.dyn_cast<IntegerType>(); .dyn_cast<IntegerType>();
if (!resultType || resultType.getWidth() != 32) if (!resultType || resultType.getWidth() != 32)
return make_string_error("only single i32 function result supported"); return makeStringError("only single i32 function result supported");
return Error::success(); return Error::success();
} }
template <> template <>
@ -263,7 +263,7 @@ Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
.getReturnType() .getReturnType()
.dyn_cast<IntegerType>(); .dyn_cast<IntegerType>();
if (!resultType || resultType.getWidth() != 64) if (!resultType || resultType.getWidth() != 64)
return make_string_error("only single i64 function result supported"); return makeStringError("only single i64 function result supported");
return Error::success(); return Error::success();
} }
template <> template <>
@ -272,7 +272,7 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
.cast<LLVM::LLVMFunctionType>() .cast<LLVM::LLVMFunctionType>()
.getReturnType() .getReturnType()
.isa<Float32Type>()) .isa<Float32Type>())
return make_string_error("only single f32 function result supported"); return makeStringError("only single f32 function result supported");
return Error::success(); return Error::success();
} }
template <typename Type> template <typename Type>
@ -281,10 +281,10 @@ Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
CompileAndExecuteConfig config) { CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal()) if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found"); return makeStringError("entry point not found");
if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0) if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
return make_string_error("function inputs not supported"); return makeStringError("function inputs not supported");
if (Error error = checkCompatibleReturnType<Type>(mainFunction)) if (Error error = checkCompatibleReturnType<Type>(mainFunction))
return error; return error;
@ -384,7 +384,7 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
? compileAndExecuteFn(options, m.get(), ? compileAndExecuteFn(options, m.get(),
options.mainFuncName.getValue(), options.mainFuncName.getValue(),
compileAndExecuteConfig) compileAndExecuteConfig)
: make_string_error("unsupported function type"); : makeStringError("unsupported function type");
int exitCode = EXIT_SUCCESS; int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error), llvm::handleAllErrors(std::move(error),

View File

@ -16,6 +16,8 @@
#include "mlir/ExecutionEngine/RunnerUtils.h" #include "mlir/ExecutionEngine/RunnerUtils.h"
#include <chrono> #include <chrono>
// NOLINTBEGIN(*-identifier-naming)
extern "C" void extern "C" void
_mlir_ciface_print_memref_shape_i8(UnrankedMemRefType<int8_t> *M) { _mlir_ciface_print_memref_shape_i8(UnrankedMemRefType<int8_t> *M) {
std::cout << "Unranked Memref "; std::cout << "Unranked Memref ";
@ -163,3 +165,5 @@ extern "C" int64_t verifyMemRefF64(int64_t rank, void *actualPtr,
UnrankedMemRefType<double> expectedDesc = {rank, expectedPtr}; UnrankedMemRefType<double> expectedDesc = {rank, expectedPtr};
return _mlir_ciface_verifyMemRefF64(&actualDesc, &expectedDesc); return _mlir_ciface_verifyMemRefF64(&actualDesc, &expectedDesc);
} }
// NOLINTEND(*-identifier-naming)

View File

@ -209,7 +209,7 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
SmallVector<AffineExpr, 4> affExprs; SmallVector<AffineExpr, 4> affExprs;
for (auto index : permutation) for (auto index : permutation)
affExprs.push_back(getAffineDimExpr(index, context)); affExprs.push_back(getAffineDimExpr(index, context));
auto m = std::max_element(permutation.begin(), permutation.end()); const auto *m = std::max_element(permutation.begin(), permutation.end());
auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context); auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
assert(permutationMap.isPermutation() && "Invalid permutation vector"); assert(permutationMap.isPermutation() && "Invalid permutation vector");
return permutationMap; return permutationMap;

View File

@ -1105,7 +1105,7 @@ void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
// Find the correct index using a binary search, as the groups are ordered. // Find the correct index using a binary search, as the groups are ordered.
ArrayRef<int> resultGroups = resultGroupIt->second; ArrayRef<int> resultGroups = resultGroupIt->second;
auto it = llvm::upper_bound(resultGroups, resultNo); const auto *it = llvm::upper_bound(resultGroups, resultNo);
int groupResultNo = 0, groupSize = 0; int groupResultNo = 0, groupSize = 0;
// If there are no smaller elements, the last result group is the lookup. // If there are no smaller elements, the last result group is the lookup.
@ -1240,8 +1240,8 @@ public:
raw_ostream &getStream() { return os; } raw_ostream &getStream() { return os; }
template <typename Container, typename UnaryFunctor> template <typename Container, typename UnaryFunctor>
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
llvm::interleaveComma(c, os, each_fn); llvm::interleaveComma(c, os, eachFn);
} }
/// This enum describes the different kinds of elision for the type of an /// This enum describes the different kinds of elision for the type of an

View File

@ -316,7 +316,7 @@ Block *Block::getUniquePredecessor() {
Block *Block::splitBlock(iterator splitBefore) { Block *Block::splitBlock(iterator splitBefore) {
// Start by creating a new basic block, and insert it immediate after this // Start by creating a new basic block, and insert it immediate after this
// one in the containing region. // one in the containing region.
auto newBB = new Block(); auto *newBB = new Block();
getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB); getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
// Move all of the operations from the split point to the end of the region // Move all of the operations from the split point to the end of the region

View File

@ -121,10 +121,10 @@ findDuplicateElement(ArrayRef<NamedAttribute> value) {
if (value.size() == 2) if (value.size() == 2)
return value[0].getName() == value[1].getName() ? value[0] : none; return value[0].getName() == value[1].getName() ? value[0] : none;
auto it = std::adjacent_find(value.begin(), value.end(), const auto *it = std::adjacent_find(value.begin(), value.end(),
[](NamedAttribute l, NamedAttribute r) { [](NamedAttribute l, NamedAttribute r) {
return l.getName() == r.getName(); return l.getName() == r.getName();
}); });
return it != value.end() ? *it : none; return it != value.end() ? *it : none;
} }

View File

@ -44,7 +44,6 @@ using namespace mlir;
using namespace mlir::detail; using namespace mlir::detail;
using llvm::hash_combine; using llvm::hash_combine;
using llvm::hash_combine_range;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MLIRContext CommandLine Options // MLIRContext CommandLine Options

View File

@ -349,28 +349,28 @@ void Operation::updateOrderIfNecessary() {
auto llvm::ilist_detail::SpecificNodeAccess< auto llvm::ilist_detail::SpecificNodeAccess<
typename llvm::ilist_detail::compute_node_options< typename llvm::ilist_detail::compute_node_options<
::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * { ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * {
return NodeAccess::getNodePtr<OptionsT>(N); return NodeAccess::getNodePtr<OptionsT>(n);
} }
auto llvm::ilist_detail::SpecificNodeAccess< auto llvm::ilist_detail::SpecificNodeAccess<
typename llvm::ilist_detail::compute_node_options< typename llvm::ilist_detail::compute_node_options<
::mlir::Operation>::type>::getNodePtr(const_pointer N) ::mlir::Operation>::type>::getNodePtr(const_pointer n)
-> const node_type * { -> const node_type * {
return NodeAccess::getNodePtr<OptionsT>(N); return NodeAccess::getNodePtr<OptionsT>(n);
} }
auto llvm::ilist_detail::SpecificNodeAccess< auto llvm::ilist_detail::SpecificNodeAccess<
typename llvm::ilist_detail::compute_node_options< typename llvm::ilist_detail::compute_node_options<
::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer { ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer {
return NodeAccess::getValuePtr<OptionsT>(N); return NodeAccess::getValuePtr<OptionsT>(n);
} }
auto llvm::ilist_detail::SpecificNodeAccess< auto llvm::ilist_detail::SpecificNodeAccess<
typename llvm::ilist_detail::compute_node_options< typename llvm::ilist_detail::compute_node_options<
::mlir::Operation>::type>::getValuePtr(const node_type *N) ::mlir::Operation>::type>::getValuePtr(const node_type *n)
-> const_pointer { -> const_pointer {
return NodeAccess::getValuePtr<OptionsT>(N); return NodeAccess::getValuePtr<OptionsT>(n);
} }
void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) { void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
@ -378,9 +378,9 @@ void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
} }
Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() { Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr)))); size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this)); iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this));
return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset); return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset);
} }
/// This is a trait method invoked when an operation is added to a block. We /// This is a trait method invoked when an operation is added to a block. We
@ -1024,8 +1024,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
if (op->getNumRegions() > 1) if (op->getNumRegions() > 1)
return op->emitOpError("region #") return op->emitOpError("region #")
<< region.getRegionNumber() << " should have no arguments"; << region.getRegionNumber() << " should have no arguments";
else return op->emitOpError("region should have no arguments");
return op->emitOpError("region should have no arguments");
} }
} }
return success(); return success();

View File

@ -34,8 +34,8 @@ NamedAttrList::NamedAttrList(DictionaryAttr attributes)
dictionarySorted.setPointerAndInt(attributes, true); dictionarySorted.setPointerAndInt(attributes, true);
} }
NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) { NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
assign(in_start, in_end); assign(inStart, inEnd);
} }
ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; } ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
@ -66,8 +66,8 @@ void NamedAttrList::append(StringRef name, Attribute attr) {
} }
/// Replaces the attributes with new list of attributes. /// Replaces the attributes with new list of attributes.
void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) { void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
DictionaryAttr::sort(ArrayRef<NamedAttribute>{in_start, in_end}, attrs); DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
dictionarySorted.setPointerAndInt(nullptr, true); dictionarySorted.setPointerAndInt(nullptr, true);
} }

View File

@ -152,10 +152,10 @@ void Region::dropAllReferences() {
} }
Region *llvm::ilist_traits<::mlir::Block>::getParentRegion() { Region *llvm::ilist_traits<::mlir::Block>::getParentRegion() {
size_t Offset( size_t offset(
size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr)))); size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr))));
iplist<Block> *Anchor(static_cast<iplist<Block> *>(this)); iplist<Block> *anchor(static_cast<iplist<Block> *>(this));
return reinterpret_cast<Region *>(reinterpret_cast<char *>(Anchor) - Offset); return reinterpret_cast<Region *>(reinterpret_cast<char *>(anchor) - offset);
} }
/// This is a trait method invoked when a basic block is added to a region. /// This is a trait method invoked when a basic block is added to a region.

View File

@ -76,9 +76,9 @@ static bool wouldOpBeTriviallyDeadImpl(Operation *rootOp) {
// Otherwise, if the op has recursive side effects we can treat the // Otherwise, if the op has recursive side effects we can treat the
// operation itself as having no effects. // operation itself as having no effects.
} else if (hasRecursiveEffects) {
continue;
} }
if (hasRecursiveEffects)
continue;
// If there were no effect interfaces, we treat this op as conservatively // If there were no effect interfaces, we treat this op as conservatively
// having effects. // having effects.

View File

@ -525,13 +525,14 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
bool isColon = getToken().is(Token::colon); bool isColon = getToken().is(Token::colon);
if (!isArrow && !isColon) { if (!isArrow && !isColon) {
return emitError("expected '->' or ':'"); return emitError("expected '->' or ':'");
} else if (isArrow) { }
if (isArrow) {
parseToken(Token::arrow, "expected '->' or '['"); parseToken(Token::arrow, "expected '->' or '['");
map = parseAffineMapRange(numDims, numSymbols); map = parseAffineMapRange(numDims, numSymbols);
return map ? success() : failure(); return map ? success() : failure();
} else if (parseToken(Token::colon, "expected ':' or '['")) {
return failure();
} }
if (parseToken(Token::colon, "expected ':' or '['"))
return failure();
if ((set = parseIntegerSetConstraints(numDims, numSymbols))) if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
return success(); return success();

View File

@ -358,8 +358,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
PassInstrumentor *pi = am.getPassInstrumentor(); PassInstrumentor *pi = am.getPassInstrumentor();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
pass}; pass};
auto dynamic_pipeline_callback = [&](OpPassManager &pipeline, auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
Operation *root) -> LogicalResult { Operation *root) -> LogicalResult {
if (!op->isAncestor(root)) if (!op->isAncestor(root))
return root->emitOpError() return root->emitOpError()
<< "Trying to schedule a dynamic pipeline on an " << "Trying to schedule a dynamic pipeline on an "
@ -379,7 +379,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
verifyPasses, parentInitGeneration, verifyPasses, parentInitGeneration,
pi, &parentInfo); pi, &parentInfo);
}; };
pass->passState.emplace(op, am, dynamic_pipeline_callback); pass->passState.emplace(op, am, dynamicPipelineCallback);
// Instrument before the pass has run. // Instrument before the pass has run.
if (pi) if (pi)
@ -437,7 +437,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
const PassInstrumentation::PipelineParentInfo *parentInfo) { const PassInstrumentation::PipelineParentInfo *parentInfo) {
assert((!instrumentor || parentInfo) && assert((!instrumentor || parentInfo) &&
"expected parent info if instrumentor is provided"); "expected parent info if instrumentor is provided");
auto scope_exit = llvm::make_scope_exit([&] { auto scopeExit = llvm::make_scope_exit([&] {
// Clear out any computed operation analyses. These analyses won't be used // Clear out any computed operation analyses. These analyses won't be used
// any more in this pipeline, and this helps reduce the current working set // any more in this pipeline, and this helps reduce the current working set
// of memory. If preserving these analyses becomes important in the future // of memory. If preserving these analyses becomes important in the future
@ -460,7 +460,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
/// type, or nullptr if one does not exist. /// type, or nullptr if one does not exist.
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs, static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
StringRef name) { StringRef name) {
auto it = llvm::find_if( auto *it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; }); mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
return it == mgrs.end() ? nullptr : &*it; return it == mgrs.end() ? nullptr : &*it;
} }
@ -470,7 +470,7 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs, static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
StringAttr name, StringAttr name,
MLIRContext &context) { MLIRContext &context) {
auto it = llvm::find_if( auto *it = llvm::find_if(
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; }); mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
return it == mgrs.end() ? nullptr : &*it; return it == mgrs.end() ? nullptr : &*it;
} }

View File

@ -253,7 +253,7 @@ StringRef StructFieldAttr::getName() const {
} }
Attribute StructFieldAttr::getType() const { Attribute StructFieldAttr::getType() const {
auto init = def->getValueInit("type"); auto *init = def->getValueInit("type");
return Attribute(cast<llvm::DefInit>(init)); return Attribute(cast<llvm::DefInit>(init));
} }

View File

@ -38,7 +38,7 @@ std::string Dialect::getCppClassName() const {
static StringRef getAsStringOrEmpty(const llvm::Record &record, static StringRef getAsStringOrEmpty(const llvm::Record &record,
StringRef fieldName) { StringRef fieldName) {
if (auto valueInit = record.getValueInit(fieldName)) { if (auto *valueInit = record.getValueInit(fieldName)) {
if (llvm::isa<llvm::StringInit>(valueInit)) if (llvm::isa<llvm::StringInit>(valueInit))
return record.getValueAsString(fieldName); return record.getValueAsString(fieldName);
} }

View File

@ -346,10 +346,9 @@ void Operator::populateTypeInferenceInfo(
if (getArg(*mi).is<NamedAttribute *>()) { if (getArg(*mi).is<NamedAttribute *>()) {
// TODO: Handle attributes. // TODO: Handle attributes.
continue; continue;
} else {
resultTypeMapping[i].emplace_back(*mi);
found = true;
} }
resultTypeMapping[i].emplace_back(*mi);
found = true;
} }
return found; return found;
}; };

View File

@ -649,7 +649,7 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
std::vector<AppliedConstraint> ret; std::vector<AppliedConstraint> ret;
ret.reserve(listInit->size()); ret.reserve(listInit->size());
for (auto it : *listInit) { for (auto *it : *listInit) {
auto *dagInit = dyn_cast<llvm::DagInit>(it); auto *dagInit = dyn_cast<llvm::DagInit>(it);
if (!dagInit) if (!dagInit)
PrintFatalError(&def, "all elements in Pattern multi-entity " PrintFatalError(&def, "all elements in Pattern multi-entity "

View File

@ -188,7 +188,7 @@ buildPredicateTree(const Pred &root,
// Build child subtrees. // Build child subtrees.
auto combined = static_cast<const CombinedPred &>(root); auto combined = static_cast<const CombinedPred &>(root);
for (const auto *record : combined.getChildren()) { for (const auto *record : combined.getChildren()) {
auto childTree = auto *childTree =
buildPredicateTree(Pred(record), allocator, allSubstitutions); buildPredicateTree(Pred(record), allocator, allSubstitutions);
rootNode->children.push_back(childTree); rootNode->children.push_back(childTree);
} }
@ -241,7 +241,7 @@ propagateGroundTruth(PredNode *node,
for (auto &child : children) { for (auto &child : children) {
// First, simplify the child. This maintains the predicate as it was. // First, simplify the child. This maintains the predicate as it was.
auto simplifiedChild = auto *simplifiedChild =
propagateGroundTruth(child, knownTruePreds, knownFalsePreds); propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
// Just add the child if we don't know how to simplify the current node. // Just add the child if we don't know how to simplify the current node.
@ -273,8 +273,9 @@ propagateGroundTruth(PredNode *node,
node->kind = collapseKind; node->kind = collapseKind;
node->children.clear(); node->children.clear();
return node; return node;
} else if (simplifiedChild->kind == eraseKind || }
eraseList.count(simplifiedChild->predicate) != 0) { if (simplifiedChild->kind == eraseKind ||
eraseList.count(simplifiedChild->predicate) != 0) {
continue; continue;
} }
node->children.push_back(simplifiedChild); node->children.push_back(simplifiedChild);
@ -350,7 +351,7 @@ static std::string getCombinedCondition(const PredNode &root) {
std::string CombinedPred::getConditionImpl() const { std::string CombinedPred::getConditionImpl() const {
llvm::SpecificBumpPtrAllocator<PredNode> allocator; llvm::SpecificBumpPtrAllocator<PredNode> allocator;
auto predicateTree = buildPredicateTree(*this, allocator, {}); auto *predicateTree = buildPredicateTree(*this, allocator, {});
predicateTree = predicateTree =
propagateGroundTruth(predicateTree, propagateGroundTruth(predicateTree,
/*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),

View File

@ -26,7 +26,7 @@ using namespace mlir::tblgen;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Trait Trait::create(const llvm::Init *init) { Trait Trait::create(const llvm::Init *init) {
auto def = cast<llvm::DefInit>(init)->getDef(); auto *def = cast<llvm::DefInit>(init)->getDef();
if (def->isSubClassOf("PredTrait")) if (def->isSubClassOf("PredTrait"))
return Trait(Kind::Pred, def); return Trait(Kind::Pred, def);
if (def->isSubClassOf("GenInternalTrait")) if (def->isSubClassOf("GenInternalTrait"))

View File

@ -61,7 +61,7 @@ public:
LogicalResult processFunction(llvm::Function *f); LogicalResult processFunction(llvm::Function *f);
/// Imports GV as a GlobalOp, creating it if it doesn't exist. /// Imports GV as a GlobalOp, creating it if it doesn't exist.
GlobalOp processGlobal(llvm::GlobalVariable *GV); GlobalOp processGlobal(llvm::GlobalVariable *gv);
private: private:
/// Returns personality of `f` as a FlatSymbolRefAttr. /// Returns personality of `f` as a FlatSymbolRefAttr.
@ -145,7 +145,8 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
os << "llvm-imported-inst-%"; os << "llvm-imported-inst-%";
inst->printAsOperand(os, /*PrintType=*/false); inst->printAsOperand(os, /*PrintType=*/false);
return FileLineColLoc::get(context, os.str(), 0, 0); return FileLineColLoc::get(context, os.str(), 0, 0);
} else if (!loc) { }
if (!loc) {
return unknownLoc; return unknownLoc;
} }
// FIXME: Obtain the filename from DILocationInfo. // FIXME: Obtain the filename from DILocationInfo.
@ -304,47 +305,47 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
return nullptr; return nullptr;
} }
GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) { GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) {
auto it = globals.find(GV); auto it = globals.find(gv);
if (it != globals.end()) if (it != globals.end())
return it->second; return it->second;
OpBuilder b(module.getBody(), getGlobalInsertPt()); OpBuilder b(module.getBody(), getGlobalInsertPt());
Attribute valueAttr; Attribute valueAttr;
if (GV->hasInitializer()) if (gv->hasInitializer())
valueAttr = getConstantAsAttr(GV->getInitializer()); valueAttr = getConstantAsAttr(gv->getInitializer());
Type type = processType(GV->getValueType()); Type type = processType(gv->getValueType());
if (!type) if (!type)
return nullptr; return nullptr;
uint64_t alignment = 0; uint64_t alignment = 0;
llvm::MaybeAlign maybeAlign = GV->getAlign(); llvm::MaybeAlign maybeAlign = gv->getAlign();
if (maybeAlign.hasValue()) { if (maybeAlign.hasValue()) {
llvm::Align align = maybeAlign.getValue(); llvm::Align align = maybeAlign.getValue();
alignment = align.value(); alignment = align.value();
} }
GlobalOp op = GlobalOp op =
b.create<GlobalOp>(UnknownLoc::get(context), type, GV->isConstant(), b.create<GlobalOp>(UnknownLoc::get(context), type, gv->isConstant(),
convertLinkageFromLLVM(GV->getLinkage()), convertLinkageFromLLVM(gv->getLinkage()),
GV->getName(), valueAttr, alignment); gv->getName(), valueAttr, alignment);
if (GV->hasInitializer() && !valueAttr) { if (gv->hasInitializer() && !valueAttr) {
Region &r = op.getInitializerRegion(); Region &r = op.getInitializerRegion();
currentEntryBlock = b.createBlock(&r); currentEntryBlock = b.createBlock(&r);
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
Value v = processConstant(GV->getInitializer()); Value v = processConstant(gv->getInitializer());
if (!v) if (!v)
return nullptr; return nullptr;
b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v})); b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
} }
if (GV->hasAtLeastLocalUnnamedAddr()) if (gv->hasAtLeastLocalUnnamedAddr())
op.setUnnamedAddrAttr(UnnamedAddrAttr::get( op.setUnnamedAddrAttr(UnnamedAddrAttr::get(
context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr()))); context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())));
if (GV->hasSection()) if (gv->hasSection())
op.setSectionAttr(b.getStringAttr(GV->getSection())); op.setSectionAttr(b.getStringAttr(gv->getSection()));
return globals[GV] = op; return globals[gv] = op;
} }
Value Importer::processConstant(llvm::Constant *c) { Value Importer::processConstant(llvm::Constant *c) {
@ -366,9 +367,9 @@ Value Importer::processConstant(llvm::Constant *c) {
return nullptr; return nullptr;
return instMap[c] = bEntry.create<NullOp>(unknownLoc, type); return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
} }
if (auto *GV = dyn_cast<llvm::GlobalVariable>(c)) if (auto *gv = dyn_cast<llvm::GlobalVariable>(c))
return bEntry.create<AddressOfOp>(UnknownLoc::get(context), return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
processGlobal(GV)); processGlobal(gv));
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) { if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
llvm::Instruction *i = ce->getAsInstruction(); llvm::Instruction *i = ce->getAsInstruction();
@ -526,8 +527,8 @@ LogicalResult
Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target, Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
SmallVectorImpl<Value> &blockArguments) { SmallVectorImpl<Value> &blockArguments) {
for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) { for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
auto *PN = cast<llvm::PHINode>(&*inst); auto *pn = cast<llvm::PHINode>(&*inst);
Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); Value value = processValue(pn->getIncomingValueForBlock(br->getParent()));
if (!value) if (!value)
return failure(); return failure();
blockArguments.push_back(value); blockArguments.push_back(value);
@ -777,10 +778,10 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
// If it doesn't have a name, currently, only function pointers that are // If it doesn't have a name, currently, only function pointers that are
// bitcast to i8* are parsed. // bitcast to i8* are parsed.
if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) { if (auto *ce = dyn_cast<llvm::ConstantExpr>(pf)) {
if (ce->getOpcode() == llvm::Instruction::BitCast && if (ce->getOpcode() == llvm::Instruction::BitCast &&
ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) { ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0))) if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0)))
return SymbolRefAttr::get(b.getContext(), func->getName()); return SymbolRefAttr::get(b.getContext(), func->getName());
} }
} }

View File

@ -55,12 +55,11 @@ static llvm::Constant *createSourceLocStrFromLocation(Location loc,
unsigned lineNo = fileLoc.getLine(); unsigned lineNo = fileLoc.getLine();
unsigned colNo = fileLoc.getColumn(); unsigned colNo = fileLoc.getColumn();
return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
} else {
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
locOS << loc;
return builder.getOrCreateSrcLocStr(locOS.str());
} }
std::string locStr;
llvm::raw_string_ostream locOS(locStr);
locOS << loc;
return builder.getOrCreateSrcLocStr(locOS.str());
} }
/// Create the location struct from the operation location information. /// Create the location struct from the operation location information.
@ -81,9 +80,8 @@ static llvm::Constant *createMappingInformation(Location loc,
if (auto nameLoc = loc.dyn_cast<NameLoc>()) { if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
StringRef name = nameLoc.getName(); StringRef name = nameLoc.getName();
return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
} else {
return createSourceLocStrFromLocation(loc, builder, "unknown");
} }
return createSourceLocStrFromLocation(loc, builder, "unknown");
} }
/// Return the runtime function used to lower the given operation. /// Return the runtime function used to lower the given operation.

View File

@ -861,11 +861,11 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
} }
// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. // Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
llvm::AtomicOrdering convertAtomicOrdering(Optional<StringRef> AOAttr) { llvm::AtomicOrdering convertAtomicOrdering(Optional<StringRef> aoAttr) {
if (!AOAttr.hasValue()) if (!aoAttr.hasValue())
return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
return StringSwitch<llvm::AtomicOrdering>(AOAttr.getValue()) return StringSwitch<llvm::AtomicOrdering>(aoAttr.getValue())
.Case("seq_cst", llvm::AtomicOrdering::SequentiallyConsistent) .Case("seq_cst", llvm::AtomicOrdering::SequentiallyConsistent)
.Case("acq_rel", llvm::AtomicOrdering::AcquireRelease) .Case("acq_rel", llvm::AtomicOrdering::AcquireRelease)
.Case("acquire", llvm::AtomicOrdering::Acquire) .Case("acquire", llvm::AtomicOrdering::Acquire)
@ -889,7 +889,7 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation.translateLoc(opInst.getLoc(), subprogram); moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
llvm::DebugLoc(diLoc)); llvm::DebugLoc(diLoc));
llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order()); llvm::AtomicOrdering ao = convertAtomicOrdering(readOp.memory_order());
llvm::Value *address = moduleTranslation.lookupValue(readOp.address()); llvm::Value *address = moduleTranslation.lookupValue(readOp.address());
llvm::OpenMPIRBuilder::InsertPointTy currentIP = builder.saveIP(); llvm::OpenMPIRBuilder::InsertPointTy currentIP = builder.saveIP();
@ -903,9 +903,9 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
// Restore the IP and insert Atomic Read. // Restore the IP and insert Atomic Read.
builder.restoreIP(currentIP); builder.restoreIP(currentIP);
llvm::OpenMPIRBuilder::AtomicOpValue V = {v, false, false}; llvm::OpenMPIRBuilder::AtomicOpValue atomicV = {v, false, false};
llvm::OpenMPIRBuilder::AtomicOpValue X = {address, false, false}; llvm::OpenMPIRBuilder::AtomicOpValue x = {address, false, false};
builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO)); builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, x, atomicV, ao));
return success(); return success();
} }

View File

@ -29,17 +29,17 @@ using mlir::LLVM::detail::createIntrinsicCall;
// take a single int32 argument. It is likely that the interface of this // take a single int32 argument. It is likely that the interface of this
// function will change to make it more generic. // function will change to make it more generic.
static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
StringRef fn_name, int parameter) { StringRef fnName, int parameter) {
llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::FunctionType *function_type = llvm::FunctionType::get( llvm::FunctionType *functionType = llvm::FunctionType::get(
llvm::Type::getInt64Ty(module->getContext()), // return type. llvm::Type::getInt64Ty(module->getContext()), // return type.
llvm::Type::getInt32Ty(module->getContext()), // parameter type. llvm::Type::getInt32Ty(module->getContext()), // parameter type.
false); // no variadic arguments. false); // no variadic arguments.
llvm::Function *fn = dyn_cast<llvm::Function>( llvm::Function *fn = dyn_cast<llvm::Function>(
module->getOrInsertFunction(fn_name, function_type).getCallee()); module->getOrInsertFunction(fnName, functionType).getCallee());
llvm::Value *fn_op0 = llvm::ConstantInt::get( llvm::Value *fnOp0 = llvm::ConstantInt::get(
llvm::Type::getInt32Ty(module->getContext()), parameter); llvm::Type::getInt32Ty(module->getContext()), parameter);
return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fn_op0)); return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
} }
namespace { namespace {

View File

@ -242,10 +242,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) { if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
elementType = arrayTy->getElementType(); elementType = arrayTy->getElementType();
numElements = arrayTy->getNumElements(); numElements = arrayTy->getNumElements();
} else if (auto fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) { } else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
elementType = fVectorTy->getElementType(); elementType = fVectorTy->getElementType();
numElements = fVectorTy->getNumElements(); numElements = fVectorTy->getNumElements();
} else if (auto sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) { } else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
elementType = sVectorTy->getElementType(); elementType = sVectorTy->getElementType();
numElements = sVectorTy->getMinNumElements(); numElements = sVectorTy->getMinNumElements();
} else { } else {

View File

@ -1525,7 +1525,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
// Handle named results. // Handle named results.
auto elementNames = tupleType.getElementNames(); auto elementNames = tupleType.getElementNames();
auto it = llvm::find(elementNames, name); const auto *it = llvm::find(elementNames, name);
if (it != elementNames.end()) if (it != elementNames.end())
return tupleType.getElementTypes()[it - elementNames.begin()]; return tupleType.getElementTypes()[it - elementNames.begin()];
} }

View File

@ -133,7 +133,7 @@ static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc,
} }
// Check the uses. // Check the uses.
auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) { const auto *useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
return contains(range, loc); return contains(range, loc);
}); });
if (useIt != def.uses.end()) { if (useIt != def.uses.end()) {

View File

@ -42,20 +42,20 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
MLIRContext &context) { MLIRContext &context) {
// Override the default '-h' and use the default PrintHelpMessage() which // Override the default '-h' and use the default PrintHelpMessage() which
// won't print options in categories. // won't print options in categories.
static llvm::cl::opt<bool> Help("h", llvm::cl::desc("Alias for -help"), static llvm::cl::opt<bool> help("h", llvm::cl::desc("Alias for -help"),
llvm::cl::Hidden); llvm::cl::Hidden);
static llvm::cl::OptionCategory MLIRReduceCategory("mlir-reduce options"); static llvm::cl::OptionCategory mlirReduceCategory("mlir-reduce options");
static llvm::cl::opt<std::string> inputFilename( static llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::Positional, llvm::cl::desc("<input file>"),
llvm::cl::cat(MLIRReduceCategory)); llvm::cl::cat(mlirReduceCategory));
static llvm::cl::opt<std::string> outputFilename( static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename for the reduced test case"), "o", llvm::cl::desc("Output filename for the reduced test case"),
llvm::cl::init("-"), llvm::cl::cat(MLIRReduceCategory)); llvm::cl::init("-"), llvm::cl::cat(mlirReduceCategory));
llvm::cl::HideUnrelatedOptions(MLIRReduceCategory); llvm::cl::HideUnrelatedOptions(mlirReduceCategory);
llvm::InitLLVM y(argc, argv); llvm::InitLLVM y(argc, argv);
@ -65,7 +65,7 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
llvm::cl::ParseCommandLineOptions(argc, argv, llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR test case reduction tool.\n"); "MLIR test case reduction tool.\n");
if (Help) { if (help) {
llvm::cl::PrintHelpMessage(); llvm::cl::PrintHelpMessage();
return success(); return success();
} }

View File

@ -301,14 +301,15 @@ public:
memrefEdgeCount[value]--; memrefEdgeCount[value]--;
} }
// Remove 'srcId' from 'inEdges[dstId]'. // Remove 'srcId' from 'inEdges[dstId]'.
for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
if ((*it).id == srcId && (*it).value == value) { if ((*it).id == srcId && (*it).value == value) {
inEdges[dstId].erase(it); inEdges[dstId].erase(it);
break; break;
} }
} }
// Remove 'dstId' from 'outEdges[srcId]'. // Remove 'dstId' from 'outEdges[srcId]'.
for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end();
++it) {
if ((*it).id == dstId && (*it).value == value) { if ((*it).id == dstId && (*it).value == value) {
outEdges[srcId].erase(it); outEdges[srcId].erase(it);
break; break;

View File

@ -85,7 +85,7 @@ LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) {
// Helper to check whether an operation is loop invariant wrt. SSA properties. // Helper to check whether an operation is loop invariant wrt. SSA properties.
auto isDefinedOutsideOfBody = [&](Value value) { auto isDefinedOutsideOfBody = [&](Value value) {
auto definingOp = value.getDefiningOp(); auto *definingOp = value.getDefiningOp();
return (definingOp && !!willBeMovedSet.count(definingOp)) || return (definingOp && !!willBeMovedSet.count(definingOp)) ||
looplike.isDefinedOutsideOfLoop(value); looplike.isDefinedOutsideOfLoop(value);
}; };

View File

@ -517,6 +517,6 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
newRegion->takeBody(oldRegion); newRegion->takeBody(oldRegion);
} }
return bb.createOperation(result); return bb.createOperation(result);
} else }
return oldOp; return oldOp;
} }

View File

@ -191,7 +191,7 @@ static void findMatchingStartFinishInsts(
// Check for dependence with outgoing DMAs. Doing this conservatively. // Check for dependence with outgoing DMAs. Doing this conservatively.
// TODO: use the dependence analysis to check for // TODO: use the dependence analysis to check for
// dependences between an incoming and outgoing DMA in the same iteration. // dependences between an incoming and outgoing DMA in the same iteration.
auto it = outgoingDmaOps.begin(); auto *it = outgoingDmaOps.begin();
for (; it != outgoingDmaOps.end(); ++it) { for (; it != outgoingDmaOps.end(); ++it) {
if (it->getDstMemRef() == dmaStartOp.getSrcMemRef()) if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
break; break;

View File

@ -168,7 +168,7 @@ LogicalResult OperationFolder::tryToFold(
if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) { if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
std::stable_partition( std::stable_partition(
op->getOpOperands().begin(), op->getOpOperands().end(), op->getOpOperands().begin(), op->getOpOperands().end(),
[&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); }); [&](OpOperand &o) { return !matchPattern(o.get(), m_Constant()); });
} }
// Check to see if any operands to the operation is constant and whether // Check to see if any operands to the operation is constant and whether

View File

@ -56,7 +56,8 @@ static bool isDependentLoadOrStoreOp(Operation *op,
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) { if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
return values.count(loadOp.getMemRef()) > 0 && return values.count(loadOp.getMemRef()) > 0 &&
values[loadOp.getMemRef()] == true; values[loadOp.getMemRef()] == true;
} else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) { }
if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
return values.count(storeOp.getMemRef()) > 0; return values.count(storeOp.getMemRef()) > 0;
} }
return false; return false;

View File

@ -3034,7 +3034,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
auto updateRegion = auto updateRegion =
[&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
&targetRegions) { &targetRegions) {
const auto it = targetRegions.find(region->memref); const auto *const it = targetRegions.find(region->memref);
if (it == targetRegions.end()) if (it == targetRegions.end())
return false; return false;

View File

@ -67,7 +67,7 @@ struct TestAliasAnalysisPass
// Check for aliasing behavior between each of the values. // Check for aliasing behavior between each of the values.
for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
for (auto innerIt = valsToCheck.begin(); innerIt != it; ++innerIt) for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it); printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
} }

View File

@ -52,9 +52,9 @@ struct TestMathPolynomialApproximationPass
void TestMathPolynomialApproximationPass::runOnFunction() { void TestMathPolynomialApproximationPass::runOnFunction() {
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
MathPolynomialApproximationOptions approx_options; MathPolynomialApproximationOptions approxOptions;
approx_options.enableAvx2 = enableAvx2; approxOptions.enableAvx2 = enableAvx2;
populateMathPolynomialApproximationPatterns(patterns, approx_options); populateMathPolynomialApproximationPatterns(patterns, approxOptions);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} }

View File

@ -689,24 +689,24 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
Region &body = *result.addRegion(); Region &body = *result.addRegion();
body.push_back(new Block); body.push_back(new Block);
Block &block = body.back(); Block &block = body.back();
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
if (!wrapped_op) if (!wrappedOp)
return failure(); return failure();
// Create a return terminator in the inner region, pass as operand to the // Create a return terminator in the inner region, pass as operand to the
// terminator the returned values from the wrapped operation. // terminator the returned values from the wrapped operation.
SmallVector<Value, 8> return_operands(wrapped_op->getResults()); SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
OpBuilder builder(parser.getContext()); OpBuilder builder(parser.getContext());
builder.setInsertionPointToEnd(&block); builder.setInsertionPointToEnd(&block);
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands); builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
// Get the results type for the wrapping op from the terminator operands. // Get the results type for the wrapping op from the terminator operands.
Operation &return_op = body.back().back(); Operation &returnOp = body.back().back();
result.types.append(return_op.operand_type_begin(), result.types.append(returnOp.operand_type_begin(),
return_op.operand_type_end()); returnOp.operand_type_end());
// Use the location of the wrapped op for the "test.wrapping_region" op. // Use the location of the wrapped op for the "test.wrapping_region" op.
result.location = wrapped_op->getLoc(); result.location = wrappedOp->getLoc();
return success(); return success();
} }

View File

@ -808,7 +808,7 @@ def OpFuncRef : TEST_Op<"op_funcref"> {
// That way, we will know if operations is called once or twice. // That way, we will know if operations is called once or twice.
def OpMGetNullAttr : NativeCodeCall<"Attribute()">; def OpMGetNullAttr : NativeCodeCall<"Attribute()">;
def OpMAttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">; def OpMAttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">;
def OpMVal : NativeCodeCall<"OpMTest($_builder, $0)">; def OpMVal : NativeCodeCall<"opMTest($_builder, $0)">;
def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ), def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ),
[(OpMAttributeIsNull:$optAttr)]>; [(OpMAttributeIsNull:$optAttr)]>;

Some files were not shown because too many files have changed in this diff Show More