Skip to content

Commit 6497dbd

Browse files
✨ add support for additional utilities
1 parent c749640 commit 6497dbd

File tree

15 files changed

+230
-75
lines changed

15 files changed

+230
-75
lines changed

mindee/client_v2.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from time import sleep
2-
from typing import Optional, Union
2+
from typing import Optional, Union, Type
33

44
from mindee.client_mixin import ClientMixin
55
from mindee.error.mindee_error import MindeeError
66
from mindee.error.mindee_http_error_v2 import handle_error_v2
7-
from mindee.input import UrlInputSource
7+
from mindee.input import UrlInputSource, UtilityParameters
88
from mindee.input.inference_parameters import InferenceParameters
99
from mindee.input.polling_options import PollingOptions
1010
from mindee.input.sources.local_input_source import LocalInputSource
@@ -15,6 +15,7 @@
1515
is_valid_post_response,
1616
)
1717
from mindee.parsing.v2.common_response import CommonStatus
18+
from mindee.parsing.v2.base_inference import BaseInference, BaseInferenceResponse, TypeInferenceV2
1819
from mindee.parsing.v2.inference_response import InferenceResponse
1920
from mindee.parsing.v2.job_response import JobResponse
2021

@@ -41,7 +42,8 @@ def __init__(self, api_key: Optional[str] = None) -> None:
4142
def enqueue_inference(
4243
self,
4344
input_source: Union[LocalInputSource, UrlInputSource],
44-
params: InferenceParameters,
45+
params: Union[InferenceParameters, UtilityParameters],
46+
slug: Optional[str] = None
4547
) -> JobResponse:
4648
"""
4749
Enqueues a document to a given model.
@@ -52,16 +54,18 @@ def enqueue_inference(
5254
:return: A valid inference response.
5355
"""
5456
logger.debug("Enqueuing inference using model: %s", params.model_id)
55-
5657
response = self.mindee_api.req_post_inference_enqueue(
57-
input_source=input_source, params=params
58+
input_source=input_source,
59+
params=params,
60+
slug=slug
5861
)
5962
dict_response = response.json()
6063

6164
if not is_valid_post_response(response):
6265
handle_error_v2(dict_response)
6366
return JobResponse(dict_response)
6467

68+
6569
def get_job(self, job_id: str) -> JobResponse:
6670
"""
6771
Get the status of an inference that was previously enqueued.
@@ -79,13 +83,18 @@ def get_job(self, job_id: str) -> JobResponse:
7983
dict_response = response.json()
8084
return JobResponse(dict_response)
8185

82-
def get_inference(self, inference_id: str) -> InferenceResponse:
86+
def get_inference(
87+
self,
88+
inference_id: str,
89+
inference_type: Type[BaseInference] = InferenceResponse
90+
) -> InferenceResponse:
8391
"""
8492
Get the result of an inference that was previously enqueued.
8593
8694
The inference will only be available after it has finished processing.
8795
8896
:param inference_id: UUID of the inference to retrieve.
97+
:param inference_type: Class of the product to instantiate.
8998
:return: An inference response.
9099
"""
91100
logger.debug("Fetching inference: %s", inference_id)
@@ -94,19 +103,20 @@ def get_inference(self, inference_id: str) -> InferenceResponse:
94103
if not is_valid_get_response(response):
95104
handle_error_v2(response.json())
96105
dict_response = response.json()
97-
return InferenceResponse(dict_response)
106+
return inference_type(dict_response)
98107

99-
def enqueue_and_get_inference(
108+
def _enqueue_and_get(
100109
self,
101110
input_source: Union[LocalInputSource, UrlInputSource],
102-
params: InferenceParameters,
103-
) -> InferenceResponse:
111+
params: Union[InferenceParameters, UtilityParameters],
112+
inference_type: Optional[Type[BaseInference]] = BaseInference
113+
) -> Union[InferenceResponse, BaseInferenceResponse]:
104114
"""
105115
Enqueues to an asynchronous endpoint and automatically polls for a response.
106116
107117
:param input_source: The document/source file to use. Can be local or remote.
108-
109118
:param params: Parameters to set when sending a file.
119+
:param inference_type: The product class to use for the response object.
110120
111121
:return: A valid inference response.
112122
"""
@@ -117,9 +127,14 @@ def enqueue_and_get_inference(
117127
params.polling_options.delay_sec,
118128
params.polling_options.max_retries,
119129
)
120-
enqueue_response = self.enqueue_inference(input_source, params)
130+
slug = inference_type if inference_type.get_slug() else None
131+
enqueue_response = self.enqueue_inference(
132+
input_source,
133+
params,
134+
slug
135+
)
121136
logger.debug(
122-
"Successfully enqueued inference with job id: %s", enqueue_response.job.id
137+
"Successfully enqueued document with job id: %s", enqueue_response.job.id
123138
)
124139
sleep(params.polling_options.initial_delay_sec)
125140
try_counter = 0
@@ -134,8 +149,48 @@ def enqueue_and_get_inference(
134149
f"Parsing failed for job {job_response.job.id}: {detail}"
135150
)
136151
if job_response.job.status == CommonStatus.PROCESSED.value:
137-
return self.get_inference(job_response.job.id)
152+
return self.get_inference(job_response.job.id, inference_type)
138153
try_counter += 1
139154
sleep(params.polling_options.delay_sec)
140155

141156
raise MindeeError(f"Couldn't retrieve document after {try_counter + 1} tries.")
157+
158+
def enqueue_and_get_inference(
159+
self,
160+
input_source: Union[LocalInputSource, UrlInputSource],
161+
params: InferenceParameters
162+
) -> InferenceResponse:
163+
"""
164+
Enqueues to an asynchronous endpoint and automatically polls for a response.
165+
166+
:param input_source: The document/source file to use. Can be local or remote.
167+
168+
:param params: Parameters to set when sending a file.
169+
170+
:return: A valid inference response.
171+
"""
172+
response = self._enqueue_and_get(input_source, params)
173+
assert isinstance(response, InferenceResponse), f'Invalid response type "{type(response)}"'
174+
return response
175+
176+
177+
def enqueue_and_get_utility(
178+
self,
179+
inference_type: Type[TypeInferenceV2],
180+
input_source: Union[LocalInputSource, UrlInputSource],
181+
params: UtilityParameters
182+
) -> TypeInferenceV2:
183+
"""
184+
Enqueues to an asynchronous endpoint and automatically polls for a response.
185+
186+
:param input_source: The document/source file to use. Can be local or remote.
187+
188+
:param params: Parameters to set when sending a file.
189+
190+
:param inference_type: The product class to use for the response object.
191+
192+
:return: A valid inference response.
193+
"""
194+
response = self._enqueue_and_get(input_source, params, inference_type)
195+
assert isinstance(response, inference_type), f'Invalid response type "{type(response)}"'
196+
return response

mindee/input/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from mindee.input.local_response import LocalResponse
2+
from mindee.input.base_parameters import BaseParameters
3+
from mindee.input.inference_parameters import InferenceParameters
4+
from mindee.input.utility_parameters import UtilityParameters
25
from mindee.input.page_options import PageOptions
36
from mindee.input.polling_options import PollingOptions
47
from mindee.input.sources.base_64_input import Base64Input
@@ -11,15 +14,17 @@
1114
from mindee.input.workflow_options import WorkflowOptions
1215

1316
__all__ = [
17+
"Base64Input",
18+
"BaseParameters",
19+
"BytesInput",
20+
"FileInput",
1421
"InputType",
1522
"LocalInputSource",
16-
"UrlInputSource",
23+
"LocalResponse",
24+
"PageOptions",
1725
"PathInput",
18-
"FileInput",
19-
"Base64Input",
20-
"BytesInput",
21-
"WorkflowOptions",
2226
"PollingOptions",
23-
"PageOptions",
24-
"LocalResponse",
27+
"UrlInputSource",
28+
"UtilityParameters",
29+
"WorkflowOptions",
2530
]

mindee/input/base_parameters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from abc import ABC
2+
from dataclasses import dataclass
3+
from typing import Optional, List
4+
5+
from mindee import PollingOptions
6+
7+
8+
@dataclass
9+
class BaseParameters(ABC):
10+
model_id: str
11+
"""ID of the model, required."""
12+
alias: Optional[str] = None
13+
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
14+
webhook_ids: Optional[List[str]] = None
15+
"""IDs of webhooks to propagate the API response to."""
16+
polling_options: Optional[PollingOptions] = None
17+
"""Options for polling. Set only if having timeout issues."""
18+
close_file: bool = True
19+
"""Whether to close the file after parsing."""

mindee/input/inference_parameters.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass, asdict
33
from typing import List, Optional, Union
44

5-
from mindee.input.polling_options import PollingOptions
5+
from mindee.input.utility_parameters import BaseParameters
66

77

88
@dataclass
@@ -44,7 +44,7 @@ class DataSchemaField(StringDataClass):
4444
guidelines: Optional[str] = None
4545
"""Optional extraction guidelines."""
4646
nested_fields: Optional[dict] = None
47-
"""Subfields when type is `nested_object`. Leave empty for other types"""
47+
"""Subfields when type is `nested_object`. Leave empty for other types."""
4848

4949

5050
@dataclass
@@ -76,13 +76,10 @@ def __post_init__(self) -> None:
7676
elif isinstance(self.replace, str):
7777
self.replace = DataSchemaReplace(**json.loads(self.replace))
7878

79-
8079
@dataclass
81-
class InferenceParameters:
80+
class InferenceParameters(BaseParameters):
8281
"""Inference parameters to set when sending a file."""
8382

84-
model_id: str
85-
"""ID of the model, required."""
8683
rag: Optional[bool] = None
8784
"""Enhance extraction accuracy with Retrieval-Augmented Generation."""
8885
raw_text: Optional[bool] = None
@@ -94,14 +91,6 @@ class InferenceParameters:
9491
Boost the precision and accuracy of all extractions.
9592
Calculate confidence scores for all fields, and fill their ``confidence`` attribute.
9693
"""
97-
alias: Optional[str] = None
98-
"""Use an alias to link the file to your own DB. If empty, no alias will be used."""
99-
webhook_ids: Optional[List[str]] = None
100-
"""IDs of webhooks to propagate the API response to."""
101-
polling_options: Optional[PollingOptions] = None
102-
"""Options for polling. Set only if having timeout issues."""
103-
close_file: bool = True
104-
"""Whether to close the file after parsing."""
10594
text_context: Optional[str] = None
10695
"""
10796
Additional text context used by the model during inference.

mindee/input/utility_parameters.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from dataclasses import dataclass
2+
3+
from mindee.input.base_parameters import BaseParameters
4+
5+
6+
@dataclass
7+
class UtilityParameters(BaseParameters):
8+
"""
9+
Parameters accepted by any of the asynchronous **inference** utility v2 endpoints.
10+
"""

mindee/mindee_http/mindee_api_v2.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import requests
55

66
from mindee.error.mindee_error import MindeeApiV2Error
7-
from mindee.input import LocalInputSource, UrlInputSource
7+
from mindee.input import LocalInputSource, UrlInputSource, UtilityParameters
88
from mindee.input.inference_parameters import InferenceParameters
99
from mindee.logger import logger
1010
from mindee.mindee_http.base_settings import USER_AGENT
@@ -74,34 +74,37 @@ def set_from_env(self) -> None:
7474
def req_post_inference_enqueue(
7575
self,
7676
input_source: Union[LocalInputSource, UrlInputSource],
77-
params: InferenceParameters,
77+
params: Union[InferenceParameters, UtilityParameters],
78+
slug: Optional[str] = None
7879
) -> requests.Response:
7980
"""
8081
Make an asynchronous request to POST a document for prediction on the V2 API.
8182
8283
:param input_source: Input object.
8384
:param params: Options for the enqueueing of the document.
85+
:param slug: Slug to use for the enqueueing, defaults to 'inferences'.
8486
:return: requests response.
8587
"""
88+
slug = slug if slug else "inferences"
8689
data: Dict[str, Union[str, list]] = {"model_id": params.model_id}
87-
url = f"{self.url_root}/inferences/enqueue"
88-
89-
if params.rag is not None:
90-
data["rag"] = str(params.rag).lower()
91-
if params.raw_text is not None:
92-
data["raw_text"] = str(params.raw_text).lower()
93-
if params.confidence is not None:
94-
data["confidence"] = str(params.confidence).lower()
95-
if params.polygon is not None:
96-
data["polygon"] = str(params.polygon).lower()
90+
url = f"{self.url_root}/{slug}/enqueue"
91+
if isinstance(params, InferenceParameters):
92+
if params.rag is not None:
93+
data["rag"] = str(params.rag).lower()
94+
if params.raw_text is not None:
95+
data["raw_text"] = str(params.raw_text).lower()
96+
if params.confidence is not None:
97+
data["confidence"] = str(params.confidence).lower()
98+
if params.polygon is not None:
99+
data["polygon"] = str(params.polygon).lower()
100+
if params.text_context and len(params.text_context):
101+
data["text_context"] = params.text_context
102+
if params.data_schema is not None:
103+
data["data_schema"] = str(params.data_schema)
97104
if params.webhook_ids and len(params.webhook_ids) > 0:
98105
data["webhook_ids"] = params.webhook_ids
99106
if params.alias and len(params.alias):
100107
data["alias"] = params.alias
101-
if params.text_context and len(params.text_context):
102-
data["text_context"] = params.text_context
103-
if params.data_schema is not None:
104-
data["data_schema"] = str(params.data_schema)
105108

106109
if isinstance(input_source, LocalInputSource):
107110
files = {"file": input_source.read_contents(params.close_file)}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from mindee.parsing.v2.base_inference.base_inference import BaseInference
2+
from mindee.parsing.v2.base_inference.base_inference_response import BaseInferenceResponse, TypeInferenceV2
3+
from mindee.parsing.v2.base_inference.split.split_inference import SplitInference
4+
from mindee.parsing.v2.base_inference.split.split_response import SplitResponse
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from abc import ABC
2+
3+
from mindee.parsing.common.string_dict import StringDict
4+
from mindee.parsing.v2 import InferenceModel, InferenceFile
5+
6+
7+
class BaseInference(ABC):
8+
model: InferenceModel
9+
"""Model info for the inference."""
10+
file: InferenceFile
11+
"""File info for the inference."""
12+
id: str
13+
"""ID of the inference."""
14+
_slug: str
15+
"""Slug of the inference."""
16+
17+
def __init__(self, raw_response: StringDict):
18+
self.id = raw_response["id"]
19+
self.model = InferenceModel(raw_response["model"])
20+
self.file = InferenceFile(raw_response["file"])
21+
22+
@classmethod
23+
def get_slug(cls):
24+
"""Getter for the inference slug."""
25+
return cls._slug
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import TypeVar, Generic
2+
3+
from mindee.parsing.common.string_dict import StringDict
4+
from mindee.parsing.v2.base_inference.base_inference import BaseInference
5+
6+
from mindee.parsing.v2.common_response import CommonResponse
7+
8+
TypeInferenceV2 = TypeVar("TypeInferenceV2", bound=BaseInference)
9+
10+
class BaseInferenceResponse(CommonResponse, Generic[TypeInferenceV2]):
11+
inference: TypeInferenceV2
12+
"""The inference result for a split utility request"""
13+
14+
def __init__(self, raw_response: StringDict) -> None:
15+
super().__init__(raw_response)
16+
self.inference = self._set_inference_type(raw_response["inference"])
17+
18+
def _set_inference_type(self, inference_response: StringDict):
19+
"""
20+
Sets the inference type.
21+
22+
:param inference_response: Server response.
23+
"""
24+
raise NotImplementedError()

mindee/parsing/v2/base_inference/split/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)