Skip to content

Commit 8de3f78

Browse files
committed
CB-12 Use asyncio and bind application interfaces
1 parent a0b8ed4 commit 8de3f78

4 files changed

Lines changed: 76 additions & 42 deletions

File tree

app/core/di/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from dependency_injection import get_dependency
1+
from .dependency_injection import get_dependency
22

33
__all__ = ["get_dependency"]

app/core/di/dependency_injection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from injector import Binder, Injector
22

3+
from app.application.generation import GenerationService, DefaultGenerationService
4+
from app.application.routing import RoutingService, DefaultRoutingService
5+
36

47
# Configure the dependency injection container
58
def configure(binder: Binder) -> None:
6-
# binder.bind(MyServiceInterface, to=MyService, scope=singleton)
9+
binder.bind(GenerationService, to=DefaultGenerationService)
10+
binder.bind(RoutingService, to=DefaultRoutingService)
711
pass
812

913

app/core/interceptors/logging_interceptor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
import time
2-
from typing import Callable
3-
2+
from typing import Callable, Any
43
import grpc
5-
from grpc_interceptor import ServerInterceptor
64

75
from app.core.logger import logger
86

97

10-
class LoggingInterceptor(ServerInterceptor):
11-
def intercept(self, method: Callable, request: any, context: grpc.ServicerContext, method_name: str) -> any:
8+
class LoggingInterceptor(grpc.aio.ServerInterceptor):
9+
async def intercept_service(
10+
self,
11+
continuation: Callable[[grpc.HandlerCallDetails], Any],
12+
handler_call_details: grpc.HandlerCallDetails,
13+
) -> Any:
1214
start_time = time.time()
15+
method_name = handler_call_details.method
1316
logger.info(f"Request: {method_name}")
1417

1518
try:
16-
response = method(request, context)
19+
response = await continuation(handler_call_details)
1720
process_time = time.time() - start_time
18-
logger.info(f"Response: Success (Time: {process_time:.2f}s) - {response}")
21+
logger.info(f"Response: Success (Time: {process_time:.2f}s)")
1922
return response
2023
except Exception as e:
2124
process_time = time.time() - start_time

app/main.py

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,75 @@
1+
import asyncio
12
import os
3+
import uuid
24
from concurrent import futures
35

46
import grpc
57
from grpc_reflection.v1alpha import reflection
68

9+
from app.application.generation import GenerationService
10+
from app.application.routing import RoutingService
11+
from app.core.di import get_dependency
712
from app.core.interceptors import LoggingInterceptor
8-
from app.generated.v1.executor import service_pb2_grpc, generation_pb2, routing_pb2, service_pb2, common_pb2
9-
1013
from app.core.logger import logger
14+
from app.domain.data import TextData, OptionData
15+
from app.domain.generation import GenerationOptions, GenerationRequest
16+
from app.domain.routing import RoutingRequest
17+
from app.generated.v1.executor import service_pb2_grpc, generation_pb2, routing_pb2, service_pb2, common_pb2
1118

1219

13-
# Service implementation
1420
class ExecutorService(service_pb2_grpc.ExecutorServiceServicer):
15-
def Generate(self, request: generation_pb2.GenerationRequest,
16-
context: grpc.ServicerContext) -> generation_pb2.GenerationResponse:
17-
text = request.input.text
18-
use_memory = request.options.use_memory
19-
response_schema = request.options.response_json_schema
20-
21-
# Your generation logic here
22-
generated_text = f"Generated response for: {text}" # Placeholder
23-
24-
return generation_pb2.GenerationResponse(
25-
generated_output=common_pb2.TextData(text=generated_text)
21+
def __init__(self):
22+
self.generation_service = get_dependency(GenerationService)
23+
self.routing_service = get_dependency(RoutingService)
24+
25+
def Generate(
26+
self,
27+
request: generation_pb2.GenerationRequest,
28+
context: grpc.ServicerContext) -> generation_pb2.GenerationResponse:
29+
request = GenerationRequest(
30+
input=TextData(text=request.input.text),
31+
options=GenerationOptions(
32+
use_memory=request.options.use_memory,
33+
response_json_schema=request.options.response_json_schema
34+
),
35+
context_id=uuid.UUID(request.context_id)
2636
)
2737

28-
def Route(self, request: routing_pb2.RoutingRequest,
29-
context: grpc.ServicerContext) -> routing_pb2.RoutingResponse:
30-
text = request.input.text
31-
options = [opt.option for opt in request.options]
32-
33-
# Your routing logic here
34-
selected = options[0] if options else ""
35-
is_fallback = False
36-
37-
return routing_pb2.RoutingResponse(
38-
selected_option=common_pb2.OptionData(option=selected),
39-
is_fallback=is_fallback
38+
try:
39+
result = asyncio.run(self.generation_service.generate(request))
40+
logger.info(f"Generated output: {result.generated_output.text}")
41+
return generation_pb2.GenerationResponse(
42+
generated_output=common_pb2.TextData(text=result.generated_output.text)
43+
)
44+
except Exception as e:
45+
return generation_pb2.GenerationResponse(
46+
error=str(e)
47+
)
48+
49+
def Route(
50+
self,
51+
request: routing_pb2.RoutingRequest,
52+
context: grpc.ServicerContext) -> routing_pb2.RoutingResponse:
53+
request = RoutingRequest(
54+
input=TextData(text=request.input.text),
55+
options=[OptionData(option=opt.option) for opt in request.options]
4056
)
4157

42-
43-
def serve():
44-
server = grpc.server(
45-
futures.ThreadPoolExecutor(max_workers=10),
58+
try:
59+
result = asyncio.run(self.routing_service.route(request))
60+
logger.info(f"Selected option: {result.selected_option.option}")
61+
return routing_pb2.RoutingResponse(
62+
selected_option=common_pb2.OptionData(option=result.selected_option.option),
63+
is_fallback=result.is_fallback
64+
)
65+
except Exception as e:
66+
return routing_pb2.RoutingResponse(
67+
error=str(e)
68+
)
69+
70+
71+
async def serve():
72+
server = grpc.aio.server(
4673
interceptors=[LoggingInterceptor()]
4774
)
4875

@@ -63,9 +90,9 @@ def serve():
6390
# Start the server
6491
server.add_insecure_port('[::]:50051')
6592
logger.info("Starting Chatbot Builder Executor gRPC server on port 50051 with reflection enabled")
66-
server.start()
67-
server.wait_for_termination()
93+
await server.start()
94+
await server.wait_for_termination()
6895

6996

7097
if __name__ == '__main__':
71-
serve()
98+
asyncio.run(serve())

0 commit comments

Comments
 (0)