Pooling models¶
Source https://github.com/vllm-project/vllm/tree/main/examples/online_serving/pooling.
Cohere rerank usage¶
Embedding requests base64 encoding_format usage¶
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_base64_client.py
Embedding requests bytes encoding_format usage¶
# vllm serve intfloat/e5-small
python examples/online_serving/pooling/embedding_requests_bytes_client.py
Jinaai rerank usage¶
Multi vector retrieval usage¶
Named Entity Recognition (NER) usage¶
OpenAI chat embedding for multimodal usage¶
OpenAI classification usage¶
# vllm serve jason9693/Qwen2.5-1.5B-apeach
python examples/online_serving/pooling/openai_classification_client.py
OpenAI cross_encoder score usage¶
# vllm serve BAAI/bge-reranker-v2-m3
python examples/online_serving/pooling/openai_cross_encoder_score.py
OpenAI cross_encoder score for multimodal usage¶
# vllm serve jinaai/jina-reranker-m0
python examples/online_serving/pooling/openai_cross_encoder_score_for_multimodal.py
OpenAI embedding usage¶
OpenAI embedding matryoshka dimensions usage¶
# vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
python examples/online_serving/pooling/openai_embedding_matryoshka_fy.py
OpenAI pooling usage¶
# vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
python examples/online_serving/pooling/openai_pooling_client.py
Online Prithvi Geospatial MAE usage¶
Example materials¶
cohere_rerank_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base
"""
import cohere
from cohere import Client, ClientV2
model = "BAAI/bge-reranker-base"
query = "What is the capital of France?"
documents = [
"The capital of France is Paris",
"Reranking is fun!",
"vLLM is an open-source framework for fast AI serving",
]
def cohere_rerank(
client: Client | ClientV2, model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents)
def main():
# cohere v1 client
cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
print("-" * 50)
print("rerank_v1_result:\n", rerank_v1_result)
print("-" * 50)
# or the v2
cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
print("rerank_v2_result:\n", rerank_v2_result)
print("-" * 50)
if __name__ == "__main__":
main()
embedding_requests_base64_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
start a supported embeddings model server with `vllm serve`, e.g.
vllm serve intfloat/e5-small
"""
import argparse
import base64
import requests
import torch
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
binary2tensor,
)
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model
# The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
prompt = {
"model": model_name,
"input": "vLLM is great!",
"encoding_format": "base64",
"embed_dtype": embed_dtype,
"endianness": endianness,
}
response = post_http_request(prompt=prompt, api_url=api_url)
embedding = []
for data in response.json()["data"]:
binary = base64.b64decode(data["embedding"])
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
embedding.append(tensor.to(torch.float32))
embedding = torch.cat(embedding)
print(embed_dtype, endianness, embedding.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
embedding_requests_bytes_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
start a supported embeddings model server with `vllm serve`, e.g.
vllm serve intfloat/e5-small
"""
import argparse
import json
import requests
import torch
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
MetadataItem,
decode_pooling_output,
)
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model
# The OpenAI client does not support the bytes encoding_format.
# The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
prompt = {
"model": model_name,
"input": "vLLM is great!",
"encoding_format": "bytes",
"embed_dtype": embed_dtype,
"endianness": endianness,
}
response = post_http_request(prompt=prompt, api_url=api_url)
metadata = json.loads(response.headers["metadata"])
body = response.content
items = [MetadataItem(**x) for x in metadata["data"]]
embedding = decode_pooling_output(items=items, body=body)
embedding = [x.to(torch.float32) for x in embedding]
embedding = torch.cat(embedding)
print(embed_dtype, endianness, embedding.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
jinaai_rerank_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base
"""
import json
import requests
url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"model": "BAAI/bge-reranker-base",
"query": "What is the capital of France?",
"documents": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Horses and cows are both animals",
],
}
def main():
response = requests.post(url, headers=headers, json=data)
# Check the response
if response.status_code == 200:
print("Request successful!")
print(json.dumps(response.json(), indent=2))
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()
multi_vector_retrieval_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Pooling API for multi vector retrieval.
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve BAAI/bge-m3
"""
import argparse
import requests
import torch
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompt = {"model": model_name, "input": prompts}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
for output in pooling_response.json()["data"]:
multi_vector = torch.tensor(output["data"])
print(multi_vector.shape)
if __name__ == "__main__":
args = parse_args()
main(args)
ner_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
"""
Example online usage of Pooling API for Named Entity Recognition (NER).
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve boltuix/NeuroBERT-NER
"""
import argparse
import requests
import torch
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
return parser.parse_args()
def main(args):
from transformers import AutoConfig, AutoTokenizer
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
# Load tokenizer and config
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
label_map = config.id2label
# Input text
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
prompt = {"model": model_name, "input": text}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
# Run inference
output = pooling_response.json()["data"][0]
logits = torch.tensor(output["data"])
predictions = logits.argmax(dim=-1)
inputs = tokenizer(text, return_tensors="pt")
# Map predictions to labels
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
labels = [label_map[p.item()] for p in predictions]
assert len(tokens) == len(predictions)
# Print results
for token, label in zip(tokens, labels):
if token not in tokenizer.all_special_tokens:
print(f"{token:15} → {label}")
if __name__ == "__main__":
args = parse_args()
main(args)
openai_chat_embedding_client_for_multimodal.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Example Python client for multimodal embedding API using vLLM API server.
Refer to each `run_*` function for the command to run the server for that model.
"""
import argparse
import base64
import io
from typing import Literal
from openai import OpenAI
from openai._types import NOT_GIVEN, NotGiven
from openai.types.chat import ChatCompletionMessageParam
from openai.types.create_embedding_response import CreateEmbeddingResponse
from PIL import Image
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
def create_chat_embeddings(
client: OpenAI,
*,
messages: list[ChatCompletionMessageParam],
model: str,
encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN,
) -> CreateEmbeddingResponse:
"""
Convenience function for accessing vLLM's Chat Embeddings API,
which is an extension of OpenAI's existing Embeddings API.
"""
return client.post(
"/embeddings",
cast_to=CreateEmbeddingResponse,
body={"messages": messages, "model": model, "encoding_format": encoding_format},
)
def run_clip(client: OpenAI, model: str):
"""
Start the server using:
vllm serve openai/clip-vit-base-patch32 \
--runner pooling
"""
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
model=model,
encoding_format="float",
)
print("Image embedding output:", response.data[0].embedding)
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "a photo of a cat"},
],
}
],
model=model,
encoding_format="float",
)
print("Text embedding output:", response.data[0].embedding)
def run_dse_qwen2_vl(client: OpenAI, model: str):
"""
Start the server using:
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
--runner pooling \
--trust-remote-code \
--max-model-len 8192 \
--chat-template examples/template_dse_qwen2_vl.jinja
"""
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url,
},
},
{"type": "text", "text": "What is shown in this image?"},
],
}
],
model=model,
encoding_format="float",
)
print("Image embedding output:", response.data[0].embedding)
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
# of the minimum input size
buffer = io.BytesIO()
image_placeholder = Image.new("RGB", (56, 56))
image_placeholder.save(buffer, "png")
buffer.seek(0)
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_placeholder}",
},
},
{"type": "text", "text": "Query: What is the weather like today?"},
],
}
],
model=model,
encoding_format="float",
)
print("Text embedding output:", response.data[0].embedding)
def run_siglip(client: OpenAI, model: str):
"""
Start the server using:
vllm serve google/siglip-base-patch16-224 \
--runner pooling
"""
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
model=model,
encoding_format="float",
)
print("Image embedding output:", response.data[0].embedding)
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "a photo of a cat"},
],
}
],
model=model,
encoding_format="float",
)
print("Text embedding output:", response.data[0].embedding)
def run_vlm2vec(client: OpenAI, model: str):
"""
Start the server using:
vllm serve TIGER-Lab/VLM2Vec-Full \
--runner pooling \
--trust-remote-code \
--max-model-len 4096 \
--chat-template examples/template_vlm2vec_phi3v.jinja
"""
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
}
],
model=model,
encoding_format="float",
)
print("Image embedding output:", response.data[0].embedding)
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{
"type": "text",
"text": "Represent the given image with the following question: What is in the image.",
},
],
}
],
model=model,
encoding_format="float",
)
print("Image+Text embedding output:", response.data[0].embedding)
response = create_chat_embeddings(
client,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "A cat and a dog"},
],
}
],
model=model,
encoding_format="float",
)
print("Text embedding output:", response.data[0].embedding)
model_example_map = {
"clip": run_clip,
"dse_qwen2_vl": run_dse_qwen2_vl,
"siglip": run_siglip,
"vlm2vec": run_vlm2vec,
}
def parse_args():
parser = argparse.ArgumentParser(
"Script to call a specified VLM through the API. Make sure to serve "
"the model with `--runner pooling` before running this."
)
parser.add_argument(
"--model",
type=str,
choices=model_example_map.keys(),
required=True,
help="The name of the embedding model.",
)
return parser.parse_args()
def main(args):
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model_id = models.data[0].id
model_example_map[args.model](client, model_id)
if __name__ == "__main__":
args = parse_args()
main(args)
openai_classification_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for classification API using vLLM API server
NOTE:
start a supported classification model server with `vllm serve`, e.g.
vllm serve jason9693/Qwen2.5-1.5B-apeach
"""
import argparse
import pprint
import requests
def post_http_request(payload: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=payload)
return response
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000)
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
return parse.parse_args()
def main(args):
host = args.host
port = args.port
model_name = args.model
api_url = f"http://{host}:{port}/classify"
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
payload = {
"model": model_name,
"input": prompts,
}
classify_response = post_http_request(payload=payload, api_url=api_url)
pprint.pprint(classify_response.json())
if __name__ == "__main__":
args = parse_args()
main(args)
openai_cross_encoder_score.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Score API.
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
"""
import argparse
import pprint
import requests
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/score"
model_name = args.model
text_1 = "What is the capital of Brazil?"
text_2 = "The capital of Brazil is Brasilia."
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both strings:")
pprint.pprint(prompt)
print("\nScore Response:")
pprint.pprint(score_response.json())
text_1 = "What is the capital of France?"
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a list:")
pprint.pprint(prompt)
print("\nScore Response:")
pprint.pprint(score_response.json())
text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 and text_2 are both lists:")
pprint.pprint(prompt)
print("\nScore Response:")
pprint.pprint(score_response.json())
if __name__ == "__main__":
args = parse_args()
main(args)
openai_cross_encoder_score_for_multimodal.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Score API.
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
"""
import argparse
import pprint
import requests
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/score"
model_name = args.model
text_1 = "slm markdown"
text_2 = {
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
},
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
},
},
]
}
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
score_response = post_http_request(prompt=prompt, api_url=api_url)
print("\nPrompt when text_1 is string and text_2 is a image list:")
pprint.pprint(prompt)
print("\nScore Response:")
pprint.pprint(score_response.json())
if __name__ == "__main__":
args = parse_args()
main(args)
openai_embedding_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API using vLLM API server
NOTE:
start a supported embeddings model server with `vllm serve`, e.g.
vllm serve intfloat/e5-small
"""
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
def main():
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
responses = client.embeddings.create(
# ruff: noqa: E501
input=[
"Hello my name is",
"The best thing about vLLM is that it supports many different models",
],
model=model,
)
for data in responses.data:
print(data.embedding) # List of float of len 4096
if __name__ == "__main__":
main()
openai_embedding_matryoshka_fy.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Example Python client for embedding API dimensions using vLLM API server
NOTE:
start a supported Matryoshka Embeddings model server with `vllm serve`, e.g.
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
"""
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
def main():
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
responses = client.embeddings.create(
input=["Follow the white rabbit."],
model=model,
dimensions=32,
)
for data in responses.data:
print(data.embedding) # List of float of len 32
if __name__ == "__main__":
main()
openai_pooling_client.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example online usage of Pooling API.
Run `vllm serve <model> --runner pooling`
to start up the server in vLLM. e.g.
vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
"""
import argparse
import pprint
import requests
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
headers = {"User-Agent": "Test Client"}
response = requests.post(api_url, headers=headers, json=prompt)
return response
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
# Input like Completions API
prompt = {"model": model_name, "input": "vLLM is great!"}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("-" * 50)
print("Pooling Response:")
pprint.pprint(pooling_response.json())
print("-" * 50)
# Input like Chat API
prompt = {
"model": model_name,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "vLLM is great!"}],
}
],
}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
print("Pooling Response:")
pprint.pprint(pooling_response.json())
print("-" * 50)
if __name__ == "__main__":
args = parse_args()
main(args)
prithvi_geospatial_mae.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import requests
# This example shows how to perform an online inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirements :
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds
def main():
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
server_endpoint = "http://localhost:8000/pooling"
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
}
ret = requests.post(server_endpoint, json=request_payload_url)
print(f"response.status_code: {ret.status_code}")
print(f"response.reason:{ret.reason}")
response = ret.json()
decoded_image = base64.b64decode(response["data"]["data"])
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
with open(out_path, "wb") as f:
f.write(decoded_image)
if __name__ == "__main__":
main()