diff --git a/spec/src/commit.toml b/spec/src/commit.toml index 89fa133c6..fdfc19dfb 100644 --- a/spec/src/commit.toml +++ b/spec/src/commit.toml @@ -18,19 +18,19 @@ pad = 0 name = "address" type = "DWordWL" desc = "Address of first byte to commit." -pad = ["arr", 0, 0, 0, 0] +pad = 0 [[variables.auxiliary]] name = "address_incr" type = "DWordHL" desc = "$#`address` + 1$" -pad = ["arr", 1, 0, 0, 0] +pad = 1 [[variables.auxiliary]] name = "count" type = "DWordWL" desc = "number of bytes to commit" -pad = ["arr", 1, 0, 0, 0] +pad = 1 [[variables.auxiliary]] name = "count_decr" diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index 85c3aaf70..4b2c538ca 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -62,7 +62,7 @@ pad = 0 name = "half_instruction_length" type = "Byte" desc = "Half the number of bytes consumed by this instruction, commonly used to indicate whether the instruction is of C type, i.e., whether it is 2 bytes long (= 1) instead of 4 (= 2)" -pad = 2 +pad = 0 [[variables.input]] name = "word_instr" diff --git a/spec/src/dvrm.toml b/spec/src/dvrm.toml index 6ff9b994c..569ef7d28 100644 --- a/spec/src/dvrm.toml +++ b/spec/src/dvrm.toml @@ -27,7 +27,7 @@ pad = 0 name = "q" type = "DWordHL" desc = "The quotient; $#`n` / #`d`$ rounded towards zero." -pad = 0 +pad = ["arr", 65535, 65535, 65535, 65535] [[variables.output]] name = "r" diff --git a/spec/src/sha256round.toml b/spec/src/sha256round.toml index 45da4d452..0c05cdf4a 100644 --- a/spec/src/sha256round.toml +++ b/spec/src/sha256round.toml @@ -126,18 +126,6 @@ type = "Word" desc = "`w[index]`" pad = 0 -[[variables.virtual]] -name = "carry_a" -type = "Byte" -desc = "The carry from `out_a`" -def = ["*", ["^", 2, -32], ["-", ["+", "temp1", "temp2"], ["cast", "out_a", "Word"]]] - -[[variables.virtual]] -name = "carry_e" -type = "Byte" -desc = "The carry from `out_e`" -def = ["*", ["^", 2, -32], ["-", ["+", "d", "temp1"], ["cast", "out_e", "Word"]]] - [[variables.virtual]] name = "ch" type = "Word" @@ -162,6 +150,18 @@ type = "BaseField" desc = "`temp2` value" def = ["+", "S0", "maj"] +[[variables.virtual]] +name = "carry_a" +type = "Byte" +desc = "The carry from `out_a`" +def = ["*", ["^", 2, -32], ["-", ["+", "temp1", "temp2"], ["cast", "out_a", "Word"]]] + +[[variables.virtual]] +name = "carry_e" +type = "Byte" +desc = "The carry from `out_e`" +def = ["*", ["^", 2, -32], ["-", ["+", "d", "temp1"], ["cast", "out_e", "Word"]]] + [[variables.multiplicity]] name = "μ" type = "Bit" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 44cbbed83..ed3d9472e 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1,3 +1,4 @@ +import contextlib import copy import sys import tomllib @@ -9,19 +10,25 @@ class ErrorReporter: reported: bool - location: str + location: list[str] def __init__(self, location: str): self.reported = False - self.location = location + self.location = [location] def update_location(self, loc: str): self.reported = False - self.location = loc + self.location = [loc] + + @contextlib.contextmanager + def context(self, ctx: str): + self.location.append(ctx) + yield + self.location.pop() def error(self, message: str): self.reported = True - print(f"ERROR {self.location}: {message}", file=sys.stderr) + print(f"ERROR {'/'.join(self.location)}: {message}", file=sys.stderr) def asserts(self, condition: bool, message: str): if not condition: @@ -86,6 +93,7 @@ def constant_fits(cst: int, target: Type) -> bool: | MulExpr | AddExpr | SubExpr + | ModExpr | PowExpr | SumExpr | NotExpr @@ -98,7 +106,7 @@ def constant_fits(cst: int, target: Type) -> bool: @dataclass class Environment: config: "Config" - valmap: dict[str, Range] + valmap: dict[str, Type] typemap: dict[str, Type] def with_val(self, key: str, val: Range) -> Self: @@ -175,6 +183,19 @@ def typecheck(self, env: Environment) -> Type: constant_fits(base.get_const(), self.type), f"Casting const to type it doesn't fit: {self!r}", ) + if isinstance(self.type, list): + return [ + CastExpr(LitExpr(base.get_const() if i == 0 else 0), t).typecheck(env) + for i, t in enumerate(self.type) + ] + return base + if isinstance(base, list) and all(b == Range.const(0) for b in base): + # Workaround for casts of constant zero, to make padding work nicely + # This may become cleaner if we eventually get to the cast rework from #326 + if isinstance(self.type, Range): + return Range.const(0) + else: + return [CastExpr(LitExpr(0), t).typecheck(env) for t in self.type] return self.type @@ -182,14 +203,14 @@ def typecheck(self, env: Environment) -> Type: class MulExpr: factors: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): reporter.error(f"Multiplication of non-scalar types: {self!r}") return DEFAULT_TYPE elif not isinstance(a, Range): - return [self.type_match(x, b) for x in a] + return [self.typecheck_binop(x, b) for x in a] elif isinstance(b, list): - return self.type_match(b, a) + return self.typecheck_binop(b, a) else: extrema = [x * y for x in [a.low, a.high] for y in [b.low, b.high]] return Range(min(extrema), max(extrema)) @@ -198,7 +219,7 @@ def typecheck(self, env: Environment) -> Type: reporter.asserts(self.factors != [], f"Empty product: {self!r}") t: Type = Range.const(1) for f in self.factors: - t = self.type_match(t, f.typecheck(env)) + t = self.typecheck_binop(t, f.typecheck(env)) return t @@ -206,12 +227,12 @@ def typecheck(self, env: Environment) -> Type: class AddExpr: terms: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Adding array types of different length {self!r}") return [DEFAULT_TYPE for _ in b] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Adding of scalar and array types {self!r}") return DEFAULT_TYPE @@ -224,7 +245,7 @@ def typecheck(self, env: Environment) -> Type: return Range.const(0) t: Type = self.terms[0].typecheck(env) for term in self.terms[1:]: - t = self.type_match(t, term.typecheck(env)) + t = self.typecheck_binop(t, term.typecheck(env)) return t @@ -233,12 +254,12 @@ class SubExpr: head: Expr subs: list[Expr] - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Subtracting array types of different length {self!r}") return [DEFAULT_TYPE for _ in a] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Subtraction of scalar and array types {self!r}") return DEFAULT_TYPE @@ -253,7 +274,7 @@ def typecheck(self, env: Environment) -> Type: return t return Range(-t.high, -t.low) for term in self.subs: - t = self.type_match(t, term.typecheck(env)) + t = self.typecheck_binop(t, term.typecheck(env)) return t @@ -278,7 +299,7 @@ def typecheck(self, env: Environment) -> Type: elt = elt.get_const() return Range.const(elt % modulus) else: - return Range(0, modulus-1) + return Range(0, modulus - 1) @dataclass @@ -293,9 +314,7 @@ def typecheck(self, env: Environment) -> Type: reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") return DEFAULT_TYPE if isinstance(exp, list) or not exp.is_const(): - reporter.error( - f"Invalid exponentiation with non-const exponent: {self.exp!r}" - ) + reporter.error(f"Invalid exponentiation with non-const exponent: {self.exp!r}") return DEFAULT_TYPE val = pow(base.get_const(), exp.get_const(), env.config.variables.prime) return Range.const(val) @@ -306,12 +325,12 @@ class SumExpr: iter: "Iter" terms: Expr - def type_match(self, a: Type, b: Type) -> Type: + def typecheck_binop(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Summing array types of different length {self!r}") return [DEFAULT_TYPE for _ in b] - return [self.type_match(x, y) for x, y in zip(a, b)] + return [self.typecheck_binop(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Summing of scalar and array types {self!r}") return DEFAULT_TYPE @@ -321,7 +340,7 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: t: Type = Range.const(0) for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): - t = self.type_match(t, tc) + t = self.typecheck_binop(t, tc) return t @@ -349,9 +368,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: case int(x): return LitExpr(x) case str(x): - reporter.asserts( - x.isidentifier(), f"Invalid identifier name for variable {x!r}" - ) + reporter.asserts(x.isidentifier(), f"Invalid identifier name for variable {x!r}") return VarExpr(x) case ["opsel", str(x)]: if x not in OPSEL: @@ -371,9 +388,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: case ["+", *terms]: return AddExpr([build_expr(config, t) for t in terms]) case ["-", head, *subs]: - return SubExpr( - build_expr(config, head), [build_expr(config, s) for s in subs] - ) + return SubExpr(build_expr(config, head), [build_expr(config, s) for s in subs]) case ["mod", elt, modulus]: return ModExpr(build_expr(config, elt), build_expr(config, modulus)) case ["^", base, exp]: @@ -388,6 +403,17 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: return DummyExpr() +def check_padding_fits(config: "Config", type: Type, pad: Expr): + def fits(v, t): + if isinstance(v, Range): + return v.is_const() and constant_fits(v.get_const(), t) + else: + return isinstance(v, list) and isinstance(t, list) and len(v) == len(t) and all(map(fits, v, t)) + + val = pad.typecheck(Environment(config, {}, {})) + reporter.asserts(fits(val, type), f"Invalid padding {pad!r} for type {type!r}") + + @dataclass class Iter: name: str @@ -396,18 +422,12 @@ class Iter: def __init__(self, config: "Config", name: str, start: object, stop: object): self.name = name - reporter.asserts( - isinstance(self.name, str), f"iter name is not a string: {self.name!r}" - ) - reporter.asserts( - self.name.isidentifier(), f"Not a valid identifier: {self.name!r}" - ) + reporter.asserts(isinstance(self.name, str), f"iter name is not a string: {self.name!r}") + reporter.asserts(self.name.isidentifier(), f"Not a valid identifier: {self.name!r}") self.start = build_expr(config, start) self.stop = build_expr(config, stop) - def typecheck[T]( - self, env: Environment, callback: Callable[[Environment], Iterable[T]] - ) -> Iterable[T]: + def typecheck[T](self, env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: start = self.start.typecheck(env) if isinstance(start, list) or not start.is_const(): reporter.error(f"Starting value of iterator not a const: {self!r}") @@ -444,9 +464,7 @@ def clean_iter(it): return Iter(config, *arr) if "iters" in obj: - reporter.asserts( - "iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}" - ) + reporter.asserts("iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}") return [clean_iter(it) for it in obj["iters"]] elif "iter" in obj: return [clean_iter(obj["iter"])] @@ -475,14 +493,10 @@ def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict) f"Invalid range: {data!r}", ) start, stop = data["range"] - if not isinstance(start, int) and not ( - isinstance(start, str) and start.isdigit() - ): + if not isinstance(start, int) and not (isinstance(start, str) and start.isdigit()): reporter.error(f"Range start not an int: {data!r}") start = 0 - if not isinstance(stop, int) and not ( - isinstance(stop, str) and stop.isdigit() - ): + if not isinstance(stop, int) and not (isinstance(stop, str) and stop.isdigit()): reporter.error(f"Range end not an int: {data!r}") stop = start reporter.asserts(int(start) <= int(stop), f"Inverted range: {data!r}") @@ -553,9 +567,7 @@ class ConfigMetadata: def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.version = data["version"] - reporter.asserts( - isinstance(self.version, int), f"version {self.version!r} is not an int" - ) + reporter.asserts(isinstance(self.version, int), f"version {self.version!r} is not an int") @dataclass @@ -610,6 +622,8 @@ def __init__(self, config: Config, category: str, data: dict): self.desc = data["desc"] reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") self.pad = build_expr(None, data.get("pad", 0)) + with reporter.context(self.name): + check_padding_fits(config, self.type, self.pad) self.precomputed = data.get("precomputed", False) reporter.asserts( isinstance(self.precomputed, bool), @@ -617,9 +631,7 @@ def __init__(self, config: Config, category: str, data: dict): ) -def all_iters[T]( - its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]] -) -> Iterable[T]: +def all_iters[T](its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: if not its: yield from callback(env) else: @@ -637,20 +649,14 @@ class VirtualDef: # A list of polynomials with each a set of iters they range over defs: list[PolyWithIters] - def __init__(self, config: Config, name: str, tp: Type, data: dict): + def __init__(self, config: Config, data: dict): if "poly" in data: idx = data.get("idx", None) - self.defs = [ - PolyWithIters( - build_expr(config, data["poly"]), iters_of(data, config, name=idx) - ) - ] + self.defs = [PolyWithIters(build_expr(config, data["poly"]), iters_of(data, config, name=idx))] elif "polys" in data: idx = data.get("idx", None) self.defs = [ - PolyWithIters( - build_expr(config, poly["poly"]), iters_of(poly, config, name=idx) - ) + PolyWithIters(build_expr(config, poly["poly"]), iters_of(poly, config, name=idx)) for poly in data["polys"] ] else: @@ -667,7 +673,7 @@ def __init__(self, config: Config, category: str, data: dict): data = copy.deepcopy(data) def_ = data.pop("def", {}) super().__init__(config, category, data) - self.def_ = VirtualDef(config, self.name, self.type, def_) + self.def_ = VirtualDef(config, def_) def typecheck(self, env: Environment) -> Type: def handle_iters( @@ -683,9 +689,7 @@ def handle_iters( for s in seen: ln = min(len(s), len(indices)) if s[:ln] == tuple(indices[:ln]): - reporter.error( - f"Double definition for virtual column: {self!r} at index {indices}" - ) + reporter.error(f"Double definition for virtual column: {self!r} at index {indices}") break val = poly.typecheck(env) @@ -703,21 +707,15 @@ def handle_iters( # But threading the extra needed state through overly complicates everything start = it.start.typecheck(env) if isinstance(start, list) or not start.is_const(): - reporter.error( - f"Starting value of virtual def iter not a const: {self!r}" - ) + reporter.error(f"Starting value of virtual def iter not a const: {self!r}") start = Range.const(0) stop = it.stop.typecheck(env) if isinstance(stop, list) or not stop.is_const(): - reporter.error( - f"Ending value of virtual def iter not a const: {self!r}" - ) + reporter.error(f"Ending value of virtual def iter not a const: {self!r}") stop = Range.const(start.get_const()) if isinstance(expected, Range): - reporter.error( - f"Virtual definition has an iter for a scalar: {self!r}" - ) + reporter.error(f"Virtual definition has an iter for a scalar: {self!r}") return if not 0 <= start.get_const() <= stop.get_const() < len(expected): @@ -760,9 +758,7 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): ) assigned_type = self.def_.defs[0].poly.typecheck(env) if not isinstance(assigned_type, Range): - reporter.error( - f"Assigning non-scalar type to scalar virtual column: {self!r}" - ) + reporter.error(f"Assigning non-scalar type to scalar virtual column: {self!r}") return self.type # Check type fits? # Leaving this out because it produces too much noise with one-hot assumptions @@ -771,13 +767,37 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): # Check no indices are covered twice seen: set[tuple] = set() for poly_iters in self.def_.defs: - handle_iters( - env, poly_iters.iters, poly_iters.poly, self.type, [], seen - ) + handle_iters(env, poly_iters.iters, poly_iters.poly, self.type, [], seen) # Check everything is covered check_covered(self.type, seen, []) return self.type + def populate_env(self, env: Environment): + # We start off general, and assume that the defs + # are ordered in such a way that each one at most + # depends on the ones before it + env.valmap[self.name] = copy.deepcopy(self.type) + + def assign(env, container, its, v): + idx = env.valmap[its[0].name].get_const() + if len(its) == 1: + container[idx] = v + else: + assign(env, container[idx], its[1:], v) + + for poly_iters in self.def_.defs: + if not poly_iters.iters: + env.valmap[self.name] = poly_iters.poly.typecheck(env) + continue + + for _ in all_iters( + poly_iters.iters, + env, + lambda e: [assign(e, env.valmap[self.name], poly_iters.iters, poly_iters.poly.typecheck(e))], + ): + # Consume the iterator + pass + @dataclass class Assumption: @@ -785,9 +805,7 @@ class Assumption: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected( - data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"} - ) + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"}) self.desc = data["desc"] self.iters = iters_of(data, config) @@ -800,9 +818,7 @@ class ArithConstraint: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected( - data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} - ) + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) assert data["kind"] == "arith" self.constraint = data["constraint"] reporter.asserts( @@ -810,9 +826,7 @@ def __init__(self, config: Config, data: dict): f"Constraint not a string: {self.constraint!r}", ) self.desc = data.get("desc", "") - reporter.asserts( - isinstance(self.desc, str), f"desc is not a string: {self.desc!r}" - ) + reporter.asserts(isinstance(self.desc, str), f"desc is not a string: {self.desc!r}") self.poly = build_expr(config, data["poly"]) self.iters = iters_of(data, config) @@ -827,9 +841,7 @@ def check_includes_zero(t: Type): f"Unsatisfiable constraint, 0 not in range: {self!r} {t}", ) else: - reporter.error( - f"Non-scalar value for polynomial constraint: {self!r} {t}" - ) + reporter.error(f"Non-scalar value for polynomial constraint: {self!r} {t}") for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): check_includes_zero(t) @@ -839,6 +851,7 @@ def check_includes_zero(t: Type): @dataclass class Signature: tag: str + condition: Optional[Type] input: list[Type] output: Optional[Type] @@ -849,11 +862,10 @@ def matches(self, other: Self) -> bool: return False if (self.output is None) != (other.output is None): return False - if ( - self.output is not None - and other.output is not None - and not structure_matches(self.output, other.output) - ): + if self.output is not None and other.output is not None and not structure_matches(self.output, other.output): + return False + # Used as `sig.matches(expected)`, so `self` is the concrete signature found in the toml + if self.condition is not None and other.condition is None: return False return structure_matches(self.input, other.input) @@ -889,13 +901,9 @@ def __init__(self, config: Config, data: dict): ) assert data["kind"] == self.kind self.tag = data["tag"] - reporter.asserts( - isinstance(self.tag, str), f"tag is not a string: {self.tag!r}" - ) + reporter.asserts(isinstance(self.tag, str), f"tag is not a string: {self.tag!r}") self.desc = data.get("desc", "") - reporter.asserts( - isinstance(self.desc, str), f"Description is not a string: {self.desc!r}" - ) + reporter.asserts(isinstance(self.desc, str), f"Description is not a string: {self.desc!r}") self.input = [build_expr(config, inp) for inp in data["input"]] if "output" in data: self.output = build_expr(config, data["output"]) @@ -914,11 +922,13 @@ def __init__(self, config: Config, data: dict): def typecheck(self, env: Environment) -> Iterable[Signature]: def callback(e: Environment) -> Iterable[Signature]: # TODO: Should we be able to check cond/multiplicity somehow? + condition = None if self.conditional is not None: - self.conditional.typecheck(e) + condition = self.conditional.typecheck(e) return [ self.signature( self.tag, + condition, [inp.typecheck(e) for inp in self.input], self.output.typecheck(e) if self.output else None, ) @@ -951,13 +961,11 @@ class InteractionConstraint(InteractionLike): @dataclass class DummyConstraint: - def typecheck(self, env: Environment) -> list[Never]: + def typecheck(self, _env: Environment) -> list[Never]: return [] -type Constraint = ( - ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint -) +type Constraint = ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint def build_constraint(config, data: dict) -> Constraint: @@ -983,15 +991,11 @@ class Chip: def __init__(self, config: Config, data: dict): """Construct a chip from toml-parsed data""" - assert_no_unexpected( - data, set(type(self).__annotations__.keys()) | {"constraint_groups"} - ) + assert_no_unexpected(data, set(type(self).__annotations__.keys()) | {"constraint_groups"}) assert_no_unexpected(data["variables"], config.variables.categories.all) self.config = config self.name = data["name"] - reporter.asserts( - isinstance(self.name, str), f"name is not a string: {self.name!r}" - ) + reporter.asserts(isinstance(self.name, str), f"name is not a string: {self.name!r}") reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") self.variables = [ (Variable if cat != "virtual" else VirtualVariable)(config, cat, var) @@ -1030,27 +1034,56 @@ def typecheck(self) -> Iterable[Signature]: env = Environment(self.config, {}, typemap) for v in self.variables: if isinstance(v, VirtualVariable): - v.typecheck(env) + with reporter.context(v.name): + v.typecheck(env) for c in self.constraints: - yield from c.typecheck(env) + with reporter.context(repr(c)): + yield from c.typecheck(env) + + def check_assignment( + self, + template_checkers: dict[str, "TemplateChecker"], + values: Optional[dict[str, Type]] = None, + ): + env = Environment(self.config, {}, {}) + if values is None: + for v in self.variables: + if not isinstance(v, VirtualVariable): + t = v.type + if isinstance(t, list) and len(t) == 1: + t = t[0] + env.valmap[v.name] = CastExpr(v.pad, t).typecheck(env) + else: + for v in self.variables: + if not isinstance(v, VirtualVariable): + if v.name not in values: + reporter.error(f"Unable to find variable name {v.name!r} when checking assignment") + return + env.valmap[v.name] = values[v.name] + for v in self.variables: + if isinstance(v, VirtualVariable): + v.populate_env(env) + + for c in self.constraints: + for sig in c.typecheck(env): + # Recurse on templates + if isinstance(sig, TemplateSignature) and sig.tag in template_checkers: + with reporter.context(repr(c)): + template_checkers[sig.tag](template_checkers, sig.condition, sig.input, sig.output) def build_signature(config: Config, data: dict) -> Signature: - assert_no_unexpected( - data, {"tag", "kind", "input", "output", "cond", "multiplicity"} - ) + assert_no_unexpected(data, {"tag", "kind", "input", "output", "cond"}) Sig: type[Signature] + cond: Optional[Type] = None match data["kind"]: case "template": - reporter.asserts( - "multiplicity" not in data, - f"Template signature with multiplicity: {data!r}", - ) + if "cond" in data: + cond = build_type(config, data["cond"]) Sig = TemplateSignature case "interaction": - reporter.asserts( - "cond" not in data, f"Template signature with cond: {data!r}" - ) + reporter.asserts("cond" not in data, f"Template signature with cond: {data!r}") + cond = Range.const(1) Sig = InteractionSignature case other: reporter.error(f"Signature of invalid kind '{other}': {data!r}") @@ -1062,7 +1095,7 @@ def build_signature(config: Config, data: dict) -> Signature: output = build_type(config, data["output"]) else: output = None - return Sig(tag, input, output) + return Sig(tag, cond, input, output) def read_signatures(config, filename) -> list[Signature]: @@ -1074,9 +1107,36 @@ def read_signatures(config, filename) -> list[Signature]: def check_signatures(found: Iterable[Signature], expected: list[Signature]): for sig in found: - reporter.asserts( - any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}" - ) + reporter.asserts(any(sig.matches(exp) for exp in expected), f"Unexpected signature: {sig}") + + +# A Function taking a mapping of available templates (for recursive expansion), +# an optional condition variable, a list of inputs and an output, +# and checks that it satisfies the template the checker represents. +type TemplateChecker = Callable[[dict[str, TemplateChecker], Optional[Type], list[Type], Optional[Type]], None] + +def build_template_checker(chip: Chip) -> TemplateChecker: + def check(template_checkers: dict[str, TemplateChecker], cond: Optional[Type], input: list[Type], output: Optional[Type]) -> None: + input = input[:] + values = {} + for v in chip.variables: + match v.category: + case "input": + values[v.name] = input.pop(0) + case "output": + if output is None: + reporter.error(f"No output available for template output variable {v.name!r}") + return + values[v.name] = output + case "condition": + values[v.name] = cond if cond else Range.const(1) + case "virtual": + pass + case other: + reporter.error(f"Cannot check template with variable of category {other!r}") + chip.check_assignment(template_checkers, values) + + return check if __name__ == "__main__": @@ -1087,18 +1147,28 @@ def check_signatures(found: Iterable[Signature], expected: list[Signature]): reported = False chips: list[Chip] = [] + template_checkers: dict[str, TemplateChecker] = {} for file in sys.argv[3:]: if file in sys.argv[1:3]: continue - chips.append(Chip.from_file(config, file)) + chip = Chip.from_file(config, file) + chips.append(chip) + template_checkers[chip.name] = build_template_checker(chip) reported |= reporter.reported if reported: sys.exit(1) + if "ADD" in template_checkers: + template_checkers["SUB"] = lambda checkers, cond, input, output: checkers["ADD"](checkers, cond, [output, input[1]], input[0]) + for chip in chips: reporter.update_location(f"Chip {chip.name}") check_signatures(chip.typecheck(), signatures) reported |= reporter.reported + for chip in chips: + reporter.update_location(f"Padding {chip.name}") + chip.check_assignment(template_checkers) + reported |= reporter.reported if reported: sys.exit(1) else: