Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ming.version import __version__, __version_info__
from ming.config import configure
from ming.datastore import create_engine, create_datastore
from ming.encryption import EncryptedObject

# Re-export direction keys
ASCENDING = pymongo.ASCENDING
Expand Down
256 changes: 255 additions & 1 deletion ming/encryption.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Generic
from typing import TYPE_CHECKING, TypeVar, Generic, Any

from ming.utils import classproperty
import ming.schema
Expand Down Expand Up @@ -87,6 +87,215 @@ def key_vault_namespace(self) -> str:
T = TypeVar('T')


class EncryptedObject(dict):
"""A dict-like wrapper that handles encryption/decryption for nested fields.

This class wraps a regular dict and provides transparent encryption/decryption
when accessing fields that have _encrypted counterparts.

This is automatically applied to dict fields that contain encrypted fields,
enabling nested field-level encryption in MongoDB documents.

**Example Usage:**

Define a document with nested encrypted fields:

.. code-block:: python

class User(Document):
class __mongometa__:
name = 'user'
session = my_session

_id = Field(schema.ObjectId)
username = Field(str)
# Dict field with encrypted nested fields
full_name = Field(dict(
first_name_encrypted=schema.Binary,
last_name_encrypted=schema.Binary
))

Create a document with unencrypted nested data:

.. code-block:: python

user = User.make_encr({
'_id': ObjectId(),
'username': 'jdoe',
'full_name': {
'first_name': 'John',
'last_name': 'Doe'
}
})
user.m.save()

Access decrypted values using dict notation:

.. code-block:: python

# Get decrypted values
print(user.full_name['first_name']) # 'John'
print(user.full_name['last_name']) # 'Doe'

# Set new encrypted values
user.full_name['first_name'] = 'Johnny'
user.m.save()

**How it Works:**

1. When you define a dict field with fields ending in ``_encrypted`` (e.g., ``first_name_encrypted``),
the system recognizes these as encrypted fields.

2. When creating a document with ``make_encr()``, any nested fields that have corresponding
``_encrypted`` fields in the schema are automatically encrypted.

3. When you access a field without the ``_encrypted`` suffix (e.g., ``'first_name'``),
EncryptedObject automatically decrypts the value from the ``first_name_encrypted`` field.

4. When you set a field without the ``_encrypted`` suffix, EncryptedObject automatically
encrypts the value and stores it in the corresponding ``_encrypted`` field.

**Multi-level Nesting:**

This works recursively for any level of nesting:

.. code-block:: python

class Profile(Document):
personal_info = Field(dict(
address=dict(
street_encrypted=schema.Binary,
city_encrypted=schema.Binary
)
))

# Access deeply nested encrypted fields
profile.personal_info['address']['street'] = '123 Main St'

:param data: The underlying dict data
:param encr_func: Function to encrypt data (str -> bytes)
:param decr_func: Function to decrypt data (bytes -> str)
:param field_schema: Dict mapping field names to their schemas (for nested dicts)
"""

def __init__(self, data: dict, encr_func, decr_func, field_schema: dict = None):
"""
:param data: The underlying dict data
:param encr_func: Function to encrypt data (str -> bytes)
:param decr_func: Function to decrypt data (bytes -> str)
:param field_schema: Dict mapping field names to their schemas (for nested dicts)
"""
super().__init__(data)
self._encr_func = encr_func
self._decr_func = decr_func
self._field_schema = field_schema or {}

# Wrap any nested dicts that have encrypted fields
self._wrap_nested_dicts()

def _wrap_nested_dicts(self):
"""Wrap nested dicts with EncryptedObject if they contain encrypted fields."""
for key, value in self.items():
if isinstance(value, dict) and not isinstance(value, EncryptedObject):
# Check if this dict has any encrypted fields
if self._has_encrypted_fields(value):
nested_schema = self._field_schema.get(key, {})
self[key] = EncryptedObject(value, self._encr_func, self._decr_func, nested_schema)

def _has_encrypted_fields(self, d: dict) -> bool:
"""Check if a dict has any fields ending with _encrypted."""
return any(k.endswith('_encrypted') for k in d.keys())

def _get_encrypted_field_name(self, key: str) -> str:
"""Get the encrypted field name for a decrypted field."""
return f"{key}_encrypted"

def _is_encrypted_field(self, key: str) -> bool:
"""Check if a field is an encrypted field (ends with _encrypted)."""
return key.endswith('_encrypted')

def _get_decrypted_field_name(self, key: str) -> str:
"""Get the decrypted field name from an encrypted field."""
if key.endswith('_encrypted'):
return key[:-10] # Remove '_encrypted' suffix
return key

def __getitem__(self, key: str) -> Any:
"""Get item with automatic decryption if accessing a decrypted field."""
# If accessing an encrypted field directly, return as-is
if self._is_encrypted_field(key):
return super().__getitem__(key)

# Check if there's an encrypted counterpart
encrypted_key = self._get_encrypted_field_name(key)
if encrypted_key in self:
# This is a decrypted field - decrypt the encrypted value
encrypted_value = super().__getitem__(encrypted_key)
return self._decr_func(encrypted_value)

# Regular field access
value = super().__getitem__(key)

# If the value is a dict with encrypted fields, wrap it
if isinstance(value, dict) and not isinstance(value, EncryptedObject):
if self._has_encrypted_fields(value):
nested_schema = self._field_schema.get(key, {})
value = EncryptedObject(value, self._encr_func, self._decr_func, nested_schema)
super().__setitem__(key, value)

