From feda8bf785779f90fed239b8cf6c8dd4cf13d8f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 26 May 2025 11:11:38 +0200 Subject: [PATCH 01/17] WIP --- src/docstub/_docstrings.py | 40 +++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index c4aa13e..413705f 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,6 +2,7 @@ import logging import traceback +from contextlib import contextmanager from dataclasses import dataclass, field from functools import cached_property from pathlib import Path @@ -31,7 +32,7 @@ _lark = lark.Lark(_grammar, propagate_positions=True, strict=True) -def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: +def _find_one_token(tree, *, name): """Find token with a specific type name in tree. Parameters @@ -285,25 +286,13 @@ def doctype_to_annotation(self, doctype): A set containing tuples. Each tuple contains a qualname, its start and its end index relative to the given `doctype`. """ - try: - self._collected_imports = set() - self._unknown_qualnames = [] + with self._prepare_transformation(): tree = _lark.parse(doctype) value = super().transform(tree=tree) annotation = Annotation( value=value, imports=frozenset(self._collected_imports) ) return annotation, self._unknown_qualnames - except ( - lark.exceptions.LexError, - lark.exceptions.ParseError, - QualnameIsKeyword, - ): - self.stats["syntax_errors"] += 1 - raise - finally: - self._collected_imports = None - self._unknown_qualnames = None def qualname(self, tree): """ @@ -509,6 +498,29 @@ def __default__(self, data, children, meta): out = children return out + @contextmanager + def _prepare_transformation(self): + """Reset transformation state before entering context and restore it on exit.""" + collected_imports = self._collected_imports + unknown_qualnames = self._unknown_qualnames + + try: + self._collected_imports = set() + self._unknown_qualnames = [] + yield + + except ( + lark.exceptions.LexError, + lark.exceptions.ParseError, + QualnameIsKeyword, + ): + self.stats["syntax_errors"] += 1 + raise + + finally: + self._collected_imports = collected_imports + self._unknown_qualnames = unknown_qualnames + def _match_import(self, qualname, *, meta): """Match `qualname` to known imports or alias to "Incomplete". From 444ef8abce29886445a495c6ad3c9341a67d9949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Tue, 27 May 2025 18:29:43 +0200 Subject: [PATCH 02/17] WIP add _doctype.py --- src/docstub/_doctype.py | 280 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 src/docstub/_doctype.py diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py new file mode 100644 index 0000000..547c0be --- /dev/null +++ b/src/docstub/_doctype.py @@ -0,0 +1,280 @@ +"""Parsing of doctypes""" + +import logging +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +import lark +import lark.visitors + +logger = logging.getLogger(__name__) + + +grammar_path = Path(__file__).parent / "doctype.lark" + +with grammar_path.open() as file: + _grammar = file.read() + +_lark = lark.Lark(_grammar, propagate_positions=True, strict=True) + + +def flatten_recursive(iterable): + for item in iterable: + if not isinstance(item, str) and isinstance(item, Iterable): + yield from flatten_recursive(item) + else: + yield item + + +def insert_between(iterable, *, sep): + out = [] + for item in iterable: + out.append(item) + out.append(sep) + return out[:-1] + + +class Token(str): + """A token representing an atomic part of a doctype.""" + + __slots__ = ("value", "kind") + + def __new__(cls, value, *, kind): + self = super().__new__(cls, value) + self.kind = kind + return self + + def __repr__(self): + return f"{type(self).__name__}('{self}', kind={self.kind!r})" + + @classmethod + def find_iter(cls, iterable, *, kind): + for item in flatten_recursive(iterable): + if isinstance(item, cls) and item.kind == kind: + yield item + + @classmethod + def find_one(cls, iterable, *, kind): + matching = list(cls.find_iter(iterable, kind=kind)) + if len(matching) != 1: + msg = ( + f"expected exactly one {cls.__name__} with {kind=}, got {len(matching)}" + ) + raise ValueError(msg) + return matching[0] + + +@lark.visitors.v_args(tree=True) +class DoctypeTransformer(lark.visitors.Transformer): + def qualname(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.Token + """ + children = tree.children + _qualname = ".".join(children) + _qualname = Token(_qualname, kind="qualname") + return _qualname + + def rst_role(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.Token + """ + qualname = Token.find_one(tree.children, kind="qualname") + return qualname + + def union(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : list[str] + """ + sep = Token(" | ", kind="union_sep") + out = insert_between(tree.children, sep=sep) + return out + + def subscription(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : str + """ + return self._format_subscription(tree.children, name="subscription") + + def natlang_literal(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : str + """ + items = [Token("Literal", kind="qualname"), *tree.children] + out = self._format_subscription(items, "nl_literal") + + if len(tree.children) == 1: + logger.warning( + "natural language literal with one item `%s`, " + "consider using `%s` to improve readability", + tree.children[0], + "".join(out), + ) + return out + + def natlang_container(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : str + """ + return self._format_subscription(tree.children, name="nl_container") + + def natlang_array(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : str + """ + array_name = Token.find_one(tree.children, kind="array_name") + items = tree.children.copy() + items.remove(array_name) + items.insert(0, Token(array_name, kind="qualname")) + return self._format_subscription(items, name="nl_array") + + def array_name(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.Token + """ + # Treat `array_name` as `qualname`, but mark it as an array name, + # so we know which one to treat as the container in `array_expression` + # This currently relies on a hack that only allows specific names + # in `array_expression` (see `ARRAY_NAME` terminal in gramar) + qualname = self.qualname(tree) + qualname = Token(qualname, kind="array_name") + return qualname + + def shape(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping shape information") + return lark.Discard + + def optional(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping optional / default info") + return lark.Discard + + def extra_info(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : lark.visitors._DiscardType + """ + logger.debug("dropping extra info") + return lark.Discard + + def _format_subscription(self, sequence, name): + sep = Token(", ", kind=f"{name}_sep") + container, *content = sequence + content = insert_between(content, sep=sep) + assert content + out = [ + container, + Token("[", kind=f"{name}_start"), + *content, + Token("]", kind=f"{name}_stop"), + ] + return out + + def __default_token__(self, token): + return Token(token.value, kind=token.type.lower()) + + +@dataclass(frozen=True, slots=True) +class ParsedDoctype: + tokens: tuple[Token, ...] + raw_doctype: str + + @classmethod + def parse(cls, doctype): + """Turn a type description in a docstring into a type annotation. + + Parameters + ---------- + doctype : str + The doctype to parse. + + Returns + ------- + annotation_list : list of Token + + Examples + -------- + >>> ParsedDoctype.parse("tuple of int or ndarray of dtype (float or int)") + + """ + tree = _lark.parse(doctype) + result = DoctypeTransformer().transform(tree=tree) + result = tuple(flatten_recursive(result)) + return cls(result, raw_doctype=doctype) + + def __str__(self): + return "".join(self.tokens) + + def __repr__(self): + return f"<{type(self).__name__}: '{self}'>" From a2ffae613c0da762e11c6b4eb8a476d7f848da6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 28 May 2025 13:09:04 +0200 Subject: [PATCH 03/17] WIP use Flags to mark doctype tokens --- src/docstub/_doctype.py | 133 ++++++++++++++++++++++++++++++--------- src/docstub/doctype.lark | 2 +- 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 547c0be..777d609 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -1,5 +1,7 @@ """Parsing of doctypes""" +import enum +import itertools import logging from collections.abc import Iterable from dataclasses import dataclass @@ -35,14 +37,42 @@ def insert_between(iterable, *, sep): return out[:-1] +class TokenFlag(enum.Flag): + # docstub: off + NAME = enum.auto() + NATLANG = enum.auto() + SUBSCRIPT = enum.auto() + LITERAL = enum.auto() + GENERATOR = enum.auto() + ARRAY = enum.auto() + UNION = enum.auto() + START = enum.auto() + STOP = enum.auto() + SEP = enum.auto() + # docstub: on + + @classmethod + def _missing_(cls, value): + forbidden = { + *itertools.combinations([cls.START, cls.STOP, cls.SEP, cls.NAME], 2) + } + for pair in forbidden: + if value is (pair[0].value | pair[1].value): + raise ValueError(f"{pair[0].name}|{pair[1].name} not allowed") + return super()._missing_(value) + + class Token(str): """A token representing an atomic part of a doctype.""" - __slots__ = ("value", "kind") + flag = TokenFlag + + __slots__ = ("value", "kind", "pos") - def __new__(cls, value, *, kind): + def __new__(cls, value, *, kind, pos=None): self = super().__new__(cls, value) - self.kind = kind + self.kind = TokenFlag(kind) + self.pos = pos return self def __repr__(self): @@ -50,8 +80,9 @@ def __repr__(self): @classmethod def find_iter(cls, iterable, *, kind): + kind = TokenFlag(kind) for item in flatten_recursive(iterable): - if isinstance(item, cls) and item.kind == kind: + if isinstance(item, cls) and all(k & item.kind for k in kind): yield item @classmethod @@ -59,7 +90,8 @@ def find_one(cls, iterable, *, kind): matching = list(cls.find_iter(iterable, kind=kind)) if len(matching) != 1: msg = ( - f"expected exactly one {cls.__name__} with {kind=}, got {len(matching)}" + f"expected exactly one {cls.__name__} with {kind=}, " + f"got {len(matching)}: {matching}" ) raise ValueError(msg) return matching[0] @@ -79,7 +111,11 @@ def qualname(self, tree): """ children = tree.children _qualname = ".".join(children) - _qualname = Token(_qualname, kind="qualname") + _qualname = Token( + _qualname, + kind=Token.flag.NAME, + pos=(tree.meta.start_pos, tree.meta.end_pos), + ) return _qualname def rst_role(self, tree): @@ -92,7 +128,7 @@ def rst_role(self, tree): ------- out : lark.Token """ - qualname = Token.find_one(tree.children, kind="qualname") + qualname = Token.find_one(tree.children, kind=Token.flag.NAME) return qualname def union(self, tree): @@ -105,7 +141,7 @@ def union(self, tree): ------- out : list[str] """ - sep = Token(" | ", kind="union_sep") + sep = Token(" | ", kind=Token.flag.UNION | Token.flag.SEP) out = insert_between(tree.children, sep=sep) return out @@ -119,7 +155,7 @@ def subscription(self, tree): ------- out : str """ - return self._format_subscription(tree.children, name="subscription") + return self._format_subscription(tree.children) def natlang_literal(self, tree): """ @@ -131,8 +167,13 @@ def natlang_literal(self, tree): ------- out : str """ - items = [Token("Literal", kind="qualname"), *tree.children] - out = self._format_subscription(items, "nl_literal") + items = [ + Token("Literal", kind=Token.flag.LITERAL | Token.flag.NAME), + *tree.children, + ] + out = self._format_subscription( + items, kind=Token.flag.LITERAL | Token.flag.NATLANG + ) if len(tree.children) == 1: logger.warning( @@ -143,6 +184,14 @@ def natlang_literal(self, tree): ) return out + def literal_item(self, tree): + item, *other = tree.children + assert not other + kind = Token.flag.LITERAL + if isinstance(item, Token): + kind |= item.kind + return Token(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) + def natlang_container(self, tree): """ Parameters @@ -153,7 +202,7 @@ def natlang_container(self, tree): ------- out : str """ - return self._format_subscription(tree.children, name="nl_container") + return self._format_subscription(tree.children, kind=Token.flag.NATLANG) def natlang_array(self, tree): """ @@ -165,11 +214,15 @@ def natlang_array(self, tree): ------- out : str """ - array_name = Token.find_one(tree.children, kind="array_name") + array_name = Token.find_one( + tree.children, kind=Token.flag.ARRAY | Token.flag.NAME + ) items = tree.children.copy() items.remove(array_name) - items.insert(0, Token(array_name, kind="qualname")) - return self._format_subscription(items, name="nl_array") + items.insert(0, array_name) + return self._format_subscription( + items, kind=Token.flag.ARRAY | Token.flag.NATLANG + ) def array_name(self, tree): """ @@ -186,7 +239,7 @@ def array_name(self, tree): # This currently relies on a hack that only allows specific names # in `array_expression` (see `ARRAY_NAME` terminal in gramar) qualname = self.qualname(tree) - qualname = Token(qualname, kind="array_name") + qualname = Token(qualname, kind=Token.flag.NAME | Token.flag.ARRAY) return qualname def shape(self, tree): @@ -228,22 +281,24 @@ def extra_info(self, tree): logger.debug("dropping extra info") return lark.Discard - def _format_subscription(self, sequence, name): - sep = Token(", ", kind=f"{name}_sep") + def _format_subscription(self, sequence, kind=None): + if kind is None: + kind = Token.flag.SUBSCRIPT + else: + kind |= Token.flag.SUBSCRIPT + + sep = Token(", ", kind=kind | Token.flag.SEP) container, *content = sequence content = insert_between(content, sep=sep) assert content out = [ container, - Token("[", kind=f"{name}_start"), + Token("[", kind=kind | Token.flag.START), *content, - Token("]", kind=f"{name}_stop"), + Token("]", kind=kind | Token.flag.STOP), ] return out - def __default_token__(self, token): - return Token(token.value, kind=token.type.lower()) - @dataclass(frozen=True, slots=True) class ParsedDoctype: @@ -265,16 +320,38 @@ def parse(cls, doctype): Examples -------- - >>> ParsedDoctype.parse("tuple of int or ndarray of dtype (float or int)") + >>> doctype = ParsedDoctype.parse( + ... "tuple of int or ndarray of dtype (float or int)" + ... ) + >>> doctype + >>> doctype.qualnames + (Token('tuple', kind='qualname'), + Token('int', kind='qualname'), + Token('ndarray', kind='qualname'), + Token('float', kind='qualname'), + Token('int', kind='qualname')) """ tree = _lark.parse(doctype) - result = DoctypeTransformer().transform(tree=tree) - result = tuple(flatten_recursive(result)) - return cls(result, raw_doctype=doctype) + tokens = DoctypeTransformer().transform(tree=tree) + tokens = tuple(flatten_recursive(tokens)) + return cls(tokens, raw_doctype=doctype) def __str__(self): return "".join(self.tokens) def __repr__(self): - return f"<{type(self).__name__}: '{self}'>" + return f"<{type(self).__name__} '{self}'>" + + @property + def qualnames(self): + return tuple(Token.find_iter(self.tokens, kind=Token.flag.NAME)) + + def print_map_tokens_to_raw(self): + for token in self.tokens: + if token.pos is not None: + start, stop = token.pos + print(self.raw_doctype) + print(" " * start + "^" * (stop - start)) + print(" " * start + token) + print() diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index d62d389..d4c0408 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -58,7 +58,7 @@ natlang_literal: "{" literal_item ("," literal_item)* "}" // An single item in a literal expression (or `optional`). We must also allow // for qualified names, since a "class" or enum can be used as a literal too. -?literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname +literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname // Natural language forms of the subscription expression for containers. From 38839e74939ff0d1dea92c7e514762ed364dd178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Thu, 28 Aug 2025 15:11:06 +0200 Subject: [PATCH 04/17] WIP small tweaks --- src/docstub/_doctype.py | 46 +++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 777d609..1336e92 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -5,6 +5,7 @@ import logging from collections.abc import Iterable from dataclasses import dataclass +from functools import lru_cache from pathlib import Path import lark @@ -67,7 +68,7 @@ class Token(str): flag = TokenFlag - __slots__ = ("value", "kind", "pos") + __slots__ = ("kind", "pos", "value") def __new__(cls, value, *, kind, pos=None): self = super().__new__(cls, value) @@ -302,10 +303,13 @@ def _format_subscription(self, sequence, kind=None): @dataclass(frozen=True, slots=True) class ParsedDoctype: + """Parsed representation of a doctype, a type description in a docstring.""" + tokens: tuple[Token, ...] raw_doctype: str @classmethod + @lru_cache(maxsize=100) def parse(cls, doctype): """Turn a type description in a docstring into a type annotation. @@ -316,27 +320,39 @@ def parse(cls, doctype): Returns ------- - annotation_list : list of Token + parsed : Self Examples -------- - >>> doctype = ParsedDoctype.parse( + >>> parsed = ParsedDoctype.parse( ... "tuple of int or ndarray of dtype (float or int)" ... ) - >>> doctype - - >>> doctype.qualnames - (Token('tuple', kind='qualname'), - Token('int', kind='qualname'), - Token('ndarray', kind='qualname'), - Token('float', kind='qualname'), - Token('int', kind='qualname')) + >>> parsed + + >>> str(parsed) + 'tuple[int] | ndarray[float | int]' + >>> parsed.format({"ndarray": "np.ndarray"}) + 'tuple[int] | np.ndarray[float | int]' + >>> parsed.qualnames # doctest: +NORMALIZE_WHITESPACE + (Token('tuple', kind=), + Token('int', kind=), + Token('ndarray', kind=), + Token('float', kind=), + Token('int', kind=)) """ tree = _lark.parse(doctype) tokens = DoctypeTransformer().transform(tree=tree) tokens = tuple(flatten_recursive(tokens)) return cls(tokens, raw_doctype=doctype) + def format(self, replace_names=None): + replace_names = replace_names or {} + tokens = [ + replace_names.get(token, token) if token.kind == TokenFlag.NAME else token + for token in self.tokens + ] + return "".join(tokens) + def __str__(self): return "".join(self.tokens) @@ -351,7 +367,7 @@ def print_map_tokens_to_raw(self): for token in self.tokens: if token.pos is not None: start, stop = token.pos - print(self.raw_doctype) - print(" " * start + "^" * (stop - start)) - print(" " * start + token) - print() + print(self.raw_doctype) # noqa: T201 + print(" " * start + "^" * (stop - start)) # noqa: T201 + print(" " * start + token) # noqa: T201 + print() # noqa: T201 From 8ec9501ddad93852bbf58fb5a4d4ccc04b204267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 12 Jan 2026 14:17:24 +0100 Subject: [PATCH 05/17] WIP --- src/docstub-stubs/_analysis.pyi | 2 + src/docstub-stubs/_app_generate_stubs.pyi | 4 +- src/docstub-stubs/_cli.pyi | 1 - src/docstub-stubs/_docstrings.pyi | 54 +-- src/docstub-stubs/_report.pyi | 23 +- src/docstub-stubs/_stubs.pyi | 15 +- src/docstub-stubs/_utils.pyi | 5 +- src/docstub/_analysis.py | 12 +- src/docstub/_app_generate_stubs.py | 23 +- src/docstub/_docstrings.py | 536 +++++----------------- src/docstub/_doctype.py | 387 +++++++++------- src/docstub/_report.py | 170 +++++++ src/docstub/_stubs.py | 39 +- src/docstub/_utils.py | 37 -- src/docstub/doctype.lark | 8 +- 15 files changed, 586 insertions(+), 730 deletions(-) diff --git a/src/docstub-stubs/_analysis.pyi b/src/docstub-stubs/_analysis.pyi index 4e439d9..34a3cec 100644 --- a/src/docstub-stubs/_analysis.pyi +++ b/src/docstub-stubs/_analysis.pyi @@ -14,6 +14,7 @@ from typing import Any, ClassVar import libcst as cst import libcst.matchers as cstm +from ._report import Stats from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum logger: logging.Logger @@ -83,6 +84,7 @@ class TypeMatcher: types: dict[str, PyImport] | None = ..., type_prefixes: dict[str, PyImport] | None = ..., type_nicknames: dict[str, str] | None = ..., + stats: Stats | None = ..., ) -> None: ... def _resolve_nickname(self, name: str) -> str: ... def match(self, search: str) -> tuple[str | None, PyImport | None]: ... diff --git a/src/docstub-stubs/_app_generate_stubs.pyi b/src/docstub-stubs/_app_generate_stubs.pyi index 3839bac..cf30bd6 100644 --- a/src/docstub-stubs/_app_generate_stubs.pyi +++ b/src/docstub-stubs/_app_generate_stubs.pyi @@ -6,7 +6,6 @@ from collections import Counter from collections.abc import Iterable, Sequence from contextlib import contextmanager from pathlib import Path -from typing import Literal from ._analysis import PyImport, TypeCollector, TypeMatcher, common_known_types from ._cache import CACHE_DIR_NAME, FileCache @@ -18,9 +17,8 @@ from ._path_utils import ( walk_source_and_targets, walk_source_package, ) -from ._report import setup_logging +from ._report import Stats, setup_logging from ._stubs import Py2StubTransformer, try_format_stub -from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index d6e0988..42e9e7f 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -4,7 +4,6 @@ import logging import sys from collections.abc import Callable, Sequence from pathlib import Path -from typing import Literal import click from _typeshed import Incomplete diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index ea5e0c7..4b2d410 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -5,8 +5,6 @@ import traceback from collections.abc import Generator, Iterable from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path -from typing import Any, ClassVar import click import lark @@ -14,19 +12,15 @@ import lark.visitors import numpydoc.docscrape as npds from ._analysis import PyImport, TypeMatcher -from ._report import ContextReporter -from ._utils import DocstubError, escape_qualname +from ._doctype import BlacklistedQualname, Expression, Token, TokenKind, parse_doctype +from ._report import ContextReporter, Stats +from ._utils import escape_qualname logger: logging.Logger -here: Path -grammar_path: Path - -with grammar_path.open() as file: - _grammar: str - -_lark: lark.Lark - +def update_qualnames( + expr: Expression, *, _parents: tuple[Expression, ...] = ... +) -> Generator[tuple[tuple[Expression, ...], Token], str]: ... def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: ... @dataclass(frozen=True, slots=True, kw_only=True) class Annotation: @@ -54,51 +48,23 @@ class Annotation: FallbackAnnotation: Annotation -class QualnameIsKeyword(DocstubError): - pass - -class DoctypeTransformer(lark.visitors.Transformer): - matcher: TypeMatcher - stats: dict[str, Any] - - blacklisted_qualnames: ClassVar[frozenset[str]] - - def __init__( - self, *, matcher: TypeMatcher | None = ..., **kwargs: dict[Any, Any] - ) -> None: ... - def doctype_to_annotation( - self, doctype: str, *, reporter: ContextReporter | None = ... - ) -> tuple[Annotation, list[tuple[str, int, int]]]: ... - def qualname(self, tree: lark.Tree) -> lark.Token: ... - def rst_role(self, tree: lark.Tree) -> lark.Token: ... - def union(self, tree: lark.Tree) -> str: ... - def subscription(self, tree: lark.Tree) -> str: ... - def natlang_literal(self, tree: lark.Tree) -> str: ... - def natlang_container(self, tree: lark.Tree) -> str: ... - def natlang_array(self, tree: lark.Tree) -> str: ... - def array_name(self, tree: lark.Tree) -> lark.Token: ... - def shape(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... - def optional_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... - def __default__( - self, data: lark.Token, children: list[lark.Token], meta: lark.tree.Meta - ) -> lark.Token | list[lark.Token]: ... - def _match_import(self, qualname: str, *, meta: lark.tree.Meta) -> str: ... - def _uncombine_numpydoc_params( params: list[npds.Parameter], ) -> Generator[npds.Parameter]: ... +def _red_partial_underline(doctype: str, *, start: int, stop: int) -> str: ... class DocstringAnnotations: docstring: str - transformer: DoctypeTransformer + matcher: TypeMatcher reporter: ContextReporter def __init__( self, docstring: str, *, - transformer: DoctypeTransformer, + matcher: TypeMatcher, reporter: ContextReporter | None = ..., + stats: Stats | None = ..., ) -> None: ... def _doctype_to_annotation( self, doctype: str, ds_line: int = ... diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index 1515c66..7815695 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -2,11 +2,13 @@ import dataclasses import logging +from collections.abc import Hashable, Iterator, Mapping, Sequence from pathlib import Path from textwrap import indent -from typing import Any, ClassVar, Literal, Self, TextIO +from typing import Any, ClassVar, Self, TextIO import click +from pre_commit.envcontext import UNSET from ._cli_help import should_strip_ansi @@ -79,3 +81,22 @@ class LogCounter(logging.NullHandler): def setup_logging( *, verbosity: Literal[-2, -1, 0, 1, 2, 3], group_errors: bool ) -> tuple[ReportHandler, LogCounter]: ... +def update_with_add_values( + *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... +) -> dict: ... + +class Stats(Mapping): + class _UNSET: + pass + + def __init__(self, stats: dict[str, list[Any] | str] | None = ...) -> None: ... + def __getitem__(self, key: str) -> list[Any] | int: ... + def __iter__(self) -> Iterator: ... + def __len__(self) -> int: ... + def inc_counter(self, key: str, *, inc: int = ...) -> None: ... + def append_to_list(self, key: str, value: Any) -> None: ... + @classmethod + def merge(cls, *stats: Self) -> Self: ... + def __repr__(self) -> str: ... + def pop(self, key: str, *, default: Any = ...) -> list[Any] | int: ... + def pop_all(self) -> dict[str, list[Any] | int]: ... diff --git a/src/docstub-stubs/_stubs.pyi b/src/docstub-stubs/_stubs.pyi index 87c2554..3a552e4 100644 --- a/src/docstub-stubs/_stubs.pyi +++ b/src/docstub-stubs/_stubs.pyi @@ -12,14 +12,10 @@ import libcst.matchers as cstm from _typeshed import Incomplete from ._analysis import PyImport, TypeMatcher -from ._docstrings import ( - Annotation, - DocstringAnnotations, - DoctypeTransformer, - FallbackAnnotation, -) -from ._report import ContextReporter -from ._utils import module_name_from_path, update_with_add_values +from ._docstrings import Annotation, DocstringAnnotations, FallbackAnnotation +from ._doctype import DoctypeTransformer +from ._report import ContextReporter, Stats +from ._utils import module_name_from_path logger: logging.Logger @@ -73,9 +69,6 @@ class Py2StubTransformer(cst.CSTTransformer): @property def is_inside_function_def(self) -> bool: ... def python_to_stub(self, source: str, *, module_path: Path | None = ...) -> str: ... - def collect_stats( - self, *, reset_after: bool = ... - ) -> dict[str, int | list[str]]: ... def visit_ClassDef(self, node: cst.ClassDef) -> Literal[True]: ... def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef diff --git a/src/docstub-stubs/_utils.pyi b/src/docstub-stubs/_utils.pyi index c29d6a8..8b8b4bd 100644 --- a/src/docstub-stubs/_utils.pyi +++ b/src/docstub-stubs/_utils.pyi @@ -2,7 +2,7 @@ import itertools import re -from collections.abc import Callable, Hashable, Mapping, Sequence +from collections.abc import Callable from functools import lru_cache, wraps from pathlib import Path from zlib import crc32 @@ -12,9 +12,6 @@ def escape_qualname(name: str) -> str: ... def _resolve_path_before_caching(func: Callable) -> Callable: ... def module_name_from_path(path: Path) -> str: ... def pyfile_checksum(path: Path) -> str: ... -def update_with_add_values( - *mappings: Mapping[Hashable, int | Sequence], out: dict | None = ... -) -> dict: ... class DocstubError(Exception): pass diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 32a6daf..056a998 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -13,6 +13,7 @@ import libcst.matchers as cstm from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum +from ._report import Stats logger: logging.Logger = logging.getLogger(__name__) @@ -492,6 +493,7 @@ def __init__( types=None, type_prefixes=None, type_nicknames=None, + stats=None, ): """ Parameters @@ -499,15 +501,13 @@ def __init__( types : dict[str, PyImport] type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] + stats : ~.Stats, optional """ self.types = common_known_types() | (types or {}) self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} - self.stats = { - "matched_type_names": 0, - "unknown_type_names": [], - } + self.stats = stats or Stats() self.current_file = None @@ -623,8 +623,8 @@ def match(self, search): type_name = type_name[type_name.find(py_import.target) :] if type_name is not None: - self.stats["matched_type_names"] += 1 + self.stats.inc_counter("matched_type_names") else: - self.stats["unknown_type_names"].append(search) + self.stats.append_to_list("unknown_type_names", search) return type_name, py_import diff --git a/src/docstub/_app_generate_stubs.py b/src/docstub/_app_generate_stubs.py index 1800b41..7828071 100644 --- a/src/docstub/_app_generate_stubs.py +++ b/src/docstub/_app_generate_stubs.py @@ -24,9 +24,8 @@ walk_source_and_targets, walk_source_package, ) -from ._report import setup_logging +from ._report import setup_logging, Stats from ._stubs import Py2StubTransformer, try_format_stub -from ._utils import update_with_add_values from ._version import __version__ logger: logging.Logger = logging.getLogger(__name__) @@ -234,7 +233,9 @@ def _generate_single_stub(task): logger.info("Wrote %s", stub_path) fo.write(stub_content) - stats = stub_transformer.collect_stats() + stats = Stats.merge( + stub_transformer.stats.pop_all(), stub_transformer.matcher.stats.pop_all() + ) return stats @@ -350,7 +351,7 @@ def generate_stubs( stats_per_task = executor.map( _generate_single_stub, task_args, chunksize=chunk_size ) - stats = update_with_add_values(*stats_per_task) + stats = Stats.merge(*stats_per_task) py_typed_out = out_dir / "py.typed" if not py_typed_out.exists(): @@ -368,24 +369,26 @@ def generate_stubs( total_warnings = error_counter.warning_count total_errors = error_counter.error_count - logger.info("Recognized type names: %i", stats["matched_type_names"]) - logger.info("Transformed doctypes: %i", stats["transformed_doctypes"]) + logger.info("Recognized type names: %i", stats.pop("matched_type_names", default=0)) + logger.info("Transformed doctypes: %i", stats.pop("transformed_doctypes", default=0)) if total_warnings: logger.warning("Warnings: %i", total_warnings) - if stats["doctype_syntax_errors"]: + if "doctype_syntax_errors" in stats: assert total_errors - logger.warning("Syntax errors: %i", stats["doctype_syntax_errors"]) - if stats["unknown_type_names"]: + logger.warning("Syntax errors: %i", stats.pop("doctype_syntax_errors")) + if "unknown_type_names" in stats: assert total_errors logger.warning( "Unknown type names: %i (locations: %i)", len(set(stats["unknown_type_names"])), len(stats["unknown_type_names"]), - extra={"details": _format_unknown_names(stats["unknown_type_names"])}, + extra={"details": _format_unknown_names(stats.pop("unknown_type_names"))}, ) if total_errors: logger.error("Total errors: %i", total_errors) + assert len(stats) == 0 + total_fails = total_errors if fail_on_warning: total_fails += total_warnings diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 4602336..341e76e 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,10 +2,8 @@ import logging import traceback -from contextlib import contextmanager from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path import click import lark @@ -16,21 +14,62 @@ # It should be possible to transform docstrings without matching to valid # types and imports. I think that could very well be done at a higher level, # e.g. in the stubs module. -from ._analysis import PyImport, TypeMatcher -from ._report import ContextReporter -from ._utils import DocstubError, escape_qualname +from ._analysis import PyImport +from ._report import ContextReporter, Stats +from ._utils import escape_qualname +from ._doctype import parse_doctype, Expression, Token, TokenKind, BlacklistedQualname + logger: logging.Logger = logging.getLogger(__name__) -here: Path = Path(__file__).parent -grammar_path: Path = here / "doctype.lark" +def update_qualnames(expr, *, _parents=tuple()): + """Yield and receive names in `expr`. + + This generator works as a coroutine. + + Parameters + ---------- + expr : Expression + _parents : tuple of (Expression, ...) + + Yields + ------ + parents : tuple of (Expression, ...) + name_token : Token + + Receives + -------- + new_name : str + Examples + -------- + >>> from docstub._doctype import parse_doctype + >>> expr = parse_doctype("tuple of (tuple or str, ...)") + >>> updater = update_qualnames(expr) + >>> for parents, name in updater: + ... if name == "tuple" and parents[-1].rule == "union": + ... updater.send("list") + ... if name == "str": + ... updater.send("bytes") + >>> expr.as_code() + 'tuple[list | bytes, ...]' + """ + _parents += (expr,) + children = expr.children.copy() -with grammar_path.open() as file: - _grammar: str = file.read() + for i, child in enumerate(children): + if hasattr(child, "children"): + yield from update_qualnames(child, _parents=_parents) -_lark: lark.Lark = lark.Lark(_grammar, propagate_positions=True, strict=True) + elif child.kind == TokenKind.NAME: + new_name = yield _parents, child + if new_name is not None: + new_token = Token(new_name, kind=child.kind) + expr.children[i] = new_token + # `send` was called, yield `None` to return from `send`, + # otherwise send would return the next child + yield def _find_one_token(tree, *, name): @@ -185,390 +224,6 @@ def _aggregate_annotations(*types): ) -class QualnameIsKeyword(DocstubError): - """Raised when a qualname is a blacklisted Python keyword.""" - - -@lark.visitors.v_args(tree=True) -class DoctypeTransformer(lark.visitors.Transformer): - """Transformer for docstring type descriptions (doctypes). - - Attributes - ---------- - matcher : ~.TypeMatcher - stats : dict[str, Any] - blacklisted_qualnames : ClassVar[frozenset[str]] - All Python keywords [1]_ are blacklisted from use in qualnames except for ``True`` - ``False`` and ``None``. - - References - ---------- - .. [1] https://docs.python.org/3/reference/lexical_analysis.html#keywords - - Examples - -------- - >>> transformer = DoctypeTransformer() - >>> annotation, unknown_names = transformer.doctype_to_annotation( - ... "tuple of (int or ndarray)" - ... ) - >>> annotation.value - 'tuple[int | ndarray]' - >>> unknown_names - [('ndarray', 17, 24)] - """ - - blacklisted_qualnames = frozenset( - { - "await", - "else", - "import", - "pass", - "break", - "except", - "in", - "raise", - "class", - "finally", - "is", - "return", - "and", - "continue", - "for", - "lambda", - "try", - "as", - "def", - "from", - "nonlocal", - "while", - "assert", - "del", - "global", - "not", - "with", - "async", - "elif", - "if", - "or", - "yield", - } - ) - - def __init__(self, *, matcher=None, **kwargs): - """ - Parameters - ---------- - matcher : ~.TypeMatcher, optional - kwargs : dict[Any, Any], optional - Keyword arguments passed to the init of the parent class. - """ - if matcher is None: - matcher = TypeMatcher() - - self.matcher = matcher - - self._reporter = None - self._collected_imports = None - self._unknown_qualnames = None - - super().__init__(**kwargs) - - self.stats = { - "doctype_syntax_errors": 0, - "transformed_doctypes": 0, - } - - def doctype_to_annotation(self, doctype, *, reporter=None): - """Turn a type description in a docstring into a type annotation. - - Parameters - ---------- - doctype : str - The doctype to parse. - reporter : ~.ContextReporter - - Returns - ------- - annotation : Annotation - The parsed annotation. - unknown_qualnames : list[tuple[str, int, int]] - A set containing tuples. Each tuple contains a qualname, its start and its - end index relative to the given `doctype`. - """ - try: - self._reporter = reporter or ContextReporter(logger=logger) - self._collected_imports = set() - self._unknown_qualnames = [] - tree = _lark.parse(doctype) - value = super().transform(tree=tree) - annotation = Annotation( - value=value, imports=frozenset(self._collected_imports) - ) - self.stats["transformed_doctypes"] += 1 - return annotation, self._unknown_qualnames - except ( - lark.exceptions.LexError, - lark.exceptions.ParseError, - QualnameIsKeyword, - ): - self.stats["doctype_syntax_errors"] += 1 - raise - finally: - self._reporter = None - self._collected_imports = None - self._unknown_qualnames = None - - def qualname(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - children = tree.children - _qualname = ".".join(children) - - _qualname = self._match_import(_qualname, meta=tree.meta) - - if _qualname in self.blacklisted_qualnames: - msg = ( - f"qualname {_qualname!r} in docstring type description " - "is a reserved Python keyword and not allowed" - ) - raise QualnameIsKeyword(msg) - - _qualname = lark.Token(type="QUALNAME", value=_qualname) - return _qualname - - def rst_role(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - qualname = _find_one_token(tree, name="QUALNAME") - return qualname - - def union(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - out = " | ".join(tree.children) - return out - - def subscription(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - _container, *_content = tree.children - _content = ", ".join(_content) - assert _content - out = f"{_container}[{_content}]" - return out - - def natlang_literal(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - out = ", ".join(tree.children) - out = f"Literal[{out}]" - - if len(tree.children) == 1: - self._reporter.warn( - "Natural language literal with one item: `{%s}`", - tree.children[0], - details=f"Consider using `{out}` to improve readability", - ) - - if self.matcher is not None: - _, py_import = self.matcher.match("Literal") - if py_import.has_import: - self._collected_imports.add(py_import) - return out - - def natlang_container(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - return self.subscription(tree) - - def natlang_array(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : str - """ - name = _find_one_token(tree, name="ARRAY_NAME") - children = [child for child in tree.children if child != name] - if children: - name = f"{name}[{', '.join(children)}]" - return str(name) - - def array_name(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.Token - """ - # Treat `array_name` as `qualname`, but mark it as an array name, - # so we know which one to treat as the container in `array_expression` - # This currently relies on a hack that only allows specific names - # in `array_expression` (see `ARRAY_NAME` terminal in gramar) - qualname = self.qualname(tree) - qualname = lark.Token("ARRAY_NAME", str(qualname)) - return qualname - - def shape(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.visitors._DiscardType - """ - # self._reporter.debug("Dropping shape information %r", tree) - return lark.Discard - - def optional_info(self, tree): - """ - Parameters - ---------- - tree : lark.Tree - - Returns - ------- - out : lark.visitors._DiscardType - """ - # self._reporter.debug("Dropping optional info %r", tree) - return lark.Discard - - def __default__(self, data, children, meta): - """Unpack children of rule nodes by default. - - Parameters - ---------- - data : lark.Token - The rule-token of the current node. - children : list[lark.Token] - The children of the current node. - meta : lark.tree.Meta - Meta information for the current node. - - Returns - ------- - out : lark.Token or list[lark.Token] - Either a token or list of tokens. - """ - if isinstance(children, list) and len(children) == 1: - out = children[0] - if hasattr(out, "type"): - out.type = data.upper() # Turn rule into "token" - else: - out = children - return out - - @contextmanager - def _prepare_transformation(self): - """Reset transformation state before entering context and restore it on exit.""" - collected_imports = self._collected_imports - unknown_qualnames = self._unknown_qualnames - - try: - self._collected_imports = set() - self._unknown_qualnames = [] - yield - - except ( - lark.exceptions.LexError, - lark.exceptions.ParseError, - QualnameIsKeyword, - ): - self.stats["syntax_errors"] += 1 - raise - - finally: - self._collected_imports = collected_imports - self._unknown_qualnames = unknown_qualnames - - def _match_import(self, qualname, *, meta): - """Match `qualname` to known imports or alias to "Incomplete". - - Parameters - ---------- - qualname : str - meta : lark.tree.Meta - Location metadata for the `qualname`, used to report possible errors. - - Returns - ------- - matched_qualname : str - Possibly modified or normalized qualname. - """ - if self.matcher is not None: - annotation_name, py_import = self.matcher.match(qualname) - else: - annotation_name = None - py_import = None - - if py_import and py_import.has_import: - self._collected_imports.add(py_import) - - if annotation_name: - matched_qualname = annotation_name - else: - # Unknown qualname, alias to `Incomplete` - self._unknown_qualnames.append((qualname, meta.start_pos, meta.end_pos)) - matched_qualname = escape_qualname(qualname) - any_alias = PyImport( - from_="_typeshed", - import_="Incomplete", - as_=matched_qualname, - ) - self._collected_imports.add(any_alias) - return matched_qualname - - def _uncombine_numpydoc_params(params): """Split combined NumPyDoc parameters. @@ -595,13 +250,33 @@ def _uncombine_numpydoc_params(params): yield param +def _red_partial_underline(doctype, *, start, stop): + """Underline a part of a string with red '^'. + + Parameters + ---------- + doctype : str + start : int + stop : int + + Returns + ------- + underlined : str + """ + width = stop - start + assert width > 0 + underline = click.style("^" * width, fg="red", bold=True) + underlined = f"{doctype}\n{' ' * start}{underline}\n" + return underlined + + class DocstringAnnotations: """Collect annotations in a given docstring. Attributes ---------- docstring : str - transformer : DoctypeTransformer + matcher : ~.TypeMatcher reporter : ~.ContextReporter Examples @@ -619,17 +294,19 @@ class DocstringAnnotations: dict_keys(['a', 'b', 'c']) """ - def __init__(self, docstring, *, transformer, reporter=None): + def __init__(self, docstring, *, matcher, reporter=None, stats=None): """ Parameters ---------- docstring : str - transformer : DoctypeTransformer + matcher : ~.TypeMatcher reporter : ~.ContextReporter, optional + stats : ~.Stats, optional """ self.docstring = docstring self.np_docstring = npds.NumpyDocString(docstring) - self.transformer = transformer + self.matcher = matcher + self.stats = Stats() if stats is None else stats if reporter is None: reporter = ContextReporter(logger=logger, line=0) @@ -655,34 +332,73 @@ def _doctype_to_annotation(self, doctype, ds_line=0): reporter = self.reporter.copy_with(line_offset=ds_line) try: - annotation, unknown_qualnames = self.transformer.doctype_to_annotation( - doctype, reporter=reporter - ) + expression = parse_doctype(doctype) + self.stats.inc_counter("transformed_doctypes") reporter.debug( - "Transformed doctype", details=(" %s\n-> %s", doctype, annotation) + "Transformed doctype", details=(" %s\n-> %s", doctype, expression) ) - except (lark.exceptions.LexError, lark.exceptions.ParseError) as error: + imports = set() + unknown_qualnames = set() + updater = update_qualnames(expression) + for _, token in updater: + search_name = str(token) + matched_name, py_import = self.matcher.match(search_name) + if matched_name is None: + assert py_import is None + unknown_qualnames.add((search_name, *token.pos)) + matched_name = escape_qualname(search_name) + _ = updater.send(matched_name) + assert _ is None + + if py_import is None: + incomplete_alias = PyImport( + from_="_typeshed", + import_="Incomplete", + as_=matched_name, + ) + imports.add(incomplete_alias) + elif py_import.has_import: + imports.add(py_import) + + annotation = Annotation(value=str(expression), imports=frozenset(imports)) + + except ( + lark.exceptions.LexError, + lark.exceptions.ParseError, + ) as error: details = None if hasattr(error, "get_context"): details = error.get_context(doctype) details = details.replace("^", click.style("^", fg="red", bold=True)) + self.stats.inc_counter("doctype_syntax_errors") reporter.error( "Invalid syntax in docstring type annotation", details=details ) return FallbackAnnotation - except lark.visitors.VisitError as e: - tb = "\n".join(traceback.format_exception(e.orig_exc)) - details = f"doctype: {doctype!r}\n\n{tb}" - reporter.error("Unexpected error while parsing doctype", details=details) + except lark.visitors.VisitError as error: + original_error = error.orig_exc + if isinstance(original_error, BlacklistedQualname): + msg = "Blacklisted keyword argument in doctype" + details = _red_partial_underline( + doctype, + start=error.obj.meta.start_pos, + stop=error.obj.meta.end_pos, + ) + else: + msg = "Unexpected error while parsing doctype" + tb = traceback.format_exception(original_error) + tb = "\n".join(tb) + details = f"doctype: {doctype!r}\n\n{tb}" + reporter.error(msg, details=details) return FallbackAnnotation else: for name, start_col, stop_col in unknown_qualnames: - width = stop_col - start_col - error_underline = click.style("^" * width, fg="red", bold=True) - details = f"{doctype}\n{' ' * start_col}{error_underline}\n" + details = _red_partial_underline( + doctype, start=start_col, stop=stop_col + ) reporter.error(f"Unknown name in doctype: {name!r}", details=details) return annotation diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 1336e92..f29140b 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -1,36 +1,63 @@ -"""Parsing of doctypes""" +"""Parsing of doctypes.""" import enum -import itertools import logging +import keyword from collections.abc import Iterable from dataclasses import dataclass -from functools import lru_cache from pathlib import Path +from textwrap import indent +from typing import Final import lark import lark.visitors -logger = logging.getLogger(__name__) +from ._utils import DocstubError -grammar_path = Path(__file__).parent / "doctype.lark" +logger: Final = logging.getLogger(__name__) + + +grammar_path: Final = Path(__file__).parent / "doctype.lark" with grammar_path.open() as file: - _grammar = file.read() + _grammar: Final = file.read() -_lark = lark.Lark(_grammar, propagate_positions=True, strict=True) +_lark: Final = lark.Lark(_grammar, propagate_positions=True, strict=True) def flatten_recursive(iterable): + """Flatten nested iterables yield the contained strings. + + Parameters + ---------- + iterable : Iterable[Iterable or str] + + Yields + ------ + item : str + """ for item in iterable: - if not isinstance(item, str) and isinstance(item, Iterable): + if isinstance(item, str): + yield item + elif isinstance(item, Iterable): yield from flatten_recursive(item) else: - yield item + raise ValueError(f"unexpected type: {item!r}") def insert_between(iterable, *, sep): + """Insert `sep` inbetween elements of `iterable`. + + Parameters + ---------- + iterable : Iterable + sep : Any + + Returns + ------- + out : list[Any] + """ out = [] for item in iterable: out.append(item) @@ -38,68 +65,109 @@ def insert_between(iterable, *, sep): return out[:-1] -class TokenFlag(enum.Flag): +class TokenKind(enum.StrEnum): # docstub: off NAME = enum.auto() - NATLANG = enum.auto() - SUBSCRIPT = enum.auto() LITERAL = enum.auto() - GENERATOR = enum.auto() - ARRAY = enum.auto() - UNION = enum.auto() - START = enum.auto() - STOP = enum.auto() - SEP = enum.auto() + SYNTAX = enum.auto() # docstub: on - @classmethod - def _missing_(cls, value): - forbidden = { - *itertools.combinations([cls.START, cls.STOP, cls.SEP, cls.NAME], 2) - } - for pair in forbidden: - if value is (pair[0].value | pair[1].value): - raise ValueError(f"{pair[0].name}|{pair[1].name} not allowed") - return super()._missing_(value) - class Token(str): - """A token representing an atomic part of a doctype.""" + """A token representing an atomic part of a doctype. - flag = TokenFlag + Attributes + ---------- + __slots__ : Final + """ - __slots__ = ("kind", "pos", "value") + __slots__ = ("value", "kind", "pos") def __new__(cls, value, *, kind, pos=None): + """ + Parameters + ---------- + value : str + kind : TokenKind or str + pos : tuple of (int, int), optional + """ self = super().__new__(cls, value) - self.kind = TokenFlag(kind) + self.kind = TokenKind(kind) self.pos = pos return self def __repr__(self): - return f"{type(self).__name__}('{self}', kind={self.kind!r})" - - @classmethod - def find_iter(cls, iterable, *, kind): - kind = TokenFlag(kind) - for item in flatten_recursive(iterable): - if isinstance(item, cls) and all(k & item.kind for k in kind): - yield item - - @classmethod - def find_one(cls, iterable, *, kind): - matching = list(cls.find_iter(iterable, kind=kind)) - if len(matching) != 1: - msg = ( - f"expected exactly one {cls.__name__} with {kind=}, " - f"got {len(matching)}: {matching}" - ) - raise ValueError(msg) - return matching[0] + return f"{type(self).__name__}('{self}', kind='{self.kind}')" + + def __getnewargs_ex__(self): + """""" + kwargs = {"value": str(self), "kind": self.kind, "pos": self.pos} + return tuple(), kwargs + + +@dataclass(slots=True) +class Expression: + """A named expression made up of sub expressions and tokens.""" + + rule: str + children: list[Expression | Token] + + @property + def tokens(self): + """All tokens in the expression.""" + return list(flatten_recursive(self)) + + @property + def names(self): + """Name tokens in the expression.""" + return [token for token in self.tokens if token.kind == TokenKind.NAME] + + def __iter__(self): + yield from self.children + + def format_tree(self): + formatted_children = ( + c.format_tree() if hasattr(c, "format_tree") else repr(c) + for c in self.children + ) + formatted_children = ",\n".join(formatted_children) + formatted_children = indent(formatted_children, prefix=" ") + return ( + f"{type(self).__name__}({self.rule!r}, children=[\n{formatted_children}])" + ) + + def __repr__(self): + return f"<{type(self).__name__}: '{self.as_code()}' rule='{self.rule}'>" + + def __str__(self): + return "".join(self.tokens) + + def as_code(self): + return str(self) + + +BLACKLISTED_QUALNAMES: Final = set(keyword.kwlist) - {"None", "True", "False"} + + +class BlacklistedQualname(DocstubError): + """Raised when a qualname is a forbidden keyword.""" @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): + + def start(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expression + """ + return Expression(rule="start", children=tree.children) + def qualname(self, tree): """ Parameters @@ -108,29 +176,32 @@ def qualname(self, tree): Returns ------- - out : lark.Token + out : Token """ children = tree.children _qualname = ".".join(children) + + if _qualname in BLACKLISTED_QUALNAMES: + raise BlacklistedQualname(_qualname) + _qualname = Token( _qualname, - kind=Token.flag.NAME, + kind=TokenKind.NAME, pos=(tree.meta.start_pos, tree.meta.end_pos), ) return _qualname - def rst_role(self, tree): + def ELLIPSES(self, token): """ Parameters ---------- - tree : lark.Tree + token : lark.Token Returns ------- - out : lark.Token + out : Token """ - qualname = Token.find_one(tree.children, kind=Token.flag.NAME) - return qualname + return Token(token, kind=TokenKind.LITERAL) def union(self, tree): """ @@ -140,11 +211,11 @@ def union(self, tree): Returns ------- - out : list[str] + out : Expression """ - sep = Token(" | ", kind=Token.flag.UNION | Token.flag.SEP) - out = insert_between(tree.children, sep=sep) - return out + sep = Token(" | ", kind=TokenKind.SYNTAX) + expr = Expression(rule="union", children=insert_between(tree.children, sep=sep)) + return expr def subscription(self, tree): """ @@ -154,7 +225,7 @@ def subscription(self, tree): Returns ------- - out : str + out : Expression """ return self._format_subscription(tree.children) @@ -166,15 +237,13 @@ def natlang_literal(self, tree): Returns ------- - out : str + out : Expression """ items = [ - Token("Literal", kind=Token.flag.LITERAL | Token.flag.NAME), + Token("Literal", kind=TokenKind.SYNTAX), *tree.children, ] - out = self._format_subscription( - items, kind=Token.flag.LITERAL | Token.flag.NATLANG - ) + out = self._format_subscription(items, rule="natlang_literal") if len(tree.children) == 1: logger.warning( @@ -186,11 +255,20 @@ def natlang_literal(self, tree): return out def literal_item(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Token + """ item, *other = tree.children assert not other - kind = Token.flag.LITERAL + kind = TokenKind.LITERAL if isinstance(item, Token): - kind |= item.kind + kind = item.kind return Token(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) def natlang_container(self, tree): @@ -201,9 +279,9 @@ def natlang_container(self, tree): Returns ------- - out : str + out : Expression """ - return self._format_subscription(tree.children, kind=Token.flag.NATLANG) + return self._format_subscription(tree.children, rule="natlang_container") def natlang_array(self, tree): """ @@ -213,17 +291,9 @@ def natlang_array(self, tree): Returns ------- - out : str + out : Expression """ - array_name = Token.find_one( - tree.children, kind=Token.flag.ARRAY | Token.flag.NAME - ) - items = tree.children.copy() - items.remove(array_name) - items.insert(0, array_name) - return self._format_subscription( - items, kind=Token.flag.ARRAY | Token.flag.NATLANG - ) + return self._format_subscription(tree.children, rule="natlang_array") def array_name(self, tree): """ @@ -233,16 +303,25 @@ def array_name(self, tree): Returns ------- - out : lark.Token + out : Token """ - # Treat `array_name` as `qualname`, but mark it as an array name, - # so we know which one to treat as the container in `array_expression` # This currently relies on a hack that only allows specific names # in `array_expression` (see `ARRAY_NAME` terminal in gramar) qualname = self.qualname(tree) - qualname = Token(qualname, kind=Token.flag.NAME | Token.flag.ARRAY) return qualname + def dtype(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expression + """ + return Expression(rule="dtype", children=tree.children) + def shape(self, tree): """ Parameters @@ -256,7 +335,7 @@ def shape(self, tree): logger.debug("dropping shape information") return lark.Discard - def optional(self, tree): + def optional_info(self, tree): """ Parameters ---------- @@ -282,92 +361,58 @@ def extra_info(self, tree): logger.debug("dropping extra info") return lark.Discard - def _format_subscription(self, sequence, kind=None): - if kind is None: - kind = Token.flag.SUBSCRIPT - else: - kind |= Token.flag.SUBSCRIPT - - sep = Token(", ", kind=kind | Token.flag.SEP) - container, *content = sequence - content = insert_between(content, sep=sep) - assert content - out = [ - container, - Token("[", kind=kind | Token.flag.START), - *content, - Token("]", kind=kind | Token.flag.STOP), - ] - return out - - -@dataclass(frozen=True, slots=True) -class ParsedDoctype: - """Parsed representation of a doctype, a type description in a docstring.""" - - tokens: tuple[Token, ...] - raw_doctype: str - - @classmethod - @lru_cache(maxsize=100) - def parse(cls, doctype): - """Turn a type description in a docstring into a type annotation. - + def _format_subscription(self, sequence, rule="subscription"): + """ Parameters ---------- - doctype : str - The doctype to parse. + sequence : Sequence[str] + rule : str, optional Returns ------- - parsed : Self - - Examples - -------- - >>> parsed = ParsedDoctype.parse( - ... "tuple of int or ndarray of dtype (float or int)" - ... ) - >>> parsed - - >>> str(parsed) - 'tuple[int] | ndarray[float | int]' - >>> parsed.format({"ndarray": "np.ndarray"}) - 'tuple[int] | np.ndarray[float | int]' - >>> parsed.qualnames # doctest: +NORMALIZE_WHITESPACE - (Token('tuple', kind=), - Token('int', kind=), - Token('ndarray', kind=), - Token('float', kind=), - Token('int', kind=)) - """ - tree = _lark.parse(doctype) - tokens = DoctypeTransformer().transform(tree=tree) - tokens = tuple(flatten_recursive(tokens)) - return cls(tokens, raw_doctype=doctype) - - def format(self, replace_names=None): - replace_names = replace_names or {} - tokens = [ - replace_names.get(token, token) if token.kind == TokenFlag.NAME else token - for token in self.tokens - ] - return "".join(tokens) - - def __str__(self): - return "".join(self.tokens) - - def __repr__(self): - return f"<{type(self).__name__} '{self}'>" - - @property - def qualnames(self): - return tuple(Token.find_iter(self.tokens, kind=Token.flag.NAME)) - - def print_map_tokens_to_raw(self): - for token in self.tokens: - if token.pos is not None: - start, stop = token.pos - print(self.raw_doctype) # noqa: T201 - print(" " * start + "^" * (stop - start)) # noqa: T201 - print(" " * start + token) # noqa: T201 - print() # noqa: T201 + out : Expression + """ + sep = Token(", ", kind=TokenKind.SYNTAX) + container, *content = sequence + content = insert_between(content, sep=sep) + assert content + expr = Expression( + rule=rule, + children=[ + container, + Token("[", kind=TokenKind.SYNTAX), + *content, + Token("]", kind=TokenKind.SYNTAX), + ], + ) + return expr + + +def parse_doctype(doctype): + """Turn a type description in a docstring into a type annotation. + + Parameters + ---------- + doctype : str + The doctype to parse. + + Returns + ------- + parsed : Expression + + Raises + ------ + lark.exceptions.VisitError + Raised when the transformation is interrupted by an exception. + See :cls:`lark.exceptions.VisitError`. + + Examples + -------- + >>> parse_doctype("tuple of (int, ...)") + + >>> parse_doctype("ndarray of dtype (float or int)") + + """ + tree = _lark.parse(doctype) + expression = DoctypeTransformer().transform(tree=tree) + return expression diff --git a/src/docstub/_report.py b/src/docstub/_report.py index 86957c2..c09dd41 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -4,8 +4,10 @@ import logging from pathlib import Path from textwrap import indent +from collections.abc import Mapping import click +from pre_commit.envcontext import UNSET from ._cli_help import should_strip_ansi @@ -445,3 +447,171 @@ def setup_logging(*, verbosity, group_errors): logging.captureWarnings(True) return reporter, log_counter + + +def update_with_add_values(*mappings, out=None): + """Merge mappings while adding together their values. + + Parameters + ---------- + mappings : Mapping[Hashable, int or Sequence] + out : dict, optional + + Returns + ------- + out : dict, optional + + Examples + -------- + >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} + >>> stats_2 = {"unknown": ["func"], "errors": 1} + >>> update_with_add_values(stats_1, stats_2) + {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} + + >>> _ = update_with_add_values(stats_1, out=stats_2) + >>> stats_2 + {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} + + >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) + {'lines': (1, 33, 42)} + """ + if out is None: + out = {} + for m in mappings: + for key, value in m.items(): + if hasattr(value, "__add__"): + out[key] = out.setdefault(key, type(value)()) + value + else: + raise TypeError(f"Don't know how to 'add' {value!r}") + return out + + +class Stats(Mapping): + """Collect statistics + + Examples + -------- + >>> stats = Stats() + >>> stats.inc_counter("counter") + >>> stats.inc_counter("counter", inc=2) + >>> stats.append_to_list("names", "Foo") + >>> stats.append_to_list("names", "Bar") + >>> dict(stats) + {'counter': 3, 'names': ['Foo', 'Bar']} + + >>> other_stats = Stats( + ... {"counter": 3, "modules": ["pathlib"], "names": ["baz"]} + ... ) + >>> merged = stats.merge(stats, other_stats) + >>> dict(merged) + {'counter': 6, 'names': ['Foo', 'Bar', 'baz'], 'modules': ['pathlib']} + """ + + class _UNSET: + """Sentinel signaling that an argument wasn't set.""" + + def __init__(self, stats=None): + """ + Parameters + ---------- + stats : dict[str, list[Any] or str] + """ + self._stats = {} if stats is None else stats + + def __getitem__(self, key): + """Retrieve a statistic. + + Parameters + ---------- + key : str + + Returns + ------- + value : list[Any] or int + """ + return self._stats[key] + + def __iter__(self): + """ + Returns + ------- + out : Iterator + """ + yield from self._stats + + def __len__(self) -> int: + return len(self._stats) + + def inc_counter(self, key, *, inc=1): + """Increase counter of a statistic. + + Parameters + ---------- + key : str + inc : int, optional + """ + if key not in self._stats: + self._stats[key] = 0 + assert isinstance(inc, int) + self._stats[key] += inc + + def append_to_list(self, key, value): + """Append `value` to statistic. + + Parameters + ---------- + key : str + value : Any + """ + if key not in self._stats: + self._stats[key] = [] + self._stats[key].append(value) + + @classmethod + def merge(cls, *stats): + """ + + Parameters + ---------- + *stats : Self + + Returns + ------- + merged : Self + """ + out = update_with_add_values(*stats) + out = cls(out) + return out + + def __repr__(self) -> str: + keys = ", ".join(self._stats.keys()) + return f"<{type(self).__name__}: {keys}>" + + def pop(self, key, *, default=_UNSET): + """Return and remove a statistic from this container. + + Parameters + ---------- + key : str + default : Any, optional + If given, falls back to the given default value if `key` is not + found. + + Returns + ------- + value : list[Any] or int + """ + if key in self._stats or default is UNSET: + return self._stats.pop(key) + return default + + def pop_all(self): + """Return and remove all statistics from this container. + + Returns + ------- + stats : dict[str, list[Any] or int] + """ + out = self._stats + self._stats = {} + return out diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 862e40e..64ca8ee 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -15,9 +15,9 @@ import libcst.matchers as cstm from ._analysis import PyImport -from ._docstrings import DocstringAnnotations, DoctypeTransformer, FallbackAnnotation -from ._report import ContextReporter -from ._utils import module_name_from_path, update_with_add_values +from ._docstrings import DocstringAnnotations, FallbackAnnotation +from ._report import ContextReporter, Stats +from ._utils import module_name_from_path logger: logging.Logger = logging.getLogger(__name__) @@ -305,8 +305,9 @@ def __init__(self, *, matcher=None): ---------- matcher : ~.TypeMatcher """ - self.transformer = DoctypeTransformer(matcher=matcher) + self.matcher = matcher self.reporter = ContextReporter(logger=logger) + self.stats = Stats() # Relevant docstring for the current context self._scope_stack = None # Entered module, class or function scopes self._pytypes_stack = None # Collected pytypes for each stack @@ -332,10 +333,10 @@ def current_source(self, value): value : Path """ self._current_source = value - # TODO pass current_source directly when using the transformer / matcher + # TODO pass current_source directly when using the matcher # instead of assigning it here! - if self.transformer is not None and self.transformer.matcher is not None: - self.transformer.matcher.current_file = value + if self.matcher is not None: + self.matcher.current_file = value @property def is_inside_function_def(self): @@ -385,26 +386,6 @@ def python_to_stub(self, source, *, module_path=None): self._required_imports = None self.current_source = None - def collect_stats(self, *, reset_after=True): - """Return statistics from processing files. - - Parameters - ---------- - reset_after : bool, optional - Whether to reset counters and statistics after returning. - - Returns - ------- - stats : dict of {str: int or list[str]} - """ - collected = [self.transformer.stats, self.transformer.matcher.stats] - merged = update_with_add_values(*collected) - if reset_after is True: - for stats in collected: - for key in stats: - stats[key] = type(stats[key])() - return merged - def visit_ClassDef(self, node): """Collect pytypes from class docstring and add scope to stack. @@ -908,13 +889,15 @@ def _annotations_from_node(self, node): try: annotations = DocstringAnnotations( docstring_node.evaluated_value, - transformer=self.transformer, + matcher=self.matcher, reporter=reporter, + stats=self.stats, ) except (SystemExit, KeyboardInterrupt): raise except Exception: reporter.error("could not parse docstring", exc_info=True) + return annotations def _create_annotated_assign( diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index 297d881..a10c91e 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -159,42 +159,5 @@ def pyfile_checksum(path): return key -def update_with_add_values(*mappings, out=None): - """Merge mappings while adding together their values. - - Parameters - ---------- - mappings : Mapping[Hashable, int or Sequence] - out : dict, optional - - Returns - ------- - out : dict, optional - - Examples - -------- - >>> stats_1 = {"errors": 2, "warnings": 0, "unknown": ["string", "integer"]} - >>> stats_2 = {"unknown": ["func"], "errors": 1} - >>> update_with_add_values(stats_1, stats_2) - {'errors': 3, 'warnings': 0, 'unknown': ['string', 'integer', 'func']} - - >>> _ = update_with_add_values(stats_1, out=stats_2) - >>> stats_2 - {'unknown': ['func', 'string', 'integer'], 'errors': 3, 'warnings': 0} - - >>> update_with_add_values({"lines": (1, 33)}, {"lines": (42,)}) - {'lines': (1, 33, 42)} - """ - if out is None: - out = {} - for m in mappings: - for key, value in m.items(): - if hasattr(value, "__add__"): - out[key] = out.setdefault(key, type(value)()) + value - else: - raise TypeError(f"Don't know how to 'add' {value!r}") - return out - - class DocstubError(Exception): """An error raised by docstub.""" diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index db3bb91..1b74b76 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -6,13 +6,13 @@ // https://lark-parser.readthedocs.io/en/latest/grammar.html -?start: annotation_with_meta +start: _annotation_with_meta // The basic structure of a full docstring annotation as it comes after the // `name : `. It includes additional meta information that is optional and // currently ignored. -?annotation_with_meta: type ("," optional_info)? +_annotation_with_meta: type ("," optional_info)? // A type annotation. Can range from a simple qualified name to a complex @@ -32,7 +32,7 @@ // [1] https://docutils.sourceforge.io/docs/ref/rst/roles.html // qualname: (/~/ ".")? (NAME ".")* NAME - | (":" (NAME ":")? NAME ":")? "`" qualname "`" -> rst_role + | (":" (NAME ":")? NAME ":")? "`" qualname "`" // An union of different types, joined either by "or" or "|". @@ -102,7 +102,7 @@ ARRAY_NAME: "array" | "ndarray" | "array-like" | "array_like" // The dtype used in an array expression. -?dtype: qualname | "(" union ")" +dtype: qualname | "(" union ")" // The shape used in an array expression. Possibly to liberal right now in From ff52bc0f5f63016d3599e343ee5171c9fff6992e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 12 Jan 2026 19:11:12 +0100 Subject: [PATCH 06/17] WIP --- src/docstub-stubs/_docstrings.pyi | 21 +- src/docstub-stubs/_doctype.pyi | 88 ++++++++ src/docstub/_docstrings.py | 240 ++++++++++------------ src/docstub/_doctype.py | 137 ++++++++----- src/docstub/doctype.lark | 16 +- tests/test_docstrings.py | 321 ++---------------------------- tests/test_doctype.py | 270 +++++++++++++++++++++++++ 7 files changed, 588 insertions(+), 505 deletions(-) create mode 100644 src/docstub-stubs/_doctype.pyi create mode 100644 tests/test_doctype.py diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index 4b2d410..a5805f0 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -12,16 +12,15 @@ import lark.visitors import numpydoc.docscrape as npds from ._analysis import PyImport, TypeMatcher -from ._doctype import BlacklistedQualname, Expression, Token, TokenKind, parse_doctype +from ._doctype import BlacklistedQualname, Expr, Term, TermKind, parse_doctype from ._report import ContextReporter, Stats from ._utils import escape_qualname logger: logging.Logger -def update_qualnames( - expr: Expression, *, _parents: tuple[Expression, ...] = ... -) -> Generator[tuple[tuple[Expression, ...], Token], str]: ... -def _find_one_token(tree: lark.Tree, *, name: str) -> lark.Token: ... +def _update_qualnames( + expr: Expr, *, _parents: tuple[Expr, ...] = ... +) -> Generator[tuple[tuple[Expr, ...], Term], str]: ... @dataclass(frozen=True, slots=True, kw_only=True) class Annotation: @@ -52,6 +51,13 @@ def _uncombine_numpydoc_params( params: list[npds.Parameter], ) -> Generator[npds.Parameter]: ... def _red_partial_underline(doctype: str, *, start: int, stop: int) -> str: ... +def doctype_to_annotation( + doctype: str, + *, + matcher: TypeMatcher | None = ..., + reporter: ContextReporter | None = ..., + stats: Stats | None = ..., +) -> Annotation: ... class DocstringAnnotations: docstring: str @@ -62,13 +68,10 @@ class DocstringAnnotations: self, docstring: str, *, - matcher: TypeMatcher, + matcher: TypeMatcher | None = ..., reporter: ContextReporter | None = ..., stats: Stats | None = ..., ) -> None: ... - def _doctype_to_annotation( - self, doctype: str, ds_line: int = ... - ) -> Annotation: ... @cached_property def attributes(self) -> dict[str, Annotation]: ... @cached_property diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi new file mode 100644 index 0000000..6de9ece --- /dev/null +++ b/src/docstub-stubs/_doctype.pyi @@ -0,0 +1,88 @@ +# File generated with docstub + +import enum +import keyword +import logging +from collections.abc import Generator, Iterable, Sequence +from dataclasses import dataclass +from pathlib import Path +from textwrap import indent +from typing import Any, Final + +import lark +import lark.visitors +from _typeshed import Incomplete + +from ._utils import DocstubError + +logger: Final + +grammar_path: Final + +with grammar_path.open() as file: + _grammar: Final + +_lark: Final + +def flatten_recursive(iterable: Iterable[Iterable | str]) -> Generator[str]: ... +def insert_between(iterable: Iterable, *, sep: Any) -> list[Any]: ... + +class TermKind(enum.StrEnum): + + NAME = enum.auto() + LITERAL = enum.auto() + SYNTAX = enum.auto() + +class Term(str): + + __slots__: Final + + def __new__( + cls, value: str, *, kind: TermKind | str, pos: tuple[int, int] | None = ... + ) -> None: ... + def __repr__(self) -> str: ... + def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ... + +@dataclass(slots=True) +class Expr: + + rule: str + children: list[Expr | Term] + + @property + def terms(self) -> list[Term]: ... + @property + def names(self) -> list[Term]: ... + def __iter__(self) -> Generator[Expr | Term]: ... + def format_tree(self) -> str: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def as_code(self) -> str: ... + +BLACKLISTED_QUALNAMES: Final + +class BlacklistedQualname(DocstubError): + pass + +class DoctypeTransformer(lark.visitors.Transformer): + def start(self, tree: lark.Tree) -> Expr: ... + def qualname(self, tree: lark.Tree) -> Term: ... + def ELLIPSES(self, token: lark.Token) -> Term: ... + def union(self, tree: lark.Tree) -> Expr: ... + def subscription(self, tree: lark.Tree) -> Expr: ... + def natlang_literal(self, tree: lark.Tree) -> Expr: ... + def literal_item(self, tree: lark.Tree) -> Term: ... + def natlang_container(self, tree: lark.Tree) -> Expr: ... + def natlang_array(self, tree: lark.Tree) -> Expr: ... + def array_name(self, tree: lark.Tree) -> Term: ... + def dtype(self, tree: lark.Tree) -> Expr: ... + def shape(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def optional_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def extra_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... + def _format_subscription( + self, sequence: Sequence[str], rule: str = ... + ) -> Expr: ... + +_transformer: Final + +def parse_doctype(doctype: str) -> Expr: ... diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 341e76e..4a1125d 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -14,29 +14,29 @@ # It should be possible to transform docstrings without matching to valid # types and imports. I think that could very well be done at a higher level, # e.g. in the stubs module. -from ._analysis import PyImport +from ._analysis import PyImport, TypeMatcher from ._report import ContextReporter, Stats from ._utils import escape_qualname -from ._doctype import parse_doctype, Expression, Token, TokenKind, BlacklistedQualname +from ._doctype import parse_doctype, Expr, Term, TermKind, BlacklistedQualname logger: logging.Logger = logging.getLogger(__name__) -def update_qualnames(expr, *, _parents=tuple()): +def _update_qualnames(expr, *, _parents=tuple()): """Yield and receive names in `expr`. This generator works as a coroutine. Parameters ---------- - expr : Expression - _parents : tuple of (Expression, ...) + expr : Expr + _parents : tuple of (~._doctype.Expr, ...) Yields ------ - parents : tuple of (Expression, ...) - name_token : Token + parents : tuple of (~._doctype.Expr, ...) + name : ~._doctype.Term Receives -------- @@ -46,7 +46,7 @@ def update_qualnames(expr, *, _parents=tuple()): -------- >>> from docstub._doctype import parse_doctype >>> expr = parse_doctype("tuple of (tuple or str, ...)") - >>> updater = update_qualnames(expr) + >>> updater = _update_qualnames(expr) >>> for parents, name in updater: ... if name == "tuple" and parents[-1].rule == "union": ... updater.send("list") @@ -60,42 +60,18 @@ def update_qualnames(expr, *, _parents=tuple()): for i, child in enumerate(children): if hasattr(child, "children"): - yield from update_qualnames(child, _parents=_parents) + yield from _update_qualnames(child, _parents=_parents) - elif child.kind == TokenKind.NAME: + elif child.kind == TermKind.NAME: new_name = yield _parents, child if new_name is not None: - new_token = Token(new_name, kind=child.kind) - expr.children[i] = new_token + new_term = Term(new_name, kind=child.kind) + expr.children[i] = new_term # `send` was called, yield `None` to return from `send`, # otherwise send would return the next child yield -def _find_one_token(tree, *, name): - """Find token with a specific type name in tree. - - Parameters - ---------- - tree : lark.Tree - name : str - Name of the token to find in the children of `tree`. - - Returns - ------- - token : lark.Token - """ - tokens = [ - child - for child in tree.children - if hasattr(child, "type") and child.type == name - ] - if len(tokens) != 1: - msg = f"expected exactly one Token of type {name}, found {len(tokens)}" - raise ValueError(msg) - return tokens[0] - - @dataclass(frozen=True, slots=True, kw_only=True) class Annotation: """Python-ready type annotation with attached import information.""" @@ -270,6 +246,94 @@ def _red_partial_underline(doctype, *, start, stop): return underlined +def doctype_to_annotation(doctype, *, matcher=None, reporter=None, stats=None): + """Convert a type description to a Python-ready type. + + Parameters + ---------- + doctype : str + matcher : ~.TypeMatcher, optional + reporter : ~.ContextReporter, optional + stats : ~.Stats, optional + + Returns + ------- + annotation : Annotation + The transformed type, ready to be inserted into a stub file, with + necessary imports attached. + """ + matcher = matcher or TypeMatcher() + reporter = reporter or ContextReporter(logger=logger) + stats = Stats() if stats is None else stats + + try: + expression = parse_doctype(doctype) + stats.inc_counter("transformed_doctypes") + reporter.debug( + "Transformed doctype", details=(" %s\n-> %s", doctype, expression) + ) + + imports = set() + unknown_qualnames = set() + updater = _update_qualnames(expression) + for _, name in updater: + search_name = str(name) + matched_name, py_import = matcher.match(search_name) + if matched_name is None: + assert py_import is None + unknown_qualnames.add((search_name, *name.pos)) + matched_name = escape_qualname(search_name) + _ = updater.send(matched_name) + assert _ is None + + if py_import is None: + incomplete_alias = PyImport( + from_="_typeshed", + import_="Incomplete", + as_=matched_name, + ) + imports.add(incomplete_alias) + elif py_import.has_import: + imports.add(py_import) + + annotation = Annotation(value=str(expression), imports=frozenset(imports)) + + except ( + lark.exceptions.LexError, + lark.exceptions.ParseError, + ) as error: + details = None + if hasattr(error, "get_context"): + details = error.get_context(doctype) + details = details.replace("^", click.style("^", fg="red", bold=True)) + stats.inc_counter("doctype_syntax_errors") + reporter.error("Invalid syntax in docstring type annotation", details=details) + return FallbackAnnotation + + except lark.visitors.VisitError as error: + original_error = error.orig_exc + if isinstance(original_error, BlacklistedQualname): + msg = "Blacklisted keyword argument in doctype" + details = _red_partial_underline( + doctype, + start=error.obj.meta.start_pos, + stop=error.obj.meta.end_pos, + ) + else: + msg = "Unexpected error while parsing doctype" + tb = traceback.format_exception(original_error) + tb = "\n".join(tb) + details = f"doctype: {doctype!r}\n\n{tb}" + reporter.error(msg, details=details) + return FallbackAnnotation + + else: + for name, start_col, stop_col in unknown_qualnames: + details = _red_partial_underline(doctype, start=start_col, stop=stop_col) + reporter.error(f"Unknown name in doctype: {name!r}", details=details) + return annotation + + class DocstringAnnotations: """Collect annotations in a given docstring. @@ -294,114 +358,24 @@ class DocstringAnnotations: dict_keys(['a', 'b', 'c']) """ - def __init__(self, docstring, *, matcher, reporter=None, stats=None): + def __init__(self, docstring, *, matcher=None, reporter=None, stats=None): """ Parameters ---------- docstring : str - matcher : ~.TypeMatcher + matcher : ~.TypeMatcher, optional reporter : ~.ContextReporter, optional stats : ~.Stats, optional """ self.docstring = docstring self.np_docstring = npds.NumpyDocString(docstring) - self.matcher = matcher + self.matcher = matcher or TypeMatcher() self.stats = Stats() if stats is None else stats if reporter is None: reporter = ContextReporter(logger=logger, line=0) self.reporter = reporter.copy_with(logger=logger) - def _doctype_to_annotation(self, doctype, ds_line=0): - """Convert a type description to a Python-ready type. - - Parameters - ---------- - doctype : str - The type description of a parameter or return value, as extracted from - a docstring. - ds_line : int, optional - The line number relative to the docstring. - - Returns - ------- - annotation : Annotation - The transformed type, ready to be inserted into a stub file, with - necessary imports attached. - """ - reporter = self.reporter.copy_with(line_offset=ds_line) - - try: - expression = parse_doctype(doctype) - self.stats.inc_counter("transformed_doctypes") - reporter.debug( - "Transformed doctype", details=(" %s\n-> %s", doctype, expression) - ) - - imports = set() - unknown_qualnames = set() - updater = update_qualnames(expression) - for _, token in updater: - search_name = str(token) - matched_name, py_import = self.matcher.match(search_name) - if matched_name is None: - assert py_import is None - unknown_qualnames.add((search_name, *token.pos)) - matched_name = escape_qualname(search_name) - _ = updater.send(matched_name) - assert _ is None - - if py_import is None: - incomplete_alias = PyImport( - from_="_typeshed", - import_="Incomplete", - as_=matched_name, - ) - imports.add(incomplete_alias) - elif py_import.has_import: - imports.add(py_import) - - annotation = Annotation(value=str(expression), imports=frozenset(imports)) - - except ( - lark.exceptions.LexError, - lark.exceptions.ParseError, - ) as error: - details = None - if hasattr(error, "get_context"): - details = error.get_context(doctype) - details = details.replace("^", click.style("^", fg="red", bold=True)) - self.stats.inc_counter("doctype_syntax_errors") - reporter.error( - "Invalid syntax in docstring type annotation", details=details - ) - return FallbackAnnotation - - except lark.visitors.VisitError as error: - original_error = error.orig_exc - if isinstance(original_error, BlacklistedQualname): - msg = "Blacklisted keyword argument in doctype" - details = _red_partial_underline( - doctype, - start=error.obj.meta.start_pos, - stop=error.obj.meta.end_pos, - ) - else: - msg = "Unexpected error while parsing doctype" - tb = traceback.format_exception(original_error) - tb = "\n".join(tb) - details = f"doctype: {doctype!r}\n\n{tb}" - reporter.error(msg, details=details) - return FallbackAnnotation - - else: - for name, start_col, stop_col in unknown_qualnames: - details = _red_partial_underline( - doctype, start=start_col, stop=stop_col - ) - reporter.error(f"Unknown name in doctype: {name!r}", details=details) - return annotation - @cached_property def attributes(self): """Return the attributes found in the docstring. @@ -569,7 +543,13 @@ def _section_annotations(self, name): continue ds_line = self._find_docstring_line(param.name, param.type) - annotation = self._doctype_to_annotation(param.type, ds_line=ds_line) + + annotation = doctype_to_annotation( + doctype=param.type, + matcher=self.matcher, + reporter=self.reporter.copy_with(line_offset=ds_line), + stats=self.stats, + ) annotated_params[param.name.strip()] = annotation return annotated_params diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index f29140b..5d734ae 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -23,7 +23,9 @@ with grammar_path.open() as file: _grammar: Final = file.read() -_lark: Final = lark.Lark(_grammar, propagate_positions=True, strict=True) +# TODO try passing `transformer=DoctypeTransformer()`, may be faster [1] +# [1] https://lark-parser.readthedocs.io/en/latest/classes.html#:~:text=after%20the%20parse%2C-,but%20faster,-) +_lark: Final = lark.Lark(_grammar, propagate_positions=True, parser="lalr") def flatten_recursive(iterable): @@ -65,7 +67,7 @@ def insert_between(iterable, *, sep): return out[:-1] -class TokenKind(enum.StrEnum): +class TermKind(enum.StrEnum): # docstub: off NAME = enum.auto() LITERAL = enum.auto() @@ -73,8 +75,8 @@ class TokenKind(enum.StrEnum): # docstub: on -class Token(str): - """A token representing an atomic part of a doctype. +class Term(str): + """A terminal / symbol representing an atomic part of a doctype. Attributes ---------- @@ -88,44 +90,79 @@ def __new__(cls, value, *, kind, pos=None): Parameters ---------- value : str - kind : TokenKind or str + kind : TermKind or str pos : tuple of (int, int), optional """ self = super().__new__(cls, value) - self.kind = TokenKind(kind) + self.kind = TermKind(kind) self.pos = pos return self - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}('{self}', kind='{self.kind}')" def __getnewargs_ex__(self): - """""" + """ + Returns + ------- + args : tuple of (Any, ...) + kwargs : dict of {str: Any} + """ kwargs = {"value": str(self), "kind": self.kind, "pos": self.pos} return tuple(), kwargs @dataclass(slots=True) -class Expression: - """A named expression made up of sub expressions and tokens.""" +class Expr: + """An expression that forms or is part of a doctype. + + Parameters + ---------- + rule : + The name of the (grammar) rule corresponding to this expression. + children : list of (Expr or Term) + Sub-expressions or terms that make up this expression. + """ rule: str - children: list[Expression | Token] + children: list[Expr | Term] @property - def tokens(self): - """All tokens in the expression.""" + def terms(self): + """All terms in the expression. + + Returns + ------- + terms : list of Term + """ return list(flatten_recursive(self)) @property def names(self): - """Name tokens in the expression.""" - return [token for token in self.tokens if token.kind == TokenKind.NAME] + """Name terms in the expression. + + Returns + ------- + names : list of Term + """ + return [term for term in self.terms if term.kind == TermKind.NAME] def __iter__(self): + """Iterate over children of this expression. + + Yields + ------ + child : Expr or Term + """ yield from self.children def format_tree(self): + """Format full hierarchy as a tree. + + Returns + ------- + formatted : str + """ formatted_children = ( c.format_tree() if hasattr(c, "format_tree") else repr(c) for c in self.children @@ -136,13 +173,13 @@ def format_tree(self): f"{type(self).__name__}({self.rule!r}, children=[\n{formatted_children}])" ) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__}: '{self.as_code()}' rule='{self.rule}'>" - def __str__(self): - return "".join(self.tokens) + def __str__(self) -> str: + return "".join(self.terms) - def as_code(self): + def as_code(self) -> str: return str(self) @@ -155,7 +192,6 @@ class BlacklistedQualname(DocstubError): @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): - def start(self, tree): """ Parameters @@ -164,9 +200,9 @@ def start(self, tree): Returns ------- - out : Expression + out : Expr """ - return Expression(rule="start", children=tree.children) + return Expr(rule="start", children=tree.children) def qualname(self, tree): """ @@ -176,7 +212,7 @@ def qualname(self, tree): Returns ------- - out : Token + out : Term """ children = tree.children _qualname = ".".join(children) @@ -184,9 +220,9 @@ def qualname(self, tree): if _qualname in BLACKLISTED_QUALNAMES: raise BlacklistedQualname(_qualname) - _qualname = Token( + _qualname = Term( _qualname, - kind=TokenKind.NAME, + kind=TermKind.NAME, pos=(tree.meta.start_pos, tree.meta.end_pos), ) return _qualname @@ -199,9 +235,9 @@ def ELLIPSES(self, token): Returns ------- - out : Token + out : Term """ - return Token(token, kind=TokenKind.LITERAL) + return Term(token, kind=TermKind.LITERAL) def union(self, tree): """ @@ -211,10 +247,10 @@ def union(self, tree): Returns ------- - out : Expression + out : Expr """ - sep = Token(" | ", kind=TokenKind.SYNTAX) - expr = Expression(rule="union", children=insert_between(tree.children, sep=sep)) + sep = Term(" | ", kind=TermKind.SYNTAX) + expr = Expr(rule="union", children=insert_between(tree.children, sep=sep)) return expr def subscription(self, tree): @@ -225,7 +261,7 @@ def subscription(self, tree): Returns ------- - out : Expression + out : Expr """ return self._format_subscription(tree.children) @@ -237,10 +273,10 @@ def natlang_literal(self, tree): Returns ------- - out : Expression + out : Expr """ items = [ - Token("Literal", kind=TokenKind.SYNTAX), + Term("Literal", kind=TermKind.SYNTAX), *tree.children, ] out = self._format_subscription(items, rule="natlang_literal") @@ -262,14 +298,14 @@ def literal_item(self, tree): Returns ------- - out : Token + out : Term """ item, *other = tree.children assert not other - kind = TokenKind.LITERAL - if isinstance(item, Token): + kind = TermKind.LITERAL + if isinstance(item, Term): kind = item.kind - return Token(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) + return Term(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) def natlang_container(self, tree): """ @@ -279,7 +315,7 @@ def natlang_container(self, tree): Returns ------- - out : Expression + out : Expr """ return self._format_subscription(tree.children, rule="natlang_container") @@ -291,7 +327,7 @@ def natlang_array(self, tree): Returns ------- - out : Expression + out : Expr """ return self._format_subscription(tree.children, rule="natlang_array") @@ -303,7 +339,7 @@ def array_name(self, tree): Returns ------- - out : Token + out : Term """ # This currently relies on a hack that only allows specific names # in `array_expression` (see `ARRAY_NAME` terminal in gramar) @@ -318,9 +354,9 @@ def dtype(self, tree): Returns ------- - out : Expression + out : Expr """ - return Expression(rule="dtype", children=tree.children) + return Expr(rule="dtype", children=tree.children) def shape(self, tree): """ @@ -370,24 +406,27 @@ def _format_subscription(self, sequence, rule="subscription"): Returns ------- - out : Expression + out : Expr """ - sep = Token(", ", kind=TokenKind.SYNTAX) + sep = Term(", ", kind=TermKind.SYNTAX) container, *content = sequence content = insert_between(content, sep=sep) assert content - expr = Expression( + expr = Expr( rule=rule, children=[ container, - Token("[", kind=TokenKind.SYNTAX), + Term("[", kind=TermKind.SYNTAX), *content, - Token("]", kind=TokenKind.SYNTAX), + Term("]", kind=TermKind.SYNTAX), ], ) return expr +_transformer: Final = DoctypeTransformer() + + def parse_doctype(doctype): """Turn a type description in a docstring into a type annotation. @@ -398,7 +437,7 @@ def parse_doctype(doctype): Returns ------- - parsed : Expression + parsed : Expr Raises ------ @@ -414,5 +453,5 @@ def parse_doctype(doctype): """ tree = _lark.parse(doctype) - expression = DoctypeTransformer().transform(tree=tree) + expression = _transformer.transform(tree=tree) return expression diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 1b74b76..50f094e 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -12,12 +12,12 @@ start: _annotation_with_meta // The basic structure of a full docstring annotation as it comes after the // `name : `. It includes additional meta information that is optional and // currently ignored. -_annotation_with_meta: type ("," optional_info)? +_annotation_with_meta: _type ("," optional_info)? // A type annotation. Can range from a simple qualified name to a complex // nested construct of types. -?type: qualname +_type: qualname | union | subscription | natlang_literal @@ -36,7 +36,7 @@ qualname: (/~/ ".")? (NAME ".")* NAME // An union of different types, joined either by "or" or "|". -union: type (_OR type)+ +union: _type (_OR _type)+ // Operator used in unions. @@ -44,7 +44,7 @@ _OR: "or" | "|" // An expression where an object is subscribed with "A[v, ...]". -subscription: qualname "[" type ("," type)* ("," ELLIPSES)? "]" +subscription: qualname "[" _type ("," _type)* ("," ELLIPSES)? "]" // Allow Python's ellipses object @@ -65,7 +65,7 @@ literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname // These forms allow nesting with other expressions. But it's discouraged to do // so extensively to maintain readability. natlang_container: qualname "of" qualname _PLURAL_S? - | qualname "of" "(" type ")" + | qualname "of" "(" _type ")" | _natlang_tuple | _natlang_mapping @@ -78,12 +78,12 @@ _PLURAL_S: /(? str: - return name.replace("-", "_").replace(".", "_") - - doctype = fmt.format(name=name, dtype=dtype, shape=shape) - expected = expected_fmt.format( - name=escape(name), dtype=escape(dtype), shape=shape - ) - - transformer = DoctypeTransformer() - annotation, _ = transformer.doctype_to_annotation(doctype) - - assert annotation.value == expected - # fmt: on - - @pytest.mark.parametrize( - ("doctype", "expected"), - [ - ("ndarray of dtype (int or float)", "ndarray[int | float]"), - ], - ) - def test_natlang_array_specific(self, doctype, expected): - transformer = DoctypeTransformer() - annotation, _ = transformer.doctype_to_annotation(doctype) - assert annotation.value == expected - - @pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"]) - def test_natlang_array_invalid_shape(self, shape): - doctype = f"array of shape {shape}" - transformer = DoctypeTransformer() - with pytest.raises(lark.exceptions.UnexpectedInput): - _ = transformer.doctype_to_annotation(doctype) - - def test_unknown_name(self): - # Simple unknown name is aliased to typing.Any - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a") - assert annotation.value == "a" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a") - } - assert unknown_names == [("a", 0, 1)] - - def test_unknown_qualname(self): - # Unknown qualified name is escaped and aliased to typing.Any as well - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a.b") - assert annotation.value == "a_b" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") - } - assert unknown_names == [("a.b", 0, 3)] - - def test_multiple_unknown_names(self): - # Multiple names are aliased to typing.Any - transformer = DoctypeTransformer() - annotation, unknown_names = transformer.doctype_to_annotation("a.b of c") - assert annotation.value == "a_b[c]" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), - PyImport(import_="Incomplete", from_="_typeshed", as_="c"), - } - assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] - - class Test_DocstringAnnotations: def test_empty_docstring(self): docstring = dedent("""No sections in this docstring.""") - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.attributes == {} assert annotations.parameters == {} assert annotations.returns is None @@ -346,8 +60,7 @@ def test_parameters(self, doctype, expected): b : """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 1 assert annotations.parameters["a"].value == expected assert "b" not in annotations.parameters @@ -368,8 +81,7 @@ def test_returns(self, doctypes, expected): b : {} """ ).format(*doctypes) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == expected @@ -382,8 +94,7 @@ def test_yields(self, caplog): b : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == "Generator[tuple[int, str]]" assert annotations.returns.imports == { @@ -404,8 +115,7 @@ def test_receives(self, caplog): d : bytes """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert ( annotations.returns.value @@ -433,8 +143,7 @@ def test_full_generator(self, caplog): e : bool """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == ( "Generator[tuple[int, str], tuple[float, bytes], bool]" @@ -456,8 +165,7 @@ def test_yields_and_returns(self, caplog): e : bool """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns.value == ("Generator[tuple[int, str], None, bool]") assert annotations.returns.imports == { @@ -473,8 +181,7 @@ def test_duplicate_parameters(self, caplog): a : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 1 assert annotations.parameters["a"].value == "int" @@ -487,8 +194,7 @@ def test_duplicate_returns(self, caplog): a : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.returns is not None assert annotations.returns is not None assert annotations.returns.value == "int" @@ -502,8 +208,7 @@ def test_args_kwargs(self): **kwargs : str """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert "args" in annotations.parameters assert "*args" not in annotations.parameters assert "kwargs" in annotations.parameters @@ -521,8 +226,7 @@ def test_missing_whitespace(self, caplog): a: int """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert annotations.parameters["a"].value == "int" assert len(caplog.records) == 1 assert "Possibly missing whitespace" in caplog.text @@ -536,8 +240,7 @@ def test_combined_numpydoc_params(self): d, e : """ ) - transformer = DoctypeTransformer() - annotations = DocstringAnnotations(docstring, transformer=transformer) + annotations = DocstringAnnotations(docstring) assert len(annotations.parameters) == 3 assert annotations.parameters["a"].value == "bool" assert annotations.parameters["b"].value == "bool" diff --git a/tests/test_doctype.py b/tests/test_doctype.py new file mode 100644 index 0000000..5a7ff44 --- /dev/null +++ b/tests/test_doctype.py @@ -0,0 +1,270 @@ +import logging + +import pytest +import lark +import lark.exceptions + +from docstub._doctype import parse_doctype, Term, TermKind, BLACKLISTED_QUALNAMES + + +class Test_parse_doctype: + @pytest.mark.parametrize( + "doctype", + [ + "((float))", + "(float,)", + "(, )", + "...", + "(..., ...)", + "{}", + "{:}", + "{a:}", + "{:b}", + "{'a',}", + "a or (b or c)", + ",, optional", + ], + ) + def test_edge_case_errors(self, doctype): + with pytest.raises(lark.exceptions.UnexpectedInput): + parse_doctype(doctype) + + @pytest.mark.parametrize("doctype", BLACKLISTED_QUALNAMES) + def test_reserved_keywords(self, doctype): + with pytest.raises(lark.exceptions.VisitError): + parse_doctype(doctype) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("int or float", "int | float"), + ("int or float or str", "int | float | str"), + ], + ) + def test_natlang_union(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + # Conventional + ("list[float]", "list[float]"), + ("dict[str, Union[int, str]]", "dict[str, Union[int, str]]"), + ("tuple[int, ...]", "tuple[int, ...]"), + ("Sequence[int | float]", "Sequence[int | float]"), + # Natural language variant with "of" and optional plural "(s)" + ("list of int", "list[int]"), + ("list of int(s)", "list[int]"), + # Natural tuple variant + ("tuple of (float, int, str)", "tuple[float, int, str]"), + ("tuple of (float, ...)", "tuple[float, ...]"), + # Natural dict variant + ("dict of {str: int}", "dict[str, int]"), + ("dict of {str: int | float}", "dict[str, int | float]"), + ("dict of {str: int or float}", "dict[str, int | float]"), + ("dict[list of str]", "dict[list[str]]"), + ], + ) + def test_subscription(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + # Natural language variant with "of" and optional plural "(s)" + ("list of int", "list[int]"), + ("list of int(s)", "list[int]"), + ("list of (int or float)", "list[int | float]"), + ("list of (list of int)", "list[list[int]]"), + # Natural tuple variant + ("tuple of (float, int, str)", "tuple[float, int, str]"), + ("tuple of (float, ...)", "tuple[float, ...]"), + # Natural dict variant + ("dict of {str: int}", "dict[str, int]"), + ("dict of {str: int | float}", "dict[str, int | float]"), + ("dict of {str: int or float}", "dict[str, int | float]"), + # Nesting is possible but probably rarely a good idea + ("list of (list of int(s))", "list[list[int]]"), + ("tuple of (tuple of (float, ...), ...)", "tuple[tuple[float, ...], ...]"), + ("dict of {str: dict of {str: float}}", "dict[str, dict[str, float]]"), + ("dict of {str: list of (list of int(s))}", "dict[str, list[list[int]]]"), + ], + ) + def test_natlang_container(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + "doctype", + [ + "list of int (s)", + "list of ((float))", + "list of (float,)", + "list of (, )", + "list of ...", + "list of (..., ...)", + "dict of {}", + "dict of {:}", + "dict of {a:}", + "dict of {:b}", + ], + ) + def test_subscription_error(self, doctype): + with pytest.raises(lark.exceptions.UnexpectedInput): + parse_doctype(doctype) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("{0}", "Literal[0]"), + ("{-1, 1}", "Literal[-1, 1]"), + ("{None}", "Literal[None]"), + ("{True, False}", "Literal[True, False]"), + ("""{'a', "bar"}""", """Literal['a', "bar"]"""), + # Enum + ("{SomeEnum.FIRST}", "Literal[SomeEnum_FIRST]"), + ("{`SomeEnum.FIRST`, 1}", "Literal[SomeEnum_FIRST, 1]"), + ("{:ref:`SomeEnum.FIRST`, 2}", "Literal[SomeEnum_FIRST, 2]"), + ("{:py:ref:`SomeEnum.FIRST`, 3}", "Literal[SomeEnum_FIRST, 3]"), + # Nesting + ("dict[{'a', 'b'}, int]", "dict[Literal['a', 'b'], int]"), + # These aren't officially valid as an argument to `Literal` (yet) + # https://typing.python.org/en/latest/spec/literal.html + # TODO figure out how docstub should deal with these + ("{-2., 1.}", "Literal[-2., 1.]"), + pytest.param( + "{-inf, inf, nan}", + "Literal[, 1.]", + marks=pytest.mark.xfail(reason="unsure how to support"), + ), + ], + ) + def test_literals(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + def test_single_natlang_literal_warning(self, caplog): + expr = parse_doctype("{True}") + assert expr.as_code() == "Literal[True]" + assert caplog.messages == ["Natural language literal with one item: `{True}`"] + assert caplog.records[0].levelno == logging.WARNING + assert ( + caplog.records[0].details + == "Consider using `Literal[True]` to improve readability" + ) + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("int", "int"), + ("int | None", "int | None"), + ("tuple of (int, float)", "tuple[int, float]"), + ("{'a', 'b'}", "Literal['a', 'b']"), + ], + ) + @pytest.mark.parametrize( + "optional_info", + [ + "", + ", optional", + ", default -1", + ", default: -1", + ", default = 1", + ", in range (0, 1), optional", + ", optional, in range [0, 1]", + ", see parameter `image`, optional", + ], + ) + def test_optional_info(self, doctype, expected, optional_info): + doctype_with_optional = doctype + optional_info + expr = parse_doctype(doctype_with_optional) + assert expr.as_code() == expected + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("`Generator`", "Generator"), + (":class:`Generator`", "Generator"), + (":py:class:`Generator`", "Generator"), + (":py:class:`Generator`[int]", "Generator[int]"), + (":py:ref:`~.Foo`[int]", "_Foo[int]"), + ("list[:py:class:`Generator`]", "list[Generator]"), + ], + ) + def test_rst_role(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + # fmt: off + @pytest.mark.parametrize( + ("fmt", "expected_fmt"), + [ + ("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"), + ("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"), + ("{name} of {dtype}", "{name}[{dtype}]"), + ], + ) + @pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"]) + @pytest.mark.parametrize("dtype", ["int", "np.int8"]) + @pytest.mark.parametrize("shape", + ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)", "([P,] M, N)"] + ) + def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape): + + def escape(name: str) -> str: + return name.replace("-", "_").replace(".", "_") + + doctype = fmt.format(name=name, dtype=dtype, shape=shape) + expected = expected_fmt.format( + name=escape(name), dtype=escape(dtype), shape=shape + ) + expr = parse_doctype(doctype) + assert expr.as_code() == expected + # fmt: on + + @pytest.mark.parametrize( + ("doctype", "expected"), + [ + ("ndarray of dtype (int or float)", "ndarray[int | float]"), + ], + ) + def test_natlang_array_specific(self, doctype, expected): + expr = parse_doctype(doctype) + assert expr.as_code() == expected + + @pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"]) + def test_natlang_array_invalid_shape(self, shape): + doctype = f"array of shape {shape}" + with pytest.raises(lark.exceptions.UnexpectedInput): + _ = parse_doctype(doctype) + + def test_unknown_name(self): + # Simple unknown name is aliased to typing.Any + annotation, unknown_names = parse_doctype("a") + assert annotation.value == "a" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a") + } + assert unknown_names == [("a", 0, 1)] + + def test_unknown_qualname(self): + # Unknown qualified name is escaped and aliased to typing.Any as well + annotation, unknown_names = parse_doctype("a.b") + assert annotation.value == "a_b" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") + } + assert unknown_names == [("a.b", 0, 3)] + + def test_multiple_unknown_names(self): + # Multiple names are aliased to typing.Any + annotation, unknown_names = parse_doctype("a.b of c") + assert annotation.value == "a_b[c]" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), + PyImport(import_="Incomplete", from_="_typeshed", as_="c"), + } + assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] + From a27715ab4a40f1bb30224b1478b73cc83c07e89c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Tue, 13 Jan 2026 00:09:08 +0100 Subject: [PATCH 07/17] WIP --- src/docstub/_doctype.py | 7 +++---- tests/test_doctype.py | 13 +++---------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 5d734ae..4884a53 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -1,8 +1,8 @@ """Parsing of doctypes.""" import enum -import logging import keyword +import logging from collections.abc import Iterable from dataclasses import dataclass from pathlib import Path @@ -14,7 +14,6 @@ from ._utils import DocstubError - logger: Final = logging.getLogger(__name__) @@ -25,7 +24,7 @@ # TODO try passing `transformer=DoctypeTransformer()`, may be faster [1] # [1] https://lark-parser.readthedocs.io/en/latest/classes.html#:~:text=after%20the%20parse%2C-,but%20faster,-) -_lark: Final = lark.Lark(_grammar, propagate_positions=True, parser="lalr") +_lark: Final = lark.Lark(_grammar, propagate_positions=True) def flatten_recursive(iterable): @@ -83,7 +82,7 @@ class Term(str): __slots__ : Final """ - __slots__ = ("value", "kind", "pos") + __slots__ = ("kind", "pos", "value") def __new__(cls, value, *, kind, pos=None): """ diff --git a/tests/test_doctype.py b/tests/test_doctype.py index 5a7ff44..f9d7dfc 100644 --- a/tests/test_doctype.py +++ b/tests/test_doctype.py @@ -1,10 +1,10 @@ import logging -import pytest import lark import lark.exceptions +import pytest -from docstub._doctype import parse_doctype, Term, TermKind, BLACKLISTED_QUALNAMES +from docstub._doctype import BLACKLISTED_QUALNAMES, parse_doctype class Test_parse_doctype: @@ -212,14 +212,8 @@ def test_rst_role(self, doctype, expected): ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)", "([P,] M, N)"] ) def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape): - - def escape(name: str) -> str: - return name.replace("-", "_").replace(".", "_") - doctype = fmt.format(name=name, dtype=dtype, shape=shape) - expected = expected_fmt.format( - name=escape(name), dtype=escape(dtype), shape=shape - ) + expected = expected_fmt.format(name=name, dtype=dtype, shape=shape) expr = parse_doctype(doctype) assert expr.as_code() == expected # fmt: on @@ -267,4 +261,3 @@ def test_multiple_unknown_names(self): PyImport(import_="Incomplete", from_="_typeshed", as_="c"), } assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] - From 2557ddbae5cc901d2e2ceaec44fd6cf54567d0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Fri, 16 Jan 2026 20:23:02 +0100 Subject: [PATCH 08/17] WIP --- src/docstub/_doctype.py | 53 +++++++++++++++++++++++++++++++++++----- src/docstub/doctype.lark | 8 ++++++ tests/test_doctype.py | 31 +++++++++++++++++++---- 3 files changed, 81 insertions(+), 11 deletions(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 17a2000..67b01f1 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Final +from typing import Final, Self import lark import lark.visitors @@ -124,7 +124,7 @@ class Expr: """ rule: str - children: list[Expr | Term] + children: list[Self | Term] @property def terms(self): @@ -226,6 +226,29 @@ def qualname(self, tree): ) return _qualname + def qualname(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Term + """ + children = tree.children + _qualname = ".".join(children) + + if _qualname in BLACKLISTED_QUALNAMES: + raise BlacklistedQualname(_qualname) + + _qualname = Term( + _qualname, + kind=TermKind.NAME, + pos=(tree.meta.start_pos, tree.meta.end_pos), + ) + return _qualname + def ELLIPSES(self, token): """ Parameters @@ -262,7 +285,7 @@ def subscription(self, tree): ------- out : Expr """ - return self._format_subscription(tree.children) + return self._format_subscription(tree.children, rule="subscription") def param_spec(self, tree): """ @@ -295,6 +318,23 @@ def callable(self, tree): """ return self._format_subscription(tree.children, rule="callable") + def literal(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + items = [ + Term("Literal", kind=TermKind.NAME), + *tree.children, + ] + out = self._format_subscription(items, rule="literal") + return out + def natlang_literal(self, tree): """ Parameters @@ -306,7 +346,7 @@ def natlang_literal(self, tree): out : Expr """ items = [ - Term("Literal", kind=TermKind.SYNTAX), + Term("Literal", kind=TermKind.NAME), *tree.children, ] out = self._format_subscription(items, rule="natlang_literal") @@ -335,7 +375,8 @@ def literal_item(self, tree): kind = TermKind.LITERAL if isinstance(item, Term): kind = item.kind - return Term(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) + out = Term(item, kind=kind, pos=(tree.meta.start_pos, tree.meta.end_pos)) + return out def natlang_container(self, tree): """ @@ -427,7 +468,7 @@ def extra_info(self, tree): logger.debug("dropping extra info") return lark.Discard - def _format_subscription(self, sequence, rule="subscription"): + def _format_subscription(self, sequence, *, rule): """ Parameters ---------- diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index cbec2af..6e1e064 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -19,6 +19,7 @@ _annotation_with_meta: _type ("," optional_info)? // nested construct of types. _type: qualname | union + | literal | subscription | callable | natlang_literal @@ -68,6 +69,13 @@ ELLIPSES: "..." natlang_literal: "{" literal_item ("," literal_item)* "}" +// A literal expression as supported by Python proper. The rule "subscription" +// isn't allowed to contain "literal_items", so we need to define this. +// Assign a higher priority so that things like `Literal[Some.ENUM]` is marked +// as a literal expression. +literal.1: "Literal[" literal_item ("," literal_item)* "]" + + // An single item in a literal expression (or `optional`). We must also allow // for qualified names, since a "class" or enum can be used as a literal too. literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname diff --git a/tests/test_doctype.py b/tests/test_doctype.py index 5a37e2f..660800e 100644 --- a/tests/test_doctype.py +++ b/tests/test_doctype.py @@ -115,6 +115,27 @@ def test_subscription_error(self, doctype): with pytest.raises(lark.exceptions.UnexpectedInput): parse_doctype(doctype) + @pytest.mark.parametrize( + ("doctype"), + [ + "Literal[0]", + "Literal[-1, 1]", + "Literal[None]", + "Literal[True, False]", + """Literal['a', "bar"]""", + # Enum + "Literal[SomeEnum.FIRST]", + "Literal[SomeEnum.FIRST, 1]", + "Literal[SomeEnum.FIRST, 2]", + "Literal[SomeEnum.FIRST, 3]", + # Nesting + "dict[Literal['a', 'b'], int]", + ], + ) + def test_literals(self, doctype): + expr = parse_doctype(doctype) + assert expr.as_code() == doctype + @pytest.mark.parametrize( ("doctype", "expected"), [ @@ -124,10 +145,10 @@ def test_subscription_error(self, doctype): ("{True, False}", "Literal[True, False]"), ("""{'a', "bar"}""", """Literal['a', "bar"]"""), # Enum - ("{SomeEnum.FIRST}", "Literal[SomeEnum_FIRST]"), - ("{`SomeEnum.FIRST`, 1}", "Literal[SomeEnum_FIRST, 1]"), - ("{:ref:`SomeEnum.FIRST`, 2}", "Literal[SomeEnum_FIRST, 2]"), - ("{:py:ref:`SomeEnum.FIRST`, 3}", "Literal[SomeEnum_FIRST, 3]"), + ("{SomeEnum.FIRST}", "Literal[SomeEnum.FIRST]"), + ("{`SomeEnum.FIRST`, 1}", "Literal[SomeEnum.FIRST, 1]"), + ("{:ref:`SomeEnum.FIRST`, 2}", "Literal[SomeEnum.FIRST, 2]"), + ("{:py:ref:`SomeEnum.FIRST`, 3}", "Literal[SomeEnum.FIRST, 3]"), # Nesting ("dict[{'a', 'b'}, int]", "dict[Literal['a', 'b'], int]"), # These aren't officially valid as an argument to `Literal` (yet) @@ -141,7 +162,7 @@ def test_subscription_error(self, doctype): ), ], ) - def test_literals(self, doctype, expected): + def test_natlang_literals(self, doctype, expected): expr = parse_doctype(doctype) assert expr.as_code() == expected From 8fb0b8419e5d3278ad84968a99ec6ce077e08720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 17 Jan 2026 00:32:10 +0100 Subject: [PATCH 09/17] WIP --- src/docstub/_doctype.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 67b01f1..3dc44dc 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -172,6 +172,10 @@ def format_tree(self): f"{type(self).__name__}({self.rule!r}, children=[\n{formatted_children}])" ) + def print_tree(self): + """Print full hierarchy as a tree.""" + print(self.format_tree()) + def __repr__(self) -> str: return f"<{type(self).__name__}: '{self.as_code()}' rule='{self.rule}'>" From fee233acef6040f0c71e0bd72c5cdc92ad0857ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sat, 17 Jan 2026 15:12:05 +0100 Subject: [PATCH 10/17] WIP --- src/docstub-stubs/_app_generate_stubs.pyi | 1 + src/docstub-stubs/_cli.pyi | 1 + src/docstub-stubs/_docstrings.pyi | 1 + src/docstub-stubs/_doctype.pyi | 17 +++++++---- src/docstub-stubs/_report.pyi | 2 +- src/docstub/_doctype.py | 35 ++++++++++++++++++---- src/docstub/doctype.lark | 4 +-- tests/test_docstrings.py | 34 ++++++++++++++++++++- tests/test_doctype.py | 36 +++++------------------ 9 files changed, 88 insertions(+), 43 deletions(-) diff --git a/src/docstub-stubs/_app_generate_stubs.pyi b/src/docstub-stubs/_app_generate_stubs.pyi index cf30bd6..ae4c95b 100644 --- a/src/docstub-stubs/_app_generate_stubs.pyi +++ b/src/docstub-stubs/_app_generate_stubs.pyi @@ -6,6 +6,7 @@ from collections import Counter from collections.abc import Iterable, Sequence from contextlib import contextmanager from pathlib import Path +from typing import Literal from ._analysis import PyImport, TypeCollector, TypeMatcher, common_known_types from ._cache import CACHE_DIR_NAME, FileCache diff --git a/src/docstub-stubs/_cli.pyi b/src/docstub-stubs/_cli.pyi index 42e9e7f..d6e0988 100644 --- a/src/docstub-stubs/_cli.pyi +++ b/src/docstub-stubs/_cli.pyi @@ -4,6 +4,7 @@ import logging import sys from collections.abc import Callable, Sequence from pathlib import Path +from typing import Literal import click from _typeshed import Incomplete diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index a5805f0..f1a9c57 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -10,6 +10,7 @@ import click import lark import lark.visitors import numpydoc.docscrape as npds +from _typeshed import Incomplete as Expr from ._analysis import PyImport, TypeMatcher from ._doctype import BlacklistedQualname, Expr, Term, TermKind, parse_doctype diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi index 6de9ece..a9ce3bc 100644 --- a/src/docstub-stubs/_doctype.pyi +++ b/src/docstub-stubs/_doctype.pyi @@ -7,11 +7,12 @@ from collections.abc import Generator, Iterable, Sequence from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, Final +from typing import Any, Final, Literal, Self import lark import lark.visitors from _typeshed import Incomplete +from _typeshed import Incomplete as Expression from ._utils import DocstubError @@ -47,14 +48,17 @@ class Term(str): class Expr: rule: str - children: list[Expr | Term] + children: list[Self | Term] @property def terms(self) -> list[Term]: ... @property def names(self) -> list[Term]: ... + @property + def sub_expressions(self) -> list[Expression] | Literal[1]: ... def __iter__(self) -> Generator[Expr | Term]: ... def format_tree(self) -> str: ... + def print_tree(self) -> None: ... def __repr__(self) -> str: ... def __str__(self) -> str: ... def as_code(self) -> str: ... @@ -67,9 +71,14 @@ class BlacklistedQualname(DocstubError): class DoctypeTransformer(lark.visitors.Transformer): def start(self, tree: lark.Tree) -> Expr: ... def qualname(self, tree: lark.Tree) -> Term: ... + def qualname(self, tree: lark.Tree) -> Term: ... + def rst_role(self, tree: lark.Tree) -> Expr: ... def ELLIPSES(self, token: lark.Token) -> Term: ... def union(self, tree: lark.Tree) -> Expr: ... def subscription(self, tree: lark.Tree) -> Expr: ... + def param_spec(self, tree: lark.Tree) -> Expr: ... + def callable(self, tree: lark.Tree) -> Expr: ... + def literal(self, tree: lark.Tree) -> Expr: ... def natlang_literal(self, tree: lark.Tree) -> Expr: ... def literal_item(self, tree: lark.Tree) -> Term: ... def natlang_container(self, tree: lark.Tree) -> Expr: ... @@ -79,9 +88,7 @@ class DoctypeTransformer(lark.visitors.Transformer): def shape(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... def optional_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... def extra_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... - def _format_subscription( - self, sequence: Sequence[str], rule: str = ... - ) -> Expr: ... + def _format_subscription(self, sequence: Sequence[str], *, rule: str) -> Expr: ... _transformer: Final diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index c480860..bf574b9 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -5,7 +5,7 @@ import logging from collections.abc import Hashable, Iterator, Mapping, Sequence from pathlib import Path from textwrap import indent -from typing import Any, ClassVar, Self, TextIO +from typing import Any, ClassVar, Literal, Self, TextIO import click from pre_commit.envcontext import UNSET diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 3dc44dc..4b52e60 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -146,6 +146,20 @@ def names(self): """ return [term for term in self.terms if term.kind == TermKind.NAME] + @property + def sub_expressions(self): + """Iterate expressions inside the current one. + + Returns + ------- + names : list of Expr or {1} + """ + cls = type(self) + for child in self.children: + if isinstance(child, cls): + yield child + yield from child.sub_expressions + def __iter__(self): """Iterate over children of this expression. @@ -253,6 +267,21 @@ def qualname(self, tree): ) return _qualname + def rst_role(self, tree): + """ + Parameters + ---------- + tree : lark.Tree + + Returns + ------- + out : Expr + """ + # Drop rst_prefix + children = [c for c in tree.children if isinstance(c, Term)] + expr = Expr(rule="rst_role", children=children) + return expr + def ELLIPSES(self, token): """ Parameters @@ -332,11 +361,7 @@ def literal(self, tree): ------- out : Expr """ - items = [ - Term("Literal", kind=TermKind.NAME), - *tree.children, - ] - out = self._format_subscription(items, rule="literal") + out = self._format_subscription(tree.children, rule="literal") return out def natlang_literal(self, tree): diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 6e1e064..30e7acb 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -34,7 +34,7 @@ _type: qualname // [1] https://docutils.sourceforge.io/docs/ref/rst/roles.html // qualname: (/~/ ".")? (NAME ".")* NAME - | (":" (NAME ":")? NAME ":")? "`" qualname "`" + | (":" (NAME ":")? NAME ":")? "`" qualname "`" -> rst_role // An union of different types, joined either by "or" or "|". @@ -73,7 +73,7 @@ natlang_literal: "{" literal_item ("," literal_item)* "}" // isn't allowed to contain "literal_items", so we need to define this. // Assign a higher priority so that things like `Literal[Some.ENUM]` is marked // as a literal expression. -literal.1: "Literal[" literal_item ("," literal_item)* "]" +literal.1: qualname "[" literal_item ("," literal_item)* "]" // An single item in a literal expression (or `optional`). We must also allow diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index 2fd36fc..f63f784 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -5,6 +5,7 @@ from docstub._analysis import PyImport from docstub._docstrings import ( Annotation, + doctype_to_annotation, DocstringAnnotations, ) @@ -35,6 +36,37 @@ def test_unexpected_value(self): Annotation(value="~.foo") +class Test_doctype_to_annotation: + + def test_unknown_name(self): + # Simple unknown name is aliased to typing.Any + annotation = doctype_to_annotation("a") + assert annotation.value == "a" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a") + } + assert unknown_names == [("a", 0, 1)] + + def test_unknown_qualname(self): + # Unknown qualified name is escaped and aliased to typing.Any as well + annotation = doctype_to_annotation("a.b") + assert annotation.value == "a_b" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") + } + assert unknown_names == [("a.b", 0, 3)] + + def test_multiple_unknown_names(self): + # Multiple names are aliased to typing.Any + annotation = doctype_to_annotation("a.b of c") + assert annotation.value == "a_b[c]" + assert annotation.imports == { + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), + PyImport(import_="Incomplete", from_="_typeshed", as_="c"), + } + assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] + + class Test_DocstringAnnotations: def test_empty_docstring(self): docstring = dedent("""No sections in this docstring.""") @@ -247,4 +279,4 @@ def test_combined_numpydoc_params(self): assert annotations.parameters["c"].value == "bool" assert "d" not in annotations.parameters - assert "e" not in annotations.parameters + assert "e" not in annotations.parameters \ No newline at end of file diff --git a/tests/test_doctype.py b/tests/test_doctype.py index 660800e..b1707e2 100644 --- a/tests/test_doctype.py +++ b/tests/test_doctype.py @@ -130,11 +130,15 @@ def test_subscription_error(self, doctype): "Literal[SomeEnum.FIRST, 3]", # Nesting "dict[Literal['a', 'b'], int]", + # Custom qualname for literal + "MyLiteral[0]", + "MyLiteral[SomeEnum.FIRST]", ], ) def test_literals(self, doctype): expr = parse_doctype(doctype) assert expr.as_code() == doctype + assert "literal" in [e.rule for e in expr.sub_expressions] @pytest.mark.parametrize( ("doctype", "expected"), @@ -165,6 +169,7 @@ def test_literals(self, doctype): def test_natlang_literals(self, doctype, expected): expr = parse_doctype(doctype) assert expr.as_code() == expected + assert "natlang_literal" in [e.rule for e in expr.sub_expressions] def test_single_natlang_literal_warning(self, caplog): expr = parse_doctype("{True}") @@ -220,6 +225,7 @@ def test_optional_info(self, doctype, expected, optional_info): def test_callable(self, doctype): expr = parse_doctype(doctype) assert expr.as_code() == doctype + assert "callable" in [e.rule for e in expr.sub_expressions] @pytest.mark.parametrize( "doctype", @@ -240,7 +246,7 @@ def test_callable_error(self, doctype): (":class:`Generator`", "Generator"), (":py:class:`Generator`", "Generator"), (":py:class:`Generator`[int]", "Generator[int]"), - (":py:ref:`~.Foo`[int]", "_Foo[int]"), + (":py:ref:`~.Foo`[int]", "~.Foo[int]"), ("list[:py:class:`Generator`]", "list[Generator]"), ], ) @@ -284,31 +290,3 @@ def test_natlang_array_invalid_shape(self, shape): doctype = f"array of shape {shape}" with pytest.raises(lark.exceptions.UnexpectedInput): _ = parse_doctype(doctype) - - def test_unknown_name(self): - # Simple unknown name is aliased to typing.Any - annotation, unknown_names = parse_doctype("a") - assert annotation.value == "a" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a") - } - assert unknown_names == [("a", 0, 1)] - - def test_unknown_qualname(self): - # Unknown qualified name is escaped and aliased to typing.Any as well - annotation, unknown_names = parse_doctype("a.b") - assert annotation.value == "a_b" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") - } - assert unknown_names == [("a.b", 0, 3)] - - def test_multiple_unknown_names(self): - # Multiple names are aliased to typing.Any - annotation, unknown_names = parse_doctype("a.b of c") - assert annotation.value == "a_b[c]" - assert annotation.imports == { - PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), - PyImport(import_="Incomplete", from_="_typeshed", as_="c"), - } - assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] From 6a1dd0495b4811e448ca3f63ec00d1fe448fc828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 18 Jan 2026 13:41:13 +0100 Subject: [PATCH 11/17] WIP --- src/docstub-stubs/_docstrings.pyi | 1 - src/docstub-stubs/_doctype.pyi | 12 ++++---- src/docstub-stubs/_report.pyi | 32 ++++++++++++++++---- src/docstub/_docstrings.py | 4 +-- src/docstub/_doctype.py | 49 +++++++++++-------------------- src/docstub/_report.py | 12 ++++---- src/docstub/doctype.lark | 2 +- tests/test_docstrings.py | 20 +++++++------ tests/test_doctype.py | 20 +++++++++---- 9 files changed, 83 insertions(+), 69 deletions(-) diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index f1a9c57..a5805f0 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -10,7 +10,6 @@ import click import lark import lark.visitors import numpydoc.docscrape as npds -from _typeshed import Incomplete as Expr from ._analysis import PyImport, TypeMatcher from ._doctype import BlacklistedQualname, Expr, Term, TermKind, parse_doctype diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi index a9ce3bc..f15175d 100644 --- a/src/docstub-stubs/_doctype.pyi +++ b/src/docstub-stubs/_doctype.pyi @@ -7,13 +7,13 @@ from collections.abc import Generator, Iterable, Sequence from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, Final, Literal, Self +from typing import Any, Final, Self import lark import lark.visitors from _typeshed import Incomplete -from _typeshed import Incomplete as Expression +from ._report import ContextReporter from ._utils import DocstubError logger: Final @@ -55,7 +55,7 @@ class Expr: @property def names(self) -> list[Term]: ... @property - def sub_expressions(self) -> list[Expression] | Literal[1]: ... + def sub_expressions(self) -> list[Self]: ... def __iter__(self) -> Generator[Expr | Term]: ... def format_tree(self) -> str: ... def print_tree(self) -> None: ... @@ -69,9 +69,9 @@ class BlacklistedQualname(DocstubError): pass class DoctypeTransformer(lark.visitors.Transformer): + def __init__(self, *, reporter: ContextReporter | None = ...) -> None: ... def start(self, tree: lark.Tree) -> Expr: ... def qualname(self, tree: lark.Tree) -> Term: ... - def qualname(self, tree: lark.Tree) -> Term: ... def rst_role(self, tree: lark.Tree) -> Expr: ... def ELLIPSES(self, token: lark.Token) -> Term: ... def union(self, tree: lark.Tree) -> Expr: ... @@ -90,6 +90,4 @@ class DoctypeTransformer(lark.visitors.Transformer): def extra_info(self, tree: lark.Tree) -> lark.visitors._DiscardType: ... def _format_subscription(self, sequence: Sequence[str], *, rule: str) -> Expr: ... -_transformer: Final - -def parse_doctype(doctype: str) -> Expr: ... +def parse_doctype(doctype: str, *, reporter: ContextReporter | None = ...) -> Expr: ... diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index bf574b9..dce4589 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -35,23 +35,43 @@ class ContextReporter: short: str, *args: Any, log_level: int, - details: str | None = ..., + details: str | tuple[Any, ...] | None = ..., **log_kw: Any ) -> None: ... def debug( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def info( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def warn( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def error( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def critical( - self, short: str, *args: Any, details: str | None = ..., **log_kw: Any + self, + short: str, + *args: Any, + details: str | tuple[Any, ...] | None = ..., + **log_kw: Any ) -> None: ... def __post_init__(self) -> None: ... @staticmethod diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 97dfabc..cd25354 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -29,7 +29,7 @@ def _update_qualnames(expr, *, _parents=()): Parameters ---------- - expr : Expr + expr : ~.Expr _parents : tuple of (~._doctype.Expr, ...) Yields @@ -266,7 +266,7 @@ def doctype_to_annotation(doctype, *, matcher=None, reporter=None, stats=None): stats = Stats() if stats is None else stats try: - expression = parse_doctype(doctype) + expression = parse_doctype(doctype, reporter=reporter) stats.inc_counter("transformed_doctypes") reporter.debug( "Transformed doctype", details=(" %s\n-> %s", doctype, expression) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 4b52e60..5274511 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -12,6 +12,7 @@ import lark import lark.visitors +from ._report import ContextReporter from ._utils import DocstubError logger: Final = logging.getLogger(__name__) @@ -152,7 +153,7 @@ def sub_expressions(self): Returns ------- - names : list of Expr or {1} + names : list of Self """ cls = type(self) for child in self.children: @@ -188,7 +189,7 @@ def format_tree(self): def print_tree(self): """Print full hierarchy as a tree.""" - print(self.format_tree()) + print(self.format_tree()) # noqa: T201 def __repr__(self) -> str: return f"<{type(self).__name__}: '{self.as_code()}' rule='{self.rule}'>" @@ -209,19 +210,15 @@ class BlacklistedQualname(DocstubError): @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): - def start(self, tree): + def __init__(self, *, reporter=None): """ Parameters ---------- - tree : lark.Tree - - Returns - ------- - out : Expr + reporter : ~.ContextReporter """ - return Expr(rule="start", children=tree.children) + self.reporter = reporter or ContextReporter(logger=logger) - def qualname(self, tree): + def start(self, tree): """ Parameters ---------- @@ -229,20 +226,9 @@ def qualname(self, tree): Returns ------- - out : Term + out : Expr """ - children = tree.children - _qualname = ".".join(children) - - if _qualname in BLACKLISTED_QUALNAMES: - raise BlacklistedQualname(_qualname) - - _qualname = Term( - _qualname, - kind=TermKind.NAME, - pos=(tree.meta.start_pos, tree.meta.end_pos), - ) - return _qualname + return Expr(rule="start", children=tree.children) def qualname(self, tree): """ @@ -381,11 +367,11 @@ def natlang_literal(self, tree): out = self._format_subscription(items, rule="natlang_literal") if len(tree.children) == 1: - logger.warning( - "natural language literal with one item `%s`, " - "consider using `%s` to improve readability", + details = ("Consider using `%s` to improve readability", "".join(out)) + self.reporter.warn( + "Natural language literal with one item: `{%s}`", tree.children[0], - "".join(out), + details=details, ) return out @@ -524,16 +510,14 @@ def _format_subscription(self, sequence, *, rule): return expr -_transformer: Final = DoctypeTransformer() - - -def parse_doctype(doctype): +def parse_doctype(doctype, *, reporter=None): """Turn a type description in a docstring into a type annotation. Parameters ---------- doctype : str The doctype to parse. + reporter : ~.ContextReporter, optional Returns ------- @@ -553,5 +537,6 @@ def parse_doctype(doctype): """ tree = _lark.parse(doctype) - expression = _transformer.transform(tree=tree) + transformer = DoctypeTransformer(reporter=reporter) + expression = transformer.transform(tree=tree) return expression diff --git a/src/docstub/_report.py b/src/docstub/_report.py index e496723..1f19bd1 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -95,7 +95,7 @@ def report(self, short, *args, log_level, details=None, **log_kw): Optional formatting arguments for `short`. log_level : int The logging level. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -118,7 +118,7 @@ def debug(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -135,7 +135,7 @@ def info(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -152,7 +152,7 @@ def warn(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -169,7 +169,7 @@ def error(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ @@ -186,7 +186,7 @@ def critical(self, short, *args, details=None, **log_kw): A short summarizing report that shouldn't wrap over multiple lines. *args : Any Optional formatting arguments for `short`. - details : str, optional + details : str or tuple of (Any, ...), optional An optional multiline report with more details. **log_kw : Any """ diff --git a/src/docstub/doctype.lark b/src/docstub/doctype.lark index 30e7acb..4bd6826 100644 --- a/src/docstub/doctype.lark +++ b/src/docstub/doctype.lark @@ -78,7 +78,7 @@ literal.1: qualname "[" literal_item ("," literal_item)* "]" // An single item in a literal expression (or `optional`). We must also allow // for qualified names, since a "class" or enum can be used as a literal too. -literal_item: ELLIPSES | STRING | SIGNED_NUMBER | qualname +literal_item: STRING | SIGNED_NUMBER | qualname // Natural language forms of the subscription expression for containers. diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index f63f784..b47777b 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -5,8 +5,8 @@ from docstub._analysis import PyImport from docstub._docstrings import ( Annotation, - doctype_to_annotation, DocstringAnnotations, + doctype_to_annotation, ) @@ -37,26 +37,25 @@ def test_unexpected_value(self): class Test_doctype_to_annotation: - - def test_unknown_name(self): + def test_unknown_name(self, caplog): # Simple unknown name is aliased to typing.Any annotation = doctype_to_annotation("a") assert annotation.value == "a" assert annotation.imports == { PyImport(import_="Incomplete", from_="_typeshed", as_="a") } - assert unknown_names == [("a", 0, 1)] + assert caplog.messages == ["Unknown name in doctype: 'a'"] - def test_unknown_qualname(self): + def test_unknown_qualname(self, caplog): # Unknown qualified name is escaped and aliased to typing.Any as well annotation = doctype_to_annotation("a.b") assert annotation.value == "a_b" assert annotation.imports == { PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") } - assert unknown_names == [("a.b", 0, 3)] + assert caplog.messages == ["Unknown name in doctype: 'a.b'"] - def test_multiple_unknown_names(self): + def test_multiple_unknown_names(self, caplog): # Multiple names are aliased to typing.Any annotation = doctype_to_annotation("a.b of c") assert annotation.value == "a_b[c]" @@ -64,7 +63,10 @@ def test_multiple_unknown_names(self): PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), PyImport(import_="Incomplete", from_="_typeshed", as_="c"), } - assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] + assert sorted(caplog.messages) == [ + "Unknown name in doctype: 'a.b'", + "Unknown name in doctype: 'c'", + ] class Test_DocstringAnnotations: @@ -279,4 +281,4 @@ def test_combined_numpydoc_params(self): assert annotations.parameters["c"].value == "bool" assert "d" not in annotations.parameters - assert "e" not in annotations.parameters \ No newline at end of file + assert "e" not in annotations.parameters diff --git a/tests/test_doctype.py b/tests/test_doctype.py index b1707e2..0e31843 100644 --- a/tests/test_doctype.py +++ b/tests/test_doctype.py @@ -176,9 +176,9 @@ def test_single_natlang_literal_warning(self, caplog): assert expr.as_code() == "Literal[True]" assert caplog.messages == ["Natural language literal with one item: `{True}`"] assert caplog.records[0].levelno == logging.WARNING - assert ( - caplog.records[0].details - == "Consider using `Literal[True]` to improve readability" + assert caplog.records[0].details == ( + "Consider using `%s` to improve readability", + "Literal[True]", ) @pytest.mark.parametrize( @@ -217,15 +217,25 @@ def test_optional_info(self, doctype, expected, optional_info): "Callable[..., str]", "Callable[[], str]", "Callback[...]", + ], + ) + def test_callable(self, doctype): + expr = parse_doctype(doctype) + assert expr.as_code() == doctype + assert "callable" in [e.rule for e in expr.sub_expressions] + + @pytest.mark.parametrize( + "doctype", + [ "Callable[Concatenate[int, float], str]", "Callable[Concatenate[int, ...], str]", "Callable[P, str]", ], ) - def test_callable(self, doctype): + def test_callable_subscriptions_form(self, doctype): expr = parse_doctype(doctype) assert expr.as_code() == doctype - assert "callable" in [e.rule for e in expr.sub_expressions] + assert "callable" not in [e.rule for e in expr.sub_expressions] @pytest.mark.parametrize( "doctype", From 23e02528a3b4daf29d1ec4d0b0be335a43b97f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 18 Jan 2026 13:55:34 +0100 Subject: [PATCH 12/17] Fix accidental import from PyCharm --- src/docstub-stubs/_report.pyi | 1 - src/docstub/_report.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/docstub-stubs/_report.pyi b/src/docstub-stubs/_report.pyi index dce4589..6430e76 100644 --- a/src/docstub-stubs/_report.pyi +++ b/src/docstub-stubs/_report.pyi @@ -8,7 +8,6 @@ from textwrap import indent from typing import Any, ClassVar, Literal, Self, TextIO import click -from pre_commit.envcontext import UNSET from ._cli_help import should_strip_ansi from ._utils import naive_natsort_key diff --git a/src/docstub/_report.py b/src/docstub/_report.py index 1f19bd1..831b1f3 100644 --- a/src/docstub/_report.py +++ b/src/docstub/_report.py @@ -7,7 +7,6 @@ from textwrap import indent import click -from pre_commit.envcontext import UNSET from ._cli_help import should_strip_ansi from ._utils import naive_natsort_key @@ -602,7 +601,7 @@ def pop(self, key, *, default=_UNSET): ------- value : list[Any] or int """ - if key in self._stats or default is UNSET: + if key in self._stats or default is self._UNSET: return self._stats.pop(key) return default From b923ef30673dba640239c927e85bac8143031fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 18 Jan 2026 14:14:53 +0100 Subject: [PATCH 13/17] Fix stubtest and type checker errors --- src/docstub-stubs/_doctype.pyi | 18 ++++++++++-------- src/docstub/_doctype.py | 20 +++++++++++++------- stubtest_allow.txt | 1 + 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi index f15175d..a25595d 100644 --- a/src/docstub-stubs/_doctype.pyi +++ b/src/docstub-stubs/_doctype.pyi @@ -7,7 +7,7 @@ from collections.abc import Generator, Iterable, Sequence from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, Final, Self +from typing import Any, ClassVar, Final, Self import lark import lark.visitors @@ -16,14 +16,14 @@ from _typeshed import Incomplete from ._report import ContextReporter from ._utils import DocstubError -logger: Final +logger: Final[logging.Logger] -grammar_path: Final +grammar_path: Final[Path] with grammar_path.open() as file: - _grammar: Final + _grammar: Final[str] -_lark: Final +_lark: Final[lark.Lark] def flatten_recursive(iterable: Iterable[Iterable | str]) -> Generator[str]: ... def insert_between(iterable: Iterable, *, sep: Any) -> list[Any]: ... @@ -35,12 +35,14 @@ class TermKind(enum.StrEnum): SYNTAX = enum.auto() class Term(str): + kind: TermKind + pos: tuple[int, int] | None - __slots__: Final + __slots__: Final[ClassVar[tuple[str, ...]]] def __new__( cls, value: str, *, kind: TermKind | str, pos: tuple[int, int] | None = ... - ) -> None: ... + ) -> Self: ... def __repr__(self) -> str: ... def __getnewargs_ex__(self) -> tuple[tuple[Any, ...], dict[str, Any]]: ... @@ -63,7 +65,7 @@ class Expr: def __str__(self) -> str: ... def as_code(self) -> str: ... -BLACKLISTED_QUALNAMES: Final +BLACKLISTED_QUALNAMES: Final[set[str]] class BlacklistedQualname(DocstubError): pass diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 5274511..f078768 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -15,17 +15,17 @@ from ._report import ContextReporter from ._utils import DocstubError -logger: Final = logging.getLogger(__name__) +logger: Final[logging.Logger] = logging.getLogger(__name__) -grammar_path: Final = Path(__file__).parent / "doctype.lark" +grammar_path: Final[Path] = Path(__file__).parent / "doctype.lark" with grammar_path.open() as file: - _grammar: Final = file.read() + _grammar: Final[str] = file.read() # TODO try passing `transformer=DoctypeTransformer()`, may be faster [1] # [1] https://lark-parser.readthedocs.io/en/latest/classes.html#:~:text=after%20the%20parse%2C-,but%20faster,-) -_lark: Final = lark.Lark(_grammar, propagate_positions=True) +_lark: Final[lark.Lark] = lark.Lark(_grammar, propagate_positions=True) def flatten_recursive(iterable): @@ -80,10 +80,12 @@ class Term(str): Attributes ---------- - __slots__ : Final + kind : TermKind + pos : tuple of (int, int) or None + __slots__ : Final[ClassVar[tuple[str, ...]]] """ - __slots__ = ("kind", "pos", "value") + __slots__ = ("kind", "pos") def __new__(cls, value, *, kind, pos=None): """ @@ -92,6 +94,10 @@ def __new__(cls, value, *, kind, pos=None): value : str kind : TermKind or str pos : tuple of (int, int), optional + + Returns + ------- + cls : Self """ self = super().__new__(cls, value) self.kind = TermKind(kind) @@ -201,7 +207,7 @@ def as_code(self) -> str: return str(self) -BLACKLISTED_QUALNAMES: Final = set(keyword.kwlist) - {"None", "True", "False"} +BLACKLISTED_QUALNAMES: Final[set[str]] = set(keyword.kwlist) - {"None", "True", "False"} class BlacklistedQualname(DocstubError): diff --git a/stubtest_allow.txt b/stubtest_allow.txt index 572288a..e9603e4 100644 --- a/stubtest_allow.txt +++ b/stubtest_allow.txt @@ -1,3 +1,4 @@ docstub\._version\..* docstub._cache.FuncSerializer.__type_params__ +docstub._doctype.TermKind..* (docstub\..*__conditional_annotations__)? From 95e70c11e29a95024ad7fa8269b27201483f944e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Sun, 18 Jan 2026 14:30:53 +0100 Subject: [PATCH 14/17] Don't annotate with Final and ClassVar at the same time https://peps.python.org/pep-0591/#:~:text=Variables%20should%20not%20be%20annotated%20with%20both%20ClassVar%20and%20Final. --- src/docstub-stubs/_doctype.pyi | 4 ++-- src/docstub/_doctype.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/docstub-stubs/_doctype.pyi b/src/docstub-stubs/_doctype.pyi index a25595d..97e462a 100644 --- a/src/docstub-stubs/_doctype.pyi +++ b/src/docstub-stubs/_doctype.pyi @@ -7,7 +7,7 @@ from collections.abc import Generator, Iterable, Sequence from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, ClassVar, Final, Self +from typing import Any, Final, Self import lark import lark.visitors @@ -38,7 +38,7 @@ class Term(str): kind: TermKind pos: tuple[int, int] | None - __slots__: Final[ClassVar[tuple[str, ...]]] + __slots__: Final[tuple[str, ...]] def __new__( cls, value: str, *, kind: TermKind | str, pos: tuple[int, int] | None = ... diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index f078768..02348c6 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -82,7 +82,7 @@ class Term(str): ---------- kind : TermKind pos : tuple of (int, int) or None - __slots__ : Final[ClassVar[tuple[str, ...]]] + __slots__ : Final[tuple[str, ...]] """ __slots__ = ("kind", "pos") From 7b4ea9153ea9051b205d2fc5a1f6f5349310f434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 19 Jan 2026 11:34:49 +0100 Subject: [PATCH 15/17] Catch and log warnings with context from NumPyDoc --- src/docstub-stubs/_docstrings.pyi | 1 + src/docstub/_docstrings.py | 9 ++++++++- tests/test_docstrings.py | 20 ++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/docstub-stubs/_docstrings.pyi b/src/docstub-stubs/_docstrings.pyi index a5805f0..f116ada 100644 --- a/src/docstub-stubs/_docstrings.pyi +++ b/src/docstub-stubs/_docstrings.pyi @@ -2,6 +2,7 @@ import logging import traceback +import warnings from collections.abc import Generator, Iterable from dataclasses import dataclass, field from functools import cached_property diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index cd25354..8a6ceba 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -2,6 +2,7 @@ import logging import traceback +import warnings from dataclasses import dataclass, field from functools import cached_property @@ -366,7 +367,6 @@ def __init__(self, docstring, *, matcher=None, reporter=None, stats=None): stats : ~.Stats, optional """ self.docstring = docstring - self.np_docstring = npds.NumpyDocString(docstring) self.matcher = matcher or TypeMatcher() self.stats = Stats() if stats is None else stats @@ -374,6 +374,13 @@ def __init__(self, docstring, *, matcher=None, reporter=None, stats=None): reporter = ContextReporter(logger=logger, line=0) self.reporter = reporter.copy_with(logger=logger) + with warnings.catch_warnings(record=True) as records: + self.np_docstring = npds.NumpyDocString(docstring) + for message in records: + short = "Warning in NumPyDoc while parsing docstring" + details = message.message.args[0] + self.reporter.warn(short, details=details) + @cached_property def attributes(self): """Return the attributes found in the docstring. diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index b47777b..078aaa5 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -282,3 +282,23 @@ def test_combined_numpydoc_params(self): assert "d" not in annotations.parameters assert "e" not in annotations.parameters + + @pytest.mark.filterwarnings("default:Unknown section:UserWarning:numpydoc") + def test_unknown_section_logged(self, caplog): + docstring = dedent( + """ + Parameters + ---------- + a : bool + + To Do + ----- + An unknown section + """ + ) + annotations = DocstringAnnotations(docstring) + assert len(annotations.parameters) == 1 + assert annotations.parameters["a"].value == "bool" + + assert caplog.messages == ["Warning in NumPyDoc while parsing docstring"] + assert caplog.records[0].details == "Unknown section To Do" From ab036ba2b77aa12cc51d4f9674d46f57f1f397ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 19 Jan 2026 11:38:21 +0100 Subject: [PATCH 16/17] Deal with shape only natlang arrays and add doctests and documentation --- src/docstub/_doctype.py | 66 +++++++++++++++++++++++++++++++++++------ tests/test_doctype.py | 4 ++- 2 files changed, 60 insertions(+), 10 deletions(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index 02348c6..dfcdd48 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -1,4 +1,4 @@ -"""Parsing of doctypes.""" +"""Parsing & transformation of doctypes into Python-compatible syntax.""" import enum import keyword @@ -34,10 +34,18 @@ def flatten_recursive(iterable): Parameters ---------- iterable : Iterable[Iterable or str] + An iterable containing nested iterables or strings. Only strings are + supported as "leafs" for now. Yields ------ item : str + + Examples + -------- + >>> nested = ["only", ["strings", ("and", "iterables"), "are", ["allowed"]]] + >>> list(flatten_recursive(nested)) + ['only', 'strings', 'and', 'iterables', 'are', 'allowed'] """ for item in iterable: if isinstance(item, str): @@ -59,6 +67,12 @@ def insert_between(iterable, *, sep): Returns ------- out : list[Any] + + Examples + -------- + >>> code = ["a", "b", "c", ] + >>> list(insert_between(code, sep=" | ")) + ['a', ' | ', 'b', ' | ', 'c'] """ out = [] for item in iterable: @@ -68,6 +82,8 @@ def insert_between(iterable, *, sep): class TermKind(enum.StrEnum): + """Encodes the different kinds of :class:`Term`.""" + # docstub: off NAME = enum.auto() LITERAL = enum.auto() @@ -83,6 +99,17 @@ class Term(str): kind : TermKind pos : tuple of (int, int) or None __slots__ : Final[tuple[str, ...]] + + Examples + -------- + >>> ''.join( + ... [ + ... Term("int", kind="name"), + ... Term(" | ", kind="syntax"), + ... Term("float", kind="name") + ... ] + ... ) + 'int | float' """ __slots__ = ("kind", "pos") @@ -216,6 +243,16 @@ class BlacklistedQualname(DocstubError): @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): + """Transform parsed doctypes into Python-compatible syntax. + + Examples + -------- + >>> tree = _lark.parse("int or tuple of (int, ...)") + >>> transformer = DoctypeTransformer() + >>> str(transformer.transform(tree=tree)) + 'int | tuple[int, ...]' + """ + def __init__(self, *, reporter=None): """ Parameters @@ -310,6 +347,7 @@ def subscription(self, tree): ------- out : Expr """ + assert len(tree.children) > 1 return self._format_subscription(tree.children, rule="subscription") def param_spec(self, tree): @@ -341,6 +379,7 @@ def callable(self, tree): ------- out : Expr """ + assert len(tree.children) > 1 return self._format_subscription(tree.children, rule="callable") def literal(self, tree): @@ -353,6 +392,7 @@ def literal(self, tree): ------- out : Expr """ + assert len(tree.children) > 1 out = self._format_subscription(tree.children, rule="literal") return out @@ -372,6 +412,7 @@ def natlang_literal(self, tree): ] out = self._format_subscription(items, rule="natlang_literal") + assert len(tree.children) >= 1 if len(tree.children) == 1: details = ("Consider using `%s` to improve readability", "".join(out)) self.reporter.warn( @@ -409,6 +450,7 @@ def natlang_container(self, tree): ------- out : Expr """ + assert len(tree.children) >= 1 return self._format_subscription(tree.children, rule="natlang_container") def natlang_array(self, tree): @@ -490,7 +532,8 @@ def extra_info(self, tree): return lark.Discard def _format_subscription(self, sequence, *, rule): - """ + """Format a `name[...]` style expression. + Parameters ---------- sequence : Sequence[str] @@ -502,17 +545,20 @@ def _format_subscription(self, sequence, *, rule): """ sep = Term(", ", kind=TermKind.SYNTAX) container, *content = sequence - content = insert_between(content, sep=sep) - assert content - expr = Expr( - rule=rule, - children=[ + assert container + + if content: + content = insert_between(content, sep=sep) + children = [ container, Term("[", kind=TermKind.SYNTAX), *content, Term("]", kind=TermKind.SYNTAX), - ], - ) + ] + else: + children = [container] + + expr = Expr(rule=rule, children=children) return expr @@ -534,6 +580,8 @@ def parse_doctype(doctype, *, reporter=None): lark.exceptions.VisitError Raised when the transformation is interrupted by an exception. See :cls:`lark.exceptions.VisitError`. + BlacklistedQualname + Raised when a qualname is a forbidden keyword. Examples -------- diff --git a/tests/test_doctype.py b/tests/test_doctype.py index 0e31843..3d4b74f 100644 --- a/tests/test_doctype.py +++ b/tests/test_doctype.py @@ -270,7 +270,6 @@ def test_rst_role(self, doctype, expected): [ ("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"), ("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"), - ("{name} of {dtype}", "{name}[{dtype}]"), ], ) @pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"]) @@ -283,17 +282,20 @@ def test_natlang_array(self, fmt, expected_fmt, name, dtype, shape): expected = expected_fmt.format(name=name, dtype=dtype, shape=shape) expr = parse_doctype(doctype) assert expr.as_code() == expected + assert "natlang_array" in [e.rule for e in expr.sub_expressions] # fmt: on @pytest.mark.parametrize( ("doctype", "expected"), [ ("ndarray of dtype (int or float)", "ndarray[int | float]"), + ("ndarray of shape (M, N)", "ndarray"), ], ) def test_natlang_array_specific(self, doctype, expected): expr = parse_doctype(doctype) assert expr.as_code() == expected + assert "natlang_array" in [e.rule for e in expr.sub_expressions] @pytest.mark.parametrize("shape", ["(-1, 3)", "(1.0, 2)", "-3D", "-2-D"]) def test_natlang_array_invalid_shape(self, shape): From bf85d153155e2df7734e188b7c49afbc087c9ef6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Mon, 19 Jan 2026 12:00:33 +0100 Subject: [PATCH 17/17] Use local logger in DoctypeTransformer --- src/docstub/_doctype.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/docstub/_doctype.py b/src/docstub/_doctype.py index dfcdd48..e0841b4 100644 --- a/src/docstub/_doctype.py +++ b/src/docstub/_doctype.py @@ -259,7 +259,8 @@ def __init__(self, *, reporter=None): ---------- reporter : ~.ContextReporter """ - self.reporter = reporter or ContextReporter(logger=logger) + reporter = reporter or ContextReporter(logger=logger) + self.reporter = reporter.copy_with(logger=logger) def start(self, tree): """