forked from OSchip/llvm-project
Fix clang-tidy issues in mlir/ (NFC)
Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D115956
This commit is contained in:
parent
3e5b1b77d5
commit
02b6fb218e
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
cl::value_desc("filename"));
|
cl::value_desc("filename"));
|
||||||
namespace {
|
namespace {
|
||||||
enum Action { None, DumpAST };
|
enum Action { None, DumpAST };
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
static cl::opt<enum Action>
|
static cl::opt<enum Action>
|
||||||
emitAction("emit", cl::desc("Select the kind of output desired"),
|
emitAction("emit", cl::desc("Select the kind of output desired"),
|
||||||
|
|
|
@ -58,8 +58,8 @@ public:
|
||||||
// add them to the module.
|
// add them to the module.
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (FunctionAST &F : moduleAST) {
|
for (FunctionAST &f : moduleAST) {
|
||||||
auto func = mlirGen(F);
|
auto func = mlirGen(f);
|
||||||
if (!func)
|
if (!func)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
theModule.push_back(func);
|
theModule.push_back(func);
|
||||||
|
@ -113,16 +113,16 @@ private:
|
||||||
|
|
||||||
// This is a generic function, the return type will be inferred later.
|
// This is a generic function, the return type will be inferred later.
|
||||||
// Arguments type are uniformly unranked tensors.
|
// Arguments type are uniformly unranked tensors.
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||||
getType(VarType{}));
|
getType(VarType{}));
|
||||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -371,7 +371,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -398,7 +398,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -47,7 +47,7 @@ static cl::opt<enum InputType> inputType(
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum Action { None, DumpAST, DumpMLIR };
|
enum Action { None, DumpAST, DumpMLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -89,8 +89,8 @@ int dumpMLIR() {
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,8 @@ public:
|
||||||
// add them to the module.
|
// add them to the module.
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (FunctionAST &F : moduleAST) {
|
for (FunctionAST &f : moduleAST) {
|
||||||
auto func = mlirGen(F);
|
auto func = mlirGen(f);
|
||||||
if (!func)
|
if (!func)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
theModule.push_back(func);
|
theModule.push_back(func);
|
||||||
|
@ -113,16 +113,16 @@ private:
|
||||||
|
|
||||||
// This is a generic function, the return type will be inferred later.
|
// This is a generic function, the return type will be inferred later.
|
||||||
// Arguments type are uniformly unranked tensors.
|
// Arguments type are uniformly unranked tensors.
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||||
getType(VarType{}));
|
getType(VarType{}));
|
||||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -371,7 +371,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -398,7 +398,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -49,7 +49,7 @@ static cl::opt<enum InputType> inputType(
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum Action { None, DumpAST, DumpMLIR };
|
enum Action { None, DumpAST, DumpMLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -86,8 +86,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,8 @@ public:
|
||||||
// add them to the module.
|
// add them to the module.
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (FunctionAST &F : moduleAST) {
|
for (FunctionAST &f : moduleAST) {
|
||||||
auto func = mlirGen(F);
|
auto func = mlirGen(f);
|
||||||
if (!func)
|
if (!func)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
theModule.push_back(func);
|
theModule.push_back(func);
|
||||||
|
@ -113,16 +113,16 @@ private:
|
||||||
|
|
||||||
// This is a generic function, the return type will be inferred later.
|
// This is a generic function, the return type will be inferred later.
|
||||||
// Arguments type are uniformly unranked tensors.
|
// Arguments type are uniformly unranked tensors.
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||||
getType(VarType{}));
|
getType(VarType{}));
|
||||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -375,7 +375,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -402,7 +402,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -50,7 +50,7 @@ static cl::opt<enum InputType> inputType(
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum Action { None, DumpAST, DumpMLIR };
|
enum Action { None, DumpAST, DumpMLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -87,8 +87,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,8 @@ public:
|
||||||
// add them to the module.
|
// add them to the module.
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (FunctionAST &F : moduleAST) {
|
for (FunctionAST &f : moduleAST) {
|
||||||
auto func = mlirGen(F);
|
auto func = mlirGen(f);
|
||||||
if (!func)
|
if (!func)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
theModule.push_back(func);
|
theModule.push_back(func);
|
||||||
|
@ -113,16 +113,16 @@ private:
|
||||||
|
|
||||||
// This is a generic function, the return type will be inferred later.
|
// This is a generic function, the return type will be inferred later.
|
||||||
// Arguments type are uniformly unranked tensors.
|
// Arguments type are uniformly unranked tensors.
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||||
getType(VarType{}));
|
getType(VarType{}));
|
||||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -375,7 +375,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -402,7 +402,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -52,7 +52,7 @@ static cl::opt<enum InputType> inputType(
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine };
|
enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -91,8 +91,8 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,8 @@ public:
|
||||||
// add them to the module.
|
// add them to the module.
|
||||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||||
|
|
||||||
for (FunctionAST &F : moduleAST) {
|
for (FunctionAST &f : moduleAST) {
|
||||||
auto func = mlirGen(F);
|
auto func = mlirGen(f);
|
||||||
if (!func)
|
if (!func)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
theModule.push_back(func);
|
theModule.push_back(func);
|
||||||
|
@ -113,16 +113,16 @@ private:
|
||||||
|
|
||||||
// This is a generic function, the return type will be inferred later.
|
// This is a generic function, the return type will be inferred later.
|
||||||
// Arguments type are uniformly unranked tensors.
|
// Arguments type are uniformly unranked tensors.
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
llvm::SmallVector<mlir::Type, 4> argTypes(proto.getArgs().size(),
|
||||||
getType(VarType{}));
|
getType(VarType{}));
|
||||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -375,7 +375,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -402,7 +402,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
ScopedHashTableScope<StringRef, mlir::Value> var_scope(symbolTable);
|
ScopedHashTableScope<StringRef, mlir::Value> varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -118,7 +118,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -66,7 +66,7 @@ enum Action {
|
||||||
DumpLLVMIR,
|
DumpLLVMIR,
|
||||||
RunJIT
|
RunJIT
|
||||||
};
|
};
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -110,8 +110,8 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -169,14 +169,14 @@ private:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
argTypes.push_back(type);
|
argTypes.push_back(type);
|
||||||
}
|
}
|
||||||
auto func_type = builder.getFunctionType(argTypes, llvm::None);
|
auto funcType = builder.getFunctionType(argTypes, llvm::None);
|
||||||
return mlir::FuncOp::create(location, proto.getName(), func_type);
|
return mlir::FuncOp::create(location, proto.getName(), funcType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emit a new function and add it to the MLIR module.
|
/// Emit a new function and add it to the MLIR module.
|
||||||
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
mlir::FuncOp mlirGen(FunctionAST &funcAST) {
|
||||||
// Create a scope in the symbol table to hold variable declarations.
|
// Create a scope in the symbol table to hold variable declarations.
|
||||||
SymbolTableScopeT var_scope(symbolTable);
|
SymbolTableScopeT varScope(symbolTable);
|
||||||
|
|
||||||
// Create an MLIR function for the given prototype.
|
// Create an MLIR function for the given prototype.
|
||||||
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
mlir::FuncOp function(mlirGen(*funcAST.getProto()));
|
||||||
|
@ -286,7 +286,7 @@ private:
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
auto structVars = structAST->getVariables();
|
auto structVars = structAST->getVariables();
|
||||||
auto it = llvm::find_if(structVars, [&](auto &var) {
|
const auto *it = llvm::find_if(structVars, [&](auto &var) {
|
||||||
return var->getName() == name->getName();
|
return var->getName() == name->getName();
|
||||||
});
|
});
|
||||||
if (it == structVars.end())
|
if (it == structVars.end())
|
||||||
|
@ -569,7 +569,7 @@ private:
|
||||||
/// Future expressions will be able to reference this variable through symbol
|
/// Future expressions will be able to reference this variable through symbol
|
||||||
/// table lookup.
|
/// table lookup.
|
||||||
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
mlir::Value mlirGen(VarDeclExprAST &vardecl) {
|
||||||
auto init = vardecl.getInitVal();
|
auto *init = vardecl.getInitVal();
|
||||||
if (!init) {
|
if (!init) {
|
||||||
emitError(loc(vardecl.loc()),
|
emitError(loc(vardecl.loc()),
|
||||||
"missing initializer in variable declaration");
|
"missing initializer in variable declaration");
|
||||||
|
@ -612,7 +612,7 @@ private:
|
||||||
|
|
||||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||||
SymbolTableScopeT var_scope(symbolTable);
|
SymbolTableScopeT varScope(symbolTable);
|
||||||
for (auto &expr : blockAST) {
|
for (auto &expr : blockAST) {
|
||||||
// Specific handling for variable declarations, return statement, and
|
// Specific handling for variable declarations, return statement, and
|
||||||
// print. These can only appear in block list and not in nested
|
// print. These can only appear in block list and not in nested
|
||||||
|
|
|
@ -121,7 +121,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||||
void printLitHelper(ExprAST *litOrNum) {
|
void printLitHelper(ExprAST *litOrNum) {
|
||||||
// Inside a literal expression we can have either a number or another literal
|
// Inside a literal expression we can have either a number or another literal
|
||||||
if (auto num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
if (auto *num = llvm::dyn_cast<NumberExprAST>(litOrNum)) {
|
||||||
llvm::errs() << num->getValue();
|
llvm::errs() << num->getValue();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ static cl::opt<std::string> inputFilename(cl::Positional,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum InputType { Toy, MLIR };
|
enum InputType { Toy, MLIR };
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum InputType> inputType(
|
static cl::opt<enum InputType> inputType(
|
||||||
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
"x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
||||||
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
||||||
|
@ -66,7 +66,7 @@ enum Action {
|
||||||
DumpLLVMIR,
|
DumpLLVMIR,
|
||||||
RunJIT
|
RunJIT
|
||||||
};
|
};
|
||||||
}
|
} // namespace
|
||||||
static cl::opt<enum Action> emitAction(
|
static cl::opt<enum Action> emitAction(
|
||||||
"emit", cl::desc("Select the kind of output desired"),
|
"emit", cl::desc("Select the kind of output desired"),
|
||||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||||
|
@ -110,8 +110,8 @@ int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||||
// Otherwise, the input is '.mlir'.
|
// Otherwise, the input is '.mlir'.
|
||||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||||
if (std::error_code EC = fileOrErr.getError()) {
|
if (std::error_code ec = fileOrErr.getError()) {
|
||||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,7 @@ struct DFSState {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static void DFSPostorder(Operation *root, DFSState *state) {
|
static void dfsPostorder(Operation *root, DFSState *state) {
|
||||||
SmallVector<Operation *> queue(1, root);
|
SmallVector<Operation *> queue(1, root);
|
||||||
std::vector<Operation *> ops;
|
std::vector<Operation *> ops;
|
||||||
while (!queue.empty()) {
|
while (!queue.empty()) {
|
||||||
|
@ -200,7 +200,7 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
|
||||||
DFSState state(toSort);
|
DFSState state(toSort);
|
||||||
for (auto *s : toSort) {
|
for (auto *s : toSort) {
|
||||||
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
||||||
DFSPostorder(s, &state);
|
dfsPostorder(s, &state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reorder and return.
|
// Reorder and return.
|
||||||
|
|
|
@ -1278,10 +1278,10 @@ bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
|
||||||
|
|
||||||
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
||||||
/// where each lists loops from outer-most to inner-most in loop nest.
|
/// where each lists loops from outer-most to inner-most in loop nest.
|
||||||
unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
|
unsigned mlir::getNumCommonSurroundingLoops(Operation &a, Operation &b) {
|
||||||
SmallVector<AffineForOp, 4> loopsA, loopsB;
|
SmallVector<AffineForOp, 4> loopsA, loopsB;
|
||||||
getLoopIVs(A, &loopsA);
|
getLoopIVs(a, &loopsA);
|
||||||
getLoopIVs(B, &loopsB);
|
getLoopIVs(b, &loopsB);
|
||||||
|
|
||||||
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
||||||
unsigned numCommonLoops = 0;
|
unsigned numCommonLoops = 0;
|
||||||
|
|
|
@ -17,7 +17,6 @@ namespace py = pybind11;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::python;
|
using namespace mlir::python;
|
||||||
|
|
||||||
using llvm::None;
|
|
||||||
using llvm::Optional;
|
using llvm::Optional;
|
||||||
using llvm::SmallVector;
|
using llvm::SmallVector;
|
||||||
using llvm::Twine;
|
using llvm::Twine;
|
||||||
|
@ -510,7 +509,8 @@ public:
|
||||||
if (mlirTypeIsAF32(elementType)) {
|
if (mlirTypeIsAF32(elementType)) {
|
||||||
// f32
|
// f32
|
||||||
return bufferInfo<float>(shapedType);
|
return bufferInfo<float>(shapedType);
|
||||||
} else if (mlirTypeIsAF64(elementType)) {
|
}
|
||||||
|
if (mlirTypeIsAF64(elementType)) {
|
||||||
// f64
|
// f64
|
||||||
return bufferInfo<double>(shapedType);
|
return bufferInfo<double>(shapedType);
|
||||||
} else if (mlirTypeIsAF16(elementType)) {
|
} else if (mlirTypeIsAF16(elementType)) {
|
||||||
|
@ -712,12 +712,12 @@ public:
|
||||||
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
|
SmallVector<MlirNamedAttribute> mlirNamedAttributes;
|
||||||
mlirNamedAttributes.reserve(attributes.size());
|
mlirNamedAttributes.reserve(attributes.size());
|
||||||
for (auto &it : attributes) {
|
for (auto &it : attributes) {
|
||||||
auto &mlir_attr = it.second.cast<PyAttribute &>();
|
auto &mlirAttr = it.second.cast<PyAttribute &>();
|
||||||
auto name = it.first.cast<std::string>();
|
auto name = it.first.cast<std::string>();
|
||||||
mlirNamedAttributes.push_back(mlirNamedAttributeGet(
|
mlirNamedAttributes.push_back(mlirNamedAttributeGet(
|
||||||
mlirIdentifierGet(mlirAttributeGetContext(mlir_attr),
|
mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
|
||||||
toMlirStringRef(name)),
|
toMlirStringRef(name)),
|
||||||
mlir_attr));
|
mlirAttr));
|
||||||
}
|
}
|
||||||
MlirAttribute attr =
|
MlirAttribute attr =
|
||||||
mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
|
mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
|
||||||
|
|
|
@ -1267,7 +1267,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
|
||||||
if (segmentSpec == 1 || segmentSpec == 0) {
|
if (segmentSpec == 1 || segmentSpec == 0) {
|
||||||
// Unpack unary element.
|
// Unpack unary element.
|
||||||
try {
|
try {
|
||||||
auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
|
auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
|
||||||
if (operandValue) {
|
if (operandValue) {
|
||||||
operands.push_back(operandValue);
|
operands.push_back(operandValue);
|
||||||
operandSegmentLengths.push_back(1);
|
operandSegmentLengths.push_back(1);
|
||||||
|
@ -2286,10 +2286,10 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"body",
|
"body",
|
||||||
[](PyModule &self) {
|
[](PyModule &self) {
|
||||||
PyOperationRef module_op = PyOperation::forOperation(
|
PyOperationRef moduleOp = PyOperation::forOperation(
|
||||||
self.getContext(), mlirModuleGetOperation(self.get()),
|
self.getContext(), mlirModuleGetOperation(self.get()),
|
||||||
self.getRef().releaseObject());
|
self.getRef().releaseObject());
|
||||||
PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
|
PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
|
||||||
return returnBlock;
|
return returnBlock;
|
||||||
},
|
},
|
||||||
"Return the block for this module")
|
"Return the block for this module")
|
||||||
|
|
|
@ -51,9 +51,8 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||||
} catch (py::error_already_set &e) {
|
} catch (py::error_already_set &e) {
|
||||||
if (e.matches(PyExc_ModuleNotFoundError)) {
|
if (e.matches(PyExc_ModuleNotFoundError)) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
|
||||||
throw;
|
|
||||||
}
|
}
|
||||||
|
throw;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -136,11 +135,10 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
|
||||||
// Positive cache.
|
// Positive cache.
|
||||||
rawOpViewClassMapCache[operationName] = foundIt->second;
|
rawOpViewClassMapCache[operationName] = foundIt->second;
|
||||||
return foundIt->second;
|
return foundIt->second;
|
||||||
} else {
|
|
||||||
// Negative cache.
|
|
||||||
rawOpViewClassMap[operationName] = py::none();
|
|
||||||
return llvm::None;
|
|
||||||
}
|
}
|
||||||
|
// Negative cache.
|
||||||
|
rawOpViewClassMap[operationName] = py::none();
|
||||||
|
return llvm::None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,6 @@
|
||||||
|
|
||||||
#include "PybindUtils.h"
|
#include "PybindUtils.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
pybind11::error_already_set
|
pybind11::error_already_set
|
||||||
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
|
mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) {
|
||||||
auto messageStr = message.str();
|
auto messageStr = message.str();
|
||||||
|
|
|
@ -10,8 +10,6 @@
|
||||||
|
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Module initialization.
|
// Module initialization.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
|
@ -818,7 +818,7 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable,
|
||||||
MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
|
MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
|
||||||
MlirStringRef newSymbol,
|
MlirStringRef newSymbol,
|
||||||
MlirOperation from) {
|
MlirOperation from) {
|
||||||
auto cppFrom = unwrap(from);
|
auto *cppFrom = unwrap(from);
|
||||||
auto *context = cppFrom->getContext();
|
auto *context = cppFrom->getContext();
|
||||||
auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
|
auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
|
||||||
auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);
|
auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);
|
||||||
|
|
|
@ -468,10 +468,10 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
|
||||||
Value structPtr =
|
Value structPtr =
|
||||||
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
|
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
|
||||||
|
|
||||||
Type int32_type = typeConverter.convertType(builder.getI32Type());
|
Type int32Type = typeConverter.convertType(builder.getI32Type());
|
||||||
Value zero =
|
Value zero =
|
||||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
|
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
|
||||||
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
|
Value three = builder.create<LLVM::ConstantOp>(loc, int32Type,
|
||||||
builder.getI32IntegerAttr(3));
|
builder.getI32IntegerAttr(3));
|
||||||
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
|
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy),
|
||||||
structPtr, ValueRange({zero, three}));
|
structPtr, ValueRange({zero, three}));
|
||||||
|
|
|
@ -90,8 +90,8 @@ static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
|
||||||
DenseMap<Value, RootOrderingCost> &costs = outer->second;
|
DenseMap<Value, RootOrderingCost> &costs = outer->second;
|
||||||
Value bestSource;
|
Value bestSource;
|
||||||
std::pair<unsigned, unsigned> bestCost;
|
std::pair<unsigned, unsigned> bestCost;
|
||||||
auto inner = costs.begin(), inner_e = costs.end();
|
auto inner = costs.begin(), innerE = costs.end();
|
||||||
while (inner != inner_e) {
|
while (inner != innerE) {
|
||||||
Value source = inner->first;
|
Value source = inner->first;
|
||||||
if (cycleSet.contains(source)) {
|
if (cycleSet.contains(source)) {
|
||||||
// Going-away edge => get its cost and erase it.
|
// Going-away edge => get its cost and erase it.
|
||||||
|
|
|
@ -259,8 +259,8 @@ void AffineLoopToGpuConverter::createLaunch(AffineForOp rootForOp,
|
||||||
// from 0 to N with step 1. Therefore, loop induction variables are replaced
|
// from 0 to N with step 1. Therefore, loop induction variables are replaced
|
||||||
// with (gpu-thread/block-id * S) + LB.
|
// with (gpu-thread/block-id * S) + LB.
|
||||||
builder.setInsertionPointToStart(&launchOp.body().front());
|
builder.setInsertionPointToStart(&launchOp.body().front());
|
||||||
auto lbArgumentIt = lbs.begin();
|
auto *lbArgumentIt = lbs.begin();
|
||||||
auto stepArgumentIt = steps.begin();
|
auto *stepArgumentIt = steps.begin();
|
||||||
for (auto en : llvm::enumerate(ivs)) {
|
for (auto en : llvm::enumerate(ivs)) {
|
||||||
Value id =
|
Value id =
|
||||||
en.index() < numBlockDims
|
en.index() < numBlockDims
|
||||||
|
@ -640,7 +640,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||||
} else if (op == launchOp.getOperation()) {
|
} else if (op == launchOp.getOperation()) {
|
||||||
// Found our sentinel value. We have finished the operations from one
|
// Found our sentinel value. We have finished the operations from one
|
||||||
// nesting level, pop one level back up.
|
// nesting level, pop one level back up.
|
||||||
auto parent = rewriter.getInsertionPoint()->getParentOp();
|
auto *parent = rewriter.getInsertionPoint()->getParentOp();
|
||||||
rewriter.setInsertionPointAfter(parent);
|
rewriter.setInsertionPointAfter(parent);
|
||||||
leftNestingScope = true;
|
leftNestingScope = true;
|
||||||
seenSideeffects = false;
|
seenSideeffects = false;
|
||||||
|
|
|
@ -455,11 +455,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||||
ivs.reserve(parallelOp.getNumLoops());
|
ivs.reserve(parallelOp.getNumLoops());
|
||||||
bool first = true;
|
bool first = true;
|
||||||
SmallVector<Value, 4> loopResults(iterArgs);
|
SmallVector<Value, 4> loopResults(iterArgs);
|
||||||
for (auto loop_operands :
|
for (auto loopOperands :
|
||||||
llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
|
llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
|
||||||
parallelOp.getUpperBound(), parallelOp.getStep())) {
|
parallelOp.getUpperBound(), parallelOp.getStep())) {
|
||||||
Value iv, lower, upper, step;
|
Value iv, lower, upper, step;
|
||||||
std::tie(iv, lower, upper, step) = loop_operands;
|
std::tie(iv, lower, upper, step) = loopOperands;
|
||||||
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
|
ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
|
||||||
ivs.push_back(forOp.getInductionVar());
|
ivs.push_back(forOp.getInductionVar());
|
||||||
auto iterRange = forOp.getRegionIterArgs();
|
auto iterRange = forOp.getRegionIterArgs();
|
||||||
|
|
|
@ -1390,7 +1390,7 @@ public:
|
||||||
auto dstType = typeConverter.convertType(op.getType());
|
auto dstType = typeConverter.convertType(op.getType());
|
||||||
auto scalarType = dstType.cast<VectorType>().getElementType();
|
auto scalarType = dstType.cast<VectorType>().getElementType();
|
||||||
auto componentsArray = components.getValue();
|
auto componentsArray = components.getValue();
|
||||||
auto context = rewriter.getContext();
|
auto *context = rewriter.getContext();
|
||||||
auto llvmI32Type = IntegerType::get(context, 32);
|
auto llvmI32Type = IntegerType::get(context, 32);
|
||||||
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
|
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
|
||||||
for (unsigned i = 0; i < componentsArray.size(); i++) {
|
for (unsigned i = 0; i < componentsArray.size(); i++) {
|
||||||
|
|
|
@ -2173,16 +2173,16 @@ public:
|
||||||
|
|
||||||
rewriter.create<linalg::YieldOp>(loc, result);
|
rewriter.create<linalg::YieldOp>(loc, result);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
}
|
||||||
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
|
y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
|
||||||
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
|
y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
|
||||||
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
|
y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
|
||||||
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
|
y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
|
||||||
|
|
||||||
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
|
if (resultElementTy.getIntOrFloatBitWidth() > 32) {
|
||||||
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
|
dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
|
||||||
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
|
dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto unitVal = rewriter.create<arith::ConstantOp>(
|
auto unitVal = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
|
loc, rewriter.getIntegerAttr(resultElementTy, 1 << shift));
|
||||||
|
@ -2206,7 +2206,6 @@ public:
|
||||||
|
|
||||||
rewriter.create<linalg::YieldOp>(loc, result);
|
rewriter.create<linalg::YieldOp>(loc, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -28,9 +28,9 @@ namespace {
|
||||||
struct GpuAllReduceRewriter {
|
struct GpuAllReduceRewriter {
|
||||||
using AccumulatorFactory = std::function<Value(Value, Value)>;
|
using AccumulatorFactory = std::function<Value(Value, Value)>;
|
||||||
|
|
||||||
GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
|
GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
|
||||||
PatternRewriter &rewriter_)
|
PatternRewriter &rewriter)
|
||||||
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
|
: funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
|
||||||
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
|
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
|
||||||
indexType(IndexType::get(reduceOp.getContext())),
|
indexType(IndexType::get(reduceOp.getContext())),
|
||||||
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
|
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
|
||||||
|
|
|
@ -313,7 +313,7 @@ private:
|
||||||
// a SymbolTable by the caller. SymbolTable needs to be refactored to
|
// a SymbolTable by the caller. SymbolTable needs to be refactored to
|
||||||
// prevent manual building of Ops with symbols in code using SymbolTables
|
// prevent manual building of Ops with symbols in code using SymbolTables
|
||||||
// and then this needs to use the OpBuilder.
|
// and then this needs to use the OpBuilder.
|
||||||
auto context = getOperation().getContext();
|
auto *context = getOperation().getContext();
|
||||||
OpBuilder builder(context);
|
OpBuilder builder(context);
|
||||||
auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
|
auto kernelModule = builder.create<gpu::GPUModuleOp>(kernelFunc.getLoc(),
|
||||||
kernelFunc.getName());
|
kernelFunc.getName());
|
||||||
|
|
|
@ -266,13 +266,14 @@ bool LLVMPointerType::areCompatible(DataLayoutEntryListRef oldLayout,
|
||||||
unsigned size = kDefaultPointerSizeBits;
|
unsigned size = kDefaultPointerSizeBits;
|
||||||
unsigned abi = kDefaultPointerAlignment;
|
unsigned abi = kDefaultPointerAlignment;
|
||||||
auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
|
auto newType = newEntry.getKey().get<Type>().cast<LLVMPointerType>();
|
||||||
auto it = llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
const auto *it =
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||||
return type.cast<LLVMPointerType>().getAddressSpace() ==
|
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
||||||
newType.getAddressSpace();
|
return type.cast<LLVMPointerType>().getAddressSpace() ==
|
||||||
}
|
newType.getAddressSpace();
|
||||||
return false;
|
}
|
||||||
});
|
return false;
|
||||||
|
});
|
||||||
if (it == oldLayout.end()) {
|
if (it == oldLayout.end()) {
|
||||||
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
llvm::find_if(oldLayout, [&](DataLayoutEntryInterface entry) {
|
||||||
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
if (auto type = entry.getKey().dyn_cast<Type>()) {
|
||||||
|
@ -440,14 +441,15 @@ LLVMStructType::getTypeSizeInBits(const DataLayout &dataLayout,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
|
enum class StructDLEntryPos { Abi = 0, Preferred = 1 };
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
static Optional<unsigned>
|
static Optional<unsigned>
|
||||||
getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
|
getStructDataLayoutEntry(DataLayoutEntryListRef params, LLVMStructType type,
|
||||||
StructDLEntryPos pos) {
|
StructDLEntryPos pos) {
|
||||||
auto currentEntry = llvm::find_if(params, [](DataLayoutEntryInterface entry) {
|
const auto *currentEntry =
|
||||||
return entry.isTypeEntry();
|
llvm::find_if(params, [](DataLayoutEntryInterface entry) {
|
||||||
});
|
return entry.isTypeEntry();
|
||||||
|
});
|
||||||
if (currentEntry == params.end())
|
if (currentEntry == params.end())
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
||||||
|
@ -509,7 +511,7 @@ bool LLVMStructType::areCompatible(DataLayoutEntryListRef oldLayout,
|
||||||
if (!newEntry.isTypeEntry())
|
if (!newEntry.isTypeEntry())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto previousEntry =
|
const auto *previousEntry =
|
||||||
llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
|
llvm::find_if(oldLayout, [](DataLayoutEntryInterface entry) {
|
||||||
return entry.isTypeEntry();
|
return entry.isTypeEntry();
|
||||||
});
|
});
|
||||||
|
|
|
@ -228,6 +228,7 @@ public:
|
||||||
return operand;
|
return operand;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__add(Value lhs, Value rhs) {
|
Value applyfn__add(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -237,6 +238,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__exp(Value x) {
|
Value applyfn__exp(Value x) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(x))
|
if (isFloatingPoint(x))
|
||||||
|
@ -244,6 +246,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__log(Value x) {
|
Value applyfn__log(Value x) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(x))
|
if (isFloatingPoint(x))
|
||||||
|
@ -251,6 +254,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__sub(Value lhs, Value rhs) {
|
Value applyfn__sub(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -260,6 +264,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__mul(Value lhs, Value rhs) {
|
Value applyfn__mul(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -269,6 +274,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__max(Value lhs, Value rhs) {
|
Value applyfn__max(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -278,6 +284,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__max_unsigned(Value lhs, Value rhs) {
|
Value applyfn__max_unsigned(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -287,6 +294,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__min(Value lhs, Value rhs) {
|
Value applyfn__min(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -296,6 +304,7 @@ public:
|
||||||
llvm_unreachable("unsupported non numeric type");
|
llvm_unreachable("unsupported non numeric type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
Value applyfn__min_unsigned(Value lhs, Value rhs) {
|
Value applyfn__min_unsigned(Value lhs, Value rhs) {
|
||||||
OpBuilder builder = getBuilder();
|
OpBuilder builder = getBuilder();
|
||||||
if (isFloatingPoint(lhs))
|
if (isFloatingPoint(lhs))
|
||||||
|
@ -1829,12 +1838,12 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Parse input tensors.
|
// Parse input tensors.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args;
|
SmallVector<OpAsmParser::OperandType, 4> inputs, inputRegionArgs;
|
||||||
SmallVector<Type, 4> inputTypes;
|
SmallVector<Type, 4> inputTypes;
|
||||||
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
||||||
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
|
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
|
||||||
|
|
||||||
if (parser.parseAssignmentListWithTypes(input_region_args, inputs,
|
if (parser.parseAssignmentListWithTypes(inputRegionArgs, inputs,
|
||||||
inputTypes))
|
inputTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1844,12 +1853,12 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse output tensors.
|
// Parse output tensors.
|
||||||
SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args;
|
SmallVector<OpAsmParser::OperandType, 4> outputs, outputRegionArgs;
|
||||||
SmallVector<Type, 4> outputTypes;
|
SmallVector<Type, 4> outputTypes;
|
||||||
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
||||||
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
|
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
|
||||||
|
|
||||||
if (parser.parseAssignmentListWithTypes(output_region_args, outputs,
|
if (parser.parseAssignmentListWithTypes(outputRegionArgs, outputs,
|
||||||
outputTypes))
|
outputTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1905,15 +1914,15 @@ static ParseResult parseTiledLoopOp(OpAsmParser &parser,
|
||||||
// Parse the body.
|
// Parse the body.
|
||||||
Region *body = result.addRegion();
|
Region *body = result.addRegion();
|
||||||
|
|
||||||
SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType());
|
SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
|
||||||
region_types.append(inputTypes);
|
regionTypes.append(inputTypes);
|
||||||
region_types.append(outputTypes);
|
regionTypes.append(outputTypes);
|
||||||
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> region_args(ivs);
|
SmallVector<OpAsmParser::OperandType, 4> regionArgs(ivs);
|
||||||
region_args.append(input_region_args);
|
regionArgs.append(inputRegionArgs);
|
||||||
region_args.append(output_region_args);
|
regionArgs.append(outputRegionArgs);
|
||||||
|
|
||||||
if (parser.parseRegion(*body, region_args, region_types))
|
if (parser.parseRegion(*body, regionArgs, regionTypes))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Parse optional attributes.
|
// Parse optional attributes.
|
||||||
|
|
|
@ -127,7 +127,7 @@ class ConvertElementwiseToLinalgPass
|
||||||
: public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
|
: public ConvertElementwiseToLinalgBase<ConvertElementwiseToLinalgPass> {
|
||||||
|
|
||||||
void runOnOperation() final {
|
void runOnOperation() final {
|
||||||
auto func = getOperation();
|
auto *func = getOperation();
|
||||||
auto *context = &getContext();
|
auto *context = &getContext();
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
|
|
@ -1426,9 +1426,9 @@ namespace {
|
||||||
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
||||||
/// ```
|
/// ```
|
||||||
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
||||||
struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
|
struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
||||||
Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
|
Conv1DNwcGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
|
||||||
int dilationW)
|
int dilationW)
|
||||||
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
|
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
|
||||||
strideW(strideW), dilationW(dilationW) {
|
strideW(strideW), dilationW(dilationW) {
|
||||||
// Determine whether `linalgOp` can be generated with this generator
|
// Determine whether `linalgOp` can be generated with this generator
|
||||||
|
@ -1594,7 +1594,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
|
||||||
/// ```
|
/// ```
|
||||||
/// kw is always unrolled.
|
/// kw is always unrolled.
|
||||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
|
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
|
||||||
FailureOr<Operation *> dilated_conv() {
|
FailureOr<Operation *> dilatedConv() {
|
||||||
if (!valid)
|
if (!valid)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -1730,7 +1730,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
|
||||||
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
||||||
/*rhsIndex*/ {kw, c},
|
/*rhsIndex*/ {kw, c},
|
||||||
/*resIndex*/ {n, w, c}}))
|
/*resIndex*/ {n, w, c}}))
|
||||||
return dilated_conv();
|
return dilatedConv();
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1752,7 +1752,7 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
|
||||||
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
|
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
|
||||||
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
|
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
|
||||||
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
|
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
|
||||||
Conv1D_NWC_Generator e(b, linalgOp, stride, dilation);
|
Conv1DNwcGenerator e(b, linalgOp, stride, dilation);
|
||||||
auto res = e.generateConv();
|
auto res = e.generateConv();
|
||||||
if (succeeded(res))
|
if (succeeded(res))
|
||||||
return res;
|
return res;
|
||||||
|
|
|
@ -195,7 +195,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
|
||||||
// Decomposes given floating point value `arg` into a normalized fraction and
|
// Decomposes given floating point value `arg` into a normalized fraction and
|
||||||
// an integral power of two (see std::frexp). Returned values have float type.
|
// an integral power of two (see std::frexp). Returned values have float type.
|
||||||
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
|
static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
|
||||||
bool is_positive = false) {
|
bool isPositive = false) {
|
||||||
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
|
assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
|
||||||
ArrayRef<int64_t> shape = vectorShape(arg);
|
ArrayRef<int64_t> shape = vectorShape(arg);
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
|
||||||
Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
|
Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
|
||||||
|
|
||||||
// Compute exponent.
|
// Compute exponent.
|
||||||
Value arg0 = is_positive ? arg : builder.create<math::AbsOp>(arg);
|
Value arg0 = isPositive ? arg : builder.create<math::AbsOp>(arg);
|
||||||
Value biasedExponentBits = builder.create<arith::ShRUIOp>(
|
Value biasedExponentBits = builder.create<arith::ShRUIOp>(
|
||||||
builder.create<arith::BitcastOp>(i32Vec, arg0),
|
builder.create<arith::BitcastOp>(i32Vec, arg0),
|
||||||
bcast(i32Cst(builder, 23)));
|
bcast(i32Cst(builder, 23)));
|
||||||
|
|
|
@ -375,13 +375,13 @@ parseReductionVarList(OpAsmParser &parser,
|
||||||
/// Print Reduction clause
|
/// Print Reduction clause
|
||||||
static void printReductionVarList(OpAsmPrinter &p,
|
static void printReductionVarList(OpAsmPrinter &p,
|
||||||
Optional<ArrayAttr> reductions,
|
Optional<ArrayAttr> reductions,
|
||||||
OperandRange reduction_vars) {
|
OperandRange reductionVars) {
|
||||||
p << "reduction(";
|
p << "reduction(";
|
||||||
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
|
for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
|
||||||
if (i != 0)
|
if (i != 0)
|
||||||
p << ", ";
|
p << ", ";
|
||||||
p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
|
p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
|
||||||
<< reduction_vars[i].getType();
|
<< reductionVars[i].getType();
|
||||||
}
|
}
|
||||||
p << ") ";
|
p << ") ";
|
||||||
}
|
}
|
||||||
|
@ -389,9 +389,9 @@ static void printReductionVarList(OpAsmPrinter &p,
|
||||||
/// Verifies Reduction Clause
|
/// Verifies Reduction Clause
|
||||||
static LogicalResult verifyReductionVarList(Operation *op,
|
static LogicalResult verifyReductionVarList(Operation *op,
|
||||||
Optional<ArrayAttr> reductions,
|
Optional<ArrayAttr> reductions,
|
||||||
OperandRange reduction_vars) {
|
OperandRange reductionVars) {
|
||||||
if (reduction_vars.size() != 0) {
|
if (reductionVars.size() != 0) {
|
||||||
if (!reductions || reductions->size() != reduction_vars.size())
|
if (!reductions || reductions->size() != reductionVars.size())
|
||||||
return op->emitOpError()
|
return op->emitOpError()
|
||||||
<< "expected as many reduction symbol references "
|
<< "expected as many reduction symbol references "
|
||||||
"as reduction variables";
|
"as reduction variables";
|
||||||
|
@ -402,7 +402,7 @@ static LogicalResult verifyReductionVarList(Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseSet<Value> accumulators;
|
DenseSet<Value> accumulators;
|
||||||
for (auto args : llvm::zip(reduction_vars, *reductions)) {
|
for (auto args : llvm::zip(reductionVars, *reductions)) {
|
||||||
Value accum = std::get<0>(args);
|
Value accum = std::get<0>(args);
|
||||||
|
|
||||||
if (!accumulators.insert(accum).second)
|
if (!accumulators.insert(accum).second)
|
||||||
|
|
|
@ -271,8 +271,8 @@ bool OperationOp::hasTypeInference() {
|
||||||
static LogicalResult verify(PatternOp pattern) {
|
static LogicalResult verify(PatternOp pattern) {
|
||||||
Region &body = pattern.body();
|
Region &body = pattern.body();
|
||||||
Operation *term = body.front().getTerminator();
|
Operation *term = body.front().getTerminator();
|
||||||
auto rewrite_op = dyn_cast<RewriteOp>(term);
|
auto rewriteOp = dyn_cast<RewriteOp>(term);
|
||||||
if (!rewrite_op) {
|
if (!rewriteOp) {
|
||||||
return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
|
return pattern.emitOpError("expected body to terminate with `pdl.rewrite`")
|
||||||
.attachNote(term->getLoc())
|
.attachNote(term->getLoc())
|
||||||
.append("see terminator defined here");
|
.append("see terminator defined here");
|
||||||
|
|
|
@ -74,9 +74,9 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
build(builder, state, range, successor);
|
build(builder, state, range, successor);
|
||||||
if (initLoop) {
|
if (initLoop) {
|
||||||
// Create the block and the loop variable.
|
// Create the block and the loop variable.
|
||||||
auto range_type = range.getType().cast<pdl::RangeType>();
|
auto rangeType = range.getType().cast<pdl::RangeType>();
|
||||||
state.regions.front()->emplaceBlock();
|
state.regions.front()->emplaceBlock();
|
||||||
state.regions.front()->addArgument(range_type.getElementType());
|
state.regions.front()->addArgument(rangeType.getElementType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -104,11 +104,13 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
|
||||||
if (candidateType == getStorageType()) {
|
if (candidateType == getStorageType()) {
|
||||||
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
|
// i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
|
||||||
return *this;
|
return *this;
|
||||||
} else if (candidateType.isa<RankedTensorType>()) {
|
}
|
||||||
|
if (candidateType.isa<RankedTensorType>()) {
|
||||||
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||||
return RankedTensorType::get(
|
return RankedTensorType::get(
|
||||||
candidateType.cast<RankedTensorType>().getShape(), getStorageType());
|
candidateType.cast<RankedTensorType>().getShape(), getStorageType());
|
||||||
} else if (candidateType.isa<UnrankedTensorType>()) {
|
}
|
||||||
|
if (candidateType.isa<UnrankedTensorType>()) {
|
||||||
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
|
||||||
return UnrankedTensorType::get(getStorageType());
|
return UnrankedTensorType::get(getStorageType());
|
||||||
} else if (candidateType.isa<VectorType>()) {
|
} else if (candidateType.isa<VectorType>()) {
|
||||||
|
@ -124,7 +126,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
|
||||||
if (quantizedType.isa<QuantizedType>()) {
|
if (quantizedType.isa<QuantizedType>()) {
|
||||||
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
|
// i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
|
||||||
return quantizedType.cast<QuantizedType>().getStorageType();
|
return quantizedType.cast<QuantizedType>().getStorageType();
|
||||||
} else if (quantizedType.isa<ShapedType>()) {
|
}
|
||||||
|
if (quantizedType.isa<ShapedType>()) {
|
||||||
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||||
ShapedType sType = quantizedType.cast<ShapedType>();
|
ShapedType sType = quantizedType.cast<ShapedType>();
|
||||||
if (!sType.getElementType().isa<QuantizedType>()) {
|
if (!sType.getElementType().isa<QuantizedType>()) {
|
||||||
|
@ -134,7 +137,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
|
||||||
sType.getElementType().cast<QuantizedType>().getStorageType();
|
sType.getElementType().cast<QuantizedType>().getStorageType();
|
||||||
if (quantizedType.isa<RankedTensorType>()) {
|
if (quantizedType.isa<RankedTensorType>()) {
|
||||||
return RankedTensorType::get(sType.getShape(), storageType);
|
return RankedTensorType::get(sType.getShape(), storageType);
|
||||||
} else if (quantizedType.isa<UnrankedTensorType>()) {
|
}
|
||||||
|
if (quantizedType.isa<UnrankedTensorType>()) {
|
||||||
return UnrankedTensorType::get(storageType);
|
return UnrankedTensorType::get(storageType);
|
||||||
} else if (quantizedType.isa<VectorType>()) {
|
} else if (quantizedType.isa<VectorType>()) {
|
||||||
return VectorType::get(sType.getShape(), storageType);
|
return VectorType::get(sType.getShape(), storageType);
|
||||||
|
@ -148,7 +152,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
|
||||||
if (candidateType == getExpressedType()) {
|
if (candidateType == getExpressedType()) {
|
||||||
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
|
// i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
|
||||||
return *this;
|
return *this;
|
||||||
} else if (candidateType.isa<ShapedType>()) {
|
}
|
||||||
|
if (candidateType.isa<ShapedType>()) {
|
||||||
ShapedType candidateShapedType = candidateType.cast<ShapedType>();
|
ShapedType candidateShapedType = candidateType.cast<ShapedType>();
|
||||||
if (candidateShapedType.getElementType() != getExpressedType()) {
|
if (candidateShapedType.getElementType() != getExpressedType()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -157,7 +162,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
|
||||||
if (candidateType.isa<RankedTensorType>()) {
|
if (candidateType.isa<RankedTensorType>()) {
|
||||||
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||||
return RankedTensorType::get(candidateShapedType.getShape(), *this);
|
return RankedTensorType::get(candidateShapedType.getShape(), *this);
|
||||||
} else if (candidateType.isa<UnrankedTensorType>()) {
|
}
|
||||||
|
if (candidateType.isa<UnrankedTensorType>()) {
|
||||||
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
|
||||||
return UnrankedTensorType::get(*this);
|
return UnrankedTensorType::get(*this);
|
||||||
} else if (candidateType.isa<VectorType>()) {
|
} else if (candidateType.isa<VectorType>()) {
|
||||||
|
@ -173,7 +179,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
|
||||||
if (quantizedType.isa<QuantizedType>()) {
|
if (quantizedType.isa<QuantizedType>()) {
|
||||||
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
|
// i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
|
||||||
return quantizedType.cast<QuantizedType>().getExpressedType();
|
return quantizedType.cast<QuantizedType>().getExpressedType();
|
||||||
} else if (quantizedType.isa<ShapedType>()) {
|
}
|
||||||
|
if (quantizedType.isa<ShapedType>()) {
|
||||||
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
|
||||||
ShapedType sType = quantizedType.cast<ShapedType>();
|
ShapedType sType = quantizedType.cast<ShapedType>();
|
||||||
if (!sType.getElementType().isa<QuantizedType>()) {
|
if (!sType.getElementType().isa<QuantizedType>()) {
|
||||||
|
@ -183,7 +190,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
|
||||||
sType.getElementType().cast<QuantizedType>().getExpressedType();
|
sType.getElementType().cast<QuantizedType>().getExpressedType();
|
||||||
if (quantizedType.isa<RankedTensorType>()) {
|
if (quantizedType.isa<RankedTensorType>()) {
|
||||||
return RankedTensorType::get(sType.getShape(), expressedType);
|
return RankedTensorType::get(sType.getShape(), expressedType);
|
||||||
} else if (quantizedType.isa<UnrankedTensorType>()) {
|
}
|
||||||
|
if (quantizedType.isa<UnrankedTensorType>()) {
|
||||||
return UnrankedTensorType::get(expressedType);
|
return UnrankedTensorType::get(expressedType);
|
||||||
} else if (quantizedType.isa<VectorType>()) {
|
} else if (quantizedType.isa<VectorType>()) {
|
||||||
return VectorType::get(sType.getShape(), expressedType);
|
return VectorType::get(sType.getShape(), expressedType);
|
||||||
|
|
|
@ -126,7 +126,7 @@ void ConvertSimulatedQuantPass::runOnFunction() {
|
||||||
bool hadFailure = false;
|
bool hadFailure = false;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
RewritePatternSet patterns(func.getContext());
|
RewritePatternSet patterns(func.getContext());
|
||||||
auto ctx = func.getContext();
|
auto *ctx = func.getContext();
|
||||||
patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
|
patterns.add<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
|
||||||
ctx, &hadFailure);
|
ctx, &hadFailure);
|
||||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||||
|
|
|
@ -140,10 +140,10 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
|
||||||
Location loc, unsigned numBits, int32_t quantizedDimension,
|
Location loc, unsigned numBits, int32_t quantizedDimension,
|
||||||
ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
|
ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
|
||||||
Type expressedType, bool isSigned) {
|
Type expressedType, bool isSigned) {
|
||||||
size_t axis_size = rmins.size();
|
size_t axisSize = rmins.size();
|
||||||
if (axis_size != rmaxs.size()) {
|
if (axisSize != rmaxs.size()) {
|
||||||
return (emitError(loc, "mismatched per-axis min and max size: ")
|
return (emitError(loc, "mismatched per-axis min and max size: ")
|
||||||
<< axis_size << " vs. " << rmaxs.size(),
|
<< axisSize << " vs. " << rmaxs.size(),
|
||||||
nullptr);
|
nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,9 +159,9 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
|
||||||
|
|
||||||
SmallVector<double, 4> scales;
|
SmallVector<double, 4> scales;
|
||||||
SmallVector<int64_t, 4> zeroPoints;
|
SmallVector<int64_t, 4> zeroPoints;
|
||||||
scales.reserve(axis_size);
|
scales.reserve(axisSize);
|
||||||
zeroPoints.reserve(axis_size);
|
zeroPoints.reserve(axisSize);
|
||||||
for (size_t axis = 0; axis != axis_size; ++axis) {
|
for (size_t axis = 0; axis != axisSize; ++axis) {
|
||||||
double rmin = rmins[axis];
|
double rmin = rmins[axis];
|
||||||
double rmax = rmaxs[axis];
|
double rmax = rmaxs[axis];
|
||||||
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
|
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
|
||||||
|
|
|
@ -106,17 +106,17 @@ Attribute mlir::quant::quantizeAttrUniform(
|
||||||
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
|
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
|
||||||
outConvertedType = converted.getType();
|
outConvertedType = converted.getType();
|
||||||
return converted;
|
return converted;
|
||||||
} else if (realValue.isa<SparseElementsAttr>()) {
|
}
|
||||||
|
if (realValue.isa<SparseElementsAttr>()) {
|
||||||
// Sparse tensor or vector constant.
|
// Sparse tensor or vector constant.
|
||||||
auto converted = convertSparseElementsAttr(
|
auto converted = convertSparseElementsAttr(
|
||||||
realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
|
realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
|
||||||
outConvertedType = converted.getType();
|
outConvertedType = converted.getType();
|
||||||
return converted;
|
return converted;
|
||||||
} else {
|
|
||||||
// Nothing else matched: try to convert a primitive.
|
|
||||||
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
|
|
||||||
outConvertedType);
|
|
||||||
}
|
}
|
||||||
|
// Nothing else matched: try to convert a primitive.
|
||||||
|
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
|
||||||
|
outConvertedType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert an attribute from a type based on
|
/// Convert an attribute from a type based on
|
||||||
|
@ -132,9 +132,9 @@ Attribute mlir::quant::quantizeAttr(Attribute realValue,
|
||||||
UniformQuantizedValueConverter converter(uniformQuantized);
|
UniformQuantizedValueConverter converter(uniformQuantized);
|
||||||
return quantizeAttrUniform(realValue, uniformQuantized, converter,
|
return quantizeAttrUniform(realValue, uniformQuantized, converter,
|
||||||
outConvertedType);
|
outConvertedType);
|
||||||
|
}
|
||||||
} else if (auto uniformQuantizedPerAxis =
|
if (auto uniformQuantizedPerAxis =
|
||||||
quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
|
quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
|
||||||
UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
|
UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
|
||||||
auto converted = converter.convert(realValue);
|
auto converted = converter.convert(realValue);
|
||||||
// TODO: why we need this outConvertedType? remove it?
|
// TODO: why we need this outConvertedType? remove it?
|
||||||
|
@ -142,7 +142,6 @@ Attribute mlir::quant::quantizeAttr(Attribute realValue,
|
||||||
outConvertedType = converted.getType();
|
outConvertedType = converted.getType();
|
||||||
}
|
}
|
||||||
return converted;
|
return converted;
|
||||||
} else {
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ static Attribute extractCompositeElement(Attribute composite,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#include "SPIRVCanonicalization.inc"
|
#include "SPIRVCanonicalization.inc"
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.AccessChainOp
|
// spv.AccessChainOp
|
||||||
|
|
|
@ -3250,13 +3250,13 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
|
static void print(spirv::CooperativeMatrixLoadNVOp m, OpAsmPrinter &printer) {
|
||||||
printer << " " << M.pointer() << ", " << M.stride() << ", "
|
printer << " " << m.pointer() << ", " << m.stride() << ", "
|
||||||
<< M.columnmajor();
|
<< m.columnmajor();
|
||||||
// Print optional memory access attribute.
|
// Print optional memory access attribute.
|
||||||
if (auto memAccess = M.memory_access())
|
if (auto memAccess = m.memory_access())
|
||||||
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
|
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
|
||||||
printer << " : " << M.pointer().getType() << " as " << M.getType();
|
printer << " : " << m.pointer().getType() << " as " << m.getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
|
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
|
||||||
|
|
|
@ -31,7 +31,7 @@ using namespace mlir::shape;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#include "ShapeCanonicalization.inc"
|
#include "ShapeCanonicalization.inc"
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
|
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
|
||||||
return RankedTensorType::get({rank}, IndexType::get(ctx));
|
return RankedTensorType::get({rank}, IndexType::get(ctx));
|
||||||
|
@ -50,7 +50,8 @@ LogicalResult shape::getShapeVec(Value input,
|
||||||
return failure();
|
return failure();
|
||||||
shapeValues = llvm::to_vector<6>(type.getShape());
|
shapeValues = llvm::to_vector<6>(type.getShape());
|
||||||
return success();
|
return success();
|
||||||
} else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
|
}
|
||||||
|
if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
|
||||||
shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
|
shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
|
||||||
return success();
|
return success();
|
||||||
} else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
|
} else if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||||
|
|
|
@ -540,7 +540,8 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
||||||
condbr.getTrueOperands());
|
condbr.getTrueOperands());
|
||||||
return success();
|
return success();
|
||||||
} else if (matchPattern(condbr.getCondition(), m_Zero())) {
|
}
|
||||||
|
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
||||||
// False branch taken.
|
// False branch taken.
|
||||||
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
||||||
condbr.getFalseOperands());
|
condbr.getFalseOperands());
|
||||||
|
|
|
@ -152,11 +152,11 @@ struct ReifyExpandOrCollapseShapeOp
|
||||||
reifyResultShapes(Operation *op, OpBuilder &b,
|
reifyResultShapes(Operation *op, OpBuilder &b,
|
||||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto reshape_op = cast<OpTy>(op);
|
auto reshapeOp = cast<OpTy>(op);
|
||||||
auto result_shape = getReshapeOutputShapeFromInputShape(
|
auto resultShape = getReshapeOutputShapeFromInputShape(
|
||||||
b, loc, reshape_op.src(), reshape_op.getResultType().getShape(),
|
b, loc, reshapeOp.src(), reshapeOp.getResultType().getShape(),
|
||||||
reshape_op.getReassociationMaps());
|
reshapeOp.getReassociationMaps());
|
||||||
reifiedReturnShapes.push_back(getAsValues(b, loc, result_shape));
|
reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -634,7 +634,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// ReshapeOp
|
// ReshapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static int64_t GetNumElements(ShapedType type) {
|
static int64_t getNumElements(ShapedType type) {
|
||||||
int64_t numElements = 1;
|
int64_t numElements = 1;
|
||||||
for (auto dim : type.getShape())
|
for (auto dim : type.getShape())
|
||||||
numElements *= dim;
|
numElements *= dim;
|
||||||
|
@ -657,7 +657,7 @@ static LogicalResult verify(ReshapeOp op) {
|
||||||
if (resultRankedType) {
|
if (resultRankedType) {
|
||||||
if (operandRankedType && resultRankedType.hasStaticShape() &&
|
if (operandRankedType && resultRankedType.hasStaticShape() &&
|
||||||
operandRankedType.hasStaticShape()) {
|
operandRankedType.hasStaticShape()) {
|
||||||
if (GetNumElements(operandRankedType) != GetNumElements(resultRankedType))
|
if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
|
||||||
return op.emitOpError("source and destination tensor should have the "
|
return op.emitOpError("source and destination tensor should have the "
|
||||||
"same number of elements");
|
"same number of elements");
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,9 +97,9 @@ public:
|
||||||
|
|
||||||
// Traverse all `elements` and create `memref.store` ops.
|
// Traverse all `elements` and create `memref.store` ops.
|
||||||
ImplicitLocOpBuilder b(loc, rewriter);
|
ImplicitLocOpBuilder b(loc, rewriter);
|
||||||
auto element_it = adaptor.elements().begin();
|
auto elementIt = adaptor.elements().begin();
|
||||||
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
|
||||||
CreateStores(/*dim=*/0, buffer, shape, constants, element_it, indices, b);
|
createStores(/*dim=*/0, buffer, shape, constants, elementIt, indices, b);
|
||||||
|
|
||||||
rewriter.replaceOp(op, {buffer});
|
rewriter.replaceOp(op, {buffer});
|
||||||
return success();
|
return success();
|
||||||
|
@ -108,21 +108,21 @@ public:
|
||||||
private:
|
private:
|
||||||
// Implements backtracking to traverse indices of the output buffer while
|
// Implements backtracking to traverse indices of the output buffer while
|
||||||
// iterating over op.elements().
|
// iterating over op.elements().
|
||||||
void CreateStores(int dim, Value buffer, ArrayRef<int64_t> shape,
|
void createStores(int dim, Value buffer, ArrayRef<int64_t> shape,
|
||||||
ArrayRef<Value> constants, ValueRange::iterator &element_it,
|
ArrayRef<Value> constants, ValueRange::iterator &elementIt,
|
||||||
SmallVectorImpl<Value> &indices,
|
SmallVectorImpl<Value> &indices,
|
||||||
ImplicitLocOpBuilder b) const {
|
ImplicitLocOpBuilder b) const {
|
||||||
if (dim == static_cast<int>(shape.size()) - 1) {
|
if (dim == static_cast<int>(shape.size()) - 1) {
|
||||||
for (int i = 0; i < shape.back(); ++i) {
|
for (int i = 0; i < shape.back(); ++i) {
|
||||||
indices.back() = constants[i];
|
indices.back() = constants[i];
|
||||||
b.create<memref::StoreOp>(*element_it, buffer, indices);
|
b.create<memref::StoreOp>(*elementIt, buffer, indices);
|
||||||
++element_it;
|
++elementIt;
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < shape[dim]; ++i) {
|
for (int i = 0; i < shape[dim]; ++i) {
|
||||||
indices[dim] = constants[i];
|
indices[dim] = constants[i];
|
||||||
CreateStores(dim + 1, buffer, shape, constants, element_it, indices, b);
|
createStores(dim + 1, buffer, shape, constants, elementIt, indices, b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -771,8 +771,8 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
|
||||||
OperationState &result,
|
OperationState &result,
|
||||||
Type outputType, Value input,
|
Type outputType, Value input,
|
||||||
Value paddings,
|
Value paddings,
|
||||||
Value pad_const) {
|
Value padConst) {
|
||||||
result.addOperands({input, paddings, pad_const});
|
result.addOperands({input, paddings, padConst});
|
||||||
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
|
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
|
||||||
if (quantAttr)
|
if (quantAttr)
|
||||||
result.addAttribute("quantization_info", quantAttr);
|
result.addAttribute("quantization_info", quantAttr);
|
||||||
|
|
|
@ -33,9 +33,9 @@ static void getValuesFromIntArrayAttribute(ArrayAttr attr,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TosaOp, typename... Args>
|
template <typename TosaOp, typename... Args>
|
||||||
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
|
||||||
Args &&...args) {
|
Args &&...args) {
|
||||||
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
|
auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
|
||||||
|
|
||||||
InferShapedTypeOpInterface shapeInterface =
|
InferShapedTypeOpInterface shapeInterface =
|
||||||
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
|
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
|
||||||
|
@ -57,12 +57,12 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||||
auto result = op->getResult(0);
|
auto result = op->getResult(0);
|
||||||
auto predictedShape = returnedShapes[0];
|
auto predictedShape = returnedShapes[0];
|
||||||
auto currentKnowledge =
|
auto currentKnowledge =
|
||||||
mlir::tosa::ValueKnowledge::getKnowledgeFromType(result_ty);
|
mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
|
||||||
|
|
||||||
// Compute the knowledge based on the inferred type.
|
// Compute the knowledge based on the inferred type.
|
||||||
auto inferredKnowledge =
|
auto inferredKnowledge =
|
||||||
mlir::tosa::ValueKnowledge::getPessimisticValueState();
|
mlir::tosa::ValueKnowledge::getPessimisticValueState();
|
||||||
inferredKnowledge.dtype = result_ty.cast<ShapedType>().getElementType();
|
inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
|
||||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||||
if (predictedShape.hasRank()) {
|
if (predictedShape.hasRank()) {
|
||||||
for (auto dim : predictedShape.getDims()) {
|
for (auto dim : predictedShape.getDims()) {
|
||||||
|
@ -73,8 +73,8 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||||
// Compute the new type based on the joined version.
|
// Compute the new type based on the joined version.
|
||||||
auto newKnowledge =
|
auto newKnowledge =
|
||||||
mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
||||||
auto new_ty = newKnowledge.getType();
|
auto newTy = newKnowledge.getType();
|
||||||
result.setType(new_ty);
|
result.setType(newTy);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -205,19 +205,19 @@ public:
|
||||||
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
|
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
|
||||||
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
|
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
|
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
|
||||||
Value weightPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
|
Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
|
||||||
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
|
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
|
||||||
|
|
||||||
if (op.quantization_info().hasValue()) {
|
if (op.quantization_info().hasValue()) {
|
||||||
auto quantInfo = op.quantization_info().getValue();
|
auto quantInfo = op.quantization_info().getValue();
|
||||||
weight = CreateOpAndInfer<tosa::PadOp>(
|
weight = createOpAndInfer<tosa::PadOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
weightPaddingVal, nullptr,
|
weightPaddingVal, nullptr,
|
||||||
PadOpQuantizationAttr::get(quantInfo.weight_zp(),
|
PadOpQuantizationAttr::get(quantInfo.weight_zp(),
|
||||||
rewriter.getContext()));
|
rewriter.getContext()));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
weight = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
|
weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
|
||||||
UnrankedTensorType::get(weightETy),
|
UnrankedTensorType::get(weightETy),
|
||||||
weight, weightPaddingVal);
|
weight, weightPaddingVal);
|
||||||
}
|
}
|
||||||
|
@ -231,7 +231,7 @@ public:
|
||||||
outputChannels, weightHeight / stride[0],
|
outputChannels, weightHeight / stride[0],
|
||||||
stride[0], weightWidth / stride[1],
|
stride[0], weightWidth / stride[1],
|
||||||
stride[1], inputChannels};
|
stride[1], inputChannels};
|
||||||
weight = CreateOpAndInfer<tosa::ReshapeOp>(
|
weight = createOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
rewriter.getI64ArrayAttr(weightReshapeDims0));
|
rewriter.getI64ArrayAttr(weightReshapeDims0));
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ public:
|
||||||
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
|
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
|
||||||
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
|
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
|
||||||
|
|
||||||
weight = CreateOpAndInfer<tosa::TransposeOp>(
|
weight = createOpAndInfer<tosa::TransposeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
transposeWeightVal);
|
transposeWeightVal);
|
||||||
|
|
||||||
|
@ -248,15 +248,15 @@ public:
|
||||||
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
|
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
|
||||||
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
|
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
|
||||||
weightWidth / stride[1], inputChannels};
|
weightWidth / stride[1], inputChannels};
|
||||||
weight = CreateOpAndInfer<tosa::ReshapeOp>(
|
weight = createOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
rewriter.getI64ArrayAttr(weightReshapeDims1));
|
rewriter.getI64ArrayAttr(weightReshapeDims1));
|
||||||
ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
|
ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
|
||||||
|
|
||||||
weight = CreateOpAndInfer<tosa::ReverseOp>(
|
weight = createOpAndInfer<tosa::ReverseOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
rewriter.getI64IntegerAttr(1));
|
rewriter.getI64IntegerAttr(1));
|
||||||
weight = CreateOpAndInfer<tosa::ReverseOp>(
|
weight = createOpAndInfer<tosa::ReverseOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||||
rewriter.getI64IntegerAttr(2));
|
rewriter.getI64IntegerAttr(2));
|
||||||
|
|
||||||
|
@ -270,18 +270,18 @@ public:
|
||||||
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
|
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
|
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
|
||||||
|
|
||||||
Value inputPaddingVal = CreateOpAndInfer<tosa::ConstOp>(
|
Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
|
||||||
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
|
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
|
||||||
|
|
||||||
if (op.quantization_info().hasValue()) {
|
if (op.quantization_info().hasValue()) {
|
||||||
auto quantInfo = op.quantization_info().getValue();
|
auto quantInfo = op.quantization_info().getValue();
|
||||||
input = CreateOpAndInfer<tosa::PadOp>(
|
input = createOpAndInfer<tosa::PadOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(inputETy), input,
|
rewriter, loc, UnrankedTensorType::get(inputETy), input,
|
||||||
inputPaddingVal, nullptr,
|
inputPaddingVal, nullptr,
|
||||||
PadOpQuantizationAttr::get(quantInfo.input_zp(),
|
PadOpQuantizationAttr::get(quantInfo.input_zp(),
|
||||||
rewriter.getContext()));
|
rewriter.getContext()));
|
||||||
} else {
|
} else {
|
||||||
input = CreateOpAndInfer<tosa::PadOp>(rewriter, loc,
|
input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
|
||||||
UnrankedTensorType::get(inputETy),
|
UnrankedTensorType::get(inputETy),
|
||||||
input, inputPaddingVal);
|
input, inputPaddingVal);
|
||||||
}
|
}
|
||||||
|
@ -299,7 +299,7 @@ public:
|
||||||
// Perform the convolution using the zero bias.
|
// Perform the convolution using the zero bias.
|
||||||
Value conv2d;
|
Value conv2d;
|
||||||
if (op.quantization_info().hasValue()) {
|
if (op.quantization_info().hasValue()) {
|
||||||
conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
|
conv2d = createOpAndInfer<tosa::Conv2DOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(resultETy), input,
|
rewriter, loc, UnrankedTensorType::get(resultETy), input,
|
||||||
weight, zeroBias,
|
weight, zeroBias,
|
||||||
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
|
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
|
||||||
|
@ -308,7 +308,7 @@ public:
|
||||||
op.quantization_info().getValue())
|
op.quantization_info().getValue())
|
||||||
.getResult();
|
.getResult();
|
||||||
} else {
|
} else {
|
||||||
conv2d = CreateOpAndInfer<tosa::Conv2DOp>(
|
conv2d = createOpAndInfer<tosa::Conv2DOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(resultETy), input,
|
rewriter, loc, UnrankedTensorType::get(resultETy), input,
|
||||||
weight, zeroBias,
|
weight, zeroBias,
|
||||||
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
|
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
|
||||||
|
@ -327,7 +327,7 @@ public:
|
||||||
// Factor striding out of the convolution result.
|
// Factor striding out of the convolution result.
|
||||||
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
|
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
|
||||||
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
|
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
|
||||||
conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
|
conv2d = createOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
||||||
rewriter.getI64ArrayAttr(convReshapeDims0));
|
rewriter.getI64ArrayAttr(convReshapeDims0));
|
||||||
|
|
||||||
|
@ -336,14 +336,14 @@ public:
|
||||||
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
|
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
|
||||||
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
|
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
|
||||||
|
|
||||||
conv2d = CreateOpAndInfer<tosa::TransposeOp>(
|
conv2d = createOpAndInfer<tosa::TransposeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
|
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
|
||||||
transposeConvVal);
|
transposeConvVal);
|
||||||
|
|
||||||
// Fuse striding behavior back into width / height.
|
// Fuse striding behavior back into width / height.
|
||||||
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
|
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
|
||||||
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
|
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
|
||||||
conv2d = CreateOpAndInfer<tosa::ReshapeOp>(
|
conv2d = createOpAndInfer<tosa::ReshapeOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
||||||
rewriter.getI64ArrayAttr(convReshapeDims1));
|
rewriter.getI64ArrayAttr(convReshapeDims1));
|
||||||
|
|
||||||
|
@ -354,14 +354,14 @@ public:
|
||||||
sliceBegin[1] = pad[0];
|
sliceBegin[1] = pad[0];
|
||||||
sliceBegin[2] = pad[1];
|
sliceBegin[2] = pad[1];
|
||||||
|
|
||||||
auto slice = CreateOpAndInfer<tosa::SliceOp>(
|
auto slice = createOpAndInfer<tosa::SliceOp>(
|
||||||
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
|
||||||
rewriter.getI64ArrayAttr(sliceBegin),
|
rewriter.getI64ArrayAttr(sliceBegin),
|
||||||
rewriter.getI64ArrayAttr(resultTy.getShape()))
|
rewriter.getI64ArrayAttr(resultTy.getShape()))
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
auto addBias =
|
auto addBias =
|
||||||
CreateOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
|
createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
|
||||||
|
|
||||||
rewriter.replaceOp(op, addBias.getResult());
|
rewriter.replaceOp(op, addBias.getResult());
|
||||||
|
|
||||||
|
|
|
@ -223,7 +223,7 @@ void propagateShapesInRegion(Region ®ion) {
|
||||||
// Check whether this use case is replaceable. We define an op as
|
// Check whether this use case is replaceable. We define an op as
|
||||||
// being replaceable if it is used by a ReturnOp or a TosaOp.
|
// being replaceable if it is used by a ReturnOp or a TosaOp.
|
||||||
bool replaceable = true;
|
bool replaceable = true;
|
||||||
for (auto user : result.getUsers()) {
|
for (auto *user : result.getUsers()) {
|
||||||
if (isa<ReturnOp>(user))
|
if (isa<ReturnOp>(user))
|
||||||
continue;
|
continue;
|
||||||
if (user->getDialect()->getNamespace() ==
|
if (user->getDialect()->getNamespace() ==
|
||||||
|
|
|
@ -1179,7 +1179,7 @@ struct UnrolledOuterProductGenerator
|
||||||
return builder.create<vector::TransposeOp>(loc, v, perm);
|
return builder.create<vector::TransposeOp>(loc, v, perm);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
|
Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
|
||||||
assert(reductionSize > 0);
|
assert(reductionSize > 0);
|
||||||
for (int64_t k = 0; k < reductionSize; ++k) {
|
for (int64_t k = 0; k < reductionSize; ++k) {
|
||||||
Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
|
Value a = builder.create<vector::ExtractOp>(loc, lhs, k);
|
||||||
|
@ -1199,31 +1199,31 @@ struct UnrolledOuterProductGenerator
|
||||||
bindDims(builder.getContext(), m, n, k);
|
bindDims(builder.getContext(), m, n, k);
|
||||||
// Classical row-major matmul: Just permute the lhs.
|
// Classical row-major matmul: Just permute the lhs.
|
||||||
if (layout({{m, k}, {k, n}, {m, n}}))
|
if (layout({{m, k}, {k, n}, {m, n}}))
|
||||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||||
if (layout({{m, k}, {n, k}, {m, n}})) {
|
if (layout({{m, k}, {n, k}, {m, n}})) {
|
||||||
Value tlhs = t(lhs);
|
Value tlhs = t(lhs);
|
||||||
return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1));
|
return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
|
||||||
}
|
}
|
||||||
// No need to permute anything.
|
// No need to permute anything.
|
||||||
if (layout({{k, m}, {k, n}, {m, n}}))
|
if (layout({{k, m}, {k, n}, {m, n}}))
|
||||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
|
||||||
// Just permute the rhs.
|
// Just permute the rhs.
|
||||||
if (layout({{k, m}, {n, k}, {m, n}}))
|
if (layout({{k, m}, {n, k}, {m, n}}))
|
||||||
return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0));
|
return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
|
||||||
// Transposed output: swap RHS and LHS.
|
// Transposed output: swap RHS and LHS.
|
||||||
// Classical row-major matmul: permute the lhs.
|
// Classical row-major matmul: permute the lhs.
|
||||||
if (layout({{m, k}, {k, n}, {n, m}}))
|
if (layout({{m, k}, {k, n}, {n, m}}))
|
||||||
return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1));
|
return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
|
||||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||||
if (layout({{m, k}, {n, k}, {n, m}})) {
|
if (layout({{m, k}, {n, k}, {n, m}})) {
|
||||||
Value trhs = t(rhs);
|
Value trhs = t(rhs);
|
||||||
return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1));
|
return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
|
||||||
}
|
}
|
||||||
if (layout({{k, m}, {k, n}, {n, m}}))
|
if (layout({{k, m}, {k, n}, {n, m}}))
|
||||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
|
||||||
if (layout({{k, m}, {n, k}, {n, m}}))
|
if (layout({{k, m}, {n, k}, {n, m}}))
|
||||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1236,16 +1236,16 @@ struct UnrolledOuterProductGenerator
|
||||||
|
|
||||||
// Case mat-vec: transpose.
|
// Case mat-vec: transpose.
|
||||||
if (layout({{m, k}, {k}, {m}}))
|
if (layout({{m, k}, {k}, {m}}))
|
||||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||||
// Case mat-trans-vec: ready to go.
|
// Case mat-trans-vec: ready to go.
|
||||||
if (layout({{k, m}, {k}, {m}}))
|
if (layout({{k, m}, {k}, {m}}))
|
||||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
|
||||||
// Case vec-mat: swap and transpose.
|
// Case vec-mat: swap and transpose.
|
||||||
if (layout({{k}, {m, k}, {m}}))
|
if (layout({{k}, {m, k}, {m}}))
|
||||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||||
// Case vec-mat-trans: swap and ready to go.
|
// Case vec-mat-trans: swap and ready to go.
|
||||||
if (layout({{k}, {k, m}, {m}}))
|
if (layout({{k}, {k, m}, {m}}))
|
||||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1260,16 +1260,16 @@ struct UnrolledOuterProductGenerator
|
||||||
|
|
||||||
// Case mat-vec: transpose.
|
// Case mat-vec: transpose.
|
||||||
if (layout({{m, k}, {k}, {m}}))
|
if (layout({{m, k}, {k}, {m}}))
|
||||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||||
// Case mat-trans-vec: ready to go.
|
// Case mat-trans-vec: ready to go.
|
||||||
if (layout({{k, m}, {k}, {m}}))
|
if (layout({{k, m}, {k}, {m}}))
|
||||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
|
||||||
// Case vec-mat: swap and transpose.
|
// Case vec-mat: swap and transpose.
|
||||||
if (layout({{k}, {m, k}, {m}}))
|
if (layout({{k}, {m, k}, {m}}))
|
||||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||||
// Case vec-mat-trans: swap and ready to go.
|
// Case vec-mat-trans: swap and ready to go.
|
||||||
if (layout({{k}, {k, m}, {m}}))
|
if (layout({{k}, {k, m}, {m}}))
|
||||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,8 +31,9 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
|
||||||
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
|
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
|
||||||
auto asmDialectAttr =
|
auto asmDialectAttr =
|
||||||
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
|
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
|
||||||
auto asmTp = "vblendps $0, $1, $2, {0}";
|
const auto *asmTp = "vblendps $0, $1, $2, {0}";
|
||||||
auto asmCstr = "=x,x,x"; // Careful: constraint parser is very brittle: no ws!
|
const auto *asmCstr =
|
||||||
|
"=x,x,x"; // Careful: constraint parser is very brittle: no ws!
|
||||||
SmallVector<Value> asmVals{v1, v2};
|
SmallVector<Value> asmVals{v1, v2};
|
||||||
auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
|
auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
|
||||||
auto asmOp = b.create<LLVM::InlineAsmOp>(
|
auto asmOp = b.create<LLVM::InlineAsmOp>(
|
||||||
|
@ -116,18 +117,18 @@ void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
|
||||||
"expects all types to be vector<8xf32>");
|
"expects all types to be vector<8xf32>");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
||||||
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
||||||
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
||||||
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
||||||
Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>());
|
Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||||
Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>());
|
Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||||
Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>());
|
Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
|
||||||
Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>());
|
Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
|
||||||
vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>());
|
vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
|
||||||
vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>());
|
vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
|
||||||
vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>());
|
vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
|
||||||
vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>());
|
vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
|
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
|
||||||
|
@ -140,46 +141,46 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
|
||||||
[&](Type t) { return t == vt; }) &&
|
[&](Type t) { return t == vt; }) &&
|
||||||
"expects all types to be vector<8xf32>");
|
"expects all types to be vector<8xf32>");
|
||||||
|
|
||||||
Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
|
||||||
Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
|
||||||
Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
|
||||||
Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
|
||||||
Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
|
Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
|
||||||
Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
|
Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
|
||||||
Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
|
Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
|
||||||
Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
|
Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
|
||||||
|
|
||||||
using inline_asm::mm256BlendPsAsm;
|
using inline_asm::mm256BlendPsAsm;
|
||||||
Value sh0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 3, 2>());
|
Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
|
||||||
Value sh2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 3, 2>());
|
Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
|
||||||
Value sh4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 3, 2>());
|
Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
|
||||||
Value sh6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 3, 2>());
|
Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
|
||||||
|
|
||||||
Value S0 =
|
Value s0 =
|
||||||
mm256BlendPsAsm(ib, T0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
||||||
Value S1 =
|
Value s1 =
|
||||||
mm256BlendPsAsm(ib, T2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
||||||
Value S2 =
|
Value s2 =
|
||||||
mm256BlendPsAsm(ib, T1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
||||||
Value S3 =
|
Value s3 =
|
||||||
mm256BlendPsAsm(ib, T3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
||||||
Value S4 =
|
Value s4 =
|
||||||
mm256BlendPsAsm(ib, T4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
||||||
Value S5 =
|
Value s5 =
|
||||||
mm256BlendPsAsm(ib, T6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
||||||
Value S6 =
|
Value s6 =
|
||||||
mm256BlendPsAsm(ib, T5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
|
||||||
Value S7 =
|
Value s7 =
|
||||||
mm256BlendPsAsm(ib, T7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
|
||||||
|
|
||||||
vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>());
|
vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
|
||||||
vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>());
|
vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
|
||||||
vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>());
|
vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
|
||||||
vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>());
|
vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
|
||||||
vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>());
|
vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
|
||||||
vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>());
|
vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
|
||||||
vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>());
|
vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
|
||||||
vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>());
|
vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
|
/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and
|
||||||
|
|
|
@ -463,8 +463,10 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
|
||||||
// https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
|
// https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
|
||||||
// The bug is fixed in VS2019 16.1. Separating the declaration and definition is
|
// The bug is fixed in VS2019 16.1. Separating the declaration and definition is
|
||||||
// a work around for older versions of Visual Studio.
|
// a work around for older versions of Visual Studio.
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
|
extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols);
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
||||||
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
|
auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
|
||||||
assert(exportSymbols.count(name) == 0 && "symbol already exists");
|
assert(exportSymbols.count(name) == 0 && "symbol already exists");
|
||||||
|
@ -517,6 +519,7 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
|
||||||
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
|
&mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(*-identifier-naming): externally called.
|
||||||
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
|
extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
|
||||||
|
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
|
|
|
@ -58,27 +58,27 @@ using llvm::orc::ThreadSafeModule;
|
||||||
using llvm::orc::TMOwningSimpleCompiler;
|
using llvm::orc::TMOwningSimpleCompiler;
|
||||||
|
|
||||||
/// Wrap a string into an llvm::StringError.
|
/// Wrap a string into an llvm::StringError.
|
||||||
static Error make_string_error(const Twine &message) {
|
static Error makeStringError(const Twine &message) {
|
||||||
return llvm::make_error<StringError>(message.str(),
|
return llvm::make_error<StringError>(message.str(),
|
||||||
llvm::inconvertibleErrorCode());
|
llvm::inconvertibleErrorCode());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleObjectCache::notifyObjectCompiled(const Module *M,
|
void SimpleObjectCache::notifyObjectCompiled(const Module *m,
|
||||||
MemoryBufferRef ObjBuffer) {
|
MemoryBufferRef objBuffer) {
|
||||||
cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
|
cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
|
||||||
ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier());
|
objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) {
|
std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
|
||||||
auto I = cachedObjects.find(M->getModuleIdentifier());
|
auto i = cachedObjects.find(m->getModuleIdentifier());
|
||||||
if (I == cachedObjects.end()) {
|
if (i == cachedObjects.end()) {
|
||||||
LLVM_DEBUG(dbgs() << "No object for " << M->getModuleIdentifier()
|
LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
|
||||||
<< " in cache. Compiling.\n");
|
<< " in cache. Compiling.\n");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
LLVM_DEBUG(dbgs() << "Object for " << M->getModuleIdentifier()
|
LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
|
||||||
<< " loaded from cache.\n");
|
<< " loaded from cache.\n");
|
||||||
return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef());
|
return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
|
void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
|
||||||
|
@ -114,7 +114,8 @@ bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
|
||||||
// Setup the machine properties from the current architecture.
|
// Setup the machine properties from the current architecture.
|
||||||
auto targetTriple = llvm::sys::getDefaultTargetTriple();
|
auto targetTriple = llvm::sys::getDefaultTargetTriple();
|
||||||
std::string errorMessage;
|
std::string errorMessage;
|
||||||
auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
|
const auto *target =
|
||||||
|
llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
|
||||||
if (!target) {
|
if (!target) {
|
||||||
errs() << "NO target: " << errorMessage << "\n";
|
errs() << "NO target: " << errorMessage << "\n";
|
||||||
return true;
|
return true;
|
||||||
|
@ -160,7 +161,7 @@ static void packFunctionArguments(Module *module) {
|
||||||
|
|
||||||
// Given a function `foo(<...>)`, define the interface function
|
// Given a function `foo(<...>)`, define the interface function
|
||||||
// `mlir_foo(i8**)`.
|
// `mlir_foo(i8**)`.
|
||||||
auto newType = llvm::FunctionType::get(
|
auto *newType = llvm::FunctionType::get(
|
||||||
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
|
builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
auto newName = makePackedFunctionName(func.getName());
|
auto newName = makePackedFunctionName(func.getName());
|
||||||
|
@ -170,7 +171,7 @@ static void packFunctionArguments(Module *module) {
|
||||||
|
|
||||||
// Extract the arguments from the type-erased argument list and cast them to
|
// Extract the arguments from the type-erased argument list and cast them to
|
||||||
// the proper types.
|
// the proper types.
|
||||||
auto bb = llvm::BasicBlock::Create(ctx);
|
auto *bb = llvm::BasicBlock::Create(ctx);
|
||||||
bb->insertInto(interfaceFunc);
|
bb->insertInto(interfaceFunc);
|
||||||
builder.SetInsertPoint(bb);
|
builder.SetInsertPoint(bb);
|
||||||
llvm::Value *argList = interfaceFunc->arg_begin();
|
llvm::Value *argList = interfaceFunc->arg_begin();
|
||||||
|
@ -237,7 +238,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||||
auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx)
|
auto llvmModule = llvmModuleBuilder ? llvmModuleBuilder(m, *ctx)
|
||||||
: translateModuleToLLVMIR(m, *ctx);
|
: translateModuleToLLVMIR(m, *ctx);
|
||||||
if (!llvmModule)
|
if (!llvmModule)
|
||||||
return make_string_error("could not convert to LLVM IR");
|
return makeStringError("could not convert to LLVM IR");
|
||||||
// FIXME: the triple should be passed to the translation or dialect conversion
|
// FIXME: the triple should be passed to the translation or dialect conversion
|
||||||
// instead of this. Currently, the LLVM module created above has no triple
|
// instead of this. Currently, the LLVM module created above has no triple
|
||||||
// associated with it.
|
// associated with it.
|
||||||
|
@ -249,7 +250,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||||
// Callback to create the object layer with symbol resolution to current
|
// Callback to create the object layer with symbol resolution to current
|
||||||
// process and dynamically linked libraries.
|
// process and dynamically linked libraries.
|
||||||
auto objectLinkingLayerCreator = [&](ExecutionSession &session,
|
auto objectLinkingLayerCreator = [&](ExecutionSession &session,
|
||||||
const Triple &TT) {
|
const Triple &tt) {
|
||||||
auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
|
auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
|
||||||
session, []() { return std::make_unique<SectionMemoryManager>(); });
|
session, []() { return std::make_unique<SectionMemoryManager>(); });
|
||||||
|
|
||||||
|
@ -276,7 +277,7 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||||
<< "\nError: " << mb.getError().message() << "\n";
|
<< "\nError: " << mb.getError().message() << "\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto &JD = session.createBareJITDylib(std::string(libPath));
|
auto &jd = session.createBareJITDylib(std::string(libPath));
|
||||||
auto loaded = DynamicLibrarySearchGenerator::Load(
|
auto loaded = DynamicLibrarySearchGenerator::Load(
|
||||||
libPath.data(), dataLayout.getGlobalPrefix());
|
libPath.data(), dataLayout.getGlobalPrefix());
|
||||||
if (!loaded) {
|
if (!loaded) {
|
||||||
|
@ -284,8 +285,8 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||||
<< "\n";
|
<< "\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
JD.addGenerator(std::move(*loaded));
|
jd.addGenerator(std::move(*loaded));
|
||||||
cantFail(objectLayer->add(JD, std::move(mb.get())));
|
cantFail(objectLayer->add(jd, std::move(mb.get())));
|
||||||
}
|
}
|
||||||
|
|
||||||
return objectLayer;
|
return objectLayer;
|
||||||
|
@ -293,14 +294,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
|
||||||
|
|
||||||
// Callback to inspect the cache and recompile on demand. This follows Lang's
|
// Callback to inspect the cache and recompile on demand. This follows Lang's
|
||||||
// LLJITWithObjectCache example.
|
// LLJITWithObjectCache example.
|
||||||
auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB)
|
auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
|
||||||
-> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
|
-> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
|
||||||
if (jitCodeGenOptLevel)
|
if (jitCodeGenOptLevel)
|
||||||
JTMB.setCodeGenOptLevel(jitCodeGenOptLevel.getValue());
|
jtmb.setCodeGenOptLevel(jitCodeGenOptLevel.getValue());
|
||||||
auto TM = JTMB.createTargetMachine();
|
auto tm = jtmb.createTargetMachine();
|
||||||
if (!TM)
|
if (!tm)
|
||||||
return TM.takeError();
|
return tm.takeError();
|
||||||
return std::make_unique<TMOwningSimpleCompiler>(std::move(*TM),
|
return std::make_unique<TMOwningSimpleCompiler>(std::move(*tm),
|
||||||
engine->cache.get());
|
engine->cache.get());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -350,13 +351,13 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
|
||||||
llvm::raw_string_ostream os(errorMessage);
|
llvm::raw_string_ostream os(errorMessage);
|
||||||
llvm::handleAllErrors(expectedSymbol.takeError(),
|
llvm::handleAllErrors(expectedSymbol.takeError(),
|
||||||
[&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
|
[&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
|
||||||
return make_string_error(os.str());
|
return makeStringError(os.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rawFPtr = expectedSymbol->getAddress();
|
auto rawFPtr = expectedSymbol->getAddress();
|
||||||
auto fptr = reinterpret_cast<void *>(rawFPtr);
|
auto *fptr = reinterpret_cast<void *>(rawFPtr);
|
||||||
if (!fptr)
|
if (!fptr)
|
||||||
return make_string_error("looked up function is null");
|
return makeStringError("looked up function is null");
|
||||||
return fptr;
|
return fptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ static OwningModuleRef parseMLIRInput(StringRef inputFilename,
|
||||||
return OwningModuleRef(parseSourceFile(sourceMgr, context));
|
return OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Error make_string_error(const Twine &message) {
|
static inline Error makeStringError(const Twine &message) {
|
||||||
return llvm::make_error<llvm::StringError>(message.str(),
|
return llvm::make_error<llvm::StringError>(message.str(),
|
||||||
llvm::inconvertibleErrorCode());
|
llvm::inconvertibleErrorCode());
|
||||||
}
|
}
|
||||||
|
@ -239,7 +239,7 @@ static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
|
||||||
CompileAndExecuteConfig config) {
|
CompileAndExecuteConfig config) {
|
||||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||||
if (!mainFunction || mainFunction.empty())
|
if (!mainFunction || mainFunction.empty())
|
||||||
return make_string_error("entry point not found");
|
return makeStringError("entry point not found");
|
||||||
void *empty = nullptr;
|
void *empty = nullptr;
|
||||||
return compileAndExecute(options, module, entryPoint, config, &empty);
|
return compileAndExecute(options, module, entryPoint, config, &empty);
|
||||||
}
|
}
|
||||||
|
@ -253,7 +253,7 @@ Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
|
||||||
.getReturnType()
|
.getReturnType()
|
||||||
.dyn_cast<IntegerType>();
|
.dyn_cast<IntegerType>();
|
||||||
if (!resultType || resultType.getWidth() != 32)
|
if (!resultType || resultType.getWidth() != 32)
|
||||||
return make_string_error("only single i32 function result supported");
|
return makeStringError("only single i32 function result supported");
|
||||||
return Error::success();
|
return Error::success();
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -263,7 +263,7 @@ Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
|
||||||
.getReturnType()
|
.getReturnType()
|
||||||
.dyn_cast<IntegerType>();
|
.dyn_cast<IntegerType>();
|
||||||
if (!resultType || resultType.getWidth() != 64)
|
if (!resultType || resultType.getWidth() != 64)
|
||||||
return make_string_error("only single i64 function result supported");
|
return makeStringError("only single i64 function result supported");
|
||||||
return Error::success();
|
return Error::success();
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -272,7 +272,7 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
|
||||||
.cast<LLVM::LLVMFunctionType>()
|
.cast<LLVM::LLVMFunctionType>()
|
||||||
.getReturnType()
|
.getReturnType()
|
||||||
.isa<Float32Type>())
|
.isa<Float32Type>())
|
||||||
return make_string_error("only single f32 function result supported");
|
return makeStringError("only single f32 function result supported");
|
||||||
return Error::success();
|
return Error::success();
|
||||||
}
|
}
|
||||||
template <typename Type>
|
template <typename Type>
|
||||||
|
@ -281,10 +281,10 @@ Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
|
||||||
CompileAndExecuteConfig config) {
|
CompileAndExecuteConfig config) {
|
||||||
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
|
||||||
if (!mainFunction || mainFunction.isExternal())
|
if (!mainFunction || mainFunction.isExternal())
|
||||||
return make_string_error("entry point not found");
|
return makeStringError("entry point not found");
|
||||||
|
|
||||||
if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
|
if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
|
||||||
return make_string_error("function inputs not supported");
|
return makeStringError("function inputs not supported");
|
||||||
|
|
||||||
if (Error error = checkCompatibleReturnType<Type>(mainFunction))
|
if (Error error = checkCompatibleReturnType<Type>(mainFunction))
|
||||||
return error;
|
return error;
|
||||||
|
@ -384,7 +384,7 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry,
|
||||||
? compileAndExecuteFn(options, m.get(),
|
? compileAndExecuteFn(options, m.get(),
|
||||||
options.mainFuncName.getValue(),
|
options.mainFuncName.getValue(),
|
||||||
compileAndExecuteConfig)
|
compileAndExecuteConfig)
|
||||||
: make_string_error("unsupported function type");
|
: makeStringError("unsupported function type");
|
||||||
|
|
||||||
int exitCode = EXIT_SUCCESS;
|
int exitCode = EXIT_SUCCESS;
|
||||||
llvm::handleAllErrors(std::move(error),
|
llvm::handleAllErrors(std::move(error),
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
#include "mlir/ExecutionEngine/RunnerUtils.h"
|
#include "mlir/ExecutionEngine/RunnerUtils.h"
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
|
// NOLINTBEGIN(*-identifier-naming)
|
||||||
|
|
||||||
extern "C" void
|
extern "C" void
|
||||||
_mlir_ciface_print_memref_shape_i8(UnrankedMemRefType<int8_t> *M) {
|
_mlir_ciface_print_memref_shape_i8(UnrankedMemRefType<int8_t> *M) {
|
||||||
std::cout << "Unranked Memref ";
|
std::cout << "Unranked Memref ";
|
||||||
|
@ -163,3 +165,5 @@ extern "C" int64_t verifyMemRefF64(int64_t rank, void *actualPtr,
|
||||||
UnrankedMemRefType<double> expectedDesc = {rank, expectedPtr};
|
UnrankedMemRefType<double> expectedDesc = {rank, expectedPtr};
|
||||||
return _mlir_ciface_verifyMemRefF64(&actualDesc, &expectedDesc);
|
return _mlir_ciface_verifyMemRefF64(&actualDesc, &expectedDesc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOLINTEND(*-identifier-naming)
|
||||||
|
|
|
@ -209,7 +209,7 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
|
||||||
SmallVector<AffineExpr, 4> affExprs;
|
SmallVector<AffineExpr, 4> affExprs;
|
||||||
for (auto index : permutation)
|
for (auto index : permutation)
|
||||||
affExprs.push_back(getAffineDimExpr(index, context));
|
affExprs.push_back(getAffineDimExpr(index, context));
|
||||||
auto m = std::max_element(permutation.begin(), permutation.end());
|
const auto *m = std::max_element(permutation.begin(), permutation.end());
|
||||||
auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
|
auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
|
||||||
assert(permutationMap.isPermutation() && "Invalid permutation vector");
|
assert(permutationMap.isPermutation() && "Invalid permutation vector");
|
||||||
return permutationMap;
|
return permutationMap;
|
||||||
|
|
|
@ -1105,7 +1105,7 @@ void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
|
||||||
|
|
||||||
// Find the correct index using a binary search, as the groups are ordered.
|
// Find the correct index using a binary search, as the groups are ordered.
|
||||||
ArrayRef<int> resultGroups = resultGroupIt->second;
|
ArrayRef<int> resultGroups = resultGroupIt->second;
|
||||||
auto it = llvm::upper_bound(resultGroups, resultNo);
|
const auto *it = llvm::upper_bound(resultGroups, resultNo);
|
||||||
int groupResultNo = 0, groupSize = 0;
|
int groupResultNo = 0, groupSize = 0;
|
||||||
|
|
||||||
// If there are no smaller elements, the last result group is the lookup.
|
// If there are no smaller elements, the last result group is the lookup.
|
||||||
|
@ -1240,8 +1240,8 @@ public:
|
||||||
raw_ostream &getStream() { return os; }
|
raw_ostream &getStream() { return os; }
|
||||||
|
|
||||||
template <typename Container, typename UnaryFunctor>
|
template <typename Container, typename UnaryFunctor>
|
||||||
inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
|
inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
|
||||||
llvm::interleaveComma(c, os, each_fn);
|
llvm::interleaveComma(c, os, eachFn);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This enum describes the different kinds of elision for the type of an
|
/// This enum describes the different kinds of elision for the type of an
|
||||||
|
|
|
@ -316,7 +316,7 @@ Block *Block::getUniquePredecessor() {
|
||||||
Block *Block::splitBlock(iterator splitBefore) {
|
Block *Block::splitBlock(iterator splitBefore) {
|
||||||
// Start by creating a new basic block, and insert it immediate after this
|
// Start by creating a new basic block, and insert it immediate after this
|
||||||
// one in the containing region.
|
// one in the containing region.
|
||||||
auto newBB = new Block();
|
auto *newBB = new Block();
|
||||||
getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
|
getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);
|
||||||
|
|
||||||
// Move all of the operations from the split point to the end of the region
|
// Move all of the operations from the split point to the end of the region
|
||||||
|
|
|
@ -121,10 +121,10 @@ findDuplicateElement(ArrayRef<NamedAttribute> value) {
|
||||||
if (value.size() == 2)
|
if (value.size() == 2)
|
||||||
return value[0].getName() == value[1].getName() ? value[0] : none;
|
return value[0].getName() == value[1].getName() ? value[0] : none;
|
||||||
|
|
||||||
auto it = std::adjacent_find(value.begin(), value.end(),
|
const auto *it = std::adjacent_find(value.begin(), value.end(),
|
||||||
[](NamedAttribute l, NamedAttribute r) {
|
[](NamedAttribute l, NamedAttribute r) {
|
||||||
return l.getName() == r.getName();
|
return l.getName() == r.getName();
|
||||||
});
|
});
|
||||||
return it != value.end() ? *it : none;
|
return it != value.end() ? *it : none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,6 @@ using namespace mlir;
|
||||||
using namespace mlir::detail;
|
using namespace mlir::detail;
|
||||||
|
|
||||||
using llvm::hash_combine;
|
using llvm::hash_combine;
|
||||||
using llvm::hash_combine_range;
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MLIRContext CommandLine Options
|
// MLIRContext CommandLine Options
|
||||||
|
|
|
@ -349,28 +349,28 @@ void Operation::updateOrderIfNecessary() {
|
||||||
|
|
||||||
auto llvm::ilist_detail::SpecificNodeAccess<
|
auto llvm::ilist_detail::SpecificNodeAccess<
|
||||||
typename llvm::ilist_detail::compute_node_options<
|
typename llvm::ilist_detail::compute_node_options<
|
||||||
::mlir::Operation>::type>::getNodePtr(pointer N) -> node_type * {
|
::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * {
|
||||||
return NodeAccess::getNodePtr<OptionsT>(N);
|
return NodeAccess::getNodePtr<OptionsT>(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto llvm::ilist_detail::SpecificNodeAccess<
|
auto llvm::ilist_detail::SpecificNodeAccess<
|
||||||
typename llvm::ilist_detail::compute_node_options<
|
typename llvm::ilist_detail::compute_node_options<
|
||||||
::mlir::Operation>::type>::getNodePtr(const_pointer N)
|
::mlir::Operation>::type>::getNodePtr(const_pointer n)
|
||||||
-> const node_type * {
|
-> const node_type * {
|
||||||
return NodeAccess::getNodePtr<OptionsT>(N);
|
return NodeAccess::getNodePtr<OptionsT>(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto llvm::ilist_detail::SpecificNodeAccess<
|
auto llvm::ilist_detail::SpecificNodeAccess<
|
||||||
typename llvm::ilist_detail::compute_node_options<
|
typename llvm::ilist_detail::compute_node_options<
|
||||||
::mlir::Operation>::type>::getValuePtr(node_type *N) -> pointer {
|
::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer {
|
||||||
return NodeAccess::getValuePtr<OptionsT>(N);
|
return NodeAccess::getValuePtr<OptionsT>(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto llvm::ilist_detail::SpecificNodeAccess<
|
auto llvm::ilist_detail::SpecificNodeAccess<
|
||||||
typename llvm::ilist_detail::compute_node_options<
|
typename llvm::ilist_detail::compute_node_options<
|
||||||
::mlir::Operation>::type>::getValuePtr(const node_type *N)
|
::mlir::Operation>::type>::getValuePtr(const node_type *n)
|
||||||
-> const_pointer {
|
-> const_pointer {
|
||||||
return NodeAccess::getValuePtr<OptionsT>(N);
|
return NodeAccess::getValuePtr<OptionsT>(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
|
void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
|
||||||
|
@ -378,9 +378,9 @@ void llvm::ilist_traits<::mlir::Operation>::deleteNode(Operation *op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
|
Block *llvm::ilist_traits<::mlir::Operation>::getContainingBlock() {
|
||||||
size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
|
size_t offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
|
||||||
iplist<Operation> *Anchor(static_cast<iplist<Operation> *>(this));
|
iplist<Operation> *anchor(static_cast<iplist<Operation> *>(this));
|
||||||
return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
|
return reinterpret_cast<Block *>(reinterpret_cast<char *>(anchor) - offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This is a trait method invoked when an operation is added to a block. We
|
/// This is a trait method invoked when an operation is added to a block. We
|
||||||
|
@ -1024,8 +1024,7 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
|
||||||
if (op->getNumRegions() > 1)
|
if (op->getNumRegions() > 1)
|
||||||
return op->emitOpError("region #")
|
return op->emitOpError("region #")
|
||||||
<< region.getRegionNumber() << " should have no arguments";
|
<< region.getRegionNumber() << " should have no arguments";
|
||||||
else
|
return op->emitOpError("region should have no arguments");
|
||||||
return op->emitOpError("region should have no arguments");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -34,8 +34,8 @@ NamedAttrList::NamedAttrList(DictionaryAttr attributes)
|
||||||
dictionarySorted.setPointerAndInt(attributes, true);
|
dictionarySorted.setPointerAndInt(attributes, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) {
|
NamedAttrList::NamedAttrList(const_iterator inStart, const_iterator inEnd) {
|
||||||
assign(in_start, in_end);
|
assign(inStart, inEnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
|
ArrayRef<NamedAttribute> NamedAttrList::getAttrs() const { return attrs; }
|
||||||
|
@ -66,8 +66,8 @@ void NamedAttrList::append(StringRef name, Attribute attr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Replaces the attributes with new list of attributes.
|
/// Replaces the attributes with new list of attributes.
|
||||||
void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) {
|
void NamedAttrList::assign(const_iterator inStart, const_iterator inEnd) {
|
||||||
DictionaryAttr::sort(ArrayRef<NamedAttribute>{in_start, in_end}, attrs);
|
DictionaryAttr::sort(ArrayRef<NamedAttribute>{inStart, inEnd}, attrs);
|
||||||
dictionarySorted.setPointerAndInt(nullptr, true);
|
dictionarySorted.setPointerAndInt(nullptr, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -152,10 +152,10 @@ void Region::dropAllReferences() {
|
||||||
}
|
}
|
||||||
|
|
||||||
Region *llvm::ilist_traits<::mlir::Block>::getParentRegion() {
|
Region *llvm::ilist_traits<::mlir::Block>::getParentRegion() {
|
||||||
size_t Offset(
|
size_t offset(
|
||||||
size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr))));
|
size_t(&((Region *)nullptr->*Region::getSublistAccess(nullptr))));
|
||||||
iplist<Block> *Anchor(static_cast<iplist<Block> *>(this));
|
iplist<Block> *anchor(static_cast<iplist<Block> *>(this));
|
||||||
return reinterpret_cast<Region *>(reinterpret_cast<char *>(Anchor) - Offset);
|
return reinterpret_cast<Region *>(reinterpret_cast<char *>(anchor) - offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This is a trait method invoked when a basic block is added to a region.
|
/// This is a trait method invoked when a basic block is added to a region.
|
||||||
|
|
|
@ -76,9 +76,9 @@ static bool wouldOpBeTriviallyDeadImpl(Operation *rootOp) {
|
||||||
|
|
||||||
// Otherwise, if the op has recursive side effects we can treat the
|
// Otherwise, if the op has recursive side effects we can treat the
|
||||||
// operation itself as having no effects.
|
// operation itself as having no effects.
|
||||||
} else if (hasRecursiveEffects) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
if (hasRecursiveEffects)
|
||||||
|
continue;
|
||||||
|
|
||||||
// If there were no effect interfaces, we treat this op as conservatively
|
// If there were no effect interfaces, we treat this op as conservatively
|
||||||
// having effects.
|
// having effects.
|
||||||
|
|
|
@ -525,13 +525,14 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
|
||||||
bool isColon = getToken().is(Token::colon);
|
bool isColon = getToken().is(Token::colon);
|
||||||
if (!isArrow && !isColon) {
|
if (!isArrow && !isColon) {
|
||||||
return emitError("expected '->' or ':'");
|
return emitError("expected '->' or ':'");
|
||||||
} else if (isArrow) {
|
}
|
||||||
|
if (isArrow) {
|
||||||
parseToken(Token::arrow, "expected '->' or '['");
|
parseToken(Token::arrow, "expected '->' or '['");
|
||||||
map = parseAffineMapRange(numDims, numSymbols);
|
map = parseAffineMapRange(numDims, numSymbols);
|
||||||
return map ? success() : failure();
|
return map ? success() : failure();
|
||||||
} else if (parseToken(Token::colon, "expected ':' or '['")) {
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
|
if (parseToken(Token::colon, "expected ':' or '['"))
|
||||||
|
return failure();
|
||||||
|
|
||||||
if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
|
if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -358,8 +358,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
|
||||||
PassInstrumentor *pi = am.getPassInstrumentor();
|
PassInstrumentor *pi = am.getPassInstrumentor();
|
||||||
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
|
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
|
||||||
pass};
|
pass};
|
||||||
auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
|
auto dynamicPipelineCallback = [&](OpPassManager &pipeline,
|
||||||
Operation *root) -> LogicalResult {
|
Operation *root) -> LogicalResult {
|
||||||
if (!op->isAncestor(root))
|
if (!op->isAncestor(root))
|
||||||
return root->emitOpError()
|
return root->emitOpError()
|
||||||
<< "Trying to schedule a dynamic pipeline on an "
|
<< "Trying to schedule a dynamic pipeline on an "
|
||||||
|
@ -379,7 +379,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
|
||||||
verifyPasses, parentInitGeneration,
|
verifyPasses, parentInitGeneration,
|
||||||
pi, &parentInfo);
|
pi, &parentInfo);
|
||||||
};
|
};
|
||||||
pass->passState.emplace(op, am, dynamic_pipeline_callback);
|
pass->passState.emplace(op, am, dynamicPipelineCallback);
|
||||||
|
|
||||||
// Instrument before the pass has run.
|
// Instrument before the pass has run.
|
||||||
if (pi)
|
if (pi)
|
||||||
|
@ -437,7 +437,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
|
||||||
const PassInstrumentation::PipelineParentInfo *parentInfo) {
|
const PassInstrumentation::PipelineParentInfo *parentInfo) {
|
||||||
assert((!instrumentor || parentInfo) &&
|
assert((!instrumentor || parentInfo) &&
|
||||||
"expected parent info if instrumentor is provided");
|
"expected parent info if instrumentor is provided");
|
||||||
auto scope_exit = llvm::make_scope_exit([&] {
|
auto scopeExit = llvm::make_scope_exit([&] {
|
||||||
// Clear out any computed operation analyses. These analyses won't be used
|
// Clear out any computed operation analyses. These analyses won't be used
|
||||||
// any more in this pipeline, and this helps reduce the current working set
|
// any more in this pipeline, and this helps reduce the current working set
|
||||||
// of memory. If preserving these analyses becomes important in the future
|
// of memory. If preserving these analyses becomes important in the future
|
||||||
|
@ -460,7 +460,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
|
||||||
/// type, or nullptr if one does not exist.
|
/// type, or nullptr if one does not exist.
|
||||||
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
|
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
|
||||||
StringRef name) {
|
StringRef name) {
|
||||||
auto it = llvm::find_if(
|
auto *it = llvm::find_if(
|
||||||
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
|
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; });
|
||||||
return it == mgrs.end() ? nullptr : &*it;
|
return it == mgrs.end() ? nullptr : &*it;
|
||||||
}
|
}
|
||||||
|
@ -470,7 +470,7 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
|
||||||
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
|
static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
|
||||||
StringAttr name,
|
StringAttr name,
|
||||||
MLIRContext &context) {
|
MLIRContext &context) {
|
||||||
auto it = llvm::find_if(
|
auto *it = llvm::find_if(
|
||||||
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
|
mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; });
|
||||||
return it == mgrs.end() ? nullptr : &*it;
|
return it == mgrs.end() ? nullptr : &*it;
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,7 +253,7 @@ StringRef StructFieldAttr::getName() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute StructFieldAttr::getType() const {
|
Attribute StructFieldAttr::getType() const {
|
||||||
auto init = def->getValueInit("type");
|
auto *init = def->getValueInit("type");
|
||||||
return Attribute(cast<llvm::DefInit>(init));
|
return Attribute(cast<llvm::DefInit>(init));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ std::string Dialect::getCppClassName() const {
|
||||||
|
|
||||||
static StringRef getAsStringOrEmpty(const llvm::Record &record,
|
static StringRef getAsStringOrEmpty(const llvm::Record &record,
|
||||||
StringRef fieldName) {
|
StringRef fieldName) {
|
||||||
if (auto valueInit = record.getValueInit(fieldName)) {
|
if (auto *valueInit = record.getValueInit(fieldName)) {
|
||||||
if (llvm::isa<llvm::StringInit>(valueInit))
|
if (llvm::isa<llvm::StringInit>(valueInit))
|
||||||
return record.getValueAsString(fieldName);
|
return record.getValueAsString(fieldName);
|
||||||
}
|
}
|
||||||
|
|
|
@ -346,10 +346,9 @@ void Operator::populateTypeInferenceInfo(
|
||||||
if (getArg(*mi).is<NamedAttribute *>()) {
|
if (getArg(*mi).is<NamedAttribute *>()) {
|
||||||
// TODO: Handle attributes.
|
// TODO: Handle attributes.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
|
||||||
resultTypeMapping[i].emplace_back(*mi);
|
|
||||||
found = true;
|
|
||||||
}
|
}
|
||||||
|
resultTypeMapping[i].emplace_back(*mi);
|
||||||
|
found = true;
|
||||||
}
|
}
|
||||||
return found;
|
return found;
|
||||||
};
|
};
|
||||||
|
|
|
@ -649,7 +649,7 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
|
||||||
std::vector<AppliedConstraint> ret;
|
std::vector<AppliedConstraint> ret;
|
||||||
ret.reserve(listInit->size());
|
ret.reserve(listInit->size());
|
||||||
|
|
||||||
for (auto it : *listInit) {
|
for (auto *it : *listInit) {
|
||||||
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||||
if (!dagInit)
|
if (!dagInit)
|
||||||
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
PrintFatalError(&def, "all elements in Pattern multi-entity "
|
||||||
|
|
|
@ -188,7 +188,7 @@ buildPredicateTree(const Pred &root,
|
||||||
// Build child subtrees.
|
// Build child subtrees.
|
||||||
auto combined = static_cast<const CombinedPred &>(root);
|
auto combined = static_cast<const CombinedPred &>(root);
|
||||||
for (const auto *record : combined.getChildren()) {
|
for (const auto *record : combined.getChildren()) {
|
||||||
auto childTree =
|
auto *childTree =
|
||||||
buildPredicateTree(Pred(record), allocator, allSubstitutions);
|
buildPredicateTree(Pred(record), allocator, allSubstitutions);
|
||||||
rootNode->children.push_back(childTree);
|
rootNode->children.push_back(childTree);
|
||||||
}
|
}
|
||||||
|
@ -241,7 +241,7 @@ propagateGroundTruth(PredNode *node,
|
||||||
|
|
||||||
for (auto &child : children) {
|
for (auto &child : children) {
|
||||||
// First, simplify the child. This maintains the predicate as it was.
|
// First, simplify the child. This maintains the predicate as it was.
|
||||||
auto simplifiedChild =
|
auto *simplifiedChild =
|
||||||
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
|
propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
|
||||||
|
|
||||||
// Just add the child if we don't know how to simplify the current node.
|
// Just add the child if we don't know how to simplify the current node.
|
||||||
|
@ -273,8 +273,9 @@ propagateGroundTruth(PredNode *node,
|
||||||
node->kind = collapseKind;
|
node->kind = collapseKind;
|
||||||
node->children.clear();
|
node->children.clear();
|
||||||
return node;
|
return node;
|
||||||
} else if (simplifiedChild->kind == eraseKind ||
|
}
|
||||||
eraseList.count(simplifiedChild->predicate) != 0) {
|
if (simplifiedChild->kind == eraseKind ||
|
||||||
|
eraseList.count(simplifiedChild->predicate) != 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
node->children.push_back(simplifiedChild);
|
node->children.push_back(simplifiedChild);
|
||||||
|
@ -350,7 +351,7 @@ static std::string getCombinedCondition(const PredNode &root) {
|
||||||
|
|
||||||
std::string CombinedPred::getConditionImpl() const {
|
std::string CombinedPred::getConditionImpl() const {
|
||||||
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
|
llvm::SpecificBumpPtrAllocator<PredNode> allocator;
|
||||||
auto predicateTree = buildPredicateTree(*this, allocator, {});
|
auto *predicateTree = buildPredicateTree(*this, allocator, {});
|
||||||
predicateTree =
|
predicateTree =
|
||||||
propagateGroundTruth(predicateTree,
|
propagateGroundTruth(predicateTree,
|
||||||
/*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
|
/*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
|
||||||
|
|
|
@ -26,7 +26,7 @@ using namespace mlir::tblgen;
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Trait Trait::create(const llvm::Init *init) {
|
Trait Trait::create(const llvm::Init *init) {
|
||||||
auto def = cast<llvm::DefInit>(init)->getDef();
|
auto *def = cast<llvm::DefInit>(init)->getDef();
|
||||||
if (def->isSubClassOf("PredTrait"))
|
if (def->isSubClassOf("PredTrait"))
|
||||||
return Trait(Kind::Pred, def);
|
return Trait(Kind::Pred, def);
|
||||||
if (def->isSubClassOf("GenInternalTrait"))
|
if (def->isSubClassOf("GenInternalTrait"))
|
||||||
|
|
|
@ -61,7 +61,7 @@ public:
|
||||||
LogicalResult processFunction(llvm::Function *f);
|
LogicalResult processFunction(llvm::Function *f);
|
||||||
|
|
||||||
/// Imports GV as a GlobalOp, creating it if it doesn't exist.
|
/// Imports GV as a GlobalOp, creating it if it doesn't exist.
|
||||||
GlobalOp processGlobal(llvm::GlobalVariable *GV);
|
GlobalOp processGlobal(llvm::GlobalVariable *gv);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Returns personality of `f` as a FlatSymbolRefAttr.
|
/// Returns personality of `f` as a FlatSymbolRefAttr.
|
||||||
|
@ -145,7 +145,8 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
|
||||||
os << "llvm-imported-inst-%";
|
os << "llvm-imported-inst-%";
|
||||||
inst->printAsOperand(os, /*PrintType=*/false);
|
inst->printAsOperand(os, /*PrintType=*/false);
|
||||||
return FileLineColLoc::get(context, os.str(), 0, 0);
|
return FileLineColLoc::get(context, os.str(), 0, 0);
|
||||||
} else if (!loc) {
|
}
|
||||||
|
if (!loc) {
|
||||||
return unknownLoc;
|
return unknownLoc;
|
||||||
}
|
}
|
||||||
// FIXME: Obtain the filename from DILocationInfo.
|
// FIXME: Obtain the filename from DILocationInfo.
|
||||||
|
@ -304,47 +305,47 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
|
GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) {
|
||||||
auto it = globals.find(GV);
|
auto it = globals.find(gv);
|
||||||
if (it != globals.end())
|
if (it != globals.end())
|
||||||
return it->second;
|
return it->second;
|
||||||
|
|
||||||
OpBuilder b(module.getBody(), getGlobalInsertPt());
|
OpBuilder b(module.getBody(), getGlobalInsertPt());
|
||||||
Attribute valueAttr;
|
Attribute valueAttr;
|
||||||
if (GV->hasInitializer())
|
if (gv->hasInitializer())
|
||||||
valueAttr = getConstantAsAttr(GV->getInitializer());
|
valueAttr = getConstantAsAttr(gv->getInitializer());
|
||||||
Type type = processType(GV->getValueType());
|
Type type = processType(gv->getValueType());
|
||||||
if (!type)
|
if (!type)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
uint64_t alignment = 0;
|
uint64_t alignment = 0;
|
||||||
llvm::MaybeAlign maybeAlign = GV->getAlign();
|
llvm::MaybeAlign maybeAlign = gv->getAlign();
|
||||||
if (maybeAlign.hasValue()) {
|
if (maybeAlign.hasValue()) {
|
||||||
llvm::Align align = maybeAlign.getValue();
|
llvm::Align align = maybeAlign.getValue();
|
||||||
alignment = align.value();
|
alignment = align.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalOp op =
|
GlobalOp op =
|
||||||
b.create<GlobalOp>(UnknownLoc::get(context), type, GV->isConstant(),
|
b.create<GlobalOp>(UnknownLoc::get(context), type, gv->isConstant(),
|
||||||
convertLinkageFromLLVM(GV->getLinkage()),
|
convertLinkageFromLLVM(gv->getLinkage()),
|
||||||
GV->getName(), valueAttr, alignment);
|
gv->getName(), valueAttr, alignment);
|
||||||
|
|
||||||
if (GV->hasInitializer() && !valueAttr) {
|
if (gv->hasInitializer() && !valueAttr) {
|
||||||
Region &r = op.getInitializerRegion();
|
Region &r = op.getInitializerRegion();
|
||||||
currentEntryBlock = b.createBlock(&r);
|
currentEntryBlock = b.createBlock(&r);
|
||||||
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
|
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
|
||||||
Value v = processConstant(GV->getInitializer());
|
Value v = processConstant(gv->getInitializer());
|
||||||
if (!v)
|
if (!v)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
|
b.create<ReturnOp>(op.getLoc(), ArrayRef<Value>({v}));
|
||||||
}
|
}
|
||||||
if (GV->hasAtLeastLocalUnnamedAddr())
|
if (gv->hasAtLeastLocalUnnamedAddr())
|
||||||
op.setUnnamedAddrAttr(UnnamedAddrAttr::get(
|
op.setUnnamedAddrAttr(UnnamedAddrAttr::get(
|
||||||
context, convertUnnamedAddrFromLLVM(GV->getUnnamedAddr())));
|
context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr())));
|
||||||
if (GV->hasSection())
|
if (gv->hasSection())
|
||||||
op.setSectionAttr(b.getStringAttr(GV->getSection()));
|
op.setSectionAttr(b.getStringAttr(gv->getSection()));
|
||||||
|
|
||||||
return globals[GV] = op;
|
return globals[gv] = op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value Importer::processConstant(llvm::Constant *c) {
|
Value Importer::processConstant(llvm::Constant *c) {
|
||||||
|
@ -366,9 +367,9 @@ Value Importer::processConstant(llvm::Constant *c) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
|
return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
|
||||||
}
|
}
|
||||||
if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
|
if (auto *gv = dyn_cast<llvm::GlobalVariable>(c))
|
||||||
return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
|
return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
|
||||||
processGlobal(GV));
|
processGlobal(gv));
|
||||||
|
|
||||||
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
|
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
|
||||||
llvm::Instruction *i = ce->getAsInstruction();
|
llvm::Instruction *i = ce->getAsInstruction();
|
||||||
|
@ -526,8 +527,8 @@ LogicalResult
|
||||||
Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
|
Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
|
||||||
SmallVectorImpl<Value> &blockArguments) {
|
SmallVectorImpl<Value> &blockArguments) {
|
||||||
for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
|
for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
|
||||||
auto *PN = cast<llvm::PHINode>(&*inst);
|
auto *pn = cast<llvm::PHINode>(&*inst);
|
||||||
Value value = processValue(PN->getIncomingValueForBlock(br->getParent()));
|
Value value = processValue(pn->getIncomingValueForBlock(br->getParent()));
|
||||||
if (!value)
|
if (!value)
|
||||||
return failure();
|
return failure();
|
||||||
blockArguments.push_back(value);
|
blockArguments.push_back(value);
|
||||||
|
@ -777,10 +778,10 @@ FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) {
|
||||||
|
|
||||||
// If it doesn't have a name, currently, only function pointers that are
|
// If it doesn't have a name, currently, only function pointers that are
|
||||||
// bitcast to i8* are parsed.
|
// bitcast to i8* are parsed.
|
||||||
if (auto ce = dyn_cast<llvm::ConstantExpr>(pf)) {
|
if (auto *ce = dyn_cast<llvm::ConstantExpr>(pf)) {
|
||||||
if (ce->getOpcode() == llvm::Instruction::BitCast &&
|
if (ce->getOpcode() == llvm::Instruction::BitCast &&
|
||||||
ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
|
ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
|
||||||
if (auto func = dyn_cast<llvm::Function>(ce->getOperand(0)))
|
if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0)))
|
||||||
return SymbolRefAttr::get(b.getContext(), func->getName());
|
return SymbolRefAttr::get(b.getContext(), func->getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,12 +55,11 @@ static llvm::Constant *createSourceLocStrFromLocation(Location loc,
|
||||||
unsigned lineNo = fileLoc.getLine();
|
unsigned lineNo = fileLoc.getLine();
|
||||||
unsigned colNo = fileLoc.getColumn();
|
unsigned colNo = fileLoc.getColumn();
|
||||||
return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
|
return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
|
||||||
} else {
|
|
||||||
std::string locStr;
|
|
||||||
llvm::raw_string_ostream locOS(locStr);
|
|
||||||
locOS << loc;
|
|
||||||
return builder.getOrCreateSrcLocStr(locOS.str());
|
|
||||||
}
|
}
|
||||||
|
std::string locStr;
|
||||||
|
llvm::raw_string_ostream locOS(locStr);
|
||||||
|
locOS << loc;
|
||||||
|
return builder.getOrCreateSrcLocStr(locOS.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create the location struct from the operation location information.
|
/// Create the location struct from the operation location information.
|
||||||
|
@ -81,9 +80,8 @@ static llvm::Constant *createMappingInformation(Location loc,
|
||||||
if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
|
if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
|
||||||
StringRef name = nameLoc.getName();
|
StringRef name = nameLoc.getName();
|
||||||
return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
|
return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
|
||||||
} else {
|
|
||||||
return createSourceLocStrFromLocation(loc, builder, "unknown");
|
|
||||||
}
|
}
|
||||||
|
return createSourceLocStrFromLocation(loc, builder, "unknown");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the runtime function used to lower the given operation.
|
/// Return the runtime function used to lower the given operation.
|
||||||
|
|
|
@ -861,11 +861,11 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
|
// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
|
||||||
llvm::AtomicOrdering convertAtomicOrdering(Optional<StringRef> AOAttr) {
|
llvm::AtomicOrdering convertAtomicOrdering(Optional<StringRef> aoAttr) {
|
||||||
if (!AOAttr.hasValue())
|
if (!aoAttr.hasValue())
|
||||||
return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
|
return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
|
||||||
|
|
||||||
return StringSwitch<llvm::AtomicOrdering>(AOAttr.getValue())
|
return StringSwitch<llvm::AtomicOrdering>(aoAttr.getValue())
|
||||||
.Case("seq_cst", llvm::AtomicOrdering::SequentiallyConsistent)
|
.Case("seq_cst", llvm::AtomicOrdering::SequentiallyConsistent)
|
||||||
.Case("acq_rel", llvm::AtomicOrdering::AcquireRelease)
|
.Case("acq_rel", llvm::AtomicOrdering::AcquireRelease)
|
||||||
.Case("acquire", llvm::AtomicOrdering::Acquire)
|
.Case("acquire", llvm::AtomicOrdering::Acquire)
|
||||||
|
@ -889,7 +889,7 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||||
moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
|
moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
|
||||||
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
|
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
|
||||||
llvm::DebugLoc(diLoc));
|
llvm::DebugLoc(diLoc));
|
||||||
llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order());
|
llvm::AtomicOrdering ao = convertAtomicOrdering(readOp.memory_order());
|
||||||
llvm::Value *address = moduleTranslation.lookupValue(readOp.address());
|
llvm::Value *address = moduleTranslation.lookupValue(readOp.address());
|
||||||
llvm::OpenMPIRBuilder::InsertPointTy currentIP = builder.saveIP();
|
llvm::OpenMPIRBuilder::InsertPointTy currentIP = builder.saveIP();
|
||||||
|
|
||||||
|
@ -903,9 +903,9 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||||
|
|
||||||
// Restore the IP and insert Atomic Read.
|
// Restore the IP and insert Atomic Read.
|
||||||
builder.restoreIP(currentIP);
|
builder.restoreIP(currentIP);
|
||||||
llvm::OpenMPIRBuilder::AtomicOpValue V = {v, false, false};
|
llvm::OpenMPIRBuilder::AtomicOpValue atomicV = {v, false, false};
|
||||||
llvm::OpenMPIRBuilder::AtomicOpValue X = {address, false, false};
|
llvm::OpenMPIRBuilder::AtomicOpValue x = {address, false, false};
|
||||||
builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
|
builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, x, atomicV, ao));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,17 +29,17 @@ using mlir::LLVM::detail::createIntrinsicCall;
|
||||||
// take a single int32 argument. It is likely that the interface of this
|
// take a single int32 argument. It is likely that the interface of this
|
||||||
// function will change to make it more generic.
|
// function will change to make it more generic.
|
||||||
static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
|
static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder,
|
||||||
StringRef fn_name, int parameter) {
|
StringRef fnName, int parameter) {
|
||||||
llvm::Module *module = builder.GetInsertBlock()->getModule();
|
llvm::Module *module = builder.GetInsertBlock()->getModule();
|
||||||
llvm::FunctionType *function_type = llvm::FunctionType::get(
|
llvm::FunctionType *functionType = llvm::FunctionType::get(
|
||||||
llvm::Type::getInt64Ty(module->getContext()), // return type.
|
llvm::Type::getInt64Ty(module->getContext()), // return type.
|
||||||
llvm::Type::getInt32Ty(module->getContext()), // parameter type.
|
llvm::Type::getInt32Ty(module->getContext()), // parameter type.
|
||||||
false); // no variadic arguments.
|
false); // no variadic arguments.
|
||||||
llvm::Function *fn = dyn_cast<llvm::Function>(
|
llvm::Function *fn = dyn_cast<llvm::Function>(
|
||||||
module->getOrInsertFunction(fn_name, function_type).getCallee());
|
module->getOrInsertFunction(fnName, functionType).getCallee());
|
||||||
llvm::Value *fn_op0 = llvm::ConstantInt::get(
|
llvm::Value *fnOp0 = llvm::ConstantInt::get(
|
||||||
llvm::Type::getInt32Ty(module->getContext()), parameter);
|
llvm::Type::getInt32Ty(module->getContext()), parameter);
|
||||||
return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fn_op0));
|
return builder.CreateCall(fn, ArrayRef<llvm::Value *>(fnOp0));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -242,10 +242,10 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
||||||
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
|
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
|
||||||
elementType = arrayTy->getElementType();
|
elementType = arrayTy->getElementType();
|
||||||
numElements = arrayTy->getNumElements();
|
numElements = arrayTy->getNumElements();
|
||||||
} else if (auto fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
|
} else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
|
||||||
elementType = fVectorTy->getElementType();
|
elementType = fVectorTy->getElementType();
|
||||||
numElements = fVectorTy->getNumElements();
|
numElements = fVectorTy->getNumElements();
|
||||||
} else if (auto sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
|
} else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
|
||||||
elementType = sVectorTy->getElementType();
|
elementType = sVectorTy->getElementType();
|
||||||
numElements = sVectorTy->getMinNumElements();
|
numElements = sVectorTy->getMinNumElements();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1525,7 +1525,7 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
|
||||||
|
|
||||||
// Handle named results.
|
// Handle named results.
|
||||||
auto elementNames = tupleType.getElementNames();
|
auto elementNames = tupleType.getElementNames();
|
||||||
auto it = llvm::find(elementNames, name);
|
const auto *it = llvm::find(elementNames, name);
|
||||||
if (it != elementNames.end())
|
if (it != elementNames.end())
|
||||||
return tupleType.getElementTypes()[it - elementNames.begin()];
|
return tupleType.getElementTypes()[it - elementNames.begin()];
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,7 +133,7 @@ static bool isDefOrUse(const AsmParserState::SMDefinition &def, llvm::SMLoc loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the uses.
|
// Check the uses.
|
||||||
auto useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
|
const auto *useIt = llvm::find_if(def.uses, [&](const llvm::SMRange &range) {
|
||||||
return contains(range, loc);
|
return contains(range, loc);
|
||||||
});
|
});
|
||||||
if (useIt != def.uses.end()) {
|
if (useIt != def.uses.end()) {
|
||||||
|
|
|
@ -42,20 +42,20 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
|
||||||
MLIRContext &context) {
|
MLIRContext &context) {
|
||||||
// Override the default '-h' and use the default PrintHelpMessage() which
|
// Override the default '-h' and use the default PrintHelpMessage() which
|
||||||
// won't print options in categories.
|
// won't print options in categories.
|
||||||
static llvm::cl::opt<bool> Help("h", llvm::cl::desc("Alias for -help"),
|
static llvm::cl::opt<bool> help("h", llvm::cl::desc("Alias for -help"),
|
||||||
llvm::cl::Hidden);
|
llvm::cl::Hidden);
|
||||||
|
|
||||||
static llvm::cl::OptionCategory MLIRReduceCategory("mlir-reduce options");
|
static llvm::cl::OptionCategory mlirReduceCategory("mlir-reduce options");
|
||||||
|
|
||||||
static llvm::cl::opt<std::string> inputFilename(
|
static llvm::cl::opt<std::string> inputFilename(
|
||||||
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
||||||
llvm::cl::cat(MLIRReduceCategory));
|
llvm::cl::cat(mlirReduceCategory));
|
||||||
|
|
||||||
static llvm::cl::opt<std::string> outputFilename(
|
static llvm::cl::opt<std::string> outputFilename(
|
||||||
"o", llvm::cl::desc("Output filename for the reduced test case"),
|
"o", llvm::cl::desc("Output filename for the reduced test case"),
|
||||||
llvm::cl::init("-"), llvm::cl::cat(MLIRReduceCategory));
|
llvm::cl::init("-"), llvm::cl::cat(mlirReduceCategory));
|
||||||
|
|
||||||
llvm::cl::HideUnrelatedOptions(MLIRReduceCategory);
|
llvm::cl::HideUnrelatedOptions(mlirReduceCategory);
|
||||||
|
|
||||||
llvm::InitLLVM y(argc, argv);
|
llvm::InitLLVM y(argc, argv);
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv,
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv,
|
llvm::cl::ParseCommandLineOptions(argc, argv,
|
||||||
"MLIR test case reduction tool.\n");
|
"MLIR test case reduction tool.\n");
|
||||||
|
|
||||||
if (Help) {
|
if (help) {
|
||||||
llvm::cl::PrintHelpMessage();
|
llvm::cl::PrintHelpMessage();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,14 +301,15 @@ public:
|
||||||
memrefEdgeCount[value]--;
|
memrefEdgeCount[value]--;
|
||||||
}
|
}
|
||||||
// Remove 'srcId' from 'inEdges[dstId]'.
|
// Remove 'srcId' from 'inEdges[dstId]'.
|
||||||
for (auto it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
|
for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
|
||||||
if ((*it).id == srcId && (*it).value == value) {
|
if ((*it).id == srcId && (*it).value == value) {
|
||||||
inEdges[dstId].erase(it);
|
inEdges[dstId].erase(it);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove 'dstId' from 'outEdges[srcId]'.
|
// Remove 'dstId' from 'outEdges[srcId]'.
|
||||||
for (auto it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
|
for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end();
|
||||||
|
++it) {
|
||||||
if ((*it).id == dstId && (*it).value == value) {
|
if ((*it).id == dstId && (*it).value == value) {
|
||||||
outEdges[srcId].erase(it);
|
outEdges[srcId].erase(it);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -85,7 +85,7 @@ LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) {
|
||||||
|
|
||||||
// Helper to check whether an operation is loop invariant wrt. SSA properties.
|
// Helper to check whether an operation is loop invariant wrt. SSA properties.
|
||||||
auto isDefinedOutsideOfBody = [&](Value value) {
|
auto isDefinedOutsideOfBody = [&](Value value) {
|
||||||
auto definingOp = value.getDefiningOp();
|
auto *definingOp = value.getDefiningOp();
|
||||||
return (definingOp && !!willBeMovedSet.count(definingOp)) ||
|
return (definingOp && !!willBeMovedSet.count(definingOp)) ||
|
||||||
looplike.isDefinedOutsideOfLoop(value);
|
looplike.isDefinedOutsideOfLoop(value);
|
||||||
};
|
};
|
||||||
|
|
|
@ -517,6 +517,6 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
|
||||||
newRegion->takeBody(oldRegion);
|
newRegion->takeBody(oldRegion);
|
||||||
}
|
}
|
||||||
return bb.createOperation(result);
|
return bb.createOperation(result);
|
||||||
} else
|
}
|
||||||
return oldOp;
|
return oldOp;
|
||||||
}
|
}
|
||||||
|
|
|
@ -191,7 +191,7 @@ static void findMatchingStartFinishInsts(
|
||||||
// Check for dependence with outgoing DMAs. Doing this conservatively.
|
// Check for dependence with outgoing DMAs. Doing this conservatively.
|
||||||
// TODO: use the dependence analysis to check for
|
// TODO: use the dependence analysis to check for
|
||||||
// dependences between an incoming and outgoing DMA in the same iteration.
|
// dependences between an incoming and outgoing DMA in the same iteration.
|
||||||
auto it = outgoingDmaOps.begin();
|
auto *it = outgoingDmaOps.begin();
|
||||||
for (; it != outgoingDmaOps.end(); ++it) {
|
for (; it != outgoingDmaOps.end(); ++it) {
|
||||||
if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
|
if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -168,7 +168,7 @@ LogicalResult OperationFolder::tryToFold(
|
||||||
if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
|
if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
|
||||||
std::stable_partition(
|
std::stable_partition(
|
||||||
op->getOpOperands().begin(), op->getOpOperands().end(),
|
op->getOpOperands().begin(), op->getOpOperands().end(),
|
||||||
[&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
|
[&](OpOperand &o) { return !matchPattern(o.get(), m_Constant()); });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check to see if any operands to the operation is constant and whether
|
// Check to see if any operands to the operation is constant and whether
|
||||||
|
|
|
@ -56,7 +56,8 @@ static bool isDependentLoadOrStoreOp(Operation *op,
|
||||||
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
|
if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
|
||||||
return values.count(loadOp.getMemRef()) > 0 &&
|
return values.count(loadOp.getMemRef()) > 0 &&
|
||||||
values[loadOp.getMemRef()] == true;
|
values[loadOp.getMemRef()] == true;
|
||||||
} else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
|
}
|
||||||
|
if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
|
||||||
return values.count(storeOp.getMemRef()) > 0;
|
return values.count(storeOp.getMemRef()) > 0;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -3034,7 +3034,7 @@ uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
|
||||||
auto updateRegion =
|
auto updateRegion =
|
||||||
[&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
|
[&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
|
||||||
&targetRegions) {
|
&targetRegions) {
|
||||||
const auto it = targetRegions.find(region->memref);
|
const auto *const it = targetRegions.find(region->memref);
|
||||||
if (it == targetRegions.end())
|
if (it == targetRegions.end())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ struct TestAliasAnalysisPass
|
||||||
|
|
||||||
// Check for aliasing behavior between each of the values.
|
// Check for aliasing behavior between each of the values.
|
||||||
for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
|
for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
|
||||||
for (auto innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
|
for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
|
||||||
printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
|
printAliasResult(aliasAnalysis.alias(*innerIt, *it), *innerIt, *it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -52,9 +52,9 @@ struct TestMathPolynomialApproximationPass
|
||||||
|
|
||||||
void TestMathPolynomialApproximationPass::runOnFunction() {
|
void TestMathPolynomialApproximationPass::runOnFunction() {
|
||||||
RewritePatternSet patterns(&getContext());
|
RewritePatternSet patterns(&getContext());
|
||||||
MathPolynomialApproximationOptions approx_options;
|
MathPolynomialApproximationOptions approxOptions;
|
||||||
approx_options.enableAvx2 = enableAvx2;
|
approxOptions.enableAvx2 = enableAvx2;
|
||||||
populateMathPolynomialApproximationPatterns(patterns, approx_options);
|
populateMathPolynomialApproximationPatterns(patterns, approxOptions);
|
||||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -689,24 +689,24 @@ static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
|
||||||
Region &body = *result.addRegion();
|
Region &body = *result.addRegion();
|
||||||
body.push_back(new Block);
|
body.push_back(new Block);
|
||||||
Block &block = body.back();
|
Block &block = body.back();
|
||||||
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
|
Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
|
||||||
if (!wrapped_op)
|
if (!wrappedOp)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Create a return terminator in the inner region, pass as operand to the
|
// Create a return terminator in the inner region, pass as operand to the
|
||||||
// terminator the returned values from the wrapped operation.
|
// terminator the returned values from the wrapped operation.
|
||||||
SmallVector<Value, 8> return_operands(wrapped_op->getResults());
|
SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
|
||||||
OpBuilder builder(parser.getContext());
|
OpBuilder builder(parser.getContext());
|
||||||
builder.setInsertionPointToEnd(&block);
|
builder.setInsertionPointToEnd(&block);
|
||||||
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
|
builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
|
||||||
|
|
||||||
// Get the results type for the wrapping op from the terminator operands.
|
// Get the results type for the wrapping op from the terminator operands.
|
||||||
Operation &return_op = body.back().back();
|
Operation &returnOp = body.back().back();
|
||||||
result.types.append(return_op.operand_type_begin(),
|
result.types.append(returnOp.operand_type_begin(),
|
||||||
return_op.operand_type_end());
|
returnOp.operand_type_end());
|
||||||
|
|
||||||
// Use the location of the wrapped op for the "test.wrapping_region" op.
|
// Use the location of the wrapped op for the "test.wrapping_region" op.
|
||||||
result.location = wrapped_op->getLoc();
|
result.location = wrappedOp->getLoc();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -808,7 +808,7 @@ def OpFuncRef : TEST_Op<"op_funcref"> {
|
||||||
// That way, we will know if operations is called once or twice.
|
// That way, we will know if operations is called once or twice.
|
||||||
def OpMGetNullAttr : NativeCodeCall<"Attribute()">;
|
def OpMGetNullAttr : NativeCodeCall<"Attribute()">;
|
||||||
def OpMAttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">;
|
def OpMAttributeIsNull : Constraint<CPred<"! ($_self)">, "Attribute is null">;
|
||||||
def OpMVal : NativeCodeCall<"OpMTest($_builder, $0)">;
|
def OpMVal : NativeCodeCall<"opMTest($_builder, $0)">;
|
||||||
def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ),
|
def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ),
|
||||||
[(OpMAttributeIsNull:$optAttr)]>;
|
[(OpMAttributeIsNull:$optAttr)]>;
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue