Skip to content

Commit edf3a17

Browse files
Fix feedback from code review - remove duplicate code and clean up architecture
Co-authored-by: erikdubbelboer <522870+erikdubbelboer@users.noreply.github.com>
1 parent c7dfb25 commit edf3a17

5 files changed

Lines changed: 75 additions & 83 deletions

File tree

internal/signaling/handler.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,6 @@ func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.Cre
5858
logger := logging.GetLogger(ctx)
5959
logger.Debug("upgrading connection")
6060

61-
// Extract remote address for rate limiting
62-
remoteAddr := r.RemoteAddr
63-
if r.Header.Get("X-Forwarded-For") != "" {
64-
remoteAddr = strings.TrimSpace(strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0])
65-
}
66-
6761
ctx, cancel := context.WithCancel(ctx)
6862
defer cancel()
6963

@@ -85,7 +79,6 @@ func Handler(ctx context.Context, store stores.Store, cloudflare *cloudflare.Cre
8579

8680
retrievedIDCallback: manager.Reconnected,
8781
rateLimiter: passwordRateLimiter,
88-
remoteAddr: remoteAddr,
8982
}
9083
defer func() {
9184
logger.Info("peer websocket closed", zap.String("peer", peer.ID), zap.String("game", peer.Game), zap.String("origin", r.Header.Get("Origin")))

internal/signaling/peer.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ type Peer struct {
2727

2828
retrievedIDCallback func(context.Context, string, string, string) (bool, []string, error)
2929
rateLimiter *util.PasswordRateLimiter
30-
remoteAddr string
3130

3231
ID string
3332
Secret string
@@ -464,9 +463,8 @@ func (p *Peer) HandleJoinPacket(ctx context.Context, packet JoinPacket) error {
464463

465464
// Check rate limit for password attempts if rate limiter is available
466465
if p.rateLimiter != nil {
467-
logger.Debug("checking rate limit", zap.String("remote_addr", p.remoteAddr))
468-
if !p.rateLimiter.IsAllowed(ctx, p.remoteAddr) {
469-
logger.Warn("rate limit exceeded for password attempt", zap.String("remote_addr", p.remoteAddr))
466+
if !p.rateLimiter.IsAllowed(ctx, "") {
467+
logger.Warn("rate limit exceeded for password attempt")
470468
util.ReplyError(ctx, p.conn, util.ErrorWithCode(fmt.Errorf("too many password attempts"), "rate-limited"))
471469
return nil
472470
}
@@ -480,8 +478,7 @@ func (p *Peer) HandleJoinPacket(ctx context.Context, packet JoinPacket) error {
480478
} else if err == stores.ErrInvalidPassword {
481479
// Record failed password attempt for rate limiting
482480
if p.rateLimiter != nil {
483-
logger.Debug("recording failed password attempt", zap.String("remote_addr", p.remoteAddr))
484-
p.rateLimiter.RecordFailedAttempt(ctx, p.remoteAddr)
481+
p.rateLimiter.RecordFailedAttempt(ctx, "")
485482
}
486483
util.ReplyError(ctx, p.conn, util.ErrorWithCode(err, "invalid-password"))
487484
return nil

internal/util/config_test.go

Lines changed: 0 additions & 63 deletions
This file was deleted.

internal/util/ratelimit.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
"time"
77
)
88

9-
// GetRemoteAddr extracts the remote IP address from context
10-
// This duplicates the logic from metrics package to avoid circular imports
11-
func GetRemoteAddr(ctx context.Context) string {
12-
// Use the same context key pattern as metrics package
13-
type metricsContextKey int
14-
remoteAddrKey := metricsContextKey(1)
15-
9+
// metricsContextKey matches the key type used in metrics package to avoid import cycle
10+
type metricsContextKey int
11+
12+
const remoteAddrKey = metricsContextKey(1)
13+
14+
// getRemoteAddr extracts the remote IP address from context
15+
// This matches the implementation in metrics package to avoid circular imports
16+
func getRemoteAddr(ctx context.Context) string {
1617
if addr, ok := ctx.Value(remoteAddrKey).(string); ok {
1718
return addr
1819
}
@@ -48,8 +49,12 @@ func NewPasswordRateLimiter(maxAttempts int, windowSize time.Duration) *Password
4849
}
4950

5051
// IsAllowed checks if a password attempt from the given IP is allowed
52+
// If remoteAddr is empty, it will be extracted from the context
5153
// Returns true if attempt is allowed, false if rate limited
5254
func (rl *PasswordRateLimiter) IsAllowed(ctx context.Context, remoteAddr string) bool {
55+
if remoteAddr == "" {
56+
remoteAddr = getRemoteAddr(ctx)
57+
}
5358
if remoteAddr == "" {
5459
return true // Allow if we can't determine IP
5560
}
@@ -77,7 +82,11 @@ func (rl *PasswordRateLimiter) IsAllowed(ctx context.Context, remoteAddr string)
7782
}
7883

7984
// RecordFailedAttempt records a failed password attempt for the given IP
85+
// If remoteAddr is empty, it will be extracted from the context
8086
func (rl *PasswordRateLimiter) RecordFailedAttempt(ctx context.Context, remoteAddr string) {
87+
if remoteAddr == "" {
88+
remoteAddr = getRemoteAddr(ctx)
89+
}
8190
if remoteAddr == "" {
8291
return // Nothing to record if we can't determine IP
8392
}

internal/util/ratelimit_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,60 @@ func TestPasswordRateLimiter_EmptyIP(t *testing.T) {
142142
if !rl.IsAllowed(ctx, "") {
143143
t.Error("Expected empty IP to still be allowed")
144144
}
145+
}
146+
147+
func TestPasswordRateLimiter_Configuration(t *testing.T) {
148+
// Test that different configurations work
149+
tests := []struct {
150+
name string
151+
maxAttempts int
152+
windowSize time.Duration
153+
attempts int
154+
shouldBlock bool
155+
}{
156+
{
157+
name: "strict limit - 2 attempts per minute",
158+
maxAttempts: 2,
159+
windowSize: time.Minute,
160+
attempts: 2,
161+
shouldBlock: true,
162+
},
163+
{
164+
name: "lenient limit - 10 attempts per minute",
165+
maxAttempts: 10,
166+
windowSize: time.Minute,
167+
attempts: 5,
168+
shouldBlock: false,
169+
},
170+
{
171+
name: "very strict - 1 attempt per minute",
172+
maxAttempts: 1,
173+
windowSize: time.Minute,
174+
attempts: 1,
175+
shouldBlock: true,
176+
},
177+
}
178+
179+
for _, tt := range tests {
180+
t.Run(tt.name, func(t *testing.T) {
181+
rl := NewPasswordRateLimiter(tt.maxAttempts, tt.windowSize)
182+
defer rl.Close()
183+
184+
ctx := context.Background()
185+
ip := "192.168.1.100"
186+
187+
// Record the specified number of attempts
188+
for i := 0; i < tt.attempts; i++ {
189+
rl.RecordFailedAttempt(ctx, ip)
190+
}
191+
192+
// Check if next attempt should be blocked
193+
allowed := rl.IsAllowed(ctx, ip)
194+
blocked := !allowed
195+
196+
if blocked != tt.shouldBlock {
197+
t.Errorf("Expected blocked=%v but got blocked=%v", tt.shouldBlock, blocked)
198+
}
199+
})
200+
}
145201
}

0 commit comments

Comments
 (0)