import struct
import socket
from flask import Flask, request, jsonify
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, Tuple, Any
import ssl
import os

# ------------------------------------------------------------
# KONFIGURACE
# ------------------------------------------------------------
API_KEY = os.environ.get("MODBUS_API_KEY", "tajny-klic")  # jednoduchá auth přes hlavičku
MAX_WORKERS_PER_ENDPOINT = 1  # kolik paralelních jobů může běžet na jedné (ip,port)
DEFAULT_MODBUS_PORT = 502
MODBUS_TIMEOUT = 3  # s (v subprocessu)
FUTURE_TIMEOUT = 5  # s (v hlavním procesu)

# cesty k certům (pro HTTPS)
SERVER_CERT = "server.crt"
SERVER_KEY = "server.key"
CA_CERT = "ca.crt"  # pokud chceš klientské certy, tady bude jejich CA

app = Flask(__name__)

# pools: (ip,port) -> ProcessPoolExecutor
pools: Dict[Tuple[str, int], ProcessPoolExecutor] = {}


# ------------------------------------------------------------
# Pomocné
# ------------------------------------------------------------
def get_pool_for_endpoint(ip: str, port: int) -> ProcessPoolExecutor:
    key = (ip, port)
    if key not in pools:
        pools[key] = ProcessPoolExecutor(max_workers=MAX_WORKERS_PER_ENDPOINT)
    return pools[key]


def validate_request(data: dict) -> Tuple[bool, Any]:
    """
    Vrátí (ok, value|error_json)
    Když je OK=False, druhá hodnota je JSON nápověda.
    """
    help_json = {
        "ok": False,
        "error": "bad request",
        "expected": {
            "ip": "string, required",
            "port": "int, optional, default 502",
            "unit": "int 0-247, optional, default 1",
            "instr": "int 1-6, required",
            "addr": "int >0, required",
            "type": "raw|int|float|hex|bin, optional, default raw",
            "count": "int >=1, when reading (1-3), default 1",
            "value": "int or list[int], required for instr 4-6",
        },
        "instr_legend": {
            "1": "read holding registers",
            "2": "read coils",
            "3": "read input registers",
            "4": "write single holding register (value=int)",
            "5": "write single coil (value=0/1/bool)",
            "6": "write multiple holding registers (value=[int,...])",
            "7": "read discrete inputs (DI)",
        },
    }
    unit = int(data.get("unit", 1))
    if unit < 0 or unit > 247:
        help_json["error"] = "unit must be 0..247"
        return False, help_json
    ip = data.get("ip")
    if not ip:
        help_json["error"] = "missing ip"
        return False, help_json

    # validace IP
    try:
        socket.inet_aton(ip)
    except OSError:
        help_json["error"] = "invalid ip"
        return False, help_json

    port = int(data.get("port", DEFAULT_MODBUS_PORT))
    instr = data.get("instr")
    if instr is None:
        help_json["error"] = "missing instr (1-6)"
        return False, help_json
    try:
        instr = int(instr)
    except ValueError:
        help_json["error"] = "instr must be int (1-6)"
        return False, help_json

    if instr not in (1, 2, 3, 4, 5, 6, 7):
        help_json["error"] = "instr must be in 1..7"
        return False, help_json

    addr = data.get("addr")
    if addr is None:
        help_json["error"] = "missing addr"
        return False, help_json
    try:
        addr = int(addr)
    except ValueError:
        help_json["error"] = "addr must be int"
        return False, help_json

    out_type = data.get("type", "raw")
    count = int(data.get("count", 1))

    # u zápisu MUSÍ být value
    if instr in (4, 5, 6):
        if "value" not in data:
            help_json["error"] = "missing value for write (instr 4-6)"
            return False, help_json

    return True, {
        "ip": ip,
        "port": port,
        "unit": unit,
        "instr": instr,
        "addr": addr,
        "type": out_type,
        "count": count,
        "value": data.get("value"),
    }


# ------------------------------------------------------------
# SUBPROCESS MODBUS TASK
# ------------------------------------------------------------
def call_with_unit(unit: int, fn, **kwargs):
    # zkus postupně různé názvy parametru (podle verze pymodbus)
    last_err = None

    for key in ("slave", "unit", "device_id"):
        try:
            return fn(**kwargs, **{key: unit})
        except TypeError as e:
            last_err = e
            continue

    # nic nefungovalo -> dej jasnou chybu (ať to vidíš v API)
    raise TypeError(f"{fn.__qualname__} does not accept unit/slave/device_id (unit={unit}). last={last_err}")
