From 6b9e38d8171d715f5d1229e52a11c6f985ef92c5 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Mon, 15 Jun 2026 15:27:12 +0200 Subject: [PATCH 1/8] Check padding for structure/type. Closes #351 and #634 --- spec/src/commit.toml | 6 +- spec/tooling/chip.py | 146 +++++++++++++------------------------------ 2 files changed, 48 insertions(+), 104 deletions(-) diff --git a/spec/src/commit.toml b/spec/src/commit.toml index 89fa133c6..fdfc19dfb 100644 --- a/spec/src/commit.toml +++ b/spec/src/commit.toml @@ -18,19 +18,19 @@ pad = 0 name = "address" type = "DWordWL" desc = "Address of first byte to commit." -pad = ["arr", 0, 0, 0, 0] +pad = 0 [[variables.auxiliary]] name = "address_incr" type = "DWordHL" desc = "$#`address` + 1$" -pad = ["arr", 1, 0, 0, 0] +pad = 1 [[variables.auxiliary]] name = "count" type = "DWordWL" desc = "number of bytes to commit" -pad = ["arr", 1, 0, 0, 0] +pad = 1 [[variables.auxiliary]] name = "count_decr" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 44cbbed83..800bc7975 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -278,7 +278,7 @@ def typecheck(self, env: Environment) -> Type: elt = elt.get_const() return Range.const(elt % modulus) else: - return Range(0, modulus-1) + return Range(0, modulus - 1) @dataclass @@ -293,9 +293,7 @@ def typecheck(self, env: Environment) -> Type: reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") return DEFAULT_TYPE if isinstance(exp, list) or not exp.is_const(): - reporter.error( - f"Invalid exponentiation with non-const exponent: {self.exp!r}" - ) + reporter.error(f"Invalid exponentiation with non-const exponent: {self.exp!r}") return DEFAULT_TYPE val = pow(base.get_const(), exp.get_const(), env.config.variables.prime) return Range.const(val) @@ -349,9 +347,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: case int(x): return LitExpr(x) case str(x): - reporter.asserts( - x.isidentifier(), f"Invalid identifier name for variable {x!r}" - ) + reporter.asserts(x.isidentifier(), f"Invalid identifier name for variable {x!r}") return VarExpr(x) case ["opsel", str(x)]: if x not in OPSEL: @@ -371,9 +367,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: case ["+", *terms]: return AddExpr([build_expr(config, t) for t in terms]) case ["-", head, *subs]: - return SubExpr( - build_expr(config, head), [build_expr(config, s) for s in subs] - ) + return SubExpr(build_expr(config, head), [build_expr(config, s) for s in subs]) case ["mod", elt, modulus]: return ModExpr(build_expr(config, elt), build_expr(config, modulus)) case ["^", base, exp]: @@ -388,6 +382,17 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: return DummyExpr() +def check_padding_fits(var_name: str, config: "Config", type: Type, pad: Expr): + def fits(v, t): + if isinstance(v, Range): + return v.is_const() and constant_fits(v.get_const(), t) + else: + return isinstance(v, list) and isinstance(t, list) and len(v) == len(t) and all(map(fits, v, t)) + + val = pad.typecheck(Environment(config, {}, {})) + reporter.asserts(fits(val, type), f"{var_name!r}: Invalid padding {pad!r} for type {type!r}") + + @dataclass class Iter: name: str @@ -396,18 +401,12 @@ class Iter: def __init__(self, config: "Config", name: str, start: object, stop: object): self.name = name - reporter.asserts( - isinstance(self.name, str), f"iter name is not a string: {self.name!r}" - ) - reporter.asserts( - self.name.isidentifier(), f"Not a valid identifier: {self.name!r}" - ) + reporter.asserts(isinstance(self.name, str), f"iter name is not a string: {self.name!r}") + reporter.asserts(self.name.isidentifier(), f"Not a valid identifier: {self.name!r}") self.start = build_expr(config, start) self.stop = build_expr(config, stop) - def typecheck[T]( - self, env: Environment, callback: Callable[[Environment], Iterable[T]] - ) -> Iterable[T]: + def typecheck[T](self, env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: start = self.start.typecheck(env) if isinstance(start, list) or not start.is_const(): reporter.error(f"Starting value of iterator not a const: {self!r}") @@ -444,9 +443,7 @@ def clean_iter(it): return Iter(config, *arr) if "iters" in obj: - reporter.asserts( - "iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}" - ) + reporter.asserts("iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}") return [clean_iter(it) for it in obj["iters"]] elif "iter" in obj: return [clean_iter(obj["iter"])] @@ -475,14 +472,10 @@ def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict) f"Invalid range: {data!r}", ) start, stop = data["range"] - if not isinstance(start, int) and not ( - isinstance(start, str) and start.isdigit() - ): + if not isinstance(start, int) and not (isinstance(start, str) and start.isdigit()): reporter.error(f"Range start not an int: {data!r}") start = 0 - if not isinstance(stop, int) and not ( - isinstance(stop, str) and stop.isdigit() - ): + if not isinstance(stop, int) and not (isinstance(stop, str) and stop.isdigit()): reporter.error(f"Range end not an int: {data!r}") stop = start reporter.asserts(int(start) <= int(stop), f"Inverted range: {data!r}") @@ -553,9 +546,7 @@ class ConfigMetadata: def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.version = data["version"] - reporter.asserts( - isinstance(self.version, int), f"version {self.version!r} is not an int" - ) + reporter.asserts(isinstance(self.version, int), f"version {self.version!r} is not an int") @dataclass @@ -610,6 +601,7 @@ def __init__(self, config: Config, category: str, data: dict): self.desc = data["desc"] reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") self.pad = build_expr(None, data.get("pad", 0)) + check_padding_fits(self.name, config, self.type, self.pad) self.precomputed = data.get("precomputed", False) reporter.asserts( isinstance(self.precomputed, bool), @@ -617,9 +609,7 @@ def __init__(self, config: Config, category: str, data: dict): ) -def all_iters[T]( - its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]] -) -> Iterable[T]: +def all_iters[T](its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: if not its: yield from callback(env) else: @@ -640,17 +630,11 @@ class VirtualDef: def __init__(self, config: Config, name: str, tp: Type, data: dict): if "poly" in data: idx = data.get("idx", None) - self.defs = [ - PolyWithIters( - build_expr(config, data["poly"]), iters_of(data, config, name=idx) - ) - ] + self.defs = [PolyWithIters(build_expr(config, data["poly"]), iters_of(data, config, name=idx))] elif "polys" in data: idx = data.get("idx", None) self.defs = [ - PolyWithIters( - build_expr(config, poly["poly"]), iters_of(poly, config, name=idx) - ) + PolyWithIters(build_expr(config, poly["poly"]), iters_of(poly, config, name=idx)) for poly in data["polys"] ] else: @@ -683,9 +667,7 @@ def handle_iters( for s in seen: ln = min(len(s), len(indices)) if s[:ln] == tuple(indices[:ln]): - reporter.error( - f"Double definition for virtual column: {self!r} at index {indices}" - ) + reporter.error(f"Double definition for virtual column: {self!r} at index {indices}") break val = poly.typecheck(env) @@ -703,21 +685,15 @@ def handle_iters( # But threading the extra needed state through overly complicates everything start = it.start.typecheck(env) if isinstance(start, list) or not start.is_const(): - reporter.error( - f"Starting value of virtual def iter not a const: {self!r}" - ) + reporter.error(f"Starting value of virtual def iter not a const: {self!r}") start = Range.const(0) stop = it.stop.typecheck(env) if isinstance(stop, list) or not stop.is_const(): - reporter.error( - f"Ending value of virtual def iter not a const: {self!r}" - ) + reporter.error(f"Ending value of virtual def iter not a const: {self!r}") stop = Range.const(start.get_const()) if isinstance(expected, Range): - reporter.error( - f"Virtual definition has an iter for a scalar: {self!r}" - ) + reporter.error(f"Virtual definition has an iter for a scalar: {self!r}") return if not 0 <= start.get_const() <= stop.get_const() < len(expected): @@ -760,9 +736,7 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): ) assigned_type = self.def_.defs[0].poly.typecheck(env) if not isinstance(assigned_type, Range): - reporter.error( - f"Assigning non-scalar type to scalar virtual column: {self!r}" - ) + reporter.error(f"Assigning non-scalar type to scalar virtual column: {self!r}") return self.type # Check type fits? # Leaving this out because it produces too much noise with one-hot assumptions @@ -771,9 +745,7 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): # Check no indices are covered twice seen: set[tuple] = set() for poly_iters in self.def_.defs: - handle_iters( - env, poly_iters.iters, poly_iters.poly, self.type, [], seen - ) + handle_iters(env, poly_iters.iters, poly_iters.poly, self.type, [], seen) # Check everything is covered check_covered(self.type, seen, []) return self.type @@ -785,9 +757,7 @@ class Assumption: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected( - data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"} - ) + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"}) self.desc = data["desc"] self.iters = iters_of(data, config) @@ -800,9 +770,7 @@ class ArithConstraint: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected( - data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} - ) + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) assert data["kind"] == "arith" self.constraint = data["constraint"] reporter.asserts( @@ -810,9 +778,7 @@ def __init__(self, config: Config, data: dict): f"Constraint not a string: {self.constraint!r}", ) self.desc = data.get("desc", "") - reporter.asserts( - isinstance(self.desc, str), f"desc is not a string: {self.desc!r}" - ) + reporter.asserts(isinstance(self.desc, str), f"desc is not a string: {self.desc!r}") self.poly = build_expr(config, data["poly"]) self.iters = iters_of(data, config) @@ -827,9 +793,7 @@ def check_includes_zero(t: Type): f"Unsatisfiable constraint, 0 not in range: {self!r} {t}", ) else: - reporter.error( - f"Non-scalar value for polynomial constraint: {self!r} {t}" - ) + reporter.error(f"Non-scalar value for polynomial constraint: {self!r} {t}") for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): check_includes_zero(t) @@ -849,11 +813,7 @@ def matches(self, other: Self) -> bool: return False if (self.output is None) != (other.output is None): return False - if ( - self.output is not None - and other.output is not None - and not structure_matches(self.output, other.output) - ): + if self.output is not None and other.output is not None and not structure_matches(self.output, other.output): return False return structure_matches(self.input, other.input) @@ -889,13 +849,9 @@ def __init__(self, config: Config, data: dict): ) assert data["kind"] == self.kind self.tag = data["tag"] - reporter.asserts( - isinstance(self.tag, str), f"tag is not a string: {self.tag!r}" - ) + reporter.asserts(isinstance(self.tag, str), f"tag is not a string: {self.tag!r}") self.desc = data.get("desc", "") - reporter.asserts( - isinstance(self.desc, str), f"Description is not a string: {self.desc!r}" - ) + reporter.asserts(isinstance(self.desc, str), f"Description is not a string: {self.desc!r}") self.input = [build_expr(config, inp) for inp in data["input"]] if "output" in data: self.output = build_expr(config, data["output"]) @@ -955,9 +911,7 @@ def typecheck(self, env: Environment) -> list[Never]: return [] -type Constraint = ( - ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint -) +type Constraint = ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint def build_constraint(config, data: dict) -> Constraint: @@ -983,15 +937,11 @@ class Chip: def __init__(self, config: Config, data: dict): """Construct a chip from toml-parsed data""" - assert_no_unexpected( - data, set(type(self).__annotations__.keys()) | {"constraint_groups"} - ) + assert_no_unexpected(data, set(type(self).__annotations__.keys()) | {"constraint_groups"}) assert_no_unexpected(data["variables"], config.variables.categories.all) self.config = config self.name = data["name"] - reporter.asserts( - isinstance(self.name, str), f"name is not a string: {self.name!r}" - ) + reporter.asserts(isinstance(self.name, str), f"name is not a string: {self.name!r}") reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") self.variables = [ (Variable if cat != "virtual" else VirtualVariable)(config, cat, var) @@ -1036,9 +986,7 @@ def typecheck(self) -> Iterable[Signature]: def build_signature(config: Config, data: dict) -> Signature: - assert_no_unexpected( - data, {"tag", "kind", "input", "output", "cond", "multiplicity"} - ) + assert_no_unexpected(data, {"tag", "kind", "input", "output", "cond", "multiplicity"}) Sig: type[Signature] match data["kind"]: case "template": @@ -1048,9 +996,7 @@ def build_signature(config: Config, data: dict) -> Signature: ) Sig = TemplateSignature case "interaction": - reporter.asserts( - "cond" not in data, f"Template signature with cond: {data!r}" - ) + reporter.asserts("cond" not in data, f"Template signature with cond: {data!r}") Sig = InteractionSignature case other: reporter.error(f"Signature of invalid kind '{other}': {data!r}") @@ -1074,9 +1020,7 @@ def read_signatures(config, filename) -> list[Signature]: def check_signatures(found: Iterable[Signature], expected: list[Signature]): for sig in found: - reporter.asserts( - any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}" - ) + reporter.asserts(any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}") if __name__ == "__main__": From f0276146a38e6b8bd82acfd47b165cfa52363c95 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 13:19:12 +0200 Subject: [PATCH 2/8] Check padding satisfies as many constraints as we can reasonably check --- spec/src/cpu.toml | 2 +- spec/src/dvrm.toml | 2 +- spec/src/sha256round.toml | 24 +++---- spec/tooling/chip.py | 136 ++++++++++++++++++++++++++++++++++---- 4 files changed, 138 insertions(+), 26 deletions(-) diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index 85c3aaf70..4b2c538ca 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -62,7 +62,7 @@ pad = 0 name = "half_instruction_length" type = "Byte" desc = "Half the number of bytes consumed by this instruction, commonly used to indicate whether the instruction is of C type, i.e., whether it is 2 bytes long (= 1) instead of 4 (= 2)" -pad = 2 +pad = 0 [[variables.input]] name = "word_instr" diff --git a/spec/src/dvrm.toml b/spec/src/dvrm.toml index 6ff9b994c..569ef7d28 100644 --- a/spec/src/dvrm.toml +++ b/spec/src/dvrm.toml @@ -27,7 +27,7 @@ pad = 0 name = "q" type = "DWordHL" desc = "The quotient; $#`n` / #`d`$ rounded towards zero." -pad = 0 +pad = ["arr", 65535, 65535, 65535, 65535] [[variables.output]] name = "r" diff --git a/spec/src/sha256round.toml b/spec/src/sha256round.toml index 45da4d452..0c05cdf4a 100644 --- a/spec/src/sha256round.toml +++ b/spec/src/sha256round.toml @@ -126,18 +126,6 @@ type = "Word" desc = "`w[index]`" pad = 0 -[[variables.virtual]] -name = "carry_a" -type = "Byte" -desc = "The carry from `out_a`" -def = ["*", ["^", 2, -32], ["-", ["+", "temp1", "temp2"], ["cast", "out_a", "Word"]]] - -[[variables.virtual]] -name = "carry_e" -type = "Byte" -desc = "The carry from `out_e`" -def = ["*", ["^", 2, -32], ["-", ["+", "d", "temp1"], ["cast", "out_e", "Word"]]] - [[variables.virtual]] name = "ch" type = "Word" @@ -162,6 +150,18 @@ type = "BaseField" desc = "`temp2` value" def = ["+", "S0", "maj"] +[[variables.virtual]] +name = "carry_a" +type = "Byte" +desc = "The carry from `out_a`" +def = ["*", ["^", 2, -32], ["-", ["+", "temp1", "temp2"], ["cast", "out_a", "Word"]]] + +[[variables.virtual]] +name = "carry_e" +type = "Byte" +desc = "The carry from `out_e`" +def = ["*", ["^", 2, -32], ["-", ["+", "d", "temp1"], ["cast", "out_e", "Word"]]] + [[variables.multiplicity]] name = "μ" type = "Bit" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 800bc7975..8f33048d5 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -9,19 +9,25 @@ class ErrorReporter: reported: bool - location: str + location: list[str] def __init__(self, location: str): self.reported = False - self.location = location + self.location = [location] def update_location(self, loc: str): self.reported = False - self.location = loc + self.location = [loc] + + def push_context(self, ctx: str): + self.location.append(ctx) + + def pop_context(self): + self.location.pop() def error(self, message: str): self.reported = True - print(f"ERROR {self.location}: {message}", file=sys.stderr) + print(f"ERROR {'/'.join(self.location)}: {message}", file=sys.stderr) def asserts(self, condition: bool, message: str): if not condition: @@ -175,6 +181,19 @@ def typecheck(self, env: Environment) -> Type: constant_fits(base.get_const(), self.type), f"Casting const to type it doesn't fit: {self!r}", ) + if isinstance(self.type, list): + return [ + CastExpr(LitExpr(base.get_const() if i == 0 else 0), t).typecheck(env) + for i, t in enumerate(self.type) + ] + return base + if isinstance(base, list) and all(b == Range.const(0) for b in base): + # Workaround for casts of constant zero, to make padding work nicely + # This may become cleaner if we eventually get to the cast rework from #326 + if isinstance(self.type, Range): + return Range.const(0) + else: + return [CastExpr(LitExpr(0), t).typecheck(env) for t in self.type] return self.type @@ -750,6 +769,32 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): check_covered(self.type, seen, []) return self.type + def populate_env(self, env: Environment): + # We start off general, and assume that the defs + # are ordered in such a way that each one at most + # depends on the ones before it + env.valmap[self.name] = copy.deepcopy(self.type) + + def assign(env, container, its, v): + idx = env.valmap[its[0].name].get_const() + if len(its) == 1: + container[idx] = v + else: + assign(env, container[idx], its[1:], v) + + for poly_iters in self.def_.defs: + if not poly_iters.iters: + env.valmap[self.name] = poly_iters.poly.typecheck(env) + continue + + for _ in all_iters( + poly_iters.iters, + env, + lambda e: [assign(e, env.valmap[self.name], poly_iters.iters, poly_iters.poly.typecheck(e))], + ): + # Consume the iterator + pass + @dataclass class Assumption: @@ -803,6 +848,7 @@ def check_includes_zero(t: Type): @dataclass class Signature: tag: str + condition: Optional[Type] input: list[Type] output: Optional[Type] @@ -870,11 +916,13 @@ def __init__(self, config: Config, data: dict): def typecheck(self, env: Environment) -> Iterable[Signature]: def callback(e: Environment) -> Iterable[Signature]: # TODO: Should we be able to check cond/multiplicity somehow? + condition = None if self.conditional is not None: - self.conditional.typecheck(e) + condition = self.conditional.typecheck(e) return [ self.signature( self.tag, + condition, [inp.typecheck(e) for inp in self.input], self.output.typecheck(e) if self.output else None, ) @@ -980,20 +1028,55 @@ def typecheck(self) -> Iterable[Signature]: env = Environment(self.config, {}, typemap) for v in self.variables: if isinstance(v, VirtualVariable): + reporter.push_context(v.name) v.typecheck(env) + reporter.pop_context() for c in self.constraints: + reporter.push_context(repr(c)) yield from c.typecheck(env) + reporter.pop_context() + + def check_assignment( + self, + check_template: dict[str, Callable[[Optional[Type], list[Type], Type], None]], + values: Optional[dict[str, Type]] = None, + ): + env = Environment(self.config, {}, {}) + if values is None: + for v in self.variables: + if not isinstance(v, VirtualVariable): + t = v.type + if isinstance(t, list) and len(t) == 1: + t = t[0] + env.valmap[v.name] = CastExpr(v.pad, t).typecheck(env) + else: + for v in self.variables: + if not isinstance(v, VirtualVariable): + if v.name not in values: + reporter.error(f"Unable to find variable name {v.name!r} when checking assignment") + return + env.valmap[v.name] = values[v.name] + for v in self.variables: + if isinstance(v, VirtualVariable): + v.populate_env(env) + + for c in self.constraints: + for sig in c.typecheck(env): + # Recurse on templates + if isinstance(sig, TemplateSignature): + reporter.push_context(repr(c)) + check_template[sig.tag](sig.condition, sig.input, sig.output) + reporter.pop_context() def build_signature(config: Config, data: dict) -> Signature: - assert_no_unexpected(data, {"tag", "kind", "input", "output", "cond", "multiplicity"}) + assert_no_unexpected(data, {"tag", "kind", "input", "output", "cond"}) Sig: type[Signature] + cond: Optional[Type] = None match data["kind"]: case "template": - reporter.asserts( - "multiplicity" not in data, - f"Template signature with multiplicity: {data!r}", - ) + if "cond" in data: + cond = build_type(config, data["cond"]) Sig = TemplateSignature case "interaction": reporter.asserts("cond" not in data, f"Template signature with cond: {data!r}") @@ -1008,7 +1091,7 @@ def build_signature(config: Config, data: dict) -> Signature: output = build_type(config, data["output"]) else: output = None - return Sig(tag, input, output) + return Sig(tag, cond, input, output) def read_signatures(config, filename) -> list[Signature]: @@ -1023,6 +1106,27 @@ def check_signatures(found: Iterable[Signature], expected: list[Signature]): reporter.asserts(any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}") +def template_checker(check_template: dict[str, Callable[[Optional[Type], list[Type], Type], None]], chip: Chip): + def check(cond: Optional[Type], input: list[Type], output: Type): + input = input[:] + values = {} + for v in chip.variables: + match v.category: + case "input": + values[v.name] = input.pop(0) + case "output": + values[v.name] = output + case "condition": + values[v.name] = cond if cond else Range.const(1) + case "virtual": + pass + case other: + reporter.error(f"Cannot check template with variable of category {other!r}") + chip.check_assignment(check_template, values) + + return check + + if __name__ == "__main__": config = Config.from_file(sys.argv[1]) signatures = read_signatures(config, sys.argv[2]) @@ -1031,18 +1135,26 @@ def check_signatures(found: Iterable[Signature], expected: list[Signature]): reported = False chips: list[Chip] = [] + template_checkers: dict[str, Callable[[Optional[Type], list[Type], Type], None]] = {} for file in sys.argv[3:]: if file in sys.argv[1:3]: continue - chips.append(Chip.from_file(config, file)) + chips.append(chip := Chip.from_file(config, file)) + template_checkers[chip.name] = template_checker(template_checkers, chip) reported |= reporter.reported if reported: sys.exit(1) + template_checkers["SUB"] = lambda cond, input, output: template_checkers["ADD"](cond, [output, input[1]], input[0]) + for chip in chips: reporter.update_location(f"Chip {chip.name}") check_signatures(chip.typecheck(), signatures) reported |= reporter.reported + for chip in chips: + reporter.update_location(f"Padding {chip.name}") + chip.check_assignment(template_checkers) + reported |= reporter.reported if reported: sys.exit(1) else: From 6f500eb788a3841e46d00c615f14e69014ccca5d Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 13:23:52 +0200 Subject: [PATCH 3/8] Avoid crashing on runs that don't check all chips --- spec/tooling/chip.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 8f33048d5..815ee6902 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1063,7 +1063,7 @@ def check_assignment( for c in self.constraints: for sig in c.typecheck(env): # Recurse on templates - if isinstance(sig, TemplateSignature): + if isinstance(sig, TemplateSignature) and sig.tag in check_template: reporter.push_context(repr(c)) check_template[sig.tag](sig.condition, sig.input, sig.output) reporter.pop_context() @@ -1145,7 +1145,8 @@ def check(cond: Optional[Type], input: list[Type], output: Type): if reported: sys.exit(1) - template_checkers["SUB"] = lambda cond, input, output: template_checkers["ADD"](cond, [output, input[1]], input[0]) + if "ADD" in template_checkers: + template_checkers["SUB"] = lambda cond, input, output: template_checkers["ADD"](cond, [output, input[1]], input[0]) for chip in chips: reporter.update_location(f"Chip {chip.name}") From f089adfc038b408782d3e2816c3983a6dcf56c10 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 13:33:44 +0200 Subject: [PATCH 4/8] Turn push_context into a context manager --- spec/tooling/chip.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 815ee6902..c8b142c9c 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1,3 +1,4 @@ +import contextlib import copy import sys import tomllib @@ -19,10 +20,10 @@ def update_location(self, loc: str): self.reported = False self.location = [loc] - def push_context(self, ctx: str): + @contextlib.contextmanager + def context(self, ctx: str): self.location.append(ctx) - - def pop_context(self): + yield self.location.pop() def error(self, message: str): @@ -1028,13 +1029,11 @@ def typecheck(self) -> Iterable[Signature]: env = Environment(self.config, {}, typemap) for v in self.variables: if isinstance(v, VirtualVariable): - reporter.push_context(v.name) - v.typecheck(env) - reporter.pop_context() + with reporter.context(v.name): + v.typecheck(env) for c in self.constraints: - reporter.push_context(repr(c)) - yield from c.typecheck(env) - reporter.pop_context() + with reporter.context(repr(c)): + yield from c.typecheck(env) def check_assignment( self, @@ -1064,9 +1063,8 @@ def check_assignment( for sig in c.typecheck(env): # Recurse on templates if isinstance(sig, TemplateSignature) and sig.tag in check_template: - reporter.push_context(repr(c)) - check_template[sig.tag](sig.condition, sig.input, sig.output) - reporter.pop_context() + with reporter.context(repr(c)): + check_template[sig.tag](sig.condition, sig.input, sig.output) def build_signature(config: Config, data: dict) -> Signature: From ae425a0e3abdffba1873fa9d75df9d8bb00c3a24 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 13:34:05 +0200 Subject: [PATCH 5/8] Check that no template cond is given when it isn't possible --- spec/tooling/chip.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index c8b142c9c..a69ea2427 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -862,6 +862,9 @@ def matches(self, other: Self) -> bool: return False if self.output is not None and other.output is not None and not structure_matches(self.output, other.output): return False + # Used as `sig.matches(expected)`, so `self` is the concrete signature found in the toml + if self.condition is not None and other.condition is None: + return False return structure_matches(self.input, other.input) @@ -1078,6 +1081,7 @@ def build_signature(config: Config, data: dict) -> Signature: Sig = TemplateSignature case "interaction": reporter.asserts("cond" not in data, f"Template signature with cond: {data!r}") + cond = Range.const(1) Sig = InteractionSignature case other: reporter.error(f"Signature of invalid kind '{other}': {data!r}") From 65667016ff316e4a82192f3778bff86b509eb9cc Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 15:41:00 +0200 Subject: [PATCH 6/8] rename type_match to typecheck_binop --- spec/tooling/chip.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index a69ea2427..07f3f4a1e 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -202,14 +202,14 @@ def typecheck(self, env: Environment) -> Type: class MulExpr: factors: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): reporter.error(f"Multiplication of non-scalar types: {self!r}") return DEFAULT_TYPE elif not isinstance(a, Range): - return [self.type_match(x, b) for x in a] + return [self.typecheck_binop(x, b) for x in a] elif isinstance(b, list): - return self.type_match(b, a) + return self.typecheck_binop(b, a) else: extrema = [x * y for x in [a.low, a.high] for y in [b.low, b.high]] return Range(min(extrema), max(extrema)) @@ -218,7 +218,7 @@ def typecheck(self, env: Environment) -> Type: reporter.asserts(self.factors != [], f"Empty product: {self!r}") t: Type = Range.const(1) for f in self.factors: - t = self.type_match(t, f.typecheck(env)) + t = self.typecheck_binop(t, f.typecheck(env)) return t @@ -226,12 +226,12 @@ def typecheck(self, env: Environment) -> Type: class AddExpr: terms: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Adding array types of different length {self!r}") return [DEFAULT_TYPE for _ in b] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Adding of scalar and array types {self!r}") return DEFAULT_TYPE @@ -244,7 +244,7 @@ def typecheck(self, env: Environment) -> Type: return Range.const(0) t: Type = self.terms[0].typecheck(env) for term in self.terms[1:]: - t = self.type_match(t, term.typecheck(env)) + t = self.typecheck_binop(t, term.typecheck(env)) return t @@ -253,12 +253,12 @@ class SubExpr: head: Expr subs: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Subtracting array types of different length {self!r}") return [DEFAULT_TYPE for _ in a] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Subtraction of scalar and array types {self!r}") return DEFAULT_TYPE @@ -273,7 +273,7 @@ def typecheck(self, env: Environment) -> Type: return t return Range(-t.high, -t.low) for term in self.subs: - t = self.type_match(t, term.typecheck(env)) + t = self.typecheck_binop(t, term.typecheck(env)) return t @@ -324,12 +324,12 @@ class SumExpr: iter: "Iter" terms: Expr - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Summing array types of different length {self!r}") return [DEFAULT_TYPE for _ in b] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Summing of scalar and array types {self!r}") return DEFAULT_TYPE @@ -339,7 +339,7 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: t: Type = Range.const(0) for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): - t = self.type_match(t, tc) + t = self.typecheck_binop(t, tc) return t From dc81fbc948108b0e5f7977bac23cb79e317dcb4b Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 15:42:18 +0200 Subject: [PATCH 7/8] Cleanup some signatures --- spec/tooling/chip.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 07f3f4a1e..0a9b75c95 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -93,6 +93,7 @@ def constant_fits(cst: int, target: Type) -> bool: | MulExpr | AddExpr | SubExpr + | ModExpr | PowExpr | SumExpr | NotExpr @@ -105,7 +106,7 @@ def constant_fits(cst: int, target: Type) -> bool: @dataclass class Environment: config: "Config" - valmap: dict[str, Range] + valmap: dict[str, Type] typemap: dict[str, Type] def with_val(self, key: str, val: Range) -> Self: @@ -402,7 +403,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: return DummyExpr() -def check_padding_fits(var_name: str, config: "Config", type: Type, pad: Expr): +def check_padding_fits(config: "Config", type: Type, pad: Expr): def fits(v, t): if isinstance(v, Range): return v.is_const() and constant_fits(v.get_const(), t) @@ -410,7 +411,7 @@ def fits(v, t): return isinstance(v, list) and isinstance(t, list) and len(v) == len(t) and all(map(fits, v, t)) val = pad.typecheck(Environment(config, {}, {})) - reporter.asserts(fits(val, type), f"{var_name!r}: Invalid padding {pad!r} for type {type!r}") + reporter.asserts(fits(val, type), f"Invalid padding {pad!r} for type {type!r}") @dataclass @@ -621,7 +622,8 @@ def __init__(self, config: Config, category: str, data: dict): self.desc = data["desc"] reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") self.pad = build_expr(None, data.get("pad", 0)) - check_padding_fits(self.name, config, self.type, self.pad) + with reporter.context(self.name): + check_padding_fits(config, self.type, self.pad) self.precomputed = data.get("precomputed", False) reporter.asserts( isinstance(self.precomputed, bool), @@ -647,7 +649,7 @@ class VirtualDef: # A list of polynomials with each a set of iters they range over defs: list[PolyWithIters] - def __init__(self, config: Config, name: str, tp: Type, data: dict): + def __init__(self, config: Config, data: dict): if "poly" in data: idx = data.get("idx", None) self.defs = [PolyWithIters(build_expr(config, data["poly"]), iters_of(data, config, name=idx))] @@ -671,7 +673,7 @@ def __init__(self, config: Config, category: str, data: dict): data = copy.deepcopy(data) def_ = data.pop("def", {}) super().__init__(config, category, data) - self.def_ = VirtualDef(config, self.name, self.type, def_) + self.def_ = VirtualDef(config, def_) def typecheck(self, env: Environment) -> Type: def handle_iters( @@ -959,7 +961,7 @@ class InteractionConstraint(InteractionLike): @dataclass class DummyConstraint: - def typecheck(self, env: Environment) -> list[Never]: + def typecheck(self, _env: Environment) -> list[Never]: return [] From c5371348ed15cdc84d0ac3a36a3fef999fa9154a Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 16 Jun 2026 15:42:38 +0200 Subject: [PATCH 8/8] Rework template_checkers recursion --- spec/tooling/chip.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 0a9b75c95..ed3d9472e 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1042,7 +1042,7 @@ def typecheck(self) -> Iterable[Signature]: def check_assignment( self, - check_template: dict[str, Callable[[Optional[Type], list[Type], Type], None]], + template_checkers: dict[str, "TemplateChecker"], values: Optional[dict[str, Type]] = None, ): env = Environment(self.config, {}, {}) @@ -1067,9 +1067,9 @@ def check_assignment( for c in self.constraints: for sig in c.typecheck(env): # Recurse on templates - if isinstance(sig, TemplateSignature) and sig.tag in check_template: + if isinstance(sig, TemplateSignature) and sig.tag in template_checkers: with reporter.context(repr(c)): - check_template[sig.tag](sig.condition, sig.input, sig.output) + template_checkers[sig.tag](template_checkers, sig.condition, sig.input, sig.output) def build_signature(config: Config, data: dict) -> Signature: @@ -1110,8 +1110,13 @@ def check_signatures(found: Iterable[Signature], expected: list[Signature]): reporter.asserts(any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}") -def template_checker(check_template: dict[str, Callable[[Optional[Type], list[Type], Type], None]], chip: Chip): - def check(cond: Optional[Type], input: list[Type], output: Type): +# A Function taking a mapping of available templates (for recursive expansion), +# an optional condition variable, a list of inputs and an output, +# and checks that it satisfies the template the checker represents. +type TemplateChecker = Callable[[dict[str, TemplateChecker], Optional[Type], list[Type], Optional[Type]], None] + +def build_template_checker(chip: Chip) -> TemplateChecker: + def check(template_checkers: dict[str, TemplateChecker], cond: Optional[Type], input: list[Type], output: Optional[Type]) -> None: input = input[:] values = {} for v in chip.variables: @@ -1119,6 +1124,9 @@ def check(cond: Optional[Type], input: list[Type], output: Type): case "input": values[v.name] = input.pop(0) case "output": + if output is None: + reporter.error(f"No output available for template output variable {v.name!r}") + return values[v.name] = output case "condition": values[v.name] = cond if cond else Range.const(1) @@ -1126,7 +1134,7 @@ def check(cond: Optional[Type], input: list[Type], output: Type): pass case other: reporter.error(f"Cannot check template with variable of category {other!r}") - chip.check_assignment(check_template, values) + chip.check_assignment(template_checkers, values) return check @@ -1139,18 +1147,19 @@ def check(cond: Optional[Type], input: list[Type], output: Type): reported = False chips: list[Chip] = [] - template_checkers: dict[str, Callable[[Optional[Type], list[Type], Type], None]] = {} + template_checkers: dict[str, TemplateChecker] = {} for file in sys.argv[3:]: if file in sys.argv[1:3]: continue - chips.append(chip := Chip.from_file(config, file)) - template_checkers[chip.name] = template_checker(template_checkers, chip) + chip = Chip.from_file(config, file) + chips.append(chip) + template_checkers[chip.name] = build_template_checker(chip) reported |= reporter.reported if reported: sys.exit(1) if "ADD" in template_checkers: - template_checkers["SUB"] = lambda cond, input, output: template_checkers["ADD"](cond, [output, input[1]], input[0]) + template_checkers["SUB"] = lambda checkers, cond, input, output: checkers["ADD"](checkers, cond, [output, input[1]], input[0]) for chip in chips: reporter.update_location(f"Chip {chip.name}")