Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ nodes:

- id: quiz
op_name: quiz
type: aggregate
type: map_batch
dependencies:
- build_kg
execution_params:
replicas: 1
batch_size: 128
params:
quiz_samples: 2 # number of quiz samples to generate
concurrency_limit: 200

- id: judge
op_name: judge
Expand Down
14 changes: 12 additions & 2 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def merge_nodes(
self,
node_data: tuple[str, List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
) -> dict:
entity_name, node_data = node_data
entity_types = []
source_ids = []
Expand Down Expand Up @@ -131,16 +131,18 @@ async def merge_nodes(

node_data = {
"entity_type": entity_type,
"entity_name": entity_name,
"description": description,
"source_id": source_id,
}
kg_instance.upsert_node(entity_name, node_data=node_data)
return node_data

async def merge_edges(
self,
edges_data: tuple[Tuple[str, str], List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
) -> dict:
(src_id, tgt_id), edge_data = edges_data

source_ids = []
Expand Down Expand Up @@ -175,11 +177,19 @@ async def merge_edges(
f"({src_id}, {tgt_id})", description
)

edge_data = {
"src_id": src_id,
"tgt_id": tgt_id,
"description": description,
"source_id": source_id, # for traceability
}

kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
)
return edge_data

async def _handle_kg_summary(
self,
Expand Down
8 changes: 5 additions & 3 deletions graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import random
from collections import deque
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
Expand Down Expand Up @@ -34,17 +35,18 @@ def _sort_units(units: list, edge_sampling: str) -> list:
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
:return: sorted units
"""
default_loss = -math.log(0.1)
if edge_sampling == "random":
random.shuffle(units)
elif edge_sampling == "min_loss":
units = sorted(
units,
key=lambda x: x[-1]["loss"],
key=lambda x: x[-1].get("loss", default_loss),
)
elif edge_sampling == "max_loss":
units = sorted(
units,
key=lambda x: x[-1]["loss"],
key=lambda x: x[-1].get("loss", default_loss),
reverse=True,
)
else:
Expand Down Expand Up @@ -142,7 +144,7 @@ def _add_unit(u):
return Community(
id=seed_unit[1],
nodes=list(community_nodes.keys()),
edges=[tuple(sorted(e)) for e in community_edges]
edges=[tuple(sorted(e)) for e in community_edges],
)

for unit in tqdm(all_units, desc="ECE partition"):
Expand Down
23 changes: 18 additions & 5 deletions graphgen/operators/build_kg/build_kg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def process(self, batch: pd.DataFrame) -> pd.DataFrame:
docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs]

# consume the chunks and build kg
self.build_kg(docs)
return pd.DataFrame([{"status": "kg_building_completed"}])
nodes, edges = self.build_kg(docs)
return pd.DataFrame(
[{"node": node, "edge": []} for node in nodes]
+ [{"node": [], "edge": edge} for edge in edges]
)

def build_kg(self, chunks: List[Chunk]) -> None:
def build_kg(self, chunks: List[Chunk]) -> tuple:
"""
Build knowledge graph (KG) and merge into kg_instance
"""
Expand All @@ -42,24 +45,34 @@ def build_kg(self, chunks: List[Chunk]) -> None:
if chunk.type in ("image", "video", "table", "formula")
]

nodes = []
edges = []

if len(text_chunks) == 0:
logger.info("All text chunks are already in the storage")
else:
logger.info("[Text Entity and Relation Extraction] processing ...")
build_text_kg(
text_nodes, text_edges = build_text_kg(
llm_client=self.llm_client,
kg_instance=self.graph_storage,
chunks=text_chunks,
max_loop=self.max_loop,
)
nodes += text_nodes
edges += text_edges
if len(mm_chunks) == 0:
logger.info("All multi-modal chunks are already in the storage")
else:
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
build_mm_kg(
mm_nodes, mm_edges = build_mm_kg(
llm_client=self.llm_client,
kg_instance=self.graph_storage,
chunks=mm_chunks,
)
nodes += mm_nodes
edges += mm_edges

self.graph_storage.index_done_callback()
logger.info("Knowledge graph building completed.")

return nodes, edges
8 changes: 5 additions & 3 deletions graphgen/operators/build_kg/build_mm_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def build_mm_kg(
llm_client: BaseLLMWrapper,
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
):
) -> tuple:
"""
Build multi-modal KG and merge into kg_instance
:param llm_client: Synthesizer LLM model to extract entities and relationships
Expand All @@ -37,14 +37,16 @@ def build_mm_kg(
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)

run_concurrent(
nodes = run_concurrent(
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
list(nodes.items()),
desc="Inserting entities into storage",
)

run_concurrent(
edges = run_concurrent(
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
list(edges.items()),
desc="Inserting relationships into storage",
)

return nodes, edges
8 changes: 5 additions & 3 deletions graphgen/operators/build_kg/build_text_kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def build_text_kg(
kg_instance: BaseGraphStorage,
chunks: List[Chunk],
max_loop: int = 3,
):
) -> tuple:
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
Expand All @@ -39,14 +39,16 @@ def build_text_kg(
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)

run_concurrent(
nodes = run_concurrent(
lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance),
list(nodes.items()),
desc="Inserting entities into storage",
)

run_concurrent(
edges = run_concurrent(
lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance),
list(edges.items()),
desc="Inserting relationships into storage",
)

return nodes, edges
12 changes: 6 additions & 6 deletions graphgen/operators/evaluate/evaluate_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
answer=str(item.get("answer", "")),
)
if not qa_pair.question or not qa_pair.answer:
self.logger.error("Empty question or answer, skipping.")
logger.error("Empty question or answer, skipping.")
return {}
except Exception as e:
self.logger.error("Error in QAPair creation: %s", str(e))
logger.error("Error in QAPair creation: %s", str(e))
return {}

for metric, evaluator in self.qa_evaluators.items():
Expand All @@ -110,7 +110,7 @@ async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]:
else:
item[metric] = float(score)
except Exception as e:
self.logger.error("Error in %s evaluation: %s", metric, str(e))
logger.error("Error in %s evaluation: %s", metric, str(e))
item[metric] = None
return item

Expand All @@ -136,7 +136,7 @@ def transform_messages_format(items: list[dict]) -> list[dict]:
return []

if not self.qa_evaluators:
self.logger.warning("No QA evaluators initialized, skipping QA evaluation")
logger.warning("No QA evaluators initialized, skipping QA evaluation")
return []

items = transform_messages_format(items)
Expand All @@ -155,11 +155,11 @@ def _evaluate_kg(self) -> Dict[str, Any]:

for metric, evaluator in self.kg_evaluators.items():
try:
self.logger.info("Running %s evaluation...", metric)
logger.info("Running %s evaluation...", metric)
score = evaluator.evaluate()
results[metric] = score
except Exception as e:
self.logger.error("Error in %s evaluation: %s", metric, str(e))
logger.error("Error in %s evaluation: %s", metric, str(e))
results[metric] = {"error": str(e)}
return results

Expand Down
7 changes: 6 additions & 1 deletion graphgen/operators/partition/partition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ def partition(self) -> Iterable[pd.DataFrame]:
else:
raise ValueError(f"Unsupported partition method: {method}")

communities = partitioner.partition(g=self.kg_instance, **method_params)
communities: Iterable = partitioner.partition(
g=self.kg_instance, **method_params
)

count = 0
for community in communities:
count += 1
batch = partitioner.community2batch(community, g=self.kg_instance)
batch = self._attach_additional_data_to_node(batch)

Expand All @@ -91,6 +95,7 @@ def partition(self) -> Iterable[pd.DataFrame]:
"edges": [batch[1]],
}
)
logger.info("Total communities partitioned: %d", count)

