Skip to content

Commit 04ce390

Browse files
fix: expose connection identifier on connection objects (#360)
* fix: expose connection identifier on connection objects * updated the get local connection fields to pass tests * ruff * added a comment on KW_ONLY
1 parent 9435884 commit 04ce390

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

openhexa/sdk/workspaces/connection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
class Connection:
88
"""Abstract base class for connections."""
99

10-
pass
10+
_: dataclasses.KW_ONLY # Ensures `identifier` is always passed as a keyword argument, preventing positional conflicts with subclass fields
11+
identifier: str = ""
1112

1213

1314
@dataclasses.dataclass
@@ -100,7 +101,7 @@ def __repr__(self):
100101

101102

102103
@dataclasses.dataclass
103-
class IASOConnection:
104+
class IASOConnection(Connection):
104105
"""IASO connection.
105106
106107
See https://github.com/BLSQ/iaso for more information.

openhexa/sdk/workspaces/current_workspace.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def _get_local_connection_fields(self, env_variable_prefix: str):
224224
connection_fields = {}
225225
connection_type = os.getenv(env_variable_prefix).upper()
226226

227-
# Get fields for the connection type
228-
_fields = fields(ConnectionClasses[connection_type])
227+
# Get fields for the connection type, excluding base Connection fields
228+
_fields = [f for f in fields(ConnectionClasses[connection_type]) if f.name != "identifier"]
229229

230230
if _fields:
231231
for field in _fields:
@@ -305,14 +305,15 @@ def get_connection(
305305
# different from the offline ones
306306
if connection_type == "S3":
307307
secret_access_key = connection_fields.pop("access_key_secret")
308-
return S3Connection(secret_access_key=secret_access_key, **connection_fields)
308+
return S3Connection(secret_access_key=secret_access_key, identifier=identifier, **connection_fields)
309309

310310
if connection_type == "POSTGRESQL":
311311
db_name = connection_fields.pop("db_name")
312312
port = int(connection_fields.pop("port"))
313313
return PostgreSQLConnection(
314314
database_name=db_name,
315315
port=port,
316+
identifier=identifier,
316317
**connection_fields,
317318
)
318319

@@ -323,9 +324,9 @@ def get_connection(
323324
bases=(CustomConnection,),
324325
repr=False,
325326
)
326-
return dataclass(**connection_fields)
327+
return dataclass(identifier=identifier, **connection_fields)
327328

328-
return ConnectionClasses[connection_type](**connection_fields)
329+
return ConnectionClasses[connection_type](identifier=identifier, **connection_fields)
329330

330331
def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection:
331332
"""Get a DHIS2 connection by its identifier.

0 commit comments

Comments
 (0)