#!/usr/bin/env python3
# Copyright (C) 2026 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import numpy as np
import openvino_genai as ov_genai

from pathlib import Path
from PIL import Image
from openvino import Tensor


def streamer(subword: str) -> bool:
    """

    Args:
        subword: sub-word of the generated text.

    Returns: Return flag corresponds whether generation should be stopped.

    """
    print(subword, end="", flush=True)

    # No value is returned as in this example we don't want to stop the generation in this method.
    # "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".


def read_image(path: str) -> Tensor:
    """

    Args:
        path: The path to the image.

    Returns: the ov.Tensor containing the image.

    """
    pic = Image.open(path).convert("RGB")
    image_data = np.array(pic)
    return Tensor(image_data)


def read_images(path: str) -> list[Tensor]:
    entry = Path(path)
    if entry.is_dir():
        return [read_image(str(file)) for file in sorted(entry.iterdir())]
    return [read_image(path)]


def parse_lora_pairs(raw):
    if len(raw) < 2:
        raise argparse.ArgumentTypeError(
            "At least one LoRA adapter pair is required: <LORA_SAFETENSORS> <ALPHA> [<LORA_SAFETENSORS> <ALPHA> ...]"
        )
    if len(raw) % 2 != 0:
        raise argparse.ArgumentTypeError("LoRA args must come in pairs: <LORA_SAFETENSORS> <ALPHA> ...")

    pairs = []
    for i in range(0, len(raw), 2):
        path = raw[i]
        try:
            alpha = float(raw[i + 1])
        except ValueError as e:
            raise argparse.ArgumentTypeError(f"Invalid alpha '{raw[i + 1]}' for LoRA '{path}'") from e
        pairs.append((path, alpha))
    return pairs


def main() -> int:
    p = argparse.ArgumentParser(
        description="OpenVINO GenAI VLM sample: run with and without LoRA adapters.",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    p.add_argument("model_dir", help="Path to model directory")
    p.add_argument("images_path", help="Image file OR directory with images")
    p.add_argument("prompt", help="Prompt/question to ask")
    p.add_argument(
        "lora_pairs",
        nargs="+",
        metavar="LORA_ALPHA",
        help="Pairs: <LORA_SAFETENSORS> <ALPHA> ...",
    )

    args = p.parse_args()
    prompt = args.prompt
    loras = parse_lora_pairs(args.lora_pairs)

    rgbs = read_images(args.images_path)

    device = "CPU"  # GPU can be used as well

    pipe_kwargs = {}

    # Configure LoRA adapters with weights (alphas)
    if loras:
        adapter_config = ov_genai.AdapterConfig()
        for lora_path, alpha in loras:
            adapter_config.add(ov_genai.Adapter(lora_path), alpha)
        pipe_kwargs["adapters"] = adapter_config

    pipe = ov_genai.VLMPipeline(args.model_dir, device, **pipe_kwargs)

    gen_cfg = ov_genai.GenerationConfig()
    gen_cfg.max_new_tokens = 100

    print("Generating answer with LoRA adapters applied:")
    pipe.generate(
        prompt,
        images=rgbs,
        generation_config=gen_cfg,
        streamer=streamer,
    )

    print("\n----------\nGenerating answer without LoRA adapters applied:")
    pipe.generate(
        prompt,
        images=rgbs,
        generation_config=gen_cfg,
        adapters=ov_genai.AdapterConfig(),
        streamer=streamer,
    )

    print("\n----------")
    return 0


if __name__ == "__main__":
    main()
