class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
# CUDA graphs are disabled.
return {}
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
raise NotImplementedError("Piecewise CUDA graphs are not supported")
if self.compilation_config.level != 0:
raise NotImplementedError("Dynamo is not used. Compilation level must be 0")
padded_sizes: dict[int, int] = {}
assert len(self.cudagraph_sizes) > 0
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
if max(scheduler_output.num_scheduled_tokens.values()) > 1:
# Prefill is included.
return None
return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions.gpu[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
input_buffers.seq_lens.np[:batch_size] = self.max_model_len
input_buffers.seq_lens.np[batch_size:] = 0
input_buffers.seq_lens.copy_to_gpu()
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
return self.hidden_states[:batch_size]