-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrapidata_client.py
More file actions
182 lines (153 loc) · 7.72 KB
/
rapidata_client.py
File metadata and controls
182 lines (153 loc) · 7.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import json
from typing import Any
import requests
from packaging import version
from rapidata import __version__
import uuid
import random
from rapidata.service.openapi_service import OpenAPIService
from rapidata.rapidata_client.benchmark.rapidata_benchmark_manager import (
RapidataBenchmarkManager,
)
from rapidata.rapidata_client.audience.rapidata_audience_manager import (
RapidataAudienceManager,
)
from rapidata.rapidata_client.order.rapidata_order_manager import RapidataOrderManager
from rapidata.rapidata_client.validation.validation_set_manager import (
ValidationSetManager,
)
from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
from rapidata.rapidata_client.config import (
logger,
tracer,
managed_print,
rapidata_config,
)
from rapidata.rapidata_client.datapoints._asset_uploader import AssetUploader
from rapidata.rapidata_client.job.rapidata_job_manager import RapidataJobManager
from rapidata.rapidata_client.flow.rapidata_flow_manager import RapidataFlowManager
from rapidata.rapidata_client.api.rapidata_api_client import optional_api_call
class RapidataClient:
"""The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
def __init__(
self,
client_id: str | None = None,
client_secret: str | None = None,
environment: str = "rapidata.ai",
oauth_scope: str = "openid roles email",
cert_path: str | None = None,
token: dict | None = None,
leeway: int = 60,
):
"""Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
If this is not successful, it will open a browser window and ask you to log in, then save your new credentials in said json file.
Args:
client_id (str): The client ID for authentication.
client_secret (str): The client secret for authentication.
environment (str, optional): The API endpoint.
oauth_scope (str, optional): The scopes to use for authentication. In general this does not need to be changed.
cert_path (str, optional): An optional path to a certificate file useful for development.
token (dict, optional): If you already have a token that the client should use for authentication. Important, if set, this needs to be the complete token object containing the access token, token type and expiration time.
leeway (int, optional): An optional leeway to use to determine if a token is expired. Defaults to 60 seconds.
Attributes:
order (RapidataOrderManager): The RapidataOrderManager instance.
validation (ValidationSetManager): The ValidationSetManager instance.
flow (RapidataFlowManager): The RapidataFlowManager instance.
audience (RapidataAudienceManager): The RapidataAudienceManager instance.
job (JobManager): The JobManager instance.
mri (RapidataBenchmarkManager): The RapidataBenchmarkManager instance.
"""
tracer.set_session_id(
uuid.UUID(int=random.Random().getrandbits(128), version=4).hex
)
with tracer.start_as_current_span("RapidataClient.__init__"):
logger.debug("Checking version")
self._check_version()
if environment != "rapidata.ai":
rapidata_config.logging.enable_otlp = False
logger.debug("Initializing OpenAPIService")
self._openapi_service = OpenAPIService(
client_id=client_id,
client_secret=client_secret,
environment=environment,
oauth_scope=oauth_scope,
cert_path=cert_path,
token=token,
leeway=leeway,
)
self._asset_uploader = AssetUploader(openapi_service=self._openapi_service)
logger.debug("Initializing RapidataOrderManager")
self.order = RapidataOrderManager(openapi_service=self._openapi_service)
logger.debug("Initializing ValidationSetManager")
self.validation = ValidationSetManager(
openapi_service=self._openapi_service
)
logger.debug("Initializing FlowManager")
self.flow = RapidataFlowManager(openapi_service=self._openapi_service)
logger.debug("Initializing JobManager")
self.job = RapidataJobManager(openapi_service=self._openapi_service)
logger.debug("Initializing RapidataBenchmarkManager")
self.mri = RapidataBenchmarkManager(openapi_service=self._openapi_service)
logger.debug("Initializing RapidataAudienceManager")
self.audience = RapidataAudienceManager(
openapi_service=self._openapi_service
)
logger.debug("Initializing RapidataDemographicManager")
self._demographic = DemographicManager(
openapi_service=self._openapi_service
)
self._check_beta_features() # can't be in the trace for some reason
def reset_credentials(self):
"""Reset the credentials saved in the configuration file for the current environment."""
logger.info("Resetting credentials")
self._openapi_service.reset_credentials()
logger.info("Credentials reset")
def clear_all_caches(self):
"""Clear all caches for the client."""
self._asset_uploader.clear_cache()
logger.info("All caches cleared")
def _check_beta_features(self):
"""Enable beta features for the client."""
with optional_api_call("check beta features"):
with tracer.start_as_current_span("RapidataClient.check_beta_features"):
result: dict[str, Any] = json.loads(
self._openapi_service.api_client.call_api(
"GET",
f"https://auth.{self._openapi_service.environment}/connect/userinfo",
_request_timeout=1,
)
.read()
.decode("utf-8")
)
logger.debug("Userinfo: %s", result)
client_id = result.get("sub")
email = result.get("email")
if client_id and email:
tracer.set_user_info(client_id=client_id, email=email)
if "Admin" not in result.get("role", []):
logger.debug("User is not an admin, not enabling beta features")
return
logger.debug("User is an admin, enabling beta features")
rapidata_config.enableBetaFeatures = True
def _check_version(self):
with optional_api_call("version check"):
response = requests.get(
"https://api.github.com/repos/RapidataAI/rapidata-python-sdk/releases/latest",
headers={"Accept": "application/vnd.github.v3+json"},
timeout=1,
)
if response.status_code == 200:
latest_version = response.json()["tag_name"].lstrip("v")
if version.parse(latest_version) > version.parse(__version__):
managed_print(
f"""A new version of the Rapidata SDK is available: {latest_version}
Your current version is: {__version__}"""
)
else:
logger.debug(
"Current version is up to date. Version: %s", __version__
)
def __str__(self) -> str:
return f"RapidataClient(environment={self._openapi_service.environment})"
def __repr__(self) -> str:
return self.__str__()