Quick Review of Self-attention
— Edited from chatGPT
Self-attention, also known as intra-attention, is a mechanism used in neural networks to relate different positions of a single sequence in order to compute a representation of that sequence. It has been particularly influential in the development of transformer models and is a core component of architectures such as BERT, GPT, and the transformer-based parts of SAM (Segment Anything Model).
Why It Is Called Self-Attention
The term “self-attention” is used because the mechanism allows the model to pay attention to different parts of the input sequence itself. Here’s a detailed explanation of how self-attention works and why it’s named this way:
1. **Self-Referencing**:
— In self-attention, each position in the input sequence is compared with every other position in the same sequence. This means each element of the sequence can reference every other element, including itself.
— This self-referencing capability is why it is called “self-attention.”
2. **Attention Mechanism**:
— Attention mechanisms allow models to focus on different parts of the input when making predictions. Self-attention specifically computes attention scores within a single sequence, enabling the model to weigh the importance of each element relative to others in the sequence.
How Self-Attention Works
1. **Input Representation**:
— Consider an input sequence with multiple elements (e.g., words in a sentence, tokens in a document). Each element is represented as a vector.
2. **Query, Key, and Value Vectors**:
— For each element in the sequence, self-attention generates three vectors: Query (Q), Key (K), and Value (V).
— These vectors are created by multiplying the input vector with learned weight matrices.
3. **Attention Scores**:
— The attention score for a pair of elements is computed by taking the dot product of the Query vector of one element with the Key vector of another element.
— This operation results in a score that indicates how much one element should attend to another.
4. **Softmax Normalization**:
— The attention scores are normalized using the softmax function to ensure they sum to one. This produces a set of weights that determine the importance of each element relative to others.
5. **Weighted Sum**:
— Each element’s final representation is computed as a weighted sum of the Value vectors, with the weights being the normalized attention scores.
6. **Output**:
— The output is a new sequence of vectors where each vector is an aggregation of the entire input sequence, weighted by their respective attention scores.
Example Calculation
For simplicity, let’s consider an input sequence of three elements: X = [x_1, x_2, x_3] .
Step-by-Step Process:
- **Generate Q, K, V for each element**:
Q1, K1, V1 = WQ * x1, WK * x1, WV * x1
Q2, K2, V2 = WQ * x2, WK * x2, WV * x2
Q3, K3, V3 = WQ * x3, WK * x3, WV * x3
2. **Compute Attention Scores**:
score_12 = Q1.dot(K2) / sqrt(d_k)
score_13 = Q1.dot(K3) / sqrt(d_k)
score_21 = Q2.dot(K1) / sqrt(d_k)
# and so on for all pairs
3. **Normalize Scores with Softmax**:
weights_1 = softmax([score_11, score_12, score_13])
weights_2 = softmax([score_21, score_22, score_23])
# and so on for each element
```
4. **Compute Weighted Sum**:
output1 = weights_1[0] * V1 + weights_1[1] * V2 + weights_1[2] * V3
output2 = weights_2[0] * V1 + weights_2[1] * V2 + weights_2[2] * V3
# and so on for each element
Summary
- **Self-Referencing**: Self-attention involves comparing each element of the sequence with every other element in the same sequence, hence “self”.
- **Attention Mechanism**: It uses attention scores to weigh the importance of each element relative to the others.
- **Output**: The result is a new sequence where each element is a context-aware representation, considering the entire sequence.
Self-attention allows models to capture dependencies regardless of their distance in the sequence, making it highly effective for tasks involving sequential data such as natural language processing and, in the case of SAM, image segmentation based on combined embeddings of images and prompts.