Skip to content

Commit c9ef6c7

Browse files
committed
merge: sync downloader gateway SSE streaming
2 parents 0c1d2eb + 42b2d31 commit c9ef6c7

4 files changed

Lines changed: 139 additions & 12 deletions

File tree

src/main/kotlin/dev/typetype/server/routes/DownloaderGatewayRoutes.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ private suspend fun forwardDownloaderRequest(call: ApplicationCall, gateway: Dow
5151
val requestHeaders = call.request.headers.names().associateWith { call.request.headers[it].orEmpty() }
5252
val body = if (hasRequestBody(method)) call.receiveText().toByteArray() else null
5353

54+
if (isSseRequest(path, requestHeaders)) {
55+
forwardDownloaderSseRequest(call, gateway, method, path, query, requestHeaders, body)
56+
return
57+
}
58+
5459
val response = runCatching { gateway.forward(method, path, query, requestHeaders, body) }
5560
.getOrElse {
5661
call.respond(HttpStatusCode.BadGateway, ErrorResponse("downloader unavailable"))
@@ -84,6 +89,12 @@ private suspend fun forwardDownloaderRequest(call: ApplicationCall, gateway: Dow
8489

8590
private fun hasRequestBody(method: String): Boolean = method == "POST" || method == "PUT" || method == "PATCH"
8691

92+
private fun isSseRequest(path: String, headers: Map<String, String>): Boolean {
93+
if (path.endsWith("/events")) return true
94+
val accept = headers.entries.firstOrNull { it.key.equals("Accept", ignoreCase = true) }?.value.orEmpty()
95+
return accept.contains("text/event-stream", ignoreCase = true)
96+
}
97+
8798
private fun shouldProxyArtifact(path: String, response: dev.typetype.server.services.DownloaderGatewayResponse): Boolean {
8899
if (!path.endsWith("/artifact")) return false
89100
if (response.status != 302 && response.status != 307) return false
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package dev.typetype.server.routes
2+
3+
import dev.typetype.server.models.ErrorResponse
4+
import dev.typetype.server.services.DownloaderGatewayService
5+
import io.ktor.http.ContentType
6+
import io.ktor.http.HttpStatusCode
7+
import io.ktor.server.application.ApplicationCall
8+
import io.ktor.server.response.respond
9+
import io.ktor.server.response.respondOutputStream
10+
11+
suspend fun forwardDownloaderSseRequest(
12+
call: ApplicationCall,
13+
gateway: DownloaderGatewayService,
14+
method: String,
15+
path: String,
16+
query: String?,
17+
requestHeaders: Map<String, String>,
18+
body: ByteArray?,
19+
) {
20+
val upstream = runCatching { gateway.openForward(method, path, query, requestHeaders, body) }
21+
.getOrElse {
22+
call.respond(HttpStatusCode.BadGateway, ErrorResponse("downloader unavailable"))
23+
return
24+
}
25+
26+
upstream.use { response ->
27+
response.headers.names().forEach { name ->
28+
if (shouldForwardResponseHeader(name)) {
29+
response.headers(name).forEach { value ->
30+
call.response.headers.append(name, value, safeOnly = false)
31+
}
32+
}
33+
}
34+
35+
val status = HttpStatusCode.fromValue(response.code)
36+
val contentType = response.header("Content-Type")
37+
?.let { runCatching { ContentType.parse(it) }.getOrNull() }
38+
?: ContentType.Text.EventStream
39+
40+
response.body?.byteStream()?.use { input ->
41+
call.respondOutputStream(contentType = contentType, status = status) {
42+
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
43+
while (true) {
44+
val read = input.read(buffer)
45+
if (read < 0) break
46+
write(buffer, 0, read)
47+
flush()
48+
}
49+
}
50+
} ?: call.respond(HttpStatusCode.fromValue(response.code))
51+
}
52+
}
53+
54+
private fun shouldForwardResponseHeader(name: String): Boolean {
55+
val lower = name.lowercase()
56+
return lower != "content-length" && lower != "transfer-encoding" && lower != "connection"
57+
}

src/main/kotlin/dev/typetype/server/services/DownloaderGatewayService.kt

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import okhttp3.MediaType.Companion.toMediaTypeOrNull
44
import okhttp3.OkHttpClient
55
import okhttp3.Request
66
import okhttp3.RequestBody.Companion.toRequestBody
7+
import okhttp3.Response
78

89
data class DownloaderGatewayResponse(
910
val status: Int,
@@ -17,6 +18,18 @@ class DownloaderGatewayService(
1718
private val client: OkHttpClient = OkHttpClient.Builder().followRedirects(false).followSslRedirects(false).build(),
1819
) {
1920
fun forward(method: String, path: String, query: String?, headers: Map<String, String>, body: ByteArray?): DownloaderGatewayResponse {
21+
openForward(method, path, query, headers, body).use { response ->
22+
val responseHeaders = response.headers.names().flatMap { name -> response.headers(name).map { name to it } }
23+
return DownloaderGatewayResponse(
24+
status = response.code,
25+
contentType = response.header("Content-Type"),
26+
headers = responseHeaders,
27+
body = response.body?.bytes() ?: ByteArray(0),
28+
)
29+
}
30+
}
31+
32+
fun openForward(method: String, path: String, query: String?, headers: Map<String, String>, body: ByteArray?): Response {
2033
val url = buildUrl(path, query)
2134
val requestBody = if (hasBody(method)) {
2235
val mediaType = headers["Content-Type"]?.toMediaTypeOrNull()
@@ -28,7 +41,11 @@ class DownloaderGatewayService(
2841
val requestBuilder = Request.Builder().url(url).method(method, requestBody)
2942
headers.forEach { (name, value) -> if (shouldForwardRequestHeader(name)) requestBuilder.addHeader(name, value) }
3043

31-
client.newCall(requestBuilder.build()).execute().use { response ->
44+
return client.newCall(requestBuilder.build()).execute()
45+
}
46+
47+
fun fetchAbsolute(url: String, headers: Map<String, String>): DownloaderGatewayResponse {
48+
openFetchAbsolute(url, headers).use { response ->
3249
val responseHeaders = response.headers.names().flatMap { name -> response.headers(name).map { name to it } }
3350
return DownloaderGatewayResponse(
3451
status = response.code,
@@ -39,19 +56,10 @@ class DownloaderGatewayService(
3956
}
4057
}
4158

42-
fun fetchAbsolute(url: String, headers: Map<String, String>): DownloaderGatewayResponse {
59+
fun openFetchAbsolute(url: String, headers: Map<String, String>): Response {
4360
val requestBuilder = Request.Builder().url(url).method("GET", null)
4461
headers["Range"]?.takeIf { it.isNotBlank() }?.let { requestBuilder.addHeader("Range", it) }
45-
46-
client.newCall(requestBuilder.build()).execute().use { response ->
47-
val responseHeaders = response.headers.names().flatMap { name -> response.headers(name).map { name to it } }
48-
return DownloaderGatewayResponse(
49-
status = response.code,
50-
contentType = response.header("Content-Type"),
51-
headers = responseHeaders,
52-
body = response.body?.bytes() ?: ByteArray(0),
53-
)
54-
}
62+
return client.newCall(requestBuilder.build()).execute()
5563
}
5664

5765
private fun buildUrl(path: String, query: String?): String {
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package dev.typetype.server
2+
3+
import dev.typetype.server.routes.downloaderGatewayRoutes
4+
import dev.typetype.server.services.DownloaderGatewayService
5+
import io.ktor.client.request.get
6+
import io.ktor.client.request.header
7+
import io.ktor.http.ContentType
8+
import io.ktor.http.HttpHeaders
9+
import io.ktor.http.HttpStatusCode
10+
import io.ktor.server.routing.routing
11+
import io.ktor.server.testing.testApplication
12+
import com.sun.net.httpserver.HttpServer
13+
import java.net.InetSocketAddress
14+
import org.junit.jupiter.api.Assertions.assertEquals
15+
import org.junit.jupiter.api.Assertions.assertTrue
16+
import org.junit.jupiter.api.Test
17+
18+
class DownloaderGatewayRoutesTest {
19+
@Test
20+
fun `sse downloader route proxies text event stream response`() = testApplication {
21+
val upstream = HttpServer.create(InetSocketAddress(0), 0)
22+
upstream.createContext("/jobs/test/events") { exchange ->
23+
val payload = "data: {\"status\":\"running\"}\n\n".toByteArray()
24+
exchange.responseHeaders.add(HttpHeaders.ContentType, ContentType.Text.EventStream.toString())
25+
exchange.responseHeaders.add(HttpHeaders.CacheControl, "no-cache")
26+
exchange.sendResponseHeaders(200, payload.size.toLong())
27+
exchange.responseBody.use { it.write(payload) }
28+
}
29+
upstream.start()
30+
31+
val port = upstream.address.port
32+
val gateway = DownloaderGatewayService("http://127.0.0.1:$port")
33+
34+
application {
35+
routing {
36+
downloaderGatewayRoutes(gateway)
37+
}
38+
}
39+
40+
try {
41+
val response = client.get("/downloader/jobs/test/events") {
42+
header(HttpHeaders.Accept, ContentType.Text.EventStream.toString())
43+
}
44+
assertEquals(HttpStatusCode.OK, response.status)
45+
assertTrue(response.headers[HttpHeaders.ContentType].orEmpty().contains("text/event-stream"))
46+
assertEquals("no-cache", response.headers[HttpHeaders.CacheControl])
47+
} finally {
48+
upstream.stop(0)
49+
}
50+
}
51+
}

0 commit comments

Comments
 (0)