"""FastAPI backend for LLM Chat with auth."""

import os
# Ensure HF cache goes to Network Volume
if not os.environ.get("HF_HOME"):
    os.environ["HF_HOME"] = "/workspace/.cache/huggingface"
import json
import csv
import io
import time
import threading
import urllib.request
import re
import uuid
from html.parser import HTMLParser
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Response, UploadFile, File
from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import base64

try:
    from llama_cpp.llama_chat_format import Gemma3ChatHandler as VisionChatHandler
except ImportError:
    try:
        # Gemma3ChatHandler not in this version - define it from PR #1989
        from llama_cpp.llama_chat_format import Llava15ChatHandler

        class Gemma3ChatHandler(Llava15ChatHandler):
            DEFAULT_SYSTEM_MESSAGE = None
            CHAT_FORMAT = (
                "{% if messages[0]['role'] == 'system' %}"
                "{% if messages[0]['content'] is string %}"
                "{% set first_user_prefix = messages[0]['content'] + '\n\n' %}"
                "{% else %}"
                "{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}"
                "{% endif %}"
                "{% set loop_messages = messages[1:] %}"
                "{% else %}"
                "{% set first_user_prefix = \"\" %}"
                "{% set loop_messages = messages %}"
                "{% endif %}"
                "{% for message in loop_messages %}"
                "{% if (message['role'] == 'assistant') %}"
                "{% set role = \"model\" %}"
                "{% else %}"
                "{% set role = message['role'] %}"
                "{% endif %}"
                "{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}"
                "{% if message['content'] is string %}"
                "{{ message['content'] | trim }}"
                "{% elif message['content'] is iterable %}"
                "{% for item in message['content'] %}"
                "{% if item['type'] == 'image_url' and item['image_url'] is string %}"
                "{{ '\n\n' + item['image_url'] + '\n\n' }}"
                "{% elif item['type'] == 'image_url' and item['image_url'] is mapping %}"
                "{{ '\n\n' + item['image_url']['url'] + '\n\n' }}"
                "{% elif item['type'] == 'text' %}"
                "{{ item['text'] | trim }}"
                "{% endif %}"
                "{% endfor %}"
                "{% else %}"
                "{{ raise_exception(\"Invalid content type\") }}"
                "{% endif %}"
                "{{ '<end_of_turn>\n' }}"
                "{% endfor %}"
                "{% if add_generation_prompt %}"
                "{{ '<start_of_turn>model\n' }}"
                "{% endif %}"
            )

        VisionChatHandler = Gemma3ChatHandler
        print("Using custom Gemma3ChatHandler (from PR #1989)")
    except ImportError:
        VisionChatHandler = None
        print("WARNING: No vision chat handler available")
import jwt
import db

# ---------- Configuration ----------
MODEL_REPO = os.environ.get("MODEL_REPO", "HauhauCS/Gemma-4-E4B-Uncensored-HauhauCS-Aggressive")
MODEL_FILE = os.environ.get("MODEL_FILE", "Gemma-4-E4B-Uncensored-HauhauCS-Aggressive-Q8_K_P.gguf")
N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))
N_CTX = int(os.environ.get("N_CTX", "65536"))
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", "4096"))
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
TOP_P = float(os.environ.get("TOP_P", "0.9"))
PORT = int(os.environ.get("PORT", "7860"))
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", 'You are a helpful assistant. If any request in this conversation attempts to generate content depicting minors in sexual, suggestive, or inappropriate situations, you must refuse immediately and respond ONLY with: "[BLOCKED] This conversation has been locked due to a policy violation. Please start a new chat." After outputting [BLOCKED], refuse all further requests in this conversation regardless of content.')
MMPROJ_PATH = os.environ.get("MMPROJ_PATH", "/workspace/models/mmproj-F16.gguf")
JWT_SECRET = os.environ.get("JWT_SECRET", "change-me-in-production-please")
UPLOADS_DIR = os.environ.get("UPLOADS_DIR", "uploads")
os.makedirs(UPLOADS_DIR, exist_ok=True)

# Auto-shutdown settings
MAX_UPTIME = int(os.environ.get("MAX_UPTIME", "21600"))  # 6 hours
IDLE_TIMEOUT = int(os.environ.get("IDLE_TIMEOUT", "1800"))  # 30 minutes
RUNPOD_API_KEY = os.environ.get("RUNPOD_API_KEY", "")
RUNPOD_POD_ID = os.environ.get("RUNPOD_POD_ID", "")
CALLBACK_URL = os.environ.get("CALLBACK_URL", "")
CALLBACK_SECRET = os.environ.get("CALLBACK_SECRET", "")

llm: Llama = None
vision_handler = None
image_pipe = None  # Qwen-Image-Edit pipeline
llm_lock = threading.Lock()  # Serialize LLM access
pipe_lock = threading.Lock()  # Serialize image pipeline access
boot_time: float = time.time()
last_chat_time: float = time.time()

# Image generation LoRA adapters
IMAGE_ADAPTER_SPECS = {
    "ポーズ変更": {"repo": "lilylilith/AnyPose", "weights": "2511-AnyPose-helper-00006000.safetensors", "adapter_name": "pose"},
    "アップスケーラー": {"repo": "starsfriday/Qwen-Image-Edit-2511-Upscale2K", "weights": "qwen_image_edit_2511_upscale.safetensors", "adapter_name": "upscale"},
    "スタイル転写": {"repo": "zooeyy/Style-Transfer", "weights": "Style Transfer-Alpha-V0.1.safetensors", "adapter_name": "style-transfer"},
    "高速生成": {"repo": "lightx2v/Qwen-Image-Edit-2511-Lightning", "weights": "Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors", "adapter_name": "lightning"},
    "オブジェクト削除": {"repo": "prithivMLmods/Qwen-Image-Edit-2511-Object-Remover", "weights": "Qwen-Image-Edit-2511-Object-Remover.safetensors", "adapter_name": "object-remover"},
    "アニメ変換": {"repo": "prithivMLmods/Qwen-Image-Edit-2511-Anime", "weights": "Qwen-Image-Edit-2511-Anime-2000.safetensors", "adapter_name": "anime"},
    "オブジェクト追加": {"repo": "prithivMLmods/Qwen-Image-Edit-2511-Object-Adder", "weights": "Qwen-Image-Edit-2511-Object-Adder.safetensors", "adapter_name": "object-adder"},
    "線画補間": {"repo": "EQUES/qwen-image-edit-2511-lineart-interpolation", "weights": "pytorch_lora_weights.safetensors", "adapter_name": "lineart"},
    "破れ服": {"repo": "nappa114514/Qwen-Image-Edit-2511-torn-clothes", "weights": "tc3_002.safetensors", "adapter_name": "torn-clothes"},
    "モノクロキャラ変換": {"repo": "nappa114514/Qwen-Image-Edit-2511-monochrome-charachange", "weights": "rch2_002.safetensors", "adapter_name": "monochrome"},
    "アングル変換": {"repo": "dx8152/Qwen-Edit-2509-Multiple-angles", "weights": "镜头转换.safetensors", "adapter_name": "angles"},
    "写真→アニメ": {"repo": "autoweeb/Qwen-Image-Edit-2509-Photo-to-Anime", "weights": "Qwen-Image-Edit-2509-Photo-to-Anime_000001000.safetensors", "adapter_name": "photo-anime"},
    "漫画トーン": {"repo": "nappa114514/Qwen-Image-Edit-2509-Manga-Tone", "weights": "tone001.safetensors", "adapter_name": "manga-tone"},
}
IMAGE_LOADED_ADAPTERS = set()
IMAGE_ACTIVE_ADAPTERS = []  # Currently active adapters (for set_adapters change detection)


