#!/usr/bin/env Rscript

# Copyright (C) 2025 Université de Reims Champagne-Ardenne.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     (1) Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#
#     (2) Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#
#     (3)The name of the author may not be used to
#     endorse or promote products derived from this software without
#     specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
# IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

`%>%` <- magrittr::`%>%`

options (warn=1)

all_files <- readr::read_csv ("files.csv", col_names = c ("filename", "type", "annotator"),
                              col_types = list (readr::col_character (),
                                                readr::col_character (),
                                                readr::col_character ()))

fsps <- (all_files
    %>% dplyr::filter (type == "https://neonatool.github.io/adftool-v2#FSP")
    %>% dplyr::select (filename, annotator))

## TODO: ACP par annotateur

shapes <- do.call (dplyr::bind_rows,
                   lapply (seq_len (nrow (fsps)), function (i) {
                       filename <- fsps$filename[i]
                       annotator <- fsps$annotator[i]
                       whole_thing <- (readr::read_csv (filename,
                                                        col_names = c ("time", "value"),
                                                        col_types = list (readr::col_double (),
                                                                          readr::col_double ()))
                           %>% dplyr::arrange (time)
                           %>% dplyr::mutate (time = seq_len (dplyr::n ())))
                       center <- (whole_thing
                           %>% dplyr::filter (time >= quantile (whole_thing$time, 0.25))
                           %>% dplyr::filter (time <= quantile (whole_thing$time, 0.75))
                           %>% dplyr::arrange (value))$time[1]
                       left_padding <- round (128 - center)
                       centered <- numeric (256)
                       if (left_padding >= 0) {
                           relevant <- c (rep (NA, left_padding),
                                          whole_thing$value)
                           covered <- c (relevant, rep (NA, max (0, 256 - length (relevant))))
                           centered[] <- covered[1:256]
                       } else {
                           discard <- -left_padding
                           relevant <- whole_thing$value[(discard + 1):(nrow (whole_thing))]
                           covered <- c (relevant,
                                         rep (NA, (max (0, 256 - length (relevant)))))
                           centered[] <- covered[1:256]
                       }
                       stopifnot (length (centered) == 256)
                       stopifnot (is.finite (centered[128]))
                       centered[] <- centered[] - centered[128]
                       length_before <- 128 - which.max(centered[1:128])
                       length_after <- which.max(centered[129:256])
                       slope_before <- which.max(centered[1:128])
                       slope_after <- 128 - which.max(centered[129:256])
                       (tibble::tibble (filename = filename,
                                        annotator = annotator,
                                        duration = (max (whole_thing$time) - min (whole_thing$time)) / 256.0,
                                        `peak-to-peak amplitude (µV)` = max (whole_thing$value),
                                        `max amplitude before (µV)` = max (centered[1:128], na.rm = TRUE),
                                        `max amplitude after (µV)` = max (centered[128:256], na.rm = TRUE),
                                        `length of the spike (before) (s)` = length_before / 256,
                                        `length of the spike (after) (s)` = length_after / 256,
                                        `length of the initial slope (s)` = slope_before / 256,
                                        `length of the final slope (s)` = slope_after / 256,
                                        `relaxation ratio` = length_after / length_before,
                                        time = 1:256,
                                        signal = centered))
                   }))

the <- function (data) {
    stopifnot (all (data[2:length (data)] == data[1]))
    data[1]
}

ggplot2::ggsave (
             "duration.png",
             ggplot2::ggplot (shapes
                              %>% dplyr::group_by (filename)
                              %>% dplyr::summarize (duration = the (duration))
                              %>% dplyr::mutate (`duration (ms)` = 1000 * duration),
                              ggplot2::aes (x = `duration (ms)`))
             + ggplot2::geom_histogram (breaks = seq (0, 1000, by=50))
             + ggplot2::ggtitle ("Duration of the FSP"))

ggplot2::ggsave (
             "amplitude.png",
             ggplot2::ggplot (shapes
                              %>% dplyr::group_by (filename)
                              %>% dplyr::summarize (`amplitude (µV)` = the (`peak-to-peak amplitude (µV)`)),
                              ggplot2::aes (x = `amplitude (µV)`))
             + ggplot2::geom_histogram (breaks = seq (0, 250, by=10))
             + ggplot2::ggtitle ("Minimum-to-max amplitude of the FSP"))

useful_shapes <- (shapes
    %>% dplyr::filter (`peak-to-peak amplitude (µV)` > 50, duration > 0.250, `peak-to-peak amplitude (µV)` < 250, duration < 1.000)
    %>% dplyr::group_by (filename)
    %>% dplyr::arrange (time)
    %>% dplyr::summarize (signal = list (t (t (signal))),
                          `by NK` = ifelse (the (annotator) == "NK", 1, 0),
                          `by AG` = ifelse (the (annotator) == "AG", 1, 0),
                          `by GL` = ifelse (the (annotator) == "GL", 1, 0),
                          `peak-to-peak amplitude (µV)` = the (`peak-to-peak amplitude (µV)`),
                          `max amplitude before (µV)` = the (`max amplitude before (µV)`),
                          `max amplitude after (µV)` = the (`max amplitude after (µV)`),
                          `length of the spike (before) (s)` = the (`length of the spike (before) (s)`),
                          `length of the spike (after) (s)` = the (`length of the spike (after) (s)`),
                          `length of the initial slope (s)` = the (`length of the initial slope (s)`),
                          `length of the final slope (s)` = the (`length of the final slope (s)`),
                          `relaxation ratio` = the (`relaxation ratio`),
                          annotator = the (annotator))
    %>% dplyr::mutate (`not by NK` = 1 - `by NK`,
                       `not by AG` = 1 - `by AG`,
                       `not by GL` = 1 - `by GL`,
                       `max before spike` = ifelse (`max amplitude before (µV)` >= `max amplitude after (µV)`, 1, 0),
                       `max after spike` = 1 - `max before spike`
                       ))

shapes_data <- t (do.call (cbind, useful_shapes$signal))
replacement <- t (colMeans (shapes_data, na.rm = TRUE))[rep (1, nrow (shapes_data)),]
shapes_data_filled <- shapes_data
shapes_data_filled[is.na (shapes_data)] <- replacement[is.na (shapes_data)]
extra_variables <- as.matrix (useful_shapes
                              %>% dplyr::mutate (`by NK` = (`by NK` - mean (`by NK`)) / sd (`by NK`))
                              %>% dplyr::mutate (`by AG` = (`by AG` - mean (`by AG`)) / sd (`by AG`))
                              %>% dplyr::mutate (`by GL` = (`by GL` - mean (`by GL`)) / sd (`by GL`))
                              %>% dplyr::mutate (`not by NK` = (`not by NK` - mean (`not by NK`)) / sd (`not by NK`))
                              %>% dplyr::mutate (`not by AG` = (`not by AG` - mean (`not by AG`)) / sd (`not by AG`))
                              %>% dplyr::mutate (`not by GL` = (`not by GL` - mean (`not by GL`)) / sd (`not by GL`))
                              %>% dplyr::mutate (`peak-to-peak amplitude (µV)` = (`peak-to-peak amplitude (µV)` - mean (`peak-to-peak amplitude (µV)`)) / sd (`peak-to-peak amplitude (µV)`))
                              %>% dplyr::mutate (`max amplitude before (µV)` = (`max amplitude before (µV)` - mean (`max amplitude before (µV)`)) / sd (`max amplitude before (µV)`))
                              %>% dplyr::mutate (`max amplitude after (µV)` = (`max amplitude after (µV)` - mean (`max amplitude after (µV)`)) / sd (`max amplitude after (µV)`))
                              %>% dplyr::mutate (`max before spike` = (`max before spike` - mean (`max before spike`)) / sd (`max before spike`))
                              %>% dplyr::mutate (`max after spike` = (`max after spike` - mean (`max after spike`)) / sd (`max after spike`))
                              %>% dplyr::mutate (`length of the spike (before) (s)` = (`length of the spike (before) (s)` - mean (`length of the spike (before) (s)`) / sd (`length of the spike (before) (s)` )))
                              %>% dplyr::mutate (`length of the spike (after) (s)` = (`length of the spike (after) (s)` - mean (`length of the spike (after) (s)`) / sd (`length of the spike (after) (s)` )))
                              %>% dplyr::mutate (`length of the initial slope (s)` = (`length of the initial slope (s)` - mean (`length of the initial slope (s)`) / sd (`length of the initial slope (s)` )))
                              %>% dplyr::mutate (`length of the final slope (s)` = (`length of the final slope (s)` - mean (`length of the final slope (s)`) / sd (`length of the final slope (s)` )))
                              %>% dplyr::mutate (`relaxation ratio` = (`relaxation ratio` - mean (`relaxation ratio`) / sd (`relaxation ratio` )))
                              %>% dplyr::select (`by NK`, `by AG`, `by GL`, `not by NK`, `not by AG`, `not by GL`, `peak-to-peak amplitude (µV)`, `max amplitude before (µV)`, `max amplitude after (µV)`, `max before spike`, `max after spike`, `length of the spike (before) (s)`, `length of the spike (after) (s)`, `length of the initial slope (s)`, `length of the final slope (s)`, `relaxation ratio`))

stopifnot (ncol (shapes_data) == 256)
stopifnot (ncol (shapes_data_filled) == 256)
stopifnot (ncol (extra_variables) == 16)
stopifnot (nrow (shapes_data) == nrow (extra_variables))
stopifnot (nrow (shapes_data_filled) == nrow (extra_variables))

ggplot2::ggsave (
             "shapes.png",
             ggplot2::ggplot (
                          do.call (
                              dplyr::bind_rows,
                              lapply (seq_len (nrow (shapes_data)), function (i) {
                                  tibble::tibble (individual = i,
                                                  time = seq_len (256) / 256,
                                                  signal = shapes_data[i,])
                                  })) %>% dplyr::filter (!is.na (signal)),
                          ggplot2::aes (x = time, y = signal, group = individual))
             + ggplot2::geom_line (alpha = 0.1)
             + ggplot2::ggtitle ("Shapes"))

## PRCOMP will center each variable regardless of the continuity in
## time. This won’t work, because feature 128 is a constant 0. We want
## to center and scale in the time dimension first.

shapes_data_centered <- shapes_data_filled
for (i in seq_len (nrow (shapes_data_centered))) {
    shapes_data_centered[i,] <- (
        (shapes_data_centered[i,] - mean (shapes_data_centered[i,]))
        / sd (shapes_data_centered[i,])
    )
}

acp <- prcomp (shapes_data_centered, retx = TRUE, center = TRUE, scale. = TRUE, rank. = 128)

## acp has:
##
## - acp$sdev: the square root of the axis inertia (vector, dimension 256)
## - acp$rotation: how to project an individual onto the representation (matrix, 256 by 12)
## - acp$center, acp$scale: the affine transformation to apply to individuals after time-centering and before projection onto the PCA dimensions.
## - acp$x: the coordinates of the individuals (matrix, dimension N by 12)

## Put the first component in the correct direction
for (ev in seq_len (ncol (acp$rotation))) {
    if (sign (sum (sign (acp$rotation[, ev]))) < 0) {
        acp$rotation[, ev] <- - acp$rotation[, ev]
        acp$x[, ev] <- - acp$x[, ev]
    }
}

expected_total_inertia <- 256
actual_total_inertia <- sum (acp$sdev ^ 2)
stopifnot (abs (expected_total_inertia - actual_total_inertia) / expected_total_inertia < 1e-8)

ggplot2::ggsave (
             "eigenvalues.png",
             ggplot2::ggplot (tibble::tibble (axis = 1:256,
                                              eigenvalue = acp$sdev ^ 2)
                              %>% dplyr::mutate (weight = eigenvalue / sum (eigenvalue)),
                              ggplot2::aes (x = axis, y = weight))
             + ggplot2::geom_line ()
             + ggplot2::scale_x_continuous (trans = 'log10')
             + ggplot2::scale_y_continuous (labels = scales::percent_format ())
             + ggplot2::ggtitle ("Weight of the different axes in the eigenvalue decomposition of the correlation matrix"))

average_shape <- tibble::tibble (time = seq_len (nrow (acp$rotation)),
                                 `amplitude (µV)` = apply (shapes_data, 2, function (data) mean (data, na.rm = TRUE)),
                                 sd = apply (shapes_data, 2, function (data) sd (data, na.rm = TRUE)),
                                 support = apply (shapes_data, 2, function (data) length (data[!is.na (data)])))

ggplot2::ggsave (
             "average_shape.png",
             ggplot2::ggplot (average_shape %>% dplyr::mutate (time = time / 256, `-1 sd` = `amplitude (µV)` - sd, `+1 sd` = `amplitude (µV)` + sd),
                              ggplot2::aes (x = time, y = `amplitude (µV)`, ymin = `-1 sd`, ymax = `+1 sd`, alpha = support))
             + ggplot2::geom_line ()
             + ggplot2::geom_ribbon (alpha = 0.1)
             + ggplot2::ggtitle ("Average shape"))

eigenshapes <- do.call (dplyr::bind_rows,
                        lapply (seq_len (ncol (acp$x)),
                                function (j) {
                                    tibble::tibble (eigenshape = j,
                                                    weight = acp$sdev[j] ^ 2 / 256,
                                                    time = seq_len (nrow (acp$rotation)),
                                                    shape = acp$rotation[,j] * acp$center[j])
                                }))

ggplot2::ggsave (
             "eigenshapes.png",
             ggplot2::ggplot (eigenshapes %>% dplyr::arrange (eigenshape, time) %>% dplyr::filter (eigenshape < 10) %>% dplyr::mutate (time = time / 256),
                              ggplot2::aes (x = time, y = shape, alpha = weight, group = eigenshape))
             + ggplot2::geom_line ()
             + ggplot2::ggtitle ("Eigenshapes (around the average shape)"))

inertia <- acp$x ^ 2
individual_inertia <- rowSums (inertia)

axis_inertia <- (nrow (acp$x) * acp$sdev ^ 2)[seq_len (ncol (inertia))]
check_axis_inertia <- colSums (inertia)
stopifnot (all (check_axis_inertia <= axis_inertia))
stopifnot (all (check_axis_inertia > 0.95 * axis_inertia))

check_contributions <- function (axis, threshold) {
    contributions <- abs (inertia[axis,])
    contributions_sum <- cumsum (-sort (-contributions)) / sum (contributions)
    length (contributions_sum[contributions_sum < threshold]) / ncol (inertia)
}

axes_health <- do.call (dplyr::bind_rows,
                        lapply (seq_len (nrow (inertia)),
                                function (axis) {
                                    tibble::tibble (axis = axis,
                                                    contributing_individuals = check_contributions (axis, 0.95))
                                }))

ggplot2::ggsave (
             "axis_health.png",
             ggplot2::ggplot (axes_health, ggplot2::aes (x = axis, y = contributing_individuals))
             + ggplot2::geom_line ()
             + ggplot2::scale_y_continuous (labels = scales::percent_format ())
             + ggplot2::ggtitle ("Relative size of the smallest subgroup of individuals that make 95% of each axis inertia"))

interesting_individuals <- tibble::tibble (index = integer (0))

n_pcs <- ncol (inertia)
dataset <- dplyr::bind_cols (tibble::as_tibble (acp$x), tibble::as_tibble (extra_variables))

correlation_matrix <- cor (as.matrix (dataset))

negative_correlation_meaningless <- colnames (correlation_matrix) %in% c ("by GL", "not by GL", "by NK", "not by NK", "by AG", "not by AG")

correlations <- (
    do.call (dplyr::bind_rows,
             lapply (seq_len (ncol (correlation_matrix)), function (j) {
                 tibble::tibble (variable = colnames (correlation_matrix)[j],
                                 variable_index = j,
                                 other_variable = colnames (correlation_matrix),
                                 other_variable_index = seq_len (ncol (correlation_matrix)),
                                 correlation = correlation_matrix[j,])
             }))
    %>% dplyr::filter (! negative_correlation_meaningless[other_variable_index] | correlation > 0)
    %>% dplyr::filter (variable_index < other_variable_index)
    %>% dplyr::filter (variable_index <= n_pcs)
    %>% dplyr::filter (variable_index <= 24)
    %>% dplyr::filter (other_variable_index > n_pcs)
    %>% dplyr::select (variable, other_variable, correlation)
    %>% dplyr::arrange (- abs (correlation)))

plot_pane <- function (axis_1, axis_2, contribution_threshold = 0.02, quality_threshold = 0.25) {
    x <- acp$x[, axis_1]
    y <- acp$x[, axis_2]
    represented_inertia <- x ^ 2 + y ^ 2
    represented_quality <- represented_inertia / individual_inertia
    annotator <- useful_shapes$annotator
    pane_contribution <- represented_inertia / (axis_inertia[axis_1] + axis_inertia[axis_2])
    outstanding_contribution <- ifelse (pane_contribution > contribution_threshold, "HIGH", "LOW")
    sufficient_quality <- which (represented_quality > quality_threshold)
    interesting_individuals <<- (
        dplyr::bind_rows (interesting_individuals,
                          tibble::tibble (index = 1:nrow (inertia),
                                          outstanding = pane_contribution > contribution_threshold,
                                          sufficient = represented_quality > quality_threshold)
                          %>% dplyr::filter (outstanding)
                          %>% dplyr::filter (sufficient)
                          %>% dplyr::select (index))
        %>% dplyr::distinct ())
    (ggplot2::ggplot (tibble::tibble (x = x[sufficient_quality],
                                      y = y[sufficient_quality],
                                      `individual\ncontribution\n(plane)` = outstanding_contribution[sufficient_quality],
                                      `outstanding contribution value` =
                                          ifelse (outstanding_contribution[sufficient_quality] == "HIGH",
                                                  sprintf ("%d—%.1f%%", seq_len (nrow (inertia))[sufficient_quality], 100 * pane_contribution[sufficient_quality]),
                                                  ""),
                                      annotator = annotator[sufficient_quality]),
                      ggplot2::aes (x = x,
                                    y = y,
                                    color = annotator,
                                    shape = `individual\ncontribution\n(plane)`,
                                    label = `outstanding contribution value`))
        + ggplot2::geom_point ()
        + ggrepel::geom_label_repel (color = "black")
        + ggplot2::scale_shape_manual (limits = c ("LOW", "HIGH"),
                                       values = c ("LOW" = 3, "HIGH" = 2))
        + ggplot2::scale_color_discrete (limits = c ("AG", "GL", "NK"))
        + ggplot2::ggtitle (sprintf ("[%d, %d] plane of individuals, quality at least %.1f%%", axis_1, axis_2, 100 * quality_threshold),
                            subtitle=sprintf ("%.1f%% of the total inertia in the plane\n%.1f%% of the total individuals represented", 100 * ((axis_inertia[axis_1] + axis_inertia[axis_2]) / sum (axis_inertia)), 100 * (length (sufficient_quality) / nrow (inertia)))))
}

variables_plane <- function (axis_1, axis_2) {
    (ggplot2::ggplot (
                  tibble::tibble (x = c (correlation_matrix[axis_1,]),
                                  y = c (correlation_matrix[axis_2,]),
                                  variable = colnames (correlation_matrix))
                  %>% dplyr::mutate (
                                 variable_index = seq_len (dplyr::n ()),
                                 quality = sqrt (x ^ 2 + y ^ 2))
                  %>% dplyr::mutate (
                                 x_circle = x / quality,
                                 y_circle = y / quality)
                  %>% dplyr::mutate (variable = ifelse (variable_index <= n_pcs,
                                                        sprintf ("PC %d", variable_index),
                                                        variable),
                                     `variable type` = ifelse (variable_index <= n_pcs,
                                                               "principal component",
                                                               ifelse (variable %in% c ("by GL", "not by GL", "by AG", "not by AG", "by NK", "not by NK"),
                                                                       "annotator indicator",
                                                                       "shape feature")))
                  %>% dplyr::filter (variable_index == axis_1 | variable_index == axis_2 | variable_index > n_pcs)
                 %>% dplyr::filter (quality > 0.2),
                  ggplot2::aes (x = x, y = y, xend = 0, yend = 0, label = variable, color = `variable type`))
        + ggforce::geom_circle (ggplot2::aes (x0 = 0, y0 = 0, r = 1), color = "black")
        + ggplot2::geom_segment ()
        + ggplot2::coord_fixed ()
        + ggrepel::geom_label_repel (color = "black")
        + ggplot2::scale_color_discrete (limits = c ("principal component", "annotator indicator", "shape feature"))
        + ggplot2::ggtitle (sprintf ("[%d, %d] plane of variables", axis_1, axis_2)))
}

for (i in seq (1, 6, by = 2)) {
    ggplot2::ggsave (sprintf ("plane_%d_%d.png", i, i + 1), plot_pane (i, i + 1))
    ggplot2::ggsave (sprintf ("variables_%d_%d.png", i, i + 1), variables_plane (i, i + 1))
}

ggplot2::ggsave (
             "outstanding_contributions.png",
             ggplot2::ggplot (do.call (dplyr::bind_rows,
                                       lapply (interesting_individuals$index, function (i) {
                                           (tibble::tibble (
                                                       individual =
                                                           sprintf ("%d by %s",
                                                                    i,
                                                                    useful_shapes$annotator[i]),
                                                       time = seq_len (ncol (shapes_data)) / 256,
                                                       signal = shapes_data[i,],
                                                       annotator = useful_shapes$annotator[i])
                                               %>% dplyr::filter (!is.na (signal)))
                                       })),
                              ggplot2::aes (x = time, y = signal, color = individual))
             + ggplot2::geom_line ()
             + ggplot2::ggtitle ("Shapes with outstanding contributions"))

interesting_individuals <- tibble::tibble (index = integer (0))
