import time
import os
import cv2
import serial
import json
import RPi.GPIO as GPIO
from ultralytics import YOLO


MODEL_PATH = "best.pt"          # YOLO model file
CONF_THRESHOLD = 0.8            # Confidence threshold

# Predator classes
PREDATOR_CLASSES = [
    "coyote",
    "fox",
    "hawk",
    "opossum",
    "raccoon",
    "snake",
    "possumbrushtail"  # mapped to "opossum" for audio
]

CHICKEN_CLASS = "chicken"       # This is used for the escalation logic
ESCALATED_ALARM = "escalated.mp3"

# Unique sounds for each predator
SOUND_MAP = {
    "coyote": "coyote.mp3",
    "fox": "fox.mp3",
    "hawk": "hawk.mp3",
    "opossum": "opossum.mp3",
    "raccoon": "raccoon.mp3",
    "snake": "snake.mp3"
    # possumbrushtail => "opossum" in code
}

USE_NOTECARD = True
#NOTECARD_PORT = "/dev/ttyAMA0"
NOTECARD_PORT = "/dev/ttyACM0"

NOTECARD_BAUD = 9600

# Rate-limiting
MIN_TIME_BETWEEN_TRIGGERS = 30    # seconds
MAX_TRIGGERS_5_MIN = 10
FORCED_COOLDOWN_HOURS = 2

# We still do skipping frames logic
FRAME_RATE = 30
PROCESSING_TIME_SECONDS = 8
FRAMES_PER_INFERENCE = FRAME_RATE * PROCESSING_TIME_SECONDS  # 240 => one detection per ~8s

# Snapshots folder
SAVE_DIR = "captures"
os.makedirs(SAVE_DIR, exist_ok=True)

# Show debug window with bounding boxes?
SHOW_VIDEO = True

# Camera resolution (optional)
CAM_WIDTH = 640
CAM_HEIGHT = 480

# Video recording settings
RECORD_DURATION = 5    # seconds to record
RECORD_FPS = 15        # frames/sec in recorded clip
VIDEO_OUTPUT_DIR = "recordings"
os.makedirs(VIDEO_OUTPUT_DIR, exist_ok=True)

# ==========================================

def play_alarm_sound(sound_file):
    if not sound_file:
        return
    if not os.path.exists(sound_file):
        print("[ERROR] Sound file not found:", sound_file)
        return
    print("[INFO] Playing sound:", sound_file)
    os.system(f"mpg123 -q '{sound_file}'")

def send_alert_via_notecard(species_label, confidence, escalated=False):
    if not USE_NOTECARD:
        return
    try:
        with serial.Serial(NOTECARD_PORT, NOTECARD_BAUD, timeout=1) as port:
            alert_body = {
                "alert": "PREDATOR_DETECTED",
                "predator": species_label,
                "confidence": round(confidence, 2),
                "escalated": escalated,
                "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
            }
            req = {
                "req": "note.add",
                "file": "alerts.qo",
                "body": alert_body
            }
            port.write((json.dumps(req) + "\n").encode("utf-8"))
            port.flush()
        print(f"[INFO] Notecard alert sent. Predator={species_label}, escalated={escalated}")
    except Exception as e:
        print("[ERROR] Could not send Notecard alert:", e)

def boxes_intersect(boxA, boxB):
    Ax1, Ay1, Ax2, Ay2 = boxA
    Bx1, By1, Bx2, By2 = boxB
    if Ax2 < Bx1 or Bx2 < Ax1:
        return False
    if Ay2 < By1 or By2 < Ay1:
        return False
    return True

def record_video(cap, duration=5, fps=15):
    """Record a short video from the cap feed for duration seconds, store in recordings/ folder."""
    # Build a filename with timestamp
    time_str = time.strftime("%Y%m%d_%H%M%S")
    out_name = f"{time_str}_predator_clip.avi"
    out_path = os.path.join(VIDEO_OUTPUT_DIR, out_name)
    print(f"[INFO] Recording video: {out_path}")

    # Get actual camera resolution
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(out_path, fourcc, fps, (width, height))

    start_time = time.time()
    while (time.time() - start_time) < duration:
        ret, frame_rec = cap.read()
        if not ret:
            break
        out.write(frame_rec)
        # If you want to see a live preview while recording, you could do:
        # if SHOW_VIDEO:
        #     cv2.imshow("Recording...", frame_rec)
        #     if cv2.waitKey(1) & 0xFF == ord('q'):
        #         break

    out.release()
    print("[INFO] Finished recording video.")

