generated from Amsterdam-Internships/InternshipAmsterdamGeneral
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_script.py
More file actions
83 lines (69 loc) · 2.35 KB
/
evaluate_script.py
File metadata and controls
83 lines (69 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
'''
This script is used to calculate automatic evaluation metrics for translated sentences.
parameters: --source_path --reference_path --target_path
'''
#Compute Bleu Sari and Meteor Scores for Validation set
import argparse
import os
import evaluate
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('--source_path',
help='path to source sentences file',
required=True)
parser.add_argument('--target_path',
help='path to predicted sentences file',
required=True)
parser.add_argument('--reference_path',
help='path to reference sentences file',
required=True)
parser.add_argument('--chart_title',
help='Title of the results chart',
required=True)
args=parser.parse_args()
source_path = args.source_path
reference_path = args.reference_path
predictions_path = args.target_path
title = args.chart_title
#load source sentences
sources = []
#path to source sentences file
with open(source_path,'r') as f:
for line in f:
sources.append(line.strip())
#load reference sentences
references = []
#path to reference sentences file
with open(reference_path,'r') as f:
for line in f:
references.append([line.strip()])
#load predicted sentences
predictions = []
#path to predicted sentences file
with open(predictions_path,'r') as f:
for line in f:
predictions.append(line.strip())
#load metrics
sari = evaluate.load("sari")
bleu = evaluate.load("bleu")
meteor = evaluate.load("meteor")
sari_score = sari.compute(sources=sources,predictions=predictions,references=references)
bleu_score = bleu.compute(predictions=predictions,references=references)
meteor_score = meteor.compute(predictions=predictions, references=references)
#print scores
term_size = os.get_terminal_size()
print('=' * term_size.columns)
print(sari_score)
print(bleu_score)
print(meteor_score)
print('=' * term_size.columns)
labels = ['SARI', 'BLEU', 'METEOR']
scores = [sari_score['sari'], bleu_score['bleu']*100, meteor_score['meteor']*100]
plt.bar(labels, scores)
plt.title(title)
plt.xlabel('Metric')
plt.ylabel('Score')
plt.ylim([0, 100])
for i, v in enumerate(scores):
plt.text(i, v+1, str(round(v, 2)), horizontalalignment='center', fontweight='bold')
plt.savefig('media/{}.png'.format(title))