Skip to content

Commit b3bc16e

Browse files
committed
GraphRAG Local Search
1 parent 0820598 commit b3bc16e

16 files changed

Lines changed: 436 additions & 72 deletions

File tree

shinkai-libs/shinkai-graphrag/src/context_builder/community_context.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ use polars::{
1111
};
1212
use rand::prelude::SliceRandom;
1313

14-
use super::{context_builder::ContextBuilderParams, indexer_entities::Entity, indexer_reports::CommunityReport};
14+
use crate::models::{CommunityReport, Entity};
15+
16+
use super::context_builder::GlobalSearchContextBuilderParams;
1517

1618
pub struct GlobalCommunityContext {
1719
community_reports: Vec<CommunityReport>,
@@ -34,9 +36,9 @@ impl GlobalCommunityContext {
3436

3537
pub async fn build_context(
3638
&self,
37-
context_builder_params: ContextBuilderParams,
39+
context_builder_params: GlobalSearchContextBuilderParams,
3840
) -> anyhow::Result<(Vec<String>, HashMap<String, DataFrame>)> {
39-
let ContextBuilderParams {
41+
let GlobalSearchContextBuilderParams {
4042
use_community_summary,
4143
column_delimiter,
4244
shuffle_data,

shinkai-libs/shinkai-graphrag/src/context_builder/context_builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#[derive(Debug, Clone)]
2-
pub struct ContextBuilderParams {
3-
//conversation_history: Option<ConversationHistory>,
2+
pub struct GlobalSearchContextBuilderParams {
43
pub use_community_summary: bool,
54
pub column_delimiter: String,
65
pub shuffle_data: bool,
@@ -12,6 +11,7 @@ pub struct ContextBuilderParams {
1211
pub normalize_community_weight: bool,
1312
pub max_tokens: usize,
1413
pub context_name: String,
14+
//conversation_history: Option<ConversationHistory>,
1515
// conversation_history_user_turns_only: bool,
1616
// conversation_history_max_turns: Option<i32>,
1717
}

shinkai-libs/shinkai-graphrag/src/context_builder/indexer_entities.rs

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use std::collections::{HashMap, HashSet};
1+
use std::collections::HashSet;
22

33
use polars::prelude::*;
44
use polars_lazy::dsl::col;
5-
use serde::{Deserialize, Serialize};
5+
6+
use crate::models::Entity;
67

78
use super::indexer_reports::filter_under_community_level;
89

@@ -52,23 +53,6 @@ pub fn read_indexer_entities(
5253
Ok(entities)
5354
}
5455

55-
#[derive(Debug, Clone, Deserialize, Serialize)]
56-
pub struct Entity {
57-
pub id: String,
58-
pub short_id: Option<String>,
59-
pub title: String,
60-
pub entity_type: Option<String>,
61-
pub description: Option<String>,
62-
pub description_embedding: Option<Vec<f64>>,
63-
pub name_embedding: Option<Vec<f64>>,
64-
pub graph_embedding: Option<Vec<f64>>,
65-
pub community_ids: Option<Vec<String>>,
66-
pub text_unit_ids: Option<Vec<String>>,
67-
pub document_ids: Option<Vec<String>>,
68-
pub rank: Option<i32>,
69-
pub attributes: Option<HashMap<String, String>>,
70-
}
71-
7256
pub fn read_entities(
7357
df: DataFrame,
7458
id_col: &str,

shinkai-libs/shinkai-graphrag/src/context_builder/indexer_reports.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use std::collections::{HashMap, HashSet};
1+
use std::collections::HashSet;
22

33
use polars::prelude::*;
44
use polars_lazy::dsl::col;
5-
use serde::{Deserialize, Serialize};
5+
6+
use crate::models::CommunityReport;
67

78
use super::indexer_entities::get_field;
89

@@ -59,20 +60,6 @@ pub fn filter_under_community_level(df: &DataFrame, community_level: u32) -> any
5960
Ok(result)
6061
}
6162

62-
#[derive(Debug, Clone, Deserialize, Serialize)]
63-
pub struct CommunityReport {
64-
pub id: String,
65-
pub short_id: Option<String>,
66-
pub title: String,
67-
pub community_id: String,
68-
pub summary: String,
69-
pub full_content: String,
70-
pub rank: Option<f64>,
71-
pub summary_embedding: Option<Vec<f64>>,
72-
pub full_content_embedding: Option<Vec<f64>>,
73-
pub attributes: Option<HashMap<String, String>>,
74-
}
75-
7663
pub fn read_community_reports(
7764
df: DataFrame,
7865
id_col: &str,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
pub mod context_builder;
22
pub mod llm;
3+
pub mod models;
34
pub mod search;
5+
pub mod vector_stores;
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use std::collections::HashMap;
2+
3+
#[derive(Debug, Clone)]
4+
pub struct CommunityReport {
5+
pub id: String,
6+
pub short_id: Option<String>,
7+
pub title: String,
8+
pub community_id: String,
9+
pub summary: String,
10+
pub full_content: String,
11+
pub rank: Option<f64>,
12+
pub summary_embedding: Option<Vec<f64>>,
13+
pub full_content_embedding: Option<Vec<f64>>,
14+
pub attributes: Option<HashMap<String, String>>,
15+
}
16+
17+
#[derive(Debug, Clone)]
18+
pub struct Entity {
19+
pub id: String,
20+
pub short_id: Option<String>,
21+
pub title: String,
22+
pub entity_type: Option<String>,
23+
pub description: Option<String>,
24+
pub description_embedding: Option<Vec<f64>>,
25+
pub name_embedding: Option<Vec<f64>>,
26+
pub graph_embedding: Option<Vec<f64>>,
27+
pub community_ids: Option<Vec<String>>,
28+
pub text_unit_ids: Option<Vec<String>>,
29+
pub document_ids: Option<Vec<String>>,
30+
pub rank: Option<i32>,
31+
pub attributes: Option<HashMap<String, String>>,
32+
}
33+
34+
#[derive(Debug, Clone)]
35+
pub struct Relationship {
36+
pub id: String,
37+
pub short_id: Option<String>,
38+
pub source: String,
39+
pub target: String,
40+
pub weight: Option<f64>,
41+
pub description: Option<String>,
42+
pub description_embedding: Option<Vec<f64>>,
43+
pub text_unit_ids: Option<Vec<String>>,
44+
pub document_ids: Option<Vec<String>>,
45+
pub attributes: Option<HashMap<String, String>>,
46+
}
47+
48+
#[derive(Debug, Clone)]
49+
pub struct TextUnit {
50+
pub id: String,
51+
pub short_id: Option<String>,
52+
pub text: String,
53+
pub text_embedding: Option<Vec<f64>>,
54+
pub entity_ids: Option<Vec<String>>,
55+
pub relationship_ids: Option<Vec<String>>,
56+
pub n_tokens: Option<i32>,
57+
pub document_ids: Option<Vec<String>>,
58+
pub attributes: Option<HashMap<String, String>>,
59+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use std::collections::HashMap;
2+
3+
use polars::frame::DataFrame;
4+
5+
#[derive(Debug, Clone)]
6+
pub enum ResponseType {
7+
String(String),
8+
KeyPoints(Vec<KeyPoint>),
9+
}
10+
11+
#[derive(Debug, Clone)]
12+
pub enum ContextData {
13+
String(String),
14+
DataFrames(Vec<DataFrame>),
15+
Dictionary(HashMap<String, DataFrame>),
16+
}
17+
18+
#[derive(Debug, Clone)]
19+
pub enum ContextText {
20+
String(String),
21+
Strings(Vec<String>),
22+
Dictionary(HashMap<String, String>),
23+
}
24+
25+
#[derive(Debug, Clone)]
26+
pub struct KeyPoint {
27+
pub answer: String,
28+
pub score: i32,
29+
}

shinkai-libs/shinkai-graphrag/src/search/global_search/global_search.rs

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use futures::future::join_all;
2-
use polars::frame::DataFrame;
32
use serde_json::Value;
43
use std::collections::HashMap;
54
use std::time::Instant;
65

76
use crate::context_builder::community_context::GlobalCommunityContext;
8-
use crate::context_builder::context_builder::{ContextBuilderParams, ConversationHistory};
7+
use crate::context_builder::context_builder::{ConversationHistory, GlobalSearchContextBuilderParams};
98
use crate::llm::llm::{BaseLLM, BaseLLMCallback, LLMParams, MessageType};
9+
use crate::search::base::{ContextData, ContextText, KeyPoint, ResponseType};
1010
use crate::search::global_search::prompts::NO_DATA_ANSWER;
1111

1212
use super::prompts::{GENERAL_KNOWLEDGE_INSTRUCTION, MAP_SYSTEM_PROMPT, REDUCE_SYSTEM_PROMPT};
@@ -21,32 +21,6 @@ pub struct SearchResult {
2121
pub prompt_tokens: usize,
2222
}
2323

24-
#[derive(Debug, Clone)]
25-
pub enum ResponseType {
26-
String(String),
27-
KeyPoints(Vec<KeyPoint>),
28-
}
29-
30-
#[derive(Debug, Clone)]
31-
pub enum ContextData {
32-
String(String),
33-
DataFrames(Vec<DataFrame>),
34-
Dictionary(HashMap<String, DataFrame>),
35-
}
36-
37-
#[derive(Debug, Clone)]
38-
pub enum ContextText {
39-
String(String),
40-
Strings(Vec<String>),
41-
Dictionary(HashMap<String, String>),
42-
}
43-
44-
#[derive(Debug, Clone)]
45-
pub struct KeyPoint {
46-
pub answer: String,
47-
pub score: i32,
48-
}
49-
5024
pub struct GlobalSearchResult {
5125
pub response: ResponseType,
5226
pub context_data: ContextData,
@@ -88,7 +62,7 @@ pub struct GlobalSearch {
8862
llm: Box<dyn BaseLLM>,
8963
context_builder: GlobalCommunityContext,
9064
num_tokens_fn: fn(&str) -> usize,
91-
context_builder_params: ContextBuilderParams,
65+
context_builder_params: GlobalSearchContextBuilderParams,
9266
map_system_prompt: String,
9367
reduce_system_prompt: String,
9468
response_type: String,
@@ -114,7 +88,7 @@ pub struct GlobalSearchParams {
11488
pub max_data_tokens: usize,
11589
pub map_llm_params: LLMParams,
11690
pub reduce_llm_params: LLMParams,
117-
pub context_builder_params: ContextBuilderParams,
91+
pub context_builder_params: GlobalSearchContextBuilderParams,
11892
}
11993

12094
impl GlobalSearch {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::{collections::HashMap, time::Instant};
2+
3+
use crate::{
4+
llm::llm::{BaseLLM, LLMParams, MessageType},
5+
search::base::{ContextData, ContextText, ResponseType},
6+
};
7+
8+
use super::{
9+
mixed_context::{LocalSearchContextBuilderParams, LocalSearchMixedContext},
10+
prompts::LOCAL_SEARCH_SYSTEM_PROMPT,
11+
};
12+
13+
pub struct LocalSearchResult {
14+
pub response: ResponseType,
15+
pub context_data: ContextData,
16+
pub context_text: ContextText,
17+
pub completion_time: f64,
18+
pub llm_calls: usize,
19+
pub prompt_tokens: usize,
20+
}
21+
22+
pub struct LocalSearch {
23+
llm: Box<dyn BaseLLM>,
24+
context_builder: LocalSearchMixedContext,
25+
num_tokens_fn: fn(&str) -> usize,
26+
system_prompt: String,
27+
response_type: String,
28+
llm_params: LLMParams,
29+
context_builder_params: LocalSearchContextBuilderParams,
30+
}
31+
32+
impl LocalSearch {
33+
pub fn new(
34+
llm: Box<dyn BaseLLM>,
35+
context_builder: LocalSearchMixedContext,
36+
num_tokens_fn: fn(&str) -> usize,
37+
llm_params: LLMParams,
38+
context_builder_params: LocalSearchContextBuilderParams,
39+
response_type: String,
40+
system_prompt: Option<String>,
41+
) -> Self {
42+
let system_prompt = system_prompt.unwrap_or(LOCAL_SEARCH_SYSTEM_PROMPT.to_string());
43+
44+
LocalSearch {
45+
llm,
46+
context_builder,
47+
num_tokens_fn,
48+
system_prompt,
49+
response_type,
50+
llm_params,
51+
context_builder_params,
52+
}
53+
}
54+
55+
pub async fn asearch(&self, query: String) -> anyhow::Result<LocalSearchResult> {
56+
let start_time = Instant::now();
57+
let (context_text, context_records) = self
58+
.context_builder
59+
.build_context(self.context_builder_params.clone())
60+
.await?;
61+
62+
let search_prompt = self
63+
.system_prompt
64+
.replace("{context_data}", &context_text)
65+
.replace("{response_type}", &self.response_type);
66+
67+
let mut search_messages = Vec::new();
68+
search_messages.push(HashMap::from([
69+
("role".to_string(), "system".to_string()),
70+
("content".to_string(), search_prompt.clone()),
71+
]));
72+
search_messages.push(HashMap::from([
73+
("role".to_string(), "user".to_string()),
74+
("content".to_string(), query.to_string()),
75+
]));
76+
77+
let search_response = self
78+
.llm
79+
.agenerate(
80+
MessageType::Dictionary(search_messages),
81+
false,
82+
None,
83+
self.llm_params.clone(),
84+
)
85+
.await?;
86+
87+
Ok(LocalSearchResult {
88+
response: ResponseType::String(search_response),
89+
context_data: ContextData::Dictionary(context_records),
90+
context_text: ContextText::String(context_text),
91+
completion_time: start_time.elapsed().as_secs_f64(),
92+
llm_calls: 1,
93+
prompt_tokens: (self.num_tokens_fn)(&search_prompt),
94+
})
95+
}
96+
}

0 commit comments

Comments
 (0)