Skip to content

Commit 8fe4c4b

Browse files
committed
Make sure to call after hook on error
1 parent fb80d42 commit 8fe4c4b

File tree

1 file changed

+36
-54
lines changed

1 file changed

+36
-54
lines changed

redis.go

Lines changed: 36 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -48,83 +48,65 @@ func (hs *hooks) AddHook(hook Hook) {
4848
func (hs hooks) process(
4949
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
5050
) error {
51-
ctx, err := hs.beforeProcess(ctx, cmd)
52-
if err != nil {
53-
cmd.SetErr(err)
54-
return err
51+
if len(hs.hooks) == 0 {
52+
return fn(ctx, cmd)
5553
}
5654

57-
cmdErr := fn(ctx, cmd)
55+
var hookIndex int
56+
var retErr error
5857

59-
if err := hs.afterProcess(ctx, cmd); err != nil {
60-
cmd.SetErr(err)
61-
return err
58+
for ; hookIndex < len(hs.hooks); hookIndex++ {
59+
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd)
60+
if retErr != nil {
61+
cmd.SetErr(retErr)
62+
break
63+
}
6264
}
6365

64-
return cmdErr
65-
}
66-
67-
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
68-
for _, h := range hs.hooks {
69-
var err error
70-
ctx, err = h.BeforeProcess(ctx, cmd)
71-
if err != nil {
72-
return nil, err
73-
}
66+
if retErr == nil {
67+
retErr = fn(ctx, cmd)
7468
}
75-
return ctx, nil
76-
}
7769

78-
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) error {
79-
var firstErr error
80-
for i := len(hs.hooks) - 1; i >= 0; i-- {
81-
h := hs.hooks[i]
82-
if err := h.AfterProcess(ctx, cmd); err != nil && firstErr == nil {
83-
firstErr = err
70+
for ; hookIndex >= 0; hookIndex-- {
71+
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil {
72+
retErr = err
73+
cmd.SetErr(retErr)
8474
}
8575
}
86-
return firstErr
76+
77+
return retErr
8778
}
8879

8980
func (hs hooks) processPipeline(
9081
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
9182
) error {
92-
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
93-
if err != nil {
94-
setCmdsErr(cmds, err)
95-
return err
83+
if len(hs.hooks) == 0 {
84+
return fn(ctx, cmds)
9685
}
9786

98-
cmdsErr := fn(ctx, cmds)
87+
var hookIndex int
88+
var retErr error
9989

100-
if err := hs.afterProcessPipeline(ctx, cmds); err != nil {
101-
setCmdsErr(cmds, err)
102-
return err
90+
for ; hookIndex < len(hs.hooks); hookIndex++ {
91+
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds)
92+
if retErr != nil {
93+
setCmdsErr(cmds, retErr)
94+
break
95+
}
10396
}
10497

105-
return cmdsErr
106-
}
107-
108-
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
109-
for _, h := range hs.hooks {
110-
var err error
111-
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
112-
if err != nil {
113-
return nil, err
114-
}
98+
if retErr == nil {
99+
retErr = fn(ctx, cmds)
115100
}
116-
return ctx, nil
117-
}
118101

119-
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) error {
120-
var firstErr error
121-
for i := len(hs.hooks) - 1; i >= 0; i-- {
122-
h := hs.hooks[i]
123-
if err := h.AfterProcessPipeline(ctx, cmds); err != nil && firstErr == nil {
124-
firstErr = err
102+
for ; hookIndex >= 0; hookIndex-- {
103+
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil {
104+
retErr = err
105+
setCmdsErr(cmds, retErr)
125106
}
126107
}
127-
return firstErr
108+
109+
return retErr
128110
}
129111

130112
func (hs hooks) processTxPipeline(

0 commit comments

Comments
 (0)