Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3778.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`Group.tree()` no longer requires the `rich` dependency. Tree rendering now uses built-in ANSI bold for terminals and HTML bold for Jupyter. New parameters: `plain=True` for unstyled output, and `max_nodes` (default 500) to truncate large hierarchies with early bailout.
2 changes: 0 additions & 2 deletions docs/user-guide/groups.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,3 @@ Groups also have the [`zarr.Group.tree`][] method, e.g.:
print(root.tree())
```

!!! note
[`zarr.Group.tree`][] requires the optional [rich](https://rich.readthedocs.io/en/stable/) dependency. It can be installed with the `[tree]` extra.
2 changes: 1 addition & 1 deletion docs/user-guide/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ These can be installed using `pip install "zarr[<extra>]"`, e.g. `pip install "z
- `gpu`: support for GPUs
- `remote`: support for reading/writing to remote data stores

Additional optional dependencies include `rich`, `universal_pathlib`. These must be installed separately.
Additional optional dependencies include `universal_pathlib`. These must be installed separately.

## conda

Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ gpu = [
"cupy-cuda12x",
]
cli = ["typer"]
optional = ["rich", "universal-pathlib"]
optional = ["universal-pathlib"]

[project.scripts]
zarr = "zarr._cli.cli:app"
Expand Down Expand Up @@ -122,7 +122,6 @@ docs = [
"towncrier",
# Optional dependencies to run examples
"numcodecs[msgpack]",
"rich",
"s3fs>=2023.10.0",
"astroid<4",
"pytest",
Expand All @@ -131,7 +130,6 @@ dev = [
{include-group = "test"},
{include-group = "remote-tests"},
{include-group = "docs"},
"rich",
"universal-pathlib",
"mypy",
]
Expand Down
1 change: 0 additions & 1 deletion src/zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def print_packages(packages: list[str]) -> None:
"s3fs",
"gcsfs",
"universal-pathlib",
"rich",
"obstore",
]

Expand Down
130 changes: 99 additions & 31 deletions src/zarr/core/_tree.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import io
import os
import sys
from collections import deque
from collections.abc import Sequence
from html import escape as html_escape
from typing import Any

from zarr.core.group import AsyncGroup

try:
import rich
import rich.console
import rich.tree
except ImportError as e:
raise ImportError("'rich' is required for Group.tree") from e


class TreeRepr:
"""
Expand All @@ -21,45 +15,119 @@ class TreeRepr:
of Zarr's public API.
"""

def __init__(self, tree: rich.tree.Tree) -> None:
self._tree = tree
def __init__(self, text: str, html: str, truncated: str = "") -> None:
self._text = text
self._html = html
self._truncated = truncated

def __repr__(self) -> str:
color_system = os.environ.get("OVERRIDE_COLOR_SYSTEM", rich.get_console().color_system)
console = rich.console.Console(file=io.StringIO(), color_system=color_system)
console.print(self._tree)
return str(console.file.getvalue())
if self._truncated:
return self._truncated + self._text
return self._text

def _repr_mimebundle_(
self,
include: Sequence[str],
exclude: Sequence[str],
include: Sequence[str] | None = None,
exclude: Sequence[str] | None = None,
**kwargs: Any,
) -> dict[str, str]:
text = self._truncated + self._text if self._truncated else self._text
# For jupyter support.
# Unsure why mypy infers the return type to by Any
return self._tree._repr_mimebundle_(include=include, exclude=exclude, **kwargs) # type: ignore[no-any-return]
html_body = self._truncated + self._html if self._truncated else self._html
html = (
'<pre style="white-space:pre;overflow-x:auto;line-height:normal;'
"font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">"
f"{html_body}</pre>\n"
)
return {"text/plain": text, "text/html": html}


async def group_tree_async(
group: AsyncGroup,
max_depth: int | None = None,
max_nodes: int = 500,
plain: bool = False,
) -> TreeRepr:
members: list[tuple[str, Any]] = []
truncated = False
async for item in group.members(max_depth=max_depth):
if len(members) == max_nodes:
truncated = True
break
members.append(item)
members.sort(key=lambda key_node: key_node[0])

async def group_tree_async(group: AsyncGroup, max_depth: int | None = None) -> TreeRepr:
tree = rich.tree.Tree(label=f"[bold]{group.name}[/bold]")
nodes = {"": tree}
members = sorted([x async for x in group.members(max_depth=max_depth)])
# Set up styling tokens: ANSI bold for terminals, HTML <b> for Jupyter,
# or empty strings when plain=True (useful for LLMs, logging, files).
if plain:
ansi_open = ansi_close = html_open = html_close = ""
else:
# Avoid emitting ANSI escape codes when output is piped or in CI.
use_ansi = sys.stdout.isatty()
ansi_open = "\x1b[1m" if use_ansi else ""
ansi_close = "\x1b[0m" if use_ansi else ""
html_open = "<b>"
html_close = "</b>"

# Group members by parent key so we can render the tree level by level.
nodes: dict[str, list[tuple[str, Any]]] = {}
for key, node in members:
if key.count("/") == 0:
parent_key = ""
else:
parent_key = key.rsplit("/", 1)[0]
parent = nodes[parent_key]
nodes.setdefault(parent_key, []).append((key, node))

# We want what the spec calls the node "name", the part excluding all leading
# /'s and path segments. But node.name includes all that, so we build it here.
# Render the tree iteratively (not recursively) to avoid hitting
# Python's recursion limit on deeply nested hierarchies.
# Each stack frame is (prefix_string, remaining_children_at_this_level).
text_lines = [f"{ansi_open}{group.name}{ansi_close}"]
html_lines = [f"{html_open}{html_escape(group.name)}{html_close}"]
stack = [("", deque(nodes.get("", [])))]
while stack:
prefix, remaining = stack[-1]
if not remaining:
stack.pop()
continue
key, node = remaining.popleft()
name = key.rsplit("/")[-1]
escaped_name = html_escape(name)
# if we popped the last item then remaining will
# now be empty - that's how we got past the if not remaining
# above, but this can still be true.
is_last = not remaining
connector = "└── " if is_last else "├── "
if isinstance(node, AsyncGroup):
label = f"[bold]{name}[/bold]"
text_lines.append(f"{prefix}{connector}{ansi_open}{name}{ansi_close}")
html_lines.append(f"{prefix}{connector}{html_open}{escaped_name}{html_close}")
else:
label = f"[bold]{name}[/bold] {node.shape} {node.dtype}"
nodes[key] = parent.add(label)

return TreeRepr(tree)
text_lines.append(
f"{prefix}{connector}{ansi_open}{name}{ansi_close} {node.shape} {node.dtype}"
)
html_lines.append(
f"{prefix}{connector}{html_open}{escaped_name}{html_close}"
f" {html_escape(str(node.shape))} {html_escape(str(node.dtype))}"
)
# Descend into children with an accumulated prefix:
# Example showing how prefix accumulates:
# /
# ├── a prefix = ""
# │ ├── b prefix = "" + "│ "
# │ │ └── x prefix = "" + "│ " + "│ "
# │ └── c prefix = "" + "│ "
# └── d prefix = ""
# └── e prefix = "" + " "
if children := nodes.get(key, []):
if is_last:
child_prefix = prefix + " "
else:
child_prefix = prefix + "│ "
stack.append((child_prefix, deque(children)))
text = "\n".join(text_lines) + "\n"
html = "\n".join(html_lines) + "\n"
note = (
f"Truncated at max_nodes={max_nodes}, some nodes and their children may be missing\n"
if truncated
else ""
)
return TreeRepr(text, html, truncated=note)
38 changes: 30 additions & 8 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,19 +1588,29 @@ async def array_values(
async for _, array in self.arrays():
yield array

async def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
async def tree(
self,
expand: bool | None = None,
level: int | None = None,
max_nodes: int = 500,
plain: bool = False,
) -> Any:
"""
Return a tree-like representation of a hierarchy.

This requires the optional ``rich`` dependency.

Parameters
----------
expand : bool, optional
This keyword is not yet supported. A NotImplementedError is raised if
it's used.
level : int, optional
The maximum depth below this Group to display in the tree.
max_nodes : int
Maximum number of nodes to display before truncating. Default is 500.
plain : bool, optional
If True, return a plain-text tree without ANSI styling. This is
useful when the output will be consumed by an LLM or written to a
file. Default is False.

Returns
-------
Expand All @@ -1611,7 +1621,7 @@ async def tree(self, expand: bool | None = None, level: int | None = None) -> An

if expand is not None:
raise NotImplementedError("'expand' is not yet implemented.")
return await group_tree_async(self, max_depth=level)
return await group_tree_async(self, max_depth=level, max_nodes=max_nodes, plain=plain)

async def empty(self, *, name: str, shape: tuple[int, ...], **kwargs: Any) -> AnyAsyncArray:
"""Create an empty array with the specified shape in this Group. The contents will
Expand Down Expand Up @@ -2371,26 +2381,38 @@ def array_values(self) -> Generator[AnyArray, None]:
for _, array in self.arrays():
yield array

def tree(self, expand: bool | None = None, level: int | None = None) -> Any:
def tree(
self,
expand: bool | None = None,
level: int | None = None,
max_nodes: int = 500,
plain: bool = False,
) -> Any:
"""
Return a tree-like representation of a hierarchy.

This requires the optional ``rich`` dependency.

Parameters
----------
expand : bool, optional
This keyword is not yet supported. A NotImplementedError is raised if
it's used.
level : int, optional
The maximum depth below this Group to display in the tree.
max_nodes : int
Maximum number of nodes to display before truncating. Default is 500.
plain : bool, optional
If True, return a plain-text tree without ANSI styling. This is
useful when the output will be consumed by an LLM or written to a
file. Default is False.

Returns
-------
TreeRepr
A pretty-printable object displaying the hierarchy.
"""
return self._sync(self._async_group.tree(expand=expand, level=level))
return self._sync(
self._async_group.tree(expand=expand, level=level, max_nodes=max_nodes, plain=plain)
)

def create_group(self, name: str, **kwargs: Any) -> Group:
"""Create a sub-group.
Expand Down
1 change: 0 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ def test_load_local(tmp_path: Path, path: str | None, load_read_only: bool) -> N


def test_tree() -> None:
pytest.importorskip("rich")
g1 = zarr.group()
g1.create_group("foo")
g3 = g1.create_group("bar")
Expand Down
Loading
Loading