"""SQLite database layer for conversation history."""

import sqlite3
import os
import hashlib
import secrets
from datetime import datetime, timezone

DB_PATH = os.environ.get("DB_PATH", "chat.db")


def get_connection() -> sqlite3.Connection:
    conn = sqlite3.connect(DB_PATH)
    conn.row_factory = sqlite3.Row
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("PRAGMA foreign_keys=ON")
    return conn


def init_db():
    conn = get_connection()
    conn.executescript("""
        CREATE TABLE IF NOT EXISTS users (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            username TEXT NOT NULL UNIQUE,
            password_hash TEXT NOT NULL,
            salt TEXT NOT NULL,
            created_at TEXT NOT NULL
        );
        CREATE TABLE IF NOT EXISTS conversations (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id INTEGER NOT NULL DEFAULT 0,
            title TEXT NOT NULL DEFAULT 'New Chat',
            system_prompt TEXT NOT NULL DEFAULT 'You are a helpful assistant.',
            created_at TEXT NOT NULL,
            updated_at TEXT NOT NULL,
            FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
        );
        CREATE TABLE IF NOT EXISTS messages (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            conversation_id INTEGER NOT NULL,
            role TEXT NOT NULL CHECK(role IN ('system', 'user', 'assistant')),
            content TEXT NOT NULL,
            created_at TEXT NOT NULL,
            FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
        );
        CREATE INDEX IF NOT EXISTS idx_messages_conv ON messages(conversation_id);
        CREATE INDEX IF NOT EXISTS idx_conversations_user ON conversations(user_id);
    """)
    # Migrations
    conv_cols = [r[1] for r in conn.execute("PRAGMA table_info(conversations)").fetchall()]
    if "system_prompt" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN system_prompt TEXT NOT NULL DEFAULT 'You are a helpful assistant.'")
    if "user_id" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN user_id INTEGER NOT NULL DEFAULT 0")
    if "temperature" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN temperature REAL DEFAULT NULL")
    if "top_p" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN top_p REAL DEFAULT NULL")
    if "top_k" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN top_k INTEGER DEFAULT NULL")
    if "max_tokens" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN max_tokens INTEGER DEFAULT NULL")
    if "summary" not in conv_cols:
        conn.execute("ALTER TABLE conversations ADD COLUMN summary TEXT DEFAULT NULL")

    msg_cols = [r[1] for r in conn.execute("PRAGMA table_info(messages)").fetchall()]
    if "image_path" not in msg_cols:
        conn.execute("ALTER TABLE messages ADD COLUMN image_path TEXT DEFAULT NULL")

    user_cols = [r[1] for r in conn.execute("PRAGMA table_info(users)").fetchall()]
    if "default_temperature" not in user_cols:
        conn.execute("ALTER TABLE users ADD COLUMN default_temperature REAL DEFAULT 1.0")
    if "default_top_p" not in user_cols:
        conn.execute("ALTER TABLE users ADD COLUMN default_top_p REAL DEFAULT 0.95")
    if "default_top_k" not in user_cols:
        conn.execute("ALTER TABLE users ADD COLUMN default_top_k INTEGER DEFAULT 64")
    if "default_max_tokens" not in user_cols:
        conn.execute("ALTER TABLE users ADD COLUMN default_max_tokens INTEGER DEFAULT 4096")

    conn.commit()
    conn.close()


def _now() -> str:
    return datetime.now(timezone.utc).isoformat()


def _hash_password(password: str, salt: str) -> str:
    return hashlib.sha256((salt + password).encode()).hexdigest()


# --- Users ---

def create_user(username: str, password: str) -> int | None:
    conn = get_connection()
    salt = secrets.token_hex(16)
    password_hash = _hash_password(password, salt)
    try:
        cur = conn.execute(
            "INSERT INTO users (username, password_hash, salt, created_at) VALUES (?, ?, ?, ?)",
            (username, password_hash, salt, _now()),
        )
        user_id = cur.lastrowid
        conn.commit()
    except sqlite3.IntegrityError:
        conn.close()
        return None
    conn.close()
    return user_id


def authenticate_user(username: str, password: str) -> int | None:
    conn = get_connection()
    row = conn.execute(
        "SELECT id, password_hash, salt FROM users WHERE username = ?",
        (username,),
    ).fetchone()
    conn.close()
    if not row:
        return None
    if _hash_password(password, row["salt"]) == row["password_hash"]:
        return row["id"]
    return None


def get_user(user_id: int) -> dict | None:
    conn = get_connection()
    row = conn.execute(
        "SELECT id, username, default_temperature, default_top_p, default_top_k, default_max_tokens FROM users WHERE id = ?",
        (user_id,),
    ).fetchone()
    conn.close()
    return dict(row) if row else None


