Skip to content

Commit 623db67

Browse files
committed
refactor WithTx and MustTx functions
1 parent 7a9c837 commit 623db67

File tree

21 files changed

+64
-35
lines changed

21 files changed

+64
-35
lines changed

models/avatars/avatar.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func saveEmailHash(email string) string {
9797
Hash: emailHash,
9898
}
9999
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
100-
if err := db.WithTx(func(ctx context.Context) error {
100+
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
101101
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
102102
if has || err != nil {
103103
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time

models/db/context.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package db
77
import (
88
"context"
99
"database/sql"
10+
"errors"
1011

1112
"xorm.io/xorm"
1213
"xorm.io/xorm/schemas"
@@ -97,13 +98,32 @@ func TxContext() (*Context, Committer, error) {
9798
return newContext(DefaultContext, sess, true), sess, nil
9899
}
99100

101+
var ErrAlreadyInTransaction = errors.New("database connection has already been in a transaction")
102+
100103
// WithTx represents executing database operations on a transaction
101-
// you can optionally change the context to a parent one
102-
func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error {
103-
parentCtx := DefaultContext
104-
if len(stdCtx) != 0 && stdCtx[0] != nil {
105-
// TODO: make sure parent context has no open session
106-
parentCtx = stdCtx[0]
104+
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
105+
if InTransaction(parentCtx) {
106+
return ErrAlreadyInTransaction
107+
}
108+
109+
sess := x.NewSession()
110+
defer sess.Close()
111+
if err := sess.Begin(); err != nil {
112+
return err
113+
}
114+
115+
if err := f(newContext(parentCtx, sess, true)); err != nil {
116+
return err
117+
}
118+
119+
return sess.Commit()
120+
}
121+
122+
// MustTx represents executing database operations on a transaction, if the transaction exist,
123+
// this function will reuse it otherwise will create a new one and close it when finished.
124+
func MustTx(parentCtx context.Context, f func(ctx context.Context) error) error {
125+
if InTransaction(parentCtx) {
126+
return f(newContext(parentCtx, GetEngine(parentCtx), true))
107127
}
108128

109129
sess := x.NewSession()

models/db/context_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@ import (
1717
func TestInTransaction(t *testing.T) {
1818
assert.NoError(t, unittest.PrepareTestDatabase())
1919
assert.False(t, db.InTransaction(db.DefaultContext))
20-
assert.NoError(t, db.WithTx(func(ctx context.Context) error {
20+
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
21+
assert.True(t, db.InTransaction(ctx))
22+
return nil
23+
}))
24+
25+
ctx, committer, err := db.TxContext()
26+
assert.NoError(t, err)
27+
defer committer.Close()
28+
assert.True(t, db.InTransaction(ctx))
29+
assert.Error(t, db.WithTx(ctx, func(ctx context.Context) error {
2130
assert.True(t, db.InTransaction(ctx))
2231
return nil
2332
}))

models/db/index_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
5959
assert.EqualValues(t, 62, maxIndex)
6060

6161
// commit transaction
62-
err = db.WithTx(func(ctx context.Context) error {
62+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
6363
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
6464
assert.NoError(t, err)
6565
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
@@ -73,7 +73,7 @@ func TestSyncMaxResourceIndex(t *testing.T) {
7373
assert.EqualValues(t, 73, maxIndex)
7474

7575
// rollback transaction
76-
err = db.WithTx(func(ctx context.Context) error {
76+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
7777
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
7878
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
7979
assert.NoError(t, err)
@@ -102,7 +102,7 @@ func TestGetNextResourceIndex(t *testing.T) {
102102
assert.EqualValues(t, 2, maxIndex)
103103

104104
// commit transaction
105-
err = db.WithTx(func(ctx context.Context) error {
105+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
106106
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
107107
assert.NoError(t, err)
108108
assert.EqualValues(t, 3, maxIndex)
@@ -114,7 +114,7 @@ func TestGetNextResourceIndex(t *testing.T) {
114114
assert.EqualValues(t, 3, maxIndex)
115115

116116
// rollback transaction
117-
err = db.WithTx(func(ctx context.Context) error {
117+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
118118
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
119119
assert.NoError(t, err)
120120
assert.EqualValues(t, 4, maxIndex)

models/project/issue.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func (p *Project) NumOpenIssues() int {
7878

7979
// MoveIssuesOnProjectBoard moves or keeps issues in a column and sorts them inside that column
8080
func MoveIssuesOnProjectBoard(board *Board, sortedIssueIDs map[int64]int64) error {
81-
return db.WithTx(func(ctx context.Context) error {
81+
return db.WithTx(db.DefaultContext,func(ctx context.Context) error {
8282
sess := db.GetEngine(ctx)
8383

8484
issueIDs := make([]int64, 0, len(sortedIssueIDs))

models/system/appstate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func init() {
2525

2626
// SaveAppStateContent saves the app state item to database
2727
func SaveAppStateContent(key, content string) error {
28-
return db.WithTx(func(ctx context.Context) error {
28+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
2929
eng := db.GetEngine(ctx)
3030
// try to update existing row
3131
res, err := eng.Exec("UPDATE app_state SET revision=revision+1, content=? WHERE id=?", content, key)

models/system/setting.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func SetSetting(setting *Setting) error {
168168
}
169169

170170
func upsertSettingValue(key, value string, version int) error {
171-
return db.WithTx(func(ctx context.Context) error {
171+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
172172
e := db.GetEngine(ctx)
173173

174174
// here we use a general method to do a safe upsert for different databases (and most transaction levels)

models/user/setting.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func SetUserSetting(userID int64, key, value string) error {
137137
}
138138

139139
func upsertUserSettingValue(userID int64, key, value string) error {
140-
return db.WithTx(func(ctx context.Context) error {
140+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
141141
e := db.GetEngine(ctx)
142142

143143
// here we use a general method to do a safe upsert for different databases (and most transaction levels)

modules/repository/collaborator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func addCollaborator(ctx context.Context, repo *repo_model.Repository, u *user_m
3737

3838
// AddCollaborator adds new collaboration to a repository with default access mode.
3939
func AddCollaborator(repo *repo_model.Repository, u *user_model.User) error {
40-
return db.WithTx(func(ctx context.Context) error {
40+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
4141
return addCollaborator(ctx, repo, u)
4242
})
4343
}

modules/repository/create.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func CreateRepository(doer, u *user_model.User, opts CreateRepoOptions) (*repo_m
211211

212212
var rollbackRepo *repo_model.Repository
213213

214-
if err := db.WithTx(func(ctx context.Context) error {
214+
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
215215
if err := CreateRepositoryByExample(ctx, doer, u, repo, false); err != nil {
216216
return err
217217
}

modules/repository/repo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ func pullMirrorReleaseSync(repo *repo_model.Repository, gitRepo *git.Repository)
485485
if err != nil {
486486
return fmt.Errorf("unable to GetTagInfos in pull-mirror Repo[%d:%s/%s]: %w", repo.ID, repo.OwnerName, repo.Name, err)
487487
}
488-
err = db.WithTx(func(ctx context.Context) error {
488+
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
489489
//
490490
// clear out existing releases
491491
//

routers/api/packages/container/blob.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func saveAsPackageBlob(hsr packages_module.HashedSizeReader, pi *packages_servic
2828

2929
contentStore := packages_module.NewContentStore()
3030

31-
err := db.WithTx(func(ctx context.Context) error {
31+
err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
3232
created := true
3333
p := &packages_model.Package{
3434
OwnerID: pi.Owner.ID,
@@ -117,7 +117,7 @@ func saveAsPackageBlob(hsr packages_module.HashedSizeReader, pi *packages_servic
117117
}
118118

119119
func deleteBlob(ownerID int64, image, digest string) error {
120-
return db.WithTx(func(ctx context.Context) error {
120+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
121121
pfds, err := container_model.GetContainerBlobs(ctx, &container_model.BlobSearchOptions{
122122
OwnerID: ownerID,
123123
Image: image,

services/attachment/attachment.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func NewAttachment(attach *repo_model.Attachment, file io.Reader) (*repo_model.A
2525
return nil, fmt.Errorf("attachment %s should belong to a repository", attach.Name)
2626
}
2727

28-
err := db.WithTx(func(ctx context.Context) error {
28+
err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
2929
attach.UUID = uuid.New().String()
3030
size, err := storage.Attachments.Save(attach.RelativePath(), file, -1)
3131
if err != nil {

services/automerge/automerge.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func addToQueue(pr *issues_model.PullRequest, sha string) {
6363

6464
// ScheduleAutoMerge if schedule is false and no error, pull can be merged directly
6565
func ScheduleAutoMerge(ctx context.Context, doer *user_model.User, pull *issues_model.PullRequest, style repo_model.MergeStyle, message string) (scheduled bool, err error) {
66-
err = db.WithTx(func(ctx context.Context) error {
66+
err = db.WithTx(ctx, func(ctx context.Context) error {
6767
lastCommitStatus, err := pull_service.GetPullRequestCommitStatusState(ctx, pull)
6868
if err != nil {
6969
return err
@@ -81,20 +81,20 @@ func ScheduleAutoMerge(ctx context.Context, doer *user_model.User, pull *issues_
8181

8282
_, err = issues_model.CreateAutoMergeComment(ctx, issues_model.CommentTypePRScheduledToAutoMerge, pull, doer)
8383
return err
84-
}, ctx)
84+
})
8585
return scheduled, err
8686
}
8787

8888
// RemoveScheduledAutoMerge cancels a previously scheduled pull request
8989
func RemoveScheduledAutoMerge(ctx context.Context, doer *user_model.User, pull *issues_model.PullRequest) error {
90-
return db.WithTx(func(ctx context.Context) error {
90+
return db.WithTx(ctx, func(ctx context.Context) error {
9191
if err := pull_model.DeleteScheduledAutoMerge(ctx, pull.ID); err != nil {
9292
return err
9393
}
9494

9595
_, err := issues_model.CreateAutoMergeComment(ctx, issues_model.CommentTypePRUnScheduledToAutoMerge, pull, doer)
9696
return err
97-
}, ctx)
97+
})
9898
}
9999

100100
// MergeScheduledPullRequest merges a previously scheduled pull request when all checks succeeded

services/org/repo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TeamAddRepository(t *organization.Team, repo *repo_model.Repository) (err e
2222
return nil
2323
}
2424

25-
return db.WithTx(func(ctx context.Context) error {
25+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
2626
return models.AddRepository(ctx, t, repo)
2727
})
2828
}

services/pull/check.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func AddToTaskQueue(pr *issues_model.PullRequest) {
6363

6464
// CheckPullMergable check if the pull mergable based on all conditions (branch protection, merge options, ...)
6565
func CheckPullMergable(stdCtx context.Context, doer *user_model.User, perm *access_model.Permission, pr *issues_model.PullRequest, manuallMerge, force bool) error {
66-
return db.WithTx(func(ctx context.Context) error {
66+
return db.WithTx(stdCtx, func(ctx context.Context) error {
6767
if pr.HasMerged {
6868
return ErrHasMerged
6969
}
@@ -122,7 +122,7 @@ func CheckPullMergable(stdCtx context.Context, doer *user_model.User, perm *acce
122122
}
123123

124124
return nil
125-
}, stdCtx)
125+
})
126126
}
127127

128128
// isSignedIfRequired check if merge will be signed if required

services/pull/merge.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ func MergedManually(pr *issues_model.PullRequest, doer *user_model.User, baseGit
828828
pullWorkingPool.CheckIn(fmt.Sprint(pr.ID))
829829
defer pullWorkingPool.CheckOut(fmt.Sprint(pr.ID))
830830

831-
if err := db.WithTx(func(ctx context.Context) error {
831+
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
832832
prUnit, err := pr.BaseRepo.GetUnitCtx(ctx, unit.TypePullRequests)
833833
if err != nil {
834834
return err

services/repository/adopt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func AdoptRepository(doer, u *user_model.User, opts repo_module.CreateRepoOption
5454
IsEmpty: !opts.AutoInit,
5555
}
5656

57-
if err := db.WithTx(func(ctx context.Context) error {
57+
if err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
5858
repoPath := repo_model.RepoPath(u.Name, repo.Name)
5959
isExist, err := util.IsExist(repoPath)
6060
if err != nil {

services/repository/fork.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func ForkRepository(ctx context.Context, doer, owner *user_model.User, opts Fork
112112
panic(panicErr)
113113
}()
114114

115-
err = db.WithTx(func(txCtx context.Context) error {
115+
err = db.WithTx(ctx, func(txCtx context.Context) error {
116116
if err = repo_module.CreateRepositoryByExample(txCtx, doer, owner, repo, false); err != nil {
117117
return err
118118
}
@@ -184,7 +184,7 @@ func ForkRepository(ctx context.Context, doer, owner *user_model.User, opts Fork
184184

185185
// ConvertForkToNormalRepository convert the provided repo from a forked repo to normal repo
186186
func ConvertForkToNormalRepository(repo *repo_model.Repository) error {
187-
err := db.WithTx(func(ctx context.Context) error {
187+
err := db.WithTx(db.DefaultContext, func(ctx context.Context) error {
188188
repo, err := repo_model.GetRepositoryByIDCtx(ctx, repo.ID)
189189
if err != nil {
190190
return err

services/repository/push.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ func pushUpdates(optsList []*repo_module.PushUpdateOptions) error {
290290

291291
// PushUpdateAddDeleteTags updates a number of added and delete tags
292292
func PushUpdateAddDeleteTags(repo *repo_model.Repository, gitRepo *git.Repository, addTags, delTags []string) error {
293-
return db.WithTx(func(ctx context.Context) error {
293+
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
294294
if err := repo_model.PushUpdateDeleteTagsContext(ctx, repo, delTags); err != nil {
295295
return err
296296
}

services/repository/template.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func GenerateRepository(doer, owner *user_model.User, templateRepo *repo_model.R
4848
}
4949

5050
var generateRepo *repo_model.Repository
51-
if err = db.WithTx(func(ctx context.Context) error {
51+
if err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
5252
generateRepo, err = repo_module.GenerateRepository(ctx, doer, owner, templateRepo, opts)
5353
if err != nil {
5454
return err

0 commit comments

Comments
 (0)