Revert "feat: transaction support powerful more"

This reverts commit cf1908f1e5.
This commit is contained in:
Qishuai Liu 2024-07-25 18:10:54 +09:00
parent 1c7e60d938
commit 7d5de3b8fe
No known key found for this signature in database
8 changed files with 30 additions and 427 deletions

View File

@ -140,29 +140,6 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
}
func getCallerInfo(db database, retry bool) string {
if !db.enableCallerInfo {
return ""
}
extraInfo := ""
if retry {
extraInfo += " (retry)"
}
for i := 0; true; i++ {
_, file, line, ok := runtime.Caller(i)
if !ok {
break
}
if file == "" || strings.Contains(file, "/sqlingo@v") {
continue
}
segs := strings.Split(file, "/")
name := segs[len(segs)-1]
return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
}
return ""
}
func getTxCallerInfo(db transaction, retry bool) string {
if !db.enableCallerInfo {
return ""
}

View File

@ -58,13 +58,11 @@ type Database interface {
Update(table Table) updateWithSet
// Initiate a DELETE FROM statement
DeleteFrom(table Table) deleteWithTable
}
// Begin Start a new transaction and returning a Transaction object.
// the DDL operations using the returned Transaction object will
// regard as one time transaction.
// User must manually call Commit() or Rollback() to end the transaction,
// after that, more DDL operations or TCL will return error.
Begin() (Transaction, error)
type txOrDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
var (
@ -74,6 +72,7 @@ var (
type database struct {
db *sql.DB
tx *sql.Tx
logger LoggerFunc
dialect dialect
retryPolicy func(error) bool
@ -186,6 +185,13 @@ func (d database) GetDB() *sql.DB {
return d.db
}
func (d database) getTxOrDB() txOrDB {
if d.tx != nil {
return d.tx
}
return d.db
}
func (d database) Query(sqlString string) (Cursor, error) {
return d.QueryContext(context.Background(), sqlString)
}
@ -196,7 +202,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString
rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry)
if err != nil {
isRetry = d.retryPolicy != nil && d.retryPolicy(err)
isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
if isRetry {
continue
}
@ -221,7 +227,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlString string, retry
interceptor := d.interceptor
var rows *sql.Rows
invoker := func(ctx context.Context, sql string) (err error) {
rows, err = d.GetDB().QueryContext(ctx, sql)
rows, err = d.getTxOrDB().QueryContext(ctx, sql)
return
}
@ -258,7 +264,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res
var result sql.Result
invoker := func(ctx context.Context, sql string) (err error) {
result, err = d.GetDB().ExecContext(ctx, sql)
result, err = d.getTxOrDB().ExecContext(ctx, sql)
return
}
var err error

View File

@ -148,11 +148,9 @@ func (e expression) GetTable() Table {
}
type scope struct {
// Transaction should be nil if without transaction begin
Transaction *transaction
Database *database
Tables []Table
lastJoin *join
Database *database
Tables []Table
lastJoin *join
}
func staticExpression(sql string, priority priority, isBool bool) expression {

View File

@ -61,9 +61,7 @@ func newField(table Table, fieldName string) actualField {
expression: expression{
builder: func(scope scope) (string, error) {
dialect := dialectUnknown
if scope.Transaction != nil {
dialect = scope.Transaction.dialect
} else if scope.Database != nil {
if scope.Database != nil {
dialect = scope.Database.dialect
}
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {

View File

@ -610,17 +610,11 @@ func (s selectStatus) FetchCursor() (Cursor, error) {
return nil, err
}
var c Cursor
if s.base.scope.Transaction != nil {
c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString)
} else {
c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString)
}
cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString)
if err != nil {
return nil, err
}
return c, nil
return cursor, nil
}
func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) {

View File

@ -24,9 +24,6 @@ func (t table) GetName() string {
}
func (t table) GetSQL(scope scope) string {
if scope.Transaction != nil {
return t.sqlDialects[scope.Transaction.dialect]
}
return t.sqlDialects[scope.Database.dialect]
}

View File

@ -3,12 +3,12 @@ package sqlingo
import (
"context"
"database/sql"
"time"
)
// Transaction is the interface of a transaction with underlying sql.Tx object.
// It provides methods to execute DDL and TCL operations.
type Transaction interface {
GetDB() *sql.DB
GetTx() *sql.Tx
Query(sql string) (Cursor, error)
Execute(sql string) (sql.Result, error)
@ -19,13 +19,10 @@ type Transaction interface {
InsertInto(table Table) insertWithTable
Update(table Table) updateWithSet
DeleteFrom(table Table) deleteWithTable
ReplaceInto(table Table) insertWithTable
// ReplaceInto(table Table) insertWithTable
Commit() error
Rollback() error
Savepoint(name string) error
RollbackTo(name string) error
ReleaseSavepoint(name string) error
}
func (d *database) GetTx() *sql.Tx {
return d.tx
}
func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error {
@ -44,16 +41,8 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T
}()
if f != nil {
db := transaction{
tx: tx,
logger: d.logger,
dialect: d.dialect,
retryPolicy: d.retryPolicy,
enableCallerInfo: d.enableCallerInfo,
interceptor: d.interceptor,
}
db := *d
db.tx = tx
err = f(&db)
if err != nil {
return err
@ -67,222 +56,3 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T
isCommitted = true
return nil
}
// Begin starts a new transaction and returning a Transaction object.
// the DDL operations using the returned Transaction object will
// regard as one time transaction.
// User must manually call Commit() or Rollback() to end the transaction,
// after that, more DDL operations or TCL will return error.
func (d *database) Begin() (Transaction, error) {
var err error
tx, err := d.db.Begin()
if err != nil {
return nil, err
}
// copy extra to transaction
t := &transaction{
tx: tx,
logger: d.logger,
dialect: d.dialect,
retryPolicy: d.retryPolicy,
enableCallerInfo: d.enableCallerInfo,
interceptor: d.interceptor,
}
return t, nil
}
type transaction struct {
tx *sql.Tx
logger LoggerFunc
dialect dialect
retryPolicy func(error) bool
enableCallerInfo bool
interceptor InterceptorFunc
}
func (t transaction) GetTx() *sql.Tx {
return t.tx
}
func (t transaction) Query(sql string) (Cursor, error) {
return t.QueryContext(context.Background(), sql)
}
func (t transaction) QueryContext(ctx context.Context, sqlString string) (Cursor, error) {
isRetry := false
for {
sqlStringWithCallerInfo := getTxCallerInfo(t, isRetry) + sqlString
rows, err := t.queryContextOnce(ctx, sqlStringWithCallerInfo)
if err != nil {
isRetry = t.tx == nil && t.retryPolicy != nil && t.retryPolicy(err)
if isRetry {
continue
}
return nil, err
}
return cursor{rows: rows}, nil
}
}
func (t transaction) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo string) (*sql.Rows, error) {
if ctx == nil {
ctx = context.Background()
}
startTime := time.Now()
defer func() {
endTime := time.Now()
if t.logger != nil {
t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false)
}
}()
interceptor := t.interceptor
var rows *sql.Rows
invoker := func(ctx context.Context, sql string) (err error) {
rows, err = t.GetTx().QueryContext(ctx, sql)
return
}
var err error
if interceptor == nil {
err = invoker(ctx, sqlStringWithCallerInfo)
} else {
err = interceptor(ctx, sqlStringWithCallerInfo, invoker)
}
if err != nil {
return nil, err
}
return rows, nil
}
func (t transaction) Execute(sql string) (sql.Result, error) {
return t.ExecuteContext(context.Background(), sql)
}
func (t transaction) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) {
if ctx == nil {
ctx = context.Background()
}
sqlStringWithCallerInfo := getTxCallerInfo(t, false) + sqlString
startTime := time.Now()
defer func() {
endTime := time.Now()
if t.logger != nil {
t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false)
}
}()
var result sql.Result
invoker := func(ctx context.Context, sql string) (err error) {
result, err = t.GetTx().ExecContext(ctx, sql)
return
}
var err error
if t.interceptor == nil {
err = invoker(ctx, sqlStringWithCallerInfo)
} else {
err = t.interceptor(ctx, sqlStringWithCallerInfo, invoker)
}
if err != nil {
return nil, err
}
return result, err
}
func (t transaction) Select(fields ...interface{}) selectWithFields {
return selectStatus{
base: selectBase{
scope: scope{
Transaction: &t,
},
fields: getFields(fields),
},
}
}
func (t transaction) SelectDistinct(fields ...interface{}) selectWithFields {
return selectStatus{
base: selectBase{
scope: scope{
Transaction: &t,
},
fields: getFields(fields),
distinct: true,
},
}
}
func (t transaction) SelectFrom(tables ...Table) selectWithTables {
return selectStatus{
base: selectBase{
scope: scope{
Transaction: &t,
Tables: tables,
},
},
}
}
func (t transaction) InsertInto(table Table) insertWithTable {
return insertStatus{
scope: scope{
Transaction: &t,
Tables: []Table{table},
},
}
}
func (t transaction) Update(table Table) updateWithSet {
return updateStatus{
scope: scope{
Transaction: &t,
Tables: []Table{table}},
}
}
func (t transaction) DeleteFrom(table Table) deleteWithTable {
return deleteStatus{
scope: scope{
Transaction: &t,
Tables: []Table{table},
},
}
}
func (t transaction) ReplaceInto(table Table) insertWithTable {
return insertStatus{
method: "REPLACE",
scope: scope{
Transaction: &t,
Tables: []Table{table},
},
}
}
func (t transaction) Commit() error {
return t.GetTx().Commit()
}
func (t transaction) Rollback() error {
return t.GetTx().Rollback()
}
// Savepoint todo defend sql injection
func (t transaction) Savepoint(name string) error {
_, err := t.GetTx().Exec("SAVEPOINT " + name)
return err
}
// RollbackTo todo defend sql injection
func (t transaction) RollbackTo(name string) error {
_, err := t.GetTx().Exec("ROLLBACK TO " + name)
return err
}
// ReleaseSavepoint todo defend sql injection
func (t transaction) ReleaseSavepoint(name string) error {
_, err := t.GetTx().Exec("RELEASE SAVEPOINT " + name)
return err
}

View File

@ -2,7 +2,6 @@ package sqlingo
import (
"context"
"database/sql"
"errors"
"testing"
)
@ -29,6 +28,9 @@ func (m *mockTx) Rollback() error {
func TestTransaction(t *testing.T) {
db := newMockDatabase()
err := db.BeginTx(nil, nil, func(tx Transaction) error {
if tx.GetDB() != db.GetDB() {
t.Error()
}
if tx.GetTx() == nil {
t.Error()
}
@ -79,142 +81,3 @@ func TestTransaction(t *testing.T) {
t.Error("should get error here")
}
}
func TestTransaction_Commit(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
if err = tx.Commit(); err != nil {
t.Error(err)
}
if !sharedMockConn.mockTx.isCommitted {
t.Error()
}
}
func TestTransaction_Rollback(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
if err = tx.Rollback(); err != nil {
t.Error(err)
}
if !sharedMockConn.mockTx.isRolledBack {
t.Error()
}
}
func TestTransaction_Done(t *testing.T) {
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
if err = tx.Commit(); err != nil {
t.Error(err)
}
if err = tx.Rollback(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}
if err = tx.Commit(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}
if _, err = tx.Select(1).FetchAll(); !errors.Is(err, sql.ErrTxDone) {
t.Error(err)
}
}
func TestTransaction_Execute(t *testing.T) {
var sqlCount = make(map[string]int)
db := newMockDatabase()
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
sqlCount[sql]++
return invoker(ctx, sql)
})
if _, err = tx.Execute("SQL 1 NOT SET INTERCEPTOR"); err != nil {
t.Error(err)
}
if sqlCount["SQL 1 NOT SET INTERCEPTOR"] != 0 {
t.Error()
}
if err = tx.Rollback(); err != nil {
t.Error(err)
}
tx, err = db.Begin()
if err != nil {
t.Error(err)
}
if _, err = tx.Execute("SQL 2 SET INTERCEPTOR"); err != nil {
t.Error(err)
}
if sqlCount["SQL 2 SET INTERCEPTOR"] != 1 {
t.Error()
}
if err = tx.Commit(); err != nil {
t.Error(err)
}
}
// TestTransaction_CRUD tests the CRUD operations in a transaction, cause sql build is tested on database,
// so we only insure there is no panic here.
func TestTransaction_CRUD(t *testing.T) {
db := newMockDatabase()
db.EnableCallerInfo(true)
tx, err := db.Begin()
if err != nil {
t.Error(err)
}
_, err = tx.Select().From(table1).FetchAll()
if err != nil {
t.Error(err)
}
if _, err = tx.SelectFrom(table1).FetchAll(); err != nil {
t.Error(err)
}
if _, err = tx.SelectDistinct(field2).From(table1).FetchAll(); err != nil {
t.Error(err)
}
if _, err = tx.InsertInto(Test).Values(1, 2).Execute(); err != nil {
t.Error(err)
}
if _, err = tx.ReplaceInto(Test).Values(1, 2).Execute(); err != nil {
t.Error(err)
}
if _, err = tx.DeleteFrom(table1).Where().Execute(); err != nil {
t.Error(err)
}
if _, err = tx.Update(table1).Set(field1, 1).Where().Execute(); err != nil {
t.Error(err)
}
if err = tx.Rollback(); err != nil {
t.Error(err)
}
}