#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
import csv
import ssl
import time
import json
import math
import struct
import queue
import signal
import pathlib
import datetime as dt
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import pymysql
import requests
from configparser import ConfigParser
import importlib
from typing import Protocol

class Plugin(Protocol):
    name: str
    def on_tick(self, tick: dt.datetime, api: "ModbusApiClient", log_root: str) -> None: ...

# -----------------------------
# Konfigurace a model
# -----------------------------

@dataclass
class PluginState:
    plugin: Any
    next_run: dt.datetime
    ren: Optional[int]  # None => daily
class LegacyPlugin:
    name = "legacy_mbm"
    ren = 120  # nebo daily = True
@dataclass(frozen=True)
class Rule:
    id: int
    name: str
    csv_col: int           # 1..9
    ip: str
    port: int
    unit: int
    instr: int             # 1,2,3,7...
    addr: int              # 0-based (pozor: server v mbm_service dělá addr-1 pokud addr>0) :contentReference[oaicite:2]{index=2}
    dtype: str             # float/u16/s16/u32/s32/coil/di/raw...
    ren: Optional[int]     # None => denně 00:00:00
    prc: Optional[int]     # decimals for rounding


@dataclass
class TaskState:
    rule: Rule
    next_run: dt.datetime

# -----------------------------
# Pomocné
# -----------------------------

def load_plugins(cfg: ConfigParser) -> List[PluginState]:
    names = [x.strip() for x in cfg.get("plugins", "enabled", fallback="").split(",") if x.strip()]
    out: List[PluginState] = []
    now = dt.datetime.now()
    t0 = now_tick_30s(now)

    for n in names:
        mod = importlib.import_module(f"plugins.{n}")
        p = mod.plugin

        daily = bool(getattr(p, "daily", False))
        if daily:
            midnight = t0.replace(hour=0, minute=0, second=0, microsecond=0)
            next_run = midnight if t0 == midnight else midnight + dt.timedelta(days=1)
            out.append(PluginState(plugin=p, next_run=next_run, ren=None))
            continue

        ren = int(getattr(p, "ren", 30))
        if ren <= 0 or ren % 30 != 0:
            log(f"PLUGIN {getattr(p,'name',n)}: invalid ren={ren}, using 30", 1)
            ren = 30

        out.append(PluginState(plugin=p, next_run=t0, ren=ren))

    return out
    
def log(msg: str, level: int = 0) -> None:
    # 0=info,1=warn,2=err
    prefix = {0: "[I]", 1: "[W]", 2: "[E]"}.get(level, "[?]")
    print(f"{dt.datetime.now().isoformat(timespec='seconds')} {prefix} {msg}", flush=True)


def safe_folder_name(name: str) -> str:
    # jednoduchá sanitizace pro složky
    s = name.strip()
    s = re.sub(r"[^\w\-\.\s]", "_", s, flags=re.UNICODE)
    s = re.sub(r"\s+", "_", s)
    return s[:128] or "noname"


