def all_gather_sampler_output(
sampler_output: SamplerOutput,
num_reqs: int,
tp_size: int,
) -> SamplerOutput:
n = (num_reqs + tp_size - 1) // tp_size
sampler_output.sampled_token_ids = pad_and_all_gather(
sampler_output.sampled_token_ids, n)[:num_reqs]
# TODO(woosuk): 3 small all-gathers, could be merged into one.
logprobs_tensors = sampler_output.logprobs_tensors
if logprobs_tensors is not None:
logprobs_tensors.logprob_token_ids = pad_and_all_gather(
logprobs_tensors.logprob_token_ids, n)[:num_reqs]
logprobs_tensors.logprobs = pad_and_all_gather(
logprobs_tensors.logprobs, n)[:num_reqs]
logprobs_tensors.selected_token_ranks = pad_and_all_gather(
logprobs_tensors.selected_token_ranks, n)[:num_reqs]
return sampler_output