Skip to content

vllm.v1.worker.gpu.cudagraph_utils

CudaGraphManager

Source code in vllm/v1/worker/gpu/cudagraph_utils.py
class CudaGraphManager:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.device = device

        self.max_model_len = vllm_config.model_config.max_model_len
        self.compilation_config = vllm_config.compilation_config
        assert self.compilation_config is not None

        self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
        self.padded_sizes = self._init_padded_sizes()

        self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
        self.pool = torch.cuda.graph_pool_handle()
        self.hidden_states: torch.Tensor | None = None

    def _init_padded_sizes(self) -> dict[int, int]:
        if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
            # CUDA graphs are disabled.
            return {}
        if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
            raise NotImplementedError("Piecewise CUDA graphs are not supported")
        if self.compilation_config.level != 0:
            raise NotImplementedError("Dynamo is not used. Compilation level must be 0")

        padded_sizes: dict[int, int] = {}
        assert len(self.cudagraph_sizes) > 0
        for i in range(1, self.cudagraph_sizes[-1] + 1):
            for x in self.cudagraph_sizes:
                if i <= x:
                    padded_sizes[i] = x
                    break
        return padded_sizes

    def needs_capture(self) -> bool:
        return len(self.padded_sizes) > 0

    def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
        if max(scheduler_output.num_scheduled_tokens.values()) > 1:
            # Prefill is included.
            return None
        return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)

    def capture_graph(
        self,
        batch_size: int,
        model: nn.Module,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_metadata_builders: list[AttentionMetadataBuilder],
        kv_cache_config: KVCacheConfig,
    ) -> None:
        assert batch_size not in self.graphs

        # Prepare dummy inputs.
        input_ids = input_buffers.input_ids.gpu[:batch_size]
        positions = input_buffers.positions.gpu[:batch_size]

        input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
        input_buffers.query_start_loc.np[batch_size:] = batch_size
        input_buffers.query_start_loc.copy_to_gpu()
        input_buffers.seq_lens.np[:batch_size] = self.max_model_len
        input_buffers.seq_lens.np[batch_size:] = 0
        input_buffers.seq_lens.copy_to_gpu()

        input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
        slot_mappings = block_tables.slot_mappings[:, :batch_size]

        attn_metadata = build_attn_metadata(
            attn_metadata_builders=attn_metadata_builders,
            num_reqs=batch_size,
            num_tokens=batch_size,
            query_start_loc=input_buffers.query_start_loc,
            seq_lens=input_buffers.seq_lens,
            num_computed_tokens_cpu=None,  # FIXME
            block_tables=input_block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=kv_cache_config,
        )

        # Warm up.
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
            num_tokens=batch_size,
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
            )
            if self.hidden_states is None:
                self.hidden_states = torch.empty_like(hidden_states)
        torch.cuda.synchronize()

        # Capture the graph.
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, self.pool):
            with set_forward_context(
                attn_metadata,
                self.vllm_config,
                num_tokens=batch_size,
            ):
                hidden_states = model(
                    input_ids=input_ids,
                    positions=positions,
                )
            self.hidden_states[:batch_size] = hidden_states
        self.graphs[batch_size] = graph

    @torch.inference_mode()
    def capture(
        self,
        model: nn.Module,
        input_buffers: InputBuffers,
        block_tables: BlockTables,
        attn_metadata_builders: list[AttentionMetadataBuilder],
        kv_cache_config: KVCacheConfig,
    ) -> None:
        assert self.needs_capture()
        # Capture larger graphs first.
        sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
        if is_global_first_rank():
            sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")

        with freeze_gc(), graph_capture(device=self.device):
            for batch_size in sizes_to_capture:
                self.capture_graph(
                    batch_size,
                    model,
                    input_buffers,
                    block_tables,
                    attn_metadata_builders,
                    kv_cache_config,
                )

    def run(self, batch_size: int) -> torch.Tensor:
        assert batch_size in self.graphs
        self.graphs[batch_size].replay()
        return self.hidden_states[:batch_size]

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_sizes instance-attribute

cudagraph_sizes = sorted(cudagraph_capture_sizes)

device instance-attribute

device = device

graphs instance-attribute

graphs: dict[int, CUDAGraph] = {}

hidden_states instance-attribute

hidden_states: Tensor | None = None

max_model_len instance-attribute

max_model_len = max_model_len

padded_sizes instance-attribute

padded_sizes = _init_padded_sizes()

pool instance-attribute

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(vllm_config: VllmConfig, device: device)
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def __init__(
    self,
    vllm_config: VllmConfig,
    device: torch.device,
):
    self.vllm_config = vllm_config
    self.device = device

    self.max_model_len = vllm_config.model_config.max_model_len
    self.compilation_config = vllm_config.compilation_config
    assert self.compilation_config is not None

    self.cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
    self.padded_sizes = self._init_padded_sizes()

    self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
    self.pool = torch.cuda.graph_pool_handle()
    self.hidden_states: torch.Tensor | None = None