def _pre_tokenize(self) -> None:
"""Pre-tokenize all nodes and edges to add token length information."""
Expand Down
78 changes: 33 additions & 45 deletions graphgen/operators/quiz/quiz_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import Iterable

import pandas as pd

from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator
Expand All @@ -15,7 +13,6 @@ def __init__(
graph_backend: str = "kuzu",
kv_backend: str = "rocksdb",
quiz_samples: int = 1,
concurrency_limit: int = 200,
):
super().__init__(working_dir=working_dir, op_name="quiz_service")
self.quiz_samples = quiz_samples
Expand All @@ -28,21 +25,16 @@ def __init__(
backend=kv_backend, working_dir=working_dir, namespace="quiz"
)
self.generator = QuizGenerator(self.llm_client)
self.concurrency_limit = concurrency_limit

def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
# this operator does not consume any batch data
# but for compatibility we keep the interface
_ = batch.to_dict(orient="records")
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
data = batch.to_dict(orient="records")
self.graph_storage.reload()
yield from self.quiz()
return self.quiz(data)

async def _process_single_quiz(self, item: tuple) -> dict | None:
# if quiz in quiz_storage exists already, directly get it
index, desc = item
_quiz_id = compute_dict_hash({"index": index, "description": desc})
if self.quiz_storage.get_by_id(_quiz_id):
return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for an existing quiz in quiz_storage has been removed. This will cause quizzes to be regenerated every time _process_single_quiz is called for the same item, even if a quiz already exists. This could lead to redundant processing and increased costs. Consider adding the check back:

