From 739baa59a7959d410bb6323ba20a0a4caf716f28 Mon Sep 17 00:00:00 2001 From: jackwilkie <76436546+jackwilkie@users.noreply.github.com> Date: Thu, 8 Jun 2023 10:52:45 +0100 Subject: [PATCH] Update saint_i.py Inter sample attention now inherits from nn.MultiheadAttention which includes optimizations such as flash attention --- models/saint_i.py | 66 +++++++++++++++++++---------------------------- 1 file changed, 27 insertions(+), 39 deletions(-) diff --git a/models/saint_i.py b/models/saint_i.py index 156360b..8fdc197 100644 --- a/models/saint_i.py +++ b/models/saint_i.py @@ -7,49 +7,37 @@ from .transformer import PositionwiseFeedForward, EncoderLayer, Encoder def intersample(query , key , value,dropout=None): - "Calculate the intersample of a given query batch" - #x , bs , n , d - b, h, n , d = query.shape - #print(query.shape,key.shape, value.shape ) - query , key , value = query.reshape(1, b, h, n*d), \ - key.reshape(1, b, h, n*d), \ - value.reshape(1, b, h, n*d) - - output, _ = attention(query, key ,value) #1 , b, n*d - output = output.squeeze(0) #b, n*d - output = output.reshape(b, h, n, d) #b,n,d - - return output - -class MultiHeadedIntersampleAttention(nn.Module): - def __init__(self, h, d_model, dropout=0.1): - "Take in model size and number of heads." - super(MultiHeadedIntersampleAttention, self).__init__() - assert d_model % h == 0 - # We assume d_v always equals d_k - - self.d_k = d_model // h - self.h = h - self.linears = clones(nn.Linear(d_model, d_model), 4) - self.attn = None - self.dropout = nn.Dropout(p=dropout) + ''' + Wrapper class for MHA which calculate attention over samples rather than features + ''' + + def __init__(self, *args, **kwargs): + ''' + Arguments are passed to MHA class + ''' + + # initalise MHA attention layer + super().__init__(*args, **kwargs) - def forward(self, query, key, value): - "Implements Figure 2" - - nbatches = query.size(0) - # 1) Do all the linear projections in batch from d_model => h x d_k - query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) - for l, x in zip(self.linears, (query, key, value))] + # Overwite forward method to transpose + def forward(self, query, key, value, **kwargs): + ''' + Requires query, key, value vectors of size batch x n x d_feature, transpoes and calucaltes attention across samples, transposes back and returns + kwargs are passed directly to nn.MultiheadAttention + ''' + + batch, n, d_features = query.size() # get original size - # 2) Apply attention on all the projected vectors in batch. - x = intersample(query, key, value, - dropout=self.dropout) + # reshape to 1 x batch x n * d_features + query = query.reshape(1, batch, -1) + key = key.reshape(1, batch, -1) + value = value.reshape(1, batch, -1) - # 3) "Concat" using a view and apply a final linear. - x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) # bs , n , d_model - return self.linears[-1](x) # bs , n , d_model + output, attn_output_weights = super().forward(query, key, value, **kwargs) # call forward function for MHA + + return output.reshape(batch, n, d_features), attn_output_weights # return output and attention weights + def make_saint_i(num_heads, embed_dim, num_layers, d_ff, dropout, dropout_ff=0.8):