-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathuse.py
More file actions
177 lines (147 loc) · 7.66 KB
/
use.py
File metadata and controls
177 lines (147 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import torch.nn as nn
import numpy as np
from phonemizer import phonemize
from phonemizer.separator import Separator
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SpatialAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return x * self.sigmoid(self.conv(x))
class StreamEncoder(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32), nn.ReLU(),
SpatialAttention(32), nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), nn.ReLU(),
nn.AdaptiveAvgPool2d((4, 4))
)
def forward(self, x): return self.conv(x)
class DualStreamCNN(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.text_encoder = StreamEncoder(3)
self.phone_encoder = StreamEncoder(2)
self.flatten = nn.Flatten()
self.stream_weights = nn.Parameter(torch.ones(num_classes, 2))
self.fc_shared = nn.Linear(1024, 512)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.output_layer = nn.Linear(512, num_classes)
def forward(self, txt, phn):
t_feat = self.flatten(self.text_encoder(txt))
p_feat = self.flatten(self.phone_encoder(phn))
gates = torch.sigmoid(self.stream_weights)
logits = []
for i in range(self.output_layer.out_features):
combined_stream = (gates[i, 0] * t_feat) + (gates[i, 1] * p_feat)
x = self.dropout(self.relu(self.fc_shared(combined_stream)))
class_logit = self.output_layer.weight[i] @ x.t() + self.output_layer.bias[i]
logits.append(class_logit)
return torch.stack(logits, dim=1)
def create_text_matrix(text, max_len=64):
raw_words = str(text).lower().split()
clean_words = [w.strip('.,!?;:"()') for w in raw_words]
matrix = np.zeros((3, max_len, max_len), dtype=np.float32)
for i in range(min(len(clean_words), max_len)):
for j in range(min(len(clean_words), max_len)):
if i == j: continue
if clean_words[i] == clean_words[j] and len(clean_words[i]) > 0:
matrix[0, i, j] = 1.0
if clean_words[i][:1] == clean_words[j][:1]: matrix[1, i, j] += 0.5
if any(p in raw_words[i] for p in [',', ';', ':']): matrix[2, i, j] = 1.0
return matrix
def create_phone_matrix(phone_text, max_len=64):
# Split by <W> separator used in training
phone_words = [p for p in str(phone_text).lower().split() if p != "<w>"]
matrix = np.zeros((2, max_len, max_len), dtype=np.float32)
for i in range(min(len(phone_words), max_len)):
for j in range(min(len(phone_words), max_len)):
if i == j: continue
if phone_words[i] == phone_words[j]: matrix[0, i, j] = 1.0
if phone_words[i][-2:] == phone_words[j][-2:]: matrix[1, i, j] = 1.0
return matrix
def get_highlights(text, txt_mat, phn_mat, figure_name):
words = text.lower().split()
highlight_indices = set()
# Identify clause boundaries using our Punctuation Channel (Channel 2)
# This finds indices of words followed by ; , . or :
clause_ends = np.where(txt_mat[2].sum(axis=1) > 0)[0]
clause_starts = np.insert(clause_ends + 1, 0, 0)
clause_starts = clause_starts[clause_starts < len(words)]
# Get all active pairs from Word Identity Channel
coords = np.argwhere(txt_mat[0] > 0)
if figure_name.lower() == 'epanaphora':
# Only highlight if the word is at the START of a clause
for i, j in coords:
if i in clause_starts or j in clause_starts:
if i < len(words): highlight_indices.add(i)
if j < len(words): highlight_indices.add(j)
elif figure_name.lower() == 'epiphora':
# Only highlight if the word is at the END of a clause
for i, j in coords:
if i in clause_ends or j in clause_ends:
if i < len(words): highlight_indices.add(i)
if j < len(words): highlight_indices.add(j)
elif figure_name.lower() == 'anadiplosis':
# The classic A->B, B->C pattern (end of one, start of next)
for i, j in coords:
if (i in clause_ends and j in clause_starts) or (j in clause_ends and i in clause_starts):
if i < len(words): highlight_indices.add(i)
if j < len(words): highlight_indices.add(j)
else:
# Fallback for general repetition (Ploke, Epizeuxis, etc.)
for i, j in coords:
if i < len(words): highlight_indices.add(i)
if j < len(words): highlight_indices.add(j)
return sorted(list(highlight_indices))
def get_figure_highlights(text, phonetic_text, model, checkpoint):
"""
Pass in raw text and phonetic text.
Returns: { 'FIGURE_NAME': ['word1', 'word2', ...] }
"""
model.eval()
words = text.split() # Keep original casing for output
clean_words = [w.lower().strip('.,!?;:"()') for w in words]
# 1. Preprocess: Convert text to matrices
# Note: Ensure these functions (create_text_matrix, etc.) use the same max_len as training
txt_mat = create_text_matrix(text)
phn_mat = create_phone_matrix(phonetic_text)
# 2. Inference: Get model confidence
txt_tensor = torch.tensor(txt_mat).unsqueeze(0).to(device)
phn_tensor = torch.tensor(phn_mat).unsqueeze(0).to(device)
with torch.no_grad():
# Get probabilities
probs = torch.sigmoid(model(txt_tensor, phn_tensor)).cpu().numpy()[0]
# 3. Decode: Identify which words triggered the active figures
final_highlights = {}
for i, score in enumerate(probs):
# Only process figures that passed the optimized threshold
if score > checkpoint['thresholds'][i]:
fig_name = checkpoint['classes'][i]
# Map matrix coordinates to word indices
# (Using the logic from the get_highlights function we discussed)
indices = get_highlights(text, txt_mat, phn_mat, fig_name)
# Convert indices back to the actual words from the input
highlighted_words = [words[idx] for idx in indices if idx < len(words)]
if highlighted_words:
final_highlights[fig_name] = highlighted_words
return final_highlights
# Setup your environment
checkpoint = torch.load('./models/rhetoric_multilabel_annotator_model_2.pth', map_location=device, weights_only=False)
model = DualStreamCNN(num_classes=len(checkpoint['classes'])).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
h1 = get_figure_highlights( "The general who became a slave; the slave who became a gladiator; the gladiator who defied an Emperor", "ð ə <W> dʒ ɛ n ɹ ə l <W> h uː <W> b ɪ k eɪ m <W> ə <W> s l eɪ v <W> ð ə <W> s l eɪ v <W> h uː <W> b ɪ k eɪ m <W> ə <W> ɡ l æ d i eɪ t ə ɹ <W> ð ə <W> ɡ l æ d i eɪ t ə ɹ <W> h uː <W> d ɪ f aɪ d <W> æ n <W> ɛ m p ə ɹ ə ɹ", model, checkpoint)
h3 = get_figure_highlights("Put out the light, and then put out the light","p ʊ t <W> aʊ t <W> ð ə <W> l aɪ t <W> æ n d <W> ð ɛ n <W> p ʊ t <W> aʊ t <W> ð ə <W> l aɪ t <W> ʃ eɪ k s p ɪ ɹ <W> ɑː θ ɛ l oʊ <W> f aɪ v <W> t uː <W> s ɛ v ə n", model, checkpoint)
# Print the results
for figure, words in h1.items():
print(f"{figure.upper()}: {', '.join(words)}")
for figure, words in h3.items():
print(f"{figure.upper()}: {', '.join(words)}")