-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy pathclients.py
More file actions
240 lines (209 loc) · 8.54 KB
/
clients.py
File metadata and controls
240 lines (209 loc) · 8.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import requests
import base64
import time
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from enum import Enum
import json
from requests.exceptions import HTTPError
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.exceptions import InvalidSignature
from typing import Optional, Dict, Any
class KalshiClient:
def get_markets(self, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return self._get("/markets", params)
import websockets
class Environment(Enum):
DEMO = "demo"
PROD = "prod"
class KalshiBaseClient:
"""Base client class for interacting with the Kalshi API."""
def __init__(
self,
key_id: str,
private_key: rsa.RSAPrivateKey,
environment: Environment = Environment.DEMO,
):
"""Initializes the client with the provided API key and private key.
Args:
key_id (str): Your Kalshi API key ID.
private_key (rsa.RSAPrivateKey): Your RSA private key.
environment (Environment): The API environment to use (DEMO or PROD).
"""
self.key_id = key_id
self.private_key = private_key
self.environment = environment
self.last_api_call = datetime.now()
if self.environment == Environment.DEMO:
self.HTTP_BASE_URL = "https://demo-api.kalshi.co"
self.WS_BASE_URL = "wss://demo-api.kalshi.co"
elif self.environment == Environment.PROD:
self.HTTP_BASE_URL = "https://api.elections.kalshi.com"
self.WS_BASE_URL = "wss://api.elections.kalshi.com"
else:
raise ValueError("Invalid environment")
def request_headers(self, method: str, path: str) -> Dict[str, Any]:
"""Generates the required authentication headers for API requests."""
current_time_milliseconds = int(time.time() * 1000)
timestamp_str = str(current_time_milliseconds)
# Remove query params from path
path_parts = path.split('?')
msg_string = timestamp_str + method + path_parts[0]
signature = self.sign_pss_text(msg_string)
headers = {
"Content-Type": "application/json",
"KALSHI-ACCESS-KEY": self.key_id,
"KALSHI-ACCESS-SIGNATURE": signature,
"KALSHI-ACCESS-TIMESTAMP": timestamp_str,
}
return headers
def sign_pss_text(self, text: str) -> str:
"""Signs the text using RSA-PSS and returns the base64 encoded signature."""
message = text.encode('utf-8')
try:
signature = self.private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.DIGEST_LENGTH
),
hashes.SHA256()
)
return base64.b64encode(signature).decode('utf-8')
except InvalidSignature as e:
raise ValueError("RSA sign PSS failed") from e
class KalshiHttpClient(KalshiBaseClient):
"""Client for handling HTTP connections to the Kalshi API."""
def __init__(
self,
key_id: str,
private_key: rsa.RSAPrivateKey,
environment: Environment = Environment.DEMO,
):
super().__init__(key_id, private_key, environment)
self.host = self.HTTP_BASE_URL
self.exchange_url = "/trade-api/v2/exchange"
self.markets_url = "/trade-api/v2/markets"
self.portfolio_url = "/trade-api/v2/portfolio"
def rate_limit(self) -> None:
"""Built-in rate limiter to prevent exceeding API rate limits."""
THRESHOLD_IN_MILLISECONDS = 100
now = datetime.now()
threshold_in_microseconds = 1000 * THRESHOLD_IN_MILLISECONDS
threshold_in_seconds = THRESHOLD_IN_MILLISECONDS / 1000
if now - self.last_api_call < timedelta(microseconds=threshold_in_microseconds):
time.sleep(threshold_in_seconds)
self.last_api_call = datetime.now()
def raise_if_bad_response(self, response: requests.Response) -> None:
"""Raises an HTTPError if the response status code indicates an error."""
if response.status_code not in range(200, 299):
response.raise_for_status()
def post(self, path: str, body: dict) -> Any:
"""Performs an authenticated POST request to the Kalshi API."""
self.rate_limit()
response = requests.post(
self.host + path,
json=body,
headers=self.request_headers("POST", path)
)
self.raise_if_bad_response(response)
return response.json()
def get(self, path: str, params: Dict[str, Any] = {}) -> Any:
"""Performs an authenticated GET request to the Kalshi API."""
self.rate_limit()
response = requests.get(
self.host + path,
headers=self.request_headers("GET", path),
params=params
)
self.raise_if_bad_response(response)
return response.json()
def delete(self, path: str, params: Dict[str, Any] = {}) -> Any:
"""Performs an authenticated DELETE request to the Kalshi API."""
self.rate_limit()
response = requests.delete(
self.host + path,
headers=self.request_headers("DELETE", path),
params=params
)
self.raise_if_bad_response(response)
return response.json()
def get_balance(self) -> Dict[str, Any]:
"""Retrieves the account balance."""
return self.get(self.portfolio_url + '/balance')
def get_exchange_status(self) -> Dict[str, Any]:
"""Retrieves the exchange status."""
return self.get(self.exchange_url + "/status")
def get_trades(
self,
ticker: Optional[str] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
max_ts: Optional[int] = None,
min_ts: Optional[int] = None,
) -> Dict[str, Any]:
"""Retrieves trades based on provided filters."""
params = {
'ticker': ticker,
'limit': limit,
'cursor': cursor,
'max_ts': max_ts,
'min_ts': min_ts,
}
# Remove None values
params = {k: v for k, v in params.items() if v is not None}
return self.get(self.markets_url + '/trades', params=params)
class KalshiWebSocketClient(KalshiBaseClient):
"""Client for handling WebSocket connections to the Kalshi API."""
def __init__(
self,
key_id: str,
private_key: rsa.RSAPrivateKey,
environment: Environment = Environment.DEMO,
):
super().__init__(key_id, private_key, environment)
self.ws = None
self.url_suffix = "/trade-api/ws/v2"
self.message_id = 1 # Add counter for message IDs
async def connect(self):
"""Establishes a WebSocket connection using authentication."""
host = self.WS_BASE_URL + self.url_suffix
auth_headers = self.request_headers("GET", self.url_suffix)
async with websockets.connect(host, additional_headers=auth_headers) as websocket:
self.ws = websocket
await self.on_open()
await self.handler()
async def on_open(self):
"""Callback when WebSocket connection is opened."""
print("WebSocket connection opened.")
await self.subscribe_to_tickers()
async def subscribe_to_tickers(self):
"""Subscribe to ticker updates for all markets."""
subscription_message = {
"id": self.message_id,
"cmd": "subscribe",
"params": {
"channels": ["ticker"]
}
}
await self.ws.send(json.dumps(subscription_message))
self.message_id += 1
async def handler(self):
"""Handle incoming messages."""
try:
async for message in self.ws:
await self.on_message(message)
except websockets.ConnectionClosed as e:
await self.on_close(e.code, e.reason)
except Exception as e:
await self.on_error(e)
async def on_message(self, message):
"""Callback for handling incoming messages."""
print("Received message:", message)
async def on_error(self, error):
"""Callback for handling errors."""
print("WebSocket error:", error)
async def on_close(self, close_status_code, close_msg):
"""Callback when WebSocket connection is closed."""
print("WebSocket connection closed with code:", close_status_code, "and message:", close_msg)