Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions testflows/snapshots/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .compare import Compare
from .errors import *
from .v1 import snapshot as snapshot_v1
from .v2 import snapshot as snapshot_v2

__all__ = ["snapshot"]

Expand Down Expand Up @@ -53,6 +54,48 @@ def get_snapshot_filename(frame, path, id):
return filename


def get_snapshot_filename_v2(frame, path, id, filename):
"""Return snapshot filename for V2 (JSON) snapshots.

When ``filename`` is provided, it is used directly (joined with ``path``
if ``path`` is also given). Otherwise, the filename is derived from
the caller's frame info similar to V1 but with a ``.json`` extension.

:param frame: caller's stack frame
:param path: custom snapshot directory, default: ``./snapshots``
:param id: unique id of the snapshot file, default: ``None``
:param filename: explicit filename to use, default: ``None``
"""
if filename is not None:
# When an explicit filename is provided, just join it with path
if path is None:
frame_info = inspect.getframeinfo(frame)
path = os.path.join(os.path.dirname(frame_info.filename), "snapshots")

if not os.path.exists(path):
os.makedirs(path)

return os.path.join(path, filename)

# Fall back to auto-generated filename with .json extension
frame_info = inspect.getframeinfo(frame)

id_parts = [os.path.basename(frame_info.filename)]
if id is not None:
id_parts.append(str(id).lower())
id_parts.append("json")

file_id = ".".join(id_parts)

if path is None:
path = os.path.join(os.path.dirname(frame_info.filename), "snapshots")

if not os.path.exists(path):
os.makedirs(path)

return os.path.join(path, file_id)


