"""
Script to run the prediction on FGVC aircraft dataset with a resnet18 classifier
It requires the three other files downloaded on the same folder:
    * the list : images_family_infer.csv
    * the model : resnet18_level2.ckpt
    * the attack parameters : lira_parameters.csv
Use of pyton >=3.11 is highly recommended
With uv (https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
```
    uv python install 3.13
    uv python pin 3.13
    uv add jsonargparse lightning pandas torchvision scipy
    uv run infer.py
```
With pip 
```
    python -m venv venv
    source venv/bin/activate
    python -m pip install jsonargparse lightning pandas torchvision scipy
    python infer.py
```
"""

import torch
import pandas as pd
from tqdm import tqdm
from torchvision import datasets
import torchvision.models as torchvision_models
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import pandas as pd
from jsonargparse import ActionConfigFile, ArgumentParser
from PIL import Image
from scipy.stats import norm
import numpy as np


import os


basic_data_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


def parse_opt():
    parser = ArgumentParser()
    parser.add_argument("--cfg", action=ActionConfigFile)
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Number of data to process simultanously",
    )
    parser.add_argument(
        "--device", type=str, default="cpu", help="device used for inference"
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="/tmp/data/",
        required=True,
        help="Path to the data directory, where to download the dataset",
    )
    parser.add_argument(
        "--csv_file",
        type=str,
        default="images_family_infer.csv",
        required=True,
        help="Path to the csv file linking each image to a class outputed by the model",
    )
    parser.add_argument(
        "--model_filepath",
        type=str,
        default="resnet18_level2.ckpt",
        required=True,
        help="Path to the model checkpoint",
    )
    parser.add_argument(
        "--res_filepath",
        type=str,
        default="res2.csv",
        required=True,
        help="Path to the csv file output, with one pred per data",
    )
    parser.add_argument(
        "--lira_paramterspath",
        type=str,
        default="lira_parameters.csv",
        required=True,
        help="Path to the csv containing the provided LIRA parameters",
    )
    parser.add_argument(
        "--output_scorepath",
        type=str,
        default="lira_score_res2.csv",
        required=True,
        help="Path to the csv file outputed, with the LIRA score per data",
    )
    opt = parser.parse_args()
    return opt


def make_model(
    model_archi, model_init_weights=None, state_dict=None, nb_classes=50, V1=False
):
    if model_init_weights is None:
        weights = None
    else:
        if V1:
            weights = getattr(torchvision_models, model_init_weights).IMAGENET1K_V1
        else:
            weights = getattr(torchvision_models, model_init_weights).DEFAULT
    model = getattr(torchvision_models, model_archi.lower())(weights=weights)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, nb_classes)
    if state_dict is not None:
        model.load_state_dict(state_dict)
    return model


class AircraftDataset(Dataset):
    def __init__(self, df, img_fold, transform=None):
        self.image_ids = df["id"].values
        self.label = df["class"].values
        self.transform = transform
        self.root = img_fold

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        img_path = os.path.join(self.root, str(self.image_ids[index]) + ".jpg")
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.label[index]
        return str(self.image_ids[index]), img, label


def kth_best(row, k):
    """Renvoie le keme plus grand élément de la list"""
    r_list = row.tolist()
    r_list.sort()
    return r_list[-k]


def compute_hinge_loss(pred, label, logits):
    """Renvoie la hinge_loss a partir de la prédiction,
    du label (truth) et des logits"""
    if pred == label:
        return logits[label] - kth_best(logits, 2)
    return logits[label] - kth_best(logits, 1)


def get_parameter(df_param, image_id):
    row = df_param[df_param.image_id == image_id]
    mean_in = row["mean_in"].values[0]
    mean_out = row["mean_out"].values[0]
    std_in = row["std_in"].values[0]
    std_out = row["std_out"].values[0]
    return mean_in, mean_out, std_in, std_out


def lira_score(
    infer_res_path="res2.csv",
    parameters_path="lira_parameters.csv",
    output_path="lira_score.csv",
):
    df = pd.read_csv(infer_res_path, converters={"image_id": str})
    df_parameters = pd.read_csv(parameters_path, converters={"image_id": str})
    dic_res = {"image_id": [], "pred": []}
    for image_id in tqdm(df.image_id.unique()):
        df_data = df[df.image_id == image_id]
        hinge_losses = df_data.hingeloss.values
        mean_in, mean_out, std_in, std_out = get_parameter(df_parameters, image_id)
        pr_in = [
            norm.logpdf(hinge_loss, mean_in, std_in) + 1e-30
            for hinge_loss in hinge_losses
        ]
        pr_out = [
            norm.logpdf(hinge_loss, mean_out, std_out) + 1e-30
            for hinge_loss in hinge_losses
        ]
        pred = [pin - pout for (pin, pout) in zip(pr_in, pr_out)]
        dic_res["image_id"].append(image_id)
        dic_res["pred"].append(np.mean(pred))
    df_res = pd.DataFrame(dic_res)
    df_res.to_csv(output_path, index=False)


def make_inference(
    model_filepath, device, batch_size, img_folder, csv_images, res_filepath
):
    # Load model
    ckpt = torch.load(model_filepath, map_location=device)
    state_dict = {
        ".".join(k.split(".")[1::]): ckpt["state_dict"][k]
        for k in ckpt["state_dict"].keys()
        if k.split(".")[0] == "model"
    }
    model_nb_classes = state_dict["fc.bias"].shape[0]
    model = make_model(
        model_archi="resnet18",
        model_init_weights=None,
        nb_classes=model_nb_classes,
        state_dict=state_dict,
    )
    model.eval()
    model.to(device)
    print(model)

    df_images = pd.read_csv(csv_images, dtype=str, keep_default_na=False)
    infer_dataset = AircraftDataset(
        df=df_images,
        img_fold=img_folder,
        transform=basic_data_transforms,
    )
    infer_loader = torch.utils.data.DataLoader(
        infer_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    with open(res_filepath, "w") as f:
        f.write("image_id,truth,pred,maxlogit,hingeloss\n")

    for _, (images_ids, images, labels) in tqdm(
        enumerate(infer_loader), total=len(infer_loader)
    ):
        images = images.to(device)
        logits = model(images).detach().to("cpu")
        maxlogits, preds = torch.max(logits, 1)
        res = ""
        for image_id, label, maxlog, pred, logit in zip(
            images_ids, labels, maxlogits, preds, logits
        ):
            hingeloss = compute_hinge_loss(pred.item(), int(label), logit)
            res += f"{image_id},{label},{pred},{maxlog},{hingeloss}\n"
        with open(res_filepath, "a") as f:
            f.write(res)


if __name__ == "__main__":
    args = parse_opt()
    download_dir = os.path.join(args.data_dir, "download")
    print(f"Download the data in {download_dir}")
    dataset = datasets.FGVCAircraft(root=download_dir, split="trainval", download=True)

    jpg_folder = os.path.join(download_dir, "fgvc-aircraft-2013b/data/images/")
    print("making the inference with the model")
    make_inference(
        model_filepath=args.model_filepath,
        csv_images=args.csv_file,
        res_filepath=args.res_filepath,
        img_folder=jpg_folder,
        device=args.device,
        batch_size=args.batch_size,
    )
    print(f"Inference done, results written in {args.res_filepath}")
    print("applying LIRA")
    lira_score(
        infer_res_path=args.res_filepath,
        parameters_path=args.lira_paramterspath,
        output_path=args.output_scorepath,
    )
    print(f"LIRA performed, attack results written in {args.output_scorepath}")