def main():
    # No relay usage, no need for GPIO config beyond setwarnings(False) if you want
    GPIO.setwarnings(False)

    print("[INFO] Loading YOLO model from:", MODEL_PATH)
    model = YOLO(MODEL_PATH)

    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, CAM_WIDTH)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, CAM_HEIGHT)

    if not cap.isOpened():
        print("[ERROR] Cannot open camera (index 0).")
        return

    last_trigger_time = 0.0
    forced_cooldown_until = 0.0
    triggers = []

    frame_count = 0

    print("[INFO] Starting detection loop. We'll do inference every 240 frames (~8s).")
    if SHOW_VIDEO:
        print("Press 'q' in the window to exit if you have a desktop/VNC.")

    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                print("[ERROR] Camera read failed.")
                break

            frame_count += 1
            now = time.time()

            # forced cooldown
            if now < forced_cooldown_until:
                if SHOW_VIDEO:
                    cv2.imshow("Chicken Guardian", frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                continue

            # skip detection if not time
            if frame_count % FRAMES_PER_INFERENCE != 0:
                if SHOW_VIDEO:
                    cv2.imshow("Chicken Guardian", frame)
                    if cv2.waitKey(1) & 0xFF == ord('q'):
                        break
                continue

            # ========== RUN YOLO INFERENCE ==========

            results = model.predict(frame, conf=CONF_THRESHOLD)
            bboxes = results[0].boxes

            predator_boxes = []
            chicken_boxes = []

            for box in bboxes:
                cls_id = int(box.cls[0])
                conf = float(box.conf[0])

                raw_label = model.names[cls_id]
                label = raw_label.lower()

                # unify possumbrushtail => opossum for audio
                if label == "possumbrushtail":
                    audio_label = "opossum"
                else:
                    audio_label = label

                xyxy = box.xyxy[0].cpu().numpy().astype(int)
                (x1, y1, x2, y2) = xyxy

                # Draw bounding box
                if label in PREDATOR_CLASSES:
                    color = (0, 0, 255)  # red
                elif label == CHICKEN_CLASS:
                    color = (255, 255, 0) # teal
                else:
                    color = (0, 255, 0)   # green

                text = f"{label}: {conf:.2f}"
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
                cv2.putText(frame, text, (x1, max(y1 - 10, 0)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

                if label in PREDATOR_CLASSES:
                    predator_boxes.append((label, audio_label, conf, xyxy))
                elif label == CHICKEN_CLASS:
                    chicken_boxes.append(xyxy)

            # Rate-limiting triggers
            for (pred_label, audio_label, pred_conf, pred_xyxy) in predator_boxes:
                if (now - last_trigger_time) < MIN_TIME_BETWEEN_TRIGGERS:
                    continue

                # remove triggers older than 5 min
                cutoff = now - 300
                triggers = [t for t in triggers if t >= cutoff]

                if len(triggers) >= MAX_TRIGGERS_5_MIN:
                    print("[WARNING] Too many triggers -> 2-hour cooldown.")
                    forced_cooldown_until = now + (FORCED_COOLDOWN_HOURS * 3600)
                    break

                # check if escalated
                escalated = any(boxes_intersect(pred_xyxy, cxy) for cxy in chicken_boxes)
                if escalated:
                    print(f"[INFO] ESCALATED alarm: {pred_label} overlapping chicken!")
                    alarm_sound = ESCALATED_ALARM
                else:
                    alarm_sound = SOUND_MAP.get(audio_label, None)

                last_trigger_time = now
                triggers.append(now)

                print(f"[INFO] Detected {pred_label} (conf={pred_conf:.2f}), escalated={escalated}")

                # Save snapshot
                time_str = time.strftime("%Y%m%d_%H%M%S")
                file_name = f"{time_str}_{pred_label}_{pred_conf:.2f}.jpg"
                save_path = os.path.join(SAVE_DIR, file_name)
                cv2.imwrite(save_path, frame)
                print(f"[INFO] Snapshot saved: {save_path}")

                # ===  Record short video clip  ===
                record_video(cap, RECORD_DURATION, RECORD_FPS)

                # === Then play audio & send Notecard ===
                play_alarm_sound(alarm_sound)
                send_alert_via_notecard(pred_label, pred_conf, escalated=escalated)

                # After this, we simply return to main loop and detection resumes.

            # Show debug window if desired
            if SHOW_VIDEO:
                cv2.imshow("Chicken Guardian", frame)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

    except KeyboardInterrupt:
        print("[INFO] Exiting on Ctrl+C.")
    finally:
        cap.release()
        if SHOW_VIDEO:
            cv2.destroyAllWindows()

if __name__ == "__main__":
    main()