diff --git a/docs/developer_docs/extensions/development.md b/docs/developer_docs/extensions/development.md index cd38f1e47d29..a10f77431100 100644 --- a/docs/developer_docs/extensions/development.md +++ b/docs/developer_docs/extensions/development.md @@ -39,7 +39,11 @@ superset-extensions bundle: Packages the extension into a .supx file. superset-extensions dev: Automatically rebuilds the extension as files change. -superset-extensions validate: Validates the extension structure and metadata. +superset-extensions validate: Validates the extension structure and metadata consistency. + +superset-extensions update: Updates derived and generated files in the extension project. + Use --version [] to update the version (prompts if no value given). + Use --license [] to update the license (prompts if no value given). ``` When creating a new extension with `superset-extensions init`, the CLI generates a standardized folder structure: diff --git a/requirements/development.txt b/requirements/development.txt index 296fe15d8207..2b9fffd474b0 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -996,6 +996,8 @@ tabulate==0.9.0 # via # -c requirements/base-constraint.txt # apache-superset +tomli-w==1.2.0 + # via apache-superset-extensions-cli tomlkit==0.13.3 # via pylint tqdm==4.67.1 diff --git a/superset-extensions-cli/pyproject.toml b/superset-extensions-cli/pyproject.toml index 3f88e8809970..8f853a86e367 100644 --- a/superset-extensions-cli/pyproject.toml +++ b/superset-extensions-cli/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "jinja2>=3.1.6", "semver>=3.0.4", "tomli>=2.2.1; python_version < '3.11'", + "tomli-w>=1.2.0", "watchdog>=6.0.0", ] diff --git a/superset-extensions-cli/src/superset_extensions_cli/cli.py b/superset-extensions-cli/src/superset_extensions_cli/cli.py index 079f4e639a32..feab2ab5d8ab 100644 --- a/superset-extensions-cli/src/superset_extensions_cli/cli.py +++ b/superset-extensions-cli/src/superset_extensions_cli/cli.py @@ -50,6 +50,8 @@ validate_display_name, validate_publisher, validate_technical_name, + write_json, + write_toml, ) REMOTE_ENTRY_REGEX = re.compile(r"^remoteEntry\..+\.js$") @@ -292,6 +294,7 @@ def app() -> None: @app.command() def validate() -> None: + """Validate the extension structure and metadata consistency.""" validate_npm() cwd = Path.cwd() @@ -372,12 +375,167 @@ def validate() -> None: click.secho(" Convention requires: frontend/src/index.tsx", fg="yellow") sys.exit(1) + # Validate version and license consistency across extension.json, frontend, and backend + mismatches: list[str] = [] + frontend_pkg_path = cwd / "frontend" / "package.json" + frontend_pkg = None + if frontend_pkg_path.is_file(): + frontend_pkg = read_json(frontend_pkg_path) + if frontend_pkg: + if frontend_pkg.get("version") != extension.version: + mismatches.append( + f" frontend/package.json version: {frontend_pkg.get('version')} " + f"(expected {extension.version})" + ) + if extension.license and frontend_pkg.get("license") != extension.license: + mismatches.append( + f" frontend/package.json license: {frontend_pkg.get('license')} " + f"(expected {extension.license})" + ) + + backend_pyproject_path = cwd / "backend" / "pyproject.toml" + if backend_pyproject_path.is_file(): + backend_pyproject = read_toml(backend_pyproject_path) + if backend_pyproject: + project = backend_pyproject.get("project", {}) + if project.get("version") != extension.version: + mismatches.append( + f" backend/pyproject.toml version: {project.get('version')} " + f"(expected {extension.version})" + ) + if extension.license and project.get("license") != extension.license: + mismatches.append( + f" backend/pyproject.toml license: {project.get('license')} " + f"(expected {extension.license})" + ) + + if mismatches: + click.secho("❌ Metadata mismatch detected:", err=True, fg="red") + for mismatch in mismatches: + click.secho(mismatch, err=True, fg="red") + click.secho( + "Run `superset-extensions update` to sync from extension.json.", + fg="yellow", + ) + sys.exit(1) + click.secho("✅ Validation successful", fg="green") +@app.command() +@click.option( + "--version", + "version_opt", + is_flag=False, + flag_value="__prompt__", + default=None, + help="Set a new version. Prompts for value if none given.", +) +@click.option( + "--license", + "license_opt", + is_flag=False, + flag_value="__prompt__", + default=None, + help="Set a new license. Prompts for value if none given.", +) +def update(version_opt: str | None, license_opt: str | None) -> None: + """Update derived and generated files in the extension project.""" + cwd = Path.cwd() + + extension_json_path = cwd / "extension.json" + extension_data = read_json(extension_json_path) + if not extension_data: + click.secho("❌ extension.json not found.", err=True, fg="red") + sys.exit(1) + + try: + extension = ExtensionConfig.model_validate(extension_data) + except Exception as e: + click.secho(f"❌ Invalid extension.json: {e}", err=True, fg="red") + sys.exit(1) + + # Resolve version: prompt if flag used without value + if version_opt == "__prompt__": + version_opt = click.prompt("Version", default=extension.version) + target_version = ( + version_opt + if version_opt and version_opt != extension.version + else extension.version + ) + + # Resolve license: prompt if flag used without value + if license_opt == "__prompt__": + license_opt = click.prompt("License", default=extension.license or "") + target_license = ( + license_opt + if license_opt and license_opt != extension.license + else extension.license + ) + + updated: list[str] = [] + + # Update extension.json if version or license changed + ext_changed = False + if version_opt and version_opt != extension.version: + extension_data["version"] = target_version + ext_changed = True + if license_opt and license_opt != extension.license: + extension_data["license"] = target_license + ext_changed = True + if ext_changed: + try: + ExtensionConfig.model_validate(extension_data) + except Exception as e: + click.secho(f"❌ Invalid value: {e}", err=True, fg="red") + sys.exit(1) + write_json(extension_json_path, extension_data) + updated.append("extension.json") + + # Update frontend/package.json + frontend_pkg_path = cwd / "frontend" / "package.json" + if frontend_pkg_path.is_file(): + frontend_pkg = read_json(frontend_pkg_path) + if frontend_pkg: + pkg_changed = False + if frontend_pkg.get("version") != target_version: + frontend_pkg["version"] = target_version + pkg_changed = True + if target_license and frontend_pkg.get("license") != target_license: + frontend_pkg["license"] = target_license + pkg_changed = True + if pkg_changed: + write_json(frontend_pkg_path, frontend_pkg) + updated.append("frontend/package.json") + + # Update backend/pyproject.toml + backend_pyproject_path = cwd / "backend" / "pyproject.toml" + if backend_pyproject_path.is_file(): + backend_pyproject = read_toml(backend_pyproject_path) + if backend_pyproject: + project = backend_pyproject.setdefault("project", {}) + toml_changed = False + if project.get("version") != target_version: + project["version"] = target_version + toml_changed = True + if target_license and project.get("license") != target_license: + project["license"] = target_license + toml_changed = True + if toml_changed: + write_toml(backend_pyproject_path, backend_pyproject) + updated.append("backend/pyproject.toml") + + if updated: + for path in updated: + click.secho(f"✅ Updated {path}", fg="green") + else: + click.secho("✅ All files already up to date.", fg="green") + + @app.command() @click.pass_context def build(ctx: click.Context) -> None: + """Build extension assets.""" ctx.invoke(validate) cwd = Path.cwd() frontend_dir = cwd / "frontend" @@ -413,6 +571,7 @@ def build(ctx: click.Context) -> None: ) @click.pass_context def bundle(ctx: click.Context, output: Path | None) -> None: + """Package the extension into a .supx file.""" ctx.invoke(build) cwd = Path.cwd() @@ -453,6 +612,7 @@ def bundle(ctx: click.Context, output: Path | None) -> None: @app.command() @click.pass_context def dev(ctx: click.Context) -> None: + """Automatically rebuild the extension as files change.""" cwd = Path.cwd() frontend_dir = cwd / "frontend" backend_dir = cwd / "backend" @@ -647,6 +807,7 @@ def init( frontend_opt: bool | None, backend_opt: bool | None, ) -> None: + """Scaffold a new extension project.""" # Get extension names with graceful validation names = prompt_for_extension_info(display_name_opt, publisher_opt, name_opt) diff --git a/superset-extensions-cli/src/superset_extensions_cli/utils.py b/superset-extensions-cli/src/superset_extensions_cli/utils.py index bda0f3c4a807..c4fae04e6ec1 100644 --- a/superset-extensions-cli/src/superset_extensions_cli/utils.py +++ b/superset-extensions-cli/src/superset_extensions_cli/utils.py @@ -21,6 +21,8 @@ from pathlib import Path from typing import Any +import tomli_w + from superset_core.extensions.constants import ( DISPLAY_NAME_PATTERN, PUBLISHER_PATTERN, @@ -109,6 +111,14 @@ def read_json(path: Path) -> dict[str, Any] | None: return json.loads(path.read_text()) +def write_json(path: Path, data: dict[str, Any]) -> None: + path.write_text(json.dumps(data, indent=2) + "\n") + + +def write_toml(path: Path, data: dict[str, Any]) -> None: + path.write_text(tomli_w.dumps(data)) + + def _normalize_for_identifiers(name: str) -> str: """ Normalize display name to clean lowercase words. diff --git a/superset-extensions-cli/tests/conftest.py b/superset-extensions-cli/tests/conftest.py index 8c223be5051c..a4cc7f7521c8 100644 --- a/superset-extensions-cli/tests/conftest.py +++ b/superset-extensions-cli/tests/conftest.py @@ -17,10 +17,12 @@ from __future__ import annotations +import json import os from pathlib import Path import pytest +import tomli_w from click.testing import CliRunner @@ -138,3 +140,69 @@ def _setup(base_path: Path) -> None: (backend_dir / "__init__.py").write_text("# init") return _setup + + +@pytest.fixture +def extension_with_versions(): + """Create an extension directory structure with configurable versions and licenses.""" + + def _create( + base_path: Path, + ext_version: str = "1.0.0", + frontend_version: str | None = None, + backend_version: str | None = None, + ext_license: str | None = "Apache-2.0", + frontend_license: str | None = None, + backend_license: str | None = None, + ) -> None: + extension_json = { + "publisher": "test-org", + "name": "test-extension", + "displayName": "Test Extension", + "version": ext_version, + "permissions": [], + } + if ext_license is not None: + extension_json["license"] = ext_license + (base_path / "extension.json").write_text(json.dumps(extension_json)) + + if frontend_version is not None: + frontend_dir = base_path / "frontend" + frontend_dir.mkdir(exist_ok=True) + (frontend_dir / "src").mkdir(exist_ok=True) + (frontend_dir / "src" / "index.tsx").write_text("// entry") + pkg = { + "name": "@test-org/test-extension", + "version": frontend_version, + } + if frontend_license is not None: + pkg["license"] = frontend_license + elif ext_license is not None: + pkg["license"] = ext_license + (frontend_dir / "package.json").write_text(json.dumps(pkg, indent=2)) + + if backend_version is not None: + backend_dir = base_path / "backend" + backend_dir.mkdir(exist_ok=True) + src_dir = backend_dir / "src" / "test_org" / "test_extension" + src_dir.mkdir(parents=True, exist_ok=True) + (src_dir / "entrypoint.py").write_text("# entry") + project = { + "name": "test-org-test-extension", + "version": backend_version, + } + if backend_license is not None: + project["license"] = backend_license + elif ext_license is not None: + project["license"] = ext_license + pyproject = { + "project": project, + "tool": { + "apache_superset_extensions": { + "build": {"include": ["src/**/*.py"]} + } + }, + } + (backend_dir / "pyproject.toml").write_text(tomli_w.dumps(pyproject)) + + return _create diff --git a/superset-extensions-cli/tests/test_cli_build.py b/superset-extensions-cli/tests/test_cli_build.py index 90f63fdd5f50..5753eff48922 100644 --- a/superset-extensions-cli/tests/test_cli_build.py +++ b/superset-extensions-cli/tests/test_cli_build.py @@ -121,7 +121,7 @@ def test_build_command_success_flow( # Setup mocks mock_rebuild_frontend.return_value = "remoteEntry.abc123.js" mock_read_toml.return_value = { - "project": {"name": "test"}, + "project": {"name": "test", "version": "1.0.0"}, "tool": { "apache_superset_extensions": { "build": {"include": ["src/test_org/test_extension/**/*.py"]} @@ -162,7 +162,7 @@ def test_build_command_handles_frontend_build_failure( # Setup mocks mock_rebuild_frontend.return_value = None # Indicates failure mock_read_toml.return_value = { - "project": {"name": "test"}, + "project": {"name": "test", "version": "1.0.0"}, "tool": { "apache_superset_extensions": { "build": {"include": ["src/test_org/test_extension/**/*.py"]} diff --git a/superset-extensions-cli/tests/test_cli_update.py b/superset-extensions-cli/tests/test_cli_update.py new file mode 100644 index 000000000000..36b965fb54da --- /dev/null +++ b/superset-extensions-cli/tests/test_cli_update.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest +from superset_extensions_cli.cli import app +from superset_extensions_cli.utils import read_json, read_toml + + +@pytest.mark.cli +def test_update_syncs_versions( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test update syncs frontend and backend versions from extension.json.""" + extension_with_versions( + isolated_filesystem, + ext_version="2.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + result = cli_runner.invoke(app, ["update"]) + + assert result.exit_code == 0 + assert "Updated frontend/package.json" in result.output + assert "Updated backend/pyproject.toml" in result.output + + frontend_pkg = read_json(isolated_filesystem / "frontend" / "package.json") + assert frontend_pkg["version"] == "2.0.0" + + backend_pyproject = read_toml(isolated_filesystem / "backend" / "pyproject.toml") + assert backend_pyproject["project"]["version"] == "2.0.0" + + +@pytest.mark.cli +def test_update_noop_when_all_match( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test update reports no changes when everything already matches.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + result = cli_runner.invoke(app, ["update"]) + + assert result.exit_code == 0 + assert "All files already up to date" in result.output + + +@pytest.mark.cli +def test_update_fails_without_extension_json(cli_runner, isolated_filesystem): + """Test update fails when extension.json is missing.""" + result = cli_runner.invoke(app, ["update"]) + + assert result.exit_code != 0 + assert "extension.json not found" in result.output + + +@pytest.mark.cli +def test_update_with_version_flag( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test --version updates extension.json first, then syncs all files.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + result = cli_runner.invoke(app, ["update", "--version", "3.0.0"]) + + assert result.exit_code == 0 + assert "Updated extension.json" in result.output + assert "Updated frontend/package.json" in result.output + assert "Updated backend/pyproject.toml" in result.output + + ext = read_json(isolated_filesystem / "extension.json") + assert ext["version"] == "3.0.0" + + frontend_pkg = read_json(isolated_filesystem / "frontend" / "package.json") + assert frontend_pkg["version"] == "3.0.0" + + backend_pyproject = read_toml(isolated_filesystem / "backend" / "pyproject.toml") + assert backend_pyproject["project"]["version"] == "3.0.0" + + +@pytest.mark.cli +def test_update_with_license_flag( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test --license updates license across all files.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ext_license="Apache-2.0", + ) + + result = cli_runner.invoke(app, ["update", "--license", "MIT"]) + + assert result.exit_code == 0 + assert "Updated extension.json" in result.output + assert "Updated frontend/package.json" in result.output + assert "Updated backend/pyproject.toml" in result.output + + ext = read_json(isolated_filesystem / "extension.json") + assert ext["license"] == "MIT" + + frontend_pkg = read_json(isolated_filesystem / "frontend" / "package.json") + assert frontend_pkg["license"] == "MIT" + + backend_pyproject = read_toml(isolated_filesystem / "backend" / "pyproject.toml") + assert backend_pyproject["project"]["license"] == "MIT" + + +@pytest.mark.cli +def test_update_version_prompt_default( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test --version without value prompts with current version as default.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + # Hit enter to accept default — nothing should change + result = cli_runner.invoke(app, ["update", "--version"], input="\n") + + assert result.exit_code == 0 + assert "All files already up to date" in result.output + + +@pytest.mark.cli +def test_update_rejects_invalid_version( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test --version with an invalid semver string exits with error.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + ) + + result = cli_runner.invoke(app, ["update", "--version", "not-a-version"]) + + assert result.exit_code != 0 + assert "Invalid value" in result.output + + # Verify extension.json was not modified + ext = read_json(isolated_filesystem / "extension.json") + assert ext["version"] == "1.0.0" diff --git a/superset-extensions-cli/tests/test_cli_validate.py b/superset-extensions-cli/tests/test_cli_validate.py index 970a2ce13ce7..b796aedd9309 100644 --- a/superset-extensions-cli/tests/test_cli_validate.py +++ b/superset-extensions-cli/tests/test_cli_validate.py @@ -207,3 +207,66 @@ def test_validate_npm_with_empty_version_output_raises_error(mock_run, mock_whic # semver.compare will raise ValueError for empty version with pytest.raises(ValueError): validate_npm() + + +# Version Consistency Tests +@pytest.mark.cli +def test_validate_fails_on_version_mismatch( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test validate fails when frontend/backend versions differ from extension.json.""" + extension_with_versions( + isolated_filesystem, + ext_version="2.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + with patch("superset_extensions_cli.cli.validate_npm"): + result = cli_runner.invoke(app, ["validate"]) + + assert result.exit_code != 0 + assert "Metadata mismatch" in result.output + assert "superset-extensions update" in result.output + + +@pytest.mark.cli +def test_validate_passes_with_matching_versions( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test validate passes when all versions match extension.json.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ) + + with patch("superset_extensions_cli.cli.validate_npm"): + result = cli_runner.invoke(app, ["validate"]) + + assert result.exit_code == 0 + assert "Validation successful" in result.output + + +@pytest.mark.cli +def test_validate_fails_on_license_mismatch( + cli_runner, isolated_filesystem, extension_with_versions +): + """Test validate fails when frontend/backend licenses differ from extension.json.""" + extension_with_versions( + isolated_filesystem, + ext_version="1.0.0", + frontend_version="1.0.0", + backend_version="1.0.0", + ext_license="Apache-2.0", + frontend_license="MIT", + backend_license="MIT", + ) + + with patch("superset_extensions_cli.cli.validate_npm"): + result = cli_runner.invoke(app, ["validate"]) + + assert result.exit_code != 0 + assert "Metadata mismatch" in result.output + assert "license" in result.output diff --git a/superset-extensions-cli/tests/test_utils.py b/superset-extensions-cli/tests/test_utils.py index 1df07c72a605..4e87eaf9303c 100644 --- a/superset-extensions-cli/tests/test_utils.py +++ b/superset-extensions-cli/tests/test_utils.py @@ -20,7 +20,7 @@ import json import pytest -from superset_extensions_cli.utils import read_json, read_toml +from superset_extensions_cli.utils import read_json, read_toml, write_json, write_toml # Read JSON Tests @@ -269,3 +269,32 @@ def test_read_toml_with_permission_denied(isolated_filesystem): toml_file.chmod(0o644) except (OSError, PermissionError): pass + + +# Write JSON Tests +@pytest.mark.unit +def test_write_json_round_trip(isolated_filesystem): + """Test write_json then read_json round-trip preserves content.""" + data = {"name": "test-extension", "version": "2.0.0", "nested": {"key": "value"}} + json_file = isolated_filesystem / "output.json" + + write_json(json_file, data) + result = read_json(json_file) + + assert result == data + + +# Write TOML Tests +@pytest.mark.unit +def test_write_toml_round_trip(isolated_filesystem): + """Test write_toml then read_toml round-trip preserves content.""" + data = { + "project": {"name": "test-package", "version": "1.0.0"}, + "tool": {"apache_superset_extensions": {"build": {"include": ["src/**/*.py"]}}}, + } + toml_file = isolated_filesystem / "output.toml" + + write_toml(toml_file, data) + result = read_toml(toml_file) + + assert result == data diff --git a/superset-frontend/packages/superset-core/src/theme/types.ts b/superset-frontend/packages/superset-core/src/theme/types.ts index a0b30ca59562..87bd226f874f 100644 --- a/superset-frontend/packages/superset-core/src/theme/types.ts +++ b/superset-frontend/packages/superset-core/src/theme/types.ts @@ -426,6 +426,7 @@ export interface ThemeControllerOptions { canUpdateTheme?: () => boolean; canUpdateMode?: () => boolean; isGlobalContext?: boolean; + initialMode?: ThemeMode; } export interface ThemeContextType { diff --git a/superset-frontend/src/embedded/EmbeddedContextProviders.tsx b/superset-frontend/src/embedded/EmbeddedContextProviders.tsx index 9ea9c0744ab7..0a832ed294af 100644 --- a/superset-frontend/src/embedded/EmbeddedContextProviders.tsx +++ b/superset-frontend/src/embedded/EmbeddedContextProviders.tsx @@ -26,7 +26,7 @@ import { DynamicPluginProvider } from 'src/components'; import { EmbeddedUiConfigProvider } from 'src/components/UiConfigContext'; import { SupersetThemeProvider } from 'src/theme/ThemeProvider'; import { ThemeController } from 'src/theme/ThemeController'; -import { type ThemeStorage } from '@apache-superset/core/theme'; +import { type ThemeStorage, ThemeMode } from '@apache-superset/core/theme'; import { store } from 'src/views/store'; import querystring from 'query-string'; @@ -52,6 +52,7 @@ class ThemeMemoryStorageAdapter implements ThemeStorage { const themeController = new ThemeController({ storage: new ThemeMemoryStorageAdapter(), + initialMode: ThemeMode.DEFAULT, }); export const getThemeController = (): ThemeController => themeController; diff --git a/superset-frontend/src/theme/ThemeController.ts b/superset-frontend/src/theme/ThemeController.ts index fafbe784914a..5234ff19aa57 100644 --- a/superset-frontend/src/theme/ThemeController.ts +++ b/superset-frontend/src/theme/ThemeController.ts @@ -102,15 +102,19 @@ export class ThemeController { // Track loaded font URLs to avoid duplicate injections private loadedFontUrls: Set = new Set(); + private initialMode: ThemeMode | undefined; + constructor({ storage = new LocalStorageAdapter(), modeStorageKey = STORAGE_KEYS.THEME_MODE, themeObject = supersetThemeObject, defaultTheme = (supersetThemeObject.theme as AnyThemeConfig) ?? {}, onChange = undefined, + initialMode = undefined, }: ThemeControllerOptions = {}) { this.storage = storage; this.modeStorageKey = modeStorageKey; + this.initialMode = initialMode; // Controller creates and owns the global theme this.globalTheme = themeObject; @@ -743,6 +747,13 @@ export class ThemeController { return ThemeMode.DEFAULT; } + // Use explicit initial mode if provided (e.g. embedded dashboards default to light) + if ( + this.initialMode !== undefined && + this.isValidThemeMode(this.initialMode) + ) + return this.initialMode; + // Default to system preference when both themes are available return ThemeMode.SYSTEM; } diff --git a/superset-frontend/src/theme/tests/ThemeController.test.ts b/superset-frontend/src/theme/tests/ThemeController.test.ts index b6a68a3a11a1..e905ed5e7c31 100644 --- a/superset-frontend/src/theme/tests/ThemeController.test.ts +++ b/superset-frontend/src/theme/tests/ThemeController.test.ts @@ -1686,3 +1686,115 @@ test('font loading: adds new font URLs when switching themes', () => { .querySelectorAll('style[data-superset-fonts]') .forEach(el => el.remove()); }); + +test('ThemeController uses initialMode when provided and no saved mode exists', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + const controller = createController({ initialMode: ThemeMode.DEFAULT }); + + expect(controller.getCurrentMode()).toBe(ThemeMode.DEFAULT); +}); + +test('ThemeController defaults to SYSTEM when initialMode is not provided', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + const controller = createController(); + + expect(controller.getCurrentMode()).toBe(ThemeMode.SYSTEM); +}); + +test('ThemeController saved mode takes precedence over initialMode', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + mockLocalStorage.getItem.mockReturnValue(ThemeMode.DARK); + + const controller = createController({ initialMode: ThemeMode.DEFAULT }); + + expect(controller.getCurrentMode()).toBe(ThemeMode.DARK); +}); + +test('ThemeController with initialMode DEFAULT applies light theme even when system prefers dark', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + mockMatchMedia.mockReturnValue({ + matches: true, // system prefers dark + addEventListener: jest.fn(), + removeEventListener: jest.fn(), + }); + + const controller = createController({ initialMode: ThemeMode.DEFAULT }); + + expect(controller.getCurrentMode()).toBe(ThemeMode.DEFAULT); + const lastCall = + mockSetConfig.mock.calls[mockSetConfig.mock.calls.length - 1][0]; + expect(lastCall.token.colorBgBase).toBe(DEFAULT_THEME.token!.colorBgBase); +}); + +test('ThemeController with initialMode still allows setThemeMode after init', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + const controller = createController({ initialMode: ThemeMode.DEFAULT }); + expect(controller.getCurrentMode()).toBe(ThemeMode.DEFAULT); + + controller.setThemeMode(ThemeMode.DARK); + expect(controller.getCurrentMode()).toBe(ThemeMode.DARK); + + controller.setThemeMode(ThemeMode.SYSTEM); + expect(controller.getCurrentMode()).toBe(ThemeMode.SYSTEM); +}); + +test('ThemeController initialMode is ignored when no dark theme exists', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: {}, + }), + ); + + const controller = createController({ initialMode: ThemeMode.SYSTEM }); + + // Should still be DEFAULT because there's no dark theme available + expect(controller.getCurrentMode()).toBe(ThemeMode.DEFAULT); +}); + +test('ThemeController invalid initialMode falls back to SYSTEM', () => { + mockGetBootstrapData.mockReturnValue( + createMockBootstrapData({ + default: DEFAULT_THEME, + dark: DARK_THEME, + }), + ); + + const controller = createController({ + initialMode: 'invalid' as ThemeMode, + }); + + // Invalid initialMode should be rejected by isValidThemeMode, + // falling through to the default SYSTEM mode + expect(controller.getCurrentMode()).toBe(ThemeMode.SYSTEM); +});