Skip to content

Commit 72aa97a

Browse files
add s3 kv store (#393)
1 parent 5fa1e1f commit 72aa97a

1 file changed

Lines changed: 271 additions & 0 deletions

File tree

mlx/s3_kv_store.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
import json
2+
import posixpath
3+
import re
4+
import argparse
5+
from typing import Optional, Dict, List, Any, Tuple
6+
from urllib.parse import quote, unquote
7+
import boto3
8+
from botocore.exceptions import ClientError
9+
10+
INDEX_SEPARATOR = "__i__"
11+
KV_SEPARATOR = "="
12+
FILENAME_SUFFIX = ".json"
13+
14+
15+
def _encode_component(s: str) -> str:
16+
return quote(s, safe="")
17+
18+
19+
def _decode_component(s: str) -> str:
20+
return unquote(s)
21+
22+
23+
def _build_filename(key: str, indexes: Optional[Dict[str, str]] = None) -> str:
24+
parts = [_encode_component(key)]
25+
if indexes:
26+
for k in sorted(indexes.keys()):
27+
v = indexes[k]
28+
parts.append(f"{_encode_component(k)}{KV_SEPARATOR}{_encode_component(str(v))}")
29+
return INDEX_SEPARATOR.join(parts) + FILENAME_SUFFIX
30+
31+
32+
def _parse_filename(filename: str) -> Tuple[str, Dict[str, str]]:
33+
if not filename.endswith(FILENAME_SUFFIX):
34+
raise ValueError("invalid filename (missing .json suffix)")
35+
core = filename[:-len(FILENAME_SUFFIX)]
36+
parts = core.split(INDEX_SEPARATOR)
37+
if not parts:
38+
raise ValueError("invalid filename")
39+
key = _decode_component(parts[0])
40+
indexes: Dict[str, str] = {}
41+
for p in parts[1:]:
42+
if KV_SEPARATOR not in p:
43+
continue
44+
k_enc, v_enc = p.split(KV_SEPARATOR, 1)
45+
k = _decode_component(k_enc)
46+
v = _decode_component(v_enc)
47+
indexes[k] = v
48+
return key, indexes
49+
50+
51+
class S3KVStore:
52+
def __init__(self, bucket: str, store_name: str, s3_client: Optional[Any] = None, endpoint_url: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None):
53+
self.bucket = bucket
54+
self.store_name = store_name.strip("/")
55+
if s3_client is None:
56+
self.s3 = boto3.client(
57+
"s3",
58+
endpoint_url=endpoint_url,
59+
aws_access_key_id=aws_access_key_id,
60+
aws_secret_access_key=aws_secret_access_key,
61+
)
62+
else:
63+
self.s3 = s3_client
64+
65+
def _prefix(self) -> str:
66+
return f"{self.store_name}/" if self.store_name else ""
67+
68+
def _s3_key_for_filename(self, filename: str) -> str:
69+
return posixpath.join(self._prefix(), filename)
70+
71+
def list(self, prefix: Optional[str] = None, max_keys: int = 1000) -> List[Dict[str, Any]]:
72+
s3_prefix = self._prefix()
73+
continuation_token = None
74+
results: List[Dict[str, Any]] = []
75+
76+
while True:
77+
kwargs = {"Bucket": self.bucket, "Prefix": s3_prefix, "MaxKeys": max_keys}
78+
if continuation_token:
79+
kwargs["ContinuationToken"] = continuation_token
80+
resp = self.s3.list_objects_v2(**kwargs)
81+
contents = resp.get("Contents", [])
82+
for obj in contents:
83+
full_key = obj["Key"]
84+
filename = posixpath.basename(full_key)
85+
try:
86+
logical_key, indexes = _parse_filename(filename)
87+
except ValueError:
88+
continue
89+
if prefix and not logical_key.startswith(prefix):
90+
continue
91+
results.append({
92+
"s3_key": full_key,
93+
"filename": filename,
94+
"key": logical_key,
95+
"indexes": indexes,
96+
"size": obj.get("Size", 0),
97+
"last_modified": obj.get("LastModified"),
98+
})
99+
if not resp.get("IsTruncated"):
100+
break
101+
continuation_token = resp.get("NextContinuationToken")
102+
103+
return results
104+
105+
def _match_indexes(self, item_indexes: Dict[str, str], filt: Dict[str, Any]) -> bool:
106+
for fk, fv in filt.items():
107+
if fk not in item_indexes:
108+
return False
109+
val = item_indexes[fk]
110+
if isinstance(fv, (list, tuple, set)):
111+
if val not in {str(x) for x in fv}:
112+
return False
113+
elif isinstance(fv, re.Pattern):
114+
if not fv.search(val):
115+
return False
116+
else:
117+
if val != str(fv):
118+
return False
119+
return True
120+
121+
def get(self, key: str, index_filter: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
122+
matches = self._find_objects_for_key(key, index_filter=index_filter)
123+
if not matches:
124+
raise KeyError(f"key not found: {key} (filter={index_filter})")
125+
if len(matches) > 1:
126+
raise ValueError(f"multiple objects match key={key}; refine using index_filter: {matches}")
127+
s3_key = matches[0]["s3_key"]
128+
try:
129+
resp = self.s3.get_object(Bucket=self.bucket, Key=s3_key)
130+
body = resp["Body"].read()
131+
return json.loads(body.decode("utf-8"))
132+
except ClientError as e:
133+
raise IOError(f"s3 get_object failed: {e}")
134+
135+
def put(self, key: str, value: Dict[str, Any], indexes: Optional[Dict[str, Any]] = None, overwrite: bool = False) -> str:
136+
if overwrite:
137+
existing = self._find_objects_for_key(key)
138+
for obj in existing:
139+
self.s3.delete_object(Bucket=self.bucket, Key=obj["s3_key"])
140+
141+
filename = _build_filename(key, {k: str(v) for k, v in (indexes or {}).items()})
142+
s3_key = self._s3_key_for_filename(filename)
143+
if not overwrite:
144+
try:
145+
self.s3.head_object(Bucket=self.bucket, Key=s3_key)
146+
raise FileExistsError(f"object already exists: {s3_key}")
147+
except ClientError as e:
148+
if e.response["Error"]["Code"] not in ("404", "NotFound", "NoSuchKey"):
149+
raise
150+
151+
payload = json.dumps(value, ensure_ascii=False).encode("utf-8")
152+
self.s3.put_object(Bucket=self.bucket, Key=s3_key, Body=payload, ContentType="application/json")
153+
return s3_key
154+
155+
def update(self, key: str, value: Dict[str, Any], index_filter: Optional[Dict[str, Any]] = None, new_indexes: Optional[Dict[str, Any]] = None) -> str:
156+
matches = self._find_objects_for_key(key, index_filter=index_filter)
157+
if not matches:
158+
raise KeyError(f"no object matches key={key} index_filter={index_filter}")
159+
if len(matches) > 1:
160+
raise ValueError(f"multiple objects match key={key} index_filter={index_filter}: {matches}")
161+
162+
old = matches[0]
163+
target_indexes = new_indexes if new_indexes is not None else old["indexes"]
164+
new_filename = _build_filename(key, {k: str(v) for k, v in (target_indexes or {}).items()})
165+
new_s3_key = self._s3_key_for_filename(new_filename)
166+
payload = json.dumps(value, ensure_ascii=False).encode("utf-8")
167+
self.s3.put_object(Bucket=self.bucket, Key=new_s3_key, Body=payload, ContentType="application/json")
168+
if old["s3_key"] != new_s3_key:
169+
self.s3.delete_object(Bucket=self.bucket, Key=old["s3_key"])
170+
return new_s3_key
171+
172+
def delete(self, key: str, index_filter: Optional[Dict[str, Any]] = None) -> int:
173+
matches = self._find_objects_for_key(key, index_filter=index_filter)
174+
count = 0
175+
for obj in matches:
176+
self.s3.delete_object(Bucket=self.bucket, Key=obj["s3_key"])
177+
count += 1
178+
return count
179+
180+
def search(self, index_filter: Dict[str, Any]) -> List[Dict[str, Any]]:
181+
all_items = self.list()
182+
return [it for it in all_items if self._match_indexes(it["indexes"], index_filter)]
183+
184+
def _find_objects_for_key(self, key: str, index_filter: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
185+
candidates = self.list(prefix=key)
186+
if index_filter is None:
187+
return candidates
188+
return [c for c in candidates if self._match_indexes(c["indexes"], index_filter)]
189+
190+
191+
# ---------------- CLI ----------------
192+
def main():
193+
parser = argparse.ArgumentParser(description="S3 KV Store CLI")
194+
parser.add_argument("bucket")
195+
parser.add_argument("store")
196+
parser.add_argument("--endpoint")
197+
198+
sub = parser.add_subparsers(dest="cmd", required=True)
199+
200+
# put
201+
sp = sub.add_parser("put")
202+
sp.add_argument("key")
203+
sp.add_argument("--indexes", type=json.loads, default="{}")
204+
sp.add_argument("--value")
205+
sp.add_argument("--value-file")
206+
sp.add_argument("--overwrite", action="store_true")
207+
208+
# get
209+
sp = sub.add_parser("get")
210+
sp.add_argument("key")
211+
sp.add_argument("--filter", type=json.loads, default="{}")
212+
213+
# update
214+
sp = sub.add_parser("update")
215+
sp.add_argument("key")
216+
sp.add_argument("--filter", type=json.loads, default="{}")
217+
sp.add_argument("--new-indexes", type=json.loads, default=None)
218+
sp.add_argument("--value")
219+
sp.add_argument("--value-file")
220+
221+
# delete
222+
sp = sub.add_parser("delete")
223+
sp.add_argument("key")
224+
sp.add_argument("--filter", type=json.loads, default="{}")
225+
226+
# list
227+
sp = sub.add_parser("list")
228+
sp.add_argument("--prefix")
229+
230+
# search
231+
sp = sub.add_parser("search")
232+
sp.add_argument("--filter", type=json.loads, required=True)
233+
234+
args = parser.parse_args()
235+
store = MLX(bucket=args.bucket, store_name=args.store, endpoint_url=args.endpoint)
236+
237+
if args.cmd == "put":
238+
if args.value_file:
239+
value = json.load(open(args.value_file))
240+
else:
241+
value = json.loads(args.value)
242+
key = store.put(args.key, value, indexes=args.indexes, overwrite=args.overwrite)
243+
print(key)
244+
245+
elif args.cmd == "get":
246+
value = store.get(args.key, index_filter=args.filter)
247+
print(json.dumps(value, indent=2))
248+
249+
elif args.cmd == "update":
250+
if args.value_file:
251+
value = json.load(open(args.value_file))
252+
else:
253+
value = json.loads(args.value)
254+
key = store.update(args.key, value, index_filter=args.filter, new_indexes=args.new_indexes)
255+
print(key)
256+
257+
elif args.cmd == "delete":
258+
count = store.delete(args.key, index_filter=args.filter)
259+
print(f"Deleted {count} object(s)")
260+
261+
elif args.cmd == "list":
262+
items = store.list(prefix=args.prefix)
263+
print(json.dumps(items, indent=2, default=str))
264+
265+
elif args.cmd == "search":
266+
items = store.search(args.filter)
267+
print(json.dumps(items, indent=2, default=str))
268+
269+
270+
if __name__ == "__main__":
271+
main()

0 commit comments

Comments
 (0)