def now_tick_30s(now: Optional[dt.datetime] = None) -> dt.datetime:
    now = now or dt.datetime.now()
    sec = (now.second // 30) * 30
    return now.replace(second=sec, microsecond=0)


def seconds_until_next_tick(now: Optional[dt.datetime] = None) -> float:
    now = now or dt.datetime.now()
    # příští hranice 30 s
    next_sec = ((now.second // 30) + 1) * 30
    add_min = 0
    if next_sec >= 60:
        next_sec -= 60
        add_min = 1
    nxt = now.replace(second=next_sec, microsecond=0) + dt.timedelta(minutes=add_min)
    return max(0.0, (nxt - now).total_seconds())


def csv_filename_for_day(day: dt.date) -> str:
    # DD_MM_RR.csv
    return f"{day.day:02d}_{day.month:02d}_{day.year % 100:02d}.csv"


def ensure_dir(path: pathlib.Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


# -----------------------------
# DB
# -----------------------------

def connect_mysql(cfg: ConfigParser):
    sql = pymysql.connect(
        host=cfg.get("database", "host"),
        port=cfg.getint("database", "port", fallback=3306),
        user=cfg.get("database", "user"),
        password=cfg.get("database", "password"),
        database=cfg.get("database", "name"),
        autocommit=False,
        cursorclass=pymysql.cursors.DictCursor,
        charset="utf8mb4",
    )
    log("MySQL - Připojeno", 0)
    return sql


def load_rules(sql) -> List[Rule]:
    q = """
        SELECT id, name, input, ip, port, unit, instr, addr, type, ren, prc
        FROM mrc
        WHERE enabled=1
    """
    with sql.cursor() as cur:
        cur.execute(q)
        rows = cur.fetchall()

    rules: List[Rule] = []
    for r in rows:
        csv_col = int(r["input"])
        ren = r["ren"]
        ren = int(ren) if ren is not None else None
        prc = r.get("prc")
        prc = int(prc) if prc is not None else None

        rules.append(Rule(
            id=int(r["id"]),
            name=str(r["name"]),
            csv_col=csv_col,
            ip=str(r["ip"]),
            port=int(r["port"]),
            unit=int(r["unit"]),
            instr=int(r["instr"]),
            addr=int(r["addr"]),
            dtype=str(r["type"]).lower(),
            ren=ren,
            prc=prc,
        ))

    return rules


def validate_rules(rules: List[Rule], ren_max: int = 3600) -> Tuple[List[Rule], List[str]]:
    errors: List[str] = []
    ok: List[Rule] = []

    for rule in rules:
        if not (1 <= rule.csv_col <= 9):
            errors.append(f"id={rule.id}: input/csv_col musí být 1..9")
            continue
        if rule.ren is not None:
            if rule.ren % 30 != 0 or rule.ren <= 0 or rule.ren > ren_max:
                errors.append(f"id={rule.id}: ren musí být násobek 30 a 1..{ren_max} (nebo NULL)")
                continue
        if rule.instr not in (1, 2, 3, 7):
            errors.append(f"id={rule.id}: instr podporuj jen 1/2/3/7 (read) v tomto loggeru")
            continue
        if rule.prc is not None and rule.prc > 10:
            errors.append(f"id={rule.id}: prc je podezřele vysoké (>{10})")
            continue
        ok.append(rule)

    return ok, errors


# -----------------------------
# API klient
# -----------------------------

class ModbusApiClient:
    def __init__(self, base_url: str, api_key: str, verify_tls: bool = False, timeout_s: float = 6.0):
        self.base_url = base_url.rstrip("/")
        self.api_key = api_key
        self.verify_tls = verify_tls
        self.timeout_s = timeout_s

    def read_raw(self, ip: str, port: int, unit: int, instr: int, addr: int, count: int) -> Any:
        url = f"{self.base_url}/modbus/call"
        payload = {
            "ip": ip,
            "port": port,
            "unit": unit,
            "instr": instr,
            "addr": addr,
            "count": count,
            "type": "raw",
        }
        headers = {
            "Content-Type": "application/json",
            "X-API-Key": self.api_key,
        }
        r = requests.post(url, headers=headers, json=payload, timeout=self.timeout_s, verify=self.verify_tls)
        if not r.ok:
            raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
        try:
            data = r.json()
            if not data.get("ok"):
                raise RuntimeError(f"API error: {data}")
        except Exception:
            raise RuntimeError(f"INVALID JSON RESPONSE: {r.text[:200]}")
        return data.get("raw")  # pro read vrací "raw" :contentReference[oaicite:5]{index=5}


# -----------------------------
# Převody (logger-side)
# -----------------------------

def regs_to_float_be(regs: List[int]) -> float:
    # stejné endian jako server: pack(">HH") a unpack(">f") :contentReference[oaicite:6]{index=6}
    high, low = regs[0], regs[1]
    packed = struct.pack(">HH", high, low)
    return struct.unpack(">f", packed)[0]


def regs_to_u32_be(regs: List[int]) -> int:
    return (int(regs[0]) << 16) | int(regs[1])


def regs_to_s32_be(regs: List[int]) -> int:
    u = regs_to_u32_be(regs)
    return u - 0x100000000 if u & 0x80000000 else u


def reg_to_s16(v: int) -> int:
    v = int(v) & 0xFFFF
    return v - 0x10000 if v & 0x8000 else v


def apply_prc(value: Any, prc: Optional[int]) -> Any:
    if prc is None:
        return value
    try:
        # round funguje pro float i int
        return round(float(value), int(prc))
    except Exception:
        return value


def decode_value(dtype: str, raw: Any, prc: Optional[int]) -> Any:
    """
    raw: list[int] pro registry nebo list[bool] pro coils/DI
    """
    dtp = (dtype or "raw").lower()

    if dtp == "raw":
        return raw

    # coils / DI => list[bool], typicky chceme první bit (protože čteme count=1)
    if isinstance(raw, list) and raw and isinstance(raw[0], bool):
        if dtp in ("coil", "di", "bool"):
            return bool(raw[0])
        return raw

    # registry
    regs = raw if isinstance(raw, list) else [raw]

    if dtp in ("u16", "uint16"):
        v = int(regs[0])
        return apply_prc(v, prc)

    if dtp in ("s16", "int16"):
        v = reg_to_s16(regs[0])
        return apply_prc(v, prc)

    if dtp in ("u32", "uint32"):
        if len(regs) < 2:
            return None
        v = regs_to_u32_be(regs[:2])
        return apply_prc(v, prc)

    if dtp in ("s32", "int32"):
        if len(regs) < 2:
            return None
        v = regs_to_s32_be(regs[:2])
        return apply_prc(v, prc)

    if dtp == "float":
        if len(regs) < 2:
            return None
        v = regs_to_float_be(regs[:2])
        return apply_prc(v, prc)

    # fallback
    return apply_prc(regs[0] if len(regs) == 1 else regs, prc)


def needed_register_count(dtype: str, instr: int) -> int:
    dtp = (dtype or "raw").lower()
    # coils / DI
    if instr in (2, 7) or dtp in ("coil", "di", "bool"):
        return 1
    # registry typy
    if dtp in ("u32", "uint32", "s32", "int32", "float"):
        return 2
    return 1


# -----------------------------
# Seskupení do souvislých bloků
# -----------------------------

@dataclass
class Block:
    ip: str
    port: int
    unit: int
    instr: int
    start_addr: int
    count: int
    members: List[Rule]


def build_blocks(rules: List[Rule]) -> List[Block]:
    by_key: Dict[Tuple[str, int, int, int], List[Rule]] = {}
    for r in rules:
        by_key.setdefault((r.ip, r.port, r.unit, r.instr), []).append(r)

    blocks: List[Block] = []

    for (ip, port, unit, instr), rs in by_key.items():
        rs_sorted = sorted(rs, key=lambda x: x.addr)

        cur_members: List[Rule] = []
        cur_start: Optional[int] = None
        cur_end_excl: Optional[int] = None

        for rule in rs_sorted:
            span = needed_register_count(rule.dtype, rule.instr)
            r_start = rule.addr
            r_end_excl = rule.addr + span

            if cur_start is None:
                cur_start = r_start
                cur_end_excl = r_end_excl
                cur_members = [rule]
                continue

            if r_start <= cur_end_excl:
                cur_members.append(rule)
                cur_end_excl = max(cur_end_excl, r_end_excl)
            else:
                blocks.append(Block(
                    ip=ip, port=port, unit=unit, instr=instr,
                    start_addr=cur_start,
                    count=int(cur_end_excl - cur_start),
                    members=cur_members
                ))
                cur_start = r_start
                cur_end_excl = r_end_excl
                cur_members = [rule]

        if cur_start is not None:
            blocks.append(Block(
                ip=ip, port=port, unit=unit, instr=instr,
                start_addr=cur_start,
                count=int(cur_end_excl - cur_start),
                members=cur_members
            ))

    return blocks
# -----------------------------
# CSV zápis
# -----------------------------

class CsvWriter:
    def __init__(self, root_dir: str):
        self.root = pathlib.Path(root_dir)

    def write_row(self, name: str, when: dt.datetime, cols_1_to_9: List[Optional[Any]]) -> None:
        folder = self.root / safe_folder_name(name)
        ensure_dir(folder)

        fn = folder / csv_filename_for_day(when.date())
        is_new = not fn.exists()

        # řádek: timestamp + 9 sloupců
        row = [when.strftime("%H:%M:%S")]
        for v in cols_1_to_9:
            row.append("" if v is None else str(v))

        with fn.open("a", newline="", encoding="utf-8") as f:
            w = csv.writer(f)
            if is_new:
                # volitelné: hlavička; když ji nechceš, smaž blok
                header = ["time"] + [f"c{i}" for i in range(1, 10)]
                w.writerow(header)
            w.writerow(row)
            f.flush()
            os.fsync(f.fileno())


# -----------------------------
# Scheduler (tick 30s)
# -----------------------------

def initial_next_run(rule: Rule, start: dt.datetime) -> dt.datetime:
    """
    - ren=None => nejbližší 00:00:00 (dnes pokud start==00:00:00, jinak zítra)
    - ren=...  => nejbližší tick (zaokrouhleno na 30s) + případné dorovnání na ren
    """
    t0 = now_tick_30s(start)

    if rule.ren is None:
        midnight = t0.replace(hour=0, minute=0, second=0, microsecond=0)
        if t0 == midnight:
            return midnight
        # další půlnoc
        return midnight + dt.timedelta(days=1)

    # ren periodicky od startu: nastavíme next_run na nejbližší tick >= start
    return t0


def run_logger(cfg_path: str) -> None:
    cfg = ConfigParser()
    cfg.read(cfg_path, encoding="utf-8")
    plugin_states = load_plugins(cfg)
    plugins = load_plugins(cfg)
    sql = connect_mysql(cfg)
    rules = load_rules(sql)
    rules, errors = validate_rules(rules, ren_max=cfg.getint("logger", "ren_max", fallback=3600))
    if errors:
        for e in errors:
            log(f"CONFIG ERROR: {e}", 2)
        log("Některé řádky byly přeskočeny kvůli chybám.", 1)

    log(f"Načteno pravidel: {len(rules)}", 0)

    api = ModbusApiClient(
        base_url=cfg.get("api", "base_url"),
        api_key=cfg.get("api", "api_key"),
        verify_tls=cfg.getboolean("api", "verify_tls", fallback=False),
        timeout_s=cfg.getfloat("api", "timeout_s", fallback=6.0),
    )
    writer = CsvWriter(cfg.get("logger", "log_dir", fallback="logs"))

    # stav úloh
    now = dt.datetime.now()
    tasks: List[TaskState] = []
    for r in rules:
        tasks.append(TaskState(rule=r, next_run=initial_next_run(r, now)))

    stop = False

    def handle_sig(_sig, _frm):
        nonlocal stop
        stop = True

    signal.signal(signal.SIGINT, handle_sig)
    signal.signal(signal.SIGTERM, handle_sig)

    log("Logger startuje. Tick každých 30s.", 0)

    while not stop:
        tick = now_tick_30s()
        tick_deadline = tick + dt.timedelta(seconds=30)

        # vyber due úlohy
        due: List[TaskState] = [t for t in tasks if t.next_run <= tick]

        # priorita: ren=30 nejdřív, pak 60, ... pak None (denní) poslední
        def prio(ts: TaskState) -> Tuple[int, int]:
            if ts.rule.ren is None:
                return (10_000_000, ts.rule.id)
            return (ts.rule.ren, ts.rule.id)

        due.sort(key=prio)

        if due:
            log(f"Tick {tick.time()} due={len(due)}", 0)

        # zpracování po skupinách NAME (kvůli CSV řádkům)
        # každé name => jeden řádek v CSV za tick (pokud má alespoň 1 hodnotu)
        due_by_name: Dict[str, List[Rule]] = {}
        for ts in due:
            due_by_name.setdefault(ts.rule.name, []).append(ts.rule)

        # když nestíháme, přeskočíme nízkou prioritu, ale ren=30 se snažíme vždy stihnout
        # jednoduchá strategie: nejdřív zpracuj všechny name, které obsahují ren=30 pravidla
        names_sorted = sorted(
            due_by_name.keys(),
            key=lambda n: (0 if any(r.ren == 30 for r in due_by_name[n]) else 1, n)
        )

        done_count = 0
        skipped_count = 0

        for name in names_sorted:
            if dt.datetime.now() >= tick_deadline:
                # přeskoč zbytek tohoto ticku
                skipped_count += len(due_by_name[name])
                continue

            rules_for_name = due_by_name[name]

            # postav bloky pro toto name (lepší kontrola priority; zároveň méně hodnot na řádek)
            blocks = build_blocks(rules_for_name)

            # připrav výstup 1..9
            cols: List[Optional[Any]] = [None] * 9

            ok_any = False
            for b in blocks:
                if dt.datetime.now() >= tick_deadline:
                    break

                try:
                    raw = api.read_raw(b.ip, b.port, b.unit, b.instr, b.start_addr, b.count)
                    # raw je list[int] (registry) nebo list[bool] (coils/DI)
                except Exception as e:
                    log(f"API FAIL name={name} {b.ip}:{b.port} instr={b.instr} addr={b.start_addr} count={b.count}: {e}", 1)
                    continue

                # rozdistribuuj členy bloku
                for r in b.members:
                    span = needed_register_count(r.dtype, r.instr)
                    offset = r.addr - b.start_addr
                    part = raw[offset: offset + span] if isinstance(raw, list) else raw
                    val = decode_value(r.dtype, part, r.prc)
                    cols[r.csv_col - 1] = val
                    ok_any = True

            if ok_any:
                writer.write_row(name=name, when=tick, cols_1_to_9=cols)

            done_count += len(rules_for_name)

        # posuň next_run pro due úlohy (i když se třeba nestihly – dle tvého „přeskočit okno“)
        for ts in due:
            if ts.rule.ren is None:
                ts.next_run = ts.next_run + dt.timedelta(days=1)
            else:
                ts.next_run = ts.next_run + dt.timedelta(seconds=ts.rule.ren)

        if due:
            log(f"Tick hotovo={done_count} přeskočeno={skipped_count}", 0)
        due_plugins = [ps for ps in plugin_states if ps.next_run <= tick]
        for ps in due_plugins:
            try:
                ps.plugin.on_tick(
                    tick,
                    api,
                    cfg.get("logger", "log_dir", fallback="logs")
                )
            except Exception as e:
                log(f"PLUGIN FAIL {getattr(ps.plugin,'name','?')}: {e}", 1)
            finally:
                if ps.ren is None:
                    ps.next_run = ps.next_run + dt.timedelta(days=1)
                else:
                    ps.next_run = ps.next_run + dt.timedelta(seconds=ps.ren)
# ===================
        # počkej do dalšího ticku
        time.sleep(seconds_until_next_tick())

    log("Logger ukončen.", 0)


if __name__ == "__main__":
    # použití: python3 modbus_logger.py config.ini
    import sys
    if len(sys.argv) < 2:
        print("Usage: python3 modbus_logger.py config.ini")
        raise SystemExit(2)
    run_logger(sys.argv[1])
