import torch
import torch.nn as nn


# Define the original model architecture
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)


def format_cpp_array(name, array, shape):
    cpp = f"const float {name}[{']['.join(map(str, shape))}] PROGMEM = {{\n"
    if len(shape) == 2:
        for row in array:
            cpp += "  { " + ", ".join(f"{v:.6f}f" for v in row) + " },\n"
    elif len(shape) == 1:
        cpp += "  " + ", ".join(f"{v:.6f}f" for v in array) + "\n"
    cpp += "};\n"
    return cpp


input_size = 11  # example value — update this to match your original training
hidden_size = 64  # example value — update this too
output_size = 3  # example value — update this too

model = Linear_QNet(input_size, hidden_size, output_size)

# Load the trained weights
model.load_state_dict(torch.load('/model/model.pth'))


linear1_weight = model.linear1.weight.data.numpy()
linear1_bias = model.linear1.bias.data.numpy()
linear2_weight = model.linear2.weight.data.numpy()
linear2_bias = model.linear2.bias.data.numpy()

# Print C array definitions
print(f"const int input_size = {input_size};")
print(f"const int hidden_size = {hidden_size};")
print(f"const int output_size = {output_size};\n")

print(format_cpp_array("LINEAR1_WEIGHTS", linear1_weight, linear1_weight.shape))
print(format_cpp_array("LINEAR1_BIAS", linear1_bias, linear1_bias.shape))
print(format_cpp_array("LINEAR2_WEIGHTS", linear2_weight, linear2_weight.shape))
print(format_cpp_array("LINEAR2_BIAS", linear2_bias, linear2_bias.shape))
