-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_manager.py
More file actions
110 lines (90 loc) · 4.1 KB
/
function_manager.py
File metadata and controls
110 lines (90 loc) · 4.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pickle
import re
import atexit
from collections import defaultdict
from transformers import RobertaTokenizer, RobertaModel
class FunctionManager:
def __init__(self, data_dir='resources', filename='function_data.pkl', model_path='models/codebert-base', debug=False):
base_dir = os.path.dirname(os.path.abspath(__file__))
self.data_dir = os.path.join(base_dir, data_dir)
self.filename = filename
self.model_path = os.path.join(base_dir, model_path)
self.debug = debug
self._init()
# 注册退出时保存数据
atexit.register(self.save)
def _init(self):
# 初始化word_filter
self.word_filter = set()
with open(f'{self.data_dir}/word_filter.txt', 'r') as f:
for line in f:
self.word_filter.add(line.strip())
# 初始化negative_functions用于训练负样本
self.negative_functions = set()
with open(f'{self.data_dir}/negative_functions.txt', 'r') as f:
for line in f:
self.negative_functions.add(line.strip())
# 加载数据
file_dir = os.path.join(self.data_dir, self.filename)
if not os.path.exists(file_dir):
self.data = {}
else:
with open(file_dir, 'rb') as f:
self.data = pickle.load(f)
# 初始化tokenizer和model
self.tokenizer = RobertaTokenizer.from_pretrained(self.model_path)
self.model = RobertaModel.from_pretrained(self.model_path)
def save(self):
with open(os.path.join(self.data_dir, self.filename), 'wb') as f:
pickle.dump(self.data, f)
def function2id(self, profile):
'''从profile生成函数ID并计算token和embedding'''
for function in profile.function:
func_name = profile.string_table[function.name]
# 如果函数名不存在,初始化它
if func_name not in self.data:
self._process_function(func_name)
def _process_function(self, func_name):
'''处理函数名,生成token和embedding并保存'''
# 生成唯一的ID
node_id = len(self.data)
func_info = {
'id': node_id,
'name': func_name,
'token': self.preprocess_function_name(func_name),
'embedding': self._generate_embedding(func_name)
}
self.data[func_name] = func_info
self.data[node_id] = func_info
if self.debug:
print(f"New function {func_name}")
def _tokenize_function_name(self, func_name):
'''对函数名进行预处理并生成token'''
name = self.preprocess_function_name(func_name)
return self.tokenizer(name, return_tensors="pt")
def _generate_embedding(self, func_name):
'''生成函数名的embedding'''
inputs = self._tokenize_function_name(func_name)
outputs = self.model(**inputs)
return outputs.pooler_output.detach() # shape: (1, 768)
def preprocess_function_name(self, name):
'''预处理函数名:转换为小写,移除特殊字符等'''
name = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name).lower() # CamelCase -> snake_case
name = re.sub(r'[^a-zA-Z0-9\s]', ' ', name) # 移除所有特殊字符(非字母、数字、空格)
name = name.replace('.', ' ').replace('_', ' ') # 替换 . 和 _ 为单空格
# 处理多余的空格,确保空格只有一个
name = re.sub(r'\s+', ' ', name).strip()
# 分词后过滤指定的词
tokens = name.split()
filtered_tokens = [token for token in tokens if token.lower() not in self.word_filter]
return ' '.join(filtered_tokens)
def __getitem__(self, func_name):
'''根据函数名获取对应的信息(ID、token、embedding)'''
if func_name not in self.data:
self._process_function(func_name)
return self.data[func_name]
def __len__(self):
'''返回当前处理的函数数量'''
return len(self.data)
function_manager = FunctionManager(debug=False)