def update_user_defaults(user_id: int, temperature: float, top_p: float, top_k: int, max_tokens: int):
    conn = get_connection()
    conn.execute(
        "UPDATE users SET default_temperature=?, default_top_p=?, default_top_k=?, default_max_tokens=? WHERE id=?",
        (temperature, top_p, top_k, max_tokens, user_id),
    )
    conn.commit()
    conn.close()


# --- Conversations ---

def create_conversation(user_id: int, title: str = "New Chat", system_prompt: str = "You are a helpful assistant.") -> int:
    conn = get_connection()
    now = _now()
    cur = conn.execute(
        "INSERT INTO conversations (user_id, title, system_prompt, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
        (user_id, title, system_prompt, now, now),
    )
    conv_id = cur.lastrowid
    conn.commit()
    conn.close()
    return conv_id


def list_conversations(user_id: int) -> list[dict]:
    conn = get_connection()
    rows = conn.execute(
        "SELECT id, title, system_prompt, updated_at FROM conversations WHERE user_id = ? ORDER BY updated_at DESC",
        (user_id,),
    ).fetchall()
    conn.close()
    return [dict(r) for r in rows]


def get_conversation(conversation_id: int) -> dict | None:
    conn = get_connection()
    row = conn.execute(
        "SELECT id, user_id, title, system_prompt, temperature, top_p, top_k, max_tokens, summary FROM conversations WHERE id = ?",
        (conversation_id,),
    ).fetchone()
    conn.close()
    return dict(row) if row else None


def update_system_prompt(conversation_id: int, system_prompt: str):
    conn = get_connection()
    conn.execute(
        "UPDATE conversations SET system_prompt = ? WHERE id = ?",
        (system_prompt, conversation_id),
    )
    conn.commit()
    conn.close()


def update_summary(conversation_id: int, summary: str):
    conn = get_connection()
    conn.execute(
        "UPDATE conversations SET summary = ? WHERE id = ?",
        (summary, conversation_id),
    )
    conn.commit()
    conn.close()


def delete_old_messages(conversation_id: int, keep_recent: int = 10):
    """Delete older messages, keeping the most recent `keep_recent` pairs (user+assistant)."""
    conn = get_connection()
    rows = conn.execute(
        "SELECT id FROM messages WHERE conversation_id = ? ORDER BY id",
        (conversation_id,),
    ).fetchall()
    keep_count = keep_recent * 2  # user + assistant pairs
    if len(rows) > keep_count:
        delete_ids = [r["id"] for r in rows[:-keep_count]]
        conn.execute(
            f"DELETE FROM messages WHERE id IN ({','.join('?' * len(delete_ids))})",
            delete_ids,
        )
        conn.commit()
    conn.close()


def count_messages(conversation_id: int) -> int:
    conn = get_connection()
    row = conn.execute(
        "SELECT COUNT(*) as cnt FROM messages WHERE conversation_id = ?",
        (conversation_id,),
    ).fetchone()
    conn.close()
    return row["cnt"]


def update_conversation_settings(conversation_id: int, temperature: float | None, top_p: float | None, top_k: int | None, max_tokens: int | None):
    conn = get_connection()
    conn.execute(
        "UPDATE conversations SET temperature=?, top_p=?, top_k=?, max_tokens=? WHERE id=?",
        (temperature, top_p, top_k, max_tokens, conversation_id),
    )
    conn.commit()
    conn.close()


def get_messages(conversation_id: int) -> list[dict]:
    conn = get_connection()
    rows = conn.execute(
        "SELECT role, content, image_path FROM messages WHERE conversation_id = ? ORDER BY id",
        (conversation_id,),
    ).fetchall()
    conn.close()
    return [dict(r) for r in rows]


def add_message(conversation_id: int, role: str, content: str, image_path: str | None = None) -> int:
    conn = get_connection()
    now = _now()
    cur = conn.execute(
        "INSERT INTO messages (conversation_id, role, content, image_path, created_at) VALUES (?, ?, ?, ?, ?)",
        (conversation_id, role, content, image_path, now),
    )
    conn.execute(
        "UPDATE conversations SET updated_at = ? WHERE id = ?",
        (now, conversation_id),
    )
    msg_id = cur.lastrowid
    conn.commit()
    conn.close()
    return msg_id


def update_conversation_title(conversation_id: int, title: str):
    conn = get_connection()
    conn.execute(
        "UPDATE conversations SET title = ? WHERE id = ?",
        (title, conversation_id),
    )
    conn.commit()
    conn.close()


def delete_conversation(conversation_id: int):
    conn = get_connection()
    conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
    conn.commit()
    conn.close()
