forked from OSchip/llvm-project
[mlir][python] Swap shape and element_type order for MemRefType.
* Matches how all of the other shaped types are declared. * No super principled reason fro this ordering beyond that it makes the one that was different be like the rest. * Also matches ordering of things like ndarray, et al. Reviewed By: ftynse, nicolasvasilache Differential Revision: https://reviews.llvm.org/D94812
This commit is contained in:
parent
7f36df0fb1
commit
b62c7e0474
|
@ -31,9 +31,9 @@ def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]:
|
|||
|
||||
|
||||
def build_matmul_buffers_func(func_name, m, k, n, dtype):
|
||||
lhs_type = MemRefType.get(dtype, [m, k])
|
||||
rhs_type = MemRefType.get(dtype, [k, n])
|
||||
result_type = MemRefType.get(dtype, [m, n])
|
||||
lhs_type = MemRefType.get([m, k], dtype)
|
||||
rhs_type = MemRefType.get([k, n], dtype)
|
||||
result_type = MemRefType.get([m, n], dtype)
|
||||
# TODO: There should be a one-liner for this.
|
||||
func_type = FunctionType.get([lhs_type, rhs_type, result_type], [])
|
||||
_, entry = FuncOp(func_name, func_type)
|
||||
|
@ -49,8 +49,6 @@ def build_matmul_buffers_func(func_name, m, k, n, dtype):
|
|||
|
||||
|
||||
def build_matmul_tensors_func(func_name, m, k, n, dtype):
|
||||
# TODO: MemRefType and TensorTypes should not have inverted dtype/shapes
|
||||
# from each other.
|
||||
lhs_type = RankedTensorType.get([m, k], dtype)
|
||||
rhs_type = RankedTensorType.get([k, n], dtype)
|
||||
result_type = RankedTensorType.get([m, n], dtype)
|
||||
|
|
|
@ -2832,7 +2832,7 @@ public:
|
|||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static(
|
||||
"get",
|
||||
[](PyType &elementType, std::vector<int64_t> shape,
|
||||
[](std::vector<int64_t> shape, PyType &elementType,
|
||||
std::vector<PyAffineMap> layout, unsigned memorySpace,
|
||||
DefaultingPyLocation loc) {
|
||||
SmallVector<MlirAffineMap> maps;
|
||||
|
@ -2856,7 +2856,7 @@ public:
|
|||
}
|
||||
return PyMemRefType(elementType.getContext(), t);
|
||||
},
|
||||
py::arg("element_type"), py::arg("shape"),
|
||||
py::arg("shape"), py::arg("element_type"),
|
||||
py::arg("layout") = py::list(), py::arg("memory_space") = 0,
|
||||
py::arg("loc") = py::none(), "Create a memref type")
|
||||
.def_property_readonly("layout", &PyMemRefType::getLayout,
|
||||
|
|
|
@ -326,7 +326,7 @@ def testMemRefType():
|
|||
f32 = F32Type.get()
|
||||
shape = [2, 3]
|
||||
loc = Location.unknown()
|
||||
memref = MemRefType.get(f32, shape, memory_space=2)
|
||||
memref = MemRefType.get(shape, f32, memory_space=2)
|
||||
# CHECK: memref type: memref<2x3xf32, 2>
|
||||
print("memref type:", memref)
|
||||
# CHECK: number of affine layout maps: 0
|
||||
|
@ -335,7 +335,7 @@ def testMemRefType():
|
|||
print("memory space:", memref.memory_space)
|
||||
|
||||
layout = AffineMap.get_permutation([1, 0])
|
||||
memref_layout = MemRefType.get(f32, shape, [layout])
|
||||
memref_layout = MemRefType.get(shape, f32, [layout])
|
||||
# CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
|
||||
print("memref type:", memref_layout)
|
||||
assert len(memref_layout.layout) == 1
|
||||
|
@ -346,7 +346,7 @@ def testMemRefType():
|
|||
|
||||
none = NoneType.get()
|
||||
try:
|
||||
memref_invalid = MemRefType.get(none, shape)
|
||||
memref_invalid = MemRefType.get(shape, none)
|
||||
except ValueError as e:
|
||||
# CHECK: invalid 'Type(none)' and expected floating point, integer, vector
|
||||
# CHECK: or complex type.
|
||||
|
|
Loading…
Reference in New Issue