diff --git a/src/cryptojwt/jwk/__init__.py b/src/cryptojwt/jwk/__init__.py index 4294f50..6ccc952 100644 --- a/src/cryptojwt/jwk/__init__.py +++ b/src/cryptojwt/jwk/__init__.py @@ -26,7 +26,16 @@ class JWK: required = ["kty"] def __init__( - self, kty="", alg="", use="", kid="", x5c=None, x5t="", x5u="", key_ops=None, **kwargs + self, + kty="", + alg="", + use="", + kid="", + x5c=None, + x5t="", + x5u="", + key_ops=None, + **kwargs, ): self.extra_args = kwargs @@ -220,6 +229,9 @@ def verify(self): raise ValueError("kid of wrong value type") return True + def __hash__(self) -> int: + return hash((self.thumbprint("SHA-256"), self.kid)) + def __eq__(self, other): """ Compare 2 Key instances to find out if they represent the same key diff --git a/src/cryptojwt/jwk/ec.py b/src/cryptojwt/jwk/ec.py index 25b84cd..4031060 100644 --- a/src/cryptojwt/jwk/ec.py +++ b/src/cryptojwt/jwk/ec.py @@ -232,6 +232,9 @@ def encryption_key(self): """ return self.pub_key + def __hash__(self) -> int: + return super().__hash__() + def __eq__(self, other): """ Verify that the other key has the same properties as myself. diff --git a/src/cryptojwt/jwk/hmac.py b/src/cryptojwt/jwk/hmac.py index e4ad9b7..a170744 100644 --- a/src/cryptojwt/jwk/hmac.py +++ b/src/cryptojwt/jwk/hmac.py @@ -40,7 +40,17 @@ class SYMKey(JWK): required = ["k", "kty"] def __init__( - self, kty="oct", alg="", use="", kid="", x5c=None, x5t="", x5u="", k="", key="", **kwargs + self, + kty="oct", + alg="", + use="", + kid="", + x5c=None, + x5t="", + x5u="", + k="", + key="", + **kwargs, ): JWK.__init__(self, kty, alg, use, kid, x5c, x5t, x5u, **kwargs) self.k = k @@ -117,6 +127,9 @@ def encryption_key(self, alg, **kwargs): return _enc_key + def __hash__(self) -> int: + return super().__hash__() + def __eq__(self, other): """ Compare 2 JWK instances to find out if they represent the same key diff --git a/src/cryptojwt/jwk/okp.py b/src/cryptojwt/jwk/okp.py index 1425f90..0a631f2 100644 --- a/src/cryptojwt/jwk/okp.py +++ b/src/cryptojwt/jwk/okp.py @@ -15,10 +15,16 @@ ) OKPPublicKey = Union[ - ed25519.Ed25519PublicKey, ed448.Ed448PublicKey, x25519.X25519PublicKey, x448.X448PublicKey + ed25519.Ed25519PublicKey, + ed448.Ed448PublicKey, + x25519.X25519PublicKey, + x448.X448PublicKey, ] OKPPrivateKey = Union[ - ed25519.Ed25519PrivateKey, ed448.Ed448PrivateKey, x25519.X25519PrivateKey, x448.X448PrivateKey + ed25519.Ed25519PrivateKey, + ed448.Ed448PrivateKey, + x25519.X25519PrivateKey, + x448.X448PrivateKey, ] OKP_CRV2PUBLIC = { @@ -155,7 +161,8 @@ def deserialize(self): def _serialize_public(self, key): self.x = b64e( key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, ) ).decode("ascii") @@ -257,6 +264,9 @@ def encryption_key(self): """ return self.pub_key + def __hash__(self) -> int: + return super().__hash__() + def __eq__(self, other): """ Verify that the other key has the same properties as myself. @@ -304,9 +314,11 @@ def cmp_keys(a, b, key_type): return False else: if a.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, ) != b.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, ): return False return True diff --git a/src/cryptojwt/jwk/rsa.py b/src/cryptojwt/jwk/rsa.py index 4bd0c03..3847d17 100644 --- a/src/cryptojwt/jwk/rsa.py +++ b/src/cryptojwt/jwk/rsa.py @@ -414,6 +414,9 @@ def load(self, filename): """ return self.load_key(import_private_rsa_key_from_file(filename)) + def __hash__(self) -> int: + return super().__hash__() + def __eq__(self, other): """ Verify that this other key is the same as myself. diff --git a/tests/test_02_jwk.py b/tests/test_02_jwk.py index 73c782e..6e83f3d 100755 --- a/tests/test_02_jwk.py +++ b/tests/test_02_jwk.py @@ -9,7 +9,11 @@ import pytest from cryptography.hazmat.primitives.asymmetric import ec, ed25519, rsa -from cryptojwt.exception import DeSerializationNotPossible, UnsupportedAlgorithm, WrongUsage +from cryptojwt.exception import ( + DeSerializationNotPossible, + UnsupportedAlgorithm, + WrongUsage, +) from cryptojwt.jwk import JWK, certificate_fingerprint, pem_hash, pems_to_x5c from cryptojwt.jwk.ec import ECKey, new_ec_key from cryptojwt.jwk.hmac import SYMKey, new_sym_key, sha256_digest @@ -735,7 +739,11 @@ def test_import_public_key_from_pem_file(filename, key_type): assert isinstance(pub_key, key_type) -OKPKEY = {"crv": "Ed25519", "kty": "OKP", "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"} +OKPKEY = { + "crv": "Ed25519", + "kty": "OKP", + "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo", +} OKPKEY_SHA256 = "kPrK_qmxVWaYVA9wwBF6Iuo3vVzz7TxHCTwXBygrS4k" @@ -809,3 +817,30 @@ def test_key_from_jwk_dict_okp_ed448(): _key = key_from_jwk_dict(jwk) assert isinstance(_key, OKPKey) assert _key.has_private_key() + + +def test_jwk_set(): + keyset: set[JWK] = set() + + key_a = new_ec_key("P-256", kid="key_a") + key_b = new_ec_key("P-256", kid="key_b") + key_c = new_ec_key("P-256", kid="key_b") + + key1 = ECKey(**key_a.serialize()) + key2 = ECKey(**key_b.serialize()) + key3 = ECKey(**key_a.serialize()) + key4 = ECKey(**key_c.serialize()) + + keyset.add(key1) + keyset.add(key1) + keyset.add(key2) + keyset.add(key2) + assert len(keyset) == 2 + + keyset.add(key3) # should not add a new item since key1 == key3 + assert len(keyset) == 2 + + assert key1 in keyset + assert key2 in keyset + assert key3 in keyset + assert key4 not in keyset