Skip to content

Pooling models

Source https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/pooling.

Convert llm model to seq cls

# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls

Embed jina_embeddings_v3 usage

Only text matching task is supported for now. See Pull Request #16120

python examples/offline_inference/pooling/embed_jina_embeddings_v3.py

Embed matryoshka dimensions usage

python examples/offline_inference/pooling/embed_matryoshka_fy.py

Multi vector retrieval usage

python examples/offline_inference/pooling/multi_vector_retrieval.py

Named Entity Recognition (NER) usage

python examples/offline_inference/pooling/ner.py

Prithvi Geospatial MAE usage

python examples/offline_inference/pooling/prithvi_geospatial_mae.py

IO Processor Plugins for Prithvi Geospatial MAE

python examples/offline_inference/pooling/prithvi_geospatial_mae_io_processor.py

Qwen3 reranker usage

python examples/offline_inference/pooling/qwen3_reranker.py

Example materials

convert_model_to_seq_cls.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

import argparse
import json

import torch
import transformers

# Usage:
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls


def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
    # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
    assert len(tokens) == 2

    lm_head_weights = causal_lm.lm_head.weight

    false_id = tokenizer.convert_tokens_to_ids(tokens[0])
    true_id = tokenizer.convert_tokens_to_ids(tokens[1])

    score_weight = lm_head_weights[true_id].to(device).to(
        torch.float32
    ) - lm_head_weights[false_id].to(device).to(torch.float32)

    with torch.no_grad():
        seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
        if seq_cls_model.score.bias is not None:
            seq_cls_model.score.bias.zero_()


def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
    lm_head_weights = causal_lm.lm_head.weight

    token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]

    score_weight = lm_head_weights[token_ids].to(device)

    with torch.no_grad():
        seq_cls_model.score.weight.copy_(score_weight)
        if seq_cls_model.score.bias is not None:
            seq_cls_model.score.bias.zero_()


method_map = {
    function.__name__: function for function in [from_2_way_softmax, no_post_processing]
}


