diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3e82df..7ad738d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: - name: SonarQube Scan (Push) if: github.event_name == 'push' - uses: SonarSource/sonarcloud-github-action@v1.5 + uses: SonarSource/sonarqube-scan-action@v6.0.0 env: SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} with: @@ -51,7 +51,7 @@ jobs: - name: SonarQube Scan (Pull Request) if: github.event_name == 'pull_request' - uses: SonarSource/sonarcloud-github-action@v1.5 + uses: SonarSource/sonarqube-scan-action@v6.0.0 env: SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} with: diff --git a/sse/sse.go b/sse/sse.go index 5c533de..74782b2 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" "net/http" "sync" "time" @@ -25,6 +26,9 @@ type Client struct { client http.Client timeout time.Duration logger logging.LoggerInterface + bodyMu sync.Mutex + body io.ReadCloser + cancel context.CancelFunc } // NewClient creates new SSEClient @@ -49,28 +53,35 @@ func NewClient(url string, keepAlive int, dialTimeout int, logger logging.Logger return client, nil } -func (l *Client) readEvents(in *bufio.Reader, out chan<- RawEvent) { +func (l *Client) readEvents(ctx context.Context, in *bufio.Reader, out chan<- RawEvent) { eventBuilder := NewEventBuilder() + defer close(out) + defer l.logger.Info("SSE reader goroutine exited") + for { - line, err := in.ReadString(endOfLineChar) - l.logger.Debug("Incoming SSE line: ", line) - if err != nil { - if l.lifecycle.IsRunning() { // If it's supposed to be running, log an error - l.logger.Error(err) - } - close(out) + select { + case <-ctx.Done(): return - } - if line != endOfLineStr { - eventBuilder.AddLine(line) - continue + default: + line, err := in.ReadString(endOfLineChar) + l.logger.Debug("Incoming SSE line: ", line) + if err != nil { + if l.lifecycle.IsRunning() { + l.logger.Error(err) + } + return + } + if line != endOfLineStr { + eventBuilder.AddLine(line) + continue + } + + if event := eventBuilder.Build(); event != nil { + out <- event + } + eventBuilder.Reset() } - l.logger.Debug("Building SSE event") - if event := eventBuilder.Build(); event != nil { - out <- event - } - eventBuilder.Reset() } } @@ -81,12 +92,23 @@ func (l *Client) Do(params map[string]string, headers map[string]string, callbac return ErrNotIdle } - activeGoroutines := sync.WaitGroup{} + var activeGoroutines sync.WaitGroup ctx, cancel := context.WithCancel(context.Background()) + + l.bodyMu.Lock() + l.cancel = cancel + l.bodyMu.Unlock() + defer func() { l.logger.Info("SSE streaming exiting") + cancel() + + l.bodyMu.Lock() + l.cancel = nil + l.bodyMu.Unlock() + activeGoroutines.Wait() l.lifecycle.ShutdownComplete() }() @@ -96,19 +118,21 @@ func (l *Client) Do(params map[string]string, headers map[string]string, callbac return &ErrConnectionFailed{wrapped: fmt.Errorf("error building request: %w", err)} } - l.logger.Debug("[GET] ", req.URL.String()) - l.logger.Debug(fmt.Sprintf("Headers: %v", req.Header)) - resp, err := l.client.Do(req) if err != nil { - l.logger.Error("Error performing get: ", req.URL.String(), err.Error()) return &ErrConnectionFailed{wrapped: fmt.Errorf("error issuing request: %w", err)} } - if resp.StatusCode != 200 { - l.logger.Error(fmt.Sprintf("GET method: Status Code: %d - %s", resp.StatusCode, resp.Status)) - return &ErrConnectionFailed{wrapped: fmt.Errorf("sse request status code: %d", resp.StatusCode)} + + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return &ErrConnectionFailed{ + wrapped: fmt.Errorf("sse request status code: %d", resp.StatusCode), + } } - defer resp.Body.Close() + + l.bodyMu.Lock() + l.body = resp.Body + l.bodyMu.Unlock() if !l.lifecycle.InitializationComplete() { return nil @@ -116,19 +140,27 @@ func (l *Client) Do(params map[string]string, headers map[string]string, callbac reader := bufio.NewReader(resp.Body) eventChannel := make(chan RawEvent, 1000) - go l.readEvents(reader, eventChannel) - // Create timeout timer in case SSE dont receive notifications or keepalive messages + activeGoroutines.Add(1) + go func() { + defer activeGoroutines.Done() + l.readEvents(ctx, reader, eventChannel) + }() + keepAliveTimer := time.NewTimer(l.timeout) defer keepAliveTimer.Stop() for { select { + case <-ctx.Done(): + return nil + case <-l.lifecycle.ShutdownRequested(): - l.logger.Info("Shutting down listener") return nil + case event, ok := <-eventChannel: keepAliveTimer.Reset(l.timeout) + if !ok { if l.lifecycle.IsRunning() { return ErrReadingStream @@ -137,15 +169,16 @@ func (l *Client) Do(params map[string]string, headers map[string]string, callbac } if event.IsEmpty() { - continue // don't forward empty/comment events + continue } + activeGoroutines.Add(1) - go func() { + go func(ev RawEvent) { defer activeGoroutines.Done() - callback(event) - }() - case <-keepAliveTimer.C: // Timeout - l.logger.Warning("SSE idle timeout.") + callback(ev) + }(event) + + case <-keepAliveTimer.C: l.lifecycle.AbnormalShutdown() return ErrTimeout } @@ -159,6 +192,17 @@ func (l *Client) Shutdown(blocking bool) { return } + l.bodyMu.Lock() + if l.cancel != nil { + l.cancel() + l.cancel = nil + } + if l.body != nil { + _ = l.body.Close() + l.body = nil + } + l.bodyMu.Unlock() + if blocking { l.lifecycle.AwaitShutdownComplete() } diff --git a/sse/sse_test.go b/sse/sse_test.go index df6f2b3..da9dd9f 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/splitio/go-toolkit/v5/logging" + "github.com/stretchr/testify/require" ) func TestSSEErrorConnecting(t *testing.T) { @@ -218,6 +219,130 @@ func TestConnectionEOF(t *testing.T) { mockedClient.Shutdown(true) } +type fakeRawEvent struct { + id int +} + +func (f fakeRawEvent) ID() string { return fmt.Sprintf("%d", f.id) } +func (f fakeRawEvent) Event() string { return "test" } +func (f fakeRawEvent) Data() string { return "data" } +func (f fakeRawEvent) Retry() int64 { return 0 } +func (f fakeRawEvent) IsError() bool { return false } +func (f fakeRawEvent) IsEmpty() bool { return false } + +func TestProcessEventsClosureBugWithInterface(t *testing.T) { + const n = 200 + + events := make([]RawEvent, n) + for i := 0; i < n; i++ { + events[i] = fakeRawEvent{id: i} + } + + received := make([]string, 0, n) + var mu sync.Mutex + + processEventsBug(events, func(e RawEvent) { + mu.Lock() + received = append(received, e.ID()) + mu.Unlock() + }) + + if len(received) != n { + t.Fatalf("expected %d events, got %d", n, len(received)) + } + + unique := map[string]bool{} + for _, id := range received { + unique[id] = true + } + + if len(unique) != n { + t.Fatalf( + "expected %d unique events, got %d (closure bug exposed)", + n, + len(unique), + ) + } +} + +func processEventsBug(events []RawEvent, callback func(RawEvent)) { + var wg sync.WaitGroup + + for _, event := range events { + wg.Add(1) + go func(ev RawEvent) { + defer wg.Done() + callback(ev) + }(event) + } + + wg.Wait() +} + +func TestShutdownDoesNotHangWhenSSEIsIdle(t *testing.T) { + // Fake SSE server: accepts connection, sends headers, then blocks forever + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + require.True(t, ok) + flusher.Flush() + + // Block until client closes the connection + <-r.Context().Done() + })) + defer server.Close() + + logger := logging.NewLogger(nil) + + client, err := NewClient( + server.URL, + 70, // keepAlive + 0, // dialTimeout + logger, + ) + require.NoError(t, err) + + done := make(chan struct{}) + + // Start streaming + go func() { + _ = client.Do( + map[string]string{"channels": "test"}, + nil, + func(e RawEvent) {}, + ) + close(done) + }() + + // Give the client time to connect and block on read + time.Sleep(100 * time.Millisecond) + + shutdownDone := make(chan struct{}) + + go func() { + client.Shutdown(true) + close(shutdownDone) + }() + + select { + case <-shutdownDone: + // OK + case <-time.After(500 * time.Millisecond): + t.Fatal("Shutdown(true) blocked — SSE reader did not exit") + } + + // Ensure Do() also returns + select { + case <-done: + // OK + case <-time.After(500 * time.Millisecond): + t.Fatal("Do() did not return after shutdown") + } +} + /* func TestCustom(t *testing.T) { url := `https://streaming.split.io/event-stream` @@ -247,7 +372,8 @@ func TestCustom(t *testing.T) { <-ready fmt.Println(1) go func() { - err := client.Do( + err := client.Do +( map[string]string{ "accessToken": accessToken, "v": "1.1", diff --git a/struct/traits/lifecycle/lifecycle.go b/struct/traits/lifecycle/lifecycle.go index 765b7bc..703536e 100644 --- a/struct/traits/lifecycle/lifecycle.go +++ b/struct/traits/lifecycle/lifecycle.go @@ -49,6 +49,10 @@ func (l *Manager) InitializationComplete() bool { func (l *Manager) BeginShutdown() bool { // If we're currently initializing but not yet running, just change the status. if atomic.CompareAndSwapInt32(&l.status, StatusStarting, StatusInitializationCancelled) { + l.c.L.Lock() + atomic.StoreInt32(&l.status, StatusIdle) + l.c.Broadcast() + l.c.L.Unlock() return true } @@ -56,7 +60,11 @@ func (l *Manager) BeginShutdown() bool { return false } - l.shutdown <- struct{}{} + select { + case l.shutdown <- struct{}{}: + default: + } + return true } diff --git a/struct/traits/lifecycle/lifecycle_test.go b/struct/traits/lifecycle/lifecycle_test.go index e0ee14f..2b09b02 100644 --- a/struct/traits/lifecycle/lifecycle_test.go +++ b/struct/traits/lifecycle/lifecycle_test.go @@ -4,6 +4,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestLifecycleManager(t *testing.T) { @@ -232,3 +234,23 @@ func TestShutdownRequestWhileInitNotComplete(t *testing.T) { t.Error("the goroutine should have not executed further than the InitializationComplete check.") } } + +func TestInitializationCancelledEventuallyBecomesIdle(t *testing.T) { + var m Manager + m.Setup() + + require.True(t, m.BeginInitialization()) + require.True(t, m.BeginShutdown()) // cancela init + + done := make(chan struct{}) + go func() { + m.AwaitShutdownComplete() + close(done) + }() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatal("initialization cancellation never transitions to Idle") + } +}