×

Writing an LLM from scratch, part 13 – attention heads are dumb

Writing an LLM from scratch, part 13 — the ‘why’ of attention, or: attention heads are dumb

Now that I’ve finished chapter 3 of
Sebastian Raschka‘s book
Build a Large Language Model (from Scratch)” —
having worked my way through multi-head attention in the last post —
I thought it would be worth pausing to take stock before moving on to Chapter 4.

There are two things I want to cover, the “why” of self-attention, and some thoughts
on context lengths. This post is on the “why” — that is, why do the particular
set of matrix multiplications described in the book do what we want them to do?

As always, this is something I’m doing
primarily to get things clear in my own head — with the possible extra benefit of it
being of use to other people out there. I will, of course, run it past multiple
LLMs to make sure I’m not posting total nonsense, but caveat lector!

Let’s get into it. As I wrote in
part 8 of this series:

I think it’s also worth noting that [what’s in the book is] very much a “mechanistic” explanation — it says how we do these calculations
without saying why. I think that the “why” is actually out of scope for this book, but it’s something that fascinates
me, and I’ll blog about it soon.

That “soon” is now 🙂

Attention heads are dumb

I think that my core problem with getting my head around why these equations
work was that I was overestimating
what a single attention head could do. In
part 6, I wrote, of
the phrase “the fat cat sat on the mat”:

So while the input embedding for “cat” just means “cat in position 3”, the context vector
for “cat” in this sentence also has some kind of overtones about it being a cat
that is sitting, perhaps less strongly that it’s a specific cat (“the” rather than “a”),
and hints of it being sitting on a mat.

The thing that I hadn’t understood was that this is true in as far as it goes, but
only for the output of the attention mechanism as a whole — not for a single
attention head.

Each individual attention head is really dumb, and what it’s doing is much
simpler than that!

The two things that combine to make the mechanism as a whole smart are multi-head
attention and layering. The book has gone over multi-head attention in detail, so
let’s drill down on that second part.

Layers

Right at the start, in part 1, I wrote:

One other thing that Raschka mentions that confuses me a little is that
apparently the original transformer architecture had six encoder and six decoder
blocks, and GPT-3 has 96 transformer layers. That doesn’t fit very comfortably
with my model of how this all works. Both encoders and decoders seem like
stand-alone things that accept inputs (tokens/embeddings) and produce outputs
(embeddings/tokens). What would you do with multiple layers of them?

Now that we’ve covered how attention works, that’s become a bit clearer.
A multi-head attention block gets a set of input embeddings, one per token in
the input sequence, and produces a set of the same number of context vectors.
There’s nothing stopping us from treating those context vectors as the input embeddings
for another attention block and doing the same thing again.

(That also explains why Raschka mentions that the number of dimensions in the
context vectors often matches the number in the input embeddings; it makes it easier
to use the same “shape” of multi-head attention calculations for each layer.)

In my mind, this is similar to the way an image-processing network — say, a CNN —
works. In those, the first layer
might detect edges, the second might detect lines at certain orientations, the
next particular shapes, and then somewhere later on, the nth might recognise
dogs’ faces.

So the representation of the token “cat” that I described above would not be part of
the output of one attention head, and perhaps even the first layer of the attention
mechanism might not have anything that rich. But it might be in the output of
the third layer of multi-head attention, or the fourth, or something like that.

By the 96th of those layers in GPT-3, what’s represented in the context vectors
is going to be super-enriched, and have lots of information spread across the different
tokens. And realising this was a bit of an epiphany for me as well.

No more fixed-length bottleneck

If you cast your
mind back to part 5, a big problem
with encoder/decoder RNNs that did not have attention mechanisms was
the fixed-length bottleneck. You would run your input sequence into an encoder
RNN, which would try to represent its meaning in its hidden state — a vector of
a particular fixed length — ready to pass it
on to the decoder. Easy with a short input, but increasingly hard and eventually
impossible as it gets longer, because you’d be trying to pack more and more information
into the same “space”.

But with attention, this super-enriched and combined representation of the input
sequence that comes out of the last attention layer is proportional in length to the
number of tokens in the input! You’re still limited by available memory, of course
(and other things — see the next post), but the more tokens you have, the larger
this “hidden state” of the context vectors.

That’s pretty cool.

So, using multi-head attention plus layers allows us to build up complex representations
even when each individual attention head is dumb. But, going back to the core of
this post, why do these dumb attention heads use the specific calculations that they
do?

Why dumb attention heads work

Let’s use an example.

A heads-up/warning first: the attention heads are learning their own representations and patterns
to match on as part of a deep learning gradient descent — so whatever they learn will
probably be weird and alien and not relate in any way to grammar as we understand
it. But for this example, let’s pretend that isn’t the case, and that we’ve got an attention head
that has learned how to match articles (like “a”, “an”, and “the”) up with their
associated nouns.

How would that work? Let’s take “the fat cat sat on the mat”, and ignore everything
apart from the two “the”s, and the nouns “cat” and “mat”. We’ll say that our attention head
wants to produce a context vector for “cat” that combines it with the first “the”
(meaning that it will contain the concept that we’re talking about a specific cat
rather than just “a” cat), and similarly it wants to blend the second “the” into
“mat”.

Now, remember that our input sequence is a series of input embeddings, which are combinations of the token
embeddings (which are vectors in a space that point to some abstract “meaning” of
the tokens) and position embeddings (which represent their position in the
sequence).

Taking “mat” as our example, we project its input embedding, which means “the token ‘mat’
at position 7”1 into query space. The breakthrough to
me was that query space is another embedding space, just like the original
one for the input embeddings, but with different representations for the values.

