Nel nostro precedente blog post abbiamo parlato di BERT, un modello di rappresentazione del linguaggio che ha rivoluzionato il campo del Natural Language Processing. Inoltre, abbiamo sottolineato come le reti Transformer che costituiscono i layer di BERT rappresentino la chiave per il successo di questo modello.
Mentre le reti Transformer continuano a migliorare lo stato dell’arte in molti task NLP, recenti ricerche si sono concentrate su un problema che affligge questo tipo di reti neurali artificiali: la complessità quadratica rispetto alla lunghezza della sequenza di input.
Per farla breve, questo inconveniente rende computazionalmente costoso addestrare un modello Transformer capace di ricevere lunghe sequenze in input.
Per avere un’idea di cosa ciò significhi, basti sapere che l’elaborazione di una sequenza di lunghezza 100 richiede a una rete Transformer un numero di operazioni (prodotti scalari) dell’ordine di 1002, mentre l’elaborazione di una sequenza di input di lunghezza 1000 (ad esempio un capitolo ragionevolmente lungo) richiede all’incirca un milione di operazioni. Questo inconveniente limita sensibilmente l’applicabilità dei modelli Transformer: nel task di Question Answering, ad esempio, l’accesso a porzioni lunghe di testo è fondamentale.
In questo articolo passeremo in rassegna due alternative che mirano a risolvere questo problema di complessità, a nostro avviso semplici ma efficaci.
Due proposte per risolvere il problema della complessità nelle reti Transformer
Longformer
La prima proposta che presenteremo è il Longformer, una rete che mira ad aumentare la lunghezza massima delle sequenze di input riducendo il contesto della self-attention.
Durante l’elaborazione della i-esima parola nella sequenza di input, il Transformer nella sua versione originale tiene conto di tutti gli altri token di input per fornire il relativo output. Ciò implica che, se indichiamo con N la lunghezza della sequenza di input, la rete esegue N prodotti scalari per ciascuno degli N token di input (ecco perché la complessità risulta essere quadratica!).
Il Longformer, invece, prende in considerazione solo gli m token precedenti e gli m token seguenti il token processato (utilizzando una cosiddetta sliding window attention). In tal modo, al Longformer è richiesto di eseguire 2m prodotti scalari per ciascuno degli N token di input, cosicché la complessità del problema scali linearmente rispetto a N.
Attention completa (Transformer originale) Sliding Window Attention (Longformer)
Linear Transformer
La seconda alternativa che esamineremo è il Linear Transformer.
In questo modello il problema della complessità quadratica viene risolto semplicemente applicando un trucco di algebra lineare!
Senza tirare in ballo complesse formule matematiche, possiamo descrivere il Linear Transformer come un Transformer che adotta una diversa funzione di similarità nel calcolo delle attention.
Se nella versione originale la similarità tra un vettore query q e un vettore chiave k era calcolata come sim(q, k) = exp (qTk), nel Linear Transformer questa viene calcolata come sim(q, k) = ((q)T(k)) = (elu(q) + 1)T(elu(k) + 1).
Sfruttando la linearità di questa nuova funzione di similarità rispetto a (k), il Linear Transformer riduce il calcolo di N prodotti scalari per ogni token di input a una singola operazione di prodotto scalare per token di input. Con questa soluzione la complessità di questo modello scala linearmente rispetto a N.
Sia Longformer che Linear Transformer risultano essere estremamente più veloci rispetto alla rete Transformer originale durante l’elaborazione di lunghe sequenze di input, permettendo l’utilizzo dei meccanismi di attention in ogni task NLP.