-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
cur_t_rnn, hc_t = self.capturer_t(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:])
if self.cat_contained:
cur_c_rnn, hc_c = self.capturer_c(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t)
cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_c)
# 4) tower, t,c,l
# CMTL
hc_t, hc_c, hc_l = hc_t.squeeze(), hc_c.squeeze(), hc_l.squeeze()
c_pred = self.fc_c(hc_c)
c_trans = self.label_trans_c(c_pred.clone())
t_pred = self.fc_t(torch.cat((hc_t, c_trans), dim=-1))
t_trans = self.label_trans_t(t_pred.clone())
l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1))
You first calculate hc_t and use it to calculate hc_c,but you then first calculate c_trans and use it to calculate t_pred, it seems not consistent and may make your result worse.

Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels