Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/cryptojwt/jwk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,16 @@ class JWK:
required = ["kty"]

def __init__(
self, kty="", alg="", use="", kid="", x5c=None, x5t="", x5u="", key_ops=None, **kwargs
self,
kty="",
alg="",
use="",
kid="",
x5c=None,
x5t="",
x5u="",
key_ops=None,
**kwargs,
):
self.extra_args = kwargs

Expand Down Expand Up @@ -220,6 +229,9 @@ def verify(self):
raise ValueError("kid of wrong value type")
return True

def __hash__(self) -> int:
return hash((self.thumbprint("SHA-256"), self.kid))

def __eq__(self, other):
"""
Compare 2 Key instances to find out if they represent the same key
Expand Down
3 changes: 3 additions & 0 deletions src/cryptojwt/jwk/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def encryption_key(self):
"""
return self.pub_key

def __hash__(self) -> int:
return super().__hash__()

def __eq__(self, other):
"""
Verify that the other key has the same properties as myself.
Expand Down
15 changes: 14 additions & 1 deletion src/cryptojwt/jwk/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,17 @@ class SYMKey(JWK):
required = ["k", "kty"]

def __init__(
self, kty="oct", alg="", use="", kid="", x5c=None, x5t="", x5u="", k="", key="", **kwargs
self,
kty="oct",
alg="",
use="",
kid="",
x5c=None,
x5t="",
x5u="",
k="",
key="",
**kwargs,
):
JWK.__init__(self, kty, alg, use, kid, x5c, x5t, x5u, **kwargs)
self.k = k
Expand Down Expand Up @@ -117,6 +127,9 @@ def encryption_key(self, alg, **kwargs):

return _enc_key

def __hash__(self) -> int:
return super().__hash__()

def __eq__(self, other):
"""
Compare 2 JWK instances to find out if they represent the same key
Expand Down
22 changes: 17 additions & 5 deletions src/cryptojwt/jwk/okp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
)

OKPPublicKey = Union[
ed25519.Ed25519PublicKey, ed448.Ed448PublicKey, x25519.X25519PublicKey, x448.X448PublicKey
ed25519.Ed25519PublicKey,
ed448.Ed448PublicKey,
x25519.X25519PublicKey,
x448.X448PublicKey,
]
OKPPrivateKey = Union[
ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey, x25519.X25519PrivateKey, x448.X448PrivateKey
ed25519.Ed25519PrivateKey,
ed448.Ed448PrivateKey,
x25519.X25519PrivateKey,
x448.X448PrivateKey,
]

OKP_CRV2PUBLIC = {
Expand Down Expand Up @@ -155,7 +161,8 @@ def deserialize(self):
def _serialize_public(self, key):
self.x = b64e(
key.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode("ascii")

Expand Down Expand Up @@ -257,6 +264,9 @@ def encryption_key(self):
"""
return self.pub_key

def __hash__(self) -> int:
return super().__hash__()

def __eq__(self, other):
"""
Verify that the other key has the same properties as myself.
Expand Down Expand Up @@ -304,9 +314,11 @@ def cmp_keys(a, b, key_type):
return False
else:
if a.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
) != b.public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
):
return False
return True
Expand Down
3 changes: 3 additions & 0 deletions src/cryptojwt/jwk/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ def load(self, filename):
"""
return self.load_key(import_private_rsa_key_from_file(filename))

def __hash__(self) -> int:
return super().__hash__()

def __eq__(self, other):
"""
Verify that this other key is the same as myself.
Expand Down
39 changes: 37 additions & 2 deletions tests/test_02_jwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import pytest
from cryptography.hazmat.primitives.asymmetric import ec, ed25519, rsa

from cryptojwt.exception import DeSerializationNotPossible, UnsupportedAlgorithm, WrongUsage
from cryptojwt.exception import (
DeSerializationNotPossible,
UnsupportedAlgorithm,
WrongUsage,
)
from cryptojwt.jwk import JWK, certificate_fingerprint, pem_hash, pems_to_x5c
from cryptojwt.jwk.ec import ECKey, new_ec_key
from cryptojwt.jwk.hmac import SYMKey, new_sym_key, sha256_digest
Expand Down Expand Up @@ -735,7 +739,11 @@ def test_import_public_key_from_pem_file(filename, key_type):
assert isinstance(pub_key, key_type)


OKPKEY = {"crv": "Ed25519", "kty": "OKP", "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"}
OKPKEY = {
"crv": "Ed25519",
"kty": "OKP",
"x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo",
}
OKPKEY_SHA256 = "kPrK_qmxVWaYVA9wwBF6Iuo3vVzz7TxHCTwXBygrS4k"


Expand Down Expand Up @@ -809,3 +817,30 @@ def test_key_from_jwk_dict_okp_ed448():
_key = key_from_jwk_dict(jwk)
assert isinstance(_key, OKPKey)
assert _key.has_private_key()


def test_jwk_set():
keyset: set[JWK] = set()

key_a = new_ec_key("P-256", kid="key_a")
key_b = new_ec_key("P-256", kid="key_b")
key_c = new_ec_key("P-256", kid="key_b")

key1 = ECKey(**key_a.serialize())
key2 = ECKey(**key_b.serialize())
key3 = ECKey(**key_a.serialize())
key4 = ECKey(**key_c.serialize())

keyset.add(key1)
keyset.add(key1)
keyset.add(key2)
keyset.add(key2)
assert len(keyset) == 2

keyset.add(key3) # should not add a new item since key1 == key3
assert len(keyset) == 2

assert key1 in keyset
assert key2 in keyset
assert key3 in keyset
assert key4 not in keyset
Loading