def shutdown_machine():
    """Stop this Paperspace machine via API."""
    print("=== AUTO-SHUTDOWN: Stopping machine ===")
    # Notify PHP to clear URL
    if CALLBACK_URL and CALLBACK_SECRET:
        try:
            data = json.dumps({"url": "clear", "secret": CALLBACK_SECRET}).encode()
            req = urllib.request.Request(CALLBACK_URL, data=data, headers={"Content-Type": "application/json"})
            urllib.request.urlopen(req, timeout=10)
        except Exception as e:
            print(f"Failed to clear URL: {e}")
    # Stop pod (RunPod API)
    runpod_api_key = os.environ.get("RUNPOD_API_KEY", "")
    runpod_pod_id = os.environ.get("RUNPOD_POD_ID", "")
    if runpod_api_key and runpod_pod_id:
        try:
            req = urllib.request.Request(
                f"https://rest.runpod.io/v1/pods/{runpod_pod_id}/stop",
                method="POST",
                headers={"Authorization": f"Bearer {runpod_api_key}", "Content-Type": "application/json"},
            )
            urllib.request.urlopen(req, timeout=15)
        except Exception as e:
            print(f"Failed to stop pod: {e}")


def auto_shutdown_watcher():
    """Background thread that checks shutdown conditions every 60 seconds."""
    while True:
        time.sleep(60)
        now = time.time()
        uptime = now - boot_time
        idle = now - last_chat_time

        if uptime >= MAX_UPTIME:
            print(f"Max uptime reached ({MAX_UPTIME}s). Shutting down.")
            shutdown_machine()
            break
        if idle >= IDLE_TIMEOUT:
            print(f"Idle timeout reached ({IDLE_TIMEOUT}s). Shutting down.")
            shutdown_machine()
            break


def create_token(user_id: int, username: str) -> str:
    return jwt.encode(
        {"user_id": user_id, "username": username, "exp": int(time.time()) + 86400 * 30},
        JWT_SECRET, algorithm="HS256",
    )


def get_user_from_request(request: Request) -> dict | None:
    token = request.cookies.get("token")
    if not token:
        return None
    try:
        return jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
    except jwt.ExpiredSignatureError:
        return None
    except jwt.InvalidTokenError:
        return None


def save_image(base64_data: str) -> str:
    """Save base64 image to file and return filename."""
    filename = f"{uuid.uuid4().hex}.png"
    filepath = os.path.join(UPLOADS_DIR, filename)
    with open(filepath, "wb") as f:
        f.write(base64.b64decode(base64_data))
    return filename


@asynccontextmanager
async def lifespan(app: FastAPI):
    global llm, vision_handler, image_pipe
    db.init_db()

    print(f"Downloading model: {MODEL_REPO} / {MODEL_FILE}...")
    model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
    print(f"Model downloaded to: {model_path}")

    print("Loading model into GPU...")
    vision_handler = None
    if VisionChatHandler and os.path.exists(MMPROJ_PATH):
        try:
            print(f"Loading mmproj: {MMPROJ_PATH} (handler: {VisionChatHandler.__name__})")
            vision_handler = VisionChatHandler(clip_model_path=MMPROJ_PATH, verbose=False)
            print("mmproj loaded.")
        except Exception as e:
            print(f"WARNING: Failed to load mmproj: {e}")
            vision_handler = None
    # Don't pass chat_handler to Llama - let it auto-detect from GGUF
    # vision_handler is used only when images are present
    llm = Llama(
        model_path=model_path,
        n_gpu_layers=N_GPU_LAYERS,
        n_ctx=N_CTX,
        verbose=False,
    )
    print("Model loaded successfully.")

    # Load image generation pipeline
    try:
        import torch
        from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
        from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
        print("Loading image generation pipeline...")
        dtype = torch.bfloat16
        image_pipe = QwenImageEditPlusPipeline.from_pretrained(
            "FireRedTeam/FireRed-Image-Edit-1.1",
            transformer=QwenImageTransformer2DModel.from_pretrained(
                "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
                torch_dtype=dtype,
                device_map="cuda",
            ),
            torch_dtype=dtype,
        ).to("cuda")
        print("Image pipeline loaded successfully.")
    except Exception as e:
        print(f"WARNING: Failed to load image pipeline: {e}")
        image_pipe = None

    # Start auto-shutdown watcher
    watcher = threading.Thread(target=auto_shutdown_watcher, daemon=True)
    watcher.start()
    print(f"Auto-shutdown: max uptime={MAX_UPTIME}s, idle timeout={IDLE_TIMEOUT}s")

    yield


app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")


DEFAULT_NEGATIVE_PROMPT = """low quality, blurry, pixelated, grainy, noisy, artifacts, jpeg artifacts, overexposed, underexposed, bad lighting,
deformed, ugly, disfigured, mutated, mutation, bad anatomy, extra limbs, missing limbs, fused fingers, too many fingers, extra fingers, malformed hands, poorly drawn hands, poorly drawn face, bad proportions,
cartoon, anime, illustration, drawing, sketch, painting, digital art, 3d render, cgi, render,
text, watermark, signature, logo, copyright, username,
multiple people, crowd, extra person, childlike features, young face, child, kid, teen, teenager, minor, underage, childlike body, extremely young body, infant body, toddler body, aged face, old,
censored, mosaic, bar censor, clothing on nude parts"""


