Skip to content

Commit e550a59

Browse files
Merge pull request #21 from InseeFrLab/mateom
Adding NaiveCode2Text and a small script for exportation to .parquet format
2 parents 1bbae86 + 08b16e2 commit e550a59

21 files changed

Lines changed: 1358 additions & 7 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Personnal usage
2+
test.ipynb
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[codz]

explorations.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def sample_codes(fs: s3fs.S3FileSystem, population_path: str, code_column: str,
116116

117117
return sampled[code_column].to_numpy()
118118

119+
label_idx = n_nace_nodes + i
120+
if target_code in codes_dict:
121+
label_to_code_idx[label_idx] = codes_dict[target_code]
119122

120123
codes = sample_codes(
121124
fs=fs,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies = [
2626
"pca>=2.10.1",
2727
"plotly>=6.5.1",
2828
"polars>=1.38.1",
29+
"pyarrow>=23.0.1",
2930
"s3fs>=2024.12.0",
3031
"transformers>=4.57.3",
3132
"umap-learn>=0.5.11",

src/agents/NaiveCode2Text/__init__.py

Whitespace-only changes.
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import polars as pl
2+
import s3fs
3+
import numpy as np
4+
5+
6+
def sample_codes(
7+
fs: s3fs.S3FileSystem,
8+
population_path: str,
9+
code_column: str,
10+
n_codes: int
11+
) -> np.ndarray:
12+
"""
13+
Sample codes with replacement using dataframes from Polars.
14+
15+
Args:
16+
fs (S3FileSystem): The filesystem for importation.
17+
population_path (str): The path of the parquet file of the population.
18+
code_column (str): The name of the column for codes.
19+
n_codes (int): The number of codes to sample.
20+
21+
Returns:
22+
numpy.ndarray: An array of n_codes codes sampled with replacement.
23+
"""
24+
25+
with fs.open(population_path, 'rb') as f:
26+
df = pl.read_parquet(f)
27+
28+
sampled = df.select(code_column).sample(n=n_codes, with_replacement=True)
29+
30+
return sampled[code_column].to_numpy()
31+
32+
33+
def sample_codes_lazy(
34+
fs: s3fs.S3FileSystem,
35+
population_path: str,
36+
code_column: str,
37+
n_codes: str
38+
) -> np.ndarray:
39+
"""
40+
Sample codes with replacement using lazyframes from Polars.
41+
42+
Args:
43+
fs (S3FileSystem): The filesystem for importation.
44+
population_path (str): The path of the parquet file of the population.
45+
code_column (str): The name of the column for codes.
46+
n_codes (int): The number of codes to sample.
47+
48+
Returns:
49+
numpy.ndarray: An array of n_codes codes sampled with replacement.
50+
"""
51+
52+
with fs.open(population_path, 'rb') as f:
53+
lf = (
54+
pl.scan_parquet(f)
55+
.with_row_index("row_id")
56+
)
57+
58+
total_rows = lf.select(pl.len()).collect().item()
59+
60+
random_ids = (
61+
pl.Series("row_id", range(total_rows))
62+
.sample(n=n_codes, with_replacement=True)
63+
.to_frame()
64+
.lazy()
65+
)
66+
67+
sampled = lf.join(random_ids, on="row_id", how="inner")
68+
69+
df = sampled.collect()
70+
71+
return df[code_column].to_numpy()
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from src.neo4j_graph.graph import Graph
2+
3+
4+
def get_code_information(
5+
graph: Graph,
6+
code: str
7+
) -> dict:
8+
"""
9+
Retrieve code specifications from a Neo4j graph
10+
11+
Args:
12+
graph (Graph from local library): The Neo4j graph.
13+
code (str): The code to specify.
14+
15+
Returns:
16+
dict: Every accessible information of the code in the graph.
17+
"""
18+
19+
query = """
20+
MATCH (node {CODE: $code})
21+
OPTIONAL MATCH (node)<-[:HAS_CHILD]-(parent)
22+
OPTIONAL MATCH (node)-[:HAS_CHILD]->(child)
23+
WITH node, parent, collect({code: child.CODE, name: child.NAME}) as children
24+
RETURN node.CODE as code,
25+
node.LEVEL as level,
26+
node.NAME as name,
27+
node.text as description,
28+
node.Includes as includes,
29+
node.IncludesAlso as includes_also,
30+
node.Excludes as excludes,
31+
node.Implementation_rule as implementation_rule,
32+
parent.CODE as parent_code,
33+
children,
34+
size(children) as children_count
35+
"""
36+
result = graph.graph.query(query, params={"code": code})
37+
38+
if not result:
39+
print("No result in get_code_information")
40+
return ()
41+
42+
return result[0]
43+
44+
45+
def NAF_to_NACE(
46+
code: str
47+
) -> str:
48+
"""
49+
For the case of NAF code format (DDDDL), transform it into NACE (DD.DD).
50+
51+
Args:
52+
code (str): The code in NAF format to transform.
53+
54+
Returns:
55+
str: The code in NACE format.
56+
"""
57+
return code[0:2] + '.' + code[2:4]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# For code sampling
2+
POPULATION_PATH = "projet-ape/data/08112022_27102024/naf2025/split/df_train.parquet"
3+
CODE_COLUMN = "nace2025"
4+
5+
# For prompt creation
6+
PROMPT_PATH = "src/agents/NaiveCode2Text/prompts/"
7+
8+
# To retrieve specifications of every code correctly:
9+
INCLUDES_DIVIDER = "\n-"
10+
EXAMPLES_DIVIDER = "\n"
11+
EXCLUDE_DIVIDER = "\n"
12+
13+
# Randomization for specifications
14+
RANDOM_SPEC_SAMPLING = True
15+
RANDOM_INCLUDES_GEOM_PROB = 0.3
16+
RANDOM_INCLUDES_MIN = 1
17+
RANDOM_INCLUDES_MAX = None # None = up to the max number of includes
18+
RANDOM_EXAMPLES_GEOM_PROB = 0.2
19+
RANDOM_EXAMPLES_MIN = 1
20+
RANDOM_EXAMPLES_MAX = None # None = up to the max number of examples per include
21+
22+
# Exportation
23+
OUTPUT_PATH = "projet-ape/synthetic_data_test/naive/"
24+
OUTPUT_FORMAT = ".parquet" # .txt or .parquet
25+
BATCH_SIZE = 5 # If choosing .parquet output format
26+
27+
# LLM Hyperparameters
28+
MODEL = "gpt-oss:20b"
29+
TEMPERATURE = 1.8
30+
LANGUAGE = "English"
31+
32+
# Generation specifications
33+
N_CODES = 12 # Number of codes to sample
34+
NB_LABELS = 10 # Number of labels to generate per code
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
import logging
3+
import time
4+
5+
from dotenv import load_dotenv
6+
import s3fs
7+
from openai import OpenAI
8+
9+
from src.agents.NaiveCode2Text.config_naive import \
10+
MODEL, TEMPERATURE, OUTPUT_PATH, N_CODES, POPULATION_PATH, CODE_COLUMN, \
11+
OUTPUT_FORMAT, BATCH_SIZE, LANGUAGE, NB_LABELS, PROMPT_PATH, \
12+
INCLUDES_DIVIDER, EXAMPLES_DIVIDER, EXCLUDE_DIVIDER, RANDOM_SPEC_SAMPLING, \
13+
RANDOM_INCLUDES_GEOM_PROB, RANDOM_INCLUDES_MIN, RANDOM_INCLUDES_MAX, \
14+
RANDOM_EXAMPLES_GEOM_PROB, RANDOM_EXAMPLES_MIN, RANDOM_EXAMPLES_MAX
15+
from src.agents.NaiveCode2Text.prompts import prompt_builder, label_generator
16+
from src.agents.NaiveCode2Text.code_retrieval import code_sampler, code_specifier
17+
from src.neo4j_graph.graph import Graph, Neo4JConfig
18+
19+
# Logger
20+
logging.basicConfig(level=logging.INFO)
21+
logger = logging.getLogger(__name__)
22+
23+
# Environment
24+
load_dotenv(override=True)
25+
26+
if __name__ == "__main__":
27+
# Clock for speed testing
28+
if OUTPUT_FORMAT == ".txt":
29+
start = time.perf_counter()
30+
31+
# Access configurations
32+
FS = s3fs.S3FileSystem(
33+
client_kwargs={'endpoint_url': os.environ["AWS_ENDPOINT_URL"]},
34+
key=os.environ["AWS_ACCESS_KEY_ID"],
35+
secret=os.environ["AWS_SECRET_ACCESS_KEY"],
36+
token=os.environ["AWS_SESSION_TOKEN"]
37+
)
38+
39+
LLM_API_KEY = os.environ["LLM_API_KEY"]
40+
LLM_URL = os.environ["LLM_URL"]
41+
LLM_CLIENT = OpenAI(api_key=LLM_API_KEY, base_url=LLM_URL)
42+
43+
# Sampling from original data
44+
logger.info("Sampling from data...")
45+
code_list = code_sampler.sample_codes_lazy(
46+
fs=FS,
47+
population_path=POPULATION_PATH,
48+
code_column=CODE_COLUMN,
49+
n_codes=N_CODES
50+
)
51+
52+
# NAF to NACE
53+
logger.info("Transforming codes from NAF to NACE...")
54+
code_list = [code_specifier.NAF_to_NACE(code) for code in code_list]
55+
56+
# Neo4j connection
57+
logger.info("Connecting to Neo4j graph...")
58+
notice_graph = Graph(Neo4JConfig(
59+
url=os.environ["NEO4J_URL"],
60+
username=os.environ["NEO4J_USERNAME"],
61+
password=os.environ["NEO4J_PWD"]
62+
))
63+
64+
# Define an automatic name for output
65+
file_name = f"generation_{MODEL}_temp{TEMPERATURE}".replace(":", "-").replace(".", "") \
66+
+ OUTPUT_FORMAT
67+
FINAL_PATH = OUTPUT_PATH + file_name
68+
69+
# Prompt generation
70+
logger.info("Generating prompts...")
71+
72+
name_list = []
73+
label_list = []
74+
75+
# Model set up
76+
LabelGenerationModel = label_generator.build_label_generation_model(NB_LABELS)
77+
78+
system_prompt = prompt_builder.build_system_prompt(
79+
prompt_path=PROMPT_PATH,
80+
language=LANGUAGE,
81+
nb_labels=NB_LABELS
82+
)
83+
84+
for i, code in enumerate(code_list):
85+
logger.info(f"Processing step {i+1}...")
86+
87+
# Get code details from Neo4j
88+
code_details = code_specifier.get_code_information(
89+
graph=notice_graph,
90+
code=code
91+
)
92+
93+
# For exportation purpose
94+
name_list.append(code_details["name"])
95+
96+
# Build prompts
97+
user_prompt = prompt_builder.build_user_prompt(
98+
code_details=code_details,
99+
language=LANGUAGE,
100+
nb_labels=NB_LABELS,
101+
includes_divider=INCLUDES_DIVIDER,
102+
examples_divider=EXAMPLES_DIVIDER,
103+
excludes_divider=EXCLUDE_DIVIDER,
104+
random_spec_sampling=RANDOM_SPEC_SAMPLING,
105+
random_includes_geom_prob=RANDOM_INCLUDES_GEOM_PROB,
106+
random_includes_min=RANDOM_INCLUDES_MIN,
107+
random_includes_max=RANDOM_INCLUDES_MAX,
108+
random_examples_geom_prob=RANDOM_EXAMPLES_GEOM_PROB,
109+
random_examples_min=RANDOM_EXAMPLES_MIN,
110+
random_examples_max=RANDOM_EXAMPLES_MAX
111+
)
112+
113+
# Ask the chatbot
114+
generation = label_generator.ask_model(
115+
system_prompt=system_prompt,
116+
user_prompt=user_prompt,
117+
llm_client=LLM_CLIENT,
118+
model=MODEL,
119+
temperature=TEMPERATURE,
120+
LabelGeneration=LabelGenerationModel
121+
)
122+
123+
label_list.append(generation.labels)
124+
125+
if OUTPUT_FORMAT == ".parquet" and (i+1) % BATCH_SIZE == 0:
126+
logger.info("Saving intermediate results...")
127+
label_generator.export_to_parquet(
128+
codes=code_list[i+1-BATCH_SIZE:i+1],
129+
names=name_list,
130+
labels=label_list,
131+
file_path=FINAL_PATH,
132+
fs=FS
133+
)
134+
label_list = []
135+
name_list = []
136+
137+
end = time.perf_counter()
138+
139+
if OUTPUT_FORMAT == ".txt":
140+
logger.info("Saving results to txt...")
141+
label_generator.export_to_txt(
142+
codes=code_list,
143+
names=name_list,
144+
labels=label_list,
145+
file_path=FINAL_PATH,
146+
generation_time=end-start
147+
)
148+
149+
elif OUTPUT_FORMAT == ".parquet":
150+
logger.info("Saving final results...")
151+
first_unsaved_index = BATCH_SIZE*(N_CODES//BATCH_SIZE)
152+
if first_unsaved_index < N_CODES:
153+
label_generator.export_to_parquet(
154+
codes=code_list[BATCH_SIZE*(N_CODES//BATCH_SIZE):],
155+
names=name_list,
156+
labels=label_list,
157+
file_path=FINAL_PATH,
158+
fs=FS
159+
)

0 commit comments

Comments
 (0)