import sys
import csv
import json
import argparse
import numpy as np
import serial
from io import StringIO
from PyQt5.QtWidgets import QApplication, QWidget, QLabel
from PyQt5 import QtCore
import pyqtgraph as pg
from PyQt5.QtCore import QThread
from scipy.signal import butter, filtfilt


class BandpassFilter:
    """Butterworth bandpass filter helper class."""

    def __init__(self, lowcut, highcut, sample_rate, order=4):
        self.lowcut = lowcut
        self.highcut = highcut
        self.sample_rate = sample_rate
        self.order = order
        self._b, self._a = self._design_filter()

    def _design_filter(self):
        nyquist_freq = 0.5 * self.sample_rate
        low = self.lowcut / nyquist_freq
        high = self.highcut / nyquist_freq
        return butter(self.order, [low, high], btype='band')

    def apply(self, data):
        return filtfilt(self._b, self._a, data)


class CSIAmplitudeBuffer:
    """Circular buffer for amplitude data storage."""

    def __init__(self, size):
        self.size = size
        self.data = np.zeros(size)

    def append(self, value):
        self.data[:-1] = self.data[1:]
        self.data[-1] = value

    def is_ready(self):
        return np.count_nonzero(self.data) >= self.size // 2

    def get(self):
        return self.data.copy()


class HeartbeatAnalyzer(QWidget):
    """Main UI widget for plotting FFT and displaying heartbeat info."""

    def __init__(self, amplitude_buffer, sampling_rate):
        super().__init__()
        self.setWindowTitle("Heartbeat Detection from CSI")
        self.resize(1000, 550)

        self.amplitude_buffer = amplitude_buffer
        self.sampling_rate = sampling_rate
        self.bandpass_filter = BandpassFilter(lowcut=0.5, highcut=2.5, sample_rate=sampling_rate)

        # Setup plot widget
        self.plot_widget = pg.PlotWidget(self)
        self.plot_widget.setGeometry(QtCore.QRect(0, 0, 1000, 500))
        self.plot_widget.setTitle("FFT of CSI Amplitude")
        self.plot_widget.setLabel('left', 'Magnitude')
        self.plot_widget.setLabel('bottom', 'Frequency (Hz)')
        self.plot_curve = self.plot_widget.plot([], pen='r')

        # Label for detected heartbeat period display
        self.period_label = QLabel(self)
        self.period_label.setGeometry(10, 510, 500, 30)
        self.period_label.setStyleSheet("font-size: 16px; color: blue")
        self.period_label.setText("Period: Calculating...")

        self.update_timer = QtCore.QTimer()
        self.update_timer.timeout.connect(self._update_fft_plot)
        self.update_timer.start(250)

    def _update_fft_plot(self):
        if not self.amplitude_buffer.is_ready():
            self.period_label.setText("Period: Not enough data")
            return

        raw_data = self.amplitude_buffer.get()
        filtered_data = self.bandpass_filter.apply(raw_data)

        windowed = filtered_data - np.mean(filtered_data)
        windowed *= np.hanning(len(windowed))

        fft_result = np.fft.fft(windowed)
        freqs = np.fft.fftfreq(len(windowed), d=1/self.sampling_rate)
        magnitudes = np.abs(fft_result)

        # Focus on positive frequencies under 10Hz
        valid_indices = (freqs >= 0) & (freqs <= 10)
        freqs = freqs[valid_indices]
        magnitudes = magnitudes[valid_indices]

        self.plot_curve.setData(freqs, magnitudes)

        if magnitudes.size > 0:
            self.plot_widget.setYRange(0, magnitudes.max() * 1.1)
            self.plot_widget.setXRange(0.1, 10)

        # Detect peak frequency in heartbeat band (0.5 - 2.5 Hz)
        heartbeat_indices = (freqs >= 0.5) & (freqs <= 2.5)
        heartbeat_freqs = freqs[heartbeat_indices]
        heartbeat_mags = magnitudes[heartbeat_indices]

        if heartbeat_mags.size > 0:
            peak_idx = np.argmax(heartbeat_mags)
            peak_freq = heartbeat_freqs[peak_idx]
            if peak_freq > 0:
                period = 1 / peak_freq
                self.period_label.setText(f"Peak Frequency: {peak_freq:.3f} Hz (Period: {period:.3f} s)")
            else:
                self.period_label.setText("Peak Frequency: Not detected")
        else:
            self.period_label.setText("Peak Frequency: Not detected")


class CSIReader(QThread):
    """Thread responsible for reading CSI data from serial port."""

    new_amplitude_signal = QtCore.pyqtSignal(float)

    def __init__(self, serial_port, subcarrier_index):
        super().__init__()
        self.serial_port_name = serial_port
        self.subcarrier_index = subcarrier_index
        self._running = True
        self.serial_connection = None

    def run(self):
        try:
            self.serial_connection = serial.Serial(
                port=self.serial_port_name,
                baudrate=921600,
                bytesize=8,
                parity='N',
                stopbits=1,
                timeout=1
            )
        except Exception as e:
            print(f"Error opening serial port: {e}")
            return

        if not self.serial_connection.isOpen():
            print("Could not open serial port.")
            return

        while self._running:
            try:
                line = self.serial_connection.readline().decode(errors='ignore').strip()
                if 'CSI_DATA' not in line:
                    continue

                # Parse CSV line
                csv_data = next(csv.reader(StringIO(line)))

                # Extract JSON payload of CSI raw data
                try:
                    csi_json_data = json.loads(csv_data[-1])
                except json.JSONDecodeError:
                    continue

                data_length = int(csv_data[-3])
                if data_length != len(csi_json_data):
                    continue

                # Convert to complex numbers: (imag, real) pairs
                complex_samples = [
                    complex(csi_json_data[i * 2 + 1], csi_json_data[i * 2])
                    for i in range(data_length // 2)
                ]

                amplitude = np.abs(complex_samples[self.subcarrier_index])

                self.new_amplitude_signal.emit(amplitude)

            except Exception as e:
                print(f"Error reading serial data: {e}")
                continue

        self.serial_connection.close()

    def stop(self):
        self._running = False
        self.wait()


def main():
    parser = argparse.ArgumentParser(description="CSI Heartbeat Detection Application")
    parser.add_argument('-p', '--port', required=True, help="Serial port name (e.g., COM3, /dev/ttyUSB0)")
    args = parser.parse_args()

    app = QApplication(sys.argv)

    amplitude_buffer = CSIAmplitudeBuffer(size=512)
    analyzer = HeartbeatAnalyzer(amplitude_buffer, sampling_rate=100)

    reader_thread = CSIReader(serial_port=args.port, subcarrier_index=100)

    reader_thread.new_amplitude_signal.connect(amplitude_buffer.append)

    reader_thread.start()
    analyzer.show()

    exit_code = app.exec_()

    reader_thread.stop()
    sys.exit(exit_code)


if __name__ == "__main__":
    main()
