"""
OTA Update Provider: manage multiple Arduino-style projects and .bin firmware uploads.
Uses shared MySQL config (config.py); stores .bin files on disk.
"""
import os
import re
import shutil
import pymysql
from flask import Blueprint, request, redirect, url_for, render_template, send_file, flash, Response

from config import (
    MYSQL_HOST,
    MYSQL_PORT,
    MYSQL_DATABASE,
    MYSQL_USER,
    MYSQL_PASSWORD,
    OTA_FIRMWARE_DIR,
)

bp = Blueprint("ota", __name__)

OTA_DB_CONFIG = {
    "host": MYSQL_HOST,
    "port": MYSQL_PORT,
    "user": MYSQL_USER,
    "password": MYSQL_PASSWORD,
    "database": MYSQL_DATABASE,
    "cursorclass": pymysql.cursors.DictCursor,
}

MAX_UPLOAD_MB = 2


def parse_version(s):
    """Parse 'X.Y.Z' or 'X.Y' to tuple of ints for comparison."""
    s = (s or "").strip()
    try:
        return tuple(int(x) for x in s.split(".") if x.strip())
    except (ValueError, AttributeError):
        return (0,)


def version_less(a_str, b_str):
    """True if a is strictly less than b (e.g. 1.0.0 < 1.0.1)."""
    a = parse_version(a_str)
    b = parse_version(b_str)
    for i in range(max(len(a), len(b))):
        ai = a[i] if i < len(a) else 0
        bi = b[i] if i < len(b) else 0
        if ai < bi:
            return True
        if ai > bi:
            return False
    return False


def get_db():
    return pymysql.connect(**OTA_DB_CONFIG)


