Quick Review: Cross-attention

Paul Xiong
3 min readJun 7, 2024

--

edited from chatGPT

Cross-attention is a mechanism in neural networks that allows one sequence to attend to another sequence, effectively integrating information from both. This is particularly useful in tasks where there are two distinct types of inputs that need to influence each other, such as image and prompt embeddings in the Segment Anything Model (SAM).

Cross-Attention Explained

Cross-attention involves updating the embeddings of one sequence (e.g., image embeddings) based on the information from another sequence (e.g., prompt embeddings), and vice versa. This allows the model to effectively integrate and utilize information from both sources.

How Cross-Attention Works

  1. Input Sequences:
  • Image Embeddings: Representations of the image, typically produced by an image encoder.
  • Prompt Embeddings: Representations of the prompts (e.g., points, boxes, masks, text), typically produced by a prompt encoder.

2. Query, Key, and Value Vectors:

  • For each element in the image embeddings and prompt embeddings, cross-attention generates three vectors: Query (Q), Key (K), and Value (V).

3. Attention Scores:

  • The attention score for an element in the image embeddings with respect to an element in the prompt embeddings is computed by taking the dot product of the Query vector of the image embedding with the Key vector of the prompt embedding.
  • Similarly, attention scores can be computed in the opposite direction, from prompt embeddings to image embeddings. this is different to self-attention.

4. Softmax Normalization:

  • The attention scores are normalized using the softmax function to produce a set of weights that sum to one.

5. Weighted Sum:

  • Each element’s final representation is computed as a weighted sum of the Value vectors from the other sequence, with the weights being the normalized attention scores.

Example Calculation

Let’s consider an example with two sequences: image embeddings I=[i1,i2,i3]I = [i_1, i_2, i_3]I=[i1​,i2​,i3​] and prompt embeddings P=[p1,p2,p3]P = [p_1, p_2, p_3]P=[p1​,p2​,p3​].

Step-by-Step Process:

  1. Generate Q, K, V for Image and Prompt Embeddings:
Q_image, K_image, V_image = WQ_image * I, WK_image * I, WV_image * I
Q_prompt, K_prompt, V_prompt = WQ_prompt * P, WK_prompt * P, WV_prompt * P

2. Compute Attention Scores (Image to Prompt):

score_ip = Q_image @ K_prompt.T / sqrt(d_k)  # Shape: [num_image_tokens, num_prompt_tokens]

3. Normalize Scores with Softmax:

weights_ip = softmax(score_ip, dim=-1)  # Shape: [num_image_tokens, num_prompt_tokens]

4. Compute Weighted Sum (Update Image Embeddings):

I_updated = weights_ip @ V_prompt  # Shape: [num_image_tokens, embedding_dim]

5. Compute Attention Scores (Prompt to Image):

score_pi = Q_prompt @ K_image.T / sqrt(d_k)  # Shape: [num_prompt_tokens, num_image_tokens]

6. Normalize Scores with Softmax:

weights_pi = softmax(score_pi, dim=-1)  # Shape: [num_prompt_tokens, num_image_tokens]

7. Compute Weighted Sum (Update Prompt Embeddings):

P_updated = weights_pi @ V_image  # Shape: [num_prompt_tokens, embedding_dim]

Implementation in a Transformer Decoder Block

Here’s how you might implement cross-attention within a transformer decoder block:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
def __init__(self, d_model, nhead):
super(CrossAttention, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead)
self.cross_attn_image_to_prompt = nn.MultiheadAttention(d_model, nhead)
self.cross_attn_prompt_to_image = nn.MultiheadAttention(d_model, nhead)
self.linear1 = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(0.1)
self.linear2 = nn.Linear(d_model, d_model)

def forward(self, image_embedding, prompt_embedding):
# Self-attention on the prompt
prompt_embedding2, _ = self.self_attn(prompt_embedding, prompt_embedding, prompt_embedding)
prompt_embedding = prompt_embedding + self.dropout(prompt_embedding2)
prompt_embedding = self.linear2(prompt_embedding)

# Cross-attention: prompt to image embedding
image_embedding2, _ = self.cross_attn_prompt_to_image(image_embedding, prompt_embedding, prompt_embedding)
image_embedding = image_embedding + self.dropout(image_embedding2)
image_embedding = self.linear2(image_embedding)

# Cross-attention: image to prompt embedding
prompt_embedding2, _ = self.cross_attn_image_to_prompt(prompt_embedding, image_embedding, image_embedding)
prompt_embedding = prompt_embedding + self.dropout(prompt_embedding2)
prompt_embedding = self.linear2(prompt_embedding)

return image_embedding, prompt_embedding

# Example usage
d_model = 512
nhead = 8
image_embedding = torch.rand(10, 32, d_model) # Batch size: 10, Sequence length: 32, Embedding dim: 512
prompt_embedding = torch.rand(10, 5, d_model) # Batch size: 10, Sequence length: 5, Embedding dim: 512

cross_attention = CrossAttention(d_model, nhead)
updated_image_embedding, updated_prompt_embedding = cross_attention(image_embedding, prompt_embedding)

Summary

  • Cross-Attention Mechanism: Allows the model to update image embeddings based on prompt embeddings and vice versa, facilitating integration of information from both sources.
  • Bidirectional Attention: Cross-attention is computed in both directions to ensure comprehensive information flow between the image and prompt embeddings.
  • Effective Integration: This mechanism enhances the model’s ability to generate accurate segmentation masks by considering the relationships and dependencies between the image and prompt embeddings.

Cross-attention is essential in tasks like SAM where combining contextual information from different modalities (images and prompts) is crucial for accurate and effective performance.

--

--

Paul Xiong
Paul Xiong

Written by Paul Xiong

Predicting the next word (token) is what powers ChatGPT, while predicting the next photo (embedding) forms the foundation of ImageGPT.