Skip to content

Commit a68b5ff

Browse files
author
Sergey Petrunin
committed
Rewrite stmt cache for transactions
1 parent 599acb0 commit a68b5ff

1 file changed

Lines changed: 43 additions & 8 deletions

File tree

stmtcacher.go

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,55 @@ func (sc *stmtCacher) QueryRow(query string, args ...interface{}) RowScanner {
7171
return stmt.QueryRow(args...)
7272
}
7373

74-
type DBProxyBeginner interface {
74+
// DBTransactionProxy wraps transaction and includes DBProxy interface
75+
type DBTransactionProxy interface {
7576
DBProxy
76-
Begin() (*sql.Tx, error)
77+
Begin() error
78+
Commit() error
79+
Rollback() error
7780
}
7881

79-
type stmtCacheProxy struct {
82+
type stmtCacheTransactionProxy struct {
8083
DBProxy
81-
db *sql.DB
84+
db *sql.DB
85+
transaction *sql.Tx
8286
}
8387

84-
func NewStmtCacheProxy(db *sql.DB) DBProxyBeginner {
85-
return &stmtCacheProxy{DBProxy: NewStmtCacher(db), db: db}
88+
// NewStmtCacheTransactionProxy returns a DBTransactionProxy
89+
// wrapping an open transaction in stmtCacher.
90+
// You should use Begin() each time you want a new transaction and
91+
// cache will be valid only for that transaction.
92+
//
93+
// Usage example:
94+
// proxy := sq.NewStmtCacheTransactionProxy(db)
95+
// mydb := sq.StatementBuilder.RunWith(proxy)
96+
// insertUsers := mydb.Insert("users").Columns("name").Values("username")
97+
// proxy.Commit()
98+
// proxy.Begin()
99+
// insertPets := mydb.Insert("pets").Columns("name", "username").Values("petname", "username")
100+
// proxy.Commit()
101+
func NewStmtCacheTransactionProxy(db *sql.DB) (proxy DBTransactionProxy, err error) {
102+
proxy = &stmtCacheTransactionProxy{db: db}
103+
return proxy, proxy.Begin()
104+
}
105+
106+
func (sp *stmtCacheTransactionProxy) Begin() (err error) {
107+
tr, err := sp.db.Begin()
108+
109+
if err != nil {
110+
return
111+
}
112+
113+
sp.DBProxy = NewStmtCacher(tr)
114+
sp.transaction = tr
115+
116+
return
117+
}
118+
119+
func (sp *stmtCacheTransactionProxy) Commit() error {
120+
return sp.transaction.Commit()
86121
}
87122

88-
func (sp *stmtCacheProxy) Begin() (*sql.Tx, error) {
89-
return sp.db.Begin()
123+
func (sp *stmtCacheTransactionProxy) Rollback() error {
124+
return sp.transaction.Rollback()
90125
}

0 commit comments

Comments
 (0)