def ensure_ota_tables(conn):
    with conn.cursor() as cur:
        cur.execute("""
            CREATE TABLE IF NOT EXISTS ota_projects (
                id INT AUTO_INCREMENT PRIMARY KEY,
                name VARCHAR(255) NOT NULL,
                slug VARCHAR(255) NOT NULL UNIQUE,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        cur.execute("""
            CREATE TABLE IF NOT EXISTS ota_firmware (
                id INT AUTO_INCREMENT PRIMARY KEY,
                project_id INT NOT NULL,
                version VARCHAR(64) NOT NULL,
                original_filename VARCHAR(255) NOT NULL,
                stored_path VARCHAR(512) NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (project_id) REFERENCES ota_projects(id) ON DELETE CASCADE
            )
        """)
        cur.execute("""
            CREATE TABLE IF NOT EXISTS ota_update_log (
                id INT AUTO_INCREMENT PRIMARY KEY,
                project_id INT NOT NULL,
                project_name VARCHAR(255) NULL,
                project_slug VARCHAR(255) NULL,
                version VARCHAR(64) NOT NULL,
                client_ip VARCHAR(64) NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                FOREIGN KEY (project_id) REFERENCES ota_projects(id) ON DELETE CASCADE
            )
        """)
        try:
            cur.execute("ALTER TABLE ota_update_log ADD COLUMN project_name VARCHAR(255) NULL")
        except pymysql.OperationalError:
            pass
        try:
            cur.execute("ALTER TABLE ota_update_log ADD COLUMN project_slug VARCHAR(255) NULL")
        except pymysql.OperationalError:
            pass
        try:
            cur.execute("ALTER TABLE ota_update_log ADD COLUMN node_id VARCHAR(255) NULL")
        except pymysql.OperationalError:
            pass
        try:
            cur.execute("ALTER TABLE ota_update_log ADD COLUMN node_name VARCHAR(255) NULL")
        except pymysql.OperationalError:
            pass
        try:
            cur.execute("ALTER TABLE ota_update_log ADD COLUMN from_version VARCHAR(64) NULL")
        except pymysql.OperationalError:
            pass
    conn.commit()


def slugify(name):
    """Lowercase, replace spaces with -, keep only alphanumeric and dash."""
    s = (name or "").strip().lower()
    s = re.sub(r"[^a-z0-9\-]", "-", s)
    s = re.sub(r"-+", "-", s).strip("-")
    return s or "project"


def _parse_version_parts(s):
    """Return (major, minor, patch) with at least 3 components. E.g. 1.0.1 -> (1,0,1), 1 -> (1,0,0)."""
    t = parse_version(s)
    if len(t) >= 3:
        return (t[0], t[1], t[2])
    if len(t) == 2:
        return (t[0], t[1], 0)
    if len(t) == 1:
        return (t[0], 0, 0)
    return (0, 0, 0)


def next_version_from(version_str):
    """Next semantic version: 1.0.1 -> 1.0.2; 1.0.9 -> 1.1.0; 1.9.9 -> 2.0.0. No version -> 1.0.0."""
    if not (version_str or "").strip():
        return "1.0.0"
    major, minor, patch = _parse_version_parts(version_str)
    patch += 1
    if patch <= 9:
        return f"{major}.{minor}.{patch}"
    patch = 0
    minor += 1
    if minor <= 9:
        return f"{major}.{minor}.{patch}"
    minor = 0
    major += 1
    return f"{major}.{minor}.{patch}"


def get_latest_version_str(conn, project_id):
    """Return the latest (semantically greatest) version string for the project, or None."""
    with conn.cursor() as cur:
        cur.execute(
            "SELECT version FROM ota_firmware WHERE project_id = %s",
            (project_id,),
        )
        rows = cur.fetchall()
    versions = [(r.get("version") or "").strip() for r in rows if (r.get("version") or "").strip()]
    if not versions:
        return None
    latest = versions[0]
    for v in versions[1:]:
        if version_less(latest, v):
            latest = v
    return latest


def next_version(conn, project_id):
    """Return next version string for project: last version + 1 (semantic: 1.0.1->1.0.2, 1.0.9->1.1.0), or 1.0.0 if none."""
    latest = get_latest_version_str(conn, project_id)
    return next_version_from(latest)


def _ensure_firmware_dir():
    os.makedirs(OTA_FIRMWARE_DIR, exist_ok=True)


@bp.route("/")
def index():
    """Dashboard: list all projects."""
    conn = get_db()
    try:
        ensure_ota_tables(conn)
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id, name, slug, created_at FROM ota_projects ORDER BY name"
            )
            projects = cur.fetchall()
            for p in projects:
                if p.get("created_at"):
                    p["created_at"] = p["created_at"].isoformat()
        with conn.cursor() as cur:
            cur.execute(
                """SELECT l.id, l.version, l.from_version, l.client_ip, l.node_id, l.node_name, l.created_at,
                          COALESCE(NULLIF(TRIM(l.project_name), ''), p.name) AS project_name,
                          COALESCE(NULLIF(TRIM(l.project_slug), ''), p.slug) AS project_slug
                   FROM ota_update_log l
                   JOIN ota_projects p ON l.project_id = p.id
                   ORDER BY l.created_at DESC LIMIT 50"""
            )
            update_log = cur.fetchall()
            for e in update_log:
                if e.get("created_at"):
                    e["created_at"] = e["created_at"].isoformat()
        return render_template("ota/index.html", projects=projects, update_log=update_log)
    finally:
        conn.close()


@bp.route("/project/new", methods=["GET", "POST"])
def project_new():
    """Create a new project (name -> slug)."""
    if request.method == "GET":
        return render_template("ota/project_new.html")
    name = (request.form.get("name") or "").strip()
    if not name:
        flash("Project name is required.")
        return render_template("ota/project_new.html"), 400
    slug = slugify(name)
    conn = get_db()
    try:
        ensure_ota_tables(conn)
        with conn.cursor() as cur:
            cur.execute("SELECT id FROM ota_projects WHERE slug = %s", (slug,))
            if cur.fetchone():
                flash(f"A project with slug '{slug}' already exists. Choose a different name.")
                return render_template("ota/project_new.html"), 400
            cur.execute(
                "INSERT INTO ota_projects (name, slug) VALUES (%s, %s)",
                (name, slug),
            )
        conn.commit()
        return redirect(url_for("ota.project_detail", slug=slug))
    finally:
        conn.close()


@bp.route("/project/<slug>")
def project_detail(slug):
    """Project detail: latest firmware row only, upload form, update log."""
    slug = (slug or "").strip()
    conn = get_db()
    try:
        ensure_ota_tables(conn)
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id, name, slug FROM ota_projects WHERE slug = %s",
                (slug,),
            )
            project = cur.fetchone()
        if not project:
            flash("Project not found.")
            return redirect(url_for("ota.index"))
        with conn.cursor() as cur:
            cur.execute(
                """SELECT id, version, original_filename, created_at
                   FROM ota_firmware WHERE project_id = %s ORDER BY id DESC LIMIT 1""",
                (project["id"],),
            )
            firmware = cur.fetchall()
            for f in firmware:
                if f.get("created_at"):
                    f["created_at"] = f["created_at"].isoformat()
        next_ver = next_version(conn, project["id"])
        with conn.cursor() as cur:
            cur.execute(
                """SELECT id, version, from_version, client_ip, node_id, node_name, created_at,
                          COALESCE(NULLIF(TRIM(project_name), ''), %s) AS project_name,
                          COALESCE(NULLIF(TRIM(project_slug), ''), %s) AS project_slug
                   FROM ota_update_log WHERE project_id = %s ORDER BY created_at DESC LIMIT 30""",
                (project["name"], project["slug"], project["id"]),
            )
            project_update_log = cur.fetchall()
            for e in project_update_log:
                if e.get("created_at"):
                    e["created_at"] = e["created_at"].isoformat()
        return render_template(
            "ota/project_detail.html",
            project=project,
            firmware=firmware,
            next_version=next_ver,
            project_update_log=project_update_log,
        )
    finally:
        conn.close()


@bp.route("/project/<slug>/edit", methods=["GET", "POST"])
def project_edit(slug):
    """Edit project name and slug."""
    slug = (slug or "").strip()
    conn = get_db()
    try:
        ensure_ota_tables(conn)
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id, name, slug FROM ota_projects WHERE slug = %s",
                (slug,),
            )
            project = cur.fetchone()
        if not project:
            flash("Project not found.")
            return redirect(url_for("ota.index"))
        if request.method == "GET":
            return render_template("ota/project_edit.html", project=project)
        name = (request.form.get("name") or "").strip()
        new_slug = slugify(request.form.get("slug") or project["slug"])
        if not name:
            flash("Project name is required.")
            return render_template("ota/project_edit.html", project=project), 400
        if new_slug != slug:
            with conn.cursor() as cur:
                cur.execute("SELECT id FROM ota_projects WHERE slug = %s AND id != %s", (new_slug, project["id"]))
                if cur.fetchone():
                    flash(f"Slug '{new_slug}' is already used by another project.")
                    return render_template("ota/project_edit.html", project=project), 400
            old_dir = os.path.join(OTA_FIRMWARE_DIR, slug)
            new_dir = os.path.join(OTA_FIRMWARE_DIR, new_slug)
            if os.path.isdir(old_dir):
                os.rename(old_dir, new_dir)
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT id, stored_path FROM ota_firmware WHERE project_id = %s",
                    (project["id"],),
                )
                for row in cur.fetchall():
                    new_path = os.path.join(new_slug, os.path.basename(row["stored_path"]))
                    cur.execute("UPDATE ota_firmware SET stored_path = %s WHERE id = %s", (new_path, row["id"]))
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE ota_projects SET name = %s, slug = %s WHERE id = %s",
                (name, new_slug, project["id"]),
            )
        conn.commit()
        flash("Project updated.")
        return redirect(url_for("ota.project_detail", slug=new_slug))
    finally:
        conn.close()


def _sanitize_filename(name):
    """Keep only safe characters for a stored filename."""
    name = os.path.basename(name or "firmware.bin")
    name = re.sub(r"[^a-zA-Z0-9._\-]", "_", name)
    return name or "firmware.bin"


@bp.route("/project/<slug>/upload", methods=["POST"])
def upload(slug):
    """Upload a .bin file; optional version override. A file must be chosen."""
    slug = (slug or "").strip()
    file = request.files.get("file")
    version_override = (request.form.get("version") or "").strip()
    conn = get_db()
    try:
        ensure_ota_tables(conn)
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id, name, slug FROM ota_projects WHERE slug = %s",
                (slug,),
            )
            project = cur.fetchone()
        if not project:
            flash("Project not found.")
            return redirect(url_for("ota.index"))
        if not file or not file.filename:
            flash("Please choose a .bin file to upload.")
            return redirect(url_for("ota.project_detail", slug=slug))
        if not file.filename.lower().endswith(".bin"):
            flash("Only .bin files are allowed.")
            return redirect(url_for("ota.project_detail", slug=slug))
        if request.content_length and request.content_length > MAX_UPLOAD_MB * 1024 * 1024:
            flash(f"File too large (max {MAX_UPLOAD_MB} MB).")
            return redirect(url_for("ota.project_detail", slug=slug))
        version = version_override if version_override else next_version(conn, project["id"])
        _ensure_firmware_dir()
        project_dir = os.path.join(OTA_FIRMWARE_DIR, slug)
        os.makedirs(project_dir, exist_ok=True)
        safe_version = re.sub(r"[^a-zA-Z0-9._\-]", "_", version)
        original_safe = _sanitize_filename(file.filename)
        stored_name = f"{safe_version}_{original_safe}"
        if not stored_name.lower().endswith(".bin"):
            stored_name += ".bin"
        stored_path = os.path.join(project_dir, stored_name)
        file.save(stored_path)
        rel_path = os.path.join(slug, stored_name)
        with conn.cursor() as cur:
            cur.execute(
                """INSERT INTO ota_firmware (project_id, version, original_filename, stored_path)
                   VALUES (%s, %s, %s, %s)""",
                (project["id"], version, file.filename, rel_path),
            )
        conn.commit()
        flash(f"Uploaded {file.filename} as version {version}.")
        return redirect(url_for("ota.project_detail", slug=slug))
    finally:
        conn.close()


@bp.route("/project/<slug>/update")
def project_update(slug):
    """ESP32 OTA: GET with header x-esp32-version. Return 200+bin if server has newer, else 304, or 404."""
    slug = (slug or "").strip()
    device_version = (request.headers.get("x-esp32-version") or "").strip()
    # If device didn't send version, don't send firmware (avoids re-installing same version in a loop)
    if not device_version:
        return Response(status=304)
    conn = get_db()
    try:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id, name, slug FROM ota_projects WHERE slug = %s",
                (slug,),
            )
            project = cur.fetchone()
        if not project:
            return Response("No project or firmware", status=404)
        with conn.cursor() as cur:
            cur.execute(
                """SELECT id, version, stored_path FROM ota_firmware
                   WHERE project_id = %s ORDER BY id DESC""",
                (project["id"],),
            )
            firmware_rows = cur.fetchall()
        if not firmware_rows:
            return Response("No project or firmware", status=404)
        # Latest = row with largest semantic version
        latest = firmware_rows[0]
        for row in firmware_rows[1:]:
            if version_less(latest["version"], row["version"]):
                latest = row
        server_version = (latest.get("version") or "").strip()
        full_path = os.path.join(OTA_FIRMWARE_DIR, latest["stored_path"])
        if not os.path.isfile(full_path):
            return Response("No project or firmware", status=404)
        if version_less(device_version, server_version):
            client_ip = (request.remote_addr or "").strip() or None
            node_id = (request.headers.get("x-esp32-node-id") or "").strip() or None
            node_name = (request.headers.get("x-esp32-node-name") or "").strip() or None
            # Device makes 2 GETs per update (pre-check + httpUpdate); log only once per update
            with conn.cursor() as cur:
                cur.execute(
                    """SELECT 1 FROM ota_update_log
                       WHERE project_id = %s AND version = %s AND (client_ip <=> %s)
                       AND created_at > NOW() - INTERVAL 2 MINUTE
                       LIMIT 1""",
                    (project["id"], server_version, client_ip),
                )
                if cur.fetchone() is None:
                    cur.execute(
                        """INSERT INTO ota_update_log (project_id, project_name, project_slug, version, from_version, client_ip, node_id, node_name)
                           VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
                        (project["id"], project.get("name") or "", project.get("slug") or "", server_version, device_version or None, client_ip, node_id, node_name),
                    )
            conn.commit()
            return send_file(
                full_path,
                mimetype="application/octet-stream",
                as_attachment=True,
                download_name="firmware.bin",
            )
        return Response(status=304)
    finally:
        conn.close()


