Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions internal/controller/tcp/cira/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
weakCipherSuiteCount = 3
keepAliveInterval = 30
keepAliveTimeout = 90
apfSessionTimeout = 3 * time.Second
)

var (
Expand Down Expand Up @@ -145,8 +146,10 @@ func (s *Server) handleConnection(conn net.Conn) {
conn: conn,
tlsConn: tlsConn,
handler: NewAPFHandler(s.devices, s.log),
session: &apf.Session{},
log: s.log,
session: &apf.Session{
Timer: time.NewTimer(apfSessionTimeout),
},
log: s.log,
}
ctx.processor = apf.NewProcessor(ctx.handler)

Expand All @@ -162,6 +165,17 @@ func (ctx *connectionContext) cleanup() {
delete(wsman.Connections, deviceID)
mu.Unlock()
}

// Stop and clean up the session timer
if ctx.session != nil && ctx.session.Timer != nil {
if !ctx.session.Timer.Stop() {
// Drain the channel if the timer fired
select {
case <-ctx.session.Timer.C:
default:
}
}
}
}

func (s *Server) processConnection(ctx *connectionContext) {
Expand Down
183 changes: 183 additions & 0 deletions internal/controller/tcp/cira/tunnel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package cira

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/device-management-toolkit/go-wsman-messages/v2/pkg/apf"

"github.com/device-management-toolkit/console/internal/usecase/devices/wsman"
"github.com/device-management-toolkit/console/pkg/logger"
)

type cleanupTestCase struct {
name string
setupSession func() *apf.Session
authenticated bool
deviceID string
wantPanic bool
wantTimerStopped bool
}

var cleanupTests = []cleanupTestCase{
{
name: "cleanup with nil session",
setupSession: func() *apf.Session { return nil },
authenticated: false,
deviceID: "",
wantPanic: false,
wantTimerStopped: false,
},
{
name: "cleanup with session but nil timer",
setupSession: func() *apf.Session {
return &apf.Session{Timer: nil}
},
authenticated: false,
deviceID: "",
wantPanic: false,
wantTimerStopped: false,
},
{
name: "cleanup with timer that stops successfully",
setupSession: func() *apf.Session {
return &apf.Session{Timer: time.NewTimer(1 * time.Hour)}
},
authenticated: false,
deviceID: "",
wantPanic: false,
wantTimerStopped: true,
},
{
name: "cleanup with timer that fails to stop should drain channel",
setupSession: func() *apf.Session {
timer := time.NewTimer(1 * time.Nanosecond)

time.Sleep(2 * time.Millisecond)

return &apf.Session{Timer: timer}
},
authenticated: false,
deviceID: "",
wantPanic: false,
wantTimerStopped: false,
},
{
name: "cleanup with timer stop failure and empty channel hits default case",
setupSession: func() *apf.Session {
timer := time.NewTimer(100 * time.Nanosecond)

time.Sleep(2 * time.Millisecond)

select {
case <-timer.C:
default:
}

return &apf.Session{Timer: timer}
},
authenticated: false,
deviceID: "",
wantPanic: false,
wantTimerStopped: false,
},
{
name: "cleanup with authenticated connection removes from connections map",
setupSession: func() *apf.Session {
return &apf.Session{Timer: time.NewTimer(10 * time.Second)}
},
authenticated: true,
deviceID: "test-device",
wantPanic: false,
wantTimerStopped: true,
},
}

func TestConnectionContext_cleanup(t *testing.T) {
t.Parallel()

for _, tt := range cleanupTests {
tt := tt // capture range variable

t.Run(tt.name, func(t *testing.T) {
t.Parallel()
runCleanupTest(t, tt)
})
}
}

func runCleanupTest(t *testing.T, tt cleanupTestCase) {
t.Helper()

// Setup
session := tt.setupSession()
ctx := setupConnectionContext(t, session, tt.authenticated, tt.deviceID)

setupConnectionsMap(t, tt.authenticated, tt.deviceID)

// Execute
require.NotPanics(t, func() {
ctx.cleanup()
})

// Verify
verifyTimerState(t, session, tt.wantTimerStopped)
verifyConnectionRemoved(t, tt.authenticated, tt.deviceID)
}

func setupConnectionContext(t *testing.T, session *apf.Session, authenticated bool, deviceID string) *connectionContext {
t.Helper()

// Create a proper APFHandler with mock deviceID
log := logger.New("error")
handler := NewAPFHandler(nil, log) // devices.Feature can be nil for cleanup test
handler.deviceID = deviceID // Set deviceID directly for test

return &connectionContext{
session: session,
authenticated: authenticated,
handler: handler,
}
}

func setupConnectionsMap(t *testing.T, authenticated bool, deviceID string) {
t.Helper()

if authenticated && deviceID != "" {
mu.Lock()

wsman.Connections[deviceID] = &wsman.ConnectionEntry{}

mu.Unlock()
}
}

func verifyTimerState(t *testing.T, session *apf.Session, wantTimerStopped bool) {
t.Helper()

if wantTimerStopped && session != nil && session.Timer != nil {
select {
case <-session.Timer.C:
// Timer was stopped and channel was drained, or timer expired naturally
default:
// Timer was stopped before it could fire
}
}
}

func verifyConnectionRemoved(t *testing.T, authenticated bool, deviceID string) {
t.Helper()

if authenticated && deviceID != "" {
mu.Lock()

_, exists := wsman.Connections[deviceID]

mu.Unlock()

assert.False(t, exists, "Connection should be removed from map")
}
}
19 changes: 12 additions & 7 deletions internal/usecase/devices/wsman/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,6 @@ func (c *ConnectionEntry) hardwareGets() (GetHWResults, error) {
return results, err
}

results.ChipResult, err = c.WsmanMessages.CIM.Chip.Get()
if err != nil {
return results, err
}

results.BiosResult, err = c.WsmanMessages.CIM.BIOSElement.Get()
if err != nil {
return results, err
Expand Down Expand Up @@ -465,6 +460,16 @@ func (c *ConnectionEntry) hardwarePulls() (PullHWResults, error) {
return results, err
}

chipEnumerateResult, err := c.WsmanMessages.CIM.Chip.Enumerate()
if err != nil {
return results, err
}

results.ChipResult, err = c.WsmanMessages.CIM.Chip.Pull(chipEnumerateResult.Body.EnumerateResponse.EnumerationContext)
if err != nil {
return results, err
}

return results, nil
}

Expand All @@ -481,7 +486,7 @@ func (c *ConnectionEntry) GetHardwareInfo() (interface{}, error) {

hwResults := HWResults{
ChassisResult: getHWResults.ChassisResult,
ChipResult: getHWResults.ChipResult,
ChipResult: pullHWResults.ChipResult,
CardResult: getHWResults.CardResult,
PhysicalMemoryResult: pullHWResults.PhysicalMemoryResult,
BiosResult: getHWResults.BiosResult,
Expand All @@ -493,13 +498,13 @@ func (c *ConnectionEntry) GetHardwareInfo() (interface{}, error) {

type GetHWResults struct {
ChassisResult chassis.Response
ChipResult chip.Response
CardResult card.Response
BiosResult bios.Response
ProcessorResult processor.Response
}
type PullHWResults struct {
PhysicalMemoryResult physical.Response
ChipResult chip.Response
}
type HWResults struct {
ChassisResult chassis.Response
Expand Down
Loading