#pip install aifes-converter --ignore-requires-python

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from packaging import version
import sys
import tensorflow as tf

np.set_printoptions(precision=2, suppress=True)

# Choose the training dataset to be generated
exp_scale = False # Exponential Input scale for dB measurement. False: linear 0..1
one_hot_output = False # Ramp: False

# input → one-hot output 
if exp_scale:
    one_hot_X = np.linspace(-24, 0, 8, dtype=np.float32).reshape(-1,1)
    one_hot_X = 0.5 * 10**(one_hot_X/20)
else:
    on_hot_X = np.linspace(0, 1, 8).reshape(-1,1) #np.array([[i / 7.0] for i in range(8)], dtype=np.float32) 
one_hot_Y = np.eye(8, dtype=np.float32) 

# Input -> slope output
if exp_scale:
    slope_X = np.linspace(-18, 6, 9, dtype=np.float32).reshape(-1,1)
    slope_X = 1 * 10**(slope_X/20)
else:
    slope_X = np.linspace(0, 1, 9).reshape(-1,1) #np.array([[i / 8.0] for i in range(9)], dtype=np.float32) 
slope_Y = np.tri(8, 8, dtype=np.float32) 
slope_Y = np.vstack((np.zeros(8,), slope_Y))

if one_hot_output:
    X, Y = one_hot_X, one_hot_Y
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    output_layer_fn = 'softmax'
else: # Slope output
    X, Y = slope_X, slope_Y
    loss_fn = tf.keras.losses.BinaryCrossentropy()
    output_layer_fn = 'sigmoid'

print(X)
print(Y)
#sys.exit(0)
# === Define model with two hidden layers ===
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    tf.keras.layers.Dense(3, activation = 'leaky_relu'),  # Hidden Layer 1
    tf.keras.layers.Dense(3, activation = 'leaky_relu'),  # Hidden Layer 2
    tf.keras.layers.Dense(8, activation = output_layer_fn) # Output Layer
])

EPOCHS = 5000 # Decreasing EPOCHS makes training faster at the cost of accuracy.
optimizer = tf.keras.optimizers.Adam()

loss_history = []
grad_histories = [[] for _ in range(len(model.layers))]

# === Training loop ===
for epoch in range(EPOCHS):

    with tf.GradientTape() as tape:
        predictions = model(X, training=True)
        loss = loss_fn(Y, predictions)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    loss_history.append(loss.numpy())

    # Track gradients per layer (weights only)
    for layer_idx in range(len(model.layers)):
        grad_weights = grads[layer_idx * 2]  # even indices = weights
        avg_grad = tf.reduce_mean(tf.abs(grad_weights)).numpy() if grad_weights is not None else 0.0
        grad_histories[layer_idx].append(avg_grad)

    if (epoch+1) % 100 == 0:
        print(f"epoch: {epoch+1}, loss: {loss}")

# === Get all weight matrices (skip biases) ===
weight_matrices = [model.layers[i].get_weights()[0] for i in range(len(model.layers))]
# === Plot: Loss + All Gradients (1 plot) + Weight Heatmaps ===
num_layers = len(model.layers)
fig, axs = plt.subplots(2, max(3, num_layers), figsize=(5 * (num_layers + 1), 10))

# --- Top row: Loss and combined gradient plot ---
# Loss plot
axs[0, 0].plot(loss_history, label="Loss", color='blue')
axs[0, 0].set_title("Loss over Epochs")
axs[0, 0].set_xlabel("Epoch")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].grid(True)
axs[0, 0].legend()

# Combined gradient plot
for i, grads in enumerate(grad_histories):
    axs[0, 1].plot(grads, label=f"Layer {i+1}")
axs[0, 1].set_title("Gradient Flow (All Layers)")
axs[0, 1].set_xlabel("Epoch")
axs[0, 1].set_ylabel("Avg. Gradient Magnitude")
axs[0, 1].grid(True)
axs[0, 1].legend()

pred = np.round(model.predict(X), 2)
axs[0, 2].set_title("Predictions")
sns.heatmap(pred, vmin=0, vmax=1, annot=True, cmap='coolwarm', center=0, ax=axs[0, 2])

# Fill remaining top row with empty plots (if needed)
for j in range(2, num_layers):
    axs[0, j].axis("off")

for i, weights in enumerate(weight_matrices):
    sns.heatmap(weights, annot=True, cmap='coolwarm', center=0,
                xticklabels=[f'L{i+1}_{j}' for j in range(weights.shape[1])],
                yticklabels=[f'In{j}' for j in range(weights.shape[0])],
                ax=axs[1, i])
    axs[1, i].set_title(f"Weights: Layer {i+1}")
    axs[1, i].set_xlabel("Output Neurons")
    axs[1, i].set_ylabel("Input Neurons")

# Remove unused axes in the second row
for j in range(num_layers, max(3, num_layers)):
    print(j)
    axs[1, j].axis("off")

# Convert to AIfES model
from aifes import keras2aifes #pytorch2aifes

if version.parse(tf.__version__) >= version.parse("2.7.0"): # TensorFlow after V2.6?
    model.layers[0].input_shape = model.input_shape # Workaround

keras2aifes.convert_to_fnn_q7(model, '.', representative_data=X, target_alignment=2, byteorder="little")
keras2aifes.convert_to_fnn_f32(model,'.')
plt.tight_layout()
plt.show()


# Other:
#keras2aifes.convert_to_fnn_f32
#keras2aifes.convert_to_fnn_f32_cmsis
#keras2aifes.convert_to_fnn_f32_express
#keras2aifes.convert_to_fnn_q7_cmsis
#keras2aifes.convert_to_fnn_q7_express