def modbus_task(ip: str, port: int, instr: int, addr: int, count: int, value: Any, unit: int):
    from pymodbus.client import ModbusTcpClient

    client = ModbusTcpClient(host=ip, port=port, timeout=MODBUS_TIMEOUT)
    if not client.connect():
        return {"ok": False, "error": f"cannot connect to {ip}:{port}"}

    zero_based_addr = addr - 1 if addr > 0 else 0

    try:
        if instr == 1:  # read holding
            rr = call_with_unit(unit, client.read_holding_registers, address=zero_based_addr, count=count)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "read", "data": rr.registers}

        elif instr == 2:  # read coils
            rr = call_with_unit(unit, client.read_coils, address=zero_based_addr, count=count)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "read", "data": rr.bits[:count]}

        elif instr == 3:  # read input registers
            rr = call_with_unit(unit, client.read_input_registers, address=zero_based_addr, count=count)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "read", "data": rr.registers}

        elif instr == 4:  # write single holding
            if not isinstance(value, int):
                try:
                    value = int(value)
                except Exception:
                    return {"ok": False, "error": "value must be int for instr 4"}
            rr = call_with_unit(unit, client.write_register, address=zero_based_addr, value=value)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "write", "written": value}

        elif instr == 5:  # write single coil
            coil_val = value if isinstance(value, bool) else bool(int(value))
            rr = call_with_unit(unit, client.write_coil, address=zero_based_addr, value=coil_val)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "write", "written": int(coil_val)}

        elif instr == 6:  # write multiple holding
            if not isinstance(value, list):
                return {"ok": False, "error": "value must be list[int] for instr 6"}
            rr = call_with_unit(unit, client.write_registers, address=zero_based_addr, values=value)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "write", "written": value}

        elif instr == 7:  # read discrete inputs (DI)
            rr = call_with_unit(unit, client.read_discrete_inputs, address=zero_based_addr, count=count)
            if rr.isError():
                return {"ok": False, "error": str(rr)}
            return {"ok": True, "kind": "read", "data": rr.bits[:count]}

        else:
            return {"ok": False, "error": f"unsupported instr {instr}"}

    finally:
        client.close()


# ------------------------------------------------------------
# PŘEVOD VÝSTUPU
# ------------------------------------------------------------
def convert_read_data(raw, out_type: str):
    """
    raw: list[int] nebo list[bool]
    out_type: raw|int|float|hex|bin
    """
    if out_type == "raw":
        return raw

    # coils = bool
    if isinstance(raw, list) and raw and isinstance(raw[0], bool):
        if out_type == "bin":
            return "".join("1" if b else "0" for b in raw)
        return raw

    # registr(y)
    if out_type == "int":
        if len(raw) == 1:
            return raw[0]
        return raw

    if out_type == "hex":
        return [hex(r) for r in raw]

    if out_type == "bin":
        return [bin(r) for r in raw]

    if out_type == "float":
        if len(raw) < 2:
            raise ValueError("need at least 2 registers for float")
        high, low = raw[0], raw[1]
        packed = struct.pack(">HH", high, low)
        value = struct.unpack(">f", packed)[0]
        return value

    return raw


# ------------------------------------------------------------
# API ENDPOINT
# ------------------------------------------------------------
@app.route("/modbus/call", methods=["POST"])
def modbus_call():
    # --- jednoduchá auth přes hlavičku ---
    client_key = request.headers.get("X-API-Key")
    if client_key != API_KEY:
        return jsonify({"ok": False, "error": "unauthorized"}), 401

    data = request.get_json(silent=True)
    if not data:
        return jsonify({
            "ok": False,
            "error": "invalid json",
            "help": "POST JSON with ip, instr, addr ...",
        }), 400

    ok, parsed = validate_request(data)
    if not ok:
        return jsonify(parsed), 400

    ip = parsed["ip"]
    port = parsed["port"]
    unit = parsed["unit"]
    instr = parsed["instr"]
    addr = parsed["addr"]
    out_type = parsed["type"]
    count = parsed["count"]
    value = parsed["value"]

    pool = get_pool_for_endpoint(ip, port)
    future = pool.submit(modbus_task, ip, port, instr, addr, count, value, unit)

    try:
        result = future.result(timeout=FUTURE_TIMEOUT)
    except Exception as e:
        return jsonify({"ok": False, "error": f"timeout or worker error: {e}"}), 500

    if not result.get("ok"):
        # vrátíme chybu z Modbusu
        return jsonify(result), 500

    # když to byl WRITE, nic nepřevádíme
    if result.get("kind") == "write":
        return jsonify({
            "ok": True,
            "kind": "write",
            "ip": ip,
            "port": port,
            "addr": addr,
            "instr": instr,
            "written": result.get("written"),
        })

    # READ
    raw = result.get("data")
    try:
        converted = convert_read_data(raw, out_type)
    except Exception as e:
        return jsonify({
            "ok": False,
            "error": f"convert error: {e}",
            "raw": raw,
        }), 500

    return jsonify({
        "ok": True,
        "kind": "read",
        "data": converted,
        "raw": raw,
        "format": out_type,
        "meta": {
            "ip": ip,
            "port": port,
            "unit": unit,
            "addr": addr,
            "instr": instr,
            "count": count,
        },
    })


# ------------------------------------------------------------
# MAIN
# ------------------------------------------------------------
if __name__ == "__main__":
    # HTTPS varianta
    # pokud chceš i ověřování klienta, uděláš SSL context ručně:
    use_https = os.path.exists(SERVER_CERT) and os.path.exists(SERVER_KEY)

    if use_https:
        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        context.load_cert_chain(SERVER_CERT, SERVER_KEY)

        # pokud chceš vyžadovat klientský cert:
        # context.verify_mode = ssl.CERT_REQUIRED
        # context.load_verify_locations(CA_CERT)

        app.run(host="0.0.0.0", port=5000, ssl_context=context)
    else:
        # fallback bez HTTPS
        app.run(host="0.0.0.0", port=5000, debug=True)