Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions benchmarks/plutus_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""PlutusData / Transaction roundtrip benchmark + profiler.

Measures the real cost of decode (from_cbor), encode (to_cbor), and to_json for
both typed PlutusData and untyped RawPlutusData, across synthetic complexity
sweeps, plus Transaction fixtures. Designed to locate the indexing bottleneck.
"""

import cProfile
import io
import pstats
import sys
import time
from dataclasses import dataclass
from typing import Dict, List

from pycardano.cbor import cbor2
from pycardano.plutus import PlutusData, RawPlutusData
from pycardano.serialization import IndefiniteList, default_encoder

try:
from importlib.metadata import version

BACKEND = f"{cbor2.__name__} {version('cbor2pure') if cbor2.__name__=='cbor2pure' else version('cbor2')}"
except Exception:
BACKEND = cbor2.__name__


def best_us(fn, iters, repeats=5):
for _ in range(min(50, iters)):
fn()
best = float("inf")
for _ in range(repeats):
t0 = time.perf_counter()
for _ in range(iters):
fn()
best = min(best, time.perf_counter() - t0)
return best / iters * 1e6


def make_constr(cid, fields):
return cbor2.CBORTag(121 + cid if cid < 7 else 1280 + (cid - 7), fields)


def gen_deep(depth):
node = make_constr(0, [0])
for _ in range(depth):
node = make_constr(1, [node])
return node


def gen_wide(n):
return make_constr(0, [i if i % 2 else b"\x01\x02\x03\x04" for i in range(n)])


def gen_list(n):
return make_constr(0, [IndefiniteList(list(range(n)))])


def gen_map(n):
return make_constr(0, [{i: bytes([i % 256]) * 8 for i in range(n)}])


def gen_realistic(depth, width):
def build(d):
fields = [d, b"\xde\xad\xbe\xef" * 4, IndefiniteList(list(range(width)))]
if d > 0:
fields.append(build(d - 1))
fields.append({j: build(0) for j in range(2)})
return make_constr(d % 6, fields)

return build(depth)


SYNTH = {
"deep(d=40)": gen_deep(40),
"wide(n=200)": gen_wide(200),
"list(n=500)": gen_list(500),
"map(n=200)": gen_map(200),
"realistic(d=6,w=10)": gen_realistic(6, 10),
}


def encode_datum(obj):
return cbor2.dumps(obj, default=default_encoder)


def run_section(title, items, iters):
print(f"\n=== {title} ===")
print(
f" {'case':24} {'bytes':>7} {'decode us':>11} {'encode us':>11} {'to_json us':>11}"
)
for name, raw in items:
dec = best_us(lambda r=raw: RawPlutusData.from_cbor(r), iters)
obj = RawPlutusData.from_cbor(raw)
enc = best_us(lambda o=obj: o.to_cbor(), iters)
try:
tj = best_us(lambda o=obj: o.to_json(), iters)
except Exception:
tj = float("nan")
print(f" {name:24} {len(raw):>7} {dec:>11.1f} {enc:>11.1f} {tj:>11.1f}")


@dataclass
class Inner(PlutusData):
CONSTR_ID = 0
a: int
b: bytes


@dataclass
class Mid(PlutusData):
CONSTR_ID = 1
x: int
items: List[Inner]
mapping: Dict[int, Inner]


@dataclass
class Outer(PlutusData):
CONSTR_ID = 2
a: bytes
mid: Mid
leaves: List[Inner]


def build_typed(n):
inners = [Inner(a=i, b=bytes([i % 256]) * 8) for i in range(n)]
mid = Mid(x=7, items=inners, mapping={i: inners[i] for i in range(min(n, 20))})
return Outer(a=b"\xab" * 28, mid=mid, leaves=inners)


def run_typed(iters):
print("\n=== Typed PlutusData decode vs untyped on the SAME bytes ===")
print(
f" {'n_inner':>8} {'bytes':>7} {'typed dec us':>13} {'raw dec us':>11} {'typed/raw':>10}"
)
for n in (10, 50, 200):
obj = build_typed(n)
raw = obj.to_cbor()
td = best_us(lambda r=raw: Outer.from_cbor(r), iters)
rd = best_us(lambda r=raw: RawPlutusData.from_cbor(r), iters)
print(f" {n:>8} {len(raw):>7} {td:>13.1f} {rd:>11.1f} {td/rd:>9.1f}x")


def main():
iters = int(sys.argv[1]) if len(sys.argv) > 1 else 1000
print(
f"### backend={BACKEND} | python {sys.version.split()[0]} | iters={iters} ###"
)

synth_items = [(n, encode_datum(o)) for n, o in SYNTH.items()]
run_section("RawPlutusData (untyped - typical indexer path)", synth_items, iters)
run_typed(iters)

heaviest = max(synth_items, key=lambda kv: len(kv[1]))
print(
f"\n=== cProfile: RawPlutusData.from_cbor on '{heaviest[0]}' ({len(heaviest[1])}B) ==="
)
pr = cProfile.Profile()
pr.enable()
for _ in range(2000):
RawPlutusData.from_cbor(heaviest[1])
pr.disable()
s = io.StringIO()
pstats.Stats(pr, stream=s).sort_stats("tottime").print_stats(16)
for line in s.getvalue().splitlines():
if (
"pycardano" in line
or "cbor2" in line
or "function calls" in line
or "{" in line
):
print(" " + line.strip()[:115])


if __name__ == "__main__":
main()
31 changes: 18 additions & 13 deletions pycardano/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,19 +418,24 @@ def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address:
raise DecodingException(f"Failed to decode address string: {e}")

# At this point, value is always bytes
# Check if it's a Byron address (CBOR with tag 24)
try:
decoded = cbor2.loads(value)
if isinstance(decoded, (tuple, list)) and len(decoded) == 2:
if isinstance(decoded[0], CBORTag) and decoded[0].tag == 24:
# This is definitely a Byron address - validate and decode it
return cls._from_byron_cbor(value)
except DecodingException:
# Byron decoding failed with validation error - re-raise it
raise
except Exception:
# Not Byron CBOR (general CBOR decode error), continue with Shelley decoding
pass
# Check if it's a Byron address (CBOR with tag 24). A Byron address is a
# 2-element CBOR array whose first element is tag 24, i.e. its bytes start with
# b"\x82\xd8\x18" (array(2) + tag(24)). Guarding on that prefix avoids a
# speculative cbor2.loads() on every (Shelley) address, whose header byte is
# never 0x82.
if value[:3] == b"\x82\xd8\x18":
try:
decoded = cbor2.loads(value)
if isinstance(decoded, (tuple, list)) and len(decoded) == 2:
if isinstance(decoded[0], CBORTag) and decoded[0].tag == 24:
# This is definitely a Byron address - validate and decode it
return cls._from_byron_cbor(value)
except DecodingException:
# Byron decoding failed with validation error - re-raise it
raise
except Exception:
# Not Byron CBOR (general CBOR decode error), continue with Shelley
pass

# Shelley address decoding (existing logic)
header = value[0]
Expand Down
44 changes: 42 additions & 2 deletions pycardano/plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from dataclasses import dataclass, field, fields
from enum import Enum
from hashlib import sha256
from typing import Any, List, Optional, Type, Union
from typing import Any, List, Optional, Tuple, Type, Union
from weakref import WeakKeyDictionary

from cbor2 import CBORTag
from nacl.encoding import RawEncoder
Expand Down Expand Up @@ -515,6 +516,19 @@ def id_map(cls, skip_constructor=False):
raise TypeError(f"Unexpected type for automatic constructor generation: {cls}")


# Per-class cache of the dataclass ``fields`` tuple for PlutusData subclasses.
# The set of fields and their declared types are class-invariant, so the
# (class-invariant) field-type validity check in ``PlutusData.__post_init__``
# only needs to run once per class instead of once per instance. We key on the
# class object via a WeakKeyDictionary so dynamically-created classes (e.g. the
# many dataclasses defined inside test functions) do not leak and never collide.
# Presence of a class in this cache means its field types have already been
# validated; the cached value is the ``fields(cls)`` tuple, reused per instance.
_plutusdata_fields_cache: "WeakKeyDictionary[type, Tuple[Any, ...]]" = (
WeakKeyDictionary()
)


@dataclass(repr=False)
class PlutusData(ArrayCBORSerializable):
"""
Expand Down Expand Up @@ -555,6 +569,30 @@ def CONSTR_ID(cls):
return getattr(cls, k)

