diff --git a/README.md b/README.md index d72d9ba..28f43d7 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ https://user-images.githubusercontent.com/24236723/233631602-6a69d83c-83ef-41ed- ### Build video chat with: * [End2End](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat#running-usage) * [ChatGPT](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_text/video_chat_with_ChatGPT#running-usage) +* [MiniMax](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_with_ChatGPT#using-minimax-as-llm-provider) — Use MiniMax M2.7 as an alternative LLM provider * [StableLM](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_text/video_chat_with_StableLM#running-usage) * [MOSS](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_text/video_chat_with_MOSS#running-usage) * [MiniGPT-4](https://github.com/OpenGVLab/Ask-Anything/tree/main/video_chat_text/video_miniGPT4#running-usage) diff --git a/video_chat_with_ChatGPT/README.md b/video_chat_with_ChatGPT/README.md index dc073e0..11eb393 100644 --- a/video_chat_with_ChatGPT/README.md +++ b/video_chat_with_ChatGPT/README.md @@ -40,13 +40,30 @@ cd ./pretrained_models/flan-t5-large-finetuned-openai-summarize_from_feedback git lfs pull cd ../.. -# Configure the necessary ChatGPT APIs -export OPENAI_API_KEY={Your_Private_Openai_Key} - -# Run the VideoChat gradio demo. -python app.py +# Configure the necessary ChatGPT APIs +export OPENAI_API_KEY={Your_Private_Openai_Key} + +# Run the VideoChat gradio demo. +python app.py +``` + +## Using MiniMax as LLM Provider + +You can use [MiniMax](https://www.minimaxi.com) as an alternative LLM provider instead of OpenAI. MiniMax offers the MiniMax-M2.7 model via an OpenAI-compatible API. + +```shell +# Set your MiniMax API key +export MINIMAX_API_KEY={Your_MiniMax_API_Key} + +# Optionally set the default provider via environment variable +export LLM_PROVIDER=minimax + +# Run the demo +python app.py ``` +You can also select the LLM provider from the **LLM Provider** dropdown in the Gradio UI at runtime. + # Acknowledgement The project is based on [InternVideo](https://github.com/OpenGVLab/InternVideo), [Tag2Text](https://github.com/xinyu1205/Tag2Text), [GRiT](https://github.com/JialianW/GRiT), [mrm8488](https://huggingface.co/mrm8488/flan-t5-large-finetuned-openai-summarize_from_feedback) and [ChatGPT](https://openai.com/blog/chatgpt). Thanks for the authors for their efforts. diff --git a/video_chat_with_ChatGPT/app.py b/video_chat_with_ChatGPT/app.py index 73d18f8..d934890 100644 --- a/video_chat_with_ChatGPT/app.py +++ b/video_chat_with_ChatGPT/app.py @@ -8,6 +8,7 @@ from util import * import gradio as gr from chatbot import * +from chatbot import LLM_PROVIDERS from load_internvideo import * device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') from simplet5 import SimpleT5 @@ -116,7 +117,7 @@ def set_example_video(example: list) -> dict: with gr.Column(): input_video_path = gr.inputs.Video(label="Input Video") input_tag = gr.Textbox(lines=1, label="User Prompt (Optional, Enter with commas)",visible=False) - + with gr.Row(): with gr.Column(sclae=0.3, min_width=0): caption = gr.Button("✍ Upload") @@ -124,9 +125,14 @@ def set_example_video(example: list) -> dict: with gr.Column(scale=0.7, min_width=0): loadinglabel = gr.Label(label="State") with gr.Column(): + llm_provider = gr.Dropdown( + choices=list(LLM_PROVIDERS.keys()), + value="openai", + label="LLM Provider", + ) openai_api_key_textbox = gr.Textbox( - value=os.environ["OPENAI_API_KEY"], - placeholder="Paste your OpenAI API key here to start (sk-...)", + value=os.environ.get("OPENAI_API_KEY", ""), + placeholder="Paste your API key here to start", show_label=False, lines=1, type="password", @@ -156,7 +162,7 @@ def set_example_video(example: list) -> dict: caption.click(lambda: [], None, state) caption.click(inference,[input_video_path,input_tag],[model_tag_output, user_tag_output, image_caption_output, dense_caption_output,video_caption_output, chat_video, loadinglabel]) - chat_video.click(bot.init_agent, [openai_api_key_textbox, image_caption_output, dense_caption_output, video_caption_output, model_tag_output, state], [input_raws,chatbot, state, openai_api_key_textbox]) + chat_video.click(bot.init_agent, [openai_api_key_textbox, image_caption_output, dense_caption_output, video_caption_output, model_tag_output, state, llm_provider], [input_raws,chatbot, state, openai_api_key_textbox]) txt.submit(bot.run_text, [txt, state], [chatbot, state]) txt.submit(lambda: "", None, txt) diff --git a/video_chat_with_ChatGPT/chatbot.py b/video_chat_with_ChatGPT/chatbot.py index 1cb94b6..1759575 100644 --- a/video_chat_with_ChatGPT/chatbot.py +++ b/video_chat_with_ChatGPT/chatbot.py @@ -2,10 +2,64 @@ from langchain.agents.tools import Tool from langchain.chains.conversation.memory import ConversationBufferMemory from langchain.llms.openai import OpenAI +from langchain.chat_models import ChatOpenAI +import os import re import gradio as gr import openai +# Supported LLM providers and their default models +LLM_PROVIDERS = { + "openai": { + "default_model": "gpt-4", + "api_base": None, # uses default OpenAI endpoint + }, + "minimax": { + "default_model": "MiniMax-M2.7", + "api_base": "https://api.minimax.io/v1", + }, +} + + +def create_llm(provider, api_key, model_name=None, temperature=0): + """Create an LLM instance based on the selected provider. + + Args: + provider: LLM provider name ("openai" or "minimax"). + api_key: API key for the chosen provider. + model_name: Model name override. Uses provider default when None. + temperature: Sampling temperature. + + Returns: + A LangChain LLM or ChatModel instance. + """ + provider = provider.lower() + if provider not in LLM_PROVIDERS: + raise ValueError( + f"Unsupported provider '{provider}'. " + f"Supported: {list(LLM_PROVIDERS.keys())}" + ) + + cfg = LLM_PROVIDERS[provider] + model = model_name or cfg["default_model"] + + if provider == "minimax": + # MiniMax requires temperature in (0.0, 1.0] + temperature = max(0.01, min(temperature, 1.0)) + return ChatOpenAI( + model_name=model, + openai_api_key=api_key, + openai_api_base=cfg["api_base"], + temperature=temperature, + ) + + # Default: OpenAI + return OpenAI( + temperature=temperature, + openai_api_key=api_key, + model_name=model, + ) + def cut_dialogue_history(history_memory, keep_last_n_words=400): if history_memory is None or len(history_memory) == 0: @@ -32,14 +86,14 @@ def run_text(self, text, state): self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500) res = self.agent({"input": text.strip()}) res['output'] = res['output'].replace("\\", "/") - response = res['output'] + response = res['output'] state = state + [(text, response)] print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n" f"Current Memory: {self.agent.memory.buffer}") return state, state - def init_agent(self, openai_api_key, image_caption, dense_caption, video_caption, tags, state): + def init_agent(self, api_key, image_caption, dense_caption, video_caption, tags, state, provider="openai"): chat_history ='' PREFIX = "ChatVideo is a chatbot that chats with you based on video descriptions." FORMAT_INSTRUCTIONS = """ @@ -65,10 +119,18 @@ def init_agent(self, openai_api_key, image_caption, dense_caption, video_caption {agent_scratchpad} """ self.memory.clear() - if not openai_api_key.startswith('sk-'): - return gr.update(visible = False),state, state, "Please paste your key here !" - self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key,model_name="gpt-4") - # openai.api_base = 'https://api.openai-proxy.com/v1/' + + # Resolve provider from argument or environment + provider = (provider or os.environ.get("LLM_PROVIDER", "openai")).lower() + + if not api_key or not api_key.strip(): + return gr.update(visible=False), state, state, "Please paste your API key!" + + # Provider-specific API key validation + if provider == "openai" and not api_key.startswith("sk-"): + return gr.update(visible=False), state, state, "Please paste your OpenAI key (sk-...)!" + + self.llm = create_llm(provider=provider, api_key=api_key) self.agent = initialize_agent( self.tools, self.llm, @@ -78,7 +140,7 @@ def init_agent(self, openai_api_key, image_caption, dense_caption, video_caption return_intermediate_steps=True, agent_kwargs={'prefix': PREFIX, 'format_instructions': FORMAT_INSTRUCTIONS, 'suffix': SUFFIX}, ) state = state + [("I upload a video, Please watch it first! ","I have watch this video, Let's chat!")] - return gr.update(visible = True),state, state, openai_api_key + return gr.update(visible = True),state, state, api_key if __name__=="__main__": import pdb diff --git a/video_chat_with_ChatGPT/tests/__init__.py b/video_chat_with_ChatGPT/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/video_chat_with_ChatGPT/tests/test_chatbot.py b/video_chat_with_ChatGPT/tests/test_chatbot.py new file mode 100644 index 0000000..63e6fb8 --- /dev/null +++ b/video_chat_with_ChatGPT/tests/test_chatbot.py @@ -0,0 +1,238 @@ +"""Unit tests for chatbot.py – LLM provider support.""" + +import os +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +# Allow imports from parent directory +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +# Ensure langchain sub-modules required by chatbot.py are importable +# (the project pins langchain==0.0.101 whose import paths differ from the +# version installed on CI). +_STUBS = {} +for _mod_name in [ + "langchain.agents.initialize", + "langchain.agents.tools", + "langchain.chains.conversation.memory", +]: + if _mod_name not in sys.modules: + _stub = ModuleType(_mod_name) + sys.modules[_mod_name] = _stub + _STUBS[_mod_name] = _stub + +# Provide stub symbols so `from langchain.agents.tools import Tool` works +if "langchain.agents.tools" in _STUBS: + _STUBS["langchain.agents.tools"].Tool = MagicMock +if "langchain.agents.initialize" in _STUBS: + _STUBS["langchain.agents.initialize"].initialize_agent = MagicMock +if "langchain.chains.conversation.memory" in _STUBS: + _STUBS["langchain.chains.conversation.memory"].ConversationBufferMemory = MagicMock + +from chatbot import LLM_PROVIDERS, ConversationBot, create_llm, cut_dialogue_history + + +# --------------------------------------------------------------------------- +# cut_dialogue_history +# --------------------------------------------------------------------------- + + +class TestCutDialogueHistory: + def test_empty_history(self): + assert cut_dialogue_history("") == "" + + def test_none_history(self): + assert cut_dialogue_history(None) is None + + def test_short_history(self): + short = "hello world" + assert cut_dialogue_history(short, keep_last_n_words=100) == short + + def test_long_history_trimmed(self): + history = "\n".join([f"line {i}" for i in range(100)]) + result = cut_dialogue_history(history, keep_last_n_words=10) + assert len(result.split()) <= 20 # trimmed + + +# --------------------------------------------------------------------------- +# LLM_PROVIDERS registry +# --------------------------------------------------------------------------- + + +class TestLLMProviders: + def test_openai_registered(self): + assert "openai" in LLM_PROVIDERS + + def test_minimax_registered(self): + assert "minimax" in LLM_PROVIDERS + + def test_minimax_api_base(self): + assert LLM_PROVIDERS["minimax"]["api_base"] == "https://api.minimax.io/v1" + + def test_minimax_default_model(self): + assert LLM_PROVIDERS["minimax"]["default_model"] == "MiniMax-M2.7" + + def test_openai_default_model(self): + assert LLM_PROVIDERS["openai"]["default_model"] == "gpt-4" + + +# --------------------------------------------------------------------------- +# create_llm +# --------------------------------------------------------------------------- + + +class TestCreateLLM: + def test_unsupported_provider_raises(self): + with pytest.raises(ValueError, match="Unsupported provider"): + create_llm("nonexistent", "key123") + + @patch("chatbot.OpenAI") + def test_openai_provider(self, mock_openai): + mock_openai.return_value = MagicMock() + llm = create_llm("openai", "sk-test123") + mock_openai.assert_called_once_with( + temperature=0, openai_api_key="sk-test123", model_name="gpt-4" + ) + + @patch("chatbot.ChatOpenAI") + def test_minimax_provider(self, mock_chat): + mock_chat.return_value = MagicMock() + llm = create_llm("minimax", "mm-key-abc") + mock_chat.assert_called_once_with( + model_name="MiniMax-M2.7", + openai_api_key="mm-key-abc", + openai_api_base="https://api.minimax.io/v1", + temperature=0.01, + ) + + @patch("chatbot.ChatOpenAI") + def test_minimax_temperature_clamped_above_zero(self, mock_chat): + mock_chat.return_value = MagicMock() + create_llm("minimax", "key", temperature=0) + _, kwargs = mock_chat.call_args + assert kwargs["temperature"] >= 0.01 + + @patch("chatbot.ChatOpenAI") + def test_minimax_temperature_clamped_at_max(self, mock_chat): + mock_chat.return_value = MagicMock() + create_llm("minimax", "key", temperature=2.0) + _, kwargs = mock_chat.call_args + assert kwargs["temperature"] <= 1.0 + + @patch("chatbot.ChatOpenAI") + def test_minimax_custom_model(self, mock_chat): + mock_chat.return_value = MagicMock() + create_llm("minimax", "key", model_name="MiniMax-M2.7-highspeed") + _, kwargs = mock_chat.call_args + assert kwargs["model_name"] == "MiniMax-M2.7-highspeed" + + @patch("chatbot.OpenAI") + def test_openai_custom_model(self, mock_openai): + mock_openai.return_value = MagicMock() + create_llm("openai", "sk-test", model_name="gpt-3.5-turbo") + _, kwargs = mock_openai.call_args + assert kwargs["model_name"] == "gpt-3.5-turbo" + + def test_case_insensitive_provider(self): + with patch("chatbot.ChatOpenAI") as mock_chat: + mock_chat.return_value = MagicMock() + create_llm("MiniMax", "key") + mock_chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# ConversationBot.init_agent +# --------------------------------------------------------------------------- + + +class TestConversationBotInitAgent: + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_openai(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_init_agent.return_value = MagicMock() + bot = ConversationBot() + result = bot.init_agent("sk-test", "cap", "dense", "vid", "tags", [], "openai") + mock_create.assert_called_once_with(provider="openai", api_key="sk-test") + assert len(result) == 4 + + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_minimax(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_init_agent.return_value = MagicMock() + bot = ConversationBot() + result = bot.init_agent("mm-key", "cap", "dense", "vid", "tags", [], "minimax") + mock_create.assert_called_once_with(provider="minimax", api_key="mm-key") + + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_empty_key_rejected(self, mock_create, mock_init_agent): + bot = ConversationBot() + result = bot.init_agent("", "cap", "dense", "vid", "tags", [], "openai") + mock_create.assert_not_called() + assert "API key" in result[3] + + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_openai_bad_key_rejected(self, mock_create, mock_init_agent): + bot = ConversationBot() + result = bot.init_agent("bad-key", "cap", "dense", "vid", "tags", [], "openai") + mock_create.assert_not_called() + assert "sk-" in result[3] + + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_minimax_no_sk_required(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_init_agent.return_value = MagicMock() + bot = ConversationBot() + result = bot.init_agent("mm-any-key", "c", "d", "v", "t", [], "minimax") + mock_create.assert_called_once() + + @patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}) + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_env_var_provider(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_init_agent.return_value = MagicMock() + bot = ConversationBot() + # provider=None should fallback to env var + result = bot.init_agent("mm-key", "c", "d", "v", "t", [], None) + mock_create.assert_called_once_with(provider="minimax", api_key="mm-key") + + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_init_agent_state_appended(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_init_agent.return_value = MagicMock() + bot = ConversationBot() + result = bot.init_agent("sk-test", "c", "d", "v", "t", [], "openai") + # state should contain the welcome message + state = result[1] + assert len(state) == 1 + assert "upload a video" in state[0][0].lower() + + +# --------------------------------------------------------------------------- +# ConversationBot.run_text +# --------------------------------------------------------------------------- + + +class TestConversationBotRunText: + @patch("chatbot.initialize_agent") + @patch("chatbot.create_llm") + def test_run_text(self, mock_create, mock_init_agent): + mock_create.return_value = MagicMock() + mock_agent = MagicMock() + mock_agent.return_value = {"output": "Test response"} + mock_agent.memory = MagicMock() + mock_agent.memory.buffer = "" + mock_init_agent.return_value = mock_agent + bot = ConversationBot() + bot.init_agent("sk-test", "c", "d", "v", "t", [], "openai") + result_state, _ = bot.run_text("hello", []) + assert result_state[-1][1] == "Test response" diff --git a/video_chat_with_ChatGPT/tests/test_integration.py b/video_chat_with_ChatGPT/tests/test_integration.py new file mode 100644 index 0000000..edb191d --- /dev/null +++ b/video_chat_with_ChatGPT/tests/test_integration.py @@ -0,0 +1,68 @@ +"""Integration tests for MiniMax LLM provider. + +These tests hit the real MiniMax API and are skipped when MINIMAX_API_KEY +is not set. +""" + +import os +import sys +from types import ModuleType +from unittest.mock import MagicMock + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +# Stub old langchain imports that may not exist in newer versions +_STUBS = {} +for _mod_name in [ + "langchain.agents.initialize", + "langchain.agents.tools", + "langchain.chains.conversation.memory", +]: + if _mod_name not in sys.modules: + _stub = ModuleType(_mod_name) + sys.modules[_mod_name] = _stub + _STUBS[_mod_name] = _stub + +if "langchain.agents.tools" in _STUBS: + _STUBS["langchain.agents.tools"].Tool = MagicMock +if "langchain.agents.initialize" in _STUBS: + _STUBS["langchain.agents.initialize"].initialize_agent = MagicMock +if "langchain.chains.conversation.memory" in _STUBS: + _STUBS["langchain.chains.conversation.memory"].ConversationBufferMemory = MagicMock + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY", "") +skip_no_key = pytest.mark.skipif( + not MINIMAX_API_KEY, reason="MINIMAX_API_KEY not set" +) + + +@skip_no_key +class TestMiniMaxIntegration: + def test_create_llm_minimax(self): + from chatbot import create_llm + + llm = create_llm("minimax", MINIMAX_API_KEY) + assert llm is not None + + def test_minimax_chat_completion(self): + from chatbot import create_llm + + llm = create_llm("minimax", MINIMAX_API_KEY, temperature=0.5) + # ChatOpenAI supports predict() / invoke() + response = llm.predict("Say hello in one word.") + assert isinstance(response, str) + assert len(response) > 0 + + def test_minimax_m27_highspeed(self): + from chatbot import create_llm + + llm = create_llm( + "minimax", MINIMAX_API_KEY, + model_name="MiniMax-M2.7-highspeed", + temperature=0.5, + ) + response = llm.predict("What is 2+2? Answer with just the number.") + assert isinstance(response, str) + assert "4" in response