SGLang Meets DeepSeek

This post introduces SGLang, an inference system designed to optimize large language model (LLM) deployment. We focus on:

  1. SGLang’s workflow and its distribution setup (e.g., PP/TP/DP constraints);
  2. DeepSeek’s inference challenges on SGLang and how DP Attention optimization addresses them.

SGLang Distribution Setup (w/o DP Attention)

Overview

The SGLang1 distribution setup can be summerized as follows:

  • NO PP support;
  • DP>1 is not supported on multiple nodes2. In fact, DP will be deprecated in the future. SGLang suggests SGLang Router for DP (an orchestrator written in Rust);
  • TP % nnodes == 0.
1
2
3
4
5
6
assert (
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!"

Configurations

The following figures show the structure for different parallel configurations in SGLang.

Figure 1. TP=1, DP=1, nnodes=1

Figure 2. TP=2, DP=1, nnodes=1

Figure 3. TP=4, DP=1, nnodes=2

Figure 4. TP=2, DP=2, nnodes=1

DP Attention Optimization

TP in Llama 3

Before introducing the DP Attention optimization, let's first investigate how TP is implemented in Llama 3.

  • word_embedding: vocab parallel;
  • positional_embedding: replicate;
  • attn: parallel on head dim;
  • mlp: parallel on column then row.

Llama 3 attn can be implemented as follows. For convenience, we omit RoPE and use MHA. Notice that the size of every matrix is reduced by tp times. Also, the KV cache size is reduced by tp times.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def forward(self, hidden_states):
# hidden_states: (bs, seq_len, hidden_dim)
q, k, v = qkv_proj(hidden_states).view(bs, seq_len,
head_num/tp, 3*head_dim).split().transpose(1, 2)
# qkv_proj on tp: (hidden_dim, 3*head_num*head_dim/tp)
# {q,k,v}: (bs, head_num/tp, seq_len, head_dim)
save_kvcache(k, v) # Reduce `tp`x KV cache size
attn_weights = matmul(q, k.transpose(2, 3))
attn_weights = softmax(attn_weights, dim=-1)
# attn_weights: (bs, head_num/tp, seq_len, seq_len)
attn_output = matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2)
# attn_output: (bs, seq_len, head_num/tp, head_dim)
attn_output = attn_output.reshape(bs, seq_len, -1)\
.all_reduce()
# attn_output: (bs, seq_len, hidden_dim)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

MLP in Llama 3 is paralleled based on column then row dispatching.

1
2
3
4
5
6
7
8
9
self.gate_up_proj = MergedColumnParallelLinear(...)
self.down_proj = RowParallelLinear(...)
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x

Problems of TP in DeepSeek

However, the TP implementation in Llama 3 is not optimal for DeepSeek. This is because

  • mla_attn:

    • Latent of size (1, kv_lora_rank + qk_rope_head_dim) for each token shall be saved. This cannot be divided along with the num_head dim — KV cache size cannot be reduced;

    • Some params (e.g. q_a_proj, kv_a_proj_with_mqa) cannot be paralleled — parameter duplications.

      1
      2
      3
      4
      5
      self.q_a_proj = ReplicatedLinear(...)
      self.q_b_proj = ColumnParallelLinear(...)
      self.kv_b_proj = ColumnParallelLinear(...)
      self.o_proj = RowParallelLinear(...)
      self.kv_a_proj_with_mqa = ReplicatedLinear(...)
  • moe_mlp:

    • Expert parallel (EP) is better than TP, since each of the experts is small.

SGLang DP Attention for DeepSeek V3

DP Attention intends to solve the above problems3 in mla_attn. The parallel policy becomes:

  • word_embedding: replicate
  • positional_embedding: replicate
  • attn: parallel on batch size (independent requests)
Figure 5. DP Attention

Its configuration requirement is:

  • 1 < AttnDP ≤ TP and TP % AttnDP = 0. This is because SGLang supports DP+TP attention.
1
2
3
4
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0

And some concrete configurations is shown below:

Figure 6. TP=2, AttnDP=2, nnodes=1

Figure 7. TP=4, AttnDP=2, nnodes=1

  1. https://github.com/sgl-project/sglang, on tag v0.4.4_post4↩︎

  2. When using DP Attention, DP has another meaning in SGLang. We use the word “AttnDP” instead of “DP” when enabling DP Attention for clarity (In fact, the canonical “DP” is always 1 in such situations). When using the word “DP”, we mean the canonical “DP” and DP Attention is disabled.↩︎

  3. See also https://zhuanlan.zhihu.com/p/15280741714↩︎