#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
estnorm.py

Purpose:
    Estimate a normal regression model, using lambda function

Version:
    1       Following estnorm.ox, using 1d parameter vectors
    slsqp   Using quadratic programming slsqp

Date:
    2017/8/21, 2019/8/30

Author:
    Charles Bos
"""
###########################################################
### Imports
import numpy as np
import pandas as pd
#import matplotlib.pyplot as plt
import scipy.optimize as opt
import math

###########################################################
### Get hessian and related functions
from lib.grad import *

###########################################################
def GetPars(vP):
    """
    Purpose:
      Read out the parameters from the vector

    Inputs:
      vP        iK+1 vector with sigma and beta's

    Return value:
      dS        double, sigma
      vBeta     iK vector, beta's
    """
    iK= np.size(vP)-1
    # Force vP to be a 1D matrix
    vP= vP.reshape(iK+1,)
    dS= vP[0]   # np.fabs(vP[0])
    vBeta= vP[1:]

    return (dS, vBeta)

###########################################################
def GetParNames(iK):
    """
    Purpose:
      Construct names for the parameters from the vector

    Inputs:
      iK        integer, number of beta's

    Return value:
      asP       iK array, with strings "sigma", "b1", ...
    """
    asP= ["B"+str(i+1) for i in range(iK)]
    asP= ["Sigma"] + asP

    return asP

###########################################################
### mX= GenrX(iN, iK)
def GenrX(iN, iK):
    """
    Purpose:
      Generate regressors, constant + uniforms

    Inputs:
      iN        integer, number of observations
      iK        integer, number of regressors

    Return values:
      mX        iN x iK matrix of regressors, constant + uniforms
    """
    mX= np.hstack([np.ones((iN, 1)), np.random.rand(iN, iK-1)])

    return mX

###########################################################
### vY= GenrY(vP, mX)
def GenrY(vP, mX):
    """
    Purpose:
      Generate regression data

    Inputs:
      vP        iK+1 vector of parameters, sigma and beta
      mX        iN x iK matrix of regressors

    Return values:
      vY        iN vector of data
    """
    iN= mX.shape[0]
    (dS, vBeta)= GetPars(vP)
    vY= mX@vBeta + dS * np.random.randn(iN)

    return vY

###########################################################
### vLL= LnLRegr(vP, vY, mX)
def LnLRegr(vP, vY, mX):
    """
    Purpose:
        Compute loglikelihood of regression model

    Inputs:
        vP      iK+1 1D-vector of parameters, with sigma and beta
        vY      iN 1D-vector of data
        mX      iN x iK matrix of regressors

    Return value:
        vLL     iN vector, loglikelihood
    """
    (iN, iK)= mX.shape
    if (np.size(vP) != iK+1):         # Check if vP is as expected
        print ("Warning: wrong size vP= ", vP)

    (dSigma, vBeta)= GetPars(vP)
    vLL= -math.inf*np.ones(iN)    # Value, just in case
    if (dSigma <= 0):
        print ("x", end="")
        return vLL

    vE= vY - mX @ vBeta
    vLL= -0.5*(np.log(2*np.pi) + 2*np.log(dSigma) + np.square(vE/dSigma))

    print (".", end="")             # Give sign of life

    return vLL

###########################################################
### dPos= fnsigmapos(vP)
def fnsigmapos(vP):
    """
    Purpose:
      Provide a function which is supposed to stay positive
    """
    dSigma= GetPars(vP)[0]

    return dSigma

###########################################################
### (vP, vS, dLL, sMess)= EstimateRegr(vY, mX, bBounds)
def EstimateRegr(vY, mX, bBounds):
    """
    Purpose:
      Estimate the regression model

    Inputs:
      vY        iN vector of data
      mX        iN x iK matrix of regressors
      bBounds   boolean, if True use bounds instead of constraint function

    Return value:
      vP        iK+1 vector of optimal parameters sigma and beta's
      vS        iK+1 vector of standard deviations
      dLL       double, loglikelihood
      sMess     string, output of optimization
    """
    (iN, iK)= mX.shape
    vP0= np.ones(iK+1)        # Get (bad...) starting values

    # vB= np.linalg.lstsq(mX, vY)[0]
    # vP0= np.vstack([[[1]], vB])

    # Create lambda function returning NEGATIVE AVERAGE LL, as function of vP only
    AvgNLnLRegr= lambda vP: -np.mean(LnLRegr(vP, vY, mX), axis=0)

    dLL= -iN*AvgNLnLRegr(vP0)

    # Construct tuple of bounds, lower/higher
    tBounds= ((0, None),) + iK*((None, None),)
    # Construct lambda function which is kept positive: Here, return sigma
    fnsigmapos= lambda vP: GetPars(vP)[0]
    # Construct tuple of bounds, lower/higher
    tCons= ({'type': 'ineq', 'fun': fnsigmapos})

    if (bBounds):
        res= opt.minimize(AvgNLnLRegr, vP0, method="SLSQP", bounds=tBounds)
    else:
        res= opt.minimize(AvgNLnLRegr, vP0, method="SLSQP", constraints=tCons)

    vP= res.x
    sMess= res.message
    dLL= -iN*res.fun
    print ("\nSLSQP using ", "bounds" if bBounds else "constraints", "results in ", sMess, "\nPars: ", vP, "\nLL= ", dLL, ", f-eval= ", res.nfev)

    mHn= hessian_2sided(AvgNLnLRegr, vP)
    mH= -iN*mHn
    mS2= -np.linalg.inv(mH)
    vS= np.sqrt(np.diag(mS2))

    return (vP, vS, dLL, sMess)

###########################################################
### Output(mPPS, dLL, sMess)
def Output(mPPS, dLL, sMess):
    """
    Purpose:
      Provide output on screen
    """
    iK= mPPS.shape[1]-1
    print ("\n\nEstimation resulted in ", sMess)
    print ("Using ML with LL= ", dLL)

    print ("Parameter estimates:\n",
           pd.DataFrame(mPPS.T, index=GetParNames(iK), columns=["PTrue", "PHat", "s(P)"]))


###########################################################
### main
def main():
    vP0= [.1, 5, 2, -2]    #dSigma and vBeta together
    iN= 100
    iSeed= 1234
    bBounds= True

    #Generate data
    np.random.seed(iSeed)
    vP0= np.array(vP0)
    iK= vP0.size - 1
    mX= GenrX(iN, iK)
    vY= GenrY(vP0, mX)

    (vP, vS, dLnPdf, sMess)= EstimateRegr(vY, mX, bBounds)
    (vP, vS, dLnPdf, sMess)= EstimateRegr(vY, mX, not bBounds)
    Output(np.vstack([vP0, vP, vS]), dLnPdf, sMess);

###########################################################
### start main
if __name__ == "__main__":
    main()
