diff --git a/seq2vec/model/seq2vec_base.py b/seq2vec/model/seq2vec_base.py index dbfd037..f471025 100644 --- a/seq2vec/model/seq2vec_base.py +++ b/seq2vec/model/seq2vec_base.py @@ -27,10 +27,12 @@ def __call__(self, seqs): Raises ------ """ - result = [] - for seq in seqs: - result.append(self.transform_single_sequence(seq)) - return np.array(result) + n_seqs = len(seqs) + result = np.zeros((n_seqs, self.latent_size), dtype=np.float32) + + for idx, seq in enumerate(seqs): + result[idx, :] = self.transform_single_sequence(seq) + return result @abstractmethod def transform_single_sequence(self, seq):