!---------------------------------------------------------------------------------------
! FLEXINVERT: mod_analytic
!---------------------------------------------------------------------------------------
!  FLEXINVERT is free software: you can redistribute it and/or modify
!  it under the terms of the GNU General Public License as published by
!  the Free Software Foundation, either version 3 of the License, or
!  (at your option) any later version.
!
!  FLEXINVERT is distributed in the hope that it will be useful,
!  but WITHOUT ANY WARRANTY; without even the implied warranty of
!  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!  GNU General Public License for more details.
!
!  You should have received a copy of the GNU General Public License
!  along with FLEXINVERT.  If not, see <http://www.gnu.org/licenses/>.
!
!  Copyright 2017, Rona Thompson
!---------------------------------------------------------------------------------------
!
!> mod_analytic
!! Analytical Solution Module
!!
!! Purpose:    Calculates the analytical solution to the inverse problem
!!
!! Details:    Finds the state vector, x, corresponding to the maximum probability, 
!!             that is, the solution which minimizes the cost function:
!!
!!             J(x) = (x-xb)^T*B^(-1)*(x-xb) + (Hx-y)^T*R^(-1)*(Hx-y)
!!
!!             where:
!!               xb = prior/background value of x (nn)
!!               B  = prior error covariance matrix (nn,nn)
!!               H  = kernel matrix (nm,nn)
!!               R  = observation error covariance matrix (nm,nm)
!!               y  = observation vector (nm)
!!
!!             Optimal value of x is found using either:
!!
!!             1) invert_nvar
!!                inverts the nvar x nvar dimensioned matrix: (H^TR^(-1)H + B^(-1)) 
!!                x = xb + (H^TR^(-1)H + B^(-1))^(-1)H^TR^(-1)(y-Hxb)
!!                and calculates the posterior covariance:
!!                cova = (H^TR^(-1)H + B^(-1))^(-1)
!!
!!             2) invert_nobs
!!                inverts the nobs x nobs dimensioned matrix: (HBH^T + R)
!!                x = xb + BH^T(HBH^T + R)^(-1)(y - Hxb)
!!                and calculates the posterior covariance:
!!                cova = B - BH^T(HBH^T + R)^(-1)HB
!!
!! Reference:  Tarantola, Inverse Problem Theory, 2005
!
!---------------------------------------------------------------------------------------

module mod_analytic

  use mod_var
  use mod_settings
  use mod_obs
  use mod_states
  use mod_covar

  implicit none
  private

  public :: analytic, invert_nvar, invert_nobs, calc_cost

  contains

  ! --------------------------------------------------
  ! analytic
  ! --------------------------------------------------

  subroutine analytic(files, config, obs, states, covar, cost_o)

    implicit none

    type (files_t),             intent (in)     :: files
    type (config_t),            intent (in)     :: config
    type (obs_t),               intent (in out) :: obs
    type (states_t),            intent (in out) :: states
    type (covar_t),             intent (in out) :: covar
    real,                       intent (in out) :: cost_o

    if ( nvar.le.nobs ) then
      call invert_nvar(files, config, obs, states, covar)
    else
      call invert_nobs(files, config, obs, states, covar)
    endif

    call calc_cost(config, obs, states, cost_o)

  end subroutine analytic

  ! --------------------------------------------------
  ! invert_nvar
  ! --------------------------------------------------

  subroutine invert_nvar(files, config, obs, states, covar)

    type (files_t),             intent (in)     :: files
    type (config_t),            intent (in)     :: config
    type (obs_t),               intent (in out) :: obs
    type (states_t),            intent (in out) :: states
    type (covar_t),             intent (in out) :: covar

    character(len=max_path_len)                 :: filename
    character(len=max_path_len)                 :: rowfmt
    real(kind=8), dimension(nobs)               :: icovr
    real(kind=8), dimension(:,:), allocatable   :: gain
    real(kind=8), dimension(:,:), allocatable   :: nwork
    real(kind=8), dimension(:,:), allocatable   :: zwork
    real(kind=8), dimension(:,:), allocatable   :: izwork
    real(kind=8), dimension(:,:), allocatable   :: umat, tvmat
    real(kind=8), dimension(:), allocatable     :: sigma, isigma
    real(kind=8), dimension(:), allocatable     :: work
    integer, dimension(:), allocatable          :: iwork
    real(kind=8), dimension(neig)               :: tmp
    real(kind=8), dimension(ndvar)              :: tmp1
    integer                                     :: lwork
    integer                                     :: i, j, k, it, jt
    integer                                     :: info, ierr

    real(kind=8), dimension(ndvar,ndvar)        :: icovb
    real(kind=8), dimension(ndvar,neig)         :: tmp0
    real(kind=8), dimension(:,:), allocatable   :: zwork_dir

    ! inverse observation error covariance
    ! ------------------------------------

    ! for diagonal R 
    icovr(:) = 0d0
    do i = 1, nobs
      if ( obs%err(i).gt.0. ) icovr(i) = 1d0/dble(obs%err(i))**2
    end do
 
    ! calculate (H^TR^(-1)H + B^(-1))
    ! -------------------------------

    ! calculate H^TR^(-1)
    allocate ( nwork(nvar,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    do i = 1, nvar
      nwork(i,:) = dble(hmat(:,i))*icovr
    end do

    allocate ( zwork(nvar,nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    zwork(:,:) = 0d0
    do jt = 1, ntstate
      do k = 1, ndvar
        do j = 1, neig
          ! row vector of U*D^(-1)
          tmp(j) = covar%evecs(k,j)/covar%evals(j)
        end do
        do i = 1, ndvar
          ! row vector of B^(-1)
          tmp1(i) = dot_product(tmp,covar%evecs(i,:))
        end do
        do it = 1, ntstate
          zwork((jt-1)*ndvar+k,(it-1)*ndvar+1:it*ndvar) = covar%icort(jt,it)*tmp1(:) + &
                matmul(nwork((jt-1)*ndvar+k,:),dble(hmat(:,(it-1)*ndvar+1:it*ndvar)))
        end do
      end do 
    end do 
    if ( config%opt_cini ) then
      zwork(:,npvar+1:nvar) = matmul(nwork(:,:),dble(hmat(:,npvar+1:nvar)))
      zwork(npvar+1:nvar,:) = matmul(nwork(npvar+1:nvar,:),dble(hmat(:,:)))
      do it = 1, ntcini*ncini
        zwork(npvar+it,npvar+it) = zwork(npvar+it,npvar+it) + 1./config%cinierr**2
      end do
    endif

    ! verbose only
    ! ------------

    if ( config%verbose ) then
      ! B^(-1)
      do i = 1, ndvar
        tmp0(i,:) = covar%evecs(i,:)/covar%evals
      end do
      icovb = matmul(tmp0,transpose(covar%evecs))
      allocate ( zwork_dir(nvar,nvar), stat=ierr )
      if ( ierr.ne.0 ) then
        write(logid,*) 'ERROR invert_nvar: not enough memory'
        stop
      endif
      ! H^TR^(-1)H + B^(-1)
      do jt = 1, ntstate
        do it = 1, ntstate
          zwork_dir((jt-1)*ndvar+1:jt*ndvar,(it-1)*ndvar+1:it*ndvar) = &
                   matmul(nwork((jt-1)*ndvar+1:jt*ndvar,:),dble(hmat(:,(it-1)*ndvar+1:it*ndvar))) + covar%icort(jt,it)*icovb
        end do
      end do
      if ( config%opt_cini ) then
        zwork_dir(:,npvar+1:nvar) = matmul(nwork(:,:),dble(hmat(:,npvar+1:nvar)))
        zwork_dir(npvar+1:nvar,:) = matmul(nwork(npvar+1:nvar,:),dble(hmat(:,:)))
        do it = 1, ntcini*ncini
          zwork_dir(npvar+it,npvar+it) = zwork_dir(npvar+it,npvar+it) + 1./config%cinierr**2
        end do
      endif    
      filename = trim(files%path_output)//'zwork.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nvar,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) zwork(i,:)
      end do
      close(100)
      filename = trim(files%path_output)//'zwork_dir.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nvar,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) zwork_dir(i,:)
      end do
      close(100)
      deallocate(zwork_dir)
    endif

    ! inverse of (H^TR^(-1)H + B^(-1))
    ! --------------------------------

    ! use Singular Value Decomposition
    lwork = 4*nvar*nvar + 7*nvar
    allocate( work(lwork), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    allocate( iwork(8*nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    allocate( umat(nvar,nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    allocate( tvmat(nvar,nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    allocate( sigma(nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    call dgesdd('A',nvar,nvar,zwork,nvar,sigma,umat,nvar,tvmat,nvar,work,lwork,iwork,info)
    if ( info.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: singular value decomposition'
      stop
    endif
    allocate( isigma(nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    allocate( izwork(nvar,nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    isigma(:) = 0d0
    do i = 1, nvar
      if(sigma(i).ne.0.) isigma(i)=1d0/sigma(i)
    end do
    do i = 1, nvar
      izwork(i,:) = matmul(tvmat(:,i)*isigma, transpose(umat))
    end do
    deallocate(umat)
    deallocate(tvmat)
    deallocate(sigma)
    deallocate(isigma)
    deallocate(iwork)
    deallocate(work)

    ! gain matrix
    ! -----------

    allocate ( gain(nvar,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nvar: not enough memory'
      stop
    endif
    gain = matmul(izwork, nwork)

    ! calculate x = x0 + G(y - Hx0)
    ! -----------------------------

    states%px = states%px0 + matmul(real(gain,kind=4),(-1.*obs%delta))

    ! posterior error
    ! ---------------

    ! include cross-correlations
    do i = 1, nvar
!      if ( sum(izwork(i,:)).gt.0 ) states%pxerr(i) = real(sqrt(sum(izwork(i,:))),kind=4)
      if ( izwork(i,i).gt.0 ) states%pxerr(i) = real(sqrt(izwork(i,i)),kind=4)
    end do

    ! verbose output
    ! --------------

    if ( config%verbose.or.config%const_out ) then
      filename = trim(files%path_output)//'cova.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nvar,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) izwork(i,:)
      end do
      close(100)
    endif

    filename = trim(files%path_output)//'gain.txt'
    open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
    write(rowfmt,'(A,I6,A)') '(',nobs,'(E11.4,1X))'
    do i = 1, nvar
      write(100,rowfmt) gain(i,:)
    end do
    close(100)

    deallocate(nwork)
    deallocate(zwork)
    deallocate(izwork)
    deallocate(gain)

  end subroutine invert_nvar

  ! --------------------------------------------------
  ! invert_nobs
  ! --------------------------------------------------

  subroutine invert_nobs(files, config, obs, states, covar)

    type (files_t),             intent (in)     :: files
    type (config_t),            intent (in)     :: config
    type (obs_t),               intent (in out) :: obs
    type (states_t),            intent (in out) :: states
    type (covar_t),             intent (in out) :: covar

    character(len=max_path_len)                 :: filename
    character(len=max_path_len)                 :: rowfmt
    real(kind=8), dimension(:,:), allocatable   :: cova
    real(kind=8), dimension(:,:), allocatable   :: gain
    real(kind=8), dimension(:,:), allocatable   :: zwork
    real(kind=8), dimension(:,:), allocatable   :: mwork
    real(kind=8), dimension(:,:), allocatable   :: imwork
    real(kind=8), dimension(neig)               :: tmp
    real(kind=8), dimension(1,ndvar)            :: tmp1
    integer                                     :: i, j, k, it, jt
    integer                                     :: info, ierr

    real(kind=8), dimension(:,:), allocatable   :: zwork_dir
    real(kind=8), dimension(ndvar,ndvar)        :: covb
    real(kind=8), dimension(ndvar,neig)         :: tmp0


    ! calculate BH^T
    ! --------------

    allocate ( zwork(nvar,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nobs: not enough memory'
      stop
    endif
    zwork(:,:) = 0d0
    do jt = 1, ntstate
      do k = 1, ndvar
        do j = 1, neig
          ! row vector of U*D
          tmp(j) = covar%evecs(k,j)*covar%evals(j)
        end do
        do i = 1, ndvar
          ! row vector of B
          tmp1(1,i) = dot_product(tmp,covar%evecs(i,:))
        end do
        do it = 1, ntstate
            zwork((jt-1)*ndvar+k,:) = zwork((jt-1)*ndvar+k,:) + covar%cort(jt,it) * &
                    matmul(tmp1(1,:),transpose(dble(hmat(:,(it-1)*ndvar+1:it*ndvar))))
        end do
      end do 
    end do 
    if ( config%opt_cini ) then
      zwork(npvar+1:nvar,:) = config%cinierr**2 * transpose(dble(hmat(:,npvar+1:nvar)))
    endif

    ! verbose only
    ! ------------

    if ( config%verbose ) then
      ! B
      do i = 1, ndvar
        tmp0(i,:) = covar%evecs(i,:)*covar%evals(:)
      end do
      covb = matmul(tmp0,transpose(covar%evecs))
      ! BH^(T)
      allocate ( zwork_dir(nvar,nobs), stat=ierr )
      if ( ierr.ne.0 ) then
        write(logid,*) 'ERROR invert_nobs: not enough memory'
        stop
      endif
      zwork_dir(:,:) = 0d0
      do jt = 1, ntstate
        do it = 1, ntstate
          zwork_dir((jt-1)*ndvar+1:jt*ndvar,:) = zwork_dir((jt-1)*ndvar+1:jt*ndvar,:) + &
                      matmul(covar%cort(jt,it)*covb,transpose(dble(hmat(:,(it-1)*ndvar+1:it*ndvar))))
        end do
      end do
      if ( config%opt_cini ) then
        zwork_dir(npvar+1:nvar,:) = config%cinierr**2 * transpose(dble(hmat(:,npvar+1:nvar)))
      endif
      filename = trim(files%path_output)//'zwork.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nobs,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) zwork(i,:)
      end do
      close(100)
      filename = trim(files%path_output)//'zwork_dir.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nobs,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) zwork_dir(i,:)
      end do
      close(100)
      deallocate(zwork_dir)
    endif

    ! calculate (HBH^T + R)
    ! ---------------------

    allocate ( mwork(nobs,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nobs: not enough memory'
      stop
    endif
    mwork = matmul(dble(hmat),zwork)
    do i = 1, nobs
      mwork(i,i) = mwork(i,i) + dble(obs%err(i)**2)
    end do

    ! calculate (HBH^T + R)^-1
    ! ------------------------

    ! use Cholesky factorization
    call dpotrf('L',nobs,mwork,nobs,info)
    if ( info.gt.0 ) then
      write(logid,*) 'ERROR invert_nobs: lower Cholesky triangle'
      stop
    endif
    call dpotri('L',nobs,mwork,nobs,info)
    if ( info.gt.0 ) then
      write(logid,*) 'ERROR invert_nobs: Cholesky inverse'
      stop
    endif
    ! returns lower ('L') triangle of inverse of mwork
    allocate ( imwork(nobs,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nobs: not enough memory'
      stop
    endif
    imwork = mwork
    do i = 1, nobs
      do j = 1, i
        imwork(j,i) = mwork(i,j)
      end do
    end do

    deallocate(mwork)

    ! gain matrix G = BH^T(HBH^T + R)^(-1)
    ! ------------------------------------

    allocate ( gain(nvar,nobs), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nobs: not enough memory'
      stop
    endif
    gain = matmul(zwork,imwork)

    ! calculate x = x0 + G(y - Hx0)
    ! -----------------------------

    states%px = states%px0 + matmul(real(gain,kind=4),(-1.*obs%delta))

    ! posterior error
    ! ---------------

    ! calculate A = B - GHB
    allocate ( cova(nvar,nvar), stat=ierr )
    if ( ierr.ne.0 ) then
      write(logid,*) 'ERROR invert_nobs: not enough memory'
      stop
    endif
    cova(:,:) = 0d0
    do jt = 1, ntstate
      do k = 1, ndvar
        do j = 1, neig
          ! calculate row vector of U*D
          tmp(j) = covar%evecs(k,j)*covar%evals(j)
        end do
        do i = 1, ndvar
          ! calculate row vector of B
          tmp1(1,i) = dot_product(tmp,covar%evecs(i,:))
        end do
        do it = 1, ntstate
          cova((jt-1)*ndvar+k,(it-1)*ndvar+1:it*ndvar) = covar%cort(jt,it)*tmp1(1,:) - &
               matmul(gain((jt-1)*ndvar+k,:),transpose(zwork((it-1)*ndvar+1:it*ndvar,:)))
        end do
      end do 
    end do
    if ( config%opt_cini ) then
      cova(1:npvar,npvar+1:nvar) = -1d0*matmul(gain(1:npvar,:),transpose(zwork(npvar+1:nvar,:))) 
      cova(npvar+1:nvar,:) = -1d0*matmul(gain(npvar+1:nvar,:),transpose(zwork(:,:)))
      do i = 1, ntcini*ncini
        cova(npvar+i,npvar+i) = config%cinierr**2 + cova(npvar+i,npvar+i)
      end do
    endif

    ! include cross-correlations
    do i = 1, nvar
!      if ( sum(cova(i,:)).gt.0 ) states%pxerr(i) = real(sqrt(sum(cova(i,:))),kind=4)
      if( cova(i,i).gt.0. ) states%pxerr(i) = real(sqrt(cova(i,i)),kind=4)
    end do

    ! verbose output
    ! --------------

    if (config%const_out) then
      filename = trim(files%path_output)//'cova.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nvar,'(E11.4,1X))'
      do i = 1, nvar
        write(100,rowfmt) cova(i,:)
      end do
      close(100)
    endif

    if ( config%verbose ) then
      filename = trim(files%path_output)//'imwork.txt'
      open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
      write(rowfmt,'(A,I6,A)') '(',nobs,'(E11.4,1X))'
      do i = 1, nobs
        write(100,rowfmt) imwork(i,:)
      end do
      close(100)
    endif

    filename = trim(files%path_output)//'gain.txt'
    open(100,file=trim(filename),status='replace',action='write',iostat=ierr)
    write(rowfmt,'(A,I6,A)') '(',nobs,'(E11.4,1X))'
    do i = 1, nvar
      write(100,rowfmt) gain(i,:)
    end do
    close(100)

    deallocate(gain)
    deallocate(imwork)
    deallocate(cova)
    deallocate(zwork)

  end subroutine invert_nobs

  ! --------------------------------------------------
  ! calc_cost
  ! --------------------------------------------------

  subroutine calc_cost(config, obs, states, cost_o)

    implicit none

    type (config_t),            intent (in)     :: config
    type (obs_t),               intent (in out) :: obs
    type (states_t),            intent (in)     :: states
    real,                       intent (in out) :: cost_o
    integer                                     :: i

    ! posterior mixing ratio contribution from fluxes in domain
    obs%model = matmul(hmat(:,1:npvar),states%px(1:npvar))

    ! posterior initial mixing ratios
    if ( config%opt_cini ) then
        obs%cinipos(:) = matmul(hmat(:,npvar+1:nvar),states%px(npvar+1:nvar))
    else
      obs%cinipos(:) = sum(obs%cini(:,:),dim=2)
    endif

    ! cost observation space
    ! Jo = (y - Hx)^TR^(-1)(y - Hx)
    cost_o = 0.
    do i = 1, nobs
      if ( trim(config%spec).eq.'co2' ) then
        obs%delta(i) = obs%cpri(i) - obs%cakpri(i) + &
                        obs%model(i) + obs%nee(i) + obs%fff(i) + obs%ocn(i) - &
                        obs%conc(i) + obs%bkg(i) + obs%cinipos(i)
      else
        obs%delta(i) = obs%cpri(i) - obs%cakpri(i) + &
                        obs%model(i) + obs%ghg(i) - obs%conc(i) + obs%bkg(i) + obs%cinipos(i)
      endif
      cost_o = cost_o + obs%delta(i)**2/obs%err(i)**2
    end do

  end subroutine calc_cost

  ! --------------------------------------------------


end module mod_analytic