def snapshot(
value,
id=None,
Expand All @@ -65,27 +108,33 @@ def snapshot(
version=snapshot_v1.VERSION,
frame=None,
compare=Compare.eq,
filename=None,
):
"""Compare value representation to a stored snapshot.

If snapshot does not exist, assertion passes else
representation of the value is compared to the stored snapshot.

Snapshot files have format:
For V1 (default), snapshot files have format:

<test file name>[.<id>].snapshot

For V2 (JSON), snapshot files have format:

<test file name>[.<id>].json (or custom ``filename``)

:param value: value to be used for snapshot
:param id: unique id of the snapshot file, default: `None`
:param id: unique id of the snapshot file, default: ``None``
:param output: function to output the representation of the value
:param path: custom snapshot path, default: `./snapshots`
:param name: name of the snapshot value inside the snapshots file, default: `snapshot`
:param encoder: custom snapshot encoder, default: `repr`
:param path: custom snapshot path, default: ``./snapshots``
:param name: name of the snapshot value inside the snapshots file, default: ``snapshot``
:param encoder: custom snapshot encoder, default: ``repr``
:param comment: (deprecated)
:param mode: mode of operation: CHECK, UPDATE, REWRITE, default: CHECK | UPDATE
:param version: snapshot version, default: snapshot_v1.VERSION
:param frame: caller frame, default: `None`
:param version: snapshot version, default: ``snapshot_v1.VERSION``
:param frame: caller frame, default: ``None``
:param compare: custom comparison function, default: equals
:param filename: explicit snapshot filename (V2 only), default: ``None``
"""
if frame is None:
frame = inspect.currentframe().f_back
Expand All @@ -98,11 +147,24 @@ def snapshot(
if output:
output(repr_value)

filename = get_snapshot_filename(frame=frame, path=path, id=id)

if version == snapshot_v1.VERSION:
snapshot_file = get_snapshot_filename(frame=frame, path=path, id=id)

return snapshot_v1(
filename=filename,
filename=snapshot_file,
repr_value=repr_value,
name=name,
mode=mode,
compare=compare,
)

if version == snapshot_v2.VERSION:
snapshot_file = get_snapshot_filename_v2(
frame=frame, path=path, id=id, filename=filename
)

return snapshot_v2(
filename=snapshot_file,
repr_value=repr_value,
name=name,
mode=mode,
Expand All @@ -122,3 +184,4 @@ def snapshot(

# define supported versions
snapshot.VERSION_V1 = snapshot_v1.VERSION
snapshot.VERSION_V2 = snapshot_v2.VERSION
168 changes: 168 additions & 0 deletions testflows/snapshots/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2024 Katteli Inc.
# TestFlows.com Open-Source Software Testing Framework (http://testflows.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import textwrap
import difflib

from .errors import SnapshotError as SnapshotErrorBase
from .errors import SnapshotNotFoundError as SnapshotNotFoundErrorBase
from .parallel import RWLock
from .mode import *
from .compare import Compare

locks = {}


def get_lock(filename):
if filename not in locks:
locks[filename] = RWLock()
return locks[filename]


def _json_repr(value):
"""Return a stable JSON string representation of a value."""
return json.dumps(value, indent=2, sort_keys=True)


class SnapshotError(SnapshotErrorBase):
def __init__(self, filename, name, snapshot_value, actual_value):
self.snapshot_value = snapshot_value
self.actual_value = actual_value
self.filename = str(filename)
self.name = str(name)

def __bool__(self):
return False

def __repr__(self):
snapshot_str = _json_repr(self.snapshot_value)
actual_str = _json_repr(self.actual_value)

r = "SnapshotError("
r += "\nfilename=" + self.filename
r += "\nname=" + self.name
r += '\nsnapshot_value="""\n'
r += textwrap.indent(snapshot_str, " " * 4)
r += '""",\nactual_value="""\n'
r += textwrap.indent(actual_str, " " * 4)
r += '""",\ndiff="""\n'
r += textwrap.indent(
"\n".join(
[
line.strip("\n")
for line in difflib.unified_diff(
snapshot_str.splitlines(),
actual_str.splitlines(),
fromfile=self.filename,
tofile="actual",
)
]
),
" " * 4,
)
r += '\n""")'
return r


class SnapshotNotFoundError(SnapshotNotFoundErrorBase):
def __init__(self, filename, name, actual_value):
self.actual_value = actual_value
self.filename = str(filename)
self.name = str(name)

def __bool__(self):
return False

def __repr__(self):
r = "SnapshotNotFoundError("
r += "\nfilename=" + self.filename
r += "\nname=" + self.name
r += '\nactual_value="""\n'
r += textwrap.indent(_json_repr(self.actual_value), " " * 4)
r += '\n""")'
return r


def read_snapshot_file(filename):
"""Read and parse a JSON snapshot file."""
with open(filename, "r", encoding="utf-8") as fd:
return json.load(fd)


def write_snapshot_file(filename, data):
"""Write data to a JSON snapshot file with sorted keys."""
with open(filename, "w", encoding="utf-8") as fd:
json.dump(data, fd, indent=2, sort_keys=True)
fd.write("\n")


def snapshot(
filename,
repr_value,
name="snapshot",
mode=SNAPSHOT_MODE_CHECK | SNAPSHOT_MODE_UPDATE,
compare=Compare.eq,
):
"""Check value against a snapshot value stored in a JSON file.

The JSON file stores a single object where each key is a snapshot name
and each value is the stored snapshot value.

:param filename: path to the JSON snapshot file
:param repr_value: the encoded value to compare (JSON-serializable)
:param name: name of the snapshot entry within the file, default: ``snapshot``
:param mode: mode of operation: CHECK, UPDATE, REWRITE, default: CHECK | UPDATE
:param compare: custom comparison function, default: equals
"""
lock = get_lock(filename)

if os.path.exists(filename):
with lock.read():
data = read_snapshot_file(filename)

if name in data:
snapshot_value = data[name]
if not compare(snapshot_value, repr_value):
if mode & SNAPSHOT_MODE_CHECK:
return SnapshotError(filename, name, snapshot_value, repr_value)
else:
return True

if not (mode & SNAPSHOT_MODE_UPDATE):
return SnapshotNotFoundError(filename, name, repr_value)

# write or update snapshot entry
with lock.write():
if os.path.exists(filename):
data = read_snapshot_file(filename)
else:
data = {}

data[name] = repr_value
write_snapshot_file(filename, data)

if mode & SNAPSHOT_MODE_REWRITE:
# For JSON, rewriting simply re-reads and re-writes the file
# to ensure canonical formatting (sorted keys, consistent indent).
with lock.write():
data = read_snapshot_file(filename)
write_snapshot_file(filename, data)

return True


# define version
snapshot.VERSION = 2
21 changes: 15 additions & 6 deletions tests/actions/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from testflows.core import current, debug
from testflows.snapshots import snapshot
from testflows.snapshots.snapshots import get_snapshot_filename
from testflows.snapshots.snapshots import get_snapshot_filename, get_snapshot_filename_v2

import actions.expect

Expand All @@ -16,11 +16,20 @@ def __init__(self, **kwargs):
self.encoder = kwargs.pop("encoder", "repr")
self.mode = kwargs.pop("mode", snapshot.CHECK | snapshot.UPDATE)
self.version = kwargs.pop("version", snapshot.VERSION_V1)
self.filename = get_snapshot_filename(
frame=kwargs.pop("frame"),
path=kwargs.pop("path", None),
id=kwargs.pop("id", None),
)

frame = kwargs.pop("frame")
path = kwargs.pop("path", None)
id = kwargs.pop("id", None)
explicit_filename = kwargs.pop("filename", None)

if self.version == snapshot.VERSION_V2:
self.filename = get_snapshot_filename_v2(
frame=frame, path=path, id=id, filename=explicit_filename,
)
else:
self.filename = get_snapshot_filename(
frame=frame, path=path, id=id,
)

def __str__(self):
mode = []
Expand Down
29 changes: 26 additions & 3 deletions tests/actions/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from testflows.snapshots import snapshot
from testflows.snapshots.errors import SnapshotError
from testflows.snapshots.snapshots import get_snapshot_filename
from testflows.snapshots.snapshots import get_snapshot_filename, get_snapshot_filename_v2

import actions.model


@TestStep(Given)
def get_unique_id(self, frame=None, path=None):
def get_unique_id(self, frame=None, path=None, version=None):
"""Generate unique snapshot id and delete snapshot file for that id
at the end of the test."""

Expand All @@ -24,7 +24,30 @@ def get_unique_id(self, frame=None, path=None):
yield id

finally:
filename = get_snapshot_filename(frame=frame, path=path, id=id)
if version == snapshot.VERSION_V2:
filename = get_snapshot_filename_v2(frame=frame, path=path, id=id, filename=None)
else:
filename = get_snapshot_filename(frame=frame, path=path, id=id)
with By("deleting file for the snapshot id", description=f"{filename}"):
try:
os.remove(filename)
except FileNotFoundError:
pass


@TestStep(Given)
def get_unique_id_v2(self, frame=None, path=None):
"""Generate unique snapshot id for V2 and delete snapshot file at the end."""

if frame is None:
frame = inspect.currentframe()

try:
id = uuid.uuid4().hex
yield id

finally:
filename = get_snapshot_filename_v2(frame=frame, path=path, id=id, filename=None)
with By("deleting file for the snapshot id", description=f"{filename}"):
try:
os.remove(filename)
Expand Down
1 change: 1 addition & 0 deletions tests/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ def feature(self):
Feature(run=load("value", "feature"))
Feature(run=load("compare", "feature"))
Feature(run=load("mode", "feature"))
Feature(run=load("value_v2", "feature"))
Loading