def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc: CpuGpuBuffer,
seq_lens: CpuGpuBuffer,
num_computed_tokens_cpu: torch.Tensor,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
seq_lens_gpu = seq_lens.gpu[:num_reqs]
seq_lens_cpu = seq_lens.cpu[:num_reqs]
max_seq_len = int(seq_lens.np[:num_reqs].max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_gpu,
seq_lens_cpu=seq_lens_cpu,
max_seq_len=max_seq_len,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
causal=True,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata