diff --git a/README.md b/README.md index a68e506..944002c 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,9 @@ uv add git+https://github.com/sourcegraph/src-py-lib.git - `src_py_lib.clients.graphql` — shared GraphQL execution with automatic cursor pagination, batched alias lookups, and schema introspection export. - `src_py_lib.clients.sourcegraph` — Sourcegraph GraphQL client with token - validation and shared config fields for `SRC_ENDPOINT` (default: - `https://sourcegraph.com`) and `SRC_ACCESS_TOKEN`. + validation, endpoint normalization, connection streaming, and shared config + fields for `SRC_ENDPOINT` (default: `https://sourcegraph.com`) and + `SRC_ACCESS_TOKEN`. - `src_py_lib.clients.linear` — Linear GraphQL client with automatic cursor handling, token validation, shared config fields, and injectable HTTP policy. - `src_py_lib.clients.slack` — Slack Web API client with token validation, @@ -70,7 +71,7 @@ import src_py_lib as src class LinearExportConfig(src.LinearClientConfig): output_dir: Path = src.config_field( - Path("."), + default=Path("."), env_var="LINEAR_EXPORT_OUTPUT_DIR", cli_flag="--output-dir", metavar="PATH", @@ -85,9 +86,13 @@ print(f"Writing files under {config.output_dir}") Config precedence is: code defaults, `.env`, shell environment, then CLI overrides. API client modules can provide shared Config base classes such as `LinearClientConfig`, and `parse_args` resolves `op://...` references by -default. Pass a custom `argparse.ArgumentParser` to `parse_args` when a -CLI also has non-Config flags. Mark sensitive fields with `secret=True` so -snapshots do not expose resolved values. +default. `config_field(default=...)` supports aliases, store-true / +store-false command flags, optional values, numeric bounds, and string patterns +for simple CLIs. Pass a custom `argparse.ArgumentParser` to `parse_args` only when you +need parsing beyond Config fields. Help text preserves description and +argument-help newlines, and reserves enough option-column width for long config +flags. Mark sensitive fields with `secret=True` so snapshots do not expose +resolved values. ## Logging example @@ -95,22 +100,31 @@ Configure logging once at process startup. Prefer configuring the root logger (`logger_name=""`, the default) so project modules and shared `src_py_lib` modules such as `src_py_lib.utils.http` are captured by the same terminal and JSONL handlers. Use `logging()` in CLIs to configure logging, add the command field to all -structured events, and emit standard startup metadata. +structured events, and emit standard run/startup/run-end metadata. Use `debug()`, `info()`, `warning()`, `error()`, and `critical()` for one-off structured events. Use `event()` blocks around timed work; they emit `trace`, -`span`, and nested `parent_span` fields. +`span`, and nested `parent_span` fields. Use `start_level="debug"` to hide +noisy start events while keeping end timing visible, and +`omit_success_status=True` for very high-volume success events. Use `stage()` +for workflow context such as `stage="apply"`. When the root logger is configured, noisy `httpx`/`httpcore` records are suppressed; `HTTPClient` emits structured `http_request` events instead. -Set `SRC_LOG_LEVEL=INFO` for a run to omit DEBUG events from the log file. +Run-end events include HTTP attempt/byte/status/retry counters. Set +`LoggingSettings.resource_sample_interval_seconds` to emit DEBUG +`resource_sample` events and include process resource totals on run end. Set +`SRC_LOG_LEVEL=INFO` for a run to omit DEBUG events from the log file. +`LoggingConfig` includes `--verbose/-v`, `--quiet/-q`, and `--silent/-s` +shortcuts (also available as `SRC_LOG_VERBOSE`, `SRC_LOG_QUIET`, and +`SRC_LOG_SILENT`). Use `logging_settings_from_config()` to build +`LoggingSettings` from those conventions. ```python import src_py_lib as src -from src_py_lib.clients.sourcegraph import SourcegraphClient with src.logging({"src_token": "provided"}): src.info("sync_started", repository_count=3) - client = SourcegraphClient("https://sourcegraph.example.com", "token") + client = src.SourcegraphClient("https://sourcegraph.example.com", "token") data = client.graphql("query Viewer { currentUser { username } }") ``` diff --git a/src/src_py_lib/__init__.py b/src/src_py_lib/__init__.py index 7c64ebf..66e2710 100644 --- a/src/src_py_lib/__init__.py +++ b/src/src_py_lib/__init__.py @@ -3,8 +3,10 @@ from __future__ import annotations import sys +from collections.abc import Callable, Mapping from contextlib import AbstractContextManager from pathlib import Path +from typing import Any from src_py_lib.clients.github import GitHubClient, PullRequest, gh_cli_token, pr_ref_from_url from src_py_lib.clients.google_sheets import ( @@ -18,6 +20,7 @@ GraphQLError, aliased_batched_query, introspect_schema, + stream_connection_nodes, ) from src_py_lib.clients.linear import ( LinearClient, @@ -31,6 +34,12 @@ SlackPacer, slack_client_from_config, ) +from src_py_lib.clients.sourcegraph import ( + SourcegraphClient, + SourcegraphClientConfig, + normalize_sourcegraph_endpoint, + sourcegraph_client_from_config, +) from src_py_lib.utils.config import ( Config, ConfigError, @@ -63,7 +72,11 @@ log, log_context, logging_context, + logging_settings_from_config, + resolve_log_level_name, + stage, startup_event, + submit_with_log_context, warning, ) from src_py_lib.utils.tsv import write_tsv @@ -75,6 +88,8 @@ def logging( command: str | None = None, git_cwd: Path | str | None = None, logging_config: LoggingSettings | None = None, + run_fields: Mapping[str, Any] | None = None, + run_summary: Callable[[], Mapping[str, Any]] | None = None, ) -> AbstractContextManager[Path | None]: """Configure standard CLI logging and emit startup metadata.""" return logging_context( @@ -82,6 +97,8 @@ def logging( config, git_cwd=git_cwd, logging_config=logging_config, + run_fields=run_fields, + run_summary=run_summary, ) @@ -109,6 +126,8 @@ def _script_name() -> str: "SlackClientConfig", "SlackError", "SlackPacer", + "SourcegraphClient", + "SourcegraphClientConfig", "aliased_batched_query", "config_field", "config_snapshot", @@ -131,14 +150,21 @@ def _script_name() -> str: "load_json_cache", "load_json_subset", "logging", + "logging_settings_from_config", "log", "log_context", + "normalize_sourcegraph_endpoint", "parse_args", "pr_ref_from_url", "quota_project_from_adc", + "resolve_log_level_name", "save_json_cache", "slack_client_from_config", + "sourcegraph_client_from_config", + "stage", "startup_event", + "stream_connection_nodes", + "submit_with_log_context", "warning", "write_tsv", ] diff --git a/src/src_py_lib/clients/graphql.py b/src/src_py_lib/clients/graphql.py index 00e6d2b..fe4e226 100644 --- a/src/src_py_lib/clients/graphql.py +++ b/src/src_py_lib/clients/graphql.py @@ -4,7 +4,7 @@ import json import re -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass, field from pathlib import Path from typing import cast @@ -113,6 +113,17 @@ class GraphQLError(RuntimeError): """Raised for GraphQL transport or application errors.""" + def __init__( + self, + message: str, + *, + status_code: int | None = None, + is_application_error: bool = False, + ) -> None: + super().__init__(message) + self.status_code = status_code + self.is_application_error = is_application_error + @dataclass class GraphQLClient: @@ -174,6 +185,49 @@ def execute_next_page(next_variables: JSONDict) -> JSONDict: ) return data + def stream_connection_nodes( + self, + query: str, + variables: Mapping[str, JSONValue] | None = None, + *, + connection_path: Sequence[str], + page_size: int | None = None, + first_variable: str = "first", + after_variable: str = "after", + ) -> Iterator[JSONDict]: + """Stream one GraphQL connection's nodes page by page. + + `connection_path` is the response path to the connection object that + contains `nodes` and `pageInfo`, for example `("viewer", "items")`. + Unlike `execute(..., follow_pages=True)`, this does not accumulate all + nodes in memory before returning. + """ + page_number = 1 + + def execute_page( + operation: str, page_variables: Mapping[str, JSONValue] | None + ) -> JSONDict: + nonlocal page_number + data = self._execute_once( + operation, + dict(page_variables or {}), + page_number=page_number, + first_variable=first_variable, + after_variable=after_variable, + ) + page_number += 1 + return data + + yield from stream_connection_nodes( + execute_page, + query, + variables, + connection_path=connection_path, + page_size=page_size, + first_variable=first_variable, + after_variable=after_variable, + ) + def _execute_once( self, query: str, @@ -200,7 +254,8 @@ def _execute_once( payload = self.http.json("POST", self.url, headers=self.headers, json_body=body) except HTTPClientError as exception: raise GraphQLError( - f"{self.label} GraphQL request failed: {exception}" + f"{self.label} GraphQL request failed: {exception}", + status_code=exception.status_code, ) from exception errors = payload.get("errors") data = json_dict(payload.get("data")) @@ -208,7 +263,10 @@ def _execute_once( if errors: fields["graphql_errors"] = len(errors) if isinstance(errors, list) else 1 if errors and not (self.tolerate_partial_errors and data): - raise GraphQLError(f"{self.label} GraphQL errors: {errors}") + raise GraphQLError( + f"{self.label} GraphQL errors: {errors}", + is_application_error=True, + ) return data @@ -218,6 +276,49 @@ def operation_name(query: str) -> str: return match.group(1) if match else "anonymous" +def stream_connection_nodes( + execute: Callable[[str, Mapping[str, JSONValue] | None], JSONDict], + query: str, + variables: Mapping[str, JSONValue] | None = None, + *, + connection_path: Sequence[str], + page_size: int | None = None, + first_variable: str = "first", + after_variable: str = "after", +) -> Iterator[JSONDict]: + """Stream one GraphQL connection's nodes through any execute callable.""" + page_variables: JSONDict = dict(variables) if variables is not None else {} + if page_size is not None: + page_variables[first_variable] = page_size + query_uses_after_variable = _query_uses_variable(query, after_variable) + if query_uses_after_variable and after_variable not in page_variables: + page_variables[after_variable] = None + + path = tuple(connection_path) + current_cursor = page_variables.get(after_variable) + while True: + data = execute(query, dict(page_variables)) + page = _node_page_at_path(data, path) + for node in json_list(page.get("nodes")): + yield json_dict(node) + + page_info = json_dict(page.get("pageInfo")) + has_next_page = page_info.get("hasNextPage") + if not isinstance(has_next_page, bool): + raise GraphQLError( + f"GraphQL pagination path {_path_label(path)} missing pageInfo.hasNextPage" + ) + if not has_next_page: + return + if not query_uses_after_variable: + raise GraphQLError( + f"GraphQL query returned more pages but does not use ${after_variable}" + ) + next_cursor = _next_page_cursor(page_info, path, current_cursor) + page_variables[after_variable] = next_cursor + current_cursor = next_cursor + + def _int_variable(variables: JSONDict, name: str) -> int | None: value = variables.get(name) return value if isinstance(value, int) else None @@ -301,9 +402,7 @@ def _fetch_remaining_pages( target_page = _node_page_at_path(data, path) target_nodes = json_list(target_page.get("nodes")) page_info = json_dict(target_page.get("pageInfo")) - after = json_str(page_info, "endCursor") - if not after: - raise GraphQLError(f"GraphQL pagination path {'.'.join(path)} missing pageInfo.endCursor") + after = _next_page_cursor(page_info, path, variables.get(after_variable)) while after: page_variables = dict(variables) @@ -322,11 +421,7 @@ def _fetch_remaining_pages( ) if not has_next_page: return - after = json_str(next_page_info, "endCursor") - if not after: - raise GraphQLError( - f"GraphQL pagination path {'.'.join(path)} missing pageInfo.endCursor" - ) + after = _next_page_cursor(next_page_info, path, after) def _next_page_paths(data: JSONDict) -> list[tuple[str, ...]]: @@ -355,10 +450,27 @@ def _node_page_at_path(data: JSONDict, path: tuple[str, ...]) -> JSONDict: current = json_dict(current).get(key) page = json_dict(current) if not page: - label = ".".join(path) or "" - raise GraphQLError(f"GraphQL response did not include pagination path {label}") + raise GraphQLError(f"GraphQL response did not include pagination path {_path_label(path)}") return page +def _next_page_cursor(page_info: JSONDict, path: tuple[str, ...], current_cursor: object) -> str: + next_cursor = json_str(page_info, "endCursor") + if not next_cursor: + raise GraphQLError( + f"GraphQL pagination path {_path_label(path)} missing pageInfo.endCursor" + ) + if isinstance(current_cursor, str) and next_cursor == current_cursor: + raise GraphQLError( + f"GraphQL pagination path {_path_label(path)} stalled: " + f"pageInfo.endCursor did not advance from {current_cursor!r}" + ) + return next_cursor + + +def _path_label(path: tuple[str, ...]) -> str: + return ".".join(path) or "" + + def _query_uses_variable(query: str, variable: str) -> bool: return re.search(rf"\${re.escape(variable)}\b", query) is not None diff --git a/src/src_py_lib/clients/linear.py b/src/src_py_lib/clients/linear.py index 3c8e7ad..fe81257 100644 --- a/src/src_py_lib/clients/linear.py +++ b/src/src_py_lib/clients/linear.py @@ -49,11 +49,11 @@ class LinearClientConfig(Config): """Config fields needed to build a Linear API client.""" linear_api_token: str = config_field( - "", + default="", env_var="LINEAR_API_TOKEN", cli_flag="--linear-api-token", metavar="TOKEN", - help="Linear API token or op:// secret reference.", + help="Linear API token or op:// secret reference", secret=True, required=True, ) diff --git a/src/src_py_lib/clients/slack.py b/src/src_py_lib/clients/slack.py index 08fe078..0ac5844 100644 --- a/src/src_py_lib/clients/slack.py +++ b/src/src_py_lib/clients/slack.py @@ -27,11 +27,11 @@ class SlackClientConfig(Config): """Config fields needed to build a Slack API client.""" slack_bot_token: str = config_field( - "", + default="", env_var="SLACK_BOT_TOKEN", cli_flag="--slack-bot-token", metavar="TOKEN", - help="Slack bot token or op:// secret reference.", + help="Slack bot token or op:// secret reference", secret=True, required=True, ) diff --git a/src/src_py_lib/clients/sourcegraph.py b/src/src_py_lib/clients/sourcegraph.py index c4f51e7..ec9b158 100644 --- a/src/src_py_lib/clients/sourcegraph.py +++ b/src/src_py_lib/clients/sourcegraph.py @@ -2,12 +2,14 @@ from __future__ import annotations +from collections.abc import Iterator, Mapping, Sequence from dataclasses import dataclass, field +from urllib.parse import urlsplit -from src_py_lib.clients.graphql import GraphQLClient +from src_py_lib.clients.graphql import GraphQLClient, stream_connection_nodes from src_py_lib.utils.config import Config, config_field from src_py_lib.utils.http import HTTPClient -from src_py_lib.utils.json_types import JSONDict, json_dict +from src_py_lib.utils.json_types import JSONDict, JSONValue, json_dict DEFAULT_SOURCEGRAPH_ENDPOINT = "https://sourcegraph.com" SOURCEGRAPH_VALIDATE_QUERY = """ @@ -19,22 +21,42 @@ """ +def normalize_sourcegraph_endpoint(endpoint: str, *, require_https: bool = False) -> str: + """Return a stable Sourcegraph base URL, or raise ValueError.""" + normalized_endpoint = endpoint.strip().rstrip("/") + endpoint_parts = urlsplit(normalized_endpoint) + if require_https and endpoint_parts.scheme != "https": + raise ValueError( + f"Sourcegraph endpoint must be an https:// URL (got {endpoint_parts.scheme!r})" + ) + if endpoint_parts.scheme not in {"http", "https"}: + raise ValueError( + "Sourcegraph endpoint must be an http:// or https:// URL " + f"(got {endpoint_parts.scheme!r})" + ) + if not endpoint_parts.hostname: + raise ValueError( + f"could not parse hostname from Sourcegraph endpoint {normalized_endpoint!r}" + ) + return normalized_endpoint + + class SourcegraphClientConfig(Config): """Config fields needed to build a Sourcegraph API client.""" src_endpoint: str = config_field( - DEFAULT_SOURCEGRAPH_ENDPOINT, + default=DEFAULT_SOURCEGRAPH_ENDPOINT, env_var="SRC_ENDPOINT", cli_flag="--src-endpoint", metavar="URL", - help=f"Sourcegraph instance URL (default: {DEFAULT_SOURCEGRAPH_ENDPOINT}).", + help=f"Sourcegraph instance URL (default: {DEFAULT_SOURCEGRAPH_ENDPOINT})", ) src_access_token: str = config_field( - "", + default="", env_var="SRC_ACCESS_TOKEN", cli_flag="--src-access-token", metavar="TOKEN", - help="Sourcegraph access token or op:// secret reference.", + help="Sourcegraph access token, or op:// secret reference", secret=True, required=True, ) @@ -52,9 +74,33 @@ class SourcegraphClient: token: str http: HTTPClient = field(default_factory=HTTPClient) - def graphql(self, query: str, variables: JSONDict | None = None) -> JSONDict: + def __post_init__(self) -> None: + self.endpoint = normalize_sourcegraph_endpoint(self.endpoint) + + def graphql(self, query: str, variables: Mapping[str, JSONValue] | None = None) -> JSONDict: return self._client().execute(query, variables) + def stream_connection_nodes( + self, + query: str, + variables: Mapping[str, JSONValue] | None = None, + *, + connection_path: Sequence[str], + page_size: int | None = None, + first_variable: str = "first", + after_variable: str = "after", + ) -> Iterator[JSONDict]: + """Stream one Sourcegraph GraphQL connection's nodes.""" + return stream_connection_nodes( + self.graphql, + query, + variables, + connection_path=connection_path, + page_size=page_size, + first_variable=first_variable, + after_variable=after_variable, + ) + def validate(self) -> JSONDict: """Validate the token with a cheap current user query and return the user.""" current_user = json_dict(self.graphql(SOURCEGRAPH_VALIDATE_QUERY).get("currentUser")) @@ -66,7 +112,7 @@ def validate(self) -> JSONDict: def _client(self) -> GraphQLClient: return GraphQLClient( - url=f"{self.endpoint.rstrip('/')}/.api/graphql", + url=f"{self.endpoint}/.api/graphql", headers={"Authorization": f"token {self.token}"}, label="Sourcegraph", http=self.http, diff --git a/src/src_py_lib/utils/config.py b/src/src_py_lib/utils/config.py index e3a8bb3..d8a0a81 100644 --- a/src/src_py_lib/utils/config.py +++ b/src/src_py_lib/utils/config.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, replace from pathlib import Path from types import UnionType -from typing import Any, Final, Union, cast, get_args, get_origin +from typing import Any, Final, Literal, Union, cast, get_args, get_origin from dotenv import dotenv_values from pydantic import BaseModel, ConfigDict, Field, ValidationError @@ -30,10 +30,26 @@ ) DEFAULT_CONFIG_ENV_FILE: Final[Path] = Path(".env") +CONFIG_HELP_MIN_POSITION: Final[int] = 24 +CONFIG_HELP_MAX_POSITION_LIMIT: Final[int] = 48 +CONFIG_HELP_PADDING: Final[int] = 4 _CONFIG_OPTION_KEY: Final[str] = "src_py_lib_config_option" _MISSING: Final[object] = object() +class ConfigHelpFormatter(argparse.RawTextHelpFormatter): + """Help formatter for Config-backed CLIs.""" + + def __init__( + self, + prog: str, + indent_increment: int = 2, + max_help_position: int = CONFIG_HELP_MIN_POSITION, + width: int | None = None, + ) -> None: + super().__init__(prog, indent_increment, max_help_position, width) + + class ConfigError(RuntimeError): """Raised when Config loading, validation, or reference resolution fails.""" @@ -45,6 +61,10 @@ class ConfigOption: field_name: str env_var: str cli_flag: str = "" + cli_aliases: tuple[str, ...] = () + cli_action: Literal["auto", "store_true", "store_false"] = "auto" + cli_nargs: str | int | None = None + cli_const: object | None = None metavar: str | None = None help: str = "" secret: bool = False @@ -58,30 +78,53 @@ class Config(BaseModel): def config_field( - default: Any = ..., *, + default: Any = ..., env_var: str, cli_flag: str | None = None, + cli_aliases: Sequence[str] = (), + cli_action: Literal["auto", "store_true", "store_false"] = "auto", + cli_nargs: str | int | None = None, + cli_const: object | None = None, metavar: str | None = None, help: str = "", secret: bool = False, required: bool = False, + gt: int | float | None = None, + ge: int | float | None = None, + lt: int | float | None = None, + le: int | float | None = None, + pattern: str | None = None, ) -> Any: """Return a Pydantic field with Config environment and CLI metadata.""" option = ConfigOption( field_name="", env_var=env_var, cli_flag=cli_flag or "", + cli_aliases=tuple(cli_aliases), + cli_action=cli_action, + cli_nargs=cli_nargs, + cli_const=cli_const, metavar=metavar, help=help, secret=secret, required=required, ) - return Field( - default, - description=help or None, - json_schema_extra=_config_json_schema_extra(option), - ) + field_kwargs: dict[str, Any] = { + "description": help or None, + "json_schema_extra": _config_json_schema_extra(option), + } + if gt is not None: + field_kwargs["gt"] = gt + if ge is not None: + field_kwargs["ge"] = ge + if lt is not None: + field_kwargs["lt"] = lt + if le is not None: + field_kwargs["le"] = le + if pattern is not None: + field_kwargs["pattern"] = pattern + return Field(default, **field_kwargs) def config_options(config_cls: type[Config]) -> tuple[ConfigOption, ...]: @@ -150,7 +193,7 @@ def add_config_arguments( """Add Config CLI flags to an argparse parser.""" group = parser.add_argument_group( "Config", - "These options override matching environment variables and .env values.", + "These options override matching environment variables and .env values", ) if include_env_file: group.add_argument( @@ -158,27 +201,28 @@ def add_config_arguments( dest="env_file", default=None, metavar="PATH", - help="Read Config .env values from PATH (default: .env).", + help="Read Config .env values from PATH (default: .env)", ) for option in config_options(config_cls): field_info = config_cls.model_fields[option.field_name] + argument_kwargs: dict[str, Any] = { + "dest": option.field_name, + "default": None, + "help": option.help, + } + if option.metavar is not None: + argument_kwargs["metavar"] = option.metavar + if option.cli_nargs is not None: + argument_kwargs["nargs"] = option.cli_nargs + if option.cli_const is not None: + argument_kwargs["const"] = option.cli_const if _is_bool_annotation(field_info.annotation): - group.add_argument( - option.cli_flag, - dest=option.field_name, - action=argparse.BooleanOptionalAction, - default=None, - help=option.help, - ) - else: - group.add_argument( - option.cli_flag, - dest=option.field_name, - default=None, - metavar=option.metavar, - help=option.help, - ) + if option.cli_action == "auto": + argument_kwargs["action"] = argparse.BooleanOptionalAction + else: + argument_kwargs["action"] = option.cli_action + group.add_argument(option.cli_flag, *option.cli_aliases, **argument_kwargs) def config_parse_args[ConfigT: Config]( @@ -195,7 +239,11 @@ def config_parse_args[ConfigT: Config]( require: Iterable[str] = (), ) -> ConfigT: """Parse Config CLI flags and return a validated Config model.""" - argument_parser = parser or argparse.ArgumentParser(description=description) + max_help_position = _config_help_max_position(config_cls, include_env_file=include_env_file) + argument_parser = parser or argparse.ArgumentParser( + description=description, + formatter_class=_config_help_formatter(max_help_position), + ) add_config_arguments(argument_parser, config_cls, include_env_file=include_env_file) args = argument_parser.parse_args(argv) try: @@ -212,6 +260,77 @@ def config_parse_args[ConfigT: Config]( argument_parser.error(str(exception)) +def _config_help_formatter(max_help_position: int) -> type[argparse.HelpFormatter]: + """Return a formatter class with this parser's computed help position.""" + + class DynamicConfigHelpFormatter(ConfigHelpFormatter): + def __init__(self, prog: str) -> None: + super().__init__(prog, max_help_position=max_help_position) + + return DynamicConfigHelpFormatter + + +def _config_help_max_position( + config_cls: type[Config], + *, + include_env_file: bool, +) -> int: + """Return help-column width based on this Config's CLI arguments.""" + invocation_lengths = [len("--env-file PATH")] if include_env_file else [] + invocation_lengths.extend( + _config_option_invocation_length(config_cls, option) + for option in config_options(config_cls) + ) + longest_invocation = max(invocation_lengths, default=0) + return min( + max(CONFIG_HELP_MIN_POSITION, longest_invocation + CONFIG_HELP_PADDING), + CONFIG_HELP_MAX_POSITION_LIMIT, + ) + + +def _config_option_invocation_length(config_cls: type[Config], option: ConfigOption) -> int: + """Return argparse-style option invocation length for help alignment.""" + field_info = config_cls.model_fields[option.field_name] + option_strings = _config_option_strings(option, field_info) + if _config_option_takes_value(option, field_info): + arguments = _config_option_arguments(option) + return len(", ".join(f"{option_string} {arguments}" for option_string in option_strings)) + return len(", ".join(option_strings)) + + +def _config_option_strings(option: ConfigOption, field_info: FieldInfo) -> tuple[str, ...]: + """Return option strings as argparse will display them.""" + if _is_bool_annotation(field_info.annotation) and option.cli_action == "auto": + long_options = tuple( + f"--no-{option_string.removeprefix('--')}" + for option_string in (option.cli_flag, *option.cli_aliases) + if option_string.startswith("--") + ) + return (option.cli_flag, *long_options, *option.cli_aliases) + return (option.cli_flag, *option.cli_aliases) + + +def _config_option_takes_value(option: ConfigOption, field_info: FieldInfo) -> bool: + """Return whether argparse displays a value placeholder for this option.""" + if not _is_bool_annotation(field_info.annotation): + return True + return option.cli_action == "auto" and option.cli_nargs is not None + + +def _config_option_arguments(option: ConfigOption) -> str: + """Return the argparse-style value placeholder for this option.""" + metavar = option.metavar or option.field_name.upper() + if option.cli_nargs == "?": + return f"[{metavar}]" + if option.cli_nargs == "*": + return f"[{metavar} ...]" + if option.cli_nargs == "+": + return f"{metavar} [{metavar} ...]" + if isinstance(option.cli_nargs, int): + return " ".join(metavar for _ in range(option.cli_nargs)) + return metavar + + def config_overrides_from_args( config_cls: type[Config], args: argparse.Namespace ) -> dict[str, object]: @@ -320,10 +439,14 @@ def _config_option_from_field(field_name: str, field_info: FieldInfo) -> ConfigO ) -def _config_option_payload(option: ConfigOption) -> dict[str, str | bool | None]: +def _config_option_payload(option: ConfigOption) -> dict[str, object]: return { "env_var": option.env_var, "cli_flag": option.cli_flag, + "cli_aliases": list(option.cli_aliases), + "cli_action": option.cli_action, + "cli_nargs": option.cli_nargs, + "cli_const": option.cli_const, "metavar": option.metavar, "help": option.help, "secret": option.secret, @@ -340,12 +463,19 @@ def _config_option_from_payload(payload: Mapping[str, object]) -> ConfigOption | if not isinstance(env_var, str) or not env_var: return None cli_flag = payload.get("cli_flag") + cli_aliases = payload.get("cli_aliases") + cli_action = payload.get("cli_action") + cli_nargs = payload.get("cli_nargs") metavar = payload.get("metavar") help_text = payload.get("help") return ConfigOption( field_name="", env_var=env_var, cli_flag=cli_flag if isinstance(cli_flag, str) else "", + cli_aliases=_string_tuple(cli_aliases), + cli_action=_cli_action(cli_action), + cli_nargs=cli_nargs if isinstance(cli_nargs, str | int) else None, + cli_const=payload.get("cli_const"), metavar=metavar if isinstance(metavar, str) else None, help=help_text if isinstance(help_text, str) else "", secret=payload.get("secret") is True, @@ -353,6 +483,18 @@ def _config_option_from_payload(payload: Mapping[str, object]) -> ConfigOption | ) +def _string_tuple(value: object) -> tuple[str, ...]: + if not isinstance(value, Sequence) or isinstance(value, str | bytes): + return () + return tuple(item for item in cast(Sequence[object], value) if isinstance(item, str)) + + +def _cli_action(value: object) -> Literal["auto", "store_true", "store_false"]: + if value in {"store_true", "store_false"}: + return cast(Literal["store_true", "store_false"], value) + return "auto" + + def _selected_raw_value( option: ConfigOption, env_file_values: Mapping[str, str], diff --git a/src/src_py_lib/utils/http.py b/src/src_py_lib/utils/http.py index f721ad3..44ed2ec 100644 --- a/src/src_py_lib/utils/http.py +++ b/src/src_py_lib/utils/http.py @@ -14,7 +14,7 @@ import httpx from src_py_lib.utils.json_types import JSONDict, json_dict -from src_py_lib.utils.logging import event +from src_py_lib.utils.logging import event, record_http_attempt, record_http_retry DEFAULT_TIMEOUT_SECONDS: Final[float] = 30.0 DEFAULT_MAX_CONNECTIONS: Final[int] = 20 @@ -134,6 +134,11 @@ def request( http_version = _response_http_version(response) if http_version is not None: fields["http_version"] = http_version + record_http_attempt( + request_bytes=len(body or b""), + response_bytes=len(payload), + status_code=response.status_code, + ) if response.status_code >= 400: body_text = _body_preview(payload) if not self._should_retry(response.status_code, attempt): @@ -144,24 +149,29 @@ def request( body=body_text, headers=dict(response.headers), ) + record_http_retry() self._sleep_before_retry(attempt, response.headers.get("Retry-After")) else: return payload except HTTPClientError: raise except httpx.TimeoutException as exception: + record_http_attempt(request_bytes=len(body or b""), transport_error=True) if not self._should_retry(None, attempt): raise HTTPClientError( f"HTTP request timed out for {method} {_safe_url(request_url)}: " f"{_exception_message(exception)}" ) from exception + record_http_retry() self._sleep_before_retry(attempt, None) except httpx.TransportError as exception: + record_http_attempt(request_bytes=len(body or b""), transport_error=True) if not self._should_retry(None, attempt): raise HTTPClientError( f"HTTP request failed for {method} {_safe_url(request_url)}: " f"{_exception_message(exception)}" ) from exception + record_http_retry() self._sleep_before_retry(attempt, None) raise AssertionError("HTTP retry loop exited without returning or raising") diff --git a/src/src_py_lib/utils/logging.py b/src/src_py_lib/utils/logging.py index aa63609..385178d 100644 --- a/src/src_py_lib/utils/logging.py +++ b/src/src_py_lib/utils/logging.py @@ -16,12 +16,19 @@ import os import secrets import subprocess +import sys import threading import time -from collections.abc import Generator, Iterable, Mapping -from dataclasses import dataclass +from collections.abc import Callable, Generator, Iterable, Mapping +from concurrent.futures import Executor, Future +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Final, cast +from typing import Any, Final, Self, cast + +if sys.platform != "win32": + import resource + +from pydantic import model_validator from src_py_lib.utils.config import Config, config_field, config_snapshot @@ -30,7 +37,11 @@ DEFAULT_RETAIN_FILES: Final[int] = 50 DEFAULT_LOG_FILE_LEVEL: Final[str] = "debug" SRC_LOG_LEVEL: Final[str] = "SRC_LOG_LEVEL" +SRC_LOG_VERBOSE: Final[str] = "SRC_LOG_VERBOSE" +SRC_LOG_QUIET: Final[str] = "SRC_LOG_QUIET" +SRC_LOG_SILENT: Final[str] = "SRC_LOG_SILENT" TRACE_SPAN_BYTES: Final[int] = 4 +MEBIBYTE: Final[int] = 1024 * 1024 SECRET_FIELD_FRAGMENTS: Final[tuple[str, ...]] = ( "api_key", "authorization", @@ -50,6 +61,7 @@ "logger", "event", "phase", + "stage", "message", ) @@ -73,17 +85,105 @@ class LoggingSettings: run: str = RUN retain_log_files: int = DEFAULT_RETAIN_FILES suppress_http_dependency_logs: bool = True + resource_sample_interval_seconds: float | None = None class LoggingConfig(Config): """Config fields for logging-related CLI and environment options.""" src_log_level: str | None = config_field( - None, + default="INFO", env_var=SRC_LOG_LEVEL, cli_flag="--src-log-level", metavar="LEVEL", - help="Minimum level for log events (default: DEBUG; e.g. INFO hides debug events).", + help="Log level (default: INFO)", + ) + verbose: bool = config_field( + default=False, + env_var=SRC_LOG_VERBOSE, + cli_flag="--verbose", + cli_aliases=("-v",), + cli_action="store_true", + help="Alias for --src-log-level DEBUG", + ) + quiet: bool = config_field( + default=False, + env_var=SRC_LOG_QUIET, + cli_flag="--quiet", + cli_aliases=("-q",), + cli_action="store_true", + help="Alias for --src-log-level WARNING", + ) + silent: bool = config_field( + default=False, + env_var=SRC_LOG_SILENT, + cli_flag="--silent", + cli_aliases=("-s",), + cli_action="store_true", + help="Alias for --src-log-level ERROR", + ) + + @model_validator(mode="after") + def validate_log_level_alias(self) -> Self: + """Require at most one alias for the terminal/log-file level.""" + if sum((self.verbose, self.quiet, self.silent)) > 1: + raise ValueError("choose only one of --verbose/-v, --quiet/-q, or --silent/-s") + return self + + +def resolve_log_level_name( + config: object | None = None, + *, + log_level: str | None = None, + verbose: bool | None = None, + quiet: bool | None = None, + silent: bool | None = None, +) -> str | None: + """Resolve common CLI log-level alias to a level name. + + Alias flags intentionally only map to strings. Explicit log-level + values are returned unchanged so `configure_logging()` owns parsing + and fallback behavior. + """ + resolved_verbose = verbose if verbose is not None else bool(getattr(config, "verbose", False)) + resolved_quiet = quiet if quiet is not None else bool(getattr(config, "quiet", False)) + resolved_silent = silent if silent is not None else bool(getattr(config, "silent", False)) + if resolved_verbose: + return "DEBUG" + if resolved_quiet: + return "WARNING" + if resolved_silent: + return "ERROR" + if log_level is not None: + return log_level + return _src_log_level_from_config(config) + + +def logging_settings_from_config( + config: object | None = None, + *, + terminal_default: str = "INFO", + log_file_default: str | None = DEFAULT_LOG_FILE_LEVEL, + logger_name: str = "", + log_file: Path | None = None, + logs_dir: Path | None = DEFAULT_LOGS_DIR, + run: str = RUN, + retain_log_files: int = DEFAULT_RETAIN_FILES, + suppress_http_dependency_logs: bool = True, + resource_sample_interval_seconds: float | None = None, +) -> LoggingSettings: + """Return `LoggingSettings` using common CLI log-level alias.""" + explicit_level = resolve_log_level_name(config) + return LoggingSettings( + logger_name=logger_name, + terminal_level=explicit_level or terminal_default, + log_file_level=explicit_level or log_file_default, + log_file=log_file, + logs_dir=logs_dir, + run=run, + retain_log_files=retain_log_files, + suppress_http_dependency_logs=suppress_http_dependency_logs, + resource_sample_interval_seconds=resource_sample_interval_seconds, ) @@ -98,6 +198,116 @@ class _SpanContext: "src_py_lib_span_context", default=None ) +_HTTP_METRICS_LOCK: Final[threading.Lock] = threading.Lock() +_http_request_attempt_count = 0 +_http_request_bytes_total = 0 +_http_response_bytes_total = 0 +_http_retry_count = 0 +_http_2xx_count = 0 +_http_3xx_count = 0 +_http_4xx_count = 0 +_http_429_count = 0 +_http_5xx_count = 0 +_http_transport_error_count = 0 + + +@dataclass +class ResourceSampler: + """Emit optional process resource samples and summarize usage at run end.""" + + interval_seconds: float + _stop: threading.Event = field(init=False, default_factory=threading.Event) + _thread: threading.Thread | None = field(init=False, default=None) + _started_at: float = field(init=False, default_factory=time.perf_counter) + _last_sample_at: float = field(init=False, default_factory=time.perf_counter) + _last_cpu_seconds: float = field(init=False, default=0.0) + _start_usage: Any = field(init=False, default=None) + _peak_rss_bytes: int = field(init=False, default=0) + + def __post_init__(self) -> None: + if self.interval_seconds < 0: + raise ValueError("resource_sample_interval_seconds must be >= 0") + self._start_usage = _resource_usage() + if self._start_usage is not None: + self._last_cpu_seconds = _cpu_seconds(self._start_usage) + + def start(self) -> None: + """Start periodic sampling, if enabled by a positive interval.""" + if self.interval_seconds <= 0: + return + context = contextvars.copy_context() + self._thread = threading.Thread( + target=context.run, + args=(self._loop,), + name="ResourceSampler", + daemon=True, + ) + self._thread.start() + self.emit_sample() + + def emit_sample(self) -> None: + """Emit one DEBUG `resource_sample` event.""" + log("debug", "resource_sample", **self._sample_fields()) + + def stop_and_summary(self) -> dict[str, Any]: + """Stop periodic sampling and return run-end resource fields.""" + if self.interval_seconds > 0: + self.emit_sample() + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2.0) + usage = _resource_usage() + summary: dict[str, Any] = { + "cpu_count_logical": os.cpu_count() or 0, + "num_threads": threading.active_count(), + } + file_descriptors = _num_file_descriptors() + if file_descriptors is not None: + summary["num_fds"] = file_descriptors + rss_bytes = _rss_bytes(usage) + if rss_bytes is not None: + self._peak_rss_bytes = max(self._peak_rss_bytes, rss_bytes) + if self._peak_rss_bytes: + summary["peak_rss_mb"] = _bytes_to_mib(self._peak_rss_bytes) + if usage is not None and self._start_usage is not None: + summary["cpu_user_seconds"] = round( + float(usage.ru_utime) - float(self._start_usage.ru_utime), 3 + ) + summary["cpu_system_seconds"] = round( + float(usage.ru_stime) - float(self._start_usage.ru_stime), 3 + ) + summary["io_read_count"] = int(usage.ru_inblock) - int(self._start_usage.ru_inblock) + summary["io_write_count"] = int(usage.ru_oublock) - int(self._start_usage.ru_oublock) + return summary + + def _loop(self) -> None: + while not self._stop.wait(self.interval_seconds): + self.emit_sample() + + def _sample_fields(self) -> dict[str, Any]: + now = time.perf_counter() + usage = _resource_usage() + fields: dict[str, Any] = { + "num_threads": threading.active_count(), + } + rss_bytes = _rss_bytes(usage) + if rss_bytes is not None: + self._peak_rss_bytes = max(self._peak_rss_bytes, rss_bytes) + fields["rss_mb"] = _bytes_to_mib(rss_bytes) + file_descriptors = _num_file_descriptors() + if file_descriptors is not None: + fields["num_fds"] = file_descriptors + if usage is not None: + cpu_seconds = _cpu_seconds(usage) + elapsed = max(now - self._last_sample_at, 0.001) + fields["process_cpu_percent"] = round( + max(cpu_seconds - self._last_cpu_seconds, 0.0) / elapsed * 100.0, + 1, + ) + self._last_cpu_seconds = cpu_seconds + self._last_sample_at = now + return fields + class _DropStructuredEvents(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: @@ -170,6 +380,7 @@ def configure_logging(config: LoggingSettings | None = None) -> Path | None: Returns the JSON log-file path when file logging is enabled. """ config = config or LoggingSettings() + reset_observability_metrics() terminal_level = _log_level(config.terminal_level) log_file_level = _log_file_level(config.log_file_level) log_file = config.log_file @@ -210,6 +421,79 @@ def configure_logging(config: LoggingSettings | None = None) -> Path | None: return log_file +def reset_observability_metrics() -> None: + """Reset process-wide HTTP counters used by `logging_context()` run summaries.""" + global _http_request_attempt_count, _http_request_bytes_total, _http_response_bytes_total + global _http_retry_count, _http_2xx_count, _http_3xx_count, _http_4xx_count + global _http_429_count, _http_5xx_count, _http_transport_error_count + with _HTTP_METRICS_LOCK: + _http_request_attempt_count = 0 + _http_request_bytes_total = 0 + _http_response_bytes_total = 0 + _http_retry_count = 0 + _http_2xx_count = 0 + _http_3xx_count = 0 + _http_4xx_count = 0 + _http_429_count = 0 + _http_5xx_count = 0 + _http_transport_error_count = 0 + + +def record_http_attempt( + *, + request_bytes: int, + response_bytes: int = 0, + status_code: int | None = None, + transport_error: bool = False, +) -> None: + """Record one HTTP attempt for the current run summary.""" + global _http_request_attempt_count, _http_request_bytes_total, _http_response_bytes_total + global _http_2xx_count, _http_3xx_count, _http_4xx_count, _http_429_count + global _http_5xx_count, _http_transport_error_count + with _HTTP_METRICS_LOCK: + _http_request_attempt_count += 1 + _http_request_bytes_total += request_bytes + _http_response_bytes_total += response_bytes + if transport_error: + _http_transport_error_count += 1 + if status_code is None: + return + if 200 <= status_code < 300: + _http_2xx_count += 1 + elif 300 <= status_code < 400: + _http_3xx_count += 1 + elif 400 <= status_code < 500: + _http_4xx_count += 1 + if status_code == 429: + _http_429_count += 1 + elif status_code >= 500: + _http_5xx_count += 1 + + +def record_http_retry() -> None: + """Record that an HTTP attempt will be retried.""" + global _http_retry_count + with _HTTP_METRICS_LOCK: + _http_retry_count += 1 + + +def observability_summary() -> dict[str, Any]: + """Return process-wide counters accumulated since logging was configured.""" + with _HTTP_METRICS_LOCK: + return { + "http_request_attempt_count": _http_request_attempt_count, + "http_request_bytes_total": _http_request_bytes_total, + "http_response_bytes_total": _http_response_bytes_total, + "http_retry_count": _http_retry_count, + "http_2xx_count": _http_2xx_count, + "http_3xx_count": _http_3xx_count, + "http_4xx_count": _http_4xx_count, + "http_429_count": _http_429_count, + "http_5xx_count": _http_5xx_count, + "http_transport_error_count": _http_transport_error_count, + } + + @contextlib.contextmanager def logging_context( name: str, @@ -217,21 +501,57 @@ def logging_context( *, git_cwd: Path | str | None = None, logging_config: LoggingSettings | None = None, + run_fields: Mapping[str, Any] | None = None, + run_summary: Callable[[], Mapping[str, Any]] | None = None, ) -> Generator[Path | None]: """Configure logging, install command context, and emit startup metadata.""" resolved_logging_config = logging_config or LoggingSettings( log_file_level=_src_log_level_from_config(config) ) log_file = configure_logging(resolved_logging_config) + sampler = _resource_sampler(resolved_logging_config) + started = time.perf_counter() + error: BaseException | None = None with log_context(command=name): - startup_event( - command=name, - config=config, - log_file=log_file, - git_cwd=_git_cwd_path(git_cwd), - logger_name=resolved_logging_config.logger_name, - ) - yield log_file + if sampler is not None: + sampler.start() + start_fields = {"phase": "start", **dict(run_fields or {})} + info("run", logger_name=resolved_logging_config.logger_name, **start_fields) + try: + startup_event( + command=name, + config=config, + log_file=log_file, + git_cwd=_git_cwd_path(git_cwd), + logger_name=resolved_logging_config.logger_name, + ) + yield log_file + except BaseException as exception: + error = exception + raise + finally: + error_type = _run_error_type(error) + summary: dict[str, Any] = {} + if sampler is not None: + summary.update(sampler.stop_and_summary()) + summary.update(observability_summary()) + summary["exit_code"] = _run_exit_code(error) + if run_summary is not None: + summary.update(dict(run_summary())) + end_fields = { + "phase": "end", + "duration_ms": round((time.perf_counter() - started) * 1000.0), + "status": "error" if error_type else "ok", + "error_type": error_type, + **dict(run_fields or {}), + **summary, + } + log( + "error" if error_type else "info", + "run", + logger_name=resolved_logging_config.logger_name, + **end_fields, + ) def default_log_file(logs_dir: Path = DEFAULT_LOGS_DIR, *, run: str = RUN) -> Path: @@ -293,9 +613,22 @@ def log_context(**fields: Any) -> Generator[None]: _CONTEXT.reset(reset_token) +@contextlib.contextmanager +def stage(name: str, **fields: Any) -> Generator[None]: + """Add a workflow stage field for nested logs and structured events.""" + with log_context(stage=name, **fields): + yield + + @contextlib.contextmanager def event( - key: str, *, level: str = "info", logger_name: str = "", **fields: Any + key: str, + *, + level: str = "info", + start_level: str | None = None, + omit_success_status: bool = False, + logger_name: str = "", + **fields: Any, ) -> Generator[dict[str, Any]]: """Emit start/end structured events around a block of work.""" parent = _SPAN_CONTEXT.get() @@ -306,7 +639,7 @@ def event( ) reset_token = _SPAN_CONTEXT.set(span) try: - log(level, key, logger_name=logger_name, phase="start", **fields) + log(start_level or level, key, logger_name=logger_name, phase="start", **fields) started = time.perf_counter() extra: dict[str, Any] = {} error: BaseException | None = None @@ -321,9 +654,13 @@ def event( **extra, "phase": "end", "duration_ms": round((time.perf_counter() - started) * 1000.0), - "status": "error" if error else "ok", - "error_type": type(error).__name__ if error else None, } + if error: + end_fields["status"] = "error" + end_fields["error_type"] = type(error).__name__ + elif not omit_success_status: + end_fields["status"] = "ok" + end_fields["error_type"] = None log( "error" if error else level, key, @@ -334,6 +671,17 @@ def event( _SPAN_CONTEXT.reset(reset_token) +def submit_with_log_context( + executor: Executor, + function: Callable[..., Any], + *args: Any, + **kwargs: Any, +) -> Future[Any]: + """Submit work to an executor with current logging ContextVars propagated.""" + context = contextvars.copy_context() + return executor.submit(context.run, function, *args, **kwargs) + + def sanitized_config_snapshot(config: object) -> dict[str, Any]: """Return a log-safe snapshot of dataclass/object/dict config values.""" if isinstance(config, Mapping): @@ -454,13 +802,13 @@ def _log_level(value: int | str) -> int: return value normalized = value.strip().upper() if not normalized: - return logging.DEBUG + return logging.INFO if normalized.isdecimal(): return int(normalized) levels = logging.getLevelNamesMapping() level = levels.get(normalized) if level is None: - return logging.DEBUG + return logging.INFO return level @@ -541,6 +889,76 @@ def _secret_state(value: object) -> str: return "reference" if isinstance(value, str) and value.startswith("op://") else "provided" +def _resource_sampler(config: LoggingSettings) -> ResourceSampler | None: + interval_seconds = config.resource_sample_interval_seconds + return ResourceSampler(interval_seconds) if interval_seconds is not None else None + + +def _run_error_type(exception: BaseException | None) -> str | None: + if exception is None: + return None + if isinstance(exception, SystemExit) and exception.code in (None, 0): + return None + return type(exception).__name__ + + +def _run_exit_code(exception: BaseException | None) -> int: + if exception is None: + return 0 + if isinstance(exception, SystemExit): + return exception.code if isinstance(exception.code, int) else 1 + return 1 + + +def _resource_usage() -> Any | None: + if sys.platform == "win32": + return None + return resource.getrusage(resource.RUSAGE_SELF) + + +def _cpu_seconds(usage: Any) -> float: + return float(usage.ru_utime) + float(usage.ru_stime) + + +def _rss_bytes(usage: Any | None) -> int | None: + current = _linux_current_rss_bytes() + if current is not None: + return current + if usage is None: + return None + # Linux reports ru_maxrss in KiB; macOS reports bytes. + max_rss = int(usage.ru_maxrss) + return max_rss if sys.platform == "darwin" else max_rss * 1024 + + +def _linux_current_rss_bytes() -> int | None: + statm = Path("/proc/self/statm") + if not statm.exists(): + return None + try: + fields = statm.read_text(encoding="utf-8").split() + if len(fields) < 2: + return None + return int(fields[1]) * os.sysconf("SC_PAGE_SIZE") + except (OSError, ValueError): + return None + + +def _num_file_descriptors() -> int | None: + for directory in (Path("/proc/self/fd"), Path("/dev/fd")): + if not directory.exists(): + continue + try: + return len(list(directory.iterdir())) + except OSError: + continue + return None + + +def _bytes_to_mib(byte_count: int) -> float: + return round(byte_count / MEBIBYTE, 2) + + def _prune_old_log_files(logs_dir: Path, retain_files: int) -> None: if retain_files <= 0 or not logs_dir.exists(): return diff --git a/tests/test_import.py b/tests/test_import.py index de08888..856d024 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -21,8 +21,11 @@ def test_root_public_api_exports_common_entrypoints(self) -> None: self.assertIsNotNone(src_py_lib.LinearClientConfig) self.assertIsNotNone(src_py_lib.LoggingConfig) self.assertIsNotNone(src_py_lib.LoggingSettings) + self.assertIsNotNone(src_py_lib.resolve_log_level_name) self.assertIsNotNone(src_py_lib.SlackClient) self.assertIsNotNone(src_py_lib.SlackPacer) + self.assertIsNotNone(src_py_lib.SourcegraphClient) + self.assertIsNotNone(src_py_lib.SourcegraphClientConfig) self.assertIsNotNone(src_py_lib.config_field) self.assertIsNotNone(src_py_lib.gh_cli_token) self.assertIsNotNone(src_py_lib.gcloud_adc_access_token) @@ -31,12 +34,16 @@ def test_root_public_api_exports_common_entrypoints(self) -> None: self.assertIsNotNone(src_py_lib.json_str) self.assertIsNotNone(src_py_lib.log) self.assertIsNotNone(src_py_lib.logging) + self.assertIsNotNone(src_py_lib.logging_settings_from_config) self.assertIsNotNone(src_py_lib.linear_client_from_config) self.assertIsNotNone(src_py_lib.load_json_cache) + self.assertIsNotNone(src_py_lib.normalize_sourcegraph_endpoint) self.assertIsNotNone(src_py_lib.parse_args) self.assertIsNotNone(src_py_lib.quota_project_from_adc) self.assertIsNotNone(src_py_lib.save_json_cache) self.assertIsNotNone(src_py_lib.slack_client_from_config) + self.assertIsNotNone(src_py_lib.sourcegraph_client_from_config) + self.assertIsNotNone(src_py_lib.stream_connection_nodes) self.assertIsNotNone(src_py_lib.write_tsv) diff --git a/tests/test_logging_http_clients.py b/tests/test_logging_http_clients.py index 8f0a809..0566ca1 100644 --- a/tests/test_logging_http_clients.py +++ b/tests/test_logging_http_clients.py @@ -10,7 +10,7 @@ import tempfile import unittest from collections.abc import Mapping -from contextlib import redirect_stderr +from contextlib import redirect_stderr, redirect_stdout from pathlib import Path from typing import Any from unittest.mock import patch @@ -20,7 +20,12 @@ import src_py_lib as src from src_py_lib.clients.github import GitHubClient, graphql_api_url, pr_ref_from_url from src_py_lib.clients.google_sheets import GoogleSheetsClient -from src_py_lib.clients.graphql import GraphQLClient, GraphQLError, introspect_schema +from src_py_lib.clients.graphql import ( + GraphQLClient, + GraphQLError, + introspect_schema, + stream_connection_nodes, +) from src_py_lib.clients.linear import LinearClient, LinearClientConfig, linear_client_from_config from src_py_lib.clients.one_password import ( OnePasswordClient, @@ -31,6 +36,7 @@ from src_py_lib.clients.sourcegraph import ( SourcegraphClient, SourcegraphClientConfig, + normalize_sourcegraph_endpoint, sourcegraph_client_from_config, ) from src_py_lib.utils.config import ( @@ -61,6 +67,8 @@ info, log, log_context, + logging_settings_from_config, + resolve_log_level_name, startup_event, warning, ) @@ -116,39 +124,39 @@ class ExampleConfig(Config): """Config model used by Config tests.""" token: str = config_field( - "", + default="", env_var="EXAMPLE_TOKEN", cli_flag="--token", metavar="TOKEN", - help="Example token.", + help="Example token", secret=True, ) page_size: int = config_field( - 25, + default=25, env_var="EXAMPLE_PAGE_SIZE", cli_flag="--page-size", metavar="N", - help="Example page size.", + help="Example page size", ) include_archived: bool = config_field( - False, + default=False, env_var="EXAMPLE_INCLUDE_ARCHIVED", cli_flag="--include-archived", - help="Include archived examples.", + help="Include archived examples", ) output_dir: Path = config_field( - Path("out"), + default=Path("out"), env_var="EXAMPLE_OUTPUT_DIR", cli_flag="--output-dir", metavar="PATH", - help="Example output directory.", + help="Example output directory", ) labels: tuple[str, ...] = config_field( - (), + default=(), env_var="EXAMPLE_LABELS", cli_flag="--labels", metavar="CSV", - help="Example labels.", + help="Example labels", ) @@ -156,39 +164,108 @@ class RequiredConfig(Config): """Config model with a required secret field.""" token: str = config_field( - "", + default="", env_var="REQUIRED_TOKEN", cli_flag="--token", metavar="TOKEN", - help="Required token.", + help="Required token", secret=True, required=True, ) name: str = config_field( - "", + default="", env_var="REQUIRED_NAME", cli_flag="--name", metavar="NAME", - help="Non-secret required config name.", + help="Non-secret required config name", + ) + + +class MultilineHelpConfig(Config): + """Config model with multiline CLI help text.""" + + notes: str = config_field( + default="", + env_var="MULTILINE_HELP_NOTES", + cli_flag="--notes", + metavar="TEXT", + help="First line.\nSecond line.\n Indented detail.", ) class SnapshotOrderConfig(Config): """Config model whose field names and env-var names sort differently.""" - alpha: str = config_field("a", env_var="ZZZ_ALPHA") - zulu: str = config_field("z", env_var="AAA_ZULU") + alpha: str = config_field(default="a", env_var="ZZZ_ALPHA") + zulu: str = config_field(default="z", env_var="AAA_ZULU") + + +class BoundedConfig(Config): + """Config model with numeric bounds.""" + + page_size: int = config_field( + default=25, + env_var="BOUNDED_PAGE_SIZE", + cli_flag="--page-size", + metavar="N", + ge=1, + ) + sample_interval: float = config_field( + default=10.0, + env_var="BOUNDED_SAMPLE_INTERVAL", + cli_flag="--sample-interval", + metavar="SECS", + ge=0, + ) + + +class PatternConfig(Config): + """Config model with a string pattern constraint.""" + + date: str | None = config_field( + default=None, + env_var="PATTERN_DATE", + cli_flag="--date", + metavar="YYYY-MM-DD", + pattern=r"^\d{4}-\d{2}-\d{2}$", + ) + + +class CommandStyleConfig(Config): + """Config model with command-style flags.""" + + get: bool = config_field( + default=False, + env_var="COMMAND_STYLE_GET", + cli_flag="--get", + cli_action="store_true", + ) + verbose: bool = config_field( + default=False, + env_var="COMMAND_STYLE_VERBOSE", + cli_flag="--verbose", + cli_aliases=("-v",), + cli_action="store_true", + ) + schema_path: Path | None = config_field( + default=None, + env_var="COMMAND_STYLE_SCHEMA_PATH", + cli_flag="--get-schema", + cli_nargs="?", + cli_const="schema.gql", + metavar="FILE", + ) class LinearExampleConfig(LinearClientConfig): """Config model composed from Linear client fields and app fields.""" page_size: int = config_field( - 25, + default=25, env_var="LINEAR_EXAMPLE_PAGE_SIZE", cli_flag="--page-size", metavar="N", - help="Example page size.", + help="Example page size", ) @@ -196,11 +273,11 @@ class SourcegraphExampleConfig(SourcegraphClientConfig): """Config model composed from Sourcegraph client fields and app fields.""" repo_query: str = config_field( - "", + default="", env_var="SOURCEGRAPH_EXAMPLE_REPO_QUERY", cli_flag="--repo-query", metavar="QUERY", - help="Example Sourcegraph repository query.", + help="Example Sourcegraph repository query", ) @@ -374,10 +451,84 @@ def test_argparse_helpers_add_flags_and_collect_overrides(self) -> None: }, ) + def test_config_arguments_support_aliases_actions_and_optional_values(self) -> None: + parser = argparse.ArgumentParser() + add_config_arguments(parser, CommandStyleConfig) + + default_schema_args = parser.parse_args(["--get", "-v", "--get-schema"]) + named_schema_args = parser.parse_args(["--get-schema", "custom.gql"]) + + default_schema_config = load_config_from_args( + CommandStyleConfig, + default_schema_args, + env={}, + resolve_op_refs=False, + ) + named_schema_config = load_config_from_args( + CommandStyleConfig, + named_schema_args, + env={}, + resolve_op_refs=False, + ) + + self.assertTrue(default_schema_config.get) + self.assertTrue(default_schema_config.verbose) + self.assertEqual(default_schema_config.schema_path, Path.cwd() / "schema.gql") + self.assertEqual(named_schema_config.schema_path, Path.cwd() / "custom.gql") + + def test_config_field_supports_numeric_bounds(self) -> None: + config = load_config( + BoundedConfig, + env_file=None, + env={"BOUNDED_PAGE_SIZE": "1", "BOUNDED_SAMPLE_INTERVAL": "0"}, + resolve_op_refs=False, + ) + + self.assertEqual(config.page_size, 1) + self.assertEqual(config.sample_interval, 0) + with self.assertRaisesRegex(ConfigError, "greater than or equal to 1"): + load_config( + BoundedConfig, + env_file=None, + env={"BOUNDED_PAGE_SIZE": "0"}, + resolve_op_refs=False, + ) + with self.assertRaisesRegex(ConfigError, "greater than or equal to 0"): + load_config( + BoundedConfig, + env_file=None, + env={"BOUNDED_SAMPLE_INTERVAL": "-0.1"}, + resolve_op_refs=False, + ) + + def test_config_field_supports_string_pattern(self) -> None: + config = load_config( + PatternConfig, + env_file=None, + env={"PATTERN_DATE": "2026-01-31"}, + resolve_op_refs=False, + ) + + self.assertEqual(config.date, "2026-01-31") + with self.assertRaisesRegex(ConfigError, "String should match pattern"): + load_config( + PatternConfig, + env_file=None, + env={"PATTERN_DATE": "2026-1-31"}, + resolve_op_refs=False, + ) + with self.assertRaisesRegex(ConfigError, "String should match pattern"): + load_config( + PatternConfig, + env_file=None, + env={"PATTERN_DATE": "2026-01-31T00:00:00Z"}, + resolve_op_refs=False, + ) + def test_logging_config_mixin_adds_log_level_from_cli_and_env(self) -> None: parser = argparse.ArgumentParser() add_config_arguments(parser, LoggingExampleConfig) - args = parser.parse_args(["--src-log-level", "INFO"]) + args = parser.parse_args(["--src-log-level", "INFO", "-v"]) cli_config = load_config_from_args( LoggingExampleConfig, @@ -393,8 +544,63 @@ def test_logging_config_mixin_adds_log_level_from_cli_and_env(self) -> None: ) self.assertEqual(cli_config.src_log_level, "INFO") + self.assertTrue(cli_config.verbose) self.assertEqual(env_config.src_log_level, "ERROR") + def test_logging_config_rejects_multiple_log_level_alias(self) -> None: + with self.assertRaisesRegex(ConfigError, "choose only one of --verbose"): + load_config( + LoggingExampleConfig, + env_file=None, + env={"SRC_LOG_VERBOSE": "true", "SRC_LOG_QUIET": "true"}, + resolve_op_refs=False, + ) + + def test_resolve_log_level_name_maps_cli_alias(self) -> None: + self.assertEqual(resolve_log_level_name(verbose=True), "DEBUG") + self.assertEqual(resolve_log_level_name(quiet=True), "WARNING") + self.assertEqual(resolve_log_level_name(silent=True), "ERROR") + self.assertEqual(resolve_log_level_name(log_level="trace"), "trace") + self.assertIsNone(resolve_log_level_name(object())) + + config = LoggingExampleConfig(src_log_level="INFO") + self.assertEqual(resolve_log_level_name(config), "INFO") + verbose_config = LoggingExampleConfig(src_log_level="INFO", verbose=True) + self.assertEqual(resolve_log_level_name(verbose_config), "DEBUG") + quiet_config = config_parse_args( + LoggingExampleConfig, + argv=["-q"], + env={}, + resolve_op_refs=False, + ) + self.assertEqual(resolve_log_level_name(quiet_config), "WARNING") + env_config = load_config( + LoggingExampleConfig, + env_file=None, + env={"SRC_LOG_SILENT": "true"}, + resolve_op_refs=False, + ) + self.assertTrue(env_config.silent) + self.assertEqual(resolve_log_level_name(env_config), "ERROR") + + def test_logging_settings_from_config_maps_common_cli_levels(self) -> None: + default_settings = logging_settings_from_config( + resource_sample_interval_seconds=2.5, + ) + self.assertEqual(default_settings.terminal_level, "INFO") + self.assertEqual(default_settings.log_file_level, "debug") + self.assertEqual(default_settings.resource_sample_interval_seconds, 2.5) + + quiet_config = LoggingExampleConfig(src_log_level="INFO", quiet=True) + quiet_settings = logging_settings_from_config(quiet_config) + self.assertEqual(quiet_settings.terminal_level, "WARNING") + self.assertEqual(quiet_settings.log_file_level, "WARNING") + + log_level_config = LoggingExampleConfig(src_log_level="ERROR") + log_level_settings = logging_settings_from_config(log_level_config) + self.assertEqual(log_level_settings.terminal_level, "ERROR") + self.assertEqual(log_level_settings.log_file_level, "ERROR") + def test_config_parse_args_loads_config_and_reports_config_errors(self) -> None: config = config_parse_args( ExampleConfig, @@ -414,6 +620,61 @@ def test_config_parse_args_loads_config_and_reports_config_errors(self) -> None: self.assertEqual(raised.exception.code, 2) self.assertIn("REQUIRED_TOKEN", stderr.getvalue()) + def test_config_parse_args_preserves_description_newlines_in_help(self) -> None: + description = "Example CLI.\n\nSteps:\n 1. Collect data.\n 2. Export data." + stdout = io.StringIO() + + with redirect_stdout(stdout), self.assertRaises(SystemExit) as raised: + config_parse_args( + ExampleConfig, + argv=["--help"], + description=description, + env={}, + resolve_op_refs=False, + ) + + self.assertEqual(raised.exception.code, 0) + self.assertIn(description, stdout.getvalue()) + + def test_config_parse_args_keeps_long_options_on_help_line(self) -> None: + stdout = io.StringIO() + + with redirect_stdout(stdout), self.assertRaises(SystemExit) as raised: + config_parse_args( + SourcegraphExampleConfig, + argv=["--help"], + env={}, + resolve_op_refs=False, + ) + + self.assertEqual(raised.exception.code, 0) + help_text = stdout.getvalue() + self.assertNotIn("--src-access-token TOKEN\n", help_text) + self.assertRegex(help_text, r"--src-access-token TOKEN +Sourcegraph access token") + + def test_config_parse_args_preserves_argument_help_newlines(self) -> None: + stdout = io.StringIO() + + with redirect_stdout(stdout), self.assertRaises(SystemExit) as raised: + config_parse_args( + MultilineHelpConfig, + argv=["--help"], + env={}, + resolve_op_refs=False, + ) + + self.assertEqual(raised.exception.code, 0) + help_text = stdout.getvalue() + self.assertIn("First line.\n", help_text) + self.assertRegex(help_text, r"\n +Second line\.\n") + self.assertRegex(help_text, r"\n + Indented detail\.") + + def test_config_field_requires_named_default(self) -> None: + config_field_any: Any = config_field + + with self.assertRaises(TypeError): + config_field_any("", env_var="POSITIONAL_DEFAULT") + def test_required_values_and_reference_resolution(self) -> None: with self.assertRaisesRegex(ConfigError, "REQUIRED_TOKEN"): load_config(RequiredConfig, env_file=None, env={}) @@ -559,7 +820,7 @@ def test_log_and_level_helpers_use_string_levels(self) -> None: ) ) try: - log("bogus", "fallback_debug", logger_name=logger_name) + log("bogus", "fallback_info", logger_name=logger_name) warning("warning_event", logger_name=logger_name) error("error_event", logger_name=logger_name) critical("critical_event", logger_name=logger_name) @@ -571,7 +832,7 @@ def test_log_and_level_helpers_use_string_levels(self) -> None: rows = [json.loads(line) for line in log_file.read_text().splitlines()] levels = {row["event"]: row["level"] for row in rows} - self.assertEqual(levels["fallback_debug"], "DEBUG") + self.assertEqual(levels["fallback_info"], "INFO") self.assertEqual(levels["warning_event"], "WARNING") self.assertEqual(levels["error_event"], "ERROR") self.assertEqual(levels["critical_event"], "CRITICAL") @@ -758,6 +1019,129 @@ def test_event_context_adds_trace_and_span_fields(self) -> None: self.assertEqual(inner_log["span"], inner_start["span"]) self.assertEqual(inner_log["parent_span"], outer_start["span"]) + def test_event_can_lower_start_level_and_omit_success_status(self) -> None: + with tempfile.TemporaryDirectory() as directory: + log_file = Path(directory) / "events.json" + logger_name = "src_py_lib_test_quiet_event" + configure_logging( + LoggingSettings( + logger_name=logger_name, + terminal_level="critical", + log_file_level="info", + log_file=log_file, + run="test-run", + ) + ) + try: + with event( + "quiet_start", + logger_name=logger_name, + level="info", + start_level="debug", + omit_success_status=True, + ): + pass + finally: + logger = logging.getLogger(logger_name) + for handler in list(logger.handlers): + logger.removeHandler(handler) + handler.close() + + rows = [json.loads(line) for line in log_file.read_text().splitlines()] + quiet_rows = [row for row in rows if row["event"] == "quiet_start"] + self.assertEqual(len(quiet_rows), 1) + self.assertEqual(quiet_rows[0]["phase"], "end") + self.assertNotIn("status", quiet_rows[0]) + self.assertNotIn("error_type", quiet_rows[0]) + + def test_logging_context_emits_run_summary_resource_and_http_metrics(self) -> None: + attempts = 0 + + def handler(_request: httpx.Request) -> httpx.Response: + nonlocal attempts + attempts += 1 + if attempts == 1: + return httpx.Response(429, json={"retry": True}, headers={"Retry-After": "0"}) + return httpx.Response(200, json={"ok": True}) + + with tempfile.TemporaryDirectory() as directory: + log_file = Path(directory) / "events.json" + try: + with src.logging( + command="unit-test", + logging_config=LoggingSettings( + terminal_level="critical", + log_file_level="debug", + log_file=log_file, + run="test-run", + resource_sample_interval_seconds=0, + ), + run_fields={"endpoint": "https://example.com"}, + run_summary=lambda: {"custom_count": 7}, + ): + client = HTTPClient( + max_attempts=2, + retry_base_delay_seconds=0, + retry_max_delay_seconds=0, + transport=httpx.MockTransport(handler), + ) + self.assertEqual( + client.json( + "POST", + "https://example.com/api", + json_body={"hello": "world"}, + ), + {"ok": True}, + ) + finally: + logger = logging.getLogger("") + for handler_ in list(logger.handlers): + logger.removeHandler(handler_) + handler_.close() + + rows = [json.loads(line) for line in log_file.read_text().splitlines()] + run_end = next(row for row in rows if row["event"] == "run" and row["phase"] == "end") + self.assertEqual(run_end["status"], "ok") + self.assertEqual(run_end["exit_code"], 0) + self.assertEqual(run_end["endpoint"], "https://example.com") + self.assertEqual(run_end["custom_count"], 7) + self.assertEqual(run_end["http_request_attempt_count"], 2) + self.assertEqual(run_end["http_retry_count"], 1) + self.assertEqual(run_end["http_2xx_count"], 1) + self.assertEqual(run_end["http_429_count"], 1) + self.assertGreater(run_end["http_request_bytes_total"], 0) + self.assertGreater(run_end["http_response_bytes_total"], 0) + self.assertIn("cpu_count_logical", run_end) + + def test_logging_context_records_system_exit_code(self) -> None: + with tempfile.TemporaryDirectory() as directory: + log_file = Path(directory) / "events.json" + try: + with ( + self.assertRaises(SystemExit), + src.logging( + command="unit-test", + logging_config=LoggingSettings( + terminal_level="critical", + log_file_level="debug", + log_file=log_file, + run="test-run", + ), + ), + ): + raise SystemExit(3) + finally: + logger = logging.getLogger("") + for handler_ in list(logger.handlers): + logger.removeHandler(handler_) + handler_.close() + + rows = [json.loads(line) for line in log_file.read_text().splitlines()] + run_end = next(row for row in rows if row["event"] == "run" and row["phase"] == "end") + self.assertEqual(run_end["status"], "error") + self.assertEqual(run_end["error_type"], "SystemExit") + self.assertEqual(run_end["exit_code"], 3) + def test_httpx_request_logs_are_debug_events(self) -> None: with tempfile.TemporaryDirectory() as directory: log_file = Path(directory) / "events.json" @@ -951,16 +1335,74 @@ def handler(_request: httpx.Request) -> httpx.Response: class ClientTest(unittest.TestCase): + def test_normalize_sourcegraph_endpoint(self) -> None: + self.assertEqual( + normalize_sourcegraph_endpoint(" https://sourcegraph.example.com/ "), + "https://sourcegraph.example.com", + ) + self.assertEqual( + normalize_sourcegraph_endpoint("http://localhost:3080/"), + "http://localhost:3080", + ) + with self.assertRaisesRegex(ValueError, "https:// URL"): + normalize_sourcegraph_endpoint("http://localhost:3080", require_https=True) + with self.assertRaisesRegex(ValueError, "http:// or https:// URL"): + normalize_sourcegraph_endpoint("sourcegraph.example.com") + def test_sourcegraph_client_builds_graphql_request(self) -> None: http = RecordingHTTP([{"data": {"currentUser": {"username": "alice"}}}]) - client = SourcegraphClient("https://sourcegraph.example.com/", "token", http=http) + client = SourcegraphClient(" https://sourcegraph.example.com/ ", "token", http=http) data = client.graphql("query Viewer { currentUser { username } }") + self.assertEqual(client.endpoint, "https://sourcegraph.example.com") self.assertEqual(data, {"currentUser": {"username": "alice"}}) self.assertEqual(http.calls[0]["method"], "POST") self.assertEqual(http.calls[0]["url"], "https://sourcegraph.example.com/.api/graphql") self.assertEqual(http.calls[0]["headers"], {"Authorization": "token token"}) + def test_sourcegraph_client_streams_connection_nodes(self) -> None: + http = RecordingHTTP( + [ + { + "data": { + "users": { + "nodes": [{"username": "alice"}], + "pageInfo": {"hasNextPage": True, "endCursor": "cursor-1"}, + } + } + }, + { + "data": { + "users": { + "nodes": [{"username": "bob"}], + "pageInfo": {"hasNextPage": False, "endCursor": None}, + } + } + }, + ] + ) + client = SourcegraphClient("https://sourcegraph.example.com", "token", http=http) + nodes = list( + client.stream_connection_nodes( + """ + query Users($first: Int, $after: String) { + users(first: $first, after: $after) { + nodes { username } + pageInfo { hasNextPage endCursor } + } + } + """, + connection_path=("users",), + page_size=1, + ) + ) + + self.assertEqual(nodes, [{"username": "alice"}, {"username": "bob"}]) + first_body = json_dict(http.calls[0]["json_body"]) + second_body = json_dict(http.calls[1]["json_body"]) + self.assertEqual(first_body["variables"], {"first": 1, "after": None}) + self.assertEqual(second_body["variables"], {"first": 1, "after": "cursor-1"}) + def test_sourcegraph_client_validate_queries_current_user(self) -> None: http = RecordingHTTP([{"data": {"currentUser": {"username": "alice"}}}]) client = SourcegraphClient("https://sourcegraph.example.com/", "token", http=http) @@ -1025,6 +1467,119 @@ def test_graphql_client_paginates_cursor_results(self) -> None: {"userId": "u1", "first": 2, "after": "cursor-1"}, ) + def test_graphql_client_streams_connection_nodes(self) -> None: + http = RecordingHTTP( + [ + { + "data": { + "viewer": { + "items": { + "nodes": [{"id": "1"}], + "pageInfo": { + "hasNextPage": True, + "endCursor": "cursor-1", + }, + } + } + } + }, + { + "data": { + "viewer": { + "items": { + "nodes": [{"id": "2"}], + "pageInfo": { + "hasNextPage": False, + "endCursor": None, + }, + } + } + } + }, + ] + ) + client = GraphQLClient("https://example.com/graphql", {}, "Example", http=http) + query = """ +query Items($first: Int!, $after: String, $userId: ID!) { + viewer { items { nodes { id } pageInfo { hasNextPage endCursor } } } +} +""" + + nodes = list( + client.stream_connection_nodes( + query, + variables={"userId": "u1"}, + connection_path=("viewer", "items"), + page_size=2, + ) + ) + + self.assertEqual(nodes, [{"id": "1"}, {"id": "2"}]) + self.assertEqual( + http.calls[0]["json_body"]["variables"], + {"userId": "u1", "first": 2, "after": None}, + ) + self.assertEqual( + http.calls[1]["json_body"]["variables"], + {"userId": "u1", "first": 2, "after": "cursor-1"}, + ) + + def test_stream_connection_nodes_accepts_execute_callback(self) -> None: + calls: list[dict[str, Any]] = [] + responses: list[JSONDict] = [ + { + "viewer": { + "items": { + "nodes": [{"id": "1"}], + "pageInfo": { + "hasNextPage": True, + "endCursor": "cursor-1", + }, + } + } + }, + { + "viewer": { + "items": { + "nodes": [{"id": "2"}], + "pageInfo": { + "hasNextPage": False, + "endCursor": None, + }, + } + } + }, + ] + + def execute(query: str, variables: Mapping[str, Any] | None) -> JSONDict: + calls.append({"query": query, "variables": dict(variables or {})}) + return responses.pop(0) + + query = """ +query Items($first: Int!, $after: String, $userId: ID!) { + viewer { items { nodes { id } pageInfo { hasNextPage endCursor } } } +} +""" + + nodes = list( + stream_connection_nodes( + execute, + query, + variables={"userId": "u1"}, + connection_path=("viewer", "items"), + page_size=2, + ) + ) + + self.assertEqual(nodes, [{"id": "1"}, {"id": "2"}]) + self.assertEqual( + [call["variables"] for call in calls], + [ + {"userId": "u1", "first": 2, "after": None}, + {"userId": "u1", "first": 2, "after": "cursor-1"}, + ], + ) + def test_graphql_client_emits_query_debug_events(self) -> None: http = RecordingHTTP( [ @@ -1127,6 +1682,78 @@ def test_graphql_client_requires_end_cursor_for_next_page(self) -> None: page_size=100, ) + def test_graphql_client_rejects_stalled_cursor(self) -> None: + http = RecordingHTTP( + [ + { + "data": { + "items": { + "nodes": [], + "pageInfo": {"hasNextPage": True, "endCursor": "cursor-1"}, + } + } + }, + { + "data": { + "items": { + "nodes": [], + "pageInfo": {"hasNextPage": True, "endCursor": "cursor-1"}, + } + } + }, + ] + ) + client = GraphQLClient("https://example.com/graphql", {}, "Example", http=http) + query = """ +query Items($first: Int!, $after: String) { + items { nodes { id } pageInfo { hasNextPage endCursor } } +} +""" + + with self.assertRaisesRegex(GraphQLError, "stalled"): + client.execute( + query, + page_size=100, + ) + + def test_graphql_client_preserves_http_status_on_transport_errors(self) -> None: + class FailingHTTP(RecordingHTTP): + def json( + self, + method: str, + url: str, + *, + headers: Mapping[str, str] | None = None, + query: Mapping[str, str | int | float | bool | None] | None = None, + json_body: object | None = None, + ) -> dict[str, Any]: + raise HTTPClientError("unavailable", status_code=503) + + client = GraphQLClient("https://example.com/graphql", {}, "Example", http=FailingHTTP()) + + with self.assertRaises(GraphQLError) as raised: + client.execute("query Viewer { viewer { login } }", follow_pages=False) + + self.assertEqual(raised.exception.status_code, 503) + self.assertFalse(raised.exception.is_application_error) + + def test_graphql_client_marks_application_errors(self) -> None: + http = RecordingHTTP( + [ + { + "data": {}, + "errors": [{"message": "field does not exist"}], + } + ] + ) + client = GraphQLClient("https://example.com/graphql", {}, "Example", http=http) + + with self.assertRaises(GraphQLError) as raised: + client.execute("query Broken { missingField }", follow_pages=False) + + self.assertIsNone(raised.exception.status_code) + self.assertTrue(raised.exception.is_application_error) + def test_github_pr_ref_from_url(self) -> None: self.assertEqual( pr_ref_from_url("https://github.com/sourcegraph/amp/pull/1234"),