def converting(
    model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
):
    assert method in method_map

    if method == "from_2_way_softmax":
        assert len(classifier_from_tokens) == 2
        num_labels = 1
    else:
        num_labels = len(classifier_from_tokens)

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, device_map=device
    )

    seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        ignore_mismatched_sizes=True,
        device_map=device,
    )

    method_map[method](
        causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
    )

    # `llm as reranker` defaults to not using pad_token
    seq_cls_model.config.use_pad_token = use_pad_token
    seq_cls_model.config.pad_token_id = tokenizer.pad_token_id

    seq_cls_model.save_pretrained(path)
    tokenizer.save_pretrained(path)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Converting *ForCausalLM models to "
        "*ForSequenceClassification models."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="BAAI/bge-reranker-v2-gemma",
        help="Model name",
    )
    parser.add_argument(
        "--classifier_from_tokens",
        type=str,
        default='["Yes"]',
        help="classifier from tokens",
    )
    parser.add_argument(
        "--method", type=str, default="no_post_processing", help="Converting converting"
    )
    parser.add_argument(
        "--use-pad-token", action="store_true", help="Whether to use pad_token"
    )
    parser.add_argument(
        "--path",
        type=str,
        default="./bge-reranker-v2-gemma-seq-cls",
        help="Path to save converted model",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    converting(
        model_name=args.model_name,
        classifier_from_tokens=json.loads(args.classifier_from_tokens),
        method=args.method,
        use_pad_token=args.use_pad_token,
        path=args.path,
    )
embed_jina_embeddings_v3.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    # Set example specific arguments
    parser.set_defaults(
        model="jinaai/jina-embeddings-v3",
        runner="pooling",
        trust_remote_code=True,
    )
    return parser.parse_args()


def main(args: Namespace):
    # Sample prompts.
    prompts = [
        "Follow the white rabbit.",  # English
        "Sigue al conejo blanco.",  # Spanish
        "Suis le lapin blanc.",  # French
        "跟着白兔走。",  # Chinese
        "اتبع الأرنب الأبيض.",  # Arabic
        "Folge dem weißen Kaninchen.",  # German
    ]

    # Create an LLM.
    # You should pass runner="pooling" for embedding models
    llm = LLM(**vars(args))

    # Generate embedding. The output is a list of EmbeddingRequestOutputs.
    # Only text matching task is supported for now. See #16120
    outputs = llm.embed(prompts)

    # Print the outputs.
    print("\nGenerated Outputs:")
    print("Only text matching task is supported for now. See #16120")
    print("-" * 60)
    for prompt, output in zip(prompts, outputs):
        embeds = output.outputs.embedding
        embeds_trimmed = (
            (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
        )
        print(
            f"Prompt: {prompt!r} \n"
            f"Embeddings for text matching: {embeds_trimmed} "
            f"(size={len(embeds)})"
        )
        print("-" * 60)


if __name__ == "__main__":
    args = parse_args()
    main(args)
embed_matryoshka_fy.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace

from vllm import LLM, EngineArgs, PoolingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    # Set example specific arguments
    parser.set_defaults(
        model="jinaai/jina-embeddings-v3",
        runner="pooling",
        trust_remote_code=True,
    )
    return parser.parse_args()


def main(args: Namespace):
    # Sample prompts.
    prompts = [
        "Follow the white rabbit.",  # English
        "Sigue al conejo blanco.",  # Spanish
        "Suis le lapin blanc.",  # French
        "跟着白兔走。",  # Chinese
        "اتبع الأرنب الأبيض.",  # Arabic
        "Folge dem weißen Kaninchen.",  # German
    ]

    # Create an LLM.
    # You should pass runner="pooling" for embedding models
    llm = LLM(**vars(args))

    # Generate embedding. The output is a list of EmbeddingRequestOutputs.
    outputs = llm.embed(prompts, pooling_params=PoolingParams(dimensions=32))

    # Print the outputs.
    print("\nGenerated Outputs:")
    print("-" * 60)
    for prompt, output in zip(prompts, outputs):
        embeds = output.outputs.embedding
        embeds_trimmed = (
            (str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
        )
        print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
        print("-" * 60)


if __name__ == "__main__":
    args = parse_args()
    main(args)
multi_vector_retrieval.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    # Set example specific arguments
    parser.set_defaults(
        model="BAAI/bge-m3",
        runner="pooling",
        enforce_eager=True,
    )
    return parser.parse_args()


def main(args: Namespace):
    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Create an LLM.
    # You should pass runner="pooling" for embedding models
    llm = LLM(**vars(args))

    # Generate embedding. The output is a list of EmbeddingRequestOutputs.
    outputs = llm.embed(prompts)

    # Print the outputs.
    print("\nGenerated Outputs:\n" + "-" * 60)
    for prompt, output in zip(prompts, outputs):
        embeds = output.outputs.embedding
        print(len(embeds))

    # Generate embedding for each token. The output is a list of PoolingRequestOutput.
    outputs = llm.encode(prompts, pooling_task="token_embed")

    # Print the outputs.
    print("\nGenerated Outputs:\n" + "-" * 60)
    for prompt, output in zip(prompts, outputs):
        multi_vector = output.outputs.data
        print(multi_vector.shape)


if __name__ == "__main__":
    args = parse_args()
    main(args)
ner.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER

from argparse import Namespace

from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser


def parse_args():
    parser = FlexibleArgumentParser()
    parser = EngineArgs.add_cli_args(parser)
    # Set example specific arguments
    parser.set_defaults(
        model="boltuix/NeuroBERT-NER",
        runner="pooling",
        enforce_eager=True,
        trust_remote_code=True,
    )
    return parser.parse_args()


def main(args: Namespace):
    # Sample prompts.
    prompts = [
        "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
    ]

    # Create an LLM.
    llm = LLM(**vars(args))
    tokenizer = llm.get_tokenizer()
    label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label

    # Run inference
    outputs = llm.encode(prompts, pooling_task="token_classify")

    for prompt, output in zip(prompts, outputs):
        logits = output.outputs.data
        predictions = logits.argmax(dim=-1)

        # Map predictions to labels
        tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
        labels = [label_map[p.item()] for p in predictions]

        # Print results
        for token, label in zip(tokens, labels):
            if token not in tokenizer.all_special_tokens:
                print(f"{token:15}{label}")


if __name__ == "__main__":
    args = parse_args()
    main(args)
prithvi_geospatial_mae.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import datetime
import os

import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule

from vllm import LLM

torch.set_default_dtype(torch.float16)

NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99

datamodule_config = {
    "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
    "batch_size": 16,
    "constant_scale": 0.0001,
    "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
    "drop_last": True,
    "no_data_replace": 0.0,
    "no_label_replace": -1,
    "num_workers": 8,
    "test_transform": [
        albumentations.Resize(
            always_apply=False, height=448, interpolation=1, p=1, width=448
        ),
        albumentations.pytorch.ToTensorV2(
            transpose_mask=False, always_apply=True, p=1.0
        ),
    ],
}


class PrithviMAE:
    def __init__(self, model):
        self.model = LLM(
            model=model,
            skip_tokenizer_init=True,
            dtype="float16",
            enforce_eager=True,
            model_impl="terratorch",
            enable_mm_embeds=True,
        )

    def run(self, input_data, location_coords):
        # merge the inputs into one data structure
        if input_data is not None and input_data.dtype == torch.float32:
            input_data = input_data.to(torch.float16)
            input_data = input_data[0]

        mm_data = {
            "pixel_values": input_data,
            "location_coords": location_coords,
        }

        prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
        outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)

        return outputs[0].outputs.data


def generate_datamodule():
    datamodule = Sen1Floods11NonGeoDataModule(
        data_root=datamodule_config["data_root"],
        batch_size=datamodule_config["batch_size"],
        num_workers=datamodule_config["num_workers"],
        bands=datamodule_config["bands"],
        drop_last=datamodule_config["drop_last"],
        test_transform=datamodule_config["test_transform"],
    )

    return datamodule


def process_channel_group(orig_img, channels):
    """
    Args:
        orig_img: torch.Tensor representing original image (reference)
        with shape = (bands, H, W).
        channels: list of indices representing RGB channels.

    Returns:
        torch.Tensor with shape (num_channels, height, width)
        for original image
    """

    orig_img = orig_img[channels, ...]
    valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
    valid_mask[orig_img == NO_DATA_FLOAT] = False

    # Rescale (enhancing contrast)
    max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
    min_value = OFFSET

    orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)

    # No data as zeros
    orig_img[~valid_mask] = 0

    return orig_img


def read_geotiff(file_path: str):
    """Read all bands from *file_path* and return image + meta info.

    Args:
        file_path: path to image file.

    Returns:
        np.ndarray with shape (bands, height, width)
        meta info dict
    """

    with rasterio.open(file_path) as src:
        img = src.read()
        meta = src.meta
        try:
            coords = src.lnglat()
        except Exception:
            # Cannot read coords
            coords = None

    return img, meta, coords


def save_geotiff(image, output_path: str, meta: dict):
    """Save multi-band image in Geotiff file.

    Args:
        image: np.ndarray with shape (bands, height, width)
        output_path: path where to save the image
        meta: dict with meta info.
    """

    with rasterio.open(output_path, "w", **meta) as dest:
        for i in range(image.shape[0]):
            dest.write(image[i, :, :], i + 1)

    return


def _convert_np_uint8(float_image: torch.Tensor):
    image = float_image.numpy() * 255.0
    image = image.astype(dtype=np.uint8)

    return image


def load_example(
    file_paths: list[str],
    mean: list[float] = None,
    std: list[float] = None,
    indices: list[int] | None = None,
):
    """Build an input example by loading images in *file_paths*.

    Args:
        file_paths: list of file paths .
        mean: list containing mean values for each band in the
              images in *file_paths*.
        std: list containing std values for each band in the
             images in *file_paths*.

    Returns:
        np.array containing created example
        list of meta info for each image in *file_paths*
    """

    imgs = []
    metas = []
    temporal_coords = []
    location_coords = []

    for file in file_paths:
        img, meta, coords = read_geotiff(file)

        # Rescaling (don't normalize on nodata)
        img = np.moveaxis(img, 0, -1)  # channels last for rescaling
        if indices is not None:
            img = img[..., indices]
        if mean is not None and std is not None:
            img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)

        imgs.append(img)
        metas.append(meta)
        if coords is not None:
            location_coords.append(coords)

        try:
            match = re.search(r"(\d{7,8}T\d{6})", file)
            if match:
                year = int(match.group(1)[:4])
                julian_day = match.group(1).split("T")[0][4:]
                if len(julian_day) == 3:
                    julian_day = int(julian_day)
                else:
                    julian_day = (
                        datetime.datetime.strptime(julian_day, "%m%d")
                        .timetuple()
                        .tm_yday
                    )
                temporal_coords.append([year, julian_day])
        except Exception as e:
            print(f"Could not extract timestamp for {file} ({e})")

    imgs = np.stack(imgs, axis=0)  # num_frames, H, W, C
    imgs = np.moveaxis(imgs, -1, 0).astype("float32")  # C, num_frames, H, W
    imgs = np.expand_dims(imgs, axis=0)  # add batch di

    return imgs, temporal_coords, location_coords, metas


