[MLIR] Matrix: support resizing horizontally

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D109897
This commit is contained in:
Arjun P 2021-09-17 13:14:50 +05:30
parent 1a5ab3e97c
commit f263ea1571
3 changed files with 45 additions and 2 deletions

View File

@ -116,8 +116,14 @@ public:
void negateColumn(unsigned column);
/// Resize the matrix to the specified dimensions. If a dimension is smaller,
/// the values are truncated; if it is bigger, the new values are default
/// initialized.
/// the values are truncated; if it is bigger, the new values are initialized
/// to zero.
///
/// Due to the representation of the matrix, resizing vertically (adding rows)
/// is less expensive than increasing the number of columns beyond
/// nReservedColumns.
void resize(unsigned newNRows, unsigned newNColumns);
void resizeHorizontally(unsigned newNColumns);
void resizeVertically(unsigned newNRows);
/// Add an extra row at the bottom of the matrix and return its position.

View File

@ -65,6 +65,18 @@ unsigned Matrix::appendExtraRow() {
return nRows - 1;
}
void Matrix::resizeHorizontally(unsigned newNColumns) {
if (newNColumns < nColumns)
removeColumns(newNColumns, nColumns - newNColumns);
if (newNColumns > nColumns)
insertColumns(nColumns, newNColumns - nColumns);
}
void Matrix::resize(unsigned newNRows, unsigned newNColumns) {
resizeHorizontally(newNColumns);
resizeVertically(newNRows);
}
void Matrix::resizeVertically(unsigned newNRows) {
nRows = newNRows;
data.resize(nRows * nReservedColumns);

View File

@ -166,4 +166,29 @@ TEST(MatrixTest, insertRows) {
EXPECT_EQ(mat(row, col), row == 5 ? 0 : 10 * row + col);
}
TEST(MatrixTest, resize) {
Matrix mat(5, 5);
EXPECT_EQ(mat.getNumRows(), 5u);
EXPECT_EQ(mat.getNumColumns(), 5u);
for (unsigned row = 0; row < 5; ++row)
for (unsigned col = 0; col < 5; ++col)
mat(row, col) = 10 * row + col;
mat.resize(3, 3);
ASSERT_TRUE(mat.hasConsistentState());
EXPECT_EQ(mat.getNumRows(), 3u);
EXPECT_EQ(mat.getNumColumns(), 3u);
for (unsigned row = 0; row < 3; ++row)
for (unsigned col = 0; col < 3; ++col)
EXPECT_EQ(mat(row, col), int(10 * row + col));
mat.resize(7, 7);
ASSERT_TRUE(mat.hasConsistentState());
EXPECT_EQ(mat.getNumRows(), 7u);
EXPECT_EQ(mat.getNumColumns(), 7u);
for (unsigned row = 0; row < 7; ++row)
for (unsigned col = 0; col < 7; ++col)
EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
}
} // namespace mlir