|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import base64 |
| 18 | +import copy |
18 | 19 | import json |
19 | 20 | import logging |
20 | 21 | import os |
@@ -559,6 +560,91 @@ async def _get_content( |
559 | 560 | return content_objects |
560 | 561 |
|
561 | 562 |
|
| 563 | +def _is_ollama_chat_provider( |
| 564 | + model: Optional[str], custom_llm_provider: Optional[str] |
| 565 | +) -> bool: |
| 566 | + """Returns True when requests should be normalized for ollama_chat.""" |
| 567 | + if custom_llm_provider and custom_llm_provider.lower() == "ollama_chat": |
| 568 | + return True |
| 569 | + if model and model.lower().startswith("ollama_chat"): |
| 570 | + return True |
| 571 | + return False |
| 572 | + |
| 573 | + |
| 574 | +def _flatten_ollama_content( |
| 575 | + content: OpenAIMessageContent | str | None, |
| 576 | +) -> str | OpenAIMessageContent | None: |
| 577 | + """Flattens multipart content to text for ollama_chat compatibility. |
| 578 | +
|
| 579 | + Ollama's chat endpoint rejects arrays for `content`. We keep textual parts, |
| 580 | + join them with newlines, and fall back to a JSON string for non-text content. |
| 581 | + If both text and non-text parts are present, only the text parts are kept. |
| 582 | + """ |
| 583 | + if not isinstance(content, list): |
| 584 | + return content |
| 585 | + |
| 586 | + text_parts = [] |
| 587 | + for block in content: |
| 588 | + if isinstance(block, dict) and block.get("type") == "text": |
| 589 | + text_value = block.get("text") |
| 590 | + if text_value: |
| 591 | + text_parts.append(text_value) |
| 592 | + |
| 593 | + if text_parts: |
| 594 | + return _NEW_LINE.join(text_parts) |
| 595 | + |
| 596 | + try: |
| 597 | + return json.dumps(content) |
| 598 | + except TypeError: |
| 599 | + return str(content) |
| 600 | + |
| 601 | + |
| 602 | +def _normalize_ollama_chat_messages( |
| 603 | + messages: list[Message], |
| 604 | + *, |
| 605 | + model: Optional[str] = None, |
| 606 | + custom_llm_provider: Optional[str] = None, |
| 607 | +) -> list[Message]: |
| 608 | + """Normalizes message payloads for ollama_chat provider. |
| 609 | +
|
| 610 | + The provider expects string content. Convert multipart content to text while |
| 611 | + leaving other providers untouched. |
| 612 | + """ |
| 613 | + if not _is_ollama_chat_provider(model, custom_llm_provider): |
| 614 | + return messages |
| 615 | + |
| 616 | + normalized_messages: list[Message] = [] |
| 617 | + for message in messages: |
| 618 | + if isinstance(message, dict): |
| 619 | + message_copy = dict(message) |
| 620 | + message_copy["content"] = _flatten_ollama_content( |
| 621 | + message_copy.get("content") |
| 622 | + ) |
| 623 | + normalized_messages.append(message_copy) |
| 624 | + continue |
| 625 | + |
| 626 | + message_copy = ( |
| 627 | + message.model_copy() |
| 628 | + if hasattr(message, "model_copy") |
| 629 | + else copy.copy(message) |
| 630 | + ) |
| 631 | + if hasattr(message_copy, "content"): |
| 632 | + flattened_content = _flatten_ollama_content( |
| 633 | + getattr(message_copy, "content") |
| 634 | + ) |
| 635 | + try: |
| 636 | + setattr(message_copy, "content", flattened_content) |
| 637 | + except AttributeError as e: |
| 638 | + logger.debug( |
| 639 | + "Failed to set 'content' attribute on message of type %s: %s", |
| 640 | + type(message_copy).__name__, |
| 641 | + e, |
| 642 | + ) |
| 643 | + normalized_messages.append(message_copy) |
| 644 | + |
| 645 | + return normalized_messages |
| 646 | + |
| 647 | + |
562 | 648 | def _build_tool_call_from_json_dict( |
563 | 649 | candidate: Any, *, index: int |
564 | 650 | ) -> Optional[ChatCompletionMessageToolCall]: |
@@ -1350,18 +1436,23 @@ async def generate_content_async( |
1350 | 1436 | _append_fallback_user_content_if_missing(llm_request) |
1351 | 1437 | logger.debug(_build_request_log(llm_request)) |
1352 | 1438 |
|
1353 | | - model = llm_request.model or self.model |
| 1439 | + effective_model = llm_request.model or self.model |
1354 | 1440 | messages, tools, response_format, generation_params = ( |
1355 | | - await _get_completion_inputs(llm_request, model) |
| 1441 | + await _get_completion_inputs(llm_request, effective_model) |
| 1442 | + ) |
| 1443 | + normalized_messages = _normalize_ollama_chat_messages( |
| 1444 | + messages, |
| 1445 | + model=effective_model, |
| 1446 | + custom_llm_provider=self._additional_args.get("custom_llm_provider"), |
1356 | 1447 | ) |
1357 | 1448 |
|
1358 | 1449 | if "functions" in self._additional_args: |
1359 | 1450 | # LiteLLM does not support both tools and functions together. |
1360 | 1451 | tools = None |
1361 | 1452 |
|
1362 | 1453 | completion_args = { |
1363 | | - "model": model, |
1364 | | - "messages": messages, |
| 1454 | + "model": effective_model, |
| 1455 | + "messages": normalized_messages, |
1365 | 1456 | "tools": tools, |
1366 | 1457 | "response_format": response_format, |
1367 | 1458 | } |
|
0 commit comments