## -----------------------------------------------------------------------------
## The model cost and residuals
## -----------------------------------------------------------------------------
# Some of the CAKE R modules are based on mkin.
#
# Modifications developed by Hybrid Intelligence (formerly Tessella), part of
# Capgemini Engineering, for Syngenta, Copyright (C) 2011-2022 Syngenta
# Tessella Project Reference: 6245, 7247, 8361, 7414, 10091
#
# The CAKE R modules are free software: you can
# redistribute them and/or modify them under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/> 

CakeCost <- function (model, obs, err = NULL, ...) {
  xName <- "time"
  yName <- "value"

  if (!is.null(err) && !is.character(err))
    stop ("'err' should be NULL or the name of the column with the error estimates in obs")

  ## Sometimes a fit is encountered for which the model is unable to calculate 
  ## values on the full range of observed values. In this case, we will return
  ## an infinite cost to ensure this value is not selected.
  modelCalculatedFully <- all(unlist(obs[xName]) %in% unlist(model[xName]))

  ##=============================================================================
  ## Observations
  ##=============================================================================

  ## The position of independent variable(s)
  ixObs <- 0
  ixObs  <- which(colnames(obs) == xName)
  if (length(ixObs) != 1)
    stop(paste("Independent variable column not found in observations", xName))

  ## The position of weighing values
  ierr <- 0
  if (! is.null(err)) {
    ierr <- which(colnames(obs) == err)
    if (length(ierr) == 0)
      stop(paste("Column with error estimates not found in observations", err))
  }

  ## The dependent variables
  ObsDataSetNames <- as.character(unique(obs[, 1]))  # Names of data sets, all data should be model variables...
  iyObs  <- which(colnames(obs) == yName)
  if (length(iyObs) == 0)
    stop(paste("Column with value of dependent variable not found in observations", yName))

  #================================
  # The model results
  #================================

  ModelDataSetNames <- colnames(model)  # Names of model variables

  ixModel <- which(colnames(model) == xName)
  if (length(ixModel) == 0)
    stop(paste("Cannot calculate cost: independent variable not found in model output", xName))
  xModel <- model[,ixModel]    # Independent variable, model
  Residual <- NULL
  CostVar  <- NULL

  #================================
  # Compare model and data...
  #================================
  xObs <- 0
  iDat <- 1:nrow(obs)

  for (i in 1:length(ObsDataSetNames)) {   # for each observed variable ...
    ii <- which(ModelDataSetNames == ObsDataSetNames[i])
    if (length(ii) == 0) stop(paste("observed variable not found in model output", ObsDataSetNames[i]))
    yModel <- model[, ii]
    
    iDat <- which(obs[,1] == ObsDataSetNames[i])
    xObs <- obs[iDat, ixObs]
    yObs <- obs[iDat, iyObs]

    # Remove NAs.
    ii <- which(is.na(yObs))
    if (length(ii) > 0) {
      xObs <- xObs[-ii]
      yObs <- yObs[-ii]
    }

    # CAKE version - Added tests for multiple non-NA values 
    if (length(ixObs) > 0 && length(unique(xModel[!is.na(xModel)]))>1 && length(yModel[!is.na(yModel)])>1)
    {
      yModel <- approx(xModel, yModel, xout = xObs)$y
    }
    else {
      cat("CakeCost Warning: Only one valid point - using mean (yModel was", yModel, ")\n")
      yModel <- mean(yModel[!is.na(yModel)])
      yObs <- mean(yObs)
    }

    iNotNa <- which(!is.na(yModel))
    yModel <- yModel[iNotNa]
    yObs <- yObs[iNotNa]
    xObs   <- xObs[iNotNa]

    if (ierr > 0) {
      Err <- obs[iDat, ierr]
      Err <- Err[iNotNa]
    } else {
      Err <- 1
    }

    if (any(is.na(Err)))
      stop(paste("error: cannot estimate weighing for observed variable: ", ObsDataSetNames[i]))
    if (min(Err) <= 0)
      stop(paste("error: weighing for observed variable is 0 or negative:", ObsDataSetNames[i]))

    if(!modelCalculatedFully){ # In this case, the model is unable to predict on the full range, set cost to Inf
      xObs <- 0
      yObs <- 0
      yModel <- Inf
      residual <- Inf
      residualWeighted <- Inf
      weight_for_residual <- Inf
    } else{
      residual <- (yModel - yObs)
      residualWeighted <- residual / Err
      weight_for_residual <- 1 / Err
    }

    Residual <- rbind(Residual,
                      data.frame(
                        name           = ObsDataSetNames[i],
                        x              = xObs,
                        yObs           = yObs,
                        yModel         = yModel,
                        weight         = weight_for_residual,
                        res.unweighted = residual,
                        res            = residualWeighted))

    CostVar <- rbind(CostVar,
                  data.frame(
                    name           = ObsDataSetNames[i],
                    N              = length(residual),
                    SSR.unweighted = sum(residual^2),
                    SSR.unscaled   = sum(residualWeighted^2),
                    SSR            = sum(residualWeighted^2)))
  } # end loop over all observed variables

  ## SSR
  Cost  <- sum(CostVar$SSR)
  Lprob <- -sum(log(pmax(0, dnorm(Residual$yModel, Residual$yObs, Err)))) # avoid log of negative values
  out <- list(model = Cost, cost = Cost, minlogp = Lprob, var = CostVar, residuals = Residual)
  class(out) <- "modCost"

  return(out)
}

## -----------------------------------------------------------------------------
## Internal cost function for optimisers
## -----------------------------------------------------------------------------
# Cost function. The returned structure must have $model
# We need to preserve state between calls so make a closure
CakeInternalCostFunctions <- function(mkinmod, state.ini.optim, state.ini.optim.boxnames, 
                                    state.ini.fixed, parms.fixed, observed, mkindiff,  
                                    quiet, atol=1e-6, solution="deSolve", err="err"){
    cost.old <- 1e+100
    calls <- 0
    out_predicted <- NA
    
    get.predicted <- function(){ out_predicted }
    
    get.best.cost <- function(){ cost.old }
    reset.best.cost <- function() { cost.old <<- 1e+100 }
    
    get.calls <- function(){ calls }
    set.calls <- function(newcalls){ calls <<- newcalls }
    
    set.error<-function(err) { observed$err <<- err }
    
    # The called cost function
    cost <- function(P) {
        assign("calls", calls + 1, inherits = TRUE)
        print(P)
        
        if (length(state.ini.optim) > 0) {
            odeini <- c(P[1:length(state.ini.optim)], state.ini.fixed)
            names(odeini) <- c(state.ini.optim.boxnames, names(state.ini.fixed))
        } else {
          odeini <- state.ini.fixed
        }
        
        odeparms <- c(P[(length(state.ini.optim) + 1):length(P)], parms.fixed)
        
        # Ensure initial state is at time 0
        outtimes = unique(c(0,observed$time))
        
        odeini <- AdjustOdeInitialValues(odeini, mkinmod, odeparms)
        
        if (solution == "analytical") {
          parms <- as.list(c(odeparms, odeini))
          parent.type <- names(mkinmod$map[[1]])[1]
          parent.name <- names(mkinmod$diffs)[[1]]

          out <- CakeAnalyticalSolution(parms, parent.type, parent.name, outtimes)
        }
        if (solution == "deSolve")  
        {
          out <- ode(y = odeini, times = outtimes, func = mkindiff, parms = odeparms, atol = atol)
        }
        
        out_transformed <- PostProcessOdeOutput(out, mkinmod, atol)
        
        assign("out_predicted", out_transformed, inherits = TRUE)
        modCost <- CakeCost(out_transformed, observed,  err = err)
        modCost$penalties <- CakePenalties(odeparms, out_transformed, observed)
        modCost$model <- modCost$cost + modCost$penalties
        
        if (modCost$model < cost.old) {
            if (!quiet) {
                cat("Model cost at call ", calls, ": m", modCost$cost, 'p:', modCost$penalties, 'o:', modCost$model, "\n")
            }
          
            assign("cost.old", modCost$model, inherits = TRUE)
        }
        
        # HACK to make nls.lm respect the penalty, as it just uses residuals and ignores the cost
        if(modCost$penalties > 0){
            modCost$residuals$res <- modCost$residuals$res + (sign(modCost$residuals$res) * modCost$penalties / length(modCost$residuals$res))
        }
        
        return(modCost)
    }
    
    list(cost=cost, 
        get.predicted=get.predicted,
        get.calls=get.calls, set.calls=set.calls,
        get.best.cost=get.best.cost, reset.best.cost=reset.best.cost,
        set.error=set.error
    )
}

CakeAnalyticalSolution <- function(parms, parent.type, parent.name, outtimes) {
    evalparse <- function(string)
    {
      eval(parse(text=string), parms)
    }

    o <- switch(parent.type,
                SFO = SFO.solution(outtimes,
                                    evalparse(parent.name),
                                    evalparse(paste("k", parent.name, sep="_"))),
                FOMC = FOMC.solution(outtimes,
                                    evalparse(parent.name),
                                    evalparse("alpha"), evalparse("beta")),
                DFOP = DFOP.solution(outtimes,
                                    evalparse(parent.name),
                                    evalparse(paste("k1", parent.name, sep="_")),
                                    evalparse(paste("k2", parent.name, sep="_")),
                                    evalparse(paste("g", parent.name, sep="_"))),
                HS = HS.solution(outtimes,
                                evalparse(parent.name),
                                evalparse("k1"), evalparse("k2"),
                                evalparse("tb")),
                IORE = IORE.solution(outtimes,
                                    evalparse(parent.name),
                                    evalparse(paste("k", parent.name, sep="_")),
                                    evalparse("N")))

    out <- cbind(outtimes, o)
    dimnames(out) <- list(outtimes, c("time", sub("_free", "", parent.name)))

    return(out)
}