Skip to content

Commit 79c1843

Browse files
committed
fix: GitHubClient must be used as async context manager
1 parent 7a88d8b commit 79c1843

2 files changed

Lines changed: 67 additions & 68 deletions

File tree

src/swe_forge/cli/benchmark.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -253,42 +253,41 @@ async def _mine_tasks(
253253
verbose: bool,
254254
) -> list[SweTask]:
255255
"""Mine tasks using the SWE pipeline."""
256-
gh_client = GitHubClient(token=github_token)
257-
258-
config = SwePipelineConfig(
259-
max_tasks=num_tasks,
260-
once=True,
261-
difficulty_filter=difficulty_filter,
262-
)
263-
264-
tasks: list[SweTask] = []
265-
266-
with Progress(
267-
SpinnerColumn(),
268-
TextColumn("[progress.description]{task.description}"),
269-
BarColumn(),
270-
TaskProgressColumn(),
271-
TimeElapsedColumn(),
272-
console=console,
273-
) as progress:
274-
task_id = progress.add_task("Mining tasks...", total=num_tasks)
275-
276-
async with SwePipeline(gh_client, config=config) as pipeline:
277-
from swe_forge.swe.pipeline import SwePipelineEventType
278-
279-
async for event in pipeline.run_with_progress():
280-
if event.event_type == SwePipelineEventType.TASK_EXTRACTED:
281-
task = event.data.get("task")
282-
if task and isinstance(task, SweTask):
283-
tasks.append(task)
284-
progress.update(
285-
task_id, advance=1, description=f"Mined {len(tasks)} tasks"
286-
)
287-
288-
if len(tasks) >= num_tasks:
289-
break
256+
async with GitHubClient(token=github_token) as gh_client:
257+
config = SwePipelineConfig(
258+
max_tasks=num_tasks,
259+
once=True,
260+
difficulty_filter=difficulty_filter,
261+
)
290262

291-
return tasks[:num_tasks]
263+
tasks: list[SweTask] = []
264+
265+
with Progress(
266+
SpinnerColumn(),
267+
TextColumn("[progress.description]{task.description}"),
268+
BarColumn(),
269+
TaskProgressColumn(),
270+
TimeElapsedColumn(),
271+
console=console,
272+
) as progress:
273+
task_id = progress.add_task("Mining tasks...", total=num_tasks)
274+
275+
async with SwePipeline(gh_client, config=config) as pipeline:
276+
from swe_forge.swe.pipeline import SwePipelineEventType
277+
278+
async for event in pipeline.run_with_progress():
279+
if event.event_type == SwePipelineEventType.TASK_EXTRACTED:
280+
task = event.data.get("task")
281+
if task and isinstance(task, SweTask):
282+
tasks.append(task)
283+
progress.update(
284+
task_id, advance=1, description=f"Mined {len(tasks)} tasks"
285+
)
286+
287+
if len(tasks) >= num_tasks:
288+
break
289+
290+
return tasks[:num_tasks]
292291

293292

294293
async def _run_harness(

src/swe_forge/cli/mine.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -303,39 +303,39 @@ class PipelineResult:
303303
tasks: list = field(default_factory=list)
304304
benchmark_metrics: object = None
305305

306-
gh_client = GitHubClient(token=token)
307-
gh_archive_client = GhArchiveClient(token=token) if not repo_filter else None
308-
309-
tasks: list[SweTask] = []
310-
metrics = None
311-
312-
with Progress(
313-
SpinnerColumn(),
314-
TextColumn("[progress.description]{task.description}"),
315-
BarColumn(),
316-
TaskProgressColumn(),
317-
TimeElapsedColumn(),
318-
console=console,
319-
) as progress:
320-
task_id = progress.add_task("Mining tasks...", total=config.max_tasks)
321-
322-
async with SwePipeline(
323-
gh_client, gh_archive_client=gh_archive_client, config=config
324-
) as pipeline:
325-
async for event in pipeline.run_with_progress():
326-
if event.event_type == SwePipelineEventType.TASK_EXTRACTED:
327-
task = event.data.get("task")
328-
if task and isinstance(task, SweTask):
329-
tasks.append(task)
330-
progress.update(
331-
task_id, advance=1, description=f"Mined {len(tasks)} tasks"
332-
)
333-
334-
elif event.event_type == SwePipelineEventType.PIPELINE_COMPLETED:
335-
metrics = event.data.get("metrics")
336-
progress.update(task_id, completed=len(tasks))
337-
338-
return PipelineResult(tasks=tasks, benchmark_metrics=metrics)
306+
async with GitHubClient(token=token) as gh_client:
307+
gh_archive_client = GhArchiveClient(token=token) if not repo_filter else None
308+
309+
tasks: list[SweTask] = []
310+
metrics = None
311+
312+
with Progress(
313+
SpinnerColumn(),
314+
TextColumn("[progress.description]{task.description}"),
315+
BarColumn(),
316+
TaskProgressColumn(),
317+
TimeElapsedColumn(),
318+
console=console,
319+
) as progress:
320+
task_id = progress.add_task("Mining tasks...", total=config.max_tasks)
321+
322+
async with SwePipeline(
323+
gh_client, gh_archive_client=gh_archive_client, config=config
324+
) as pipeline:
325+
async for event in pipeline.run_with_progress():
326+
if event.event_type == SwePipelineEventType.TASK_EXTRACTED:
327+
task = event.data.get("task")
328+
if task and isinstance(task, SweTask):
329+
tasks.append(task)
330+
progress.update(
331+
task_id, advance=1, description=f"Mined {len(tasks)} tasks"
332+
)
333+
334+
elif event.event_type == SwePipelineEventType.PIPELINE_COMPLETED:
335+
metrics = event.data.get("metrics")
336+
progress.update(task_id, completed=len(tasks))
337+
338+
return PipelineResult(tasks=tasks, benchmark_metrics=metrics)
339339

340340

341341
@app.command("complete")

0 commit comments

Comments
 (0)