!! Copyright (C) 2004-2014 M. Oliveira, F. Nogueira
!! Copyright (C) 2011-2012 T. Cerqueira
!! Copyright (C) 2014 P. Borlido
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module ode_integrator_m
  use global_m
  use messages_m
  use io_m
  use gsl_interface_m
  use quantum_numbers_m
  use potentials_m
  implicit none

  integer, external :: ode_solve

                    !---Derived Data Types---!

  type ode_integrator_t
    integer  :: ode
    integer  :: dim
    integer  :: nstepmax
    integer(POINTER_SIZE) :: evl, stp, ctrl
    type(qn_t) :: qn
    real(R8)   :: e
    type(potential_t) :: potential
  end type ode_integrator_t


                    !---Global Variables---!

  integer, parameter :: ODE_RK2   = 1, &
                        ODE_RK4   = 2, &
                        ODE_RKF4  = 3, &
                        ODE_RKCK4 = 4, &
                        ODE_RKPD8 = 5

  integer, parameter :: ODE_SCHRODINGER  = 11, &
                        ODE_SCALAR_REL   = 22, &
                        ODE_DIRAC        = 33, &
                        ODE_DIRAC_POL1   = 44, &
                        ODE_DIRAC_POL2   = 55

  real(R8), parameter :: ODE_FINF = 1.0E-20_r8 !< Functions values at infinity


                    !---Public/Private Statements---!

  private
  public :: ode_integrator_t, &
            ode_integrator_null, &
            ode_integrator_init, &
            ode_integrator_end, &
            ode_integration, &
            ode_function_to_wavefunction, &
            ode_match_functions, &
            ode_function_mismatch, &
            ode_practical_infinity, &
            ODE_FINF, &
            ODE_RK2, &
            ODE_RK4, &
            ODE_RKF4, &
            ODE_RKCK4, &
            ODE_RKPD8, &
            ODE_SCHRODINGER, &
            ODE_SCALAR_REL, &
            ODE_DIRAC, &
            ODE_DIRAC_POL1, &
            ODE_DIRAC_POL2

contains

  !-----------------------------------------------------------------------
  !> Nullifies and sets to zero all the components of the ODE integrator. 
  !-----------------------------------------------------------------------
  subroutine ode_integrator_null(odeint)
    type(ode_integrator_t), intent(out) :: odeint

    odeint%ode = 0
    odeint%dim = 0
    odeint%nstepmax = 0
    odeint%qn = QN_NULL
    odeint%e = M_ZERO
    call potential_null(odeint%potential)

  end subroutine ode_integrator_null

  !-----------------------------------------------------------------------
  !> Initialize the ODE integrator.                                       
  !-----------------------------------------------------------------------
  subroutine ode_integrator_init(odeint, ode, stepping_function, nstepmax, tol, qn, ev, potential)
    type(ode_integrator_t), intent(inout) :: odeint            !< ODE integrator
    integer,                intent(in)    :: ode               !< equation to be integrated
    integer,                intent(in)    :: stepping_function !< stepping function used during integration
    integer,                intent(in)    :: nstepmax          !< maximum number of steps allowed during integration
    real(R8),               intent(in)    :: tol               !< tolerance
    type(qn_t),             intent(in)    :: qn                !< set of quantum numbers
    type(potential_t),      intent(in)    :: potential         !< energy
    real(R8),               intent(in)    :: ev                !< potential object

    call push_sub("ode_integrator_init")

    odeint%ode = ode
    odeint%dim = 2
    if (ode == ODE_DIRAC_POL2) odeint%dim = 8
    odeint%nstepmax = nstepmax
    odeint%qn = qn
    odeint%e = ev
    odeint%potential = potential

    !Initialize GSL objects
    call gsl_odeiv_step_alloc(stepping_function, odeint%dim, odeint%stp)
    call gsl_odeiv_evolve_alloc(odeint%dim, odeint%evl)
    call gsl_odeiv_control_standart_new(odeint%ctrl, 1.0E-22_r8, tol, M_ONE, M_ONE)

    call pop_sub()
  end subroutine ode_integrator_init

  !-----------------------------------------------------------------------
  !> Frees all the memory associated with the GSL objects used by the     
  !> ODE solver.                                                          
  !-----------------------------------------------------------------------
  subroutine ode_integrator_end(odeint)
    type(ode_integrator_t), intent(inout) :: odeint

    call push_sub("ode_integrator_end")

    if (odeint%nstepmax > 0) then
      !Free GSL objects for simple precision integrations
      call gsl_odeiv_step_free(odeint%stp)
      call gsl_odeiv_evolve_free(odeint%evl)
      call gsl_odeiv_control_free(odeint%ctrl)
    end if

    odeint%ode = 0
    odeint%dim = 0
    odeint%nstepmax = 0
    odeint%qn = QN_NULL
    odeint%e = 0
    call potential_end(odeint%potential)

    call pop_sub()
  end subroutine ode_integrator_end

  !-----------------------------------------------------------------------
  !> Integrates the radial wave-equation from r1 to r2.                   
  !-----------------------------------------------------------------------
  subroutine ode_integration(odeint, r1, r2, nstep, r, f)
    type(ode_integrator_t), intent(in)  :: odeint !< ode integrator object
    real(R8),               intent(in)  :: r1     !< starting point for integration
    real(R8),               intent(in)  :: r2     !< final point for integration
    integer,                intent(out) :: nstep  !< number of steps taken by the solver
    real(R8),               pointer     :: r(:)   !< mesh used by the ODE solver
    real(R8),               pointer     :: f(:,:) !< functions

    integer  :: ierr
    real(R8), allocatable :: rmax(:), fmax(:,:)

    call push_sub("ode_integration")

    !Allocate work arrays
    allocate(rmax(odeint%nstepmax), fmax(odeint%nstepmax, odeint%dim))

    !Set initial points for integration
    rmax(1) = r1
    if (r1 < r2) then
      !Outward integration
      call ode_bc_origin(odeint, rmax(1), fmax(1,:))
    else
      !Inward integration
      call ode_bc_infinity(odeint, rmax(1), fmax(1,:))
    end if

    !Integrate the equation from r1 to r2
    ierr = ode_solve(r2, odeint%stp, odeint%evl, odeint%ctrl, &
                     (r1+r2)/M_TWO, odeint%nstepmax, nstep, rmax, &
                     odeint%dim, fmax, odeint, ode_derivatives)

    if (ierr /= 0) then
      if (in_debug_mode) then
        call ode_integrator_debug(odeint, r1, r2, nstep, rmax, fmax)
      end if
      message(1) = "Error in subtoutine ode_integration. Error message:"
      call gsl_strerror(ierr, message(2))
      call write_fatal(2)
    end if

    !Copy the functions to new arrays
    allocate(r(nstep), f(nstep, odeint%dim))
    r(1:nstep) = rmax(1:nstep)
    f(1:nstep,1:odeint%dim) = fmax(1:nstep,1:odeint%dim)

    !Deallocate arrays
    deallocate(rmax, fmax)

    call pop_sub()
  end subroutine ode_integration

  !-----------------------------------------------------------------------
  !> Given a set of solutions to the ODE, returns the corresponding       
  !> wave-functions and wave-functions derivatives.                       
  !-----------------------------------------------------------------------
  subroutine ode_function_to_wavefunction(odeint, r, f, wf, wfp)
    type(ode_integrator_t), intent(in)  :: odeint        !< ode integrator object
    real(R8),               intent(in)  :: r             !< r coodinate
    real(R8),               intent(in)  :: f(odeint%dim) !< ODE solution at r
    real(R8),               intent(out) :: wf(:)         !< wave-functions at r
    real(R8),               intent(out) :: wfp(:)        !< wave-functions derivatives at r

    real(R8) :: fp(odeint%dim)

    call push_sub("ode_function_to_wavefunction")

    call ode_derivatives(odeint, r, f, fp)

    select case (odeint%ode)
    case (ODE_SCHRODINGER, ODE_SCALAR_REL)
      wf(1) = f(1)
      wfp(1) = fp(1)

    case (ODE_DIRAC, ODE_DIRAC_POL1)
      if (odeint%qn%sg == M_ZERO) then
        wf(1:2) = f(1:2)
        wfp(1:2) = fp(1:2)
      elseif (odeint%qn%sg == -M_HALF) then
        wf(1:2) = f(1:2)
        wfp(1:2) = fp(1:2)
        wf(3:4) = M_ZERO
        wfp(3:4) = M_ZERO
      elseif (odeint%qn%sg == M_HALF) then
        wf(3:4) = f(1:2)
        wfp(3:4) = fp(1:2)
        wf(1:2) = M_ZERO
        wfp(1:2) = M_ZERO
      end if

    case (ODE_DIRAC_POL2)
      wf(1:4) = f(1:4) + f(5:8)
      wfp(1:4) = fp(1:4) + fp(5:8)

    case default
      message(1) = "Error in ode_function_to_wavefunction"
      call write_fatal(1)

    end select

    call pop_sub()
  end subroutine ode_function_to_wavefunction

  !-----------------------------------------------------------------------
  !> Given the outward and inward solutions to the ODE, match all         
  !> functions minus one.                                                 
  !-----------------------------------------------------------------------
  subroutine ode_match_functions(odeint, nstep_out, nstep_in, f_out, f_in)
    type(ode_integrator_t), intent(in)    :: odeint                       !< ode integrator object
    integer,                intent(in)    :: nstep_out                    !< number of points of f_out
    integer,                intent(in)    :: nstep_in                     !< number of points of f_in
    real(R8),               intent(inout) :: f_out(nstep_out, odeint%dim) !< outward solutions
    real(R8),               intent(inout) :: f_in (nstep_in,  odeint%dim) !< inward solutions

    real(R8) :: a, b, c
    real(R8) :: ap1_o, ap2_o, ap1_i, ap2_i
    real(R8) :: am1_o, am2_o, am1_i, am2_i
    real(R8) :: bp1_o, bp2_o, bp1_i, bp2_i
    real(R8) :: bm1_o, bm2_o, bm1_i, bm2_i

    call push_sub("ode_match_functions")

    !Make sure the inward wavefunctions are positive
    if (f_in(nstep_in, 1) < M_ZERO) f_in = -f_in      

    select case (odeint%ode)
    case (ODE_SCHRODINGER, ODE_SCALAR_REL, ODE_DIRAC, ODE_DIRAC_POL1)
      f_out = f_out*f_in(nstep_in,1)/f_out(nstep_out,1)

    case (ODE_DIRAC_POL2)
      ap1_o = f_out(nstep_out, 1) ; ap1_i = f_in(nstep_in, 1)
      bp1_o = f_out(nstep_out, 2) ; bp1_i = f_in(nstep_in, 2)
      am1_o = f_out(nstep_out, 3) ; am1_i = f_in(nstep_in, 3)
      bm1_o = f_out(nstep_out, 4) ; bm1_i = f_in(nstep_in, 4)
      ap2_o = f_out(nstep_out, 5) ; ap2_i = f_in(nstep_in, 5)
      bp2_o = f_out(nstep_out, 6) ; bp2_i = f_in(nstep_in, 6)
      am2_o = f_out(nstep_out, 7) ; am2_i = f_in(nstep_in, 7)
      bm2_o = f_out(nstep_out, 8) ; bm2_i = f_in(nstep_in, 8)

      a = ((am1_i*ap2_i-ap1_i*am2_i)*bm2_o + (ap1_i*am2_o-am1_i*ap2_o)*bm2_i + &
           (am2_i*ap2_o-am2_o*ap2_i)*bm1_i)/((am1_i*ap1_o-ap1_i*am1_o)*bm2_o + &
           (ap1_i*am2_o-am1_i*ap2_o)*bm1_o + (am1_o*ap2_o-am2_o*ap1_o)*bm1_i)
      b = ((am1_i*ap1_o-ap1_i*am1_o)*bm2_i + (ap1_i*am2_i-am1_i*ap2_i)*bm1_o + &
           (am1_o*ap2_i-am2_i*ap1_o)*bm1_i)/((am1_i*ap1_o-ap1_i*am1_o)*bm2_o + &
           (ap1_i*am2_o-am1_i*ap2_o)*bm1_o + (am1_o*ap2_o-am2_o*ap1_o)*bm1_i)
      c = ((am1_o*ap2_i-am2_i*ap1_o)*bm2_o + (am2_o*ap1_o-am1_o*ap2_o)*bm2_i + &
           (am2_i*ap2_o-am2_o*ap2_i)*bm1_o)/((am1_i*ap1_o-ap1_i*am1_o)*bm2_o + &
           (ap1_i*am2_o-am1_i*ap2_o)*bm1_o + (am1_o*ap2_o-am2_o*ap1_o)*bm1_i)

        f_out(1:nstep_out,1:4) = a*f_out(1:nstep_out,1:4)
        f_out(1:nstep_out,5:8) = b*f_out(1:nstep_out,5:8)
        f_in(1:nstep_in,1:4)   = c*f_in(1:nstep_in,1:4)

    case default
      message(1) = "Error in ode_match_functions"
      call write_fatal(1)

    end select

    call pop_sub()
  end subroutine ode_match_functions

  !-----------------------------------------------------------------------
  !> Given the outward and inward solutions to the ODE, match all         
  !> functions minus one and return the mismatch of the remaining function
  !-----------------------------------------------------------------------
  function ode_function_mismatch(odeint, r, f_out, f_in)
    type(ode_integrator_t), intent(in) :: odeint            !< ode integrator object
    real(R8),               intent(in) :: r                 !< coordinate where to evaluate the mismatch
    real(R8),               intent(in) :: f_out(odeint%dim) !< outward solutions
    real(R8),               intent(in) :: f_in(odeint%dim)  !< inward solutions
    real(R8) :: ode_function_mismatch

    real(R8) :: fp_out(odeint%dim), fp_in(odeint%dim)

    call push_sub("ode_function_mismatch")

    !Calculate the function mismatch
    select case (odeint%ode)
    case (ODE_SCHRODINGER, ODE_SCALAR_REL, ODE_DIRAC, ODE_DIRAC_POL1)
      call ode_derivatives(odeint, r, f_out, fp_out)
      call ode_derivatives(odeint, r, f_in, fp_in)

      ode_function_mismatch = fp_out(1)/f_out(1) - fp_in(1)/f_in(1)

    case (ODE_DIRAC_POL2)
      ode_function_mismatch = (f_out(2) + f_out(6)) - (f_in(2) + f_in(6))

    case default
      message(1) = "Error in ode_function_mismatch"
      call write_fatal(1)

    end select

    call pop_sub()
  end function ode_function_mismatch

  !-----------------------------------------------------------------------
  !> Set the value of the functions at a point close to the origin.       
  !-----------------------------------------------------------------------
  subroutine ode_bc_origin(odeint, r0, f)
    type(ode_integrator_t), intent(in)  :: odeint        !< integrator object
    real(R8),               intent(in)  :: r0            !< point close to origin
    real(R8),               intent(out) :: f(odeint%dim) !< function at r0

    real(R8) :: e, z, vp0, s, a0, a1, a2, a3, b0, b1, b2, w
    real(R8) :: a01, a12, b01, d02, d11, bxc_0, clm, c02, b12, c11
    type(qn_t) :: qn

    call push_sub("ode_bc_origin")

    ! For a potential v(r) = vp(r) - Z/r, the general solutions of the equations 
    ! when r->0 are of the form:
    !
    !   g(r) = r**(s-1) (a0 + a1 r + a2 r**2 + ...)
    !   f(r) = r**(s-1) (b0 + b1 r + b2 r**2 + ...)
    ! 
    ! Schrodinger equation:
    !   s = l
    !   a0 = 0
    !   a2 = - Z/(l + 1) a1
    !   a3 = [-Z a2 + (vp - e) a1]/(2l + 3)
    !   b0 = l a1
    !   b1 = (l + 1) a2
    !   b2 = (l + 2) a3
    !
    ! Scalar-relativistic equation:
    !   s  = sqrt(l(l + 1) + 1 - (Z/c)**2)
    !   a1 = [(e + 2c**2 - vp) b0 - Z/c (2e + 2c**2 - vp) a0]/[(2s + 1)c]
    !   a2 = [(e + 2c**2 - vp)(vp - e)/c a0 + (e + 2c**2 - vp) b1 - Z/c (2e + 2c**2 - vp) a1]/[2(2s + 2)c]
    !   b0 = c/Z (s - 1) a0
    !   b1 = c/Z [s a1 - (e + 2c**2 - vp)/c b0]
    !   b2 = c/Z [(s + 1) a2 - (e + 2c**2 - vp)/c b1]
    !
    ! Dirac equation:
    !   s  = sqrt(k**2 - (Z/c)**2)
    !   a0 = Z/c b0/(s + k)  v  a0 = - c/Z (s - k) b0
    !   a1 = [(s + 1 - k)(e + 2c**2 - vp) b0 - Z/c (e - vp) a0]/(2s + 1)/c
    !   a2 = [(s + 2 - k)(e + 2c**2 - vp) b1 - Z/c (e - vp) a1]/(2s + 2)/(2c)
    !   b0 = -Z/c a0/(s - k)  v  b0 = c/Z (s + k) a0
    !   b1 = [-(e - vp)/c a0 - Z/c a1]/(s + 1 - k)
    !   b2 = [-(e - vp)/c a1 - Z/c a2]/(s + 2 - k)
    !
    ! Polarized Spin Dirac equation:
    !  g1(r) = r^(s1-1) a01 + r^s2 b11
    !  f1(r) = r^(s1-1) c01 + r^s2 d11
    !  g2(r) = r^s1 a12 + r^(s2-1) b02
    !  f2(r) = r^s1 c12 + r^(s2-1) d02
    !
    !  s1  = sqrt( (l + 1)**2 - (z/c)**2 ) - 1
    !  s2  = sqrt( l**2 - (z/c)**2 ) - 1
    !  clm = sqrt( (2l + 1)**2 - (2m)**2 ) / (2l + 1)
    !  c01 = - Z/c a01 / (s1+l+1)
    !  a12 = + Z/c a01/c B0/2 Clm / (l+s1+1)
    !  c12 = + B0/2 Clm/c a01
    !  b02 = Z/c d02/(s2+l) 
    !  b11 = - B0/2 Clm/c d02
    !  d11 = Z/c d02/c B0/2 Clm / (l+s2)
    
    ! 
    ! NOTE: for the scalar-relativistic equation and for Dirac
    !       equation Z=0 is a special case.

    z = potential_nuclear_charge(odeint%potential)
    vp0 = v(odeint%potential, r0, odeint%qn) + z/r0
    e = odeint%e
    qn = odeint%qn

    select case (odeint%ode)
    case (ODE_SCHRODINGER)
      s = qn%l
      a0 = M_ZERO
      a1 = M_ONE
      a2 = -z/(s + M_ONE)*a1
      a3 = (-z*a2 + (vp0 - e)*a1)/(M_TWO*s + M_THREE)
      b0 = s*a1
      b1 = (s + 1)*a2
      b2 = (s + 2)*a3

    case (ODE_SCALAR_REL)
      w = e + M_TWO*M_C**2 - vp0
      if (z /= M_ZERO) then
        s = sqrt(qn%l*(qn%l + M_ONE) + M_ONE - (z/M_C)**2)
        a0 = M_ONE
        b0 = M_C/z*(s - M_ONE)*a0
        a1 = (w*b0 - z/M_C*M_TWO*(e + M_C - vp0)*a0)/(M_TWO*s + M_ONE)/M_C
        b1 = M_C/z*(s*a1 - w*b0/M_C)
        a2 = (w*(vp0 - e)*a0/M_C + w*b1 - z/M_C*M_TWO*(e + M_C - vp0)*a1)/(M_FOUR*s + M_FOUR)/M_C
        b2 = M_C/z*((s + M_ONE)*a2 - w*b1/M_C)
        a3 = M_ZERO
      else
        if (qn%l == 0) then
          s = M_ONE
          a0 = M_ONE
          b1 = (vp0 - e)*a0/M_THREE
          a2 = w/(M_TWO*M_C)*b1
          b0 = M_ZERO ; a1 = M_ZERO ; b2 = M_ZERO ; a3 = M_ZERO
        else
          s = qn%l
          b0 = M_ONE
          a1 = w/(s*M_C)*b0
          a0 = M_ZERO ; b1 = M_ZERO ; a2 = M_ZERO ; b2 = M_ZERO ; a3 = M_ZERO
        end if
      end if

    case (ODE_DIRAC)
      if (z /= M_ZERO) then
        a0 = M_ONE
        s = sqrt(qn%k**2 - (z/M_C)**2)
        b0 = M_C/z*(s + qn%k)*a0
      else
        if (qn%k < 0) then
          s = -qn%k
          a0 = M_ONE
          b0 = M_ZERO
        else
          s = qn%k
          a0 = M_ZERO
          b0 = M_ONE
        end if
      end if
      w = e + M_TWO*M_C**2 - vp0
      a1 = (w*(s + M_ONE - qn%k)*b0 - z/M_C*(e - vp0)*a0)/(M_TWO*s + M_ONE)/M_C
      b1 = -((e - vp0)*a0 + z*a1)/(s + M_ONE - qn%k)/M_C
      a2 = (w*(s + M_TWO - qn%k)*b1 - z/M_C*(e - vp0)*a1)/(M_TWO*s + M_TWO)/(M_TWO*M_C)
      b2 = -((e - vp0)*a1 + z*a2)/(s + M_TWO - qn%k)/M_C
      a3 = M_ZERO

    case (ODE_DIRAC_POL1)
      s  = sqrt( (qn%l + M_ONE)**2 - (z/M_C)**2 )
      a0 = M_ONE
      b0 = -Z/M_C*a0/(s+qn%l+M_ONE)
      a1 = M_ZERO
      b1 = M_ZERO
      a2 = M_ZERO
      b2 = M_ZERO
      a3 = M_ZERO

    case (ODE_DIRAC_POL2)
      clm = -sqrt( (M_TWO*qn%l + M_ONE)**2 - M_FOUR*qn%m**2  )/(M_TWO*qn%l + M_ONE)
      bxc_0 = bxc(odeint%potential, r0)

      !Set 1
      s  = sqrt( (qn%l + M_ONE)**2 - (z/M_C)**2 )
      a01 = M_ONE
      b01 = -Z/M_C*a01/(s+qn%l+M_ONE) 
      a12 = Z/M_C*a01/M_C*bxc_0/M_TWO*Clm/(qn%l + s + M_ONE)
      b12 = bxc_0/M_TWO*clm/M_C*a01

      f(1) = a01*r0**(s - M_ONE)
      f(2) = b01*r0**(s - M_ONE)
      f(3) = a12*r0**s
      f(4) = b12*r0**s

      !Set 2
      s  = sqrt( qn%l**2 - (z/M_C)**2 )
      d02 = M_ONE
      c02 = + Z/M_C*d02/(s + qn%l)
      c11 = - bxc_0/M_TWO*clm/M_C*d02
      d11 = + Z/M_C*d02/M_C*bxc_0/M_TWO*Clm/(qn%l + s)

      f(5) = c11*r0**s
      f(6) = d11*r0**s
      f(7) = c02*r0**(s - M_ONE)
      f(8) = d02*r0**(s - M_ONE)

    case default
      message(1) = "Error in ode_bc_origin"
      call write_fatal(1)

    end select

    if (odeint%ode /= ODE_DIRAC_POL2) then
      f(1) = r0**(s - M_ONE)*(a0 + a1*r0 + a2*r0**2 + a3*r0**3)
      f(2) = r0**(s - M_ONE)*(b0 + b1*r0 + b2*r0**2)
    end if

    ! MGGA term
    if (odeint%ode == ODE_SCHRODINGER) then
      f(2) = f(2)*(M_ONE + M_TWO*vtau(odeint%potential, r0, qn))
    end if

    call pop_sub()
  end subroutine ode_bc_origin

  !-----------------------------------------------------------------------
  !> Set the value of the functions at a point far awy from the nucleus.  
  !-----------------------------------------------------------------------
  subroutine ode_bc_infinity(odeint, rinf, f)
    type(ode_integrator_t), intent(in)  :: odeint        !< integrator object
    real(R8),               intent(in)  :: rinf          !< point far away
    real(R8),               intent(out) :: f(odeint%dim) !< function at rinf

    real(R8) :: e, vinf, minf, clm, alpha, bxc_i, bxc_r
    type(qn_t) :: qn

    call push_sub("ode_bc_infinity")

    e = odeint%e
    qn = odeint%qn

    !Values of the functions at infinity
    select case (odeint%ode)
    case (ODE_SCHRODINGER)
      f(1) = ODE_FINF
      f(2) = -sqrt(-M_TWO*e)*f(1)*(M_ONE + M_TWO*vtau(odeint%potential, rinf, qn))

    case (ODE_SCALAR_REL)
      f(1) = ODE_FINF
      vinf = v(odeint%potential, rinf, qn)
      minf = M_ONE + (e - vinf)/M_TWO/M_C2
      f(2) = -sqrt(-M_TWO*minf*e)/(M_TWO*minf*M_C)*f(1)

    case (ODE_DIRAC, ODE_DIRAC_POL1)
      f(1) = ODE_FINF !a01*rinf**h *EXP(-alpha*rinf)
      if (e /= M_ZERO) then
        f(2) = e/(M_C*sqrt(-e)*sqrt(M_TWO + e/M_C2))*f(1)
      else
        f(2) = M_ZERO
      end if

    case (ODE_DIRAC_POL2)
      !first set
      f(1) = ODE_FINF !a01*rinf**h *EXP(-alpha*rinf)
      if (e /= M_ZERO) then
        clm = -sqrt( (M_TWO*qn%l + M_ONE)**2 - M_FOUR*qn%m**2  )/(M_TWO*qn%l + M_ONE)
        bxc_i = -bxc_integral(odeint%potential, ra=rinf)
        bxc_r = bxc(odeint%potential, rinf)
        alpha = sqrt(-e)*sqrt(M_TWO + e/M_C2)

        f(2) = e/(M_C*alpha)*f(1)
        f(3) = alpha*clm*bxc_i/(M_TWO*e)*f(1)

        f(4) = clm/(M_TWO*M_C)*(bxc_i - bxc_r/alpha)*f(1)
      else
        f(2) = M_ZERO
        f(3) = M_ZERO
        f(4) = M_ZERO
      end if

      !second set
      f(5) = f(3)
      f(6) = f(4)
      f(7) = f(1)
      f(8) = f(2)

    case default
      message(1) = "Error in ode_bc_infinity"
      call write_fatal(1)

    end select

    call pop_sub()
  end subroutine ode_bc_infinity

  !-----------------------------------------------------------------------
  !> Returns the value of the practical infinity. The practical infinity  
  !> is such that function at the practical infinity is of the order of   
  !> magnitude of ODE_FINF.                                               
  !-----------------------------------------------------------------------
  function ode_practical_infinity(odeint, ri, tol)
    type(ode_integrator_t), intent(in) :: odeint !< integrator object
    real(R8),               intent(in) :: ri     !< 
    real(R8),               intent(in) :: tol    !< tolerance
    real(R8) :: ode_practical_infinity

    real(R8) :: vinf, k, x, f, xm, d, h, e_max

    call push_sub("ode_practical_infinity")

    e_max = min(odeint%e, -1.0e-12_r8)

    select case (odeint%ode)
    case (ODE_SCHRODINGER)
      ode_practical_infinity = -log(ODE_FINF)/sqrt(-M_TWO*e_max)

    case (ODE_SCALAR_REL)
      ode_practical_infinity = -log(ODE_FINF)/sqrt(-M_TWO*(M_ONE + e_max/M_TWO/M_C**2)*e_max)

    case (ODE_DIRAC, ODE_DIRAC_POL1, ODE_DIRAC_POL2)
      k = sqrt(-e_max)*sqrt(M_TWO + e_max/M_C2)
      x = ri
      vinf = v(odeint%potential, x, odeint%qn)*x
      h = -vinf*(e_max + M_C2)/(k*M_C2) - M_ONE
      f = h*log(x) - k*x - log(ODE_FINF)
      d = -log(ODE_FINF)/k - x
      do
        d = d*M_HALF
        xm = x + d
        vinf = v(odeint%potential, xm, odeint%qn)*xm
        h = -vinf*(e_max+M_C2)/(k*M_C2) - 1.0 
        if (f*(h*log(xm) - k*xm - log(ODE_FINF)) > M_ZERO) x = xm
        if (h*log(xm) - k*xm - log(ODE_FINF) == M_ZERO .or. d <= tol) exit
      end do
      ode_practical_infinity = xm

    case default
      message(1) = "Error in ode_practical_infinity"
      call write_fatal(1)

    end select

    !Make sure the practical infinity is always larger than the classical turning point
    if (ode_practical_infinity <= ri) ode_practical_infinity = ri + tol

    call pop_sub()
  end function ode_practical_infinity

  !-----------------------------------------------------------------------
  !> This subroutine is supposed to be called from the GSL ODE solver. It 
  !> basically defines the system of equations to be solved by returning  
  !> the values of the derivatives of the functions.                      
  !>                                                                      
  !> The system of differential equations should be written in the        
  !> following way:                                                       
  !>                                                                      
  !>  \f$\frac{d y_i}{d r} = f_i(r, y_1(r), \dots, y_2(2))\f$
  !>
  !-----------------------------------------------------------------------
  subroutine ode_derivatives(odeint, r, y, f)
    type(ode_integrator_t), intent(in)  :: odeint        !< 
    real(R8),               intent(in)  :: r             !< 
    real(R8),               intent(in)  :: y(odeint%dim) !< 
    real(R8),               intent(out) :: f(odeint%dim) !< 

    real(R8) :: k, vr_m_e, dvdr_r, m, dmdr_r, mi, ci, c2i, ri, r2i, optvtau, llp1
    real(R8) :: clm, bxc_r, two_l, two_m
    real(R8) :: c11, c12, c33, c34, c21, c23, c22, c41, c43, c44

    vr_m_e = v(odeint%potential, r, odeint%qn) - odeint%e
    ri = M_ONE/r

    select case (odeint%ode)
    case (ODE_SCHRODINGER)
      ! Schrodinger equation
      !
      !                               d R
      ! y_1 = R     y_2 = (1 + 2 v_t) ---
      !                               d r
      !
      ! f_1 = y_2
      !
      !         2                   l(l+1)
      ! f_2 = - - y_2 + [(1 + 2 v_t)------ + 2(v_ks - e)] y_1
      !         r                    r^2      
      optvtau = M_ONE + M_TWO*vtau(odeint%potential, r, odeint%qn)
      llp1 = real(odeint%qn%l, R8)*(real(odeint%qn%l, R8) + M_ONE)
      r2i = ri*ri

      f(1) = y(2)/optvtau
      f(2) = (M_TWO*vr_m_e + optvtau*llp1*r2i)*y(1) - M_TWO*ri*y(2) - M_TWO*n(odeint%potential, r, odeint%qn)

    case (ODE_SCALAR_REL) 
      ! Scalar-relativistic equation
      r2i = ri*ri
      dvdr_r = dvdr(odeint%potential, r, odeint%qn)
      ci = M_ONE/M_C 
      c2i = ci*ci
      m = M_ONE - vr_m_e*M_HALF*c2i
      mi = M_ONE/m
      dmdr_r = -dvdr_r*M_HALF*c2i
      llp1 = real(odeint%qn%l, R8)*(real(odeint%qn%l, R8) + M_ONE)

      f(1) = M_TWO*m*M_C*y(2)
      f(2) = (llp1*M_HALF*mi*r2i + vr_m_e)*ci*y(1) - M_TWO*ri*y(2)

    case (ODE_DIRAC)
      ! Spin-unpolarized Dirac equation
      k = real(odeint%qn%k, R8)

      ci = M_ONE/M_C
      f(1) = -(k + M_ONE)*ri*y(1) + (M_TWO*M_C2 - vr_m_e)*ci*y(2)
      f(2) = vr_m_e*ci*y(1) + (k - M_ONE)*ri*y(2)

    case (ODE_DIRAC_POL1)
      ! Spin-polarized Dirac equation for m == 2l+1
      ci = M_ONE/M_C
      bxc_r = bxc(odeint%potential, r)

      c11 = odeint%qn%l*ri
      c12 = ci*(M_TWO*M_C2 - (vr_m_e + M_TWO*odeint%qn%m/(M_TWO*odeint%qn%l + M_THREE)*bxc_r))
      c21 = ci*(vr_m_e + M_TWO*odeint%qn%m/(M_TWO*odeint%qn%l + M_ONE)*bxc_r)
      c22 = -(odeint%qn%l + M_TWO)*ri

      f(1) = c11*y(1) + c12*y(2)
      f(2) = c21*y(1) + c22*y(2)

    case (ODE_DIRAC_POL2)
      ! Spin-polarized Dirac equation for m /= 2l+1
      ci = M_ONE/M_C
      two_l = M_TWO*odeint%qn%l
      two_m = M_TWO*odeint%qn%m
      bxc_r = bxc(odeint%potential, r)
      clm = -sqrt( (two_l + M_ONE)**2 - two_m**2  )/(two_l + M_ONE)

      c11 = odeint%qn%l*ri
      c12 = ci*(M_TWO*M_C2 - (vr_m_e + two_m/(two_l + M_THREE)*bxc_r))
      c21 = ci*(vr_m_e + two_m/(two_l + M_ONE)*bxc_r)
      c22 = -(odeint%qn%l + M_TWO)*ri
      c23 = ci*clm*bxc_r
      c33 = -(odeint%qn%l + M_ONE)*ri
      c34 = ci*(M_TWO*M_C2 - (vr_m_e - two_m/(two_l - M_ONE) *bxc_r))
      c41 = c23
      c43 = ci*(vr_m_e - two_m/(two_l + M_ONE)*bxc_r)
      c44 = (odeint%qn%l - M_ONE)*ri

      f(1) = c11*y(1) + c12*y(2)
      f(2) = c21*y(1) + c22*y(2) + c23*y(3)
      f(3) = c33*y(3) + c34*y(4)
      f(4) = c41*y(1) + c43*y(3) + c44*y(4)
      f(5) = c11*y(5) + c12*y(6)
      f(6) = c21*y(5) + c22*y(6) + c23*y(7)
      f(7) = c33*y(7) + c34*y(8)
      f(8) = c41*y(5) + c43*y(7) + c44*y(8) 

    end select

  end subroutine ode_derivatives

  !-----------------------------------------------------------------------
  !> Prints debug information to the "debug_info/ode_integrator" file.    
  !-----------------------------------------------------------------------
  subroutine ode_integrator_debug(odeint, ri, rf, nstep, r, f)
    type(ode_integrator_t), intent(in) :: odeint                         !< integrator object
    real(R8),               intent(in) :: ri                             !< initial integration point
    real(R8),               intent(in) :: rf                             !< final integration point
    integer,                intent(in) :: nstep                          !< number of steps taken by the solver
    real(R8),               intent(in) :: r(odeint%nstepmax)             !< mesh used by the ODE solver
    real(R8),               intent(in) :: f(odeint%nstepmax, odeint%dim) !< integrated function
    
    character(len=80) :: fmt
    integer           :: unit, i

    call push_sub("ode_integrator_debug")

    call io_open(unit, file='debug_info/ode_integrator')
    write(unit,'("ODE Integrator maximum number of steps: ",I6)') odeint%nstepmax
    write(unit,'("Quantum numbers:")')
    write(unit,'("  l =  ",I2)') odeint%qn%l
    write(unit,'("  k =  ",I2)') odeint%qn%k
    write(unit,'("  s =  ",F4.1)') odeint%qn%s
    write(unit,'("  m =  ",F4.1)') odeint%qn%m
    write(unit,'("  sg =  ",F4.1)') odeint%qn%sg
    write(unit,'("Energy:",F12.5)') odeint%e
    close(unit)

    call io_open(unit, file='debug_info/ode_functions')
    write(unit,'("# Integration Starting Point: ",ES10.3E2)') ri
    write(unit,'("# Integration Ending Point:   ",ES10.3E2)') rf
    write(fmt,'(A,I1,A)') "(ES15.8E2,1X,", odeint%dim,"(ES15.8E2,1X))"
    do i = 1, nstep
      write(unit,fmt) r(i), f(i,1:odeint%dim)
    end do
    close(unit)

    call pop_sub()
  end subroutine ode_integrator_debug

end module ode_integrator_m