def run_model(
    input_data,
    temporal_coords,
    location_coords,
    model,
    datamodule,
    img_size,
    lightning_model=None,
):
    # Reflect pad if not divisible by img_size
    original_h, original_w = input_data.shape[-2:]
    pad_h = (img_size - (original_h % img_size)) % img_size
    pad_w = (img_size - (original_w % img_size)) % img_size
    input_data = np.pad(
        input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
    )

    # Build sliding window

    batch_size = 1
    # batch = torch.tensor(input_data, device="cpu")
    batch = torch.tensor(input_data)
    windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
    h1, w1 = windows.shape[3:5]
    windows = rearrange(
        windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
    )

    # Split into batches if number of windows > batch_size
    num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
    windows = torch.tensor_split(windows, num_batches, dim=0)

    if temporal_coords:
        temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
    else:
        temporal_coords = None
    if location_coords:
        location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
    else:
        location_coords = None

    # Run Prithvi-EO-V2-300M-TL-Sen1Floods11
    pred_imgs = []
    for x in windows:
        # Apply standardization
        x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
        x = datamodule.aug(x)["image"]

        with torch.no_grad():
            pred = model.run(x, location_coords=location_coords)
        y_hat = pred.argmax(dim=1)

        y_hat = torch.nn.functional.interpolate(
            y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
        )

        pred_imgs.append(y_hat)

    pred_imgs = torch.concat(pred_imgs, dim=0)

    # Build images from patches
    pred_imgs = rearrange(
        pred_imgs,
        "(b h1 w1) c h w -> b c (h1 h) (w1 w)",
        h=img_size,
        w=img_size,
        b=1,
        c=1,
        h1=h1,
        w1=w1,
    )

    # Cut padded area back to original size
    pred_imgs = pred_imgs[..., :original_h, :original_w]

    # Squeeze (batch size 1)
    pred_imgs = pred_imgs[0]

    return pred_imgs