@bp.route("/project/<slug>/firmware/<int:firmware_id>")
def download_firmware(slug, firmware_id):
    """Download the .bin file for Arduino clients."""
    slug = (slug or "").strip()
    conn = get_db()
    try:
        with conn.cursor() as cur:
            cur.execute(
                """SELECT f.id, f.stored_path, p.slug
                   FROM ota_firmware f
                   JOIN ota_projects p ON f.project_id = p.id
                   WHERE p.slug = %s AND f.id = %s""",
                (slug, firmware_id),
            )
            row = cur.fetchone()
        if not row:
            flash("Firmware not found.")
            return redirect(url_for("ota.index"))
        full_path = os.path.join(OTA_FIRMWARE_DIR, row["stored_path"])
        if not os.path.isfile(full_path):
            flash("File no longer exists on disk.")
            return redirect(url_for("ota.project_detail", slug=slug))
        return send_file(
            full_path,
            as_attachment=True,
            download_name=os.path.basename(row["stored_path"]),
            mimetype="application/octet-stream",
        )
    finally:
        conn.close()


@bp.route("/project/<slug>/firmware/<int:firmware_id>/delete", methods=["POST"])
def delete_firmware(slug, firmware_id):
    """Delete a firmware record and its file from disk."""
    slug = (slug or "").strip()
    conn = get_db()
    try:
        with conn.cursor() as cur:
            cur.execute(
                """SELECT f.id, f.stored_path, p.slug
                   FROM ota_firmware f
                   JOIN ota_projects p ON f.project_id = p.id
                   WHERE p.slug = %s AND f.id = %s""",
                (slug, firmware_id),
            )
            row = cur.fetchone()
        if not row:
            flash("Firmware not found.")
            return redirect(url_for("ota.project_detail", slug=slug))
        full_path = os.path.join(OTA_FIRMWARE_DIR, row["stored_path"])
        with conn.cursor() as cur:
            cur.execute("DELETE FROM ota_firmware WHERE id = %s", (firmware_id,))
        conn.commit()
        if os.path.isfile(full_path):
            try:
                os.remove(full_path)
            except OSError:
                pass
        flash("Firmware deleted.")
        return redirect(url_for("ota.project_detail", slug=slug))
    finally:
        conn.close()
