Skip to content

Commit 0434324

Browse files
hwu71mahaloz
andauthored
fix: handle adjacent @@ variable tokens in split_words() (#15)
* fix: handle adjacent @@ variable tokens in split_words() When variables appear adjacent without spaces in decompiled code (e.g., func(a,b,c)), the @@ placeholder tokens merge into one word. re.search() only matched the first pattern, silently losing the rest and causing a holder/mask count mismatch that discards all predictions. Replace re.search() with re.finditer() to extract all @@ patterns. * fix broken tests with pin --------- Co-authored-by: mahaloz <zion@zionbasque.com>
1 parent 1e6a85b commit 0434324

3 files changed

Lines changed: 16 additions & 14 deletions

File tree

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ long_description_content_type = text/markdown
1515
[options]
1616
install_requires =
1717
torch
18-
transformers
18+
transformers>=5.2.0
1919
tqdm
2020
dailalib
2121
libbs>=1.18.1

varbert/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.3.0"
1+
__version__ = "2.3.1"
22

33
import importlib.resources
44
import tarfile

varbert/model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def varec_init(self):
107107
str(self.model_base_dir),
108108
avar_vocab_size = self.vocab_size,
109109
from_tf=False,
110-
config=config
110+
config=config
111111
)
112112

113113
model.to(device)
@@ -116,21 +116,24 @@ def varec_init(self):
116116
@staticmethod
117117
def create_inputs_for_model(code_txt, tokenizer):
118118
input_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code_txt))
119-
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
119+
input_ids = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
120120
return torch.tensor(input_ids, dtype=torch.long)
121121

122122
@staticmethod
123123
def split_words(text: str):
124124
words = text.replace("\n", " ").split(" ")
125125
r = []
126126
for w in words:
127-
m = re.search(r"@@[^\s@]+@@[^\s@]+@@", w)
128-
if m is not None:
129-
if m.start() > 0:
130-
r.append(w[: m.start()])
131-
r.append(w[m.start(): m.end()])
132-
if m.end() < len(w):
133-
r.append(w[m.end():])
127+
matches = list(re.finditer(r"@@[^\s@]+@@[^\s@]+@@", w))
128+
if matches:
129+
pos = 0
130+
for m in matches:
131+
if m.start() > pos:
132+
r.append(w[pos: m.start()])
133+
r.append(w[m.start(): m.end()])
134+
pos = m.end()
135+
if pos < len(w):
136+
r.append(w[pos:])
134137
else:
135138
r.append(w)
136139
r = [w for w in r if len(w) > 0]
@@ -206,7 +209,7 @@ def preprocess_word_mask(self, ftext, tokenizer):
206209
tpwords.append(vocab[t])
207210
towords.append(vocab[t])
208211
pos += 1
209-
212+
210213
assert len(tpwords) == len(towords)
211214
assert None not in tpwords
212215
assert None not in towords
@@ -280,7 +283,7 @@ def process(self, code: str):
280283
# _code = "\n".join(_code_lines)
281284

282285
input_ids = self.preprocess_word_mask(_code, tokenizer)[0]
283-
input_ids_with_special_tokens = tokenizer.build_inputs_with_special_tokens(input_ids)
286+
input_ids_with_special_tokens = [tokenizer.bos_token_id] + input_ids + [tokenizer.eos_token_id]
284287
if len(input_ids_with_special_tokens) < 800:
285288
# padding
286289
padded_input_ids = input_ids_with_special_tokens[:-1] + [1] * 800 + [2]
@@ -411,4 +414,3 @@ def forward(
411414
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
412415
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
413416
}
414-

0 commit comments

Comments
 (0)