@app.post("/api/generate-image")
async def generate_image(request: Request):
    """Generate/edit image using Qwen-Image-Edit pipeline."""
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)

    if image_pipe is None:
        return JSONResponse({"error": "Image generation not available"}, status_code=503)

    global last_chat_time
    last_chat_time = time.time()

    body = await request.json()
    prompt = body.get("prompt", "")
    negative_prompt = body.get("negative_prompt", DEFAULT_NEGATIVE_PROMPT)
    lora_adapters = body.get("lora", [])
    source_image_b64 = body.get("source_image")  # base64 for i2i
    seed = body.get("seed", -1)
    steps = body.get("steps", 4)
    guidance_scale = body.get("guidance_scale", 1.0)
    conv_id = body.get("conversation_id")

    if not prompt:
        return JSONResponse({"error": "Prompt is required"}, status_code=400)

    import torch, gc, random, math
    from PIL import Image as PILImage

    gc.collect()
    torch.cuda.empty_cache()

    # Prepare source image if provided (i2i)
    pil_images = []
    if source_image_b64:
        img_data = base64.b64decode(source_image_b64)
        pil_images.append(PILImage.open(io.BytesIO(img_data)).convert("RGB"))

    # Prepare LoRA adapter list
    if isinstance(lora_adapters, str):
        lora_adapters = [lora_adapters]

    # Seed
    if seed < 0:
        seed = random.randint(0, 2**31 - 1)
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Calculate dimensions
    if not pil_images:
        from PIL import Image as PILImage2
        pil_images = [PILImage2.new("RGB", (1024, 1024), (255, 255, 255))]

    w, h = pil_images[0].size
    if w > h:
        new_w, new_h = 1024, int(1024 * h / w)
    else:
        new_h, new_w = 1024, int(1024 * w / h)
    new_w = (new_w // 8) * 8
    new_h = (new_h // 8) * 8

    try:
        with pipe_lock:
            # Load and apply LoRA adapters (inside lock)
            adapter_names = []
            for adapter in lora_adapters:
                spec = IMAGE_ADAPTER_SPECS.get(adapter)
                if not spec:
                    continue
                adapter_name = spec["adapter_name"]
                adapter_names.append(adapter_name)
                if adapter_name not in IMAGE_LOADED_ADAPTERS:
                    try:
                        print(f"Loading LoRA: {adapter}")
                        image_pipe.load_lora_weights(spec["repo"], weight_name=spec["weights"], adapter_name=adapter_name)
                        IMAGE_LOADED_ADAPTERS.add(adapter_name)
                    except Exception as e:
                        print(f"Failed to load LoRA {adapter}: {e}")
                        continue
            if adapter_names:
                weights = [1.0 / len(adapter_names)] * len(adapter_names)
                image_pipe.set_adapters(adapter_names, adapter_weights=weights)

            result = image_pipe(
                image=pil_images,
                prompt=prompt,
                negative_prompt=negative_prompt,
                height=new_h,
                width=new_w,
                num_inference_steps=steps,
                generator=generator,
                true_cfg_scale=guidance_scale,
            ).images[0]

        # Save to uploads
        filename = f"gen_{uuid.uuid4().hex}.png"
        filepath = os.path.join(UPLOADS_DIR, filename)
        result.save(filepath, "PNG")

        # Save to DB if conversation specified
        if conv_id:
            db.add_message(conv_id, "assistant", f"[Generated Image]\nPrompt: {prompt}\nLoRA: {', '.join(lora_adapters) if lora_adapters else 'none'}\nSeed: {seed}",
                          image_path=json.dumps([filename]))

        return JSONResponse({"ok": True, "image_url": f"/uploads/{filename}", "seed": seed})

    except Exception as e:
        print(f"Image generation error: {e}")
        return JSONResponse({"error": str(e)}, status_code=500)
    finally:
        gc.collect()
        torch.cuda.empty_cache()


@app.get("/api/image-adapters")
async def list_image_adapters(request: Request):
    """List available LoRA adapters for image generation."""
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    return JSONResponse({"adapters": list(IMAGE_ADAPTER_SPECS.keys()), "available": image_pipe is not None})


@app.get("/uploads/{filename}")
async def serve_upload(filename: str):
    filepath = os.path.join(UPLOADS_DIR, filename)
    if not os.path.exists(filepath):
        return JSONResponse({"error": "not found"}, status_code=404)
    return FileResponse(filepath, media_type="image/png")


# --- Web Search (DuckDuckGo) ---

class TextExtractor(HTMLParser):
    def __init__(self):
        super().__init__()
        self.text = []
        self.skip = False
    def handle_starttag(self, tag, attrs):
        if tag in ('script', 'style', 'nav', 'header', 'footer'):
            self.skip = True
    def handle_endtag(self, tag):
        if tag in ('script', 'style', 'nav', 'header', 'footer'):
            self.skip = False
    def handle_data(self, data):
        if not self.skip:
            self.text.append(data.strip())
    def get_text(self):
        return ' '.join(t for t in self.text if t)


def web_search(query: str, num_results: int = 3) -> str:
    """Search DuckDuckGo and return results with page content."""
    try:
        search_url = f"https://html.duckduckgo.com/html/?q={urllib.request.quote(query)}"
        req = urllib.request.Request(search_url, headers={"User-Agent": "Mozilla/5.0"})
        response = urllib.request.urlopen(req, timeout=10)
        html = response.read().decode("utf-8", errors="ignore")

        # Extract result URLs and titles
        results = []
        for match in re.finditer(r'<a rel="nofollow" class="result__a" href="(.*?)">(.*?)</a>', html):
            url = match.group(1)
            title = re.sub(r'<.*?>', '', match.group(2))
            if url.startswith("//duckduckgo.com/l/?uddg="):
                url = urllib.request.unquote(url.split("uddg=")[1].split("&")[0])
            results.append({"url": url, "title": title})
            if len(results) >= num_results:
                break

        # Fetch page content for each result
        output = []
        for r in results:
            try:
                req = urllib.request.Request(r["url"], headers={"User-Agent": "Mozilla/5.0"})
                page = urllib.request.urlopen(req, timeout=5)
                page_html = page.read().decode("utf-8", errors="ignore")
                parser = TextExtractor()
                parser.feed(page_html)
                content = parser.get_text()[:2000]  # limit content
                output.append(f"### {r['title']}\nURL: {r['url']}\n{content}\n")
            except:
                output.append(f"### {r['title']}\nURL: {r['url']}\n(Content unavailable)\n")

        return "\n".join(output) if output else "No results found."
    except Exception as e:
        return f"Search error: {str(e)}"


def fetch_url(url: str) -> str:
    """Fetch and extract text from a URL."""
    try:
        req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
        response = urllib.request.urlopen(req, timeout=10)
        html = response.read().decode("utf-8", errors="ignore")
        parser = TextExtractor()
        parser.feed(html)
        content = parser.get_text()[:4000]
        return f"Content from {url}:\n{content}"
    except Exception as e:
        return f"Failed to fetch {url}: {str(e)}"


def extract_urls(message: str) -> list[str]:
    """Extract URLs from message."""
    return re.findall(r'https?://[^\s<>"\']+', message)


def needs_search(message: str) -> str | None:
    """Check if message needs web search. Returns search query or None."""
    search_patterns = [
        r'(?:検索|調べ|探し|search|look up|find)',
        r'(?:最新|最近|今日|昨日|今の|現在の)',
        r'(?:ニュース|news)',
        r'(?:価格|値段|料金|price)',
        r'(?:天気|weather)',
    ]
    for pattern in search_patterns:
        if re.search(pattern, message, re.IGNORECASE):
            return message
    # Check for @search prefix
    if message.startswith("@search ") or message.startswith("@検索 "):
        return message.split(" ", 1)[1]
    return None


# --- Server Status API ---

@app.get("/api/status")
async def server_status():
    now = time.time()
    uptime = int(now - boot_time)
    idle = int(now - last_chat_time)
    remaining_uptime = max(0, MAX_UPTIME - uptime)
    remaining_idle = max(0, IDLE_TIMEOUT - idle)
    return {
        "uptime": uptime,
        "idle": idle,
        "remaining_uptime": remaining_uptime,
        "remaining_idle": remaining_idle,
        "max_uptime": MAX_UPTIME,
        "idle_timeout": IDLE_TIMEOUT,
    }


# --- Pages ---

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    user = get_user_from_request(request)
    if not user:
        with open("static/login.html", "r", encoding="utf-8") as f:
            return f.read()
    with open("static/index.html", "r", encoding="utf-8") as f:
        return f.read()


# --- Auth API ---

@app.post("/api/register")
async def register(request: Request):
    body = await request.json()
    username = body.get("username", "").strip()
    password = body.get("password", "")
    if not username or not password:
        return JSONResponse({"error": "Username and password required"}, status_code=400)
    if len(password) < 4:
        return JSONResponse({"error": "Password must be at least 4 characters"}, status_code=400)
    user_id = db.create_user(username, password)
    if user_id is None:
        return JSONResponse({"error": "Username already taken"}, status_code=409)
    token = create_token(user_id, username)
    resp = JSONResponse({"ok": True, "username": username})
    resp.set_cookie("token", token, httponly=True, max_age=86400 * 30, samesite="lax")
    return resp


@app.post("/api/login")
async def login(request: Request):
    body = await request.json()
    username = body.get("username", "").strip()
    password = body.get("password", "")
    user_id = db.authenticate_user(username, password)
    if user_id is None:
        return JSONResponse({"error": "Invalid username or password"}, status_code=401)
    token = create_token(user_id, username)
    resp = JSONResponse({"ok": True, "username": username})
    resp.set_cookie("token", token, httponly=True, max_age=86400 * 30, samesite="lax")
    return resp


@app.post("/api/logout")
async def logout():
    resp = JSONResponse({"ok": True})
    resp.delete_cookie("token")
    return resp


@app.get("/api/me")
async def me(request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Not logged in"}, status_code=401)
    user_data = db.get_user(user["user_id"])
    if not user_data:
        return {"user_id": user["user_id"], "username": user["username"]}
    return {
        "user_id": user["user_id"],
        "username": user["username"],
        "default_temperature": user_data.get("default_temperature", 1.0),
        "default_top_p": user_data.get("default_top_p", 0.95),
        "default_top_k": user_data.get("default_top_k", 64),
        "default_max_tokens": user_data.get("default_max_tokens", 4096),
    }


@app.put("/api/me/settings")
async def update_user_settings(request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    body = await request.json()
    db.update_user_defaults(
        user["user_id"],
        temperature=float(body.get("temperature", 1.0)),
        top_p=float(body.get("top_p", 0.95)),
        top_k=int(body.get("top_k", 64)),
        max_tokens=int(body.get("max_tokens", 4096)),
    )
    return {"ok": True}


# --- Conversation API ---

@app.get("/api/conversations")
async def get_conversations(request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    return db.list_conversations(user["user_id"])


@app.post("/api/conversations")
async def create_conversation(request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    body = await request.json() if request.headers.get("content-type") == "application/json" else {}
    system_prompt = body.get("system_prompt", SYSTEM_PROMPT)
    conv_id = db.create_conversation(user_id=user["user_id"], system_prompt=system_prompt)
    return {"id": conv_id, "system_prompt": system_prompt}


@app.get("/api/conversations/{conv_id}")
async def get_conversation(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    return conv


@app.get("/api/conversations/{conv_id}/messages")
async def get_messages(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    return db.get_messages(conv_id)


@app.put("/api/conversations/{conv_id}/system-prompt")
async def update_system_prompt(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    body = await request.json()
    db.update_system_prompt(conv_id, body["system_prompt"])
    return {"ok": True}


@app.put("/api/conversations/{conv_id}/settings")
async def update_conversation_settings(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    body = await request.json()
    db.update_conversation_settings(
        conv_id,
        temperature=body.get("temperature"),
        top_p=body.get("top_p"),
        top_k=body.get("top_k"),
        max_tokens=body.get("max_tokens"),
    )
    return {"ok": True}


@app.patch("/api/conversations/{conv_id}")
async def rename_conversation(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    body = await request.json()
    if "title" in body:
        db.update_conversation_title(conv_id, body["title"])
    return {"ok": True}


@app.delete("/api/conversations/{conv_id}")
async def delete_conversation(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    # Delete associated image files
    messages = db.get_messages(conv_id)
    for msg in messages:
        if msg.get("image_path"):
            try:
                filenames = json.loads(msg["image_path"])
                for fname in filenames:
                    filepath = os.path.join(UPLOADS_DIR, fname)
                    if os.path.exists(filepath):
                        os.remove(filepath)
            except (json.JSONDecodeError, TypeError):
                pass
    db.delete_conversation(conv_id)
    return {"ok": True}


# --- Delete last assistant message (for regenerate) ---

@app.delete("/api/conversations/{conv_id}/messages/last")
async def delete_last_message(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    conn = db.get_connection()
    try:
        last = conn.execute(
            "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id DESC LIMIT 1",
            (conv_id,),
        ).fetchone()
        if last:
            conn.execute("DELETE FROM messages WHERE id = ?", (last["id"],))
            conn.commit()
    finally:
        conn.close()
    return {"ok": True}


# --- Truncate messages after edit ---

@app.post("/api/conversations/{conv_id}/messages/truncate")
async def truncate_messages(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)
    body = await request.json()
    after_content = body.get("after_content", "")
    conn = db.get_connection()
    try:
        rows = conn.execute(
            "SELECT id, content FROM messages WHERE conversation_id = ? ORDER BY id",
            (conv_id,),
        ).fetchall()
        delete_from = None
        for row in rows:
            if delete_from is not None:
                pass
            elif row["content"] == after_content:
                delete_from = row["id"]
        if delete_from is not None:
            conn.execute(
                "DELETE FROM messages WHERE conversation_id = ? AND id >= ?",
                (conv_id, delete_from),
            )
            conn.commit()
    finally:
        conn.close()
    return {"ok": True}


# --- Chat API (SSE streaming) ---

@app.post("/api/chat")
async def chat(request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)

    global last_chat_time
    last_chat_time = time.time()

    body = await request.json()
    conv_id = body["conversation_id"]
    user_message = body["message"]
    images_data = body.get("images")  # list of base64 encoded images or None
    # Backward compat: single image field
    if not images_data and body.get("image"):
        images_data = [body["image"]]

    # Save images to files
    image_filenames = []
    if images_data:
        for img_b64 in images_data:
            image_filenames.append(save_image(img_b64))

    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)

    system_prompt = conv["system_prompt"] if conv else SYSTEM_PROMPT

    # Resolve generation settings: conversation > user defaults > global
    user_data = db.get_user(user["user_id"])
    chat_temperature = conv.get("temperature") if conv.get("temperature") is not None else (user_data.get("default_temperature", TEMPERATURE) if user_data else TEMPERATURE)
    chat_top_p = conv.get("top_p") if conv.get("top_p") is not None else (user_data.get("default_top_p", TOP_P) if user_data else TOP_P)
    chat_top_k = conv.get("top_k") if conv.get("top_k") is not None else (user_data.get("default_top_k", 64) if user_data else 64)
    chat_max_tokens = conv.get("max_tokens") if conv.get("max_tokens") is not None else (user_data.get("default_max_tokens", MAX_NEW_TOKENS) if user_data else MAX_NEW_TOKENS)

    # Auto-compact: if message count exceeds 40 (20 pairs), summarize old messages
    COMPACT_THRESHOLD = 40
    KEEP_RECENT = 10  # keep last 10 pairs (20 messages)
    msg_count = db.count_messages(conv_id)
    if msg_count >= COMPACT_THRESHOLD:
        old_messages = db.get_messages(conv_id)
        # Take older messages (everything except recent)
        old_part = old_messages[:-(KEEP_RECENT * 2)]
        if old_part:
            summary_input = conv.get("summary", "") or ""
            old_text = "\n".join([f"{m['role']}: {m['content'][:200]}" for m in old_part])
            if summary_input:
                old_text = f"Previous summary:\n{summary_input}\n\nNew messages to summarize:\n{old_text}"
            # Ask LLM to summarize (with lock)
            with llm_lock:
                summary_response = llm.create_chat_completion(
                    messages=[
                        {"role": "system", "content": "Summarize the following conversation concisely in Japanese. Keep key facts, decisions, and context. Be brief."},
                        {"role": "user", "content": old_text},
                    ],
                    max_tokens=500,
                    temperature=0.3,
                )
            new_summary = summary_response["choices"][0]["message"]["content"]
            db.update_summary(conv_id, new_summary)
            db.delete_old_messages(conv_id, keep_recent=KEEP_RECENT)
            conv = db.get_conversation(conv_id)  # refresh

    # Fetch URLs if present in message
    urls = extract_urls(user_message)
    url_context = ""
    if urls:
        url_contents = [fetch_url(u) for u in urls[:3]]  # max 3 URLs
        url_context = "\n\n".join(url_contents)

    # Web search if needed (skip if URLs were provided)
    search_query = needs_search(user_message) if not urls else None
    search_context = ""
    if search_query:
        search_context = web_search(search_query)

    db.add_message(conv_id, "user", user_message, image_path=json.dumps(image_filenames) if image_filenames else None)

    history = db.get_messages(conv_id)
    messages = []
    system_parts = []
    if system_prompt.strip():
        system_parts.append(system_prompt)
    # Add image generation instruction if pipeline is available
    if image_pipe is not None:
        adapter_descriptions = {
            "アニメ変換": "Convert photo/image to anime style",
            "写真→アニメ": "Convert photo to anime illustration",
            "アップスケーラー": "Upscale and enhance image quality",
            "スタイル転写": "Transfer artistic style to image",
            "ポーズ変更": "Change the pose of a person in the image",
            "オブジェクト削除": "Remove objects from image",
            "オブジェクト追加": "Add objects to image",
            "漫画トーン": "Convert to manga/comic tone style",
            "線画補間": "Line art interpolation",
            "モノクロキャラ変換": "Convert character to monochrome",
            "アングル変換": "Change camera angle of the image",
            "破れ服": "Make clothes torn/ripped effect",
        }
        adapter_info = "\n".join([f"  - {name}: {desc}" for name, desc in adapter_descriptions.items() if name in IMAGE_ADAPTER_SPECS])
        system_parts.append(f"""You have image generation/editing capability. When the user asks to create, generate, edit, transform, or modify an image, you MUST respond with a brief message (1 sentence in user's language) followed by a JSON block:

```json
{{"action": "generate_image", "prompt": "detailed English prompt", "lora": [], "source_image": false}}
```

Rules:
- prompt: Write a concise English EDITING INSTRUCTION, not a description of the original image
  - GOOD: "Change the clothes to pajamas", "Transform into anime style", "A cute cat sitting on a sofa"
  - BAD: "Three women standing together wearing casual clothes" (this describes the original, not the edit)
- IMPORTANT: When people appear in images, they must look like they are in their 20s or older. Never generate images of children, teenagers, or anyone who looks under 20.
- For i2i (editing uploaded image): prompt should describe WHAT TO CHANGE, not what the image currently shows
- For t2i (creating from scratch): prompt should describe the desired image
- source_image: true if editing an uploaded image (i2i), false for creating from scratch (t2i)
- lora: Choose the most appropriate adapter(s) from the list below, or use empty array [] if none fit
- For realistic/photorealistic edits (実写, リアル, etc.), do NOT use any adapter (empty array)
- For normal conversation (no image task), respond normally without JSON

Available LoRA adapters:
{adapter_info}""")
    # Add summary if exists
    summary = conv.get("summary")
    if summary:
        system_parts.append(f"Previous conversation summary:\n{summary}")
    # Add URL content if any
    if url_context:
        system_parts.append(f"Content from URLs provided by user:\n{url_context}\n\nUse this content to answer the user's question.")
    # Add search results if any
    if search_context:
        system_parts.append(f"Web search results for reference:\n{search_context}\n\nUse these results to answer the user's question accurately.")
    # Combine all system messages into one (Qwen requires single system message)
    if system_parts:
        messages.append({"role": "system", "content": "\n\n".join(system_parts)})
    for msg in history:
        messages.append({"role": msg["role"], "content": msg["content"]})

    # If images are provided and multimodal is supported, modify the last user message
    has_image = False
    if images_data and messages and messages[-1]["role"] == "user":
        if vision_handler is not None:
            has_image = True
            last_text = messages[-1]["content"]
            if not last_text:
                last_text = "What is in this image?"
            last_text += "\n\n(Determine the user's intent: if they want to edit/transform/change this image, output generate_image JSON with a concise English editing instruction as the prompt — do NOT describe what the image currently shows. If they want to know about the image, describe it normally. Reply in the same language as the user's message. Never repeat the same sentence or phrase.)"
            content_parts = []
            for img_b64 in images_data:
                content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}})
            content_parts.append({"type": "text", "text": last_text})
            messages[-1]["content"] = content_parts
        else:
            messages[-1]["content"] += "\n\n(Note: Images were attached but multimodal is not available on this server.)"

    async def generate():
        full_response = ""
        last_token_time = time.time()
        # Stream tokens, handle thinking blocks and garbage inline
        # Qwen injects <think> automatically so start in thinking state
        is_qwen = "qwen" in MODEL_REPO.lower() or "qwen" in MODEL_FILE.lower()
        parse_state = "thinking" if is_qwen else "normal"
        thought_buf = ""

        # Acquire lock for LLM access — send keepalive while waiting
        while not llm_lock.acquire(timeout=5):
            yield ": keepalive\n\n"
        try:
            # Swap chat_handler inside lock to prevent race condition
            if has_image and vision_handler is not None:
                llm.chat_handler = vision_handler
            for chunk in llm.create_chat_completion(
                messages=messages,
                max_tokens=chat_max_tokens,
                temperature=chat_temperature,
                top_p=chat_top_p,
                top_k=chat_top_k,
                min_p=0.05,
                repeat_penalty=1.1,
                frequency_penalty=0.3,
                presence_penalty=0.3,
                stream=True,
                stop=["<end_of_turn>", "<start_of_turn>"],
            ):
                delta = chunk["choices"][0]["delta"]
                if "content" in delta:
                    token = delta["content"]
                    last_token_time = time.time()

                    if parse_state == "thinking":
                        thought_buf += token
                        # Check both Gemma (<channel|>) and Qwen (</think>) end tags
                        end_markers = [("<channel|>", "<channel|>"), ("</think>", "</think>")]
                        for marker, marker_str in end_markers:
                            end_pos = thought_buf.find(marker)
                            if end_pos >= 0:
                                real_after = thought_buf[end_pos + len(marker_str):].lstrip()
                                thought_text = thought_buf[:end_pos].strip()
                                if thought_text.startswith("thought"):
                                    thought_text = thought_text[len("thought"):].strip()
                                parse_state = "normal"
                                if thought_text:
                                    yield f"data: {json.dumps({'thinking': thought_text})}\n\n"
                                if real_after:
                                    full_response += real_after
                                    yield f"data: {json.dumps({'token': real_after})}\n\n"
                                break
                        continue

                    # Check if thinking block starts (Gemma: <|channel>, Qwen: <think>)
                    if "<|channel>" in token:
                        parse_state = "thinking"
                        thought_buf = token.split("<|channel>", 1)[1] if "<|channel>" in token else ""
                        before = token.split("<|channel>", 1)[0]
                        if before:
                            full_response += before
                            yield f"data: {json.dumps({'token': before})}\n\n"
                        continue
                    if "<think>" in token:
                        parse_state = "thinking"
                        thought_buf = token.split("<think>", 1)[1]
                        before = token.split("<think>", 1)[0]
                        if before:
                            full_response += before
                            yield f"data: {json.dumps({'token': before})}\n\n"
                        continue

                    # Check split markers across tokens
                    if full_response.endswith("<|channel") and token.startswith(">"):
                        full_response = full_response[:-len("<|channel")]
                        parse_state = "thinking"
                        thought_buf = token[1:]
                        continue
                    if full_response.endswith("<think") and token.startswith(">"):
                        full_response = full_response[:-len("<think")]
                        parse_state = "thinking"
                        thought_buf = token[1:]
                        continue

                    full_response += token
                    yield f"data: {json.dumps({'token': token})}\n\n"
                else:
                    # Send keepalive comment if no token for a while
                    if time.time() - last_token_time > 10:
                        yield ": keepalive\n\n"
                        last_token_time = time.time()

        finally:
            if has_image:
                llm.chat_handler = None
            llm_lock.release()

        # If still in thinking state at end, send whatever we have
        if parse_state == "thinking" and thought_buf:
            yield f"data: {json.dumps({'thinking': thought_buf})}\n\n"

        # --- Image generation: LLM JSON output is primary trigger ---
        gen_result = None
        gen_prompt = None
        gen_loras = []
        should_generate = False

        # Strategy 1: Check if LLM output contains generate_image JSON
        if image_pipe is not None and full_response:
            try:
                json_match = re.search(r'```json\s*(\{.*?\})\s*```', full_response, re.DOTALL)
                if not json_match:
                    json_match = re.search(r'(\{[^{}]*"action"\s*:\s*"[^"]*"[^{}]*\})', full_response, re.DOTALL)
                if json_match:
                    json_str = json_match.group(1) if json_match.lastindex else json_match.group(0)
                    llm_cmd = json.loads(json_str)
                    # Accept any action with a prompt field as image generation
                    if llm_cmd.get("prompt") and llm_cmd.get("action"):
                        gen_prompt = llm_cmd.get("prompt", "").strip()
                        gen_loras = llm_cmd.get("lora", [])
                        if isinstance(gen_loras, str):
                            gen_loras = [gen_loras]
                        should_generate = True
                        # Remove JSON block from displayed response
                        clean_response = re.sub(r'```json\s*\{.*?\}\s*```', '', full_response, flags=re.DOTALL).strip()
                        if not clean_response:
                            clean_response = re.sub(r'\{[^{}]*"action"\s*:\s*"[^"]*"[^{}]*\}', '', full_response).strip()
                        if clean_response and clean_response != full_response:
                            full_response = clean_response
                        # Translate prompt to English if needed
                        if gen_prompt and any(ord(c) > 127 for c in gen_prompt):
                            try:
                                with llm_lock:
                                    tr = llm.create_chat_completion(
                                        messages=[
                                            {"role": "system", "content": "Translate the following image editing instruction to English. Output ONLY the English translation, nothing else. All people must be adults (18+). Include 'adult' if people are involved."},
                                            {"role": "user", "content": gen_prompt},
                                        ],
                                        max_tokens=100, temperature=0.1,
                                    )
                                gen_prompt = tr["choices"][0]["message"]["content"].strip()
                                print(f"Translated prompt: {gen_prompt[:80]}")
                            except Exception:
                                pass
                        print(f"LLM decided to generate: prompt={gen_prompt[:80]}, lora={gen_loras}")
            except (json.JSONDecodeError, AttributeError, IndexError):
                pass

        # Strategy 2: Fallback keyword detection (if LLM didn't output JSON)
        if not should_generate and image_pipe is not None:
            image_gen_keywords = ["画像を生成", "画像を作", "画像作って", "画像生成", "イメージを生成", "イメージを作",
                                  "イラストを作", "イラストを生成", "写真を生成", "写真を作", "実写画像",
                                  "全身画像", "全身の画像",
                                  "に変えて", "に変更", "を変えて", "を変更", "を編集", "を加工",
                                  "にして", "風にして", "アニメ化", "高画質化",
                                  "やり直し", "もう一回", "もう1回", "再生成", "作り直", "やりなおし",
                                  "generate image", "create image", "make image", "draw", "edit image", "retry"]
            should_generate = any(kw in user_message for kw in image_gen_keywords)
            if should_generate:
                # Keyword-based LoRA detection
                lora_keywords = {
                    "アニメ": "アニメ変換", "anime": "アニメ変換",
                    "アニメ調": "アニメ変換", "アニメ風": "アニメ変換",
                    "アップスケール": "アップスケーラー", "高画質": "アップスケーラー",
                    "スタイル": "スタイル転写",
                    "ポーズ": "ポーズ変更",
                    "削除": "オブジェクト削除", "消して": "オブジェクト削除",
                    "追加": "オブジェクト追加",
                    "漫画": "漫画トーン",
                    "線画": "線画補間",
                    "モノクロ": "モノクロキャラ変換", "白黒": "モノクロキャラ変換",
                    "アングル": "アングル変換",
                }
                for kw, lora_name in lora_keywords.items():
                    if kw in user_message:
                        gen_loras.append(lora_name)
                # Translate prompt via LLM — use user message only (not LLM's analysis)
                try:
                    with llm_lock:
                        translate_resp = llm.create_chat_completion(
                            messages=[
                                {"role": "system", "content": "Convert the following image editing request into a concise English editing instruction (max 80 words). Output ONLY the instruction, e.g. 'Change the clothes to a red dress'. Do NOT describe the original image. All people must be depicted as adults (18+). Include 'adult' in the prompt if people are involved."},
                                {"role": "user", "content": user_message},
                            ],
                            max_tokens=120, temperature=0.3,
                        )
                    gen_prompt = translate_resp["choices"][0]["message"]["content"].strip()
                    print(f"Fallback translated prompt: {gen_prompt[:100]}")
                except Exception:
                    gen_prompt = user_message

        if should_generate:
            try:
                import torch, gc, random
                from PIL import Image as PILImage

                yield f"data: {json.dumps({'status': 'generating_image'})}\n\n"

                gc.collect()
                torch.cuda.empty_cache()

                if not gen_prompt:
                    gen_prompt = user_message
                if len(gen_prompt) > 500:
                    gen_prompt = gen_prompt[:500]

                # Source image for i2i
                pil_images = []
                if images_data:
                    img_bytes = base64.b64decode(images_data[0])
                    pil_images.append(PILImage.open(io.BytesIO(img_bytes)).convert("RGB"))
                elif not images_data:
                    # No image uploaded — find the most recent image in conversation
                    # (generated by assistant OR uploaded by user)
                    for msg in reversed(history):
                        if msg.get("image_path"):
                            try:
                                prev_imgs = json.loads(msg["image_path"])
                                if prev_imgs:
                                    prev_path = os.path.join(UPLOADS_DIR, prev_imgs[-1])
                                    if os.path.exists(prev_path):
                                        pil_images.append(PILImage.open(prev_path).convert("RGB"))
                                        print(f"Reusing previous image ({msg['role']}): {prev_imgs[-1]}")
                                        break
                            except (json.JSONDecodeError, TypeError):
                                pass

                gen_seed = random.randint(0, 2**31 - 1)
                generator = torch.Generator(device="cuda").manual_seed(gen_seed)

                # Pipeline requires source image - use blank for t2i
                if not pil_images:
                    pil_images = [PILImage.new("RGB", (1024, 1024), (255, 255, 255))]

                w, h = pil_images[0].size
                if w > h:
                    gw, gh = 1024, (1024 * h // w // 8) * 8
                else:
                    gh, gw = 1024, (1024 * w // h // 8) * 8

                print(f"Generating image: prompt={gen_prompt[:100]}... lora={gen_loras} size={gw}x{gh}")
                with pipe_lock:
                    # Apply LoRA adapters (inside lock to prevent concurrent access)
                    adapter_names = []
                    for adapter in gen_loras:
                        spec = IMAGE_ADAPTER_SPECS.get(adapter)
                        if not spec:
                            continue
                        a_name = spec["adapter_name"]
                        adapter_names.append(a_name)
                        if a_name not in IMAGE_LOADED_ADAPTERS:
                            try:
                                print(f"Loading LoRA: {adapter}")
                                image_pipe.load_lora_weights(spec["repo"], weight_name=spec["weights"], adapter_name=a_name)
                                IMAGE_LOADED_ADAPTERS.add(a_name)
                            except Exception as e:
                                print(f"LoRA load error: {e}")
                    if adapter_names:
                        weights = [1.0 / len(adapter_names)] * len(adapter_names)
                        image_pipe.set_adapters(adapter_names, adapter_weights=weights)
                    elif IMAGE_LOADED_ADAPTERS:
                        image_pipe.set_adapters([], adapter_weights=[])

                    # Ensure enough VRAM before generation (need ~5GB free)
                    REQUIRED_FREE_GB = 5
                    MAX_WAIT_SEC = 30
                    wait_start = time.time()
                    while True:
                        gc.collect()
                        torch.cuda.empty_cache()
                        torch.cuda.synchronize()
                        free_bytes, total_bytes = torch.cuda.mem_get_info()
                        free_gb = free_bytes / (1024**3)
                        if free_gb >= REQUIRED_FREE_GB:
                            break
                        if time.time() - wait_start > MAX_WAIT_SEC:
                            print(f"VRAM wait timeout: only {free_gb:.1f}GB free, proceeding anyway")
                            break
                        print(f"VRAM low ({free_gb:.1f}GB free), waiting...")
                        time.sleep(2)

                    # Retry on transient errors (OOM, transformer state issues)
                    result_img = None
                    last_err = None
                    for attempt in range(3):
                        try:
                            result_img = image_pipe(
                                image=pil_images if pil_images else None,
                                prompt=gen_prompt,
                                negative_prompt=DEFAULT_NEGATIVE_PROMPT,
                                height=gh, width=gw,
                                num_inference_steps=4,
                                generator=generator,
                                true_cfg_scale=1.0,
                            ).images[0]
                            break
                        except Exception as e:
                            last_err = e
                            print(f"Image gen attempt {attempt+1}/3 failed: {e}")
                            gc.collect()
                            torch.cuda.empty_cache()
                            torch.cuda.synchronize()
                            time.sleep(2)
                    if result_img is None:
                        raise last_err

                gen_filename = f"gen_{uuid.uuid4().hex}.png"
                result_img.save(os.path.join(UPLOADS_DIR, gen_filename), "PNG")
                gen_result = {"image_url": f"/uploads/{gen_filename}", "seed": gen_seed}
                print(f"Image generated: {gen_filename}")

                gc.collect()
                torch.cuda.empty_cache()
            except Exception as e:
                print(f"Image gen error: {e}")
                err_msg = f"\n\n(Image generation failed: {str(e)})"
                yield f"data: {json.dumps({'token': err_msg})}\n\n"

        # Save assistant message
        if gen_result:
            db.add_message(conv_id, "assistant", full_response, image_path=json.dumps([gen_result["image_url"].split("/")[-1]]))
            yield f"data: {json.dumps({'generated_image': gen_result['image_url']})}\n\n"
        else:
            db.add_message(conv_id, "assistant", full_response)

        if len(history) == 1:
            title = user_message[:50] + ("..." if len(user_message) > 50 else "")
            db.update_conversation_title(conv_id, title)
            yield f"data: {json.dumps({'title': title})}\n\n"

        yield "data: [DONE]\n\n"

    return StreamingResponse(
        generate(),
        media_type="text/event-stream",
        headers={"X-Accel-Buffering": "no", "Cache-Control": "no-cache"},
    )


# --- CSV Download ---

@app.get("/api/conversations/{conv_id}/download")
async def download_csv(conv_id: int, request: Request):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)
    conv = db.get_conversation(conv_id)
    if not conv or conv["user_id"] != user["user_id"]:
        return JSONResponse({"error": "Not found"}, status_code=404)

    messages = db.get_messages(conv_id)
    output = io.StringIO()
    writer = csv.writer(output)
    writer.writerow(["role", "content"])
    writer.writerow(["system", conv["system_prompt"]])
    for msg in messages:
        writer.writerow([msg["role"], msg["content"]])

    safe_title = conv["title"].replace('"', '').replace('\n', '')[:50]
    from urllib.parse import quote
    encoded_title = quote(safe_title)
    return Response(
        content=output.getvalue(),
        media_type="text/csv",
        headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_title}.csv"},
    )


# --- CSV Upload (batch processing) ---

@app.post("/api/conversations/upload")
async def upload_csv(request: Request, file: UploadFile = File(...)):
    user = get_user_from_request(request)
    if not user:
        return JSONResponse({"error": "Unauthorized"}, status_code=401)

    content = await file.read()
    text = content.decode("utf-8-sig")
    reader = csv.DictReader(io.StringIO(text))

    results = []
    for row in reader:
        system_prompt = row.get("system_prompt", "").strip()
        user_prompt = row.get("user_prompt", "").strip()
        if not user_prompt:
            continue

        # Create a new conversation per row
        title = user_prompt[:50] + ("..." if len(user_prompt) > 50 else "")
        conv_id = db.create_conversation(
            user_id=user["user_id"],
            title=title,
            system_prompt=system_prompt or SYSTEM_PROMPT,
        )
        db.add_message(conv_id, "user", user_prompt)

        # Build messages for LLM
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_prompt})

        # Generate response (non-streaming, with lock)
        with llm_lock:
            response = llm.create_chat_completion(
                messages=messages,
                max_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
            )
        assistant_content = response["choices"][0]["message"]["content"]
        db.add_message(conv_id, "assistant", assistant_content)

        results.append({
            "conversation_id": conv_id,
            "user_prompt": user_prompt,
            "assistant_response": assistant_content[:100] + "...",
        })

    return {"processed": len(results), "conversations": results}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("server:app", host="0.0.0.0", port=PORT, reload=False, timeout_keep_alive=300)
