@@ -9,26 +9,51 @@ import (
99// txContextKey gorm database transaction context key
1010type txContextKey struct {}
1111
12+ type DataSourceManager interface {
13+ GetDataSource () * gorm.DB
14+ }
15+
1216type Transaction interface {
13- WithContext (context.Context , func (ctx context.Context , tx * gorm.DB ) error ) error
17+ Transaction (context.Context , func (ctx context.Context ) error ) error
18+ WithContext (context.Context ) * gorm.DB
1419}
1520
1621type transactionManager struct {
17- db * gorm. DB
22+ dsm DataSourceManager
1823}
1924
20- func NewTransactionManager (dataSource * gorm. DB ) Transaction {
25+ func NewTransactionManager (dsm DataSourceManager ) Transaction {
2126 return & transactionManager {
22- db : dataSource ,
27+ dsm : dsm ,
2328 }
2429}
2530
26- func (tm * transactionManager ) WithContext (ctx context.Context , fn func (ctx context.Context , tx * gorm.DB ) error ) error {
31+ func (tm * transactionManager ) WithContext (ctx context.Context ) * gorm.DB {
32+ if ctx == nil {
33+ return tm .dsm .GetDataSource ()
34+ }
35+
2736 if tx , ok := ctx .Value (txContextKey {}).(* gorm.DB ); ok {
28- return fn (ctx , tx )
37+ return tx
38+ }
39+
40+ return tm .dsm .GetDataSource ().WithContext (ctx )
41+ }
42+
43+ func (tm * transactionManager ) Transaction (ctx context.Context , fn func (ctx context.Context ) error ) error {
44+ if ctx == nil {
45+ return tm .with (context .Background (), fn )
46+ }
47+
48+ if _ , ok := ctx .Value (txContextKey {}).(* gorm.DB ); ok {
49+ return fn (ctx )
2950 }
3051
31- return tm .db .WithContext (ctx ).Transaction (func (tx * gorm.DB ) error {
32- return fn (context .WithValue (ctx , txContextKey {}, tx ), tx )
52+ return tm .with (ctx , fn )
53+ }
54+
55+ func (tm * transactionManager ) with (ctx context.Context , fn func (ctx context.Context ) error ) error {
56+ return tm .dsm .GetDataSource ().Transaction (func (tx * gorm.DB ) error {
57+ return fn (context .WithValue (ctx , txContextKey {}, tx ))
3358 })
3459}
0 commit comments