"""
HuggingFace GGUF Model Chat UI on Paperspace
- Downloads a GGUF model from HuggingFace Hub
- Runs inference via llama-cpp-python (GPU accelerated)
- Provides a Gradio chat interface
"""

import os
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# ---------- Configuration ----------
MODEL_REPO = os.environ.get("MODEL_REPO", "HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive")
MODEL_FILE = os.environ.get("MODEL_FILE", "Gemma-4-E4B-Uncensored-HauhauCS-Aggressive-Q8_K_P.gguf")
N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))  # -1 = all layers on GPU
N_CTX = int(os.environ.get("N_CTX", "8192"))
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "4096"))
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
TOP_P = float(os.environ.get("TOP_P", "0.9"))
PORT = int(os.environ.get("PORT", "7860"))

# ---------- Model Download & Loading ----------
print(f"Downloading model: {MODEL_REPO} / {MODEL_FILE}...")
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
print(f"Model downloaded to: {model_path}")

print("Loading model into GPU...")
llm = Llama(
    model_path=model_path,
    n_gpu_layers=N_GPU_LAYERS,
    n_ctx=N_CTX,
    verbose=True,
)
print("Model loaded successfully.")


def chat(user_message, history, system_prompt):
    """Generate a streaming response."""
    messages = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt})
    for msg in history:
        if msg["content"]:
            messages.append({"role": msg["role"], "content": msg["content"]})
    messages.append({"role": "user", "content": user_message})

    history.append({"role": "user", "content": user_message})
    history.append({"role": "assistant", "content": ""})

    partial = ""
    for chunk in llm.create_chat_completion(
        messages=messages,
        max_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        stream=True,
    ):
        delta = chunk["choices"][0]["delta"]
        if "content" in delta:
            partial += delta["content"]
            history[-1]["content"] = partial
            yield history


# ---------- Gradio UI ----------
with gr.Blocks(title="LLM Chat") as demo:
    gr.Markdown(f"# LLM Chat\n**Model**: `{MODEL_FILE}` &nbsp;|&nbsp; **GPU Layers**: `{N_GPU_LAYERS}`")

    system_prompt = gr.Textbox(
        label="System Prompt",
        value="You are a helpful assistant.",
        lines=2,
    )

    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Message", placeholder="Type your message here...", autofocus=True)
    clear = gr.Button("Clear")

    msg.submit(chat, [msg, chatbot, system_prompt], chatbot).then(
        lambda: "", outputs=[msg]
    )
    clear.click(lambda: [], outputs=[chatbot])

if __name__ == "__main__":
    demo.queue()
    demo.launch(server_name="0.0.0.0", server_port=PORT, share=True)
