!! Copyright (C) 2004-2013 M. Oliveira, F. Nogueira
!! Copyright (C) 2012 T. Cerqueira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module eigensolver_m
  use global_m
  use oct_parser_m
  use messages_m
  use io_m
  use units_m
  use mesh_m
  use quantum_numbers_m
  use potentials_m
  use wave_equations_m
  implicit none


                    !---Derived Data Types---!

  type eigensolver_t
    real(R8) :: tol
    integer  :: mode
  end type eigensolver_t

  type ld_e
     real(R8) :: e
     real(R8) :: ldd
     integer  :: nnodes
     logical  :: refine
     type(ld_e), pointer :: next
  end type ld_e


                   !---Global Variables---!

  integer, parameter :: EIGEN_SAFE = 1, &
                        EIGEN_FAST = 2


                    !---Public/Private Statements---!

  private
  public :: eigensolver_t, &
            eigensolver_init, &
            eigensolver_bracket, &
            eigensolver_find_ev, &
            eigensolver_end, &
            EIGEN_SAFE, EIGEN_FAST

contains

  !-----------------------------------------------------------------------
  !> Initialize eigensolver object.
  !-----------------------------------------------------------------------
  subroutine eigensolver_init(eigensolver)
    type(eigensolver_t), intent(out) :: eigensolver

    call push_sub("eigensolver_init")

    !Read Eingensolver tolerance
    eigensolver%tol = oct_parse_f90_double('EigenSolverTolerance', 1.0E-8_r8)
    eigensolver%tol = eigensolver%tol*units_in%energy%factor
    if (eigensolver%tol <= M_ZERO) then
      message(1) = "EigenSolverTolerance must be positive."
      call write_fatal(1)
    end if

    !Read Eingensolver mode
    eigensolver%mode = oct_parse_f90_int('EigenSolverMode', EIGEN_FAST)
    message(1) = ""
    message(2) = "Eigensolver Info"
    call write_info(2)
    write(message(1),'(2X,"Method:    ",A)') "Brent's method"
    write(message(2),'(2X,"Tolerance: ",ES10.3E2)') eigensolver%tol
    select case (eigensolver%mode)
    case (EIGEN_SAFE)
      message(3) = "  Mode: safe"
    case (EIGEN_FAST)
      message(3) = "  Mode: fast"
    case default
      message(1) = "Unknown EigenSolverMode"
      call write_fatal(1)
    end select
    call write_info(3,20)

    call pop_sub()
  end subroutine eigensolver_init

  !-----------------------------------------------------------------------
  !> End eigensolver object.
  !-----------------------------------------------------------------------
  subroutine eigensolver_end(eigensolver)
    type(eigensolver_t), intent(inout) :: eigensolver

    call push_sub("eigensolver_end")

    !There is nothing to end yet

    call pop_sub()
  end subroutine eigensolver_end

  !-----------------------------------------------------------------------
  !> Bracket the eigenvalue.                                               
  !-----------------------------------------------------------------------
  subroutine eigensolver_bracket(n_ev, qns, wave_eq, eigensolver, potential, integrator, brackets, bracketed)
    integer,             intent(in)    :: n_ev              !< number of eigenvalues to be bracketed
    type(qn_t),          intent(in)    :: qns(n_ev)         !< sets of quantum numbers
    integer,             intent(in)    :: wave_eq           !< wave-equation to use
    type(eigensolver_t), intent(in)    :: eigensolver       !< eigensolver object
    type(potential_t),   intent(in)    :: potential         !< potential to use in the wave-equation
    type(integrator_t),  intent(inout) :: integrator        !< integrator object
    real(R8),            intent(out)   :: brackets(2, n_ev) !< information about upper and lower bounds of the
                                                            !! intervals containing the eigenvalues
    logical,             intent(out)   :: bracketed(n_ev)   !< did the eigensolver manage to bracket the eigenvalues?

    integer :: i, min_nnodes, max_nnodes
    real(R8) :: h
    type(ld_e), pointer :: first, ptr, next, new_ptr, p1, p2, p3, p4
    type(qn_t) :: qn
    integer, allocatable :: nnodes(:)
 
    call push_sub("eigensolver_bracket")

    if (n_ev == 0) then
      !No eigenvalues to bracket
      call pop_sub()
      return
    end if

    nullify(first, next, ptr, new_ptr)
    nullify(p1, p2, p3, p4)
    brackets = M_ZERO

    !Quantum numbers should be ordered by increasing node number
    do i = 2, n_ev
      ASSERT(qns(i)%n >= qns(i-1)%n)
    end do

    !We use qn for calling the wave equation integrator
    qn = qns(1)

    !Number of nodes of the eigenvalues
    allocate(nnodes(n_ev))
    nnodes = qn_number_of_nodes(qns)
    min_nnodes = minval(nnodes)
    max_nnodes = maxval(nnodes)

    !Initial bracket
    allocate(first); call ld_e_null(first)
    allocate(next);  call ld_e_null(next)
    first%e = wave_equation_emin(qns(1), wave_eq, potential)
    next%e  = wave_equation_emax(qns(n_ev), potential)
    first%ldd = wave_equation_ld_diff(qn, first%e, wave_eq, potential, integrator, first%nnodes)
    next%ldd  = wave_equation_ld_diff(qn, next%e,  wave_eq, potential, integrator, next%nnodes)
    first%refine = .true.
    next%refine = .false.
    h = abs(next%e - first%e)
    first%next => next

    !Bracketing
    main: do
      !Check which eigenvalues have been bracketed
      p1 => first
      bracketed = .false.
      do        
        !We need four points, because we are going to check if we have a zero or a discontinuity
        p2 => p1%next
        if (.not. associated(p2%next)) exit
        p3 => p2%next
        if (.not. associated(p3%next)) exit
        p4 => p3%next

        !Now the actual check
        do i = 1, n_ev
          if (bracketed(i)) cycle
          if (p2%ldd*p3%ldd < M_ZERO .and. &
              p2%nnodes == nnodes(i) .and. p3%nnodes == nnodes(i) &
             .and. abs(p1%ldd) > abs(p2%ldd) .and. &
             abs(p3%ldd) < abs(p4%ldd) &
             ) then
            brackets(1,i) = p2%e
            brackets(2,i) = p3%e
            bracketed(i) = .true.
            exit
          end if
        end do
        p1 => p2
      end do

      !If we have bracketed all the eigenvalues then we are done!
      if (all(bracketed)) exit

      !Check if there is something wrong
      if (h < sqrt(eigensolver%tol)) then
        if (in_debug_mode) then
          call bracket_eigenvalue_debug()
          call potential_debug(potential)
        end if
        exit
      end if

      !Check which brackets should be further refined
      ptr => first
      do
        next => ptr%next
        !We should not refine this bracket if it:
        if (ptr%refine .and. (&
             ! already contains an eigenvalue and no missing eigenvalues have the same number of nodes
             (any(brackets(1,:) == ptr%e) .and. .not. any(nnodes == ptr%nnodes .and. .not.bracketed)) .or. &
             ! has less nodes than the minimum number of nodes we are looking for
             (ptr%nnodes < min_nnodes .and. next%nnodes < min_nnodes) .or. &
             ! has more nodes than the maximum number of nodes we are looking for
             (ptr%nnodes > max_nnodes .and. next%nnodes > max_nnodes) .or. &
             ! has a number of nodes we do not require
             (ptr%nnodes == next%nnodes .and. .not. any(ptr%nnodes == nnodes)) .or. &
             ! there cannot be an eigenvalue in this interval
             ! (for the spin-polarized relativistic case we will not use this condition,
             !  because the function can have discontinuities)
             (ptr%nnodes == next%nnodes .and. ptr%ldd*next%ldd > M_ZERO .and. &
             count(nnodes == ptr%nnodes) == 1 .and. eigensolver%mode == EIGEN_FAST) &
             )) then
          ptr%refine = .false.
        end if
        ptr => next
        if (.not. associated(ptr%next)) exit
      end do

      !Refine brackets
      ptr => first
      h = h*M_HALF
      do
        next => ptr%next
        if (ptr%refine) then
          allocate(new_ptr)
          new_ptr%e = ptr%e + (next%e - ptr%e)*M_HALF
          new_ptr%ldd = wave_equation_ld_diff(qn, new_ptr%e, wave_eq, potential, integrator, new_ptr%nnodes)
          new_ptr%refine = .true.
          new_ptr%next => next
          ptr%next => new_ptr
          nullify(new_ptr)
        end if
        ptr => next
        if (.not. associated(ptr%next)) exit
      end do

    end do main

    !Deallocate everything
    deallocate(nnodes)
    call end_list(first)

    call pop_sub()
  contains
    subroutine bracket_eigenvalue_debug()
      integer :: unit, i

      call io_open(unit, 'debug_info/bracket_eigenvalue')

      ptr => first
      do
        write(unit,'(F12.5,1X,ES10.3E2,1X,I2,1X,L2)') ptr%e, ptr%ldd, ptr%nnodes, ptr%refine
        if (.not. associated(ptr%next)) exit
        ptr => ptr%next
      end do
      close(unit)

      call io_open(unit, 'debug_info/bracketed_eigenvalues')
      do i = 1, n_ev
        write(unit,'(F12.5,1X,F12.5,1X,I2,1X,F4.1)') brackets(1, i), brackets(2, i), nnodes(i)
      end do

      close(unit)

    end subroutine bracket_eigenvalue_debug

  end subroutine eigensolver_bracket

  subroutine ld_e_null(ptr)
    type(ld_e), pointer :: ptr

    ptr%e = M_ZERO
    ptr%ldd = M_ZERO
    ptr%nnodes = 0
    nullify(ptr%next)

  end subroutine ld_e_null

  subroutine end_list(ptr)
    type(ld_e), pointer :: ptr
    
    type(ld_e), pointer :: next

    do
      if (.not. associated(ptr%next)) exit
      next => ptr%next
      deallocate(ptr)
      ptr => next
    end do
    deallocate(ptr)

  end subroutine end_list

  !-----------------------------------------------------------------------
  !> Use Brents method to find the eigenvalue.                             
  !-----------------------------------------------------------------------
  subroutine eigensolver_find_ev(qn, wave_eq, eigensolver, potential, integrator, bracket, ev)
    type(qn_t),          intent(in)    :: qn          !< set of quantum numbers
    integer,             intent(in)    :: wave_eq     !< wave-equation to use
    type(eigensolver_t), intent(in)    :: eigensolver !< eigensolver object
    type(potential_t),   intent(in)    :: potential   !< potential to use in the wave-equation
    type(integrator_t),  intent(inout) :: integrator  !< integrator object
    real(R8),            intent(in)    :: bracket(2)  !< information about upper and lower bounds of the
                                                      !! interval containing the eigenvalue
    real(R8),            intent(out)   :: ev          !< eigenvalue

    integer  :: nnodes_dum
    real(R8) :: a, b, c, fa, fb, fc, tol1, n, d, e, r, s, q, p, eps = epsilon(M_ONE)

    call push_sub("find_eigenvalue")

    a = bracket(1)
    b = bracket(2)
    if (a == M_ZERO .and. b == M_ZERO) then !Unbound state
      ev = M_ZERO
      call pop_sub()
      return
    end if
    fa = wave_equation_ld_diff(qn, a, wave_eq, potential, integrator, nnodes_dum)
    fb = wave_equation_ld_diff(qn, b, wave_eq, potential, integrator, nnodes_dum)
    c = b
    fc = fb

    do
      if ((fb < M_ZERO .and. fc < M_ZERO) .or. (fb > M_ZERO .and. fc > M_ZERO)) then
        c = a; fc = fa
        d = b - a; e = d
      end if

      if (abs(fc) < abs(fb)) then
        a = b; b = c; c = a
        fa = fb; fb = fc; fc = fa
      end if
      tol1 = M_TWO*eps*abs(b) + M_HALF*eigensolver%tol
      n = (c - b)*M_HALF

      if (abs(n) <= eigensolver%tol .or. fb == M_HALF) exit

      if (abs(e) > eigensolver%tol .and. abs(fa) > abs(fb)) then
        s = fb/fa
        if (a == c) then !Linear interpolation
          p = M_TWO*n*s
          q = M_ONE - s
        else !Inverse quadratic interpolation
          q = fa/fc
          r = fb/fc
          p = s*(M_TWO*n*q*(q - r) - (b - a)*(r - M_ONE))
          q = (q - M_ONE)*(r - M_ONE)*(s - M_ONE)           
        end if
        if (p > M_ZERO) then
          q = -q
        else
          p = abs(p)
        end if
        s = e
        e = d
        if (M_TWO*p < min(M_THREE*n*q - abs(tol1*q), abs(M_HALF*s*q))) then
          d = p/q
        else
          d = n
          e = d
        end if
      end if

      a = b
      fa = fb
      if (abs(d) > tol1) then
        b = b + d
      else
        b = b + sign(tol1, n)
      end if

      fb = wave_equation_ld_diff(qn, b, wave_eq, potential, integrator, nnodes_dum)
    end do

    ev = b
    call pop_sub()
  end subroutine eigensolver_find_ev

end module eigensolver_m
