From eb191481d516f978aaf9e02d8a773007a3f60d28 Mon Sep 17 00:00:00 2001 From: Pascal Tomecek Date: Tue, 3 Feb 2026 09:57:56 -0500 Subject: [PATCH 1/3] Add first cut of panel ui components for viewing the registry and models Signed-off-by: Pascal Tomecek --- ccflow/base.py | 27 ++ ccflow/tests/test_base.py | 26 ++ ccflow/tests/ui/__init__.py | 0 ccflow/tests/ui/test_cli.py | 136 +++++++ ccflow/tests/ui/test_model.py | 555 ++++++++++++++++++++++++++++ ccflow/tests/ui/test_registry.py | 602 +++++++++++++++++++++++++++++++ ccflow/tests/ui/utils.py | 29 ++ ccflow/ui/__init__.py | 3 + ccflow/ui/cli.py | 172 +++++++++ ccflow/ui/model.py | 255 +++++++++++++ ccflow/ui/registry.py | 181 ++++++++++ pyproject.toml | 6 + 12 files changed, 1992 insertions(+) create mode 100644 ccflow/tests/ui/__init__.py create mode 100644 ccflow/tests/ui/test_cli.py create mode 100644 ccflow/tests/ui/test_model.py create mode 100644 ccflow/tests/ui/test_registry.py create mode 100644 ccflow/tests/ui/utils.py create mode 100644 ccflow/ui/__init__.py create mode 100644 ccflow/ui/cli.py create mode 100644 ccflow/ui/model.py create mode 100644 ccflow/ui/registry.py diff --git a/ccflow/base.py b/ccflow/base.py index 2a35454..03d5b18 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -241,6 +241,24 @@ def get_widget( # Can't use self.model_dump_json or self.model_dump because they don't expose the fallback argument return JSON(self.__pydantic_serializer__.to_python(self, **kwargs), **(widget_kwargs or {})) + def __panel__(self): + """Return a Panel viewable for this model. + + Requires ccflow UI dependencies (panel, panel_material_ui). + """ + from ccflow.ui.model import ModelViewer + + return ModelViewer(model=self) + + def get_panel(self): + """Get a Panel pane for this model. + + Requires panel to be installed. + """ + import panel as pn + + return pn.panel(self) + @model_validator(mode="wrap") def _base_model_validator(cls, v, handler, info): if isinstance(v, str): @@ -400,6 +418,15 @@ def models(self) -> MappingProxyType: """Return an immutable pointer to the models dictionary.""" return MappingProxyType(self._models) + def __panel__(self): + """Return a Panel viewable for this registry. + + Requires ccflow UI dependencies (panel, panel_material_ui). + """ + from ccflow.ui.registry import ModelRegistryViewer + + return ModelRegistryViewer(self) + @classmethod def root(cls) -> Self: """Return a static instance of the root registry.""" diff --git a/ccflow/tests/test_base.py b/ccflow/tests/test_base.py index e53e0af..26623fb 100644 --- a/ccflow/tests/test_base.py +++ b/ccflow/tests/test_base.py @@ -173,6 +173,32 @@ def test_widget(self): ), ) + def test_panel(self): + from ccflow import ModelRegistry + from ccflow.ui.model import ModelViewer + from ccflow.ui.registry import ModelRegistryViewer + + m = ModelA(x="foo") + panel_obj = m.__panel__() + self.assertIsInstance(panel_obj, ModelViewer) + + registry = ModelRegistry(name="test", models={"a": m}) + registry_panel_obj = registry.__panel__() + self.assertIsInstance(registry_panel_obj, ModelRegistryViewer) + + def test_get_panel(self): + import panel as pn + + from ccflow import ModelRegistry + + m = ModelA(x="foo") + panel_pane = m.get_panel() + self.assertIsInstance(panel_pane, pn.viewable.Viewable) + + registry = ModelRegistry(name="test", models={"a": m}) + registry_pane = registry.get_panel() + self.assertIsInstance(registry_pane, pn.viewable.Viewable) + class TestLocalRegistration(TestCase): def test_local_class_registered_for_base_model(self): diff --git a/ccflow/tests/ui/__init__.py b/ccflow/tests/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ccflow/tests/ui/test_cli.py b/ccflow/tests/ui/test_cli.py new file mode 100644 index 0000000..5d6388b --- /dev/null +++ b/ccflow/tests/ui/test_cli.py @@ -0,0 +1,136 @@ +"""Unit tests for ccflow.ui.cli module.""" + +from ccflow.ui.cli import _get_ui_args_parser + + +class TestGetUIArgsParser: + """Tests for _get_ui_args_parser function.""" + + def test_parser_has_config_args(self): + """Test parser includes config arguments.""" + parser = _get_ui_args_parser() + + # Parse with config args + args = parser.parse_args( + [ + "--config-path", + "/path/to/config", + "--config-name", + "base", + ] + ) + + assert args.config_path == "/path/to/config" + assert args.config_name == "base" + + def test_parser_defaults(self): + """Test parser default values.""" + parser = _get_ui_args_parser() + args = parser.parse_args(["-cp", ".", "-cn", "test"]) + + assert args.address == "127.0.0.1" + assert args.port == 8080 + assert args.browser_width == 400 + assert args.browser_height == 700 + assert args.viewer_width is None + assert args.show is False + + def test_parser_overrides(self): + """Test parser accepts override arguments.""" + parser = _get_ui_args_parser() + args = parser.parse_args( + [ + "-cp", + ".", + "-cn", + "test", + "key1=value1", + "key2=value2", + ] + ) + + assert args.overrides == ["key1=value1", "key2=value2"] + + def test_parser_ui_args(self): + """Test parser UI server arguments.""" + parser = _get_ui_args_parser() + args = parser.parse_args( + [ + "-cp", + ".", + "-cn", + "test", + "--address", + "0.0.0.0", + "--port", + "9000", + "--show", + ] + ) + + assert args.address == "0.0.0.0" + assert args.port == 9000 + assert args.show is True + + def test_parser_viewer_layout_args(self): + """Test parser viewer layout arguments.""" + parser = _get_ui_args_parser() + args = parser.parse_args( + [ + "-cp", + ".", + "-cn", + "test", + "--browser-width", + "500", + "--browser-height", + "800", + "--viewer-width", + "600", + ] + ) + + assert args.browser_width == 500 + assert args.browser_height == 800 + assert args.viewer_width == 600 + + def test_parser_websocket_origin(self): + """Test parser websocket origin argument.""" + parser = _get_ui_args_parser() + args = parser.parse_args( + [ + "-cp", + ".", + "-cn", + "test", + "--allow-websocket-origin", + "localhost:8080", + "example.com", + ] + ) + + assert args.allow_websocket_origin == ["localhost:8080", "example.com"] + + def test_parser_config_dir_args(self): + """Test parser config directory arguments.""" + parser = _get_ui_args_parser() + args = parser.parse_args( + [ + "-cp", + "/root/config", + "-cn", + "base", + "-cd", + "/extra/config", + "-cdcn", + "override", + "--basepath", + "/search/from/here", + ] + ) + + assert args.config_path == "/root/config" + assert args.config_name == "base" + assert args.config_dir == "/extra/config" + assert args.config_dir_config_name == "override" + assert args.basepath == "/search/from/here" diff --git a/ccflow/tests/ui/test_model.py b/ccflow/tests/ui/test_model.py new file mode 100644 index 0000000..043dbd6 --- /dev/null +++ b/ccflow/tests/ui/test_model.py @@ -0,0 +1,555 @@ +"""Unit tests for ccflow.ui.model module.""" + +import panel as pn +from pydantic import Field + +from ccflow import BaseModel, CallableModel, ContextBase, Flow, GenericResult, MetaData, ModelRegistry +from ccflow.ui.model import ModelConfigViewer, ModelTypeViewer, ModelViewer + +from .utils import find_components_by_type + + +class SimpleModel(BaseModel): + """A simple test model with documentation.""" + + name: str + value: int = 0 + + +class NoDocModel(BaseModel): + field: str + + +class DescribedModel(BaseModel): + """Model with field descriptions.""" + + name: str = Field(description="The name field") + count: int = Field(description="The count field") + + +class ContainerModel(BaseModel): + """Model that contains another model.""" + + inner: SimpleModel + + +class SampleContext(ContextBase): + """Sample context for callable models.""" + + input_value: str = "" + + +class SampleResult(GenericResult): + """Sample result for callable models.""" + + output_value: str = "" + + +class SimpleCallableModel(CallableModel): + """A simple callable model for testing.""" + + multiplier: int = 1 + + @Flow.call + def __call__(self, context: SampleContext) -> SampleResult: + return SampleResult(output_value=context.input_value * self.multiplier) + + +class TestModelTypeViewer: + """Tests for ModelTypeViewer class.""" + + def test_init_returns_viewable(self): + """Test ModelTypeViewer returns a Panel viewable.""" + viewer = ModelTypeViewer() + panel = viewer.__panel__() + assert isinstance(panel, pn.viewable.Viewable) + + def test_panel_contains_html_pane(self): + """Test that the panel contains an HTML pane for displaying content.""" + viewer = ModelTypeViewer() + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + assert len(html_panes) > 0 + + def test_on_type_change_with_none(self): + """Test that setting model_type to None clears the display.""" + viewer = ModelTypeViewer() + viewer.model_type = SimpleModel + viewer.model_type = None + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + # All HTML panes should be empty + for pane in html_panes: + assert pane.object == "" + + def test_on_type_change_with_model(self): + """Test that setting model_type displays type information.""" + viewer = ModelTypeViewer() + viewer.model_type = SimpleModel + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + # Check that type name is displayed + assert "SimpleModel" in html_content + # Check that documentation is displayed + assert "A simple test model with documentation" in html_content + # Check that fields are displayed + assert "name" in html_content + assert "value" in html_content + assert "str" in html_content + assert "int" in html_content + + def test_on_type_change_without_docstring(self): + """Test model type without a docstring.""" + viewer = ModelTypeViewer() + viewer.model_type = NoDocModel + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "NoDocModel" in html_content + assert "field" in html_content + # Should not have documentation section + assert "Class Documentation:" not in html_content + + def test_on_type_change_with_field_descriptions(self): + """Test that field descriptions are displayed.""" + viewer = ModelTypeViewer() + viewer.model_type = DescribedModel + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "The name field" in html_content + assert "The count field" in html_content + + +class TestModelConfigViewer: + """Tests for ModelConfigViewer class.""" + + def test_init_returns_viewable(self): + """Test ModelConfigViewer returns a Panel viewable.""" + viewer = ModelConfigViewer() + panel = viewer.__panel__() + assert isinstance(panel, pn.viewable.Viewable) + + def test_panel_contains_html_pane(self): + """Test that the panel contains an HTML pane for metadata display.""" + viewer = ModelConfigViewer() + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + assert len(html_panes) > 0 + + def test_on_model_change_with_none(self): + """Test that setting model to None clears the metadata display.""" + viewer = ModelConfigViewer() + model = SimpleModel(name="test") + viewer.model = model + viewer.model = None + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + for pane in html_panes: + assert pane.object == "" + + def test_on_model_change_with_model(self): + """Test that setting model displays metadata.""" + viewer = ModelConfigViewer() + model = SimpleModel(name="test", value=42) + viewer.model = model + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + # Should not crash, content depends on model metadata + assert len(html_panes) > 0 + + def test_on_model_change_with_description(self): + """Test model with meta description.""" + viewer = ModelConfigViewer() + model = SimpleCallableModel( + multiplier=2, + meta=MetaData(description="This is a test model description"), + ) + viewer.model = model + + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "This is a test model description" in html_content + + def test_render_dependencies_empty(self): + """Test _render_dependencies with no dependencies.""" + viewer = ModelConfigViewer() + model = SimpleModel(name="test") + result = viewer._render_dependencies(model) + assert result == "" + + def test_render_dependencies_with_deps(self): + """Test _render_dependencies with dependencies.""" + root = ModelRegistry.root() + registry = ModelRegistry(name="test_dep_registry") + root.add("test_dep_registry", registry) + + try: + model = SimpleModel(name="test") + registry.add("my_model", model) + + container = ContainerModel(inner=model) + registry.add("container", container) + + viewer = ModelConfigViewer() + result = viewer._render_dependencies(container) + + assert "Registry Dependencies" in result + assert "my_model" in result + finally: + root.remove("test_dep_registry") + + +class TestModelViewer: + """Tests for ModelViewer class.""" + + def test_init_returns_viewable(self): + """Test ModelViewer returns a Panel viewable.""" + viewer = ModelViewer() + panel = viewer.__panel__() + assert isinstance(panel, pn.viewable.Viewable) + + def test_panel_contains_json_editor(self): + """Test that the panel contains a JSON editor widget.""" + viewer = ModelViewer() + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert len(json_editors) > 0 + + def test_json_editor_initially_empty(self): + """Test that JSON editor is initially empty.""" + viewer = ModelViewer() + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value == {} + + def test_on_model_change_with_none(self): + """Test that setting model to None clears the JSON editor.""" + viewer = ModelViewer() + viewer.model = SimpleModel(name="test") + viewer.model = None + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value == {} + + def test_on_model_change_with_base_model(self): + """Test that setting a BaseModel populates the JSON editor.""" + viewer = ModelViewer() + model = SimpleModel(name="test", value=42) + viewer.model = model + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + json_value = json_editors[0].value + + assert "name" in json_value + assert json_value["name"] == "test" + assert json_value["value"] == 42 + + def test_on_model_change_with_callable_model(self): + """Test that setting a CallableModel sets up type viewers correctly.""" + viewer = ModelViewer() + model = SimpleCallableModel(multiplier=2) + viewer.model = model + + # Verify the internal type viewers are properly configured + assert viewer._config_viewer.model == model + assert viewer._type_viewer.model_type is type(model) + assert viewer._context_type_viewer.model_type == model.context_type + assert viewer._result_type_viewer.model_type == model.result_type + + def test_on_model_change_updates_viewers(self): + """Test that changing model updates the viewers.""" + viewer = ModelViewer() + + # Set a callable model first + callable_model = SimpleCallableModel(multiplier=2) + viewer.model = callable_model + + # Then set a base model + base_model = SimpleModel(name="test") + viewer.model = base_model + + # Viewers should be updated for the new model + assert viewer._config_viewer.model == base_model + assert viewer._type_viewer.model_type is type(base_model) + + def test_json_serialization(self): + """Test that JSON editor correctly serializes model data.""" + viewer = ModelViewer() + model = SimpleModel(name="test_name", value=123) + viewer.model = model + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + json_value = json_editors[0].value + + assert json_value["name"] == "test_name" + assert json_value["value"] == 123 + + +class TestModelSwitching: + """Tests for switching between different models to ensure proper state reset.""" + + def test_switch_between_base_models_same_type(self): + """Test switching between two BaseModel instances of the same type.""" + viewer = ModelViewer() + + model1 = SimpleModel(name="first", value=1) + model2 = SimpleModel(name="second", value=2) + + # Set first model + viewer.model = model1 + assert viewer._config_viewer.model == model1 + assert viewer._type_viewer.model_type == SimpleModel + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == "first" + assert json_editors[0].value["value"] == 1 + + # Switch to second model + viewer.model = model2 + assert viewer._config_viewer.model == model2 + assert viewer._type_viewer.model_type == SimpleModel + + # Verify JSON editor updated (not still showing old model) + assert json_editors[0].value["name"] == "second" + assert json_editors[0].value["value"] == 2 + + def test_switch_between_base_models_different_types(self): + """Test switching between two BaseModel instances of different types.""" + viewer = ModelViewer() + + model1 = SimpleModel(name="simple", value=42) + model2 = DescribedModel(name="described", count=100) + + # Set first model + viewer.model = model1 + assert viewer._type_viewer.model_type == SimpleModel + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == "simple" + assert "value" in json_editors[0].value + + # Switch to different type + viewer.model = model2 + assert viewer._config_viewer.model == model2 + assert viewer._type_viewer.model_type == DescribedModel + + # JSON editor should show new model's fields, not old ones + assert json_editors[0].value["name"] == "described" + assert json_editors[0].value["count"] == 100 + assert "value" not in json_editors[0].value + + def test_switch_from_callable_to_base_model(self): + """Test switching from CallableModel to BaseModel removes context/result tabs.""" + viewer = ModelViewer() + + callable_model = SimpleCallableModel(multiplier=5) + base_model = SimpleModel(name="base", value=10) + + # Set callable model first + viewer.model = callable_model + assert viewer._context_type_viewer.model_type == callable_model.context_type + assert viewer._result_type_viewer.model_type == callable_model.result_type + + # Should have 4 tabs: Summary, Model Type, Context Type, Result Type + assert len(viewer._tabs) == 4 + + # Switch to base model + viewer.model = base_model + assert viewer._config_viewer.model == base_model + assert viewer._type_viewer.model_type == SimpleModel + + # Should have only 2 tabs now: Summary, Model Type + assert len(viewer._tabs) == 2 + + # JSON editor should show base model data + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == "base" + assert json_editors[0].value["value"] == 10 + assert "multiplier" not in json_editors[0].value + + def test_switch_from_base_to_callable_model(self): + """Test switching from BaseModel to CallableModel adds context/result tabs.""" + viewer = ModelViewer() + + base_model = SimpleModel(name="base", value=10) + callable_model = SimpleCallableModel(multiplier=5) + + # Set base model first + viewer.model = base_model + assert len(viewer._tabs) == 2 + + # Switch to callable model + viewer.model = callable_model + assert viewer._config_viewer.model == callable_model + assert viewer._type_viewer.model_type == SimpleCallableModel + assert viewer._context_type_viewer.model_type == callable_model.context_type + assert viewer._result_type_viewer.model_type == callable_model.result_type + + # Should have 4 tabs now + assert len(viewer._tabs) == 4 + + # JSON editor should show callable model data + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert "multiplier" in json_editors[0].value + assert json_editors[0].value["multiplier"] == 5 + + def test_switch_to_none_clears_state(self): + """Test switching from a model to None clears all state.""" + viewer = ModelViewer() + + model = SimpleCallableModel(multiplier=3) + + # Set model + viewer.model = model + assert len(viewer._tabs) == 4 + assert viewer._json_container.visible is True + + # Clear model + viewer.model = None + assert len(viewer._tabs) == 0 + assert viewer._json_container.visible is False + + panel = viewer.__panel__() + json_editors = find_components_by_type(panel, pn.widgets.JSONEditor) + assert json_editors[0].value == {} + + def test_tabs_reset_to_first_on_model_switch(self): + """Test that tab selection resets to first tab when switching models.""" + viewer = ModelViewer() + + model1 = SimpleCallableModel(multiplier=1) + model2 = SimpleCallableModel(multiplier=2) + + # Set first model and change active tab + viewer.model = model1 + viewer._tabs.active = 2 # Select "Context Type" tab + + # Switch to second model + viewer.model = model2 + + # Tab should reset to first (Summary) + assert viewer._tabs.active == 0 + + +class TestModelTypeViewerSwitching: + """Tests for ModelTypeViewer switching between types.""" + + def test_switch_model_types(self): + """Test switching between different model types updates display.""" + viewer = ModelTypeViewer() + + # Set first type + viewer.model_type = SimpleModel + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "SimpleModel" in html_content + assert "name" in html_content + assert "value" in html_content + + # Switch to different type + viewer.model_type = DescribedModel + html_content = "".join(pane.object for pane in html_panes) + + # Should show new type, not old + assert "DescribedModel" in html_content + assert "count" in html_content + assert "SimpleModel" not in html_content + # "name" is in both, but "value" should not be present + assert "value" not in html_content + + def test_switch_to_none_clears_display(self): + """Test switching to None clears the display.""" + viewer = ModelTypeViewer() + + viewer.model_type = SimpleModel + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + + # Verify content is present + html_content = "".join(pane.object for pane in html_panes) + assert "SimpleModel" in html_content + + # Clear + viewer.model_type = None + html_content = "".join(pane.object for pane in html_panes) + assert html_content == "" + + +class TestModelConfigViewerSwitching: + """Tests for ModelConfigViewer switching between models.""" + + def test_switch_models_with_different_metadata(self): + """Test switching between models with different metadata.""" + viewer = ModelConfigViewer() + + model1 = SimpleCallableModel( + multiplier=1, + meta=MetaData(description="First model description"), + ) + model2 = SimpleCallableModel( + multiplier=2, + meta=MetaData(description="Second model description"), + ) + + # Set first model + viewer.model = model1 + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "First model description" in html_content + + # Switch to second model + viewer.model = model2 + html_content = "".join(pane.object for pane in html_panes) + + # Should show new description, not old + assert "Second model description" in html_content + assert "First model description" not in html_content + + def test_switch_from_model_with_description_to_without(self): + """Test switching from model with description to one without.""" + viewer = ModelConfigViewer() + + model_with_desc = SimpleCallableModel( + multiplier=1, + meta=MetaData(description="Has a description"), + ) + model_without_desc = SimpleModel(name="no desc") + + # Set model with description + viewer.model = model_with_desc + panel = viewer.__panel__() + html_panes = find_components_by_type(panel, pn.pane.HTML) + html_content = "".join(pane.object for pane in html_panes) + + assert "Has a description" in html_content + + # Switch to model without description + viewer.model = model_without_desc + html_content = "".join(pane.object for pane in html_panes) + + # Old description should not be present + assert "Has a description" not in html_content diff --git a/ccflow/tests/ui/test_registry.py b/ccflow/tests/ui/test_registry.py new file mode 100644 index 0000000..c1266f9 --- /dev/null +++ b/ccflow/tests/ui/test_registry.py @@ -0,0 +1,602 @@ +"""Unit tests for ccflow.ui.registry module.""" + +from unittest import mock + +import panel as pn + +from ccflow import BaseModel, ModelRegistry +from ccflow.ui.registry import ModelRegistryViewer, RegistryBrowser + +from .utils import find_components_by_type + + +class SimpleModel(BaseModel): + """A simple test model.""" + + name: str + value: int = 0 + + +class AnotherModel(BaseModel): + """Another test model.""" + + data: str = "" + + +class TestRegistryBrowser: + """Tests for RegistryBrowser class.""" + + def test_init_returns_viewable(self): + """Test RegistryBrowser returns a Panel viewable.""" + registry = ModelRegistry(name="test") + browser = RegistryBrowser(registry) + panel = browser.__panel__() + assert isinstance(panel, pn.viewable.Viewable) + + def test_panel_contains_autocomplete_search(self): + """Test that the panel contains an autocomplete search widget.""" + registry = ModelRegistry(name="test") + browser = RegistryBrowser(registry) + panel = browser.__panel__() + autocomplete = find_components_by_type(panel, pn.widgets.AutocompleteInput) + assert len(autocomplete) > 0 + + def test_init_with_empty_registry(self): + """Test RegistryBrowser initialization with empty registry.""" + registry = ModelRegistry(name="test") + browser = RegistryBrowser(registry) + + assert browser.selected_model is None + # Search options should be empty + panel = browser.__panel__() + autocomplete = find_components_by_type(panel, pn.widgets.AutocompleteInput) + assert autocomplete[0].options == [] + + def test_init_with_models(self): + """Test RegistryBrowser initialization with models in registry.""" + registry = ModelRegistry(name="test") + model1 = SimpleModel(name="test1", value=1) + model2 = AnotherModel(data="test2") + registry.add("model1", model1) + registry.add("model2", model2) + + browser = RegistryBrowser(registry) + + # Search options should contain the model names + panel = browser.__panel__() + autocomplete = find_components_by_type(panel, pn.widgets.AutocompleteInput) + assert "model1" in autocomplete[0].options + assert "model2" in autocomplete[0].options + + def test_build_tree_simple(self): + """Test _build_tree with flat registry.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + + browser = RegistryBrowser(registry) + tree_items = browser._tree_items + + assert len(tree_items) == 1 + assert tree_items[0]["label"] == "my_model" + assert tree_items[0]["model"] == model + assert tree_items[0]["_index_path"] == (0,) + + def test_build_tree_nested(self): + """Test _build_tree with nested registries.""" + root = ModelRegistry(name="root") + sub = ModelRegistry(name="sub") + model = SimpleModel(name="test", value=1) + + sub.add("nested_model", model) + root.add("subregistry", sub) + + browser = RegistryBrowser(root) + tree_items = browser._tree_items + + assert len(tree_items) == 1 + assert tree_items[0]["label"] == "subregistry" + assert "items" in tree_items[0] + assert len(tree_items[0]["items"]) == 1 + assert tree_items[0]["items"][0]["label"] == "nested_model" + assert tree_items[0]["items"][0]["model"] == model + + def test_build_node_index(self): + """Test _build_node_index creates correct path mappings.""" + root = ModelRegistry(name="root") + sub = ModelRegistry(name="sub") + model1 = SimpleModel(name="test1", value=1) + model2 = SimpleModel(name="test2", value=2) + + root.add("top_model", model1) + sub.add("nested_model", model2) + root.add("subregistry", sub) + + browser = RegistryBrowser(root) + + assert "top_model" in browser._node_index + assert "subregistry/nested_model" in browser._node_index + assert browser._node_index["top_model"]["model"] == model1 + assert browser._node_index["subregistry/nested_model"]["model"] == model2 + + def test_expanded_from_index_path(self): + """Test _expanded_from_index_path generates correct expanded paths.""" + # Single level - no expansion needed + result = RegistryBrowser._expanded_from_index_path((0,)) + assert result == [] + + # Two levels + result = RegistryBrowser._expanded_from_index_path((0, 1)) + assert result == [(0,)] + + # Three levels + result = RegistryBrowser._expanded_from_index_path((0, 1, 2)) + assert result == [(0,), (0, 1)] + + # Four levels + result = RegistryBrowser._expanded_from_index_path((1, 2, 3, 4)) + assert result == [(1,), (1, 2), (1, 2, 3)] + + def test_on_search_select_empty(self): + """Test _on_search_select with empty value.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = "" + + browser._on_search_select(event) + assert browser.selected_model is None + + def test_on_search_select_invalid_path(self): + """Test _on_search_select with non-existent path.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = "nonexistent" + + browser._on_search_select(event) + assert browser.selected_model is None + + def test_on_search_select_valid_path(self): + """Test _on_search_select with valid path.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = "my_model" + + browser._on_search_select(event) + + # Tree value should be set + assert len(browser._tree.value) == 1 + assert browser._tree.value[0]["label"] == "my_model" + + def test_on_tree_select_empty(self): + """Test _on_tree_select with empty selection.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = [] + + browser._on_tree_select(event) + assert browser.selected_model is None + + def test_on_tree_select_model(self): + """Test _on_tree_select with model selection.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=1) + registry.add("my_model", model) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = [{"label": "my_model", "model": model}] + + browser._on_tree_select(event) + assert browser.selected_model == model + + def test_on_tree_select_registry(self): + """Test _on_tree_select with registry selection (no model key).""" + registry = ModelRegistry(name="test") + sub = ModelRegistry(name="sub") + registry.add("subregistry", sub) + browser = RegistryBrowser(registry) + + event = mock.Mock() + event.new = [{"label": "subregistry", "items": []}] + + browser._on_tree_select(event) + assert browser.selected_model is None + + def test_search_options_sorted(self): + """Test that search widget options are sorted.""" + root = ModelRegistry(name="root") + sub = ModelRegistry(name="sub") + model1 = SimpleModel(name="test1", value=1) + model2 = SimpleModel(name="test2", value=2) + + root.add("zebra", model1) + sub.add("alpha", model2) + root.add("subregistry", sub) + + browser = RegistryBrowser(root) + panel = browser.__panel__() + autocomplete = find_components_by_type(panel, pn.widgets.AutocompleteInput) + + assert autocomplete[0].options == sorted(autocomplete[0].options) + + +class TestModelRegistryViewer: + """Tests for ModelRegistryViewer class.""" + + def test_init_returns_viewable(self): + """Test ModelRegistryViewer returns a Panel viewable.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + panel = viewer.__panel__() + assert isinstance(panel, pn.viewable.Viewable) + + def test_panel_is_row_layout(self): + """Test that the panel is a Row layout (browser + viewer side by side).""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + panel = viewer.__panel__() + assert isinstance(panel, pn.Row) + + def test_init_with_custom_dimensions(self): + """Test initialization with custom width/height.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer( + registry, + browser_width=500, + browser_height=800, + viewer_width=600, + ) + + assert viewer.browser_width == 500 + assert viewer.browser_height == 800 + assert viewer.viewer_width == 600 + + def test_browser_viewer_wiring(self): + """Test that browser selection updates viewer.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=42) + registry.add("my_model", model) + + viewer = ModelRegistryViewer(registry) + + # Simulate browser selection + viewer._browser.selected_model = model + + # Viewer should be updated + assert viewer._viewer.model == model + + def test_default_browser_dimensions(self): + """Test default browser dimensions.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + + assert viewer.browser_width == 400 + assert viewer.browser_height == 700 + assert viewer.viewer_width is None + + def test_make_browser_column(self): + """Test _make_browser_column creates proper column.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + + column = viewer._make_browser_column() + assert isinstance(column, pn.Column) + assert column.width == viewer.browser_width + assert column.height == viewer.browser_height + assert column.scroll is True + + def test_make_viewer_column_with_width(self): + """Test _make_viewer_column with specified width.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry, viewer_width=600) + + column = viewer._make_viewer_column() + assert isinstance(column, pn.Column) + assert column.width == 600 + + def test_make_viewer_column_without_width(self): + """Test _make_viewer_column without specified width (stretch).""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + + column = viewer._make_viewer_column() + assert isinstance(column, pn.Column) + assert column.sizing_mode == "stretch_width" + + def test_model_param_default_none(self): + """Test that model param starts as None.""" + registry = ModelRegistry(name="test") + viewer = ModelRegistryViewer(registry) + assert viewer.model is None + + def test_model_param_updated_on_selection(self): + """Test that model param is updated when a model is selected.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=42) + registry.add("my_model", model) + + viewer = ModelRegistryViewer(registry) + viewer._browser.selected_model = model + + assert viewer.model == model + + def test_model_param_cleared_on_deselection(self): + """Test that model param is cleared when selection is cleared.""" + registry = ModelRegistry(name="test") + model = SimpleModel(name="test", value=42) + registry.add("my_model", model) + + viewer = ModelRegistryViewer(registry) + viewer._browser.selected_model = model + assert viewer.model == model + + viewer._browser.selected_model = None + assert viewer.model is None + + +class TestIntegration: + """Integration tests for UI components.""" + + def test_full_workflow(self): + """Test complete workflow from registry to viewer.""" + # Create a registry with nested structure + root = ModelRegistry(name="root") + models_reg = ModelRegistry(name="models") + configs_reg = ModelRegistry(name="configs") + + model1 = SimpleModel(name="model1", value=1) + model2 = SimpleModel(name="model2", value=2) + config1 = AnotherModel(data="config1") + + models_reg.add("first", model1) + models_reg.add("second", model2) + configs_reg.add("main", config1) + + root.add("models", models_reg) + root.add("configs", configs_reg) + + # Create viewer + viewer = ModelRegistryViewer(root) + + # Verify browser has all paths indexed + browser = viewer._browser + assert "models/first" in browser._node_index + assert "models/second" in browser._node_index + assert "configs/main" in browser._node_index + + # Simulate selecting a model + browser.selected_model = model1 + + # Verify viewer is updated + assert viewer._viewer.model == model1 + + def test_nested_registry_expansion(self): + """Test that nested paths generate correct expansion.""" + root = ModelRegistry(name="root") + level1 = ModelRegistry(name="level1") + level2 = ModelRegistry(name="level2") + model = SimpleModel(name="deep", value=99) + + level2.add("deep_model", model) + level1.add("level2", level2) + root.add("level1", level1) + + browser = RegistryBrowser(root) + + # Find the deep model node + node = browser._node_index["level1/level2/deep_model"] + + # Check expansion path + expanded = browser._expanded_from_index_path(node["_index_path"]) + assert len(expanded) == 2 # level1 and level1/level2 + + def test_switch_model_selection(self): + """Test switching between different models updates viewer correctly.""" + root = ModelRegistry(name="root") + model1 = SimpleModel(name="first", value=1) + model2 = SimpleModel(name="second", value=2) + + root.add("model1", model1) + root.add("model2", model2) + + viewer = ModelRegistryViewer(root) + + # Select first model + viewer._browser.selected_model = model1 + assert viewer._viewer.model == model1 + assert viewer._viewer._config_viewer.model == model1 + + # Switch to second model + viewer._browser.selected_model = model2 + assert viewer._viewer.model == model2 + assert viewer._viewer._config_viewer.model == model2 + + # Verify JSON editor shows second model's data + json_editors = find_components_by_type(viewer._viewer.__panel__(), pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == "second" + assert json_editors[0].value["value"] == 2 + + def test_switch_between_different_model_types(self): + """Test switching between models of different types in the viewer.""" + root = ModelRegistry(name="root") + simple_model = SimpleModel(name="simple", value=42) + another_model = AnotherModel(data="test data") + + root.add("simple", simple_model) + root.add("another", another_model) + + viewer = ModelRegistryViewer(root) + + # Select SimpleModel + viewer._browser.selected_model = simple_model + assert viewer._viewer._type_viewer.model_type == SimpleModel + + json_editors = find_components_by_type(viewer._viewer.__panel__(), pn.widgets.JSONEditor) + assert "name" in json_editors[0].value + assert "value" in json_editors[0].value + + # Switch to AnotherModel + viewer._browser.selected_model = another_model + assert viewer._viewer._type_viewer.model_type == AnotherModel + + # JSON should now show AnotherModel's fields + assert "data" in json_editors[0].value + assert "value" not in json_editors[0].value + assert json_editors[0].value["data"] == "test data" + + def test_deselect_model_clears_viewer(self): + """Test that deselecting a model (selecting registry) clears the viewer.""" + root = ModelRegistry(name="root") + sub = ModelRegistry(name="sub") + model = SimpleModel(name="test", value=1) + + sub.add("model", model) + root.add("subregistry", sub) + + viewer = ModelRegistryViewer(root) + + # Select model first + viewer._browser.selected_model = model + assert viewer._viewer.model == model + assert viewer._viewer._json_container.visible is True + + # Deselect (simulate selecting registry node which has no model) + viewer._browser.selected_model = None + assert viewer._viewer.model is None + assert viewer._viewer._json_container.visible is False + + def test_tree_selection_updates_viewer(self): + """Test that tree selection properly updates the viewer through the wiring.""" + root = ModelRegistry(name="root") + model1 = SimpleModel(name="first", value=1) + model2 = SimpleModel(name="second", value=2) + + root.add("model1", model1) + root.add("model2", model2) + + viewer = ModelRegistryViewer(root) + browser = viewer._browser + + # Simulate tree selection of first model + event = mock.Mock() + event.new = [{"label": "model1", "model": model1}] + browser._on_tree_select(event) + + assert browser.selected_model == model1 + assert viewer._viewer.model == model1 + + # Simulate tree selection of second model + event.new = [{"label": "model2", "model": model2}] + browser._on_tree_select(event) + + assert browser.selected_model == model2 + assert viewer._viewer.model == model2 + + def test_search_then_switch_models(self): + """Test using search to select models and then switching.""" + root = ModelRegistry(name="root") + model1 = SimpleModel(name="alpha", value=1) + model2 = SimpleModel(name="beta", value=2) + + root.add("alpha_model", model1) + root.add("beta_model", model2) + + viewer = ModelRegistryViewer(root) + browser = viewer._browser + + # Search and select first model + event = mock.Mock() + event.new = "alpha_model" + browser._on_search_select(event) + + # Tree should be updated + assert len(browser._tree.value) == 1 + assert browser._tree.value[0]["label"] == "alpha_model" + + # Simulate the tree select callback that would happen + tree_event = mock.Mock() + tree_event.new = browser._tree.value + browser._on_tree_select(tree_event) + + assert viewer._viewer.model == model1 + + # Now search and select second model + event.new = "beta_model" + browser._on_search_select(event) + + tree_event.new = browser._tree.value + browser._on_tree_select(tree_event) + + assert viewer._viewer.model == model2 + + # Verify viewer shows second model + json_editors = find_components_by_type(viewer._viewer.__panel__(), pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == "beta" + + def test_rapid_model_switching(self): + """Test rapidly switching between multiple models.""" + root = ModelRegistry(name="root") + models = [SimpleModel(name=f"model_{i}", value=i) for i in range(5)] + + for i, model in enumerate(models): + root.add(f"model_{i}", model) + + viewer = ModelRegistryViewer(root) + + # Rapidly switch through all models + for i, model in enumerate(models): + viewer._browser.selected_model = model + assert viewer._viewer.model == model + assert viewer._viewer._config_viewer.model == model + + json_editors = find_components_by_type(viewer._viewer.__panel__(), pn.widgets.JSONEditor) + assert json_editors[0].value["name"] == f"model_{i}" + assert json_editors[0].value["value"] == i + + def test_switch_models_in_nested_registries(self): + """Test switching between models in different nested registries.""" + root = ModelRegistry(name="root") + reg_a = ModelRegistry(name="reg_a") + reg_b = ModelRegistry(name="reg_b") + + model_a = SimpleModel(name="in_a", value=100) + model_b = AnotherModel(data="in_b") + + reg_a.add("model", model_a) + reg_b.add("model", model_b) + root.add("registry_a", reg_a) + root.add("registry_b", reg_b) + + viewer = ModelRegistryViewer(root) + + # Verify both paths are indexed + assert "registry_a/model" in viewer._browser._node_index + assert "registry_b/model" in viewer._browser._node_index + + # Select model from registry_a + viewer._browser.selected_model = model_a + assert viewer._viewer._type_viewer.model_type == SimpleModel + + # Switch to model from registry_b (different type) + viewer._browser.selected_model = model_b + assert viewer._viewer._type_viewer.model_type == AnotherModel + + json_editors = find_components_by_type(viewer._viewer.__panel__(), pn.widgets.JSONEditor) + assert "data" in json_editors[0].value + assert "value" not in json_editors[0].value diff --git a/ccflow/tests/ui/utils.py b/ccflow/tests/ui/utils.py new file mode 100644 index 0000000..f1a4391 --- /dev/null +++ b/ccflow/tests/ui/utils.py @@ -0,0 +1,29 @@ +"""Utility functions for UI tests.""" + +import panel as pn + + +def find_components_by_type(layout, component_type): + """Recursively find all components of a given type in a Panel layout. + + Args: + layout: A Panel layout or component to search + component_type: The type of component to find + + Returns: + A list of all components matching the given type + """ + found = [] + if isinstance(layout, component_type): + found.append(layout) + if hasattr(layout, "objects"): + for obj in layout.objects: + found.extend(find_components_by_type(obj, component_type)) + if hasattr(layout, "__iter__") and not isinstance(layout, str): + try: + for item in layout: + if hasattr(item, "objects") or isinstance(item, pn.viewable.Viewable): + found.extend(find_components_by_type(item, component_type)) + except TypeError: + pass + return found diff --git a/ccflow/ui/__init__.py b/ccflow/ui/__init__.py new file mode 100644 index 0000000..417aeab --- /dev/null +++ b/ccflow/ui/__init__.py @@ -0,0 +1,3 @@ +from .cli import * +from .model import * +from .registry import * diff --git a/ccflow/ui/cli.py b/ccflow/ui/cli.py new file mode 100644 index 0000000..a0801d8 --- /dev/null +++ b/ccflow/ui/cli.py @@ -0,0 +1,172 @@ +"""CLI for serving ModelRegistryViewer as a Panel application.""" + +import argparse +import inspect +from pathlib import Path +from typing import Callable, Optional + +import panel as pn + +from ccflow import ModelRegistry +from ccflow.utils.hydra import load_config + +from .registry import ModelRegistryViewer + +__all__ = ("registry_viewer_cli",) + + +def _get_ui_args_parser() -> argparse.ArgumentParser: + """Create argument parser with UI server configuration options.""" + parser = argparse.ArgumentParser( + add_help=True, + description="Serve ModelRegistryViewer as a Panel application", + ) + + # Registry loading arguments (similar to utils.hydra) + parser.add_argument( + "overrides", + nargs="*", + help="Key=value arguments to override config values", + ) + parser.add_argument( + "--config-path", + "-cp", + help="Path to the Hydra config directory", + ) + parser.add_argument( + "--config-name", + "-cn", + help="Name of the config file (without .yaml extension)", + ) + parser.add_argument( + "--config-dir", + "-cd", + help="Additional config directory to add to search path", + ) + parser.add_argument( + "--config-dir-config-name", + "-cdcn", + help="Config name to look for within config-dir", + ) + parser.add_argument( + "--basepath", + help="Base path for searching config directories", + ) + + # UI server arguments + parser.add_argument( + "--address", + type=str, + default="127.0.0.1", + help="Address to bind the server to (default: 127.0.0.1)", + ) + parser.add_argument( + "--port", + type=int, + default=8080, + help="Port to bind the server to (default: 8080)", + ) + parser.add_argument( + "--allow-websocket-origin", + type=str, + nargs="+", + default=["*"], + help="Allowed websocket origins (default: *)", + ) + parser.add_argument( + "--show", + action="store_true", + help="Open browser automatically", + ) + + # Viewer layout arguments + parser.add_argument( + "--browser-width", + type=int, + default=400, + help="Width of the registry browser panel (default: 400)", + ) + parser.add_argument( + "--browser-height", + type=int, + default=700, + help="Height of the registry browser panel (default: 700)", + ) + parser.add_argument( + "--viewer-width", + type=int, + default=None, + help="Fixed width for model viewer panel (default: stretch)", + ) + + return parser + + +def registry_viewer_cli( + config_path: str = "", + config_name: str = "", + hydra_main: Optional[Callable] = None, +): + """CLI entry point for serving ModelRegistryViewer. + + Parameters + ---------- + config_path + The config_path specified in hydra.main() + config_name + The config_name specified in hydra.main() + hydra_main + The function decorated with hydra.main(). Used to resolve config_path + relative to the decorated function's file location. + """ + parser = _get_ui_args_parser() + args = parser.parse_args() + + # Resolve config path (same logic as cfg_explain_cli) + if args.config_path: + root_config_dir = args.config_path + elif hydra_main and config_path: + root_config_dir = str(Path(inspect.getfile(hydra_main.__wrapped__)).parent / config_path) + else: + raise ValueError("Must provide --config-path.") + + # Resolve config name + if args.config_name: + root_config_name = args.config_name + elif config_name: + root_config_name = config_name + else: + raise ValueError("Must provide --config-name.") + + # Load config using hydra utilities + result = load_config( + root_config_dir=root_config_dir, + root_config_name=root_config_name, + config_dir=args.config_dir, + config_name=args.config_dir_config_name, + overrides=args.overrides, + basepath=args.basepath, + ) + + # Load registry from config + registry = ModelRegistry.root() + registry.load_config(cfg=result.cfg, overwrite=True) + + # Create app factory for per-session instances + def create_app(): + viewer = ModelRegistryViewer( + registry, + browser_width=args.browser_width, + browser_height=args.browser_height, + viewer_width=args.viewer_width, + ) + return viewer.__panel__() + + # Serve the panel app (callable = fresh instance per session) + pn.serve( + create_app, + address=args.address, + port=args.port, + allow_websocket_origin=args.allow_websocket_origin, + show=args.show, + ) diff --git a/ccflow/ui/model.py b/ccflow/ui/model.py new file mode 100644 index 0000000..d320a18 --- /dev/null +++ b/ccflow/ui/model.py @@ -0,0 +1,255 @@ +import html + +import bleach +import panel as pn + +# Register extensions +import panel_material_ui # noqa: F401 +import panel_material_ui as pmui +import param +from pydantic._internal._repr import display_as_type + +import ccflow + +pn.extension() +pn.extension("jsoneditor") + + +__all__ = ("ModelTypeViewer", "ModelViewer", "ModelConfigViewer") + + +class ModelTypeViewer(param.Parameterized): + """ + Displays type name, class docstring, and fields for a Pydantic model type. + """ + + model_type = param.Parameter(default=None) + + def __init__(self, **params): + super().__init__(**params) + + self._pane = pn.pane.HTML("", width=1200) + self._layout = pn.Column( + self._pane, + ) + + self.param.watch(self._on_type_change, "model_type") + + def __panel__(self): + return self._layout + + def _on_type_change(self, event): + model_cls = event.new + if model_cls is None: + self._pane.object = "" + return + + type_name = display_as_type(model_cls) + + # Class documentation + docs = (model_cls.__doc__ or "").strip() + docs_html = "" + if docs: + escaped = html.escape(docs).replace("\n", "
") + docs_html = f""" +
+
Class Documentation:
+
{escaped}
+
+ """ + + # Fields + fields = getattr(model_cls, "model_fields", {}) + field_items = [] + + for name, field in fields.items(): + field_type = display_as_type(field.annotation) + desc = field.description or "" + field_items.append( + f"
  • {html.escape(name)} ({html.escape(field_type)}){': ' + html.escape(desc) if desc else ''}
  • " + ) + + fields_html = "" + if field_items: + fields_html = f""" +
    +
    Fields:
    +
      + {"".join(field_items)} +
    +
    + """ + + self._pane.object = f""" +
    +
    + Type: + {html.escape(type_name)} +
    + {docs_html} + {fields_html} +
    + """ + + +class ModelConfigViewer(param.Parameterized): + """ + Displays instance-level metadata (description + dependencies). + """ + + model = param.Parameter(default=None) + + def __init__(self, **params): + super().__init__(**params) + + self._metadata = pn.pane.HTML("", width=1200) + + self._layout = pn.Column( + self._metadata, + ) + + self.param.watch(self._on_model_change, "model") + + def __panel__(self): + return self._layout + + # ------------------------------------------------------------ + + def _render_dependencies(self, model): + deps = model.get_registry_dependencies() + if not deps: + return "" + + # Collect all values, deduplicate, and sort + all_paths = [] + for group in deps: + if len(group) == 1: + all_paths.append(group[0]) + else: + all_paths.append(" | ".join(group)) + + # Unique elements, sorted + rows = sorted(set(all_paths)) + + items = "".join(f"
  • {html.escape(row)}
  • " for row in rows) + + return f""" +
    +
    + Registry Dependencies +
    +
      + {items} +
    +
    + """ + + def _on_model_change(self, event): + model = event.new + if model is None: + self._metadata.object = "" + return + + description = model.meta.description.strip() if hasattr(model, "meta") and model.meta.description else "" + + desc_html = "" + if description: + desc_html = f""" +
    +
    Instance Description
    +
    {bleach.linkify(html.escape(description))}
    +
    + """ + + self._metadata.object = desc_html + self._render_dependencies(model) + + +class ModelViewer(param.Parameterized): + """ + Displays a tabbed view of a ccflow Model instance, including description, registry dependencies, docstrings and json representation. + """ + + model = param.Parameter(default=None) + + def __init__(self, **params): + super().__init__(**params) + + # Sub-viewers (no JSONEditor inside) + self._config_viewer = ModelConfigViewer() + self._type_viewer = ModelTypeViewer() + self._context_type_viewer = ModelTypeViewer() + self._result_type_viewer = ModelTypeViewer() + + # Material UI Tabs (metadata only) + self._tabs = pmui.Tabs( + active=0, + sizing_mode="stretch_width", + ) + + # JSON editor (stable, but hidden until a model is selected) + self._json_editor = pn.widgets.JSONEditor( + value={}, + mode="view", + menu=False, + width=600, + ) + + self._json_container = pn.Column( + "## Parameters", + self._json_editor, + visible=False, # hidden initially + ) + + self._layout = pn.Column( + "## Model Viewer", + self._tabs, + pn.Spacer(height=12), + self._json_container, + ) + + self.param.watch(self._on_model_change, "model") + + def __panel__(self): + return self._layout + + # ------------------------------------------------------------ + + def _on_model_change(self, event): + model = event.new + self._tabs.clear() + + if model is None: + # hide JSON editor if no model + self._json_editor.value = {} + self._json_container.visible = False + return + + # ---------------- Config tab ---------------- + self._config_viewer.model = model + self._tabs.append(("Summary", self._config_viewer)) + + # ---------------- Model Type tab ---------------- + self._type_viewer.model_type = type(model) + self._tabs.append(("Model Type", self._type_viewer)) + + # ---------------- CallableModel extras ---------------- + if isinstance(model, ccflow.CallableModel): + self._context_type_viewer.model_type = model.context_type + self._tabs.append(("Context Type", self._context_type_viewer)) + + self._result_type_viewer.model_type = model.result_type + self._tabs.append(("Result Type", self._result_type_viewer)) + + # Default to Config tab + self._tabs.active = 0 + + # Update & show JSONEditor + self._json_editor.value = model.__pydantic_serializer__.to_python(model, fallback=str, mode="json") + self._json_container.visible = True diff --git a/ccflow/ui/registry.py b/ccflow/ui/registry.py new file mode 100644 index 0000000..380b441 --- /dev/null +++ b/ccflow/ui/registry.py @@ -0,0 +1,181 @@ +import panel as pn + +# Register extensions +import panel_material_ui # noqa: F401 +import panel_material_ui as pmui +import param + +from .model import ModelViewer + +pn.extension() + +__all__ = ("RegistryBrowser", "ModelRegistryViewer") + + +class RegistryBrowser(param.Parameterized): + selected_model = param.Parameter(default=None) + + def __init__(self, registry, **params): + super().__init__(**params) + self._registry = registry + + self._tree_items = self._build_tree(registry) + self._node_index = self._build_node_index(self._tree_items) + + self._tree = pmui.Tree( + items=self._tree_items, + multi_select=False, + ) + + self._search = pn.widgets.AutocompleteInput( + name="Search", + options=sorted(self._node_index.keys()), + placeholder="Search full path…", + case_sensitive=False, + search_strategy="includes", + min_characters=1, + sizing_mode="stretch_width", + ) + + self._search.param.watch(self._on_search_select, "value") + self._tree.param.watch(self._on_tree_select, "value") + + self._layout = pn.Column( + "## Registry", + self._search, + self._tree, + ) + + def __panel__(self): + return self._layout + + # ---------------- Tree construction ---------------- + + def _build_tree(self, registry, index_prefix=()): + import ccflow + + items = [] + for i, (name, model) in enumerate(registry.models.items()): + index_path = index_prefix + (i,) + entry = { + "label": name, + "_index_path": index_path, + } + + if isinstance(model, ccflow.ModelRegistry): + entry["items"] = self._build_tree(model, index_prefix=index_path) + else: + entry["model"] = model + + items.append(entry) + + return items + + def _build_node_index(self, tree_items): + index = {} + + def walk(items, prefix=""): + for node in items: + path = f"{prefix}/{node['label']}" if prefix else node["label"] + if "model" in node: + index[path] = node + walk(node.get("items", []), path) + + walk(tree_items) + return index + + @staticmethod + def _expanded_from_index_path(index_path): + return [index_path[:i] for i in range(1, len(index_path))] + + # ---------------- Callbacks ---------------- + + def _on_search_select(self, event): + path = event.new + if not path: + return + node = self._node_index.get(path) + if not node: + return + self._tree.expanded = self._expanded_from_index_path(node["_index_path"]) + self._tree.value = [node] + self._search.value = "" + + def _on_tree_select(self, event): + self.selected_model = event.new[0].get("model") if event.new else None + + +class ModelRegistryViewer(param.Parameterized): + """ + Top-level viewer that composes the RegistryBrowser and ModelViewer + into a scrollable two-panel layout. + """ + + # ---------------- Layout parameters ---------------- + browser_width = param.Integer( + default=400, + bounds=(200, None), + doc="Width of the registry browser panel (px)", + ) + + browser_height = param.Integer( + default=700, + bounds=(300, None), + doc="Height of the registry browser panel (px)", + ) + + viewer_width = param.Integer( + default=None, + allow_None=True, + doc="Optional fixed width for the model viewer panel (px)", + ) + + model = param.Parameter( + default=None, + doc="The currently selected model from the registry browser", + ) + + def __init__(self, registry, **params): + super().__init__(**params) + + # Core components + self._browser = RegistryBrowser(registry) + self._viewer = ModelViewer() + + # Wire browser → viewer and model param + def _on_selection(e): + self.model = e.new + self._viewer.model = e.new + + self._browser.param.watch(_on_selection, "selected_model") + + # Build layout + self._layout = pn.Row( + self._make_browser_column(), + self._make_viewer_column(), + ) + + def __panel__(self): + return self._layout + + # ---------------- Internal helpers ---------------- + + def _make_browser_column(self): + return pn.Column( + self._browser, + width=self.browser_width, + height=self.browser_height, + scroll=True, # ✅ only left panel scrolls + ) + + def _make_viewer_column(self): + if self.viewer_width is not None: + return pn.Column( + self._viewer, + width=self.viewer_width, + ) + else: + return pn.Column( + self._viewer, + sizing_mode="stretch_width", + ) diff --git a/pyproject.toml b/pyproject.toml index e0ba5f6..986571f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,12 @@ dependencies = [ [project.optional-dependencies] full = [ + "bleach", "cexprtk", "duckdb", "pandas", + "panel", + "panel_material_ui", "plotly", "polars", "ray", @@ -75,10 +78,13 @@ develop = [ "uv", "wheel", # Full deps + "bleach", "cexprtk", "csp>=0.8.0,<1", "duckdb", "pandas", + "panel", + "panel_material_ui", "plotly", "polars", "ray", From 279788d8500319dd56f71cdc4950aeabd0b4c610 Mon Sep 17 00:00:00 2001 From: Pascal Tomecek Date: Wed, 4 Feb 2026 17:21:03 -0500 Subject: [PATCH 2/3] Clean up import error handling for optional UI imports Signed-off-by: Pascal Tomecek --- ccflow/base.py | 14 ++++++++++++-- ccflow/ui/model.py | 13 ++++++++----- ccflow/ui/registry.py | 4 +--- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/ccflow/base.py b/ccflow/base.py index 03d5b18..2309577 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -246,7 +246,12 @@ def __panel__(self): Requires ccflow UI dependencies (panel, panel_material_ui). """ - from ccflow.ui.model import ModelViewer + try: + from ccflow.ui.model import ModelViewer + except ImportError: + raise ImportError( + "panel and other optional dependencies must be installed to use ModelViewer. Pip install ccflow[full] to install all optional dependencies." + ) from None return ModelViewer(model=self) @@ -255,7 +260,12 @@ def get_panel(self): Requires panel to be installed. """ - import panel as pn + try: + import panel as pn + except ImportError: + raise ImportError( + "panel and other optional dependencies must be installed to use get_panel(). Pip install ccflow[full] to install all optional dependencies." + ) from None return pn.panel(self) diff --git a/ccflow/ui/model.py b/ccflow/ui/model.py index d320a18..a631e3b 100644 --- a/ccflow/ui/model.py +++ b/ccflow/ui/model.py @@ -1,10 +1,7 @@ import html -import bleach import panel as pn - -# Register extensions -import panel_material_ui # noqa: F401 +import panel_material_ui # noqa: F401 Must be imported like this to register the extension import panel_material_ui as pmui import param from pydantic._internal._repr import display_as_type @@ -161,10 +158,16 @@ def _on_model_change(self, event): desc_html = "" if description: + try: + import bleach + + description = bleach.linkify(html.escape(description)) + except ImportError: + description = html.escape(description) desc_html = f"""
    Instance Description
    -
    {bleach.linkify(html.escape(description))}
    +
    {description}
    """ diff --git a/ccflow/ui/registry.py b/ccflow/ui/registry.py index 380b441..6e6c810 100644 --- a/ccflow/ui/registry.py +++ b/ccflow/ui/registry.py @@ -1,7 +1,5 @@ import panel as pn - -# Register extensions -import panel_material_ui # noqa: F401 +import panel_material_ui # noqa: F401 Must be imported like this to register the extension import panel_material_ui as pmui import param From bbcd76cd64f66bd7de20e3c9da45e37931dce00d Mon Sep 17 00:00:00 2001 From: Pascal Tomecek Date: Mon, 9 Feb 2026 16:13:06 -0500 Subject: [PATCH 3/3] Refactor Panel CLI creation to re-use more common components Signed-off-by: Pascal Tomecek --- ccflow/tests/ui/test_cli.py | 132 +++++----------------- ccflow/tests/utils/test_hydra.py | 188 ++++++++++++++++++++++++++++++- ccflow/ui/cli.py | 82 ++------------ ccflow/utils/hydra.py | 124 +++++++++++++++----- 4 files changed, 324 insertions(+), 202 deletions(-) diff --git a/ccflow/tests/ui/test_cli.py b/ccflow/tests/ui/test_cli.py index 5d6388b..d209aea 100644 --- a/ccflow/tests/ui/test_cli.py +++ b/ccflow/tests/ui/test_cli.py @@ -4,83 +4,47 @@ class TestGetUIArgsParser: - """Tests for _get_ui_args_parser function.""" + """Tests for _get_ui_args_parser function. - def test_parser_has_config_args(self): - """Test parser includes config arguments.""" + Note: Default values for hydra config args and panel server args are tested + in ccflow/tests/utils/test_hydra.py. These tests focus on viewer-specific + arguments and verifying the parser composition works correctly. + """ + + def test_parser_composition(self): + """Test parser includes args from both helper functions.""" parser = _get_ui_args_parser() + args = parser.parse_args([]) - # Parse with config args - args = parser.parse_args( - [ - "--config-path", - "/path/to/config", - "--config-name", - "base", - ] - ) + # From add_hydra_config_args + assert hasattr(args, "overrides") + assert hasattr(args, "config_path") + assert hasattr(args, "config_name") - assert args.config_path == "/path/to/config" - assert args.config_name == "base" + # From add_panel_server_args + assert hasattr(args, "address") + assert hasattr(args, "port") + assert hasattr(args, "show") - def test_parser_defaults(self): - """Test parser default values.""" + # Viewer-specific + assert hasattr(args, "browser_width") + assert hasattr(args, "browser_height") + assert hasattr(args, "viewer_width") + + def test_viewer_layout_defaults(self): + """Test default values for viewer-specific arguments.""" parser = _get_ui_args_parser() - args = parser.parse_args(["-cp", ".", "-cn", "test"]) + args = parser.parse_args([]) - assert args.address == "127.0.0.1" - assert args.port == 8080 assert args.browser_width == 400 assert args.browser_height == 700 assert args.viewer_width is None - assert args.show is False - def test_parser_overrides(self): - """Test parser accepts override arguments.""" + def test_viewer_layout_custom_values(self): + """Test setting custom values for viewer layout arguments.""" parser = _get_ui_args_parser() args = parser.parse_args( [ - "-cp", - ".", - "-cn", - "test", - "key1=value1", - "key2=value2", - ] - ) - - assert args.overrides == ["key1=value1", "key2=value2"] - - def test_parser_ui_args(self): - """Test parser UI server arguments.""" - parser = _get_ui_args_parser() - args = parser.parse_args( - [ - "-cp", - ".", - "-cn", - "test", - "--address", - "0.0.0.0", - "--port", - "9000", - "--show", - ] - ) - - assert args.address == "0.0.0.0" - assert args.port == 9000 - assert args.show is True - - def test_parser_viewer_layout_args(self): - """Test parser viewer layout arguments.""" - parser = _get_ui_args_parser() - args = parser.parse_args( - [ - "-cp", - ".", - "-cn", - "test", "--browser-width", "500", "--browser-height", @@ -94,43 +58,9 @@ def test_parser_viewer_layout_args(self): assert args.browser_height == 800 assert args.viewer_width == 600 - def test_parser_websocket_origin(self): - """Test parser websocket origin argument.""" - parser = _get_ui_args_parser() - args = parser.parse_args( - [ - "-cp", - ".", - "-cn", - "test", - "--allow-websocket-origin", - "localhost:8080", - "example.com", - ] - ) - - assert args.allow_websocket_origin == ["localhost:8080", "example.com"] - - def test_parser_config_dir_args(self): - """Test parser config directory arguments.""" + def test_overrides_positional(self): + """Test overrides are captured as positional arguments.""" parser = _get_ui_args_parser() - args = parser.parse_args( - [ - "-cp", - "/root/config", - "-cn", - "base", - "-cd", - "/extra/config", - "-cdcn", - "override", - "--basepath", - "/search/from/here", - ] - ) + args = parser.parse_args(["key1=value1", "key2=value2"]) - assert args.config_path == "/root/config" - assert args.config_name == "base" - assert args.config_dir == "/extra/config" - assert args.config_dir_config_name == "override" - assert args.basepath == "/search/from/here" + assert args.overrides == ["key1=value1", "key2=value2"] diff --git a/ccflow/tests/utils/test_hydra.py b/ccflow/tests/utils/test_hydra.py index 06e2d9d..7c9886e 100644 --- a/ccflow/tests/utils/test_hydra.py +++ b/ccflow/tests/utils/test_hydra.py @@ -1,10 +1,196 @@ +import argparse import sys from pathlib import Path +from unittest.mock import MagicMock import pytest from hydra import compose, initialize -from ccflow.utils.hydra import get_args_parser_default, load_config +from ccflow.utils.hydra import ( + add_hydra_config_args, + add_panel_server_args, + get_args_parser_default, + load_config, + resolve_config_paths, +) + + +class TestAddHydraConfigArgs: + """Tests for add_hydra_config_args helper function.""" + + def test_adds_all_arguments(self): + """Test that all expected arguments are added.""" + parser = argparse.ArgumentParser() + add_hydra_config_args(parser) + args = parser.parse_args([]) + + assert hasattr(args, "overrides") + assert hasattr(args, "config_path") + assert hasattr(args, "config_name") + assert hasattr(args, "config_dir") + assert hasattr(args, "config_dir_config_name") + assert hasattr(args, "basepath") + + def test_default_values(self): + """Test default values for all arguments.""" + parser = argparse.ArgumentParser() + add_hydra_config_args(parser) + args = parser.parse_args([]) + + assert args.overrides == [] + assert args.config_path is None + assert args.config_name is None + assert args.config_dir is None + assert args.config_dir_config_name is None + assert args.basepath is None + + def test_short_flags(self): + """Test short flag aliases work correctly.""" + parser = argparse.ArgumentParser() + add_hydra_config_args(parser) + args = parser.parse_args(["-cp", "/path", "-cn", "name", "-cd", "/dir", "-cdcn", "dirname"]) + + assert args.config_path == "/path" + assert args.config_name == "name" + assert args.config_dir == "/dir" + assert args.config_dir_config_name == "dirname" + + def test_overrides_positional(self): + """Test overrides are captured as positional arguments.""" + parser = argparse.ArgumentParser() + add_hydra_config_args(parser) + args = parser.parse_args(["key1=value1", "key2=value2", "+group=option"]) + + assert args.overrides == ["key1=value1", "key2=value2", "+group=option"] + + +class TestAddPanelServerArgs: + """Tests for add_panel_server_args helper function.""" + + def test_adds_all_arguments(self): + """Test that all expected arguments are added.""" + parser = argparse.ArgumentParser() + add_panel_server_args(parser) + args = parser.parse_args([]) + + assert hasattr(args, "address") + assert hasattr(args, "port") + assert hasattr(args, "allow_websocket_origin") + assert hasattr(args, "basic_auth") + assert hasattr(args, "cookie_secret") + assert hasattr(args, "show") + + def test_default_values(self): + """Test default values for all arguments.""" + parser = argparse.ArgumentParser() + add_panel_server_args(parser) + args = parser.parse_args([]) + + assert args.address == "127.0.0.1" + assert args.port == 8080 + assert args.allow_websocket_origin == ["*"] + assert args.basic_auth is None + assert args.cookie_secret == "secret" + assert args.show is False + + def test_sets_epilog(self): + """Test that epilog is set on the parser.""" + parser = argparse.ArgumentParser() + add_panel_server_args(parser) + + assert parser.epilog is not None + assert "server" in parser.epilog.lower() + + def test_custom_values(self): + """Test setting custom values for arguments.""" + parser = argparse.ArgumentParser() + add_panel_server_args(parser) + args = parser.parse_args( + [ + "--address", + "0.0.0.0", + "--port", + "9000", + "--basic-auth", + "user:pass", + "--cookie-secret", + "mysecret", + "--show", + ] + ) + + assert args.address == "0.0.0.0" + assert args.port == 9000 + assert args.basic_auth == "user:pass" + assert args.cookie_secret == "mysecret" + assert args.show is True + + def test_websocket_origin_multiple(self): + """Test multiple websocket origins.""" + parser = argparse.ArgumentParser() + add_panel_server_args(parser) + args = parser.parse_args(["--allow-websocket-origin", "localhost:8080", "example.com"]) + + assert args.allow_websocket_origin == ["localhost:8080", "example.com"] + + +class TestResolveConfigPaths: + """Tests for resolve_config_paths helper function.""" + + def test_uses_args_config_path(self): + """Test that args.config_path takes precedence.""" + args = argparse.Namespace(config_path="/from/args", config_name="name") + root_dir, root_name = resolve_config_paths(args, config_path="/default", config_name="default") + + assert root_dir == "/from/args" + assert root_name == "name" + + def test_uses_args_config_name(self): + """Test that args.config_name takes precedence.""" + args = argparse.Namespace(config_path="/path", config_name="from_args") + root_dir, root_name = resolve_config_paths(args, config_path="", config_name="default") + + assert root_name == "from_args" + + def test_falls_back_to_config_name_default(self): + """Test fallback to config_name parameter when args.config_name is None.""" + args = argparse.Namespace(config_path="/path", config_name=None) + root_dir, root_name = resolve_config_paths(args, config_path="", config_name="default_name") + + assert root_name == "default_name" + + def test_raises_without_config_path(self): + """Test ValueError when no config_path available.""" + args = argparse.Namespace(config_path=None, config_name="name") + + with pytest.raises(ValueError, match="Must provide --config-path"): + resolve_config_paths(args, config_path="", config_name="name", hydra_main=None) + + def test_raises_without_config_name(self): + """Test ValueError when no config_name available.""" + args = argparse.Namespace(config_path="/path", config_name=None) + + with pytest.raises(ValueError, match="Must provide --config-name"): + resolve_config_paths(args, config_path="", config_name="", hydra_main=None) + + def test_uses_hydra_main_for_config_path(self): + """Test that hydra_main is used to resolve config_path.""" + # Create a mock hydra_main with __wrapped__ attribute + mock_func = MagicMock() + mock_func.__wrapped__ = lambda: None + # Mock inspect.getfile to return a known path + import ccflow.utils.hydra as hydra_module + + original_getfile = hydra_module.inspect.getfile + + try: + hydra_module.inspect.getfile = lambda x: "/some/module/path.py" + args = argparse.Namespace(config_path=None, config_name="name") + root_dir, root_name = resolve_config_paths(args, config_path="config", config_name="name", hydra_main=mock_func) + + assert root_dir == "/some/module/config" + finally: + hydra_module.inspect.getfile = original_getfile @pytest.fixture diff --git a/ccflow/ui/cli.py b/ccflow/ui/cli.py index a0801d8..a6b7d3f 100644 --- a/ccflow/ui/cli.py +++ b/ccflow/ui/cli.py @@ -1,14 +1,12 @@ """CLI for serving ModelRegistryViewer as a Panel application.""" import argparse -import inspect -from pathlib import Path from typing import Callable, Optional import panel as pn from ccflow import ModelRegistry -from ccflow.utils.hydra import load_config +from ccflow.utils.hydra import add_hydra_config_args, add_panel_server_args, load_config, resolve_config_paths from .registry import ModelRegistryViewer @@ -22,64 +20,13 @@ def _get_ui_args_parser() -> argparse.ArgumentParser: description="Serve ModelRegistryViewer as a Panel application", ) - # Registry loading arguments (similar to utils.hydra) - parser.add_argument( - "overrides", - nargs="*", - help="Key=value arguments to override config values", - ) - parser.add_argument( - "--config-path", - "-cp", - help="Path to the Hydra config directory", - ) - parser.add_argument( - "--config-name", - "-cn", - help="Name of the config file (without .yaml extension)", - ) - parser.add_argument( - "--config-dir", - "-cd", - help="Additional config directory to add to search path", - ) - parser.add_argument( - "--config-dir-config-name", - "-cdcn", - help="Config name to look for within config-dir", - ) - parser.add_argument( - "--basepath", - help="Base path for searching config directories", - ) + # Standard hydra config loading arguments + add_hydra_config_args(parser) - # UI server arguments - parser.add_argument( - "--address", - type=str, - default="127.0.0.1", - help="Address to bind the server to (default: 127.0.0.1)", - ) - parser.add_argument( - "--port", - type=int, - default=8080, - help="Port to bind the server to (default: 8080)", - ) - parser.add_argument( - "--allow-websocket-origin", - type=str, - nargs="+", - default=["*"], - help="Allowed websocket origins (default: *)", - ) - parser.add_argument( - "--show", - action="store_true", - help="Open browser automatically", - ) + # Standard Panel server arguments + add_panel_server_args(parser) - # Viewer layout arguments + # Viewer-specific arguments parser.add_argument( "--browser-width", type=int, @@ -122,21 +69,8 @@ def registry_viewer_cli( parser = _get_ui_args_parser() args = parser.parse_args() - # Resolve config path (same logic as cfg_explain_cli) - if args.config_path: - root_config_dir = args.config_path - elif hydra_main and config_path: - root_config_dir = str(Path(inspect.getfile(hydra_main.__wrapped__)).parent / config_path) - else: - raise ValueError("Must provide --config-path.") - - # Resolve config name - if args.config_name: - root_config_name = args.config_name - elif config_name: - root_config_name = config_name - else: - raise ValueError("Must provide --config-name.") + # Resolve config paths using shared helper + root_config_dir, root_config_name = resolve_config_paths(args, config_path, config_name, hydra_main) # Load config using hydra utilities result = load_config( diff --git a/ccflow/utils/hydra.py b/ccflow/utils/hydra.py index 70da1b0..f457c39 100644 --- a/ccflow/utils/hydra.py +++ b/ccflow/utils/hydra.py @@ -24,6 +24,9 @@ __all__ = ( "ConfigLoadResult", "load_config", + "add_hydra_config_args", + "add_panel_server_args", + "resolve_config_paths", "get_args_parser_default", "get_args_parser_default_ui", "ui_launcher_default", @@ -237,27 +240,33 @@ def yaml_load(*args, **kwargs): return result -def get_args_parser_default() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(add_help=True, description="Hydra Config Audit Tool") +def add_hydra_config_args(parser: argparse.ArgumentParser) -> None: + """Add standard hydra config loading arguments to a parser. + + Adds the following arguments: + - overrides: Positional arguments for key=value config overrides + - --config-path/-cp: Path to the hydra config directory + - --config-name/-cn: Name of the config file + - --config-dir/-cd: Additional config directory to add to search path + - --config-dir-config-name/-cdcn: Config name within config-dir + - --basepath: Base path for searching config directories + """ parser.add_argument( "overrides", nargs="*", help="Any key=value arguments to override config values (use dots for.nested=overrides)", ) - parser.add_argument( "--config-path", "-cp", help="""Overrides the config_path specified in hydra.main(). The config_path is absolute or relative to the Python file declaring @hydra.main()""", ) - parser.add_argument( "--config-name", "-cn", help="Overrides the config_name specified in hydra.main()", ) - parser.add_argument( "--config-dir", "-cd", @@ -272,16 +281,19 @@ def get_args_parser_default() -> argparse.ArgumentParser: "--basepath", help="The base path to start searching for the `config_dir` (if not the cwd). This is useful when you want to load from an absolute (rather than relative) path.", ) - parser.add_argument( - "--no-gui", - action="store_true", - help="Disable the GUI", - ) - return parser -def get_args_parser_default_ui() -> argparse.ArgumentParser: - parser = get_args_parser_default() +def add_panel_server_args(parser: argparse.ArgumentParser) -> None: + """Add standard Panel server arguments to a parser. + + Adds the following arguments: + - --address: Address to bind the server to (default: 127.0.0.1) + - --port: Port to bind the server to (default: 8080) + - --allow-websocket-origin: Allowed websocket origins (default: *) + - --basic-auth: Enable basic authentication (username:password format) + - --cookie-secret: Cookie secret for the server + - --show: Open browser automatically + """ parser.add_argument( "--address", type=str, @@ -316,6 +328,11 @@ def get_args_parser_default_ui() -> argparse.ArgumentParser: default=None, ) parser.add_argument("--cookie-secret", type=str, default="secret", help="Cookie secret for the server.") + parser.add_argument( + "--show", + action="store_true", + help="Open browser automatically", + ) parser.epilog = dedent("""\ This will launch the server that can be used to view the configuration. The server will be accessible at http://
    : by default. @@ -324,6 +341,73 @@ def get_args_parser_default_ui() -> argparse.ArgumentParser: You can enable basic authentication using the --basic-auth argument. You can specify the cookie secret using the --cookie-secret argument. """) + + +def resolve_config_paths( + args: argparse.Namespace, + config_path: str = "", + config_name: str = "", + hydra_main: Optional[Callable] = None, +) -> tuple: + """Resolve root_config_dir and root_config_name from CLI args or defaults. + + This helper extracts the common logic for resolving config paths from either + CLI arguments or default values provided by the decorated hydra.main function. + + Parameters + ---------- + args + Parsed argparse namespace containing config_path and config_name attributes + config_path + Default config_path, typically from hydra.main() decorator + config_name + Default config_name, typically from hydra.main() decorator + hydra_main + The function decorated with hydra.main(). Used to resolve config_path + relative to the decorated function's file location. + + Returns + ------- + tuple + (root_config_dir, root_config_name) + + Raises + ------ + ValueError + If neither args.config_path nor hydra_main+config_path are provided + If neither args.config_name nor config_name are provided + """ + if args.config_path: + root_config_dir = args.config_path + elif hydra_main and config_path: + root_config_dir = str(Path(inspect.getfile(hydra_main.__wrapped__)).parent / config_path) + else: + raise ValueError("Must provide --config-path.") + + if args.config_name: + root_config_name = args.config_name + elif config_name: + root_config_name = config_name + else: + raise ValueError("Must provide --config-name.") + + return root_config_dir, root_config_name + + +def get_args_parser_default() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(add_help=True, description="Hydra Config Audit Tool") + add_hydra_config_args(parser) + parser.add_argument( + "--no-gui", + action="store_true", + help="Disable the GUI", + ) + return parser + + +def get_args_parser_default_ui() -> argparse.ArgumentParser: + parser = get_args_parser_default() + add_panel_server_args(parser) return parser @@ -360,19 +444,7 @@ def cfg_explain_cli( parser = get_args_parser_default_ui() args = parser.parse_args() - if args.config_path: - root_config_dir = args.config_path - elif hydra_main and config_path: - root_config_dir = str(Path(inspect.getfile(hydra_main.__wrapped__)).parent / config_path) - else: - raise ValueError("Must provide --config-path.") - - if args.config_name: - root_config_name = args.config_name - elif config_name: - root_config_name = config_name - else: - raise ValueError("Must provide --config-name.") + root_config_dir, root_config_name = resolve_config_paths(args, config_path, config_name, hydra_main) result = load_config( root_config_dir=root_config_dir,