Skip to content

Commit 376e978

Browse files
committed
fix: is_final is the new ending of the navigator classification
1 parent 7185c8f commit 376e978

4 files changed

Lines changed: 127 additions & 80 deletions

File tree

explorations.py

Lines changed: 118 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
query = """
2121
MATCH path = (root)-[*]->(n)
22-
WHERE n.LEVEL = 5
22+
WHERE n.FINAL = 1
2323
AND n.embedding IS NOT NULL
2424
AND root.LEVEL = 0
2525
RETURN n.embedding as embedding,
2626
n.NAME as name,
27+
n.CODE as code,
2728
[node IN nodes(path) | node.CODE] as path_codes,
2829
[node IN nodes(path) | node.LEVEL] as path_levels
2930
"""
@@ -33,10 +34,14 @@
3334
embeddings = []
3435
names = []
3536
paths = []
37+
codes_dict = {}
3638

37-
for record in results:
39+
for idx, record in enumerate(results):
3840
embeddings.append(record["embedding"])
3941
names.append(record["name"])
42+
code = record["code"]
43+
code_clean = code.replace(".", "").replace(" ", "")
44+
codes_dict[code_clean] = idx
4045

4146
path_str = " → ".join([
4247
name for lvl, name in zip(record["path_levels"], record["path_codes"])
@@ -45,76 +50,29 @@
4550

4651

4752
print(f"Nœuds récupérés: {len(names)}")
48-
print(embeddings)
4953

54+
n_nace_nodes = len(embeddings)
5055

51-
# %%
52-
# Add a query in the embedding space
53-
54-
queries = ["Je vends des croissants", "Livreur de taxi", "Coiffeur"]
55-
emb_model = OpenAIEmbeddings(
56-
model=os.environ['EMBEDDING_MODEL'],
57-
openai_api_base=os.environ['URL_EMBEDDING_API'],
58-
openai_api_key="EMPTY",
59-
tiktoken_enabled=False,
60-
)
61-
for i, query in iter(queries):
62-
query_emb = emb_model.embed_query(query)
63-
embeddings.append(query_emb)
64-
names.append(f"Query {i}")
65-
paths.append(query)
66-
67-
68-
# %% UMAP
69-
reducer = umap.UMAP(random_state=42, n_neighbors=10, min_dist=0.1)
70-
embeddings = np.array(embeddings)
71-
coords = reducer.fit_transform(embeddings)
72-
X, Y = coords.T
73-
74-
# %% Visualisation interactive
75-
fig = go.Figure()
76-
77-
fig.add_trace(go.Scatter(
78-
x=X, y=Y,
79-
mode='markers',
80-
marker=dict(
81-
size=10,
82-
color=np.arange(len(X)), # Couleur par index
83-
colorscale='Viridis',
84-
showscale=True,
85-
line=dict(width=0.5, color='white')
86-
),
87-
text=[f"<b>{name}</b><br><br>{path}" for name, path in zip(names, paths)],
88-
hovertemplate='%{text}<extra></extra>'
89-
))
9056

91-
fig.update_layout(
92-
title="Nœuds de niveau 5",
93-
xaxis_title="UMAP 1",
94-
yaxis_title="UMAP 2",
95-
width=1200,
96-
height=800,
97-
hovermode='closest',
98-
plot_bgcolor='white',
99-
xaxis=dict(showgrid=True, gridcolor='lightgray'),
100-
yaxis=dict(showgrid=True, gridcolor='lightgray')
101-
)
10257

103-
fig.show()
10458
# %%
10559
import os
10660
import s3fs
10761
os.environ["AWS_ACCESS_KEY_ID"] = 'UN8E5UMY78E5H4AKC7HF'
10862
os.environ["AWS_SECRET_ACCESS_KEY"] = 'fSUt5up9uh4qfyHH4LIQ6J0GiQp42eFc+fKWrRS2'
10963
os.environ["AWS_SESSION_TOKEN"] = 'eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3NLZXkiOiJVTjhFNVVNWTc4RTVINEFLQzdIRiIsImFsbG93ZWQtb3JpZ2lucyI6WyIqIl0sImF1ZCI6WyJtaW5pby1kYXRhbm9kZSIsIm9ueXhpYSIsImFjY291bnQiXSwiYXV0aF90aW1lIjoxNzcwNjI4NzkzLCJhenAiOiJvbnl4aWEiLCJlbWFpbCI6InRoZW8uZmVycnlAaW5zZWUuZnIiLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiZXhwIjoxNzcxNTEyODA1LCJmYW1pbHlfbmFtZSI6IkZlcnJ5IiwiZ2l2ZW5fbmFtZSI6IlRoZW8iLCJncm91cHMiOlsiVVNFUl9PTllYSUEiLCJhcGUiLCJtb2RlbHMtaGYiLCJzc3BsYWIiXSwiaWF0IjoxNzcwOTA4MDA0LCJpc3MiOiJodHRwczovL2F1dGgubGFiLnNzcGNsb3VkLmZyL2F1dGgvcmVhbG1zL3NzcGNsb3VkIiwianRpIjoib25ydHJ0OjllMjk1ZmEzLTliNmMtNjZjYi0yMWE0LTA2NDlhNGVkMWUzYSIsImxvY2FsZSI6ImZyIiwibmFtZSI6IlRoZW8gRmVycnkiLCJwb2xpY3kiOiJzdHNvbmx5IiwicHJlZmVycmVkX3VzZXJuYW1lIjoidGhlb2YiLCJyZWFsbV9hY2Nlc3MiOnsicm9sZXMiOlsib2ZmbGluZV9hY2Nlc3MiLCJ1bWFfYXV0aG9yaXphdGlvbiIsInZpcCIsImRlZmF1bHQtcm9sZXMtc3NwY2xvdWQiXX0sInJlc291cmNlX2FjY2VzcyI6eyJhY2NvdW50Ijp7InJvbGVzIjpbIm1hbmFnZS1hY2NvdW50IiwibWFuYWdlLWFjY291bnQtbGlua3MiLCJ2aWV3LXByb2ZpbGUiXX19LCJyb2xlcyI6WyJvZmZsaW5lX2FjY2VzcyIsInVtYV9hdXRob3JpemF0aW9uIiwidmlwIiwiZGVmYXVsdC1yb2xlcy1zc3BjbG91ZCJdLCJzY29wZSI6Im9wZW5pZCBwcm9maWxlIGdyb3VwcyBlbWFpbCIsInNpZCI6ImRiYTY1NzAxLWE3OTctMDFjZi0yYWE1LTRkYjkzY2Q0ZWM4NiIsInN1YiI6IjNlYTdiY2Q0LWJkMjMtNDA2Yy1hYmE2LWFmMzM3ZjBlMTAzNiIsInR5cCI6IkJlYXJlciJ9.keTVOmqa7NmhFGb5Jp384W0EisDdxox7Sip2f1B4MPdfN5z_tDtU85beJbBqCFl6TJdybu0PHVRX_sDW5q4Fgg'
11064
os.environ["AWS_DEFAULT_REGION"] = 'us-east-1'
65+
66+
N_CODES = 20
67+
11168
fs = s3fs.S3FileSystem(
11269
client_kwargs={'endpoint_url': 'https://'+'minio.lab.sspcloud.fr'},
11370
key = os.environ["AWS_ACCESS_KEY_ID"],
11471
secret = os.environ["AWS_SECRET_ACCESS_KEY"],
11572
token = os.environ["AWS_SESSION_TOKEN"])
11673

11774

75+
11876
def sample_codes(fs: s3fs.S3FileSystem, population_path: str, code_column: str, n_codes: int):
11977
"""
12078
Sample codes using Polars from S3.
@@ -139,10 +97,113 @@ def sample_codes(fs: s3fs.S3FileSystem, population_path: str, code_column: str,
13997
fs=fs,
14098
population_path=path,
14199
code_column=columns,
142-
n_codes=10)
100+
n_codes=N_CODES)
101+
102+
labels, target_codes = zip(*codes)
103+
104+
emb_model = OpenAIEmbeddings(
105+
model=os.environ['EMBEDDING_MODEL'],
106+
openai_api_base=os.environ['URL_EMBEDDING_API'],
107+
openai_api_key="EMPTY",
108+
tiktoken_enabled=False,
109+
)
110+
111+
labels_embeddings = emb_model.embed_documents(list(labels))
143112

144-
print(codes)
145-
labels, codes = zip(*codes)
113+
label_to_code_idx = {}
114+
115+
for i, (label, label_emb, target_code) in enumerate(zip(labels, labels_embeddings, target_codes)):
116+
embeddings.append(label_emb)
117+
names.append(label[:50])
118+
paths.append(f"Libellé -> Code cible: {target_code}")
119+
120+
label_idx = n_nace_nodes + i
121+
if target_code in codes_dict:
122+
label_to_code_idx[label_idx] = codes_dict[target_code]
123+
124+
# %%
125+
print(label_to_code_idx)
126+
127+
# %% UMAP
128+
reducer = umap.UMAP(random_state=42, n_neighbors=10, min_dist=0.1)
129+
embeddings = np.array(embeddings)
130+
coords = reducer.fit_transform(embeddings)
131+
X, Y = coords.T
132+
133+
134+
135+
136+
# %% Visualisation interactive
137+
fig = go.Figure()
138+
139+
140+
# 1. Ajouter les lignes de connexion AVANT les points
141+
for label_idx, code_idx in label_to_code_idx.items():
142+
fig.add_trace(go.Scatter(
143+
x=[X[label_idx], X[code_idx]],
144+
y=[Y[label_idx], Y[code_idx]],
145+
mode='lines',
146+
line=dict(color='rgba(150, 150, 150, 0.8)', width=3, dash='solid'),
147+
showlegend=False,
148+
hoverinfo='skip'
149+
))
150+
151+
# 2. Ajouter les nœuds NACE (cercles)
152+
fig.add_trace(go.Scatter(
153+
x=X[:n_nace_nodes],
154+
y=Y[:n_nace_nodes],
155+
mode='markers',
156+
name='Codes NACE',
157+
marker=dict(
158+
size=10,
159+
color=np.arange(n_nace_nodes),
160+
colorscale='Viridis',
161+
showscale=True,
162+
line=dict(width=0.5, color='white'),
163+
symbol='circle'
164+
),
165+
text=[f"<b>{name}</b><br><br>{path}" for name, path in zip(names[:n_nace_nodes], paths[:n_nace_nodes])],
166+
hovertemplate='%{text}<extra></extra>'
167+
))
168+
169+
# 3. Ajouter les libellés (étoiles)
170+
fig.add_trace(go.Scatter(
171+
x=X[n_nace_nodes:],
172+
y=Y[n_nace_nodes:],
173+
mode='markers',
174+
name='Libellés',
175+
marker=dict(
176+
size=15,
177+
color='red',
178+
symbol='star', # ou 'diamond', 'square', 'cross', 'x', 'triangle-up'
179+
line=dict(width=1, color='darkred')
180+
),
181+
text=[f"<b>{name}</b><br><br>{path}" for name, path in zip(names[n_nace_nodes:], paths[n_nace_nodes:])],
182+
hovertemplate='%{text}<extra></extra>'
183+
))
184+
185+
fig.update_layout(
186+
title="Nœuds NACE niveau 5 et libellés échantillonnés",
187+
xaxis_title="UMAP 1",
188+
yaxis_title="UMAP 2",
189+
width=1400,
190+
height=900,
191+
hovermode='closest',
192+
plot_bgcolor='white',
193+
xaxis=dict(showgrid=True, gridcolor='lightgray'),
194+
yaxis=dict(showgrid=True, gridcolor='lightgray'),
195+
legend=dict(
196+
yanchor="top",
197+
y=0.99,
198+
xanchor="right",
199+
x=0.99
200+
)
201+
)
202+
203+
fig.show()
204+
205+
# %%
206+
fig.write_html("umap_visualization.html")
146207

147208
# %%
148209
result = await classify_navigator(labels[0])

src/agents/Text2Code/classifiers/navigator_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def get_instructions(self) -> str:
2121
return """
2222
Vous êtes un expert en classification NACE. Votre mission est de naviguer
2323
dans l'arborescence afin d'atteindre le code le plus spécifique caractérisant l'activité indiquée.
24-
Après avoir vérifié que vous êtes au niveau 4 de l'arbre, et que votre position actuelle est bien finale,
24+
Après avoir vérifié que votre position actuelle est bien finale (is_final = 1),
2525
vous renverrez votre position.
2626
Si vous n'avez pas réussi à atteindre une position finale, dites-le.
2727
Soyez méthodique et justifiez chaque choix !

src/navigator/navigator.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_code_information(code: str) -> Dict[str, Any]:
5858
"code": info.get("code"),
5959
"name": info.get("name"),
6060
"level": info.get("level"),
61+
"is_final": info.get("is_final"),
6162
"description": info.get("description", "")[:500], # Limiter la taille
6263
}
6364

@@ -76,10 +77,9 @@ def get_current_children() -> List[Dict[str, Any]]:
7677
"""
7778
logger.info(f"Navigator: get_current_children called at the position {navigator.current_code}")
7879
children_found = _unfreeze_list_of_dicts(navigator._cached_get_children(navigator.current_code))
79-
keys_to_keep = ["code", "name"]
80-
print(f"Keys_to_keep: {keys_to_keep}")
80+
keys_to_keep = ["code", "name", "is_final"]
8181
filtered_children_found = [
82-
{k:d[k] for k in keys_to_keep}
82+
{k: d[k] for k in keys_to_keep}
8383
for d in children_found
8484
]
8585
logger.info(f"Navigator children found: {filtered_children_found}")
@@ -371,26 +371,8 @@ def submit_classification(
371371
go_to_parent,
372372
go_to_child,
373373
get_context_summary,
374-
# submit_classification
375374
]
376375

377-
""" return [
378-
get_current_information,
379-
get_code_information,
380-
get_current_parent,
381-
get_current_children,
382-
get_current_siblings,
383-
get_current_descendants,
384-
navigate_to,
385-
go_to_parent,
386-
go_to_child,
387-
reset_to_root,
388-
get_context_summary,
389-
get_navigation_history,
390-
submit_classification
391-
] """
392-
393-
394376
class Navigator(Graph):
395377
"""
396378
Classe de navigation dans la hiérarchie NACE avec état persistant.

src/neo4j_graph/graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _cached_get_code_information(self, code: str) -> Tuple[Tuple[str, Any], ...]
183183
RETURN node.CODE as code,
184184
node.LEVEL as level,
185185
node.NAME as name,
186+
node.FINAL as is_final,
186187
node.text as description,
187188
node.Includes as includes,
188189
node.IncludesAlso as includes_also,
@@ -211,7 +212,7 @@ def _cached_get_children(self, code: str) -> Tuple[Tuple[Tuple[str, Any], ...],
211212
MATCH (node {CODE: $code})-[:HAS_CHILD]->(child)
212213
RETURN child.CODE as code,
213214
child.LEVEL as level,
214-
child.FINAL as final,
215+
child.FINAL as is_final,
215216
child.NAME as name,
216217
child.text as description,
217218
child.Includes as includes,
@@ -233,6 +234,7 @@ def _cached_get_descendants(
233234
MATCH (node {{CODE: $code}})-[:HAS_CHILD*{levels}]->(descendant)
234235
RETURN descendant.CODE as code,
235236
descendant.LEVEL as level,
237+
descendant.FINAL as is_final,
236238
descendant.NAME as name,
237239
descendant.text as description,
238240
descendant.Includes as includes,
@@ -255,6 +257,7 @@ def _cached_get_siblings(self, code: str) -> Tuple[Tuple[Tuple[str, Any], ...],
255257
RETURN sibling.CODE as code,
256258
sibling.LEVEL as level,
257259
sibling.NAME as name,
260+
sibling.FINAL as is_final,
258261
sibling.text as description,
259262
sibling.Includes as includes,
260263
sibling.Excludes as excludes
@@ -293,6 +296,7 @@ def _cached_search_codes(self, search_term: str) -> Tuple[Tuple[Tuple[str, Any],
293296
OR toLower(node.text) CONTAINS toLower($search_term)
294297
RETURN node.CODE as code,
295298
node.LEVEL as level,
299+
node.FINAL as is_final,
296300
node.NAME as name,
297301
node.text as description
298302
ORDER BY node.LEVEL, node.CODE

0 commit comments

Comments
 (0)