if self.quiz_storage.get_by_id(_quiz_id):
    return None

tasks = []
for i in range(self.quiz_samples):
Expand All @@ -68,47 +60,43 @@ async def _process_single_quiz(self, item: tuple) -> dict | None:
logger.error("Error when quizzing description %s: %s", item, e)
return None

def quiz(self) -> Iterable[pd.DataFrame]:
def quiz(self, batch) -> pd.DataFrame:
"""
Get all nodes and edges and quiz their descriptions using QuizGenerator.
"""
edges = self.graph_storage.get_all_edges()
nodes = self.graph_storage.get_all_nodes()

items = []

for edge in edges:
edge_data = edge[2]
desc = edge_data["description"]
items.append(((edge[0], edge[1]), desc))
for item in batch:
node_data = item.get("node", [])
edge_data = item.get("edge", [])

for node in nodes:
node_data = node[1]
desc = node_data["description"]
items.append((node[0], desc))
if node_data:
node_id = node_data["entity_name"]
desc = node_data["description"]
items.append((node_id, desc))
if edge_data:
edge_key = (edge_data["src_id"], edge_data["tgt_id"])
desc = edge_data["description"]
items.append((edge_key, desc))

logger.info("Total descriptions to quiz: %d", len(items))

for i in range(0, len(items), self.concurrency_limit):
batch_items = items[i : i + self.concurrency_limit]
batch_results = run_concurrent(
self._process_single_quiz,
batch_items,
desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})",
unit="description",
)
results = run_concurrent(
self._process_single_quiz,
items,
desc=f"Quizzing batch of {len(items)} descriptions",
unit="description",
)
valid_results = [res for res in results if res]

final_results = []
for new_result in batch_results:
if new_result:
self.quiz_storage.upsert(
{
new_result["_quiz_id"]: {
"description": new_result["description"],
"quizzes": new_result["quizzes"],
}
}
)
final_results.append(new_result)
self.quiz_storage.index_done_callback()
yield pd.DataFrame(final_results)
for res in valid_results:
self.quiz_storage.upsert(
{
res["_quiz_id"]: {
"description": res["description"],
"quizzes": res["quizzes"],
}
}
)
self.quiz_storage.index_done_callback()
return pd.DataFrame(valid_results)
Loading