Skip to content

Commit 4d2b56f

Browse files
committed
fix: prevent segmentation fault in CIRA tunnel session timer (#776)
Initialize APF session timer to prevent nil pointer dereference during connection cleanup.
1 parent ba02a58 commit 4d2b56f

2 files changed

Lines changed: 199 additions & 2 deletions

File tree

internal/controller/tcp/cira/tunnel.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const (
2828
weakCipherSuiteCount = 3
2929
keepAliveInterval = 30
3030
keepAliveTimeout = 90
31+
apfSessionTimeout = 3 * time.Second
3132
)
3233

3334
var (
@@ -145,8 +146,10 @@ func (s *Server) handleConnection(conn net.Conn) {
145146
conn: conn,
146147
tlsConn: tlsConn,
147148
handler: NewAPFHandler(s.devices, s.log),
148-
session: &apf.Session{},
149-
log: s.log,
149+
session: &apf.Session{
150+
Timer: time.NewTimer(apfSessionTimeout),
151+
},
152+
log: s.log,
150153
}
151154
ctx.processor = apf.NewProcessor(ctx.handler)
152155

@@ -162,6 +165,17 @@ func (ctx *connectionContext) cleanup() {
162165
delete(wsman.Connections, deviceID)
163166
mu.Unlock()
164167
}
168+
169+
// Stop and clean up the session timer
170+
if ctx.session != nil && ctx.session.Timer != nil {
171+
if !ctx.session.Timer.Stop() {
172+
// Drain the channel if the timer fired
173+
select {
174+
case <-ctx.session.Timer.C:
175+
default:
176+
}
177+
}
178+
}
165179
}
166180

167181
func (s *Server) processConnection(ctx *connectionContext) {
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
package cira
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/device-management-toolkit/go-wsman-messages/v2/pkg/apf"
11+
12+
"github.com/device-management-toolkit/console/internal/usecase/devices/wsman"
13+
"github.com/device-management-toolkit/console/pkg/logger"
14+
)
15+
16+
type cleanupTestCase struct {
17+
name string
18+
setupSession func() *apf.Session
19+
authenticated bool
20+
deviceID string
21+
wantPanic bool
22+
wantTimerStopped bool
23+
}
24+
25+
var cleanupTests = []cleanupTestCase{
26+
{
27+
name: "cleanup with nil session",
28+
setupSession: func() *apf.Session { return nil },
29+
authenticated: false,
30+
deviceID: "",
31+
wantPanic: false,
32+
wantTimerStopped: false,
33+
},
34+
{
35+
name: "cleanup with session but nil timer",
36+
setupSession: func() *apf.Session {
37+
return &apf.Session{Timer: nil}
38+
},
39+
authenticated: false,
40+
deviceID: "",
41+
wantPanic: false,
42+
wantTimerStopped: false,
43+
},
44+
{
45+
name: "cleanup with timer that stops successfully",
46+
setupSession: func() *apf.Session {
47+
return &apf.Session{Timer: time.NewTimer(1 * time.Hour)}
48+
},
49+
authenticated: false,
50+
deviceID: "",
51+
wantPanic: false,
52+
wantTimerStopped: true,
53+
},
54+
{
55+
name: "cleanup with timer that fails to stop should drain channel",
56+
setupSession: func() *apf.Session {
57+
timer := time.NewTimer(1 * time.Nanosecond)
58+
59+
time.Sleep(2 * time.Millisecond)
60+
61+
return &apf.Session{Timer: timer}
62+
},
63+
authenticated: false,
64+
deviceID: "",
65+
wantPanic: false,
66+
wantTimerStopped: false,
67+
},
68+
{
69+
name: "cleanup with timer stop failure and empty channel hits default case",
70+
setupSession: func() *apf.Session {
71+
timer := time.NewTimer(100 * time.Nanosecond)
72+
73+
time.Sleep(2 * time.Millisecond)
74+
75+
select {
76+
case <-timer.C:
77+
default:
78+
}
79+
80+
return &apf.Session{Timer: timer}
81+
},
82+
authenticated: false,
83+
deviceID: "",
84+
wantPanic: false,
85+
wantTimerStopped: false,
86+
},
87+
{
88+
name: "cleanup with authenticated connection removes from connections map",
89+
setupSession: func() *apf.Session {
90+
return &apf.Session{Timer: time.NewTimer(10 * time.Second)}
91+
},
92+
authenticated: true,
93+
deviceID: "test-device",
94+
wantPanic: false,
95+
wantTimerStopped: true,
96+
},
97+
}
98+
99+
func TestConnectionContext_cleanup(t *testing.T) {
100+
t.Parallel()
101+
102+
for _, tt := range cleanupTests {
103+
tt := tt // capture range variable
104+
105+
t.Run(tt.name, func(t *testing.T) {
106+
t.Parallel()
107+
runCleanupTest(t, tt)
108+
})
109+
}
110+
}
111+
112+
func runCleanupTest(t *testing.T, tt cleanupTestCase) {
113+
t.Helper()
114+
115+
// Setup
116+
session := tt.setupSession()
117+
ctx := setupConnectionContext(t, session, tt.authenticated, tt.deviceID)
118+
119+
setupConnectionsMap(t, tt.authenticated, tt.deviceID)
120+
121+
// Execute
122+
require.NotPanics(t, func() {
123+
ctx.cleanup()
124+
})
125+
126+
// Verify
127+
verifyTimerState(t, session, tt.wantTimerStopped)
128+
verifyConnectionRemoved(t, tt.authenticated, tt.deviceID)
129+
}
130+
131+
func setupConnectionContext(t *testing.T, session *apf.Session, authenticated bool, deviceID string) *connectionContext {
132+
t.Helper()
133+
134+
// Create a proper APFHandler with mock deviceID
135+
log := logger.New("error")
136+
handler := NewAPFHandler(nil, log) // devices.Feature can be nil for cleanup test
137+
handler.deviceID = deviceID // Set deviceID directly for test
138+
139+
return &connectionContext{
140+
session: session,
141+
authenticated: authenticated,
142+
handler: handler,
143+
}
144+
}
145+
146+
func setupConnectionsMap(t *testing.T, authenticated bool, deviceID string) {
147+
t.Helper()
148+
149+
if authenticated && deviceID != "" {
150+
mu.Lock()
151+
152+
wsman.Connections[deviceID] = &wsman.ConnectionEntry{}
153+
154+
mu.Unlock()
155+
}
156+
}
157+
158+
func verifyTimerState(t *testing.T, session *apf.Session, wantTimerStopped bool) {
159+
t.Helper()
160+
161+
if wantTimerStopped && session != nil && session.Timer != nil {
162+
select {
163+
case <-session.Timer.C:
164+
// Timer was stopped and channel was drained, or timer expired naturally
165+
default:
166+
// Timer was stopped before it could fire
167+
}
168+
}
169+
}
170+
171+
func verifyConnectionRemoved(t *testing.T, authenticated bool, deviceID string) {
172+
t.Helper()
173+
174+
if authenticated && deviceID != "" {
175+
mu.Lock()
176+
177+
_, exists := wsman.Connections[deviceID]
178+
179+
mu.Unlock()
180+
181+
assert.False(t, exists, "Connection should be removed from map")
182+
}
183+
}

0 commit comments

Comments
 (0)