def __post_init__(self):
cls = type(self)
# The field set and their declared types are class-invariant, so the
# field-type validity check is identical for every instance of a class.
# Once a class has been validated we cache its ``fields`` tuple; presence
# in the cache means the field types have already passed validation, so
# subsequent instances only run the per-instance bytes-length check.
cls_fields = _plutusdata_fields_cache.get(cls)
if cls_fields is not None:
# Fast path: class already validated. Only the bytes-length check,
# which depends on instance data, needs to run.
for f in cls_fields:
data = getattr(self, f.name)
if isinstance(data, bytes) and len(data) > 64:
raise InvalidArgumentException(
f"The size of {data} exceeds {self.MAX_BYTES_SIZE} bytes. "
"Use pycardano.serialization.ByteString for long bytes."
)
return

# Slow path: first instance of this class. Preserve the original
# behavior exactly, including the interleaved order of the type check
# and the bytes-length check across fields. Only cache the validated
# ``fields`` tuple if every field type passes the (class-invariant)
# validity check.
valid_types = (
RawPlutusData,
PlutusData,
Expand All @@ -564,7 +602,8 @@ def __post_init__(self):
ByteString,
bytes,
)
for f in fields(self):
cls_fields = fields(self)
for f in cls_fields:
if inspect.isclass(f.type) and not issubclass(f.type, valid_types):
raise TypeError(
f"Invalid field type: {f.type}. A field in PlutusData should be one of {valid_types}"
Expand All @@ -576,6 +615,7 @@ def __post_init__(self):
f"The size of {data} exceeds {self.MAX_BYTES_SIZE} bytes. "
"Use pycardano.serialization.ByteString for long bytes."
)
_plutusdata_fields_cache[cls] = cls_fields

def to_shallow_primitive(self) -> CBORTag:
primitives: Primitive = super().to_shallow_primitive()
Expand Down
Loading
Loading