def main(
    data_file: str,
    model: str,
    output_dir: str,
    rgb_outputs: bool,
    input_indices: list[int] = None,
):
    os.makedirs(output_dir, exist_ok=True)

    model_obj = PrithviMAE(model=model)
    datamodule = generate_datamodule()
    img_size = 512  # Size of Sen1Floods11

    input_data, temporal_coords, location_coords, meta_data = load_example(
        file_paths=[data_file],
        indices=input_indices,
    )

    meta_data = meta_data[0]  # only one image

    if input_data.mean() > 1:
        input_data = input_data / 10000  # Convert to range 0-1

    channels = [
        datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
    ]  # BGR -> RGB

    pred = run_model(
        input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
    )
    # Save pred
    meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
    pred_file = os.path.join(
        output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
    )
    save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)

    # Save image + pred
    meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)

    if input_data.mean() < 1:
        input_data = input_data * 10000  # Scale to 0-10000

    rgb_orig = process_channel_group(
        orig_img=torch.Tensor(input_data[0, :, 0, ...]),
        channels=channels,
    )
    rgb_orig = rgb_orig.to(torch.float32)

    pred[pred == 0.0] = np.nan
    img_pred = rgb_orig * 0.7 + pred * 0.3
    img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]

    img_pred_file = os.path.join(
        output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
    )
    save_geotiff(
        image=_convert_np_uint8(img_pred),
        output_path=img_pred_file,
        meta=meta_data,
    )

    # Save image rgb
    if rgb_outputs:
        name_suffix = os.path.splitext(os.path.basename(data_file))[0]
        rgb_file = os.path.join(
            output_dir,
            f"original_rgb_{name_suffix}.tiff",
        )
        save_geotiff(
            image=_convert_np_uint8(rgb_orig),
            output_path=rgb_file,
            meta=meta_data,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser("MAE run inference", add_help=False)

    parser.add_argument(
        "--data_file",
        type=str,
        default="./India_900498_S2Hand.tif",
        help="Path to the file.",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
        help="Path to a checkpoint file to load from.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Path to the directory where to save outputs.",
    )
    parser.add_argument(
        "--input_indices",
        default=[1, 2, 3, 8, 11, 12],
        type=int,
        nargs="+",
        help="""
        0-based indices of the six Prithvi channels to be selected from the input.
        By default selects [1,2,3,8,11,12] for S2L1C data.
        """,
    )
    parser.add_argument(
        "--rgb_outputs",
        action="store_true",
        help="If present, output files will only contain RGB channels. "
        "Otherwise, all bands will be saved.",
    )
    args = parser.parse_args()

    main(**vars(args))
prithvi_geospatial_mae_io_processor.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os

import torch

from vllm import LLM

# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirements:
# - install TerraTorch v1.1 (or later):
#   pip install terratorch>=v1.1


def main():
    torch.set_default_dtype(torch.float16)
    image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff"  # noqa: E501

    img_prompt = dict(
        data=image_url,
        data_format="url",
        image_format="tiff",
        out_data_format="b64_json",
    )

    llm = LLM(
        model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
        skip_tokenizer_init=True,
        trust_remote_code=True,
        enforce_eager=True,
        # Limit the maximum number of parallel requests
        # to avoid the model going OOM.
        # The maximum number depends on the available GPU memory
        max_num_seqs=32,
        io_processor_plugin="terratorch_segmentation",
        model_impl="terratorch",
        enable_mm_embeds=True,
    )

    pooler_output = llm.encode(img_prompt, pooling_task="plugin")
    output = pooler_output[0].outputs

    print(output)
    decoded_data = base64.b64decode(output.data)

    file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
    with open(file_path, "wb") as f:
        f.write(decoded_data)

    print(f"Output file path: {file_path}")


if __name__ == "__main__":
    main()
qwen3_reranker.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

from vllm import LLM

model_name = "Qwen/Qwen3-Reranker-0.6B"

# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")

# If you want to load the official original version, the init parameters are
# as follows.


def get_llm() -> LLM:
    """Initializes and returns the LLM model for Qwen3-Reranker."""
    return LLM(
        model=model_name,
        runner="pooling",
        hf_overrides={
            "architectures": ["Qwen3ForSequenceClassification"],
            "classifier_from_token": ["no", "yes"],
            "is_original_qwen3_reranker": True,
        },
    )


# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector.  The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"


def main() -> None:
    instruction = (
        "Given a web search query, retrieve relevant passages that answer the query"
    )

    queries = [
        "What is the capital of China?",
        "Explain gravity",
    ]

    documents = [
        "The capital of China is Beijing.",
        "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
    ]

    queries = [
        query_template.format(prefix=prefix, instruction=instruction, query=query)
        for query in queries
    ]
    documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]

    llm = get_llm()
    outputs = llm.score(queries, documents)

    print("-" * 30)
    print([output.outputs.score for output in outputs])
    print("-" * 30)


if __name__ == "__main__":
    main()