return value

def __setitem__(self, key: str, value: Any):
"""Set item with automatic encryption if setting a decrypted field."""
# If setting an encrypted field directly, set as-is
if self._is_encrypted_field(key):
super().__setitem__(key, value)
return

# Check if there's an encrypted counterpart
encrypted_key = self._get_encrypted_field_name(key)
if encrypted_key in self:
# This is a decrypted field - encrypt the value and store in encrypted field
if value is not None:
encrypted_value = self._encr_func(value)
super().__setitem__(encrypted_key, encrypted_value)
else:
super().__setitem__(encrypted_key, None)
# Don't store the decrypted value
return

# Regular field - just set it
# If value is a dict with encrypted fields, wrap it
if isinstance(value, dict) and not isinstance(value, EncryptedObject):
if self._has_encrypted_fields(value):
nested_schema = self._field_schema.get(key, {})
value = EncryptedObject(value, self._encr_func, self._decr_func, nested_schema)

super().__setitem__(key, value)

def get(self, key: str, default=None) -> Any:
"""Get with default, handling decryption."""
try:
return self[key]
except KeyError:
return default

def __getattr__(self, name: str) -> Any:
"""Support attribute access like obj.field_name."""
try:
return self[name]
except KeyError:
raise AttributeError(name)

def __setattr__(self, name: str, value: Any):
"""Support attribute setting like obj.field_name = value."""
# Handle internal attributes
if name.startswith('_'):
super().__setattr__(name, value)
else:
self[name] = value


class DecryptedField(Generic[T]):

def __init__(self, field_type: type[T], encrypted_field: str):
Expand Down Expand Up @@ -130,6 +339,9 @@ class EncryptedMixin:

Generally, don't use this directly, but instead call the methods on the Document/MappedClass you're working with.
"""

# Make EncryptedObject accessible as a class attribute
EncryptedObject = EncryptedObject

@classproperty
def _datastore(cls) -> ming.datastore.DataStore:
Expand Down Expand Up @@ -204,10 +416,52 @@ def encrypt_some_fields(cls, data: dict) -> dict:
:return: a modified copy of the ``data`` param with the currently-unencrypted-but-encryptable fields replaced with ``_encrypted`` counterparts.
"""
encrypted_data = data.copy()

# Encrypt top-level decrypted fields
for fld in cls.decrypted_field_names():
if fld in encrypted_data:
val = encrypted_data.pop(fld)
encrypted_data[f'{fld}_encrypted'] = cls.encr(val)

# Handle nested dicts - recursively encrypt fields in dict values
if hasattr(cls, 'm') and hasattr(cls.m, 'field_index') and cls.m.field_index:
for key, value in encrypted_data.items():
if isinstance(value, dict) and key in cls.m.field_index:
field = cls.m.field_index[key]
if hasattr(field, 'schema') and hasattr(field.schema, 'fields'):
# This is an Object schema with defined fields
encrypted_data[key] = cls._encrypt_nested_dict(value, field.schema.fields)

return encrypted_data

@classmethod
def _encrypt_nested_dict(cls, data: dict, schema_fields: dict) -> dict:
"""Recursively encrypt fields in a nested dict based on schema.

:param data: The dict data to encrypt
:param schema_fields: The schema fields definition for this dict level
"""
encrypted_data = data.copy()

# Find which fields in the schema are encrypted fields (end with _encrypted)
encrypted_field_names = [k for k in schema_fields.keys() if k.endswith('_encrypted')]

# For each encrypted field, check if we have the decrypted version in data
for encrypted_field in encrypted_field_names:
decrypted_field = encrypted_field[:-10] # Remove '_encrypted'

if decrypted_field in encrypted_data:
# We have the decrypted version - encrypt it
val = encrypted_data.pop(decrypted_field)
encrypted_data[encrypted_field] = cls.encr(val)

# Recursively handle nested dicts
for key, value in encrypted_data.items():
if isinstance(value, dict) and key in schema_fields:
nested_schema = schema_fields[key]
if hasattr(nested_schema, 'fields'):
encrypted_data[key] = cls._encrypt_nested_dict(value, nested_schema.fields)

return encrypted_data

def decrypt_some_fields(self) -> dict:
Expand Down
21 changes: 20 additions & 1 deletion ming/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,9 +391,28 @@ def __init__(self, field):
def __get__(self, inst, cls=None):
if inst is None: return self
try:
return inst[self.name]
value = inst[self.name]
# If the value is a dict with encrypted fields (direct or nested), wrap it with EncryptedObject
from .encryption import EncryptedObject
if isinstance(value, dict) and not isinstance(value, EncryptedObject):
if self._has_encrypted_fields_recursive(value):
# Get encryption functions from the document instance
value = EncryptedObject(value, inst.encr, inst.decr)
# Store the wrapped value back
inst[self.name] = value
return value
except KeyError:
raise AttributeError(self.name)

def _has_encrypted_fields_recursive(self, d: dict) -> bool:
"""Check if a dict has any fields ending with _encrypted, recursively."""
for k, v in d.items():
if k.endswith('_encrypted'):
return True
if isinstance(v, dict):
if self._has_encrypted_fields_recursive(v):
return True
return False

def __set__(self, inst, value):
inst[self.name] = value
Expand Down
Loading