Skip to content

Commit 73aaf53

Browse files
CaymanWilliamsCayman Williams
andauthored
orjson replacement with json package (#72)
* initial orjson replacement * format/lint * fix imports * fix a few json dumps calls * bump version * Fix monitor cluster methods --------- Co-authored-by: Cayman Williams <cayman@synccomputing.com>
1 parent b639427 commit 73aaf53

16 files changed

Lines changed: 138 additions & 113 deletions

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ dependencies = [
2929
"boto3~=1.26.0",
3030
"pydantic~=1.10.0",
3131
"httpx~=0.23.0",
32-
"orjson~=3.8.0",
3332
"click~=8.1.0",
3433
"tenacity==8.2.2",
3534
"azure-identity==1.13.0",

sync/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Library for leveraging the power of Sync"""
2-
__version__ = "0.5.1"
2+
__version__ = "0.5.2"
33

44
TIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

sync/awsdatabricks.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import json
12
import logging
23
from time import sleep
34
from typing import List, Tuple
45
from urllib.parse import urlparse
56

67
import boto3 as boto
78
import botocore
8-
import orjson
99
from botocore.exceptions import ClientError
1010

1111
import sync._databricks
@@ -56,6 +56,7 @@
5656
Response,
5757
)
5858
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
59+
from sync.utils.json import DefaultDateTimeEncoder
5960

6061
__all__ = [
6162
"get_access_report",
@@ -273,7 +274,7 @@ def _load_aws_cluster_info(cluster: dict) -> Tuple[Response[dict], Response[dict
273274
cluster_info_file_response = _get_cluster_instances_from_dbfs(cluster_info_file_key)
274275

275276
cluster_info = (
276-
orjson.loads(cluster_info_file_response) if cluster_info_file_response else None
277+
json.loads(cluster_info_file_response) if cluster_info_file_response else None
277278
)
278279

279280
# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
@@ -409,12 +410,16 @@ def write_file(body: bytes):
409410
all_timelines = retired_timelines + list(active_timelines_by_id.values())
410411

411412
write_file(
412-
orjson.dumps(
413-
{
414-
"instances": list(all_inst_by_id.values()),
415-
"instance_timelines": all_timelines,
416-
"volumes": list(recorded_volumes_by_id.values()),
417-
}
413+
bytes(
414+
json.dumps(
415+
{
416+
"instances": list(all_inst_by_id.values()),
417+
"instance_timelines": all_timelines,
418+
"volumes": list(recorded_volumes_by_id.values()),
419+
},
420+
cls=DefaultDateTimeEncoder,
421+
),
422+
"utf-8",
418423
)
419424
)
420425
except Exception as e:

sync/awsemr.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import datetime
66
import io
7+
import json
78
import logging
89
import re
910
from copy import deepcopy
@@ -12,7 +13,6 @@
1213
from uuid import uuid4
1314

1415
import boto3 as boto
15-
import orjson
1616
from dateutil.parser import parse as dateparse
1717

1818
from sync import TIME_FORMAT
@@ -28,6 +28,7 @@
2828
ProjectError,
2929
Response,
3030
)
31+
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds
3132

3233
logger = logging.getLogger(__name__)
3334

@@ -364,7 +365,7 @@ def get_project_cluster_report( # noqa: C901
364365
s3.download_fileobj(parsed_project_url.netloc, config_key, config)
365366
return Response(
366367
result=(
367-
orjson.loads(config.getvalue().decode()),
368+
json.loads(config.getvalue().decode()),
368369
f"s3://{parsed_project_url.netloc}/{log_key}",
369370
)
370371
)
@@ -753,10 +754,7 @@ def _upload_object(obj: dict, s3_url: str) -> Response[str]:
753754
s3 = boto.client("s3")
754755
s3.upload_fileobj(
755756
io.BytesIO(
756-
orjson.dumps(
757-
obj,
758-
option=orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS | orjson.OPT_NAIVE_UTC,
759-
)
757+
bytes(json.dumps(obj, cls=DateTimeEncoderNaiveUTCDropMicroseconds), "utf-8")
760758
),
761759
parsed_url.netloc,
762760
obj_key,

sync/azuredatabricks.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import json
12
import logging
23
import os
34
import sys
45
from time import sleep
5-
from typing import List, Dict, Type, TypeVar, Optional
6+
from typing import Dict, List, Optional, Type, TypeVar
67
from urllib.parse import urlparse
78

8-
import orjson
99
from azure.common.credentials import get_cli_profile
1010
from azure.core.exceptions import ClientAuthenticationError
1111
from azure.identity import DefaultAzureCredential
@@ -59,6 +59,7 @@
5959
Response,
6060
)
6161
from sync.utils.dbfs import format_dbfs_filepath, write_dbfs_file
62+
from sync.utils.json import DefaultDateTimeEncoder
6263

6364
__all__ = [
6465
"get_access_report",
@@ -266,9 +267,7 @@ def _get_cluster_instances(cluster: dict) -> Response[dict]:
266267
)
267268

268269
cluster_instances = (
269-
orjson.loads(cluster_instances_file_response)
270-
if cluster_instances_file_response
271-
else None
270+
json.loads(cluster_instances_file_response) if cluster_instances_file_response else None
272271
)
273272

274273
# If this cluster does not have the "Sync agent" configured, attempt a best-effort snapshot of the instances that
@@ -371,11 +370,15 @@ def write_file(body: bytes):
371370
all_timelines = retired_timelines + list(active_timelines_by_id.values())
372371

373372
write_file(
374-
orjson.dumps(
375-
{
376-
"instances": list(all_vms_by_id.values()),
377-
"timelines": all_timelines,
378-
}
373+
bytes(
374+
json.dumps(
375+
{
376+
"instances": list(all_vms_by_id.values()),
377+
"timelines": all_timelines,
378+
},
379+
cls=DefaultDateTimeEncoder,
380+
),
381+
"utf-8",
379382
)
380383
)
381384

sync/cli/_databricks.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import json
12
from typing import Tuple
23

34
import click
4-
import orjson
55

66
from sync.api.projects import (
77
create_project_recommendation,
@@ -11,6 +11,7 @@
1111
from sync.cli.util import validate_project
1212
from sync.config import CONFIG
1313
from sync.models import DatabricksComputeType, DatabricksPlanType, Platform, Preference
14+
from sync.utils.json import DateTimeEncoderNaiveUTC
1415

1516
pass_platform = click.make_pass_decorator(Platform)
1617

@@ -202,9 +203,10 @@ def get_recommendation(project: dict, recommendation_id: str):
202203
click.echo("Recommendation generation failed.", err=True)
203204
else:
204205
click.echo(
205-
orjson.dumps(
206+
json.dumps(
206207
recommendation,
207-
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
208+
indent=2,
209+
cls=DateTimeEncoderNaiveUTC,
208210
)
209211
)
210212
else:
@@ -223,9 +225,10 @@ def get_submission(project: dict, submission_id: str):
223225
click.echo("Submission generation failed.", err=True)
224226
else:
225227
click.echo(
226-
orjson.dumps(
228+
json.dumps(
227229
submission,
228-
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
230+
indent=2,
231+
cls=DateTimeEncoderNaiveUTC,
229232
)
230233
)
231234
else:
@@ -277,9 +280,10 @@ def get_cluster_report(
277280
config = config_response.result
278281
if config:
279282
click.echo(
280-
orjson.dumps(
283+
json.dumps(
281284
config.dict(exclude_none=True),
282-
option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z,
285+
indent=2,
286+
cls=DateTimeEncoderNaiveUTC,
283287
)
284288
)
285289
else:

sync/cli/awsemr.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
import json
12
from io import TextIOWrapper
23
from typing import Dict
34

45
import click
5-
import orjson
66

77
from sync import awsemr
88
from sync.api.predictions import get_prediction
99
from sync.cli.util import validate_project
1010
from sync.config import CONFIG
1111
from sync.models import Platform, Preference
12+
from sync.utils.json import DateTimeEncoderNaiveUTC
1213

1314

1415
@click.group
@@ -34,7 +35,7 @@ def run_job_flow(job_flow: TextIOWrapper, project: dict = None, region: str = No
3435
"""Run a job flow
3536
3637
JOB_FLOW is a file containing the RunJobFlow request object"""
37-
job_flow_obj = orjson.loads(job_flow.read())
38+
job_flow_obj = json.loads(job_flow.read())
3839

3940
run_response = awsemr.run_and_record_job_flow(
4041
job_flow_obj, project["id"] if project else None, region
@@ -125,11 +126,7 @@ def get_cluster_report(cluster_id: str, region: str = None):
125126
config_response = awsemr.get_cluster_report(cluster_id, region)
126127
config = config_response.result
127128
if config:
128-
click.echo(
129-
orjson.dumps(
130-
config, option=orjson.OPT_INDENT_2 | orjson.OPT_NAIVE_UTC | orjson.OPT_UTC_Z
131-
)
132-
)
129+
click.echo(json.dumps(config, indent=2, cls=DateTimeEncoderNaiveUTC))
133130
else:
134131
click.echo(f"Failed to create prediction. {config_response.error}", err=True)
135132

sync/cli/predictions.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import io
2+
import json
23
from pathlib import Path
34
from urllib.parse import urlparse
45

56
import boto3 as boto
67
import click
7-
import orjson
88

99
from sync.api.predictions import (
1010
create_prediction,
@@ -17,6 +17,7 @@
1717
from sync.cli.util import validate_project
1818
from sync.config import CONFIG
1919
from sync.models import Platform, Preference
20+
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds
2021

2122

2223
@click.group
@@ -48,12 +49,12 @@ def generate(
4849
parsed_report_arg = urlparse(report)
4950
if parsed_report_arg.scheme == "":
5051
with open(report) as report_fobj:
51-
report = orjson.loads(report_fobj.read())
52+
report = json.loads(report_fobj.read())
5253
elif parsed_report_arg.scheme == "s3":
5354
s3 = boto.client("s3")
5455
report_io = io.BytesIO()
5556
s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io)
56-
report = orjson.loads(report_io.getvalue())
57+
report = json.loads(report_io.getvalue())
5758
else:
5859
ctx.fail("Unsupported report argument")
5960

@@ -83,13 +84,7 @@ def generate(
8384
prediction = prediction_response.result
8485
if prediction:
8586
click.echo(
86-
orjson.dumps(
87-
prediction,
88-
option=orjson.OPT_INDENT_2
89-
| orjson.OPT_UTC_Z
90-
| orjson.OPT_NAIVE_UTC
91-
| orjson.OPT_OMIT_MICROSECONDS,
92-
)
87+
json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds)
9388
)
9489
else:
9590
click.echo(str(response.error), err=True)
@@ -108,12 +103,12 @@ def create(ctx: click.Context, platform: Platform, event_log: str, report: str,
108103
parsed_report_arg = urlparse(report)
109104
if parsed_report_arg.scheme == "":
110105
with open(report) as report_fobj:
111-
report = orjson.loads(report_fobj.read())
106+
report = json.loads(report_fobj.read())
112107
elif parsed_report_arg.scheme == "s3":
113108
s3 = boto.client("s3")
114109
report_io = io.BytesIO()
115110
s3.download_fileobj(parsed_report_arg.netloc, parsed_report_arg.path.lstrip("/"), report_io)
116-
report = orjson.loads(report_io.getvalue())
111+
report = json.loads(report_io.getvalue())
117112
else:
118113
ctx.fail("Unsupported report argument")
119114

@@ -161,15 +156,7 @@ def status(prediction_id: str):
161156
def get(prediction_id: str, preference: Preference):
162157
"""Retrieve a prediction"""
163158
response = get_prediction(prediction_id, preference.value)
164-
click.echo(
165-
orjson.dumps(
166-
response.result,
167-
option=orjson.OPT_INDENT_2
168-
| orjson.OPT_UTC_Z
169-
| orjson.OPT_NAIVE_UTC
170-
| orjson.OPT_OMIT_MICROSECONDS,
171-
)
172-
)
159+
click.echo(json.dumps(response.result, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))
173160

174161

175162
@predictions.command

sync/cli/projects.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import json
2+
13
import click
2-
import orjson
34

45
from sync.api.projects import (
56
create_project,
@@ -12,6 +13,7 @@
1213
from sync.cli.util import validate_project
1314
from sync.config import CONFIG
1415
from sync.models import Preference
16+
from sync.utils.json import DateTimeEncoderNaiveUTCDropMicroseconds
1517

1618

1719
@click.group
@@ -40,12 +42,7 @@ def get(project: dict):
4042
response = get_project(project["id"])
4143
project = response.result
4244
if project:
43-
click.echo(
44-
orjson.dumps(
45-
project,
46-
option=orjson.OPT_INDENT_2 | orjson.OPT_UTC_Z | orjson.OPT_OMIT_MICROSECONDS,
47-
)
48-
)
45+
click.echo(json.dumps(project, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))
4946
else:
5047
click.echo(str(response.error), err=True)
5148

@@ -183,14 +180,6 @@ def get_latest_prediction(project: dict, preference: Preference):
183180
prediction_response = get_prediction(project["id"], preference)
184181
prediction = prediction_response.result
185182
if prediction:
186-
click.echo(
187-
orjson.dumps(
188-
prediction,
189-
option=orjson.OPT_INDENT_2
190-
| orjson.OPT_UTC_Z
191-
| orjson.OPT_NAIVE_UTC
192-
| orjson.OPT_OMIT_MICROSECONDS,
193-
)
194-
)
183+
click.echo(json.dumps(prediction, indent=2, cls=DateTimeEncoderNaiveUTCDropMicroseconds))
195184
else:
196185
click.echo(str(prediction_response.error), err=True)

0 commit comments

Comments
 (0)