[Matrix] Add tests for addition transpose optimizations

Tests before transpose optimizations around additions.

Differential Revision: https://reviews.llvm.org/D133656
This commit is contained in:
Francis Visoiu Mistrih 2022-09-10 22:13:41 -07:00
parent 0ce96e06ee
commit 0fcc99ade4
1 changed files with 166 additions and 0 deletions

View File

@ -94,6 +94,172 @@ entry:
ret void
}
; A^T + B^T -> (A + B)^T
define void @at_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
; CHECK-LABEL: @at_plus_bt(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[AT]], [[BT]]
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
%fadd = fadd <9 x double> %at, %bt
store <9 x double> %fadd, <9 x double>* %C
ret void
}
; (A + B)^T -> A^T + B^T -> (A + B)^T
define void @a_plus_b_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
; CHECK-LABEL: @a_plus_b_t(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[A]], [[B]]
; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[C:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%fadd = fadd <9 x double> %a, %b
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
store <9 x double> %t, <9 x double>* %C
ret void
}
; A^T * B^T + C^T * D^T -> (B * A + D * C)^T
define void @atbt_plus_ctdt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, <9 x double>* %E) {
; CHECK-LABEL: @atbt_plus_ctdt(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
; CHECK-NEXT: [[TMP2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[TMP3:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP2]], i32 3, i32 3)
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[TMP3]]
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[E:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%c = load <9 x double>, <9 x double>* %Cptr
%d = load <9 x double>, <9 x double>* %Dptr
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
%ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
%dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
%atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
%ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3)
%fadd = fadd <9 x double> %atbt, %ctdt
store <9 x double> %fadd, <9 x double>* %E
ret void
}
; -(A^T) + B^T
define void @negat_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
; CHECK-LABEL: @negat_plus_bt(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
; CHECK-NEXT: [[NEGAT:%.*]] = fneg <9 x double> [[AT]]
; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]]
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
%negat = fneg <9 x double> %at
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
%fadd = fadd <9 x double> %negat, %bt
store <9 x double> %fadd, <9 x double>* %C
ret void
}
; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k)
define void @atbt_plus_kctdt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, double %k, <9 x double>* %E) {
; CHECK-LABEL: @atbt_plus_kctdt_t(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
; CHECK-NEXT: [[CT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[C]], i32 3, i32 3)
; CHECK-NEXT: [[DT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[D]], i32 3, i32 3)
; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
; CHECK-NEXT: [[KCT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[CT]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[KCTDT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[KCT]], <9 x double> [[DT]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[KCTDT]]
; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[E:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%c = load <9 x double>, <9 x double>* %Cptr
%d = load <9 x double>, <9 x double>* %Dptr
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
%ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
%dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
%atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
%veck = insertelement <9 x double> poison, double %k, i64 0
%splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
%kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3)
%kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3)
%fadd = fadd <9 x double> %atbt, %kctdt
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
store <9 x double> %t, <9 x double>* %E
ret void
}
; (A^T * (k * B^T))^T => (B * k) * A
define void @atkbt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, double %k, <9 x double>* %C) {
; CHECK-LABEL: @atkbt_t(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3)
; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
; CHECK-NEXT: ret void
;
entry:
%a = load <9 x double>, <9 x double>* %Aptr
%b = load <9 x double>, <9 x double>* %Bptr
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
%veck = insertelement <9 x double> poison, double %k, i64 0
%splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
%kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3)
%atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3)
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3)
store <9 x double> %t, <9 x double>* %C
ret void
}
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)