mirror of https://github.com/lqs/sqlingo
Revert "feat: transaction support powerful more"
This reverts commit cf1908f1e5
.
This commit is contained in:
parent
1c7e60d938
commit
7d5de3b8fe
23
common.go
23
common.go
|
@ -140,29 +140,6 @@ func commaOrderBys(scope scope, orderBys []OrderBy) (string, error) {
|
|||
}
|
||||
|
||||
func getCallerInfo(db database, retry bool) string {
|
||||
if !db.enableCallerInfo {
|
||||
return ""
|
||||
}
|
||||
extraInfo := ""
|
||||
if retry {
|
||||
extraInfo += " (retry)"
|
||||
}
|
||||
for i := 0; true; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if file == "" || strings.Contains(file, "/sqlingo@v") {
|
||||
continue
|
||||
}
|
||||
segs := strings.Split(file, "/")
|
||||
name := segs[len(segs)-1]
|
||||
return fmt.Sprintf("/* %s:%d%s */ ", name, line, extraInfo)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getTxCallerInfo(db transaction, retry bool) string {
|
||||
if !db.enableCallerInfo {
|
||||
return ""
|
||||
}
|
||||
|
|
24
database.go
24
database.go
|
@ -58,13 +58,11 @@ type Database interface {
|
|||
Update(table Table) updateWithSet
|
||||
// Initiate a DELETE FROM statement
|
||||
DeleteFrom(table Table) deleteWithTable
|
||||
}
|
||||
|
||||
// Begin Start a new transaction and returning a Transaction object.
|
||||
// the DDL operations using the returned Transaction object will
|
||||
// regard as one time transaction.
|
||||
// User must manually call Commit() or Rollback() to end the transaction,
|
||||
// after that, more DDL operations or TCL will return error.
|
||||
Begin() (Transaction, error)
|
||||
type txOrDB interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -74,6 +72,7 @@ var (
|
|||
|
||||
type database struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
logger LoggerFunc
|
||||
dialect dialect
|
||||
retryPolicy func(error) bool
|
||||
|
@ -186,6 +185,13 @@ func (d database) GetDB() *sql.DB {
|
|||
return d.db
|
||||
}
|
||||
|
||||
func (d database) getTxOrDB() txOrDB {
|
||||
if d.tx != nil {
|
||||
return d.tx
|
||||
}
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d database) Query(sqlString string) (Cursor, error) {
|
||||
return d.QueryContext(context.Background(), sqlString)
|
||||
}
|
||||
|
@ -196,7 +202,7 @@ func (d database) QueryContext(ctx context.Context, sqlString string) (Cursor, e
|
|||
sqlStringWithCallerInfo := getCallerInfo(d, isRetry) + sqlString
|
||||
rows, err := d.queryContextOnce(ctx, sqlStringWithCallerInfo, isRetry)
|
||||
if err != nil {
|
||||
isRetry = d.retryPolicy != nil && d.retryPolicy(err)
|
||||
isRetry = d.tx == nil && d.retryPolicy != nil && d.retryPolicy(err)
|
||||
if isRetry {
|
||||
continue
|
||||
}
|
||||
|
@ -221,7 +227,7 @@ func (d database) queryContextOnce(ctx context.Context, sqlString string, retry
|
|||
interceptor := d.interceptor
|
||||
var rows *sql.Rows
|
||||
invoker := func(ctx context.Context, sql string) (err error) {
|
||||
rows, err = d.GetDB().QueryContext(ctx, sql)
|
||||
rows, err = d.getTxOrDB().QueryContext(ctx, sql)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -258,7 +264,7 @@ func (d database) ExecuteContext(ctx context.Context, sqlString string) (sql.Res
|
|||
|
||||
var result sql.Result
|
||||
invoker := func(ctx context.Context, sql string) (err error) {
|
||||
result, err = d.GetDB().ExecContext(ctx, sql)
|
||||
result, err = d.getTxOrDB().ExecContext(ctx, sql)
|
||||
return
|
||||
}
|
||||
var err error
|
||||
|
|
|
@ -148,11 +148,9 @@ func (e expression) GetTable() Table {
|
|||
}
|
||||
|
||||
type scope struct {
|
||||
// Transaction should be nil if without transaction begin
|
||||
Transaction *transaction
|
||||
Database *database
|
||||
Tables []Table
|
||||
lastJoin *join
|
||||
Database *database
|
||||
Tables []Table
|
||||
lastJoin *join
|
||||
}
|
||||
|
||||
func staticExpression(sql string, priority priority, isBool bool) expression {
|
||||
|
|
4
field.go
4
field.go
|
@ -61,9 +61,7 @@ func newField(table Table, fieldName string) actualField {
|
|||
expression: expression{
|
||||
builder: func(scope scope) (string, error) {
|
||||
dialect := dialectUnknown
|
||||
if scope.Transaction != nil {
|
||||
dialect = scope.Transaction.dialect
|
||||
} else if scope.Database != nil {
|
||||
if scope.Database != nil {
|
||||
dialect = scope.Database.dialect
|
||||
}
|
||||
if len(scope.Tables) != 1 || scope.lastJoin != nil || scope.Tables[0].GetName() != tableName {
|
||||
|
|
10
select.go
10
select.go
|
@ -610,17 +610,11 @@ func (s selectStatus) FetchCursor() (Cursor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var c Cursor
|
||||
if s.base.scope.Transaction != nil {
|
||||
c, err = s.base.scope.Transaction.QueryContext(s.ctx, sqlString)
|
||||
} else {
|
||||
c, err = s.base.scope.Database.QueryContext(s.ctx, sqlString)
|
||||
}
|
||||
|
||||
cursor, err := s.base.scope.Database.QueryContext(s.ctx, sqlString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
return cursor, nil
|
||||
}
|
||||
|
||||
func (s selectStatus) FetchFirst(dest ...interface{}) (ok bool, err error) {
|
||||
|
|
3
table.go
3
table.go
|
@ -24,9 +24,6 @@ func (t table) GetName() string {
|
|||
}
|
||||
|
||||
func (t table) GetSQL(scope scope) string {
|
||||
if scope.Transaction != nil {
|
||||
return t.sqlDialects[scope.Transaction.dialect]
|
||||
}
|
||||
return t.sqlDialects[scope.Database.dialect]
|
||||
}
|
||||
|
||||
|
|
242
transaction.go
242
transaction.go
|
@ -3,12 +3,12 @@ package sqlingo
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transaction is the interface of a transaction with underlying sql.Tx object.
|
||||
// It provides methods to execute DDL and TCL operations.
|
||||
type Transaction interface {
|
||||
GetDB() *sql.DB
|
||||
GetTx() *sql.Tx
|
||||
Query(sql string) (Cursor, error)
|
||||
Execute(sql string) (sql.Result, error)
|
||||
|
@ -19,13 +19,10 @@ type Transaction interface {
|
|||
InsertInto(table Table) insertWithTable
|
||||
Update(table Table) updateWithSet
|
||||
DeleteFrom(table Table) deleteWithTable
|
||||
ReplaceInto(table Table) insertWithTable
|
||||
// ReplaceInto(table Table) insertWithTable
|
||||
Commit() error
|
||||
Rollback() error
|
||||
Savepoint(name string) error
|
||||
RollbackTo(name string) error
|
||||
ReleaseSavepoint(name string) error
|
||||
}
|
||||
|
||||
func (d *database) GetTx() *sql.Tx {
|
||||
return d.tx
|
||||
}
|
||||
|
||||
func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx Transaction) error) error {
|
||||
|
@ -44,16 +41,8 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T
|
|||
}()
|
||||
|
||||
if f != nil {
|
||||
db := transaction{
|
||||
tx: tx,
|
||||
logger: d.logger,
|
||||
dialect: d.dialect,
|
||||
retryPolicy: d.retryPolicy,
|
||||
enableCallerInfo: d.enableCallerInfo,
|
||||
interceptor: d.interceptor,
|
||||
}
|
||||
db := *d
|
||||
db.tx = tx
|
||||
|
||||
err = f(&db)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -67,222 +56,3 @@ func (d *database) BeginTx(ctx context.Context, opts *sql.TxOptions, f func(tx T
|
|||
isCommitted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin starts a new transaction and returning a Transaction object.
|
||||
// the DDL operations using the returned Transaction object will
|
||||
// regard as one time transaction.
|
||||
// User must manually call Commit() or Rollback() to end the transaction,
|
||||
// after that, more DDL operations or TCL will return error.
|
||||
func (d *database) Begin() (Transaction, error) {
|
||||
var err error
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// copy extra to transaction
|
||||
t := &transaction{
|
||||
tx: tx,
|
||||
logger: d.logger,
|
||||
dialect: d.dialect,
|
||||
retryPolicy: d.retryPolicy,
|
||||
enableCallerInfo: d.enableCallerInfo,
|
||||
interceptor: d.interceptor,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
tx *sql.Tx
|
||||
logger LoggerFunc
|
||||
dialect dialect
|
||||
retryPolicy func(error) bool
|
||||
enableCallerInfo bool
|
||||
interceptor InterceptorFunc
|
||||
}
|
||||
|
||||
func (t transaction) GetTx() *sql.Tx {
|
||||
return t.tx
|
||||
}
|
||||
|
||||
func (t transaction) Query(sql string) (Cursor, error) {
|
||||
return t.QueryContext(context.Background(), sql)
|
||||
}
|
||||
|
||||
func (t transaction) QueryContext(ctx context.Context, sqlString string) (Cursor, error) {
|
||||
isRetry := false
|
||||
for {
|
||||
sqlStringWithCallerInfo := getTxCallerInfo(t, isRetry) + sqlString
|
||||
|
||||
rows, err := t.queryContextOnce(ctx, sqlStringWithCallerInfo)
|
||||
if err != nil {
|
||||
isRetry = t.tx == nil && t.retryPolicy != nil && t.retryPolicy(err)
|
||||
if isRetry {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return cursor{rows: rows}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) queryContextOnce(ctx context.Context, sqlStringWithCallerInfo string) (*sql.Rows, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
endTime := time.Now()
|
||||
if t.logger != nil {
|
||||
t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false)
|
||||
}
|
||||
}()
|
||||
|
||||
interceptor := t.interceptor
|
||||
var rows *sql.Rows
|
||||
invoker := func(ctx context.Context, sql string) (err error) {
|
||||
rows, err = t.GetTx().QueryContext(ctx, sql)
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if interceptor == nil {
|
||||
err = invoker(ctx, sqlStringWithCallerInfo)
|
||||
} else {
|
||||
err = interceptor(ctx, sqlStringWithCallerInfo, invoker)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (t transaction) Execute(sql string) (sql.Result, error) {
|
||||
return t.ExecuteContext(context.Background(), sql)
|
||||
}
|
||||
|
||||
func (t transaction) ExecuteContext(ctx context.Context, sqlString string) (sql.Result, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
sqlStringWithCallerInfo := getTxCallerInfo(t, false) + sqlString
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
endTime := time.Now()
|
||||
if t.logger != nil {
|
||||
t.logger(sqlStringWithCallerInfo, endTime.Sub(startTime), true, false)
|
||||
}
|
||||
}()
|
||||
|
||||
var result sql.Result
|
||||
invoker := func(ctx context.Context, sql string) (err error) {
|
||||
result, err = t.GetTx().ExecContext(ctx, sql)
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if t.interceptor == nil {
|
||||
err = invoker(ctx, sqlStringWithCallerInfo)
|
||||
} else {
|
||||
err = t.interceptor(ctx, sqlStringWithCallerInfo, invoker)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (t transaction) Select(fields ...interface{}) selectWithFields {
|
||||
return selectStatus{
|
||||
base: selectBase{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
},
|
||||
fields: getFields(fields),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) SelectDistinct(fields ...interface{}) selectWithFields {
|
||||
return selectStatus{
|
||||
base: selectBase{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
},
|
||||
fields: getFields(fields),
|
||||
distinct: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) SelectFrom(tables ...Table) selectWithTables {
|
||||
return selectStatus{
|
||||
base: selectBase{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
Tables: tables,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) InsertInto(table Table) insertWithTable {
|
||||
return insertStatus{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
Tables: []Table{table},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) Update(table Table) updateWithSet {
|
||||
return updateStatus{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
Tables: []Table{table}},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) DeleteFrom(table Table) deleteWithTable {
|
||||
return deleteStatus{
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
Tables: []Table{table},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) ReplaceInto(table Table) insertWithTable {
|
||||
return insertStatus{
|
||||
method: "REPLACE",
|
||||
scope: scope{
|
||||
Transaction: &t,
|
||||
Tables: []Table{table},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t transaction) Commit() error {
|
||||
return t.GetTx().Commit()
|
||||
}
|
||||
|
||||
func (t transaction) Rollback() error {
|
||||
return t.GetTx().Rollback()
|
||||
}
|
||||
|
||||
// Savepoint todo defend sql injection
|
||||
func (t transaction) Savepoint(name string) error {
|
||||
_, err := t.GetTx().Exec("SAVEPOINT " + name)
|
||||
return err
|
||||
}
|
||||
|
||||
// RollbackTo todo defend sql injection
|
||||
func (t transaction) RollbackTo(name string) error {
|
||||
_, err := t.GetTx().Exec("ROLLBACK TO " + name)
|
||||
return err
|
||||
}
|
||||
|
||||
// ReleaseSavepoint todo defend sql injection
|
||||
func (t transaction) ReleaseSavepoint(name string) error {
|
||||
_, err := t.GetTx().Exec("RELEASE SAVEPOINT " + name)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package sqlingo
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
@ -29,6 +28,9 @@ func (m *mockTx) Rollback() error {
|
|||
func TestTransaction(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
err := db.BeginTx(nil, nil, func(tx Transaction) error {
|
||||
if tx.GetDB() != db.GetDB() {
|
||||
t.Error()
|
||||
}
|
||||
if tx.GetTx() == nil {
|
||||
t.Error()
|
||||
}
|
||||
|
@ -79,142 +81,3 @@ func TestTransaction(t *testing.T) {
|
|||
t.Error("should get error here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction_Commit(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !sharedMockConn.mockTx.isCommitted {
|
||||
t.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction_Rollback(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Rollback(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !sharedMockConn.mockTx.isRolledBack {
|
||||
t.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction_Done(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Rollback(); !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.Select(1).FetchAll(); !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction_Execute(t *testing.T) {
|
||||
var sqlCount = make(map[string]int)
|
||||
db := newMockDatabase()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
|
||||
sqlCount[sql]++
|
||||
return invoker(ctx, sql)
|
||||
})
|
||||
|
||||
if _, err = tx.Execute("SQL 1 NOT SET INTERCEPTOR"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if sqlCount["SQL 1 NOT SET INTERCEPTOR"] != 0 {
|
||||
t.Error()
|
||||
}
|
||||
|
||||
if err = tx.Rollback(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if _, err = tx.Execute("SQL 2 SET INTERCEPTOR"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if sqlCount["SQL 2 SET INTERCEPTOR"] != 1 {
|
||||
t.Error()
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTransaction_CRUD tests the CRUD operations in a transaction, cause sql build is tested on database,
|
||||
// so we only insure there is no panic here.
|
||||
func TestTransaction_CRUD(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
db.EnableCallerInfo(true)
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
_, err = tx.Select().From(table1).FetchAll()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.SelectFrom(table1).FetchAll(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.SelectDistinct(field2).From(table1).FetchAll(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.InsertInto(Test).Values(1, 2).Execute(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.ReplaceInto(Test).Values(1, 2).Execute(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.DeleteFrom(table1).Where().Execute(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, err = tx.Update(table1).Set(field1, 1).Where().Execute(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err = tx.Rollback(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue