# set directory -----------------------------------------------------------

oldw <- getOption("warn")
options(warn = -1)

setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
getwd()

# Directory and libraries -------------------------------------------------
library(tidyverse)
library(strex)
library(data.table)
library(summarytools)
library(descr)

library(mlr3)
library(mlr3learners)
library(DALEXtra)
library(xgboost)
library(mlr)
require(gridExtra)
library(cvms)
library(janitor)
library(yardstick)
library(caret)
require(Matrix)

# read english data -------------------------------------------------------

english_chunk <- read_rds("../chunk_info_allfeats.rds") %>% 
  rename(reftype_2= refType_2) %>% 
  rename(reftype_3= reftype_3) %>% 
  rename(reftype_4= refType_4)

train_en <- english_chunk %>% 
  filter(type=="train") %>% 
  select(2,12:23)

test_en <- english_chunk %>% 
  filter(type=="test") %>% 
  select(2,12:23)

dev_en <- english_chunk %>% 
  filter(type=="dev") %>% 
  select(2,12:23)

#***********************************
# functions
#***********************************
# predictions_3 <- function(model,test,test_set){ predict(model, test, type="prob") %>% 
#     as.data.frame() %>% 
#     mutate(class_prediction = colnames(.)[max.col(.)]) %>%
#     mutate(class_prediction= as.factor(class_prediction)) %>% 
#     mutate(original_value = test_en$reftype_3) %>% 
#     mutate(ID= test_en$ID)
# }

predictions <- function(model,test){ predict(model, test, type="prob") %>% 
    as.data.frame() %>% 
    mutate(class_prediction = colnames(.)[max.col(.)]) %>%
    mutate(class_prediction= as.factor(class_prediction)) %>% 
    mutate(original_value = test$reftype) 
}

predictions_3 <- function(model,test){ predict(model, test, type="prob") %>% 
    as.data.frame() %>% 
    mutate(class_prediction = colnames(.)[max.col(.)]) %>%
    mutate(class_prediction= as.factor(class_prediction)) %>% 
    mutate(original_value = test$reftype_3) 
}

predictions_4 <- function(model,test){ predict(model, test, type="prob") %>% 
    as.data.frame() %>% 
    mutate(class_prediction = colnames(.)[max.col(.)]) %>%
    mutate(class_prediction= as.factor(class_prediction)) %>% 
    mutate(original_value = test$reftype_4) %>% 
    as.data.frame() %>% 
    add_row(original_value = "name",class_prediction = "demonstrative")
}

perclass_all <- function(dt){
  modelname = deparse(substitute(dt))
  cm = as.matrix(table(Actual = dt$original_value, Predicted = dt$class_prediction))
  n = sum(cm) # number of instances
  nc = nrow(cm) # number of classes
  diag = diag(cm) # number of correctly classified instances per class 
  rowsums = apply(cm, 1, sum) # number of instances per class
  colsums = apply(cm, 2, sum) # number of predictions per class
  p = rowsums / n # distribution of instances over the actual classes
  q = colsums / n # distribution of instances over the predicted classes
  accuracy = sum(diag) / n 
  precision = round(diag / colsums, digits = 5)
  recall = round(diag / rowsums, digits = 5) 
  f1 = round(2 * precision * recall / (precision + recall), digits = 5) 
  macroPrecision = round(mean(precision), digits=5)
  macroRecall = round(mean(recall), digits = 5)
  macroF1 = round(mean(f1), digits = 5)
  oneVsAll = lapply(1 : nc,
                    function(i){
                      v = c(cm[i,i],
                            rowsums[i] - cm[i,i],
                            colsums[i] - cm[i,i],
                            n-rowsums[i] - colsums[i] + cm[i,i]);
                      return(matrix(v, nrow = 2, byrow = T))})
  s = matrix(0, nrow = 2, ncol = 2)
  for(i in 1 : nc){s = s + oneVsAll[[i]]}
  avgAccuracy = sum(diag(s)) / sum(s)
  micro_prf = (diag(s) / apply(s,1, sum))[1];
  dt = data.frame(modelname,diag,rowsums, colsums,p,q, precision, recall, f1, macroPrecision, macroRecall, macroF1)
}

#dt = data.frame(modelname,diag,rowsums, colsums,p,q, precision, recall, f1, macroPrecision, macroRecall, macroF1, avgAccuracy)

#***********************************
# xgboost model function
#***********************************
ctrl=trainControl(method = 'cv',
                  number = 5,
                  verboseIter = T,
                  #savePredictions = 'final',
                  savePredictions=TRUE,
                  classProbs = T)


tune_grid <- expand.grid(nrounds = 1000,
                         max_depth = 5,
                         eta = 0.05,
                         gamma = 0.01,
                         colsample_bytree = 0.75,
                         min_child_weight = 0,
                         subsample = 0.5)

xgboost<- function(train,test,model){
  set.seed(200)
  xgboost <- caret::train(reftype ~., data = train, method = "xgbTree",
                              trControl=ctrl,
                              tuneGrid = tune_grid,
                              tuneLength = 10)
  pred <- predictions(xgboost,test) %>% as.data.frame()
  write_delim(pred, path=paste("en/results/", model,'_pred.txt',sep=''), delim = "\t")
  write_rds(pred,path=paste("en/results/", model,'_pred.rds',sep=''))
  stat <- perclass_all(pred)
  write_delim(stat, path=paste("en/results/", model,'_perclass.txt',sep=''), delim = "\t")
  write_rds(stat,path=paste("en/results/", model,'_perclass.rds',sep=''))
  cm <- caret::confusionMatrix(data = pred$class_prediction, reference = pred$original_value, mode = "everything")
  overall <- print(cm$overall) %>% as.data.frame() %>% rownames_to_column(.) %>% rename(model= '.')
  write_delim(overall, path=paste("en/results/", model,'_overall.txt',sep=''), delim = "\t")
  write_rds(overall,path=paste("en/results/", model,'_overall.rds',sep=''))
  confmat <- as.data.frame.matrix(cm$table)
  write_delim(confmat, path=paste("en/results/", model,'_confmat.txt',sep=''), delim = "\t")
  set.seed(123)
  truth_predicted <- data.frame(
    obs = pred$original_value,
    pred = pred$class_prediction
  )
  truth_predicted$obs <- as.factor(truth_predicted$obs)
  truth_predicted$pred <- as.factor(truth_predicted$pred)
  cm_plt <- conf_mat(truth_predicted, obs, pred)
  p <- autoplot(cm_plt, type = "heatmap") +
    scale_fill_gradient(low="#D6EAF8",high = "#2E86C1")
  ggsave(filename=paste("en/results/", model,'_confmat.jpg',sep=''), plot=p)
}

#***********************
# 2 way
#***********************************
colnames(train_en)
two_way <- c("reftype_2","Syn","DisStat","SenStat","DistAnt",
               "IntRef","LocPro","GloPro","entity") 

two_way_train <- train_en %>% 
  select(all_of(two_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_2)

two_way_test <- test_en %>% 
  select(all_of(two_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_2)

xgboost(two_way_train,two_way_test,"en_two_way")

#***********************
# 3 way
#***********************************
colnames(train_en)
three_way <- c("reftype_3","Syn","DisStat","SenStat","DistAnt",
             "IntRef","LocPro","GloPro","entity") 

three_way_train <- train_en %>% 
  select(all_of(three_way)) %>%
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_3)

three_way_test <- test_en %>% 
  select(all_of(three_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_3)

xgboost(three_way_train,three_way_test,"en_three_way")


#***********************
# 4 way oversample
#***********************************
colnames(train_en)
four_way <- c("reftype_4","Syn","DisStat","SenStat","DistAnt",
               "IntRef","LocPro","GloPro","entity") 

four_way_train <- train_en %>% 
  select(all_of(four_way)) %>%
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_4)

# set.seed(234)
# four_way_train_up <- upSample(x = four_way_train[, -1],
#                                y = four_way_train$reftype) %>%
#   rename(reftype=Class) %>%
#   select(9,1:8)


four_way_test <- test_en %>% 
  select(all_of(four_way)) %>% 
  mutate_if(.,is.character,as.factor) %>% 
  rename(reftype=reftype_4)

xgboost(four_way_train_up,four_way_test,"en_three_way")

#***********************
# 4 way
#***********************************
colnames(train_en)
four_way <- c("reftype_4","Syn","DisStat","SenStat","DistAnt",
               "IntRef","LocPro","GloPro","entity") 

four_way_train <- train_en %>% 
  select(all_of(four_way)) %>% mutate_if(.,is.character,as.factor) 
four_way_test <- test_en %>% select(all_of(four_way)) %>% mutate_if(.,is.character,as.factor)

set.seed(200)
xgboost_4 <- caret::train(reftype_4 ~., data = four_way_train, method = "xgbTree",
                        trControl=ctrl,
                        tuneGrid = tune_grid,
                        tuneLength = 10)

pred <- predictions_4(xgboost_4,four_way_test) %>% as.data.frame() 

write_delim(pred, path=paste("en/results/en_four_way",'_pred.txt',sep=''), delim = "\t")
write_rds(pred,path=paste("en/results/en_four_way",'_pred.rds',sep=''))

stat <- perclass_all(pred)

stat <- stat %>% 
  as.data.frame() %>% 
  mutate(f1= ifelse(f1 =="NaN", 0, f1 )) %>% 
  mutate(macroF1= ifelse(macroF1== "NaN", mean(f1),macroF1))

write_delim(stat, file = paste("en/results/en_four_way",'_perclass', ".txt"), delim = '\t')

