From c6b1826988fd5057eea66d437071cb3af4fa0c10 Mon Sep 17 00:00:00 2001 From: FantasticCode2018 <”22151214378@stu.xidian.edu.cn“> Date: Mon, 12 Jun 2023 15:17:18 +0800 Subject: [PATCH] Add a detailed operation about None and open file with utf-8 --- WebQSP/data.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/WebQSP/data.py b/WebQSP/data.py index 32554f6..7543071 100644 --- a/WebQSP/data.py +++ b/WebQSP/data.py @@ -52,7 +52,7 @@ def __init__(self, input_dir, fn, bert_name, ent2id, rel2id, batch_size, trainin sub_map = defaultdict(list) so_map = defaultdict(list) - for line in open(os.path.join(input_dir, 'fbwq_full/train.txt')): + for line in open(os.path.join(input_dir, 'fbwq_full/train.txt'), encoding='utf-8'): l = line.strip().split('\t') s = l[0].strip() p = l[1].strip() @@ -77,7 +77,7 @@ def __init__(self, input_dir, fn, bert_name, ent2id, rel2id, batch_size, trainin question_2 = question_2[1] # question = question_1 + 'NE' + question_2 question = question_1.strip() - ans = line[1].split('|') + ans = line[1].strip().split('|') # if (head, ans[0]) not in so_map: @@ -92,7 +92,14 @@ def __init__(self, input_dir, fn, bert_name, ent2id, rel2id, batch_size, trainin head = [ent2id[head]] question = self.tokenizer(question.strip(), max_length=64, padding='max_length', return_tensors="pt") - ans = [ent2id[a] for a in ans] + # ans = [ent2id[a] for a in ans] + ############################################ + ans1 = [] + for a in ans: + if ent2id.get(a) is not None: + ans1.append(ent2id[a]) + ans = list(filter(lambda x: x is not None, ans1)) + ############################################ data.append([head, question, ans, entity_range]) print('data number: {}'.format(len(data))) @@ -117,7 +124,7 @@ def load_data(input_dir, bert_name, batch_size): else: print('Read data...') ent2id = {} - for line in open(os.path.join(input_dir, 'fbwq_full/entities.dict')): + for line in open(os.path.join(input_dir, 'fbwq_full/entities.dict'), encoding='utf-8'): l = line.strip().split('\t') ent2id[l[0].strip()] = len(ent2id) # print(len(ent2id)) @@ -128,7 +135,7 @@ def load_data(input_dir, bert_name, batch_size): rel2id[l[0].strip()] = int(l[1]) triples = [] - for line in open(os.path.join(input_dir, 'fbwq_full/train.txt')): + for line in open(os.path.join(input_dir, 'fbwq_full/train.txt'), encoding='utf-8'): l = line.strip().split('\t') s = ent2id[l[0].strip()] p = rel2id[l[1].strip()]