class Sampler:
def __init__(
self,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
self.logprobs_mode = logprobs_mode
def sample_token(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits = logits / temp.view(-1, 1)
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
sampled = gumbel_sample(
probs,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
)
sampled = sampled.to(torch.int64)
return sampled, logits if return_logits else None
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample_token(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
sampled, _ = self.sample_token(
logits, sampling_metadata, return_logits=False
)
logprobs_tensors = None
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output