Understanding Masked Multi-Head Attention in Simple Terms

Aditya Mangal
3 min readAug 27, 2024

--

Masked Multi-Head Attention is a crucial mechanism used in the decoder part of the Transformer model, especially in tasks like language generation, where the model predicts the next word in a sequence.

1. What is Masked Multi-Head Attention?

Masked multi-head attention is a variation of multi-head attention that ensures the model doesn’t “cheat” by looking ahead at future words when generating a sequence. This is important in tasks like text generation, where you want the model to predict the next word based only on the words it has seen so far.

2. Why is Masking Necessary?

Imagine you’re writing a sentence word by word. You want to predict each word in the sequence one at a time, without peeking at the words that come later in the sentence. If the model could see future words, it would be like knowing the answer before you finish reading the question — this would defeat the purpose of prediction.

3. How Does Masking Work?

  • Attention Mechanism Recap: Normally, in self-attention, each word in a sequence looks at all the other words to decide which ones are important. For example, if you’re predicting the word “sat” in “The cat sat on the mat,” the model would look at “The,” “cat,” “on,” “the,” and “mat.”
  • Masking: In masked multi-head attention, the model uses a mask to block out future words. This means when predicting the word “sat,” the model can only consider “The” and “cat.” It can’t look at “on,” “the,” or “mat” because they come later in the sequence.
  • Technically, this is done by applying a mask to the attention scores before the softmax step. The mask assigns a very large negative value (like negative infinity) to the positions corresponding to future words, so when the softmax is applied, these positions get a probability of zero.

4. How Does it Work in Practice?

Let’s break it down step by step:

  • Step 1: Input Sequence: Suppose the input sequence is “The cat sat on the mat.”
  • Step 2: Self-Attention with Masking: When the model is predicting the word “sat,” it applies masking so that it can only consider “The” and “cat” — the words that have come before it.
  • Step 3: Multi-Head Attention: Multiple heads process this masked information, each focusing on different aspects like word meaning, position, or context, just like in regular multi-head attention.
  • Step 4: Combining Results: The insights from all the heads are combined to predict the next word, but crucially, without considering any future words.

5. Importance in Language Generation

In language generation tasks like machine translation or text generation, masked multi-head attention is essential. It allows the model to generate text in a left-to-right manner, ensuring that each word is predicted based only on the preceding context.

Summary

  • Masked Multi-Head Attention prevents the model from looking ahead at future words when generating text, ensuring that each word is predicted based on the preceding words only.
  • Masking is applied to block future positions, so the model only considers previous words.
  • Multi-Head attention still works as usual, with multiple heads analyzing the masked information in parallel, but only using the allowed context.

This technique is key to making the Transformer model effective at generating coherent and contextually appropriate text.

--

--

Aditya Mangal

My Personal Quote to overcome problems and remove dependencies - "It's not the car, it's the driver who win the race".