collectedEvents() {
+ synchronized (collectedEvents) {
+ return List.copyOf(collectedEvents);
+ }
+ }
+
+ /** Whether a billing event would be emitted for the given request. */
@VisibleForTesting
boolean shouldEmit(ModelUsage modelUsage) {
- return apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_LOGGING)
- && BILLING_LOGGER.isInfoEnabled()
- && modelUsage != null;
+ if (modelUsage == null) {
+ return false;
+ }
+ boolean shouldLog =
+ apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_LOGGING)
+ && BILLING_LOGGER.isInfoEnabled();
+ boolean shouldBuffer = apiFeatures.isFeatureEnabled(ApiFeature.BILLING_EVENTS_RESPONSE);
+ return shouldLog || shouldBuffer;
}
/**
diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilter.java b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilter.java
new file mode 100644
index 0000000000..e9a74e1736
--- /dev/null
+++ b/src/main/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilter.java
@@ -0,0 +1,59 @@
+package io.stargate.sgv2.jsonapi.service.provider;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import io.stargate.sgv2.jsonapi.api.request.RequestContext;
+import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.inject.Inject;
+import jakarta.ws.rs.container.ContainerResponseContext;
+import java.util.List;
+import org.jboss.resteasy.reactive.server.ServerResponseFilter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Adds the {@code Billing-Events} HTTP response header (a JSON array of {@link BillingEvent}s
+ * collected during the request) when {@link ApiFeature#BILLING_EVENTS_RESPONSE} is enabled.
+ *
+ * If the feature is off, or no billing events were emitted, the header is not added. Failures to
+ * serialize are logged and silently dropped so a serialization bug never breaks the actual API
+ * response.
+ */
+@ApplicationScoped
+public class BillingResponseFilter {
+
+ /** HTTP response header that carries the JSON array of billing events. */
+ public static final String BILLING_EVENTS_HEADER = "Billing-Events";
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(BillingResponseFilter.class);
+
+ // ObjectWriter is thread-safe and expensive to build; share one across all requests.
+ private static final ObjectWriter OBJECT_WRITER = new ObjectMapper().writer();
+
+ private final RequestContext requestContext;
+
+ @Inject
+ public BillingResponseFilter(RequestContext requestContext) {
+ this.requestContext = requestContext;
+ }
+
+ @ServerResponseFilter
+ public void addBillingHeader(ContainerResponseContext responseContext) {
+ if (!requestContext.apiFeatures().isFeatureEnabled(ApiFeature.BILLING_EVENTS_RESPONSE)) {
+ return;
+ }
+ List events = requestContext.billing().collectedEvents();
+ if (events.isEmpty()) {
+ return;
+ }
+ try {
+ responseContext
+ .getHeaders()
+ .add(BILLING_EVENTS_HEADER, OBJECT_WRITER.writeValueAsString(events));
+ } catch (JsonProcessingException e) {
+ LOGGER.error("Failed to serialize {} billing events to response header", events.size(), e);
+ }
+ }
+}
diff --git a/src/main/resources/embedding-providers-config.yaml b/src/main/resources/embedding-providers-config.yaml
index 06da3fea6d..36d1f5e82e 100644
--- a/src/main/resources/embedding-providers-config.yaml
+++ b/src/main/resources/embedding-providers-config.yaml
@@ -314,6 +314,10 @@ stargate:
vector-dimension: 1024
properties:
max-tokens: 512
+ - name: nvidia/nv-embedqa-e5-v5
+ vector-dimension: 1024
+ properties:
+ max-tokens: 512
jinaAI:
#see https://jina.ai/embeddings/#apiform
display-name: Jina AI
diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilterTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilterTest.java
new file mode 100644
index 0000000000..8c44735c4f
--- /dev/null
+++ b/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingResponseFilterTest.java
@@ -0,0 +1,122 @@
+package io.stargate.sgv2.jsonapi.service.provider;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.stargate.sgv2.jsonapi.api.request.RequestContext;
+import io.stargate.sgv2.jsonapi.config.BillingConfig;
+import io.stargate.sgv2.jsonapi.config.DatabaseType;
+import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
+import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures;
+import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig;
+import jakarta.ws.rs.container.ContainerResponseContext;
+import jakarta.ws.rs.core.MultivaluedHashMap;
+import jakarta.ws.rs.core.MultivaluedMap;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import org.junit.jupiter.api.Test;
+
+class BillingResponseFilterTest {
+
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
+ private record BillingAndFeatures(Billing billing, ApiFeatures apiFeatures) {}
+
+ private static BillingAndFeatures newBillingWith(boolean logging, boolean response) {
+ BillingConfig config = mock(BillingConfig.class);
+ when(config.product()).thenReturn("serverless");
+ when(config.resourceType()).thenReturn("serverless_database");
+ when(config.internalModelProviders()).thenReturn(List.of("nvidia"));
+ when(config.enabledEventTypes()).thenReturn(Optional.empty());
+
+ FeaturesConfig featuresConfig = mock(FeaturesConfig.class);
+ Map flags = new HashMap<>();
+ flags.put(ApiFeature.BILLING_EVENTS_LOGGING, String.valueOf(logging));
+ flags.put(ApiFeature.BILLING_EVENTS_RESPONSE, String.valueOf(response));
+ when(featuresConfig.flags()).thenReturn(flags);
+
+ ApiFeatures apiFeatures = ApiFeatures.fromConfigAndRequest(featuresConfig, null);
+ return new BillingAndFeatures(new Billing(config, apiFeatures), apiFeatures);
+ }
+
+ private static ModelUsage usage() {
+ return new ModelUsage(
+ ModelProvider.NVIDIA,
+ ModelType.EMBEDDING,
+ "test-model",
+ io.stargate.sgv2.jsonapi.api.request.tenant.Tenant.create(
+ DatabaseType.ASTRA, "tenant-1", "us-west-2"),
+ ModelInputType.INDEX,
+ 10,
+ 20,
+ 100,
+ 200,
+ 1000L);
+ }
+
+ private static BillingResponseFilter filterFor(Billing billing, ApiFeatures apiFeatures) {
+ RequestContext rc = mock(RequestContext.class);
+ when(rc.billing()).thenReturn(billing);
+ when(rc.apiFeatures()).thenReturn(apiFeatures);
+ return new BillingResponseFilter(rc);
+ }
+
+ private static ContainerResponseContext responseContextWithHeaders(
+ MultivaluedMap headers) {
+ ContainerResponseContext response = mock(ContainerResponseContext.class);
+ when(response.getHeaders()).thenReturn(headers);
+ return response;
+ }
+
+ @Test
+ void addsHeaderWhenFeatureOnAndEventsPresent() throws Exception {
+ BillingAndFeatures bf = newBillingWith(false, true);
+ bf.billing().emitEvent(usage());
+ BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());
+
+ MultivaluedMap headers = new MultivaluedHashMap<>();
+ filter.addBillingHeader(responseContextWithHeaders(headers));
+
+ Object headerValue = headers.getFirst(BillingResponseFilter.BILLING_EVENTS_HEADER);
+ assertThat(headerValue).isNotNull();
+ JsonNode parsed = MAPPER.readTree(headerValue.toString());
+ assertThat(parsed.isArray()).isTrue();
+ assertThat(parsed.size()).isEqualTo(3);
+ assertThat(parsed.get(0).get("event_type").asText()).isEqualTo("internal_model_total_tokens");
+ }
+
+ @Test
+ void skipsHeaderWhenFeatureOff() {
+ // RESPONSE off — header must not be added even if LOGGING was on for this request.
+ BillingAndFeatures bf = newBillingWith(true, false);
+ bf.billing().emitEvent(usage());
+ BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());
+
+ MultivaluedMap headers = new MultivaluedHashMap<>();
+ ContainerResponseContext response = responseContextWithHeaders(headers);
+ filter.addBillingHeader(response);
+
+ assertThat(headers.containsKey(BillingResponseFilter.BILLING_EVENTS_HEADER)).isFalse();
+ // We should never touch the headers either (early return saves the work).
+ verify(response, never()).getHeaders();
+ }
+
+ @Test
+ void skipsHeaderWhenNoEventsCollected() {
+ // RESPONSE on, but no emitEvent calls — header skipped because buffer is empty.
+ BillingAndFeatures bf = newBillingWith(false, true);
+ BillingResponseFilter filter = filterFor(bf.billing(), bf.apiFeatures());
+
+ MultivaluedMap headers = new MultivaluedHashMap<>();
+ filter.addBillingHeader(responseContextWithHeaders(headers));
+
+ assertThat(headers.containsKey(BillingResponseFilter.BILLING_EVENTS_HEADER)).isFalse();
+ }
+}
diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingTest.java
index 685c5793ca..1ca1ca76d2 100644
--- a/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingTest.java
+++ b/src/test/java/io/stargate/sgv2/jsonapi/service/provider/BillingTest.java
@@ -13,6 +13,7 @@
import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig;
import io.vertx.core.MultiMap;
import java.util.EnumSet;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -53,9 +54,15 @@ private static Billing newBilling() {
}
private static ApiFeatures featuresWithBilling(boolean enabled) {
+ return features(enabled, false);
+ }
+
+ private static ApiFeatures features(boolean logging, boolean response) {
FeaturesConfig config = mock(FeaturesConfig.class);
- when(config.flags())
- .thenReturn(Map.of(ApiFeature.BILLING_EVENTS_LOGGING, String.valueOf(enabled)));
+ Map flags = new HashMap<>();
+ flags.put(ApiFeature.BILLING_EVENTS_LOGGING, String.valueOf(logging));
+ flags.put(ApiFeature.BILLING_EVENTS_RESPONSE, String.valueOf(response));
+ when(config.flags()).thenReturn(flags);
return ApiFeatures.fromConfigAndRequest(config, null);
}
@@ -240,9 +247,75 @@ void shouldEmit_falseWhenFeatureDisabled() {
void emitEvent_isNoOpWhenGatesFail() {
// null usage is always a no-op
newBilling(featuresWithBilling(true)).emitEvent(null);
- // BILLING_EVENTS_LOGGING disabled is always a no-op
- newBilling(featuresWithBilling(false))
- .emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+ // both LOGGING and RESPONSE disabled — emitEvent is a no-op
+ Billing billing = newBilling(features(false, false));
+ billing.emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+ assertThat(billing.collectedEvents()).isEmpty();
+ }
+
+ @Test
+ void emitEvent_buffersEventsWhenResponseEnabled() {
+ // LOGGING off, RESPONSE on — events still build and land in the buffer.
+ Billing billing = newBilling(features(false, true));
+ billing.emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+
+ assertThat(billing.collectedEvents())
+ .extracting(BillingEvent::eventType)
+ .containsExactly(
+ BillingEventType.INTERNAL_MODEL_TOTAL_TOKENS,
+ BillingEventType.INTERNAL_MODEL_EGRESS_BYTES,
+ BillingEventType.INTERNAL_MODEL_INGRESS_BYTES);
+ }
+
+ @Test
+ void emitEvent_doesNotBufferWhenOnlyLoggingEnabled() {
+ // LOGGING on, RESPONSE off — buffer must stay empty (no memory leak from the buffer
+ // when the response feature is off).
+ Billing billing = newBilling(features(true, false));
+ billing.emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+
+ assertThat(billing.collectedEvents()).isEmpty();
+ }
+
+ @Test
+ void emitEvent_buffersAcrossMultipleCalls() {
+ Billing billing = newBilling(features(false, true));
+ billing.emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+ billing.emitEvent(usage(ModelProvider.OPENAI, ModelType.EMBEDDING, astraTenant(REGION)));
+
+ // 3 events per emitEvent call × 2 calls = 6 events total.
+ assertThat(billing.collectedEvents()).hasSize(6);
+ }
+
+ @Test
+ void collectedEvents_returnsImmutableSnapshot() {
+ Billing billing = newBilling(features(false, true));
+ billing.emitEvent(usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION)));
+
+ List snapshot = billing.collectedEvents();
+ // Snapshot must not reflect later writes — it's a defensive copy.
+ int before = snapshot.size();
+ billing.emitEvent(usage(ModelProvider.OPENAI, ModelType.EMBEDDING, astraTenant(REGION)));
+ assertThat(snapshot).hasSize(before);
+ // And the snapshot itself must not be modifiable.
+ org.assertj.core.api.Assertions.assertThatThrownBy(snapshot::clear)
+ .isInstanceOf(UnsupportedOperationException.class);
+ }
+
+ @Test
+ void shouldEmit_trueWhenOnlyResponseEnabled() {
+ Billing billing = newBilling(features(false, true));
+ ModelUsage modelUsage = usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION));
+
+ assertThat(billing.shouldEmit(modelUsage)).isTrue();
+ }
+
+ @Test
+ void shouldEmit_trueWhenBothFlagsEnabled() {
+ Billing billing = newBilling(features(true, true));
+ ModelUsage modelUsage = usage(ModelProvider.NVIDIA, ModelType.EMBEDDING, astraTenant(REGION));
+
+ assertThat(billing.shouldEmit(modelUsage)).isTrue();
}
/**