#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
speed_numba.py

Purpose:
    Check speed of computation using Numba jit-translation

Version:
    1       First start
    2       Use decorators, and parallel

Date:
    2018/8/29, 2021/3/27

Author:
    Charles Bos
"""
###########################################################
### Imports
import numpy as np
from numba import jit,njit,prange
from lib.inctime import *

###########################################################
### mXtX= MatMult(mX, iR)
def MatMult(mX, iR):
    """
    Purpose:
        Pass time computing R times X'X, using numpy

    Inputs:
        mX  iN x iK matrix
        iR  integer, number of repetitions

    Return value:
        mXtX    iK x iK matrix of X'X
    """
    for r in range(iR):
        mXtX= mX.T@mX
    return mXtX

###########################################################
### mXtX= Loop(mX, iR)
def Loop(mX, iR):
    """
    Purpose:
        Pass time computing R times X'X, in a loop

    Inputs:
        mX  iN x iK matrix
        iR  integer, number of repetitions

    Return value:
        mXtX    iK x iK matrix of X'X
    """
    (iN, iK)= mX.shape
    for r in range(iR):
        mXtX= np.zeros((iK, iK))
        for i in range(iK):
            for j in range(i+1):
                for k in range(iN):
                    mXtX[i,j]+= mX[k,i] * mX[k,j]
                mXtX[j, i]= mXtX[i, j]
    return mXtX

###########################################################
### mXtX= Loop_NJit(mX, iR)
@njit()
def Loop_NJit(mX, iR):
    """
    Purpose:
        Pass time computing R times X'X, in a loop

    Inputs:
        mX  iN x iK matrix
        iR  integer, number of repetitions

    Return value:
        mXtX    iK x iK matrix of X'X
    """
    (iN, iK)= mX.shape
    for r in range(iR):
        mXtX= np.zeros((iK, iK))
        for i in range(iK):
            for j in range(i+1):
                for k in range(iN):
                    mXtX[i,j]+= mX[k,i] * mX[k,j]
                mXtX[j, i]= mXtX[i, j]
    return mXtX

###########################################################
### mXtX= Loop_Inner(mX)
@njit(parallel= False)      # Do the inner part translated to C, no parallelisation
def Loop_Inner(mX):
    """
    Purpose:
        Pass time computing X'X, in a loop

    Inputs:
        mX  iN x iK matrix

    Return value:
        mXtX    iK x iK matrix of X'X
    """
    (iN, iK)= mX.shape
    mXtX= np.zeros((iK, iK))
    for i in range(iK):
        for j in range(i+1):
            for k in range(iN):
                mXtX[i,j]+= mX[k,i] * mX[k,j]
            mXtX[j, i]= mXtX[i, j]
    return mXtX

###########################################################
### mXtXr= Loop_parallel(mX, iR)
@njit(parallel= True)       # Do the outer loop in parallel
def Loop_parallel(mX, iR):
    """
    Purpose:
        Pass time computing R times X'X, in a loop

    Inputs:
        mX  iN x iK matrix
        iR  integer, number of repetitions

    Return value:
        mXtX    iK x iK matrix of X'X (technically, the average of iR repetitions)
    """
    (iN, iK)= mX.shape
    mXtXr= np.zeros((iK, iK))
    for r in prange(iR):            # Use prange, indicating a parallel loop
        mXtXr+= Loop_Inner(mX)      # Reduction, combine the results by computing the average
    return mXtXr/iR

###########################################################
### main
def main():
    # Magic numbers
    iN= 100
    iK= 10
    iR= 1000

    # Initialisation
    mX= np.hstack([np.ones((iN, 1)), np.random.randn(iN, iK-1)])

    print ("Calculation X'X for (%i x %i) matrix X, repeating R=%i times" % (iN, iK, iR))

    # Estimation
    with Timer("Loop, Rx"):
        mXtX= Loop(mX, iR)
    with Timer("MatMult, Rx"):
        mXtX= MatMult(mX, iR)

    with Timer("Loop_NJit 1x, compiling"):
        mXtX= Loop_NJit(mX, 1)
    with Timer("Loop_NJit Rx"):
        mXtX= Loop_NJit(mX, iR)

    with Timer("Loop_parallel 1x, compiling"):
        mXtX= Loop_parallel(mX, 1)
    with Timer("Loop_parallel Rx"):
        mXtX= Loop_parallel(mX, iR)

    # Output
    print ("Quite a difference...\n")

###########################################################
### start main
if __name__ == "__main__":
    main()