Let’s say that in this new embedding space, representations are much simpler — they don’t have
as much detail as the original one. It just represents “this is an article” or
“this is not an article”, and some information about positioning —
that is, the embedding for an article at position 1 is close to the embedding at
position 2, but is not very close to the one for an article at position 69,536.
And other things that are not articles would be somewhere even further away.

In this example, perhaps the projection that our attention head has
learned will map “‘mat’ at position 7” to an embedding pointing in the direction of
“some article — the or a — at position 6 or lower, probably quite close”.
In other words, the projection into query space turns an input embedding for a token into the kind of thing
this attention head is looking for when it’s handling that token. Likewise “‘cat’ at position 2”
would be projected into an embedding vector meaning “some article at position 1 or
lower, probably quite close”.

Now, as well as projecting the input embeddings into the query space, we’re also projecting them into the key space. In that case, our imaginary
article-matching head would create a projection that would turn the first “the” into
something meaning “an article at position 1”, and the second into one meaning
“an article at position 6”.

So, the query weights have projected our input embeddings into this “low-resolution” embedding space
to point in a direction meaning “this is what I’m interested in”, and the key weights
have projected the input embeddings into the same embedding space in a direction meaning “this is what I am”.

That means that when we do our dot product, the query vector for the “mat” will point in a
very similar direction to the key vector for the second “the”, and so the dot
product will be high — remember, so long as vectors are roughly the same length,
the dot product is an indication of how similar they are.

What’s important about this is that the shared embedding space that the query and
key vectors use can actually be pretty impoverished compared to the rich space
that the input embeddings used. In our case, all the head cares about are whether
tokens are nouns or articles or something else, and their position.

Let’s take an example. Here’s the imaginary set of attention scores that
I imagined that the attention mechanism might come up with back in part 6 (modified to
be causal so that tokens don’t pay any attention to tokens in their “future”):

Token ω(“The”) ω(“fat”) ω(“cat”) ω(“sat”) ω(“on”) ω(“the”) ω(“mat”)
The 1 0 0 0 0 0 0
fat 0.2 1 0 0 0 0 0
cat 0.6 0.8 1 0 0 0 0
sat 0.1 0 0.85 1 0 0 0
on 0 0.1 0.4 0.6 1 0 0
the 0 0 0 0 0.1 1 0
mat 0 0 0.2 0.8 0.7 0.6 1

Each row is, for the token in the first column, the attention scores for all of the
other words. It’s based on my own intuition about the importance of words, and it’s
the kind of thing you might imagine a clever attention head might come up with.
(Remember that ω is the variable we use to represent attention scores.)

But our closer-to-real-world example of an article-noun matching head is really dumb,
so it might come up with something more like this:

Token ω(“The”) ω(“fat”) ω(“cat”) ω(“sat”) ω(“on”) ω(“the”) ω(“mat”)
The 0 0 0 0 0 0 0
fat 0 0 0 0 0 0 0
cat 0.8 0 1 0 0 0 0
sat 0 0 0 0 0 0 0
on 0 0 0 0 0 0 0
the 0 0 0 0 0 0 0
mat 0.1 0 0 0 0 0.8 1

All it has done is
decided to pay attention to the “the”s when considering the nouns — and it’s
even paying a bit of attention to the first “the” when considering “mat”, because
it doesn’t know that it has to be the closest “the” that it matches with. 2

Now, as I said earlier, the real attention heads, having been trained by gradient
descent over billions of tokens, will probably have learned something weird and
abstract and not related to the way we think of language, grammar and the parts
of speech.

But taken individually, they will be really dumb, because the equation is doing something really simple:
when considering a particular kind of thing, look for this other kind of thing.
Each token is projected into a shared embedding space by the query weights
(“what I’m looking for” ) and into the same space by the key weights
(“what I am”), and the dot product does the comparison to find
matches.

Of course, that doesn’t mean we lose any information. This impoverished embedding space is only used
to do the matching to work out our attention scores. When we work out the context
vectors we use projection into value space, which can be as rich as we like.

It’s worth noting that although the example Raschka is using in the book has the
same dimensionality for the shared space for query and key vectors, and the space
for value vectors, there’s actually no need for that. I’ve seen specs for LLMs
where the QK space has fewer dimensions — which makes sense, at least for this
trivial example.

It’s also worth noting that this key/query space is impoverished in this example,
but in a real “alien” learned example, it could actually be quite complex and
rich — but much harder to understand than this example. Ultimately, the
nature of that embedding space will be learned in the same way as everything else,
and will match whatever thing the head in question has learned to do.

The elegance of dumb attention

So, that is (right now) my understanding of how scaled dot product attention works.
We’re just doing simple pattern matching, where each token’s input embedding is
projected by the query weights into a (learned) embedding space that is able to represent what it is
“looking for” in some sense. It’s also projected by the key weights into the same
space, but this time in a way that makes it point to what it “is” in the same sense.
Then the dot product matches those up so that we can associate input embeddings with
each other to work out our attention scores.

That all makes sense in my head, and I hope it does in at least a few other people’s 🙂

I’ll wrap this one up here; next time I’ll be posting on what I understand right
now about what we’ve been through in the book so far means for context lengths. We’ve seen
the upside of that hidden state that grows as the input sequence does — what are
the downsides?


  1. I’ll one-index for this. 

  2. I can’t think of any way a single head could, TBH. It’s considering all other tokens in parallel,
    so when it’s looking at the first “the” it doesn’t know that there’s another closer one. 

Source: www.gilesthomas.com

Chris Nyamu is a tech enthusiast and industry insider at TechieBrief.com, covering AI, cybersecurity, and emerging tech trends. With deep insights and a passion for innovation, he delivers expert analysis and breaking news, keeping readers ahead in the fast-paced world of technology.

Post Comment