loke.dev
Header image for The Draft Model Gamble: What Nobody Tells You About Speculative Decoding in High-Throughput LLMs

The Draft Model Gamble: What Nobody Tells You About Speculative Decoding in High-Throughput LLMs

Why adding a 'dumb' model to your inference pipeline is the secret to 3x throughput—and the mathematical threshold where the gamble starts to slow you down.

· 8 min read

I was staring at a Grafana dashboard watching the token-per-second count on a Llama-3-70B instance crawl along like a tired snail. We were paying for high-end H100s, yet the user experience felt like 1996 dial-up because our sequence lengths were hitting the memory bandwidth ceiling. Then I flipped the switch on a 135M-parameter "draft model" for speculative decoding, and the throughput tripled instantly—until we changed the prompt type, and everything ground to a halt.

This is the reality of speculative decoding. It’s often sold as a "free lunch" for LLM inference, a way to cheat the laws of physics and get 70B-level quality at 7B-level speeds. But it’s actually a high-stakes mathematical gamble. If you don't understand the threshold where your draft model becomes a liability, you’re just adding latency and wasting VRAM.

The Core Mechanism: One Model to Guess, One to Correct

At its heart, speculative decoding is a parlor trick that works because GPUs are weird.

Standard LLM inference is "autoregressive." To generate 10 tokens, you have to run the model 10 times, one after another. This is incredibly inefficient because most of a GPU's power is spent moving model weights from memory to the processor, not actually doing the math. We call this being memory-bandwidth bound.

Speculative decoding changes the flow:
1. A tiny "Draft Model" (e.g., TinyLlama-135M) quickly guesses the next $K$ tokens (say, 5 tokens).
2. The big "Target Model" (e.g., Llama-3-70B) looks at all 5 tokens at once in a single forward pass.
3. Because of how transformer architectures work, the Target Model can verify those 5 tokens in roughly the same time it takes to generate 1 token.
4. If the Target Model agrees with the Draft Model’s guesses, you just "won" 5 tokens for the price of one big model call.

If the draft model is wrong on token #3, you keep tokens #1 and #2, throw away the rest, and restart.

The Math You Can’t Ignore

Most blog posts skip the math, but the math is why your implementation might be failing. The speedup ($S$) of speculative decoding is roughly defined by this relationship:

$$S = \frac{1 - \alpha^{K+1}}{(1 - \alpha)(1 + rK)}$$

Where:
* $\alpha$ is the acceptance rate (the probability the target model agrees with the draft).
* $K$ is the lookahead (how many tokens the draft model guesses).
* $r$ is the ratio of time it takes to run one draft step versus one target step.

Here is the kicker: If your draft model is too big ($r$ is high) or too stupid ($\alpha$ is low), your speedup $S$ becomes less than 1. That means you are moving slower than if you had just used the big model alone.

I’ve seen teams spend weeks trying to get a 7B model to draft for a 70B model, only to realize the 7B model was so slow that even with a 90% acceptance rate, they were losing time. The "sweet spot" for $r$ is usually $1:10$ or $1:50$.

Hands-on: Implementing Speculative Decoding

You don't need to write custom CUDA kernels to test this. Hugging Face has made this remarkably accessible with their assisted_generation API.

Here is a clean implementation of how you’d actually set this up in a production-like script:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time

# The "Target" model - the one we want high quality from
target_model_id = "meta-llama/Meta-Llama-3-8B"
# The "Draft" model - the tiny one that guesses
draft_model_id = "ibm-fms/llama-135m-accelerator"

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(target_model_id)
target_model = AutoModelForCausalLM.from_pretrained(
    target_model_id, torch_dtype=torch.float16
).to(device)

draft_model = AutoModelForCausalLM.from_pretrained(
    draft_model_id, torch_dtype=torch.float16
).to(device)

prompt = "Explain the concept of quantum entanglement in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Standard Generation
start = time.time()
outputs_std = target_model.generate(**inputs, max_new_tokens=50)
print(f"Standard Time: {time.time() - start:.2f}s")

# Speculative Decoding (Assisted Generation)
start = time.time()
outputs_spec = target_model.generate(
    **inputs, 
    assistant_model=draft_model, # This triggers speculative decoding
    max_new_tokens=50
)
print(f"Speculative Time: {time.time() - start:.2f}s")

When you run this, you’ll notice something immediately: the first few tokens might take a beat longer, but then the text starts "chunking" out. That’s the draft model winning its bets.

The "Draft Model Cliff": When Accuracy Matters More Than Speed

Why isn't everyone doing this for every task? Because the Acceptance Rate ($\alpha$) is highly volatile.

In code generation, the draft model often has a high acceptance rate because syntax is predictable. If the draft model starts a for loop, there’s a high probability the target model agrees with the i in range that follows.

But in creative writing or complex reasoning, the draft model falls off a cliff. If the target model wants to use the word "melancholy" but the draft model guesses "sad," the draft is rejected. You just wasted the compute time of the draft model *and* the verification pass of the target model for zero gain.

The Gotcha: If your acceptance rate drops below ~50-60%, the overhead of running the draft model and the extra logic of the verification step usually makes the total latency worse than vanilla inference.

Practical Logic: The speculative_sample Function

If you're building a custom inference engine, you can't just use model.generate(). You need to handle the distribution matching. The target model shouldn't just check if the draft's token is *the* top token; it should check if the draft's token is *statistically likely* under the target's distribution.

Here is a simplified logic flow for how the verification step works:

def speculative_verify(draft_logits, target_logits, draft_tokens):
    """
    Simplified logic for verifying draft tokens.
    In reality, we use a modified rejection sampling.
    """
    accepted_tokens = []
    for i in range(len(draft_tokens)):
        # Calculate probabilities
        p = torch.softmax(target_logits[i], dim=-1)
        q = torch.softmax(draft_logits[i], dim=-1)
        
        token_id = draft_tokens[i]
        
        # Rejection Sampling Criterion
        # If Target thinks token_id is more likely than Draft did, always accept.
        # Otherwise, accept with a probability proportional to the difference.
        if p[token_id] >= q[token_id]:
            accepted_tokens.append(token_id)
        else:
            threshold = torch.rand(1).item()
            if threshold < (p[token_id] / q[token_id]):
                accepted_tokens.append(token_id)
            else:
                # Reject this token and all subsequent ones
                # Sample a new token from the 'residual' distribution
                residual_dist = torch.clamp(p - q, min=0)
                new_token = torch.multinomial(residual_dist / residual_dist.sum(), 1)
                accepted_tokens.append(new_token.item())
                break
    return accepted_tokens

The KV Cache Nightmare

Here’s what nobody tells you about building this at scale: The KV Cache management is a nightmare.

In standard inference, the Key-Value (KV) cache grows by exactly one token per step. In speculative decoding, the draft model maintains its own KV cache, and the target model has its own. When the target model rejects three tokens, you have to "rewind" the KV cache for both models.

If you are using a library like vLLM or TGI, they handle this for you using PagedAttention. But if you're rolling your own implementation on bare metal or using TensorRT-LLM, you need to be extremely careful with memory pointers. If you forget to truncate the KV cache after a rejection, the next forward pass will have corrupted context, and your model will start hallucinating gibberish.

How to Choose Your Draft Model Gamble

If you're looking to implement this in a high-throughput environment, don't just pick the smallest model available. Consider these three factors:

1. The Vocabulary Match

This is the biggest technical hurdle. Your draft model must use the exact same tokenizer as your target model. If your target is Llama-3 (128k vocab) and your draft is a generic GPT-2 model (50k vocab), it literally cannot work. You would need to add a "translation layer" which eats up all your speed gains. This is why IBM's fms-accelerator models or Medusa heads are becoming popular—they are specifically designed to match the target's vocabulary.

2. The Training Data Overlap

A draft model trained on Wikipedia won't help you much if your target model is doing SQL generation. The closer the draft model's training distribution is to the target model's output, the higher your $\alpha$ will be. I've found that distilling the target model into a smaller version of itself (e.g., using the target's outputs as the ground truth for training the draft) is the most effective way to boost the acceptance rate.

3. The "Medusa" Alternative

If you don't want to manage two separate models, look into Medusa. Instead of a second model, you add multiple "heads" (extra linear layers) to the top of your big model. Each head tries to predict the next+1, next+2, next+3 tokens simultaneously.
* Pros: No second model to load; KV cache is easier to manage.
* Cons: You have to train these heads yourself for your specific model.

Measuring Success

Stop looking at just "Tokens per Second" (TPS). In speculative decoding, TPS can be misleading because it might include tokens that were later rejected. You need to measure Effective Throughput—the number of tokens actually delivered to the user per second.

I recommend tracking these three metrics:
1. Acceptance Rate per Request Type: Categorize prompts (code, chat, summary).
2. Draft Latency vs. Target Latency: If your draft takes 5ms and your target takes 50ms, you have a 1:10 ratio. That's good. If it's 20ms vs 50ms, you're in the danger zone.
3. VRAM Overhead: Speculative decoding requires keeping two models in memory. If this forces you to use a lower quantization (e.g., moving from 8-bit to 4-bit) to fit the draft model, your quality might drop more than the speed gain is worth.

Is the Gamble Worth It?

Speculative decoding is not a silver bullet. It’s an optimization for when you are bandwidth-bound.

If you are running a batch size of 128 on an H100, your GPU is likely already compute-bound. In that scenario, speculative decoding will actually slow you down because you’re adding more compute to a system that’s already maxed out.

But if you are serving a single user (batch size 1) and they want their answer *now*, the draft model gamble is the most effective way to make a massive LLM feel snappy. Just keep a close eye on your $\alpha$, and be ready to pull the plug if your draft model starts guessing wrong.