forked from OSchip/llvm-project
[Matrix] Add tests for addition transpose optimizations
Tests before transpose optimizations around additions. Differential Revision: https://reviews.llvm.org/D133656
This commit is contained in:
parent
0ce96e06ee
commit
0fcc99ade4
|
@ -94,6 +94,172 @@ entry:
|
|||
ret void
|
||||
}
|
||||
|
||||
; A^T + B^T -> (A + B)^T
|
||||
define void @at_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
|
||||
; CHECK-LABEL: @at_plus_bt(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[AT]], [[BT]]
|
||||
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
|
||||
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
|
||||
%fadd = fadd <9 x double> %at, %bt
|
||||
store <9 x double> %fadd, <9 x double>* %C
|
||||
ret void
|
||||
}
|
||||
|
||||
; (A + B)^T -> A^T + B^T -> (A + B)^T
|
||||
define void @a_plus_b_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
|
||||
; CHECK-LABEL: @a_plus_b_t(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[A]], [[B]]
|
||||
; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
|
||||
; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[C:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%fadd = fadd <9 x double> %a, %b
|
||||
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
|
||||
store <9 x double> %t, <9 x double>* %C
|
||||
ret void
|
||||
}
|
||||
|
||||
; A^T * B^T + C^T * D^T -> (B * A + D * C)^T
|
||||
define void @atbt_plus_ctdt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, <9 x double>* %E) {
|
||||
; CHECK-LABEL: @atbt_plus_ctdt(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[D]], <9 x double> [[C]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP2]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[TMP3]]
|
||||
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[E:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%c = load <9 x double>, <9 x double>* %Cptr
|
||||
%d = load <9 x double>, <9 x double>* %Dptr
|
||||
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
|
||||
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
|
||||
%ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
|
||||
%dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
|
||||
%atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
|
||||
%ctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %ct, <9 x double> %dt, i32 3, i32 3, i32 3)
|
||||
%fadd = fadd <9 x double> %atbt, %ctdt
|
||||
store <9 x double> %fadd, <9 x double>* %E
|
||||
ret void
|
||||
}
|
||||
|
||||
; -(A^T) + B^T
|
||||
define void @negat_plus_bt(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %C) {
|
||||
; CHECK-LABEL: @negat_plus_bt(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[AT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[A]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[NEGAT:%.*]] = fneg <9 x double> [[AT]]
|
||||
; CHECK-NEXT: [[BT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[B]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[NEGAT]], [[BT]]
|
||||
; CHECK-NEXT: store <9 x double> [[FADD]], <9 x double>* [[C:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
|
||||
%negat = fneg <9 x double> %at
|
||||
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
|
||||
%fadd = fadd <9 x double> %negat, %bt
|
||||
store <9 x double> %fadd, <9 x double>* %C
|
||||
ret void
|
||||
}
|
||||
|
||||
; (A^T * B^T + k * C^T * D^T)^T -> (B * A) + (D * C * k)
|
||||
define void @atbt_plus_kctdt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, <9 x double>* %Cptr, <9 x double>* %Dptr, double %k, <9 x double>* %E) {
|
||||
; CHECK-LABEL: @atbt_plus_kctdt_t(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[CPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[D:%.*]] = load <9 x double>, <9 x double>* [[DPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[CT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[C]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[DT:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[D]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[A]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[TMP0]], i32 3, i32 3)
|
||||
; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
|
||||
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
|
||||
; CHECK-NEXT: [[KCT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[SPLAT]], <9 x double> [[CT]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[KCTDT:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[KCT]], <9 x double> [[DT]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[FADD:%.*]] = fadd <9 x double> [[TMP1]], [[KCTDT]]
|
||||
; CHECK-NEXT: [[T:%.*]] = call <9 x double> @llvm.matrix.transpose.v9f64(<9 x double> [[FADD]], i32 3, i32 3)
|
||||
; CHECK-NEXT: store <9 x double> [[T]], <9 x double>* [[E:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%c = load <9 x double>, <9 x double>* %Cptr
|
||||
%d = load <9 x double>, <9 x double>* %Dptr
|
||||
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
|
||||
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
|
||||
%ct = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %c, i32 3, i32 3)
|
||||
%dt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %d, i32 3, i32 3)
|
||||
%atbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %bt, i32 3, i32 3, i32 3)
|
||||
%veck = insertelement <9 x double> poison, double %k, i64 0
|
||||
%splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
|
||||
%kct = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %ct, i32 3, i32 3, i32 3)
|
||||
%kctdt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %kct, <9 x double> %dt, i32 3, i32 3, i32 3)
|
||||
%fadd = fadd <9 x double> %atbt, %kctdt
|
||||
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %fadd, i32 3, i32 3)
|
||||
store <9 x double> %t, <9 x double>* %E
|
||||
ret void
|
||||
}
|
||||
|
||||
; (A^T * (k * B^T))^T => (B * k) * A
|
||||
define void @atkbt_t(<9 x double>* %Aptr, <9 x double>* %Bptr, double %k, <9 x double>* %C) {
|
||||
; CHECK-LABEL: @atkbt_t(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[A:%.*]] = load <9 x double>, <9 x double>* [[APTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[B:%.*]] = load <9 x double>, <9 x double>* [[BPTR:%.*]], align 128
|
||||
; CHECK-NEXT: [[VECK:%.*]] = insertelement <9 x double> poison, double [[K:%.*]], i64 0
|
||||
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <9 x double> [[VECK]], <9 x double> poison, <9 x i32> zeroinitializer
|
||||
; CHECK-NEXT: [[MMUL1:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[B]], <9 x double> [[SPLAT]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: [[MMUL:%.*]] = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> [[MMUL1]], <9 x double> [[A]], i32 3, i32 3, i32 3)
|
||||
; CHECK-NEXT: store <9 x double> [[MMUL]], <9 x double>* [[C:%.*]], align 128
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%a = load <9 x double>, <9 x double>* %Aptr
|
||||
%b = load <9 x double>, <9 x double>* %Bptr
|
||||
%at = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %a, i32 3, i32 3)
|
||||
%bt = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %b, i32 3, i32 3)
|
||||
%veck = insertelement <9 x double> poison, double %k, i64 0
|
||||
%splat = shufflevector <9 x double> %veck, <9 x double> poison, <9 x i32> zeroinitializer
|
||||
%kbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %splat, <9 x double> %bt, i32 3, i32 3, i32 3)
|
||||
%atkbt = call <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double> %at, <9 x double> %kbt, i32 3, i32 3, i32 3)
|
||||
%t = call <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double> %atkbt, i32 3, i32 3)
|
||||
store <9 x double> %t, <9 x double>* %C
|
||||
ret void
|
||||
}
|
||||
|
||||
declare <9 x double> @llvm.matrix.multiply.v9f64.v9f64.v9f64(<9 x double>, <9 x double>, i32, i32, i32)
|
||||
declare <9 x double> @llvm.matrix.transpose.v9f64.v9f64(<9 x double>, i32, i32)
|
||||
declare <9 x i32> @llvm.matrix.transpose.v9i32.v9i32(<9 x i32>, i32, i32)
|
||||
|
|
Loading…
Reference in New Issue