Skip to content

Commit 7a6f68e

Browse files
authored
make ollamaInferenceEngine handle return_meta_data (#1956)
* make ollamaInferenceEngine handle return_meta_data Signed-off-by: lilacheden <lilach.edel@gmail.com> * fix wml inference tests to use supported models Signed-off-by: lilacheden <lilach.edel@gmail.com> * more model fixes to test_inference_engine Signed-off-by: lilacheden <lilach.edel@gmail.com> * allow small diff in metric test Signed-off-by: lilacheden <lilach.edel@gmail.com> --------- Signed-off-by: lilacheden <lilach.edel@gmail.com>
1 parent 3cab0a5 commit 7a6f68e

3 files changed

Lines changed: 17 additions & 6 deletions

File tree

src/unitxt/inference.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,18 @@ def _infer(
14611461
options=args,
14621462
)
14631463
results.append(response)
1464-
1464+
if return_meta_data:
1465+
return [
1466+
TextGenerationInferenceOutput(
1467+
prediction=element["message"]["content"],
1468+
generated_text=element["message"]["content"],
1469+
input_tokens=element.get("prompt_eval_count", 0),
1470+
output_tokens=element.get("eval_count", 0),
1471+
model_name=self.model,
1472+
inference_type=self.label,
1473+
)
1474+
for element in results
1475+
]
14651476
return [element["message"]["content"] for element in results]
14661477

14671478

tests/inference/test_inference_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_llava_inference_engine(self):
159159

160160
def test_watsonx_inference(self):
161161
model = WMLInferenceEngineGeneration(
162-
model_name="google/flan-t5-xl",
162+
model_name="ibm/granite-3-8b-instruct",
163163
data_classification_policy=["public"],
164164
random_seed=111,
165165
min_new_tokens=1,
@@ -193,7 +193,7 @@ def test_watsonx_inference_with_external_client(self):
193193
from ibm_watsonx_ai.client import APIClient, Credentials
194194

195195
model = WMLInferenceEngineGeneration(
196-
model_name="google/flan-t5-xl",
196+
model_name="ibm/granite-3-8b-instruct",
197197
data_classification_policy=["public"],
198198
random_seed=111,
199199
min_new_tokens=1,
@@ -279,7 +279,7 @@ def test_option_selecting_by_log_prob_inference_engines(self):
279279
]
280280

281281
watsonx_engine = WMLInferenceEngineGeneration(
282-
model_name="meta-llama/llama-3-2-1b-instruct"
282+
model_name="ibm/granite-3-8b-instruct"
283283
)
284284

285285
for engine in [watsonx_engine]:
@@ -383,7 +383,7 @@ def test_lite_llm_inference_engine(self):
383383

384384
def test_lite_llm_inference_engine_without_task_data_not_failing(self):
385385
LiteLLMInferenceEngine(
386-
model="watsonx/meta-llama/llama-3-2-1b-instruct",
386+
model="watsonx/meta-llama/llama-3-2-11b-vision-instruct",
387387
max_tokens=2,
388388
temperature=0,
389389
top_p=1,

tests/library/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2708,7 +2708,7 @@ def test_perplexity(self):
27082708
metric=perplexity_question, predictions=prediction, references=references
27092709
)
27102710
self.assertAlmostEqual(
2711-
first_instance_target, outputs[0]["score"]["instance"]["score"]
2711+
first_instance_target, outputs[0]["score"]["instance"]["score"], places=5
27122712
)
27132713

27142714
def test_fuzzyner(self):

0 commit comments

Comments
 (0)