_init_padded_sizes

_init_padded_sizes() -> dict[int, int]
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def _init_padded_sizes(self) -> dict[int, int]:
    if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
        # CUDA graphs are disabled.
        return {}
    if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
        raise NotImplementedError("Piecewise CUDA graphs are not supported")
    if self.compilation_config.level != 0:
        raise NotImplementedError("Dynamo is not used. Compilation level must be 0")

    padded_sizes: dict[int, int] = {}
    assert len(self.cudagraph_sizes) > 0
    for i in range(1, self.cudagraph_sizes[-1] + 1):
        for x in self.cudagraph_sizes:
            if i <= x:
                padded_sizes[i] = x
                break
    return padded_sizes

capture

capture(
    model: Module,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_metadata_builders: list[AttentionMetadataBuilder],
    kv_cache_config: KVCacheConfig,
) -> None
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
@torch.inference_mode()
def capture(
    self,
    model: nn.Module,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_metadata_builders: list[AttentionMetadataBuilder],
    kv_cache_config: KVCacheConfig,
) -> None:
    assert self.needs_capture()
    # Capture larger graphs first.
    sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
    if is_global_first_rank():
        sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")

    with freeze_gc(), graph_capture(device=self.device):
        for batch_size in sizes_to_capture:
            self.capture_graph(
                batch_size,
                model,
                input_buffers,
                block_tables,
                attn_metadata_builders,
                kv_cache_config,
            )

capture_graph

capture_graph(
    batch_size: int,
    model: Module,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_metadata_builders: list[AttentionMetadataBuilder],
    kv_cache_config: KVCacheConfig,
) -> None
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def capture_graph(
    self,
    batch_size: int,
    model: nn.Module,
    input_buffers: InputBuffers,
    block_tables: BlockTables,
    attn_metadata_builders: list[AttentionMetadataBuilder],
    kv_cache_config: KVCacheConfig,
) -> None:
    assert batch_size not in self.graphs

    # Prepare dummy inputs.
    input_ids = input_buffers.input_ids.gpu[:batch_size]
    positions = input_buffers.positions.gpu[:batch_size]

    input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
    input_buffers.query_start_loc.np[batch_size:] = batch_size
    input_buffers.query_start_loc.copy_to_gpu()
    input_buffers.seq_lens.np[:batch_size] = self.max_model_len
    input_buffers.seq_lens.np[batch_size:] = 0
    input_buffers.seq_lens.copy_to_gpu()

    input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
    slot_mappings = block_tables.slot_mappings[:, :batch_size]

    attn_metadata = build_attn_metadata(
        attn_metadata_builders=attn_metadata_builders,
        num_reqs=batch_size,
        num_tokens=batch_size,
        query_start_loc=input_buffers.query_start_loc,
        seq_lens=input_buffers.seq_lens,
        num_computed_tokens_cpu=None,  # FIXME
        block_tables=input_block_tables,
        slot_mappings=slot_mappings,
        kv_cache_config=kv_cache_config,
    )

    # Warm up.
    with set_forward_context(
        attn_metadata,
        self.vllm_config,
        num_tokens=batch_size,
    ):
        hidden_states = model(
            input_ids=input_ids,
            positions=positions,
        )
        if self.hidden_states is None:
            self.hidden_states = torch.empty_like(hidden_states)
    torch.cuda.synchronize()

    # Capture the graph.
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, self.pool):
        with set_forward_context(
            attn_metadata,
            self.vllm_config,
            num_tokens=batch_size,
        ):
            hidden_states = model(
                input_ids=input_ids,
                positions=positions,
            )
        self.hidden_states[:batch_size] = hidden_states
    self.graphs[batch_size] = graph

get_cudagraph_size

get_cudagraph_size(
    scheduler_output: SchedulerOutput,
) -> int | None
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def get_cudagraph_size(self, scheduler_output: SchedulerOutput) -> int | None:
    if max(scheduler_output.num_scheduled_tokens.values()) > 1:
        # Prefill is included.
        return None
    return self.padded_sizes.get(scheduler_output.total_num_scheduled_tokens)

needs_capture

needs_capture() -> bool
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def needs_capture(self) -> bool:
    return len(self.padded_sizes) > 0

run

run(batch_size: int) -> Tensor
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
def run(self, batch_size: int) -> torch.Tensor:
    assert batch_size in self.graphs
    self.graphs[batch_size].replay()
    return self.hidden_states[:batch_size]

freeze_gc

freeze_gc()
Source code in vllm/v1/worker/gpu/cudagraph_utils.py
@contextmanager
def freeze_gc():
    gc.collect()
    gc.freeze()
    try:
        yield
    finally:
        gc.unfreeze()