Skip to content

Commit 69fd4b2

Browse files
committed
Add TestServiceRestart to verify service restarts upon failure on Windows
This test monitors the service's process ID (PID) to confirm that the Windows Service Control Manager (SCM) restarts the service when it fails. It starts the service, records its initial PID, and then waits for the service to fail and restart.
1 parent df4a4da commit 69fd4b2

File tree

3 files changed

+133
-4
lines changed

3 files changed

+133
-4
lines changed

windows/svc/example/service.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ package main
88

99
import (
1010
"fmt"
11+
"os"
1112
"strings"
1213
"time"
1314

15+
"golang.org/x/sys/windows"
1416
"golang.org/x/sys/windows/svc"
1517
"golang.org/x/sys/windows/svc/debug"
1618
"golang.org/x/sys/windows/svc/eventlog"
@@ -27,9 +29,17 @@ func (m *exampleService) Execute(args []string, r <-chan svc.ChangeRequest, chan
2729
slowtick := time.Tick(2 * time.Second)
2830
tick := fasttick
2931
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
32+
33+
// Simulate failure after 5 seconds
34+
failureTimer := time.NewTimer(5 * time.Second)
35+
defer failureTimer.Stop()
36+
3037
loop:
3138
for {
3239
select {
40+
case <-failureTimer.C:
41+
// Simulate failure by returning a non-zero exit code
42+
return false, uint32(windows.ERROR_UNEXP_NET_ERR)
3343
case <-tick:
3444
beep()
3545
elog.Info(1, "beep")
@@ -81,6 +91,11 @@ func runService(name string, isDebug bool) {
8191
err = run(name, &exampleService{})
8292
if err != nil {
8393
elog.Error(1, fmt.Sprintf("%s service failed: %v", name, err))
94+
if exitErr, ok := err.(*svc.ExitError); ok {
95+
os.Exit(int(exitErr.Code))
96+
} else {
97+
os.Exit(1)
98+
}
8499
return
85100
}
86101
elog.Info(1, fmt.Sprintf("%s service stopped", name))

windows/svc/service.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package svc
99

1010
import (
1111
"errors"
12+
"fmt"
1213
"sync"
1314
"unsafe"
1415

@@ -132,17 +133,26 @@ type ctlEvent struct {
132133

133134
// service provides access to windows service api.
134135
type service struct {
135-
name string
136-
h windows.Handle
137-
c chan ctlEvent
138-
handler Handler
136+
name string
137+
h windows.Handle
138+
c chan ctlEvent
139+
handler Handler
140+
exitCode uint32
139141
}
140142

141143
type exitCode struct {
142144
isSvcSpecific bool
143145
errno uint32
144146
}
145147

148+
type ExitError struct {
149+
Code uint32
150+
}
151+
152+
func (e *ExitError) Error() string {
153+
return fmt.Sprintf("service exited with error code %d", e.Code)
154+
}
155+
146156
func (s *service) updateStatus(status *Status, ec *exitCode) error {
147157
if s.h == 0 {
148158
return errors.New("updateStatus with no service status handle")
@@ -274,6 +284,7 @@ loop:
274284
}
275285

276286
theService.updateStatus(&Status{State: Stopped}, &ec)
287+
theService.exitCode = ec.errno
277288

278289
return windows.NO_ERROR
279290
}

windows/svc/svc_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,106 @@ func TestIsWindowsServiceWhenParentExits(t *testing.T) {
239239
}
240240
}
241241
}
242+
243+
func TestServiceRestart(t *testing.T) {
244+
if os.Getenv("GO_BUILDER_NAME") == "" {
245+
// Don't install services on arbitrary users' machines.
246+
t.Skip("Skipping test that modifies system services: GO_BUILDER_NAME not set")
247+
}
248+
if testing.Short() {
249+
t.Skip("Skipping test in short mode that modifies system services")
250+
}
251+
252+
const name = "svctestservice"
253+
254+
m, err := mgr.Connect()
255+
if err != nil {
256+
t.Fatalf("SCM connection failed: %v", err)
257+
}
258+
defer m.Disconnect()
259+
260+
// Build the service executable
261+
exepath := filepath.Join(t.TempDir(), "a.exe")
262+
o, err := exec.Command("go", "build", "-o", exepath, "golang.org/x/sys/windows/svc/example").CombinedOutput()
263+
if err != nil {
264+
t.Fatalf("Failed to build service program: %v\n%v", err, string(o))
265+
}
266+
267+
// Ensure any existing service is stopped and deleted
268+
stopAndDeleteIfInstalled(t, m, name)
269+
270+
// Create the service
271+
s, err := m.CreateService(name, exepath, mgr.Config{DisplayName: "x-sys svc test service"})
272+
if err != nil {
273+
t.Fatalf("CreateService(%s) failed: %v", name, err)
274+
}
275+
defer s.Close()
276+
277+
// Set the service to restart on failure
278+
actions := []mgr.RecoveryAction{
279+
{Type: mgr.ServiceRestart, Delay: 1 * time.Second}, // Restart after 1 second
280+
}
281+
err = s.SetRecoveryActions(actions, 0)
282+
if err != nil {
283+
t.Fatalf("Failed to set service recovery actions: %v", err)
284+
}
285+
286+
// Set the flag to perform recovery actions on non-crash failures
287+
err = s.SetRecoveryActionsOnNonCrashFailures(true)
288+
if err != nil {
289+
t.Fatalf("Failed to set RecoveryActionsOnNonCrashFailures: %v", err)
290+
}
291+
292+
// Start the service
293+
testState(t, s, svc.Stopped)
294+
err = s.Start()
295+
if err != nil {
296+
t.Fatalf("Start(%s) failed: %v", s.Name, err)
297+
}
298+
299+
// Wait for the service to start
300+
waitState(t, s, svc.Running)
301+
302+
// Get the initial process ID
303+
status, err := s.Query()
304+
if err != nil {
305+
t.Fatalf("Query(%s) failed: %v", s.Name, err)
306+
}
307+
initialPID := status.ProcessId
308+
t.Logf("Initial PID: %d", initialPID)
309+
310+
// Wait up to 30 seconds for the PID to change, indicating a restart
311+
var newPID uint32
312+
success := false
313+
for i := 0; i < 30; i++ {
314+
time.Sleep(1 * time.Second)
315+
316+
status, err = s.Query()
317+
if err != nil {
318+
t.Fatalf("Query(%s) failed: %v", s.Name, err)
319+
}
320+
newPID = status.ProcessId
321+
322+
if newPID != 0 && newPID != initialPID {
323+
success = true
324+
t.Logf("Service restarted successfully, new PID: %d", newPID)
325+
break
326+
}
327+
}
328+
329+
if !success {
330+
t.Fatalf("Service did not restart within the expected time")
331+
}
332+
333+
// Cleanup: Stop and delete the service
334+
_, err = s.Control(svc.Stop)
335+
if err != nil {
336+
t.Fatalf("Control(%s) failed: %v", s.Name, err)
337+
}
338+
waitState(t, s, svc.Stopped)
339+
340+
err = s.Delete()
341+
if err != nil {
342+
t.Fatalf("Delete failed: %v", err)
343+
}
344+
}

0 commit comments

Comments
 (0)