// This file is part of the FXT library.
// Copyright (C) 2011, 2012, 2014, 2023, 2024 Joerg Arndt
// License: GNU General Public License version 3 or later,
// see the file COPYING.txt in the main directory.

#include "walsh/square-wave-transform.h"

#include "aux0/sumdiff.h"  // diffsum(), sumdiff_r()
#include "fxttypes.h"

#include <cmath>  // sqrt()

//#include "jjassert.h"

void
square_wave_basis(double *f, ulong n, ulong k)
// Compute the k-th basis vector for the square wave transform.
// The vectors are linearly independent but not mutually orthogonal.
{
    const ulong nh = n / 2;
    if ( k < nh )
    {
        ++k;
        ulong j = 0;
        while ( j < k )       { f[j] = +1.0;  ++j; }
        while ( j < k + nh )  { f[j] = -1.0;  ++j; }
        while ( j < n )       { f[j] = +1.0;  ++j; }
        return;
    }

    if ( k == n - 1 )
    {
        for (ulong j=0; j<n; ++j)  { f[j] = +1; }
        return;
    }

    square_wave_basis( f,      nh, k - nh );
    square_wave_basis( f + nh, nh, k - nh );
}
// -------------------------


void
swt_normalize(double *f, ulong n)
// Normalization for the square wave transform.
{
//    constexpr double s2 = sqrt(0.5);
    constexpr double s2 = 0.70710678118654752440084436210484903928;
    if ( n <= 2 )
    {
        if ( n == 2 )
        {
            f[n-2] *= s2;
            f[n-1] *= s2;
        }
        return;
    }

    double s = 0.5;
    for (ulong nh = n/2; nh >= 4; nh /= 2)
    {
        const ulong off = n - 2 * nh;
        // left len / 2:  s = sqrt(1/4) = 1/2
        // next len / 4:  s = sqrt(1/8)
        // next len / 8:  s = sqrt(1/16) = 1/4
        // next len / 16: s = sqrt(1/32)
        // etc., until 4 elements (on the right) are left.
        for (ulong j=off; j < off+nh; ++j)  { f[j] *= s; }
        s *= s2;
    }
    // last 4: s = sqrt( 1 / n )
    f[n-4] *= s;
    f[n-3] *= s;
    f[n-2] *= s;
    f[n-1] *= s;
}
// -------------------------


void
square_wave_transform(double *f, ulong ldn, bool nq/*=true*/)
// Square wave transform.
// The transform of a delta-pulse has norm 1 (if nq==true).
// Algorithm is O(n) where n is the length of the vector f[].
//
// Reference:
//  John Pender, David Covey:
//  New square wave transform for digital signal processing,
//  IEEE Transactions on Signal Processing, vol.40, no.8, pp.2095-2097,
//  (August-1992).
//
// This transform (with nq==false) is multiplication
// with the following matrix ('+':=+1, '-':=-1):
// (i.e., the rows are the basis functions)
//   0: [ + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + ]
//   1: [ + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + ]
//   2: [ + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + ]
//   3: [ + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + ]
//   4: [ + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + ]
//   5: [ + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + ]
//   6: [ + + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + ]
//   7: [ + + + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + ]
//   8: [ + + + + + + + + + - - - - - - - - - - - - - - - - + + + + + + + ]
//   9: [ + + + + + + + + + + - - - - - - - - - - - - - - - - + + + + + + ]
//  10: [ + + + + + + + + + + + - - - - - - - - - - - - - - - - + + + + + ]
//  11: [ + + + + + + + + + + + + - - - - - - - - - - - - - - - - + + + + ]
//  12: [ + + + + + + + + + + + + + - - - - - - - - - - - - - - - - + + + ]
//  13: [ + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - + + ]
//  14: [ + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - + ]
//  15: [ + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - ]
//  16: [ + - - - - - - - - + + + + + + + + - - - - - - - - + + + + + + + ]
//  17: [ + + - - - - - - - - + + + + + + + + - - - - - - - - + + + + + + ]
//  18: [ + + + - - - - - - - - + + + + + + + + - - - - - - - - + + + + + ]
//  19: [ + + + + - - - - - - - - + + + + + + + + - - - - - - - - + + + + ]
//  20: [ + + + + + - - - - - - - - + + + + + + + + - - - - - - - - + + + ]
//  21: [ + + + + + + - - - - - - - - + + + + + + + + - - - - - - - - + + ]
//  22: [ + + + + + + + - - - - - - - - + + + + + + + + - - - - - - - - + ]
//  23: [ + + + + + + + + - - - - - - - - + + + + + + + + - - - - - - - - ]
//  24: [ + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + + ]
//  25: [ + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + + ]
//  26: [ + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - + ]
//  27: [ + + + + - - - - + + + + - - - - + + + + - - - - + + + + - - - - ]
//  28: [ + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + ]
//  29: [ + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - ]
//  30: [ + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - ]
//  31: [ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ]
// If nq==true then norm of all row vectors is sqrt(n).
// If nq==true then the norm of the row vectors is
//  sqrt(1) for the last 4,  sqrt(2) for the prior 4 (again 4),
//  sqrt(4) for the prior 8,  sqrt(8) for the prior 16,
//  sqrt(16) for the prior 32,  etc.
{
    const ulong n = 1UL << ldn;
    if ( n <= 2 )
    {
        if ( n == 2 )
        {
            diffsum( f[n-2], f[n-1] );
            if ( nq )  swt_normalize(f, n);
        }
        return;
    }

    // first pass: diffsum of halves
    ulong nh = n / 2;
    for (ulong j=0; j < nh; ++j)  { diffsum( f[j], f[j+nh] ); }

    ulong off = 0;
    // remaining passes: cyclic diff left, diffsum right
    while ( nh >= 4 )
    {
        // cyclic differences left:
        const double t = f[off];  // will be added to last
        for (ulong j=off; j < off+nh-1; ++j)  { f[j] -= f[j+1]; }
        f[off+nh-1] += t;

        // diffsum right:
        const ulong n4 = nh / 2;
        for (ulong j=off+nh; j < off+nh+n4; ++j)  { diffsum( f[j], f[j+n4] ); }

        off += nh;
        nh /= 2;
    }

    // rightmost 4 elements:
    diffsum( f[n-4], f[n-3] );
    diffsum( f[n-2], f[n-1] );

    // normalization at end:
    if ( nq )  swt_normalize(f, n);
}
// -------------------------


void
inverse_square_wave_transform(double *f, ulong ldn, bool nq/*=true*/)
// Inverse of the square wave transform only if nq==true in both transforms.
// Algorithm is O(n) where n is the length of the vector f[].
//
// Reference:
//  John Pender, David Covey:
//  New square wave transform for digital signal processing,
//  IEEE Transactions on Signal Processing, vol.40, no.8, pp.2095-2097,
//  (August-1992).
//
// This transform (with nq==false) is multiplication
// with the following matrix ('+':=+1, '-':=-1, ' ':=0):
//
//   0: [ +                             + +             + +     + + + + + ]
//   1: [ - +                             - +             - +     - + - + ]
//   2: [   - +                             - +             - +   - - + + ]
//   3: [     - +                             - +             - + + - - + ]
//   4: [       - +                             - +       -     - + + + + ]
//   5: [         - +                             - +     + -     - + - + ]
//   6: [           - +                             - +     + -   - - + + ]
//   7: [             - +                             - +     + - + - - + ]
//   8: [               - +               -             - +     + + + + + ]
//   9: [                 - +             + -             - +     - + - + ]
//  10: [                   - +             + -             - +   - - + + ]
//  11: [                     - +             + -             - + + - - + ]
//  12: [                       - +             + -       -     - + + + + ]
//  13: [                         - +             + -     + -     - + - + ]
//  14: [                           - +             + -     + -   - - + + ]
//  15: [                             - +             + -     + - + - - + ]
//  16: [ -                             - +             + +     + + + + + ]
//  17: [ + -                             - +             - +     - + - + ]
//  18: [   + -                             - +             - +   - - + + ]
//  19: [     + -                             - +             - + + - - + ]
//  20: [       + -                             - +       -     - + + + + ]
//  21: [         + -                             - +     + -     - + - + ]
//  22: [           + -                             - +     + -   - - + + ]
//  23: [             + -                             - +     + - + - - + ]
//  24: [               + -               -             - +     + + + + + ]
//  25: [                 + -             + -             - +     - + - + ]
//  26: [                   + -             + -             - +   - - + + ]
//  27: [                     + -             + -             - + + - - + ]
//  28: [                       + -             + -       -     - + + + + ]
//  29: [                         + -             + -     + -     - + - + ]
//  30: [                           + -             + -     + -   - - + + ]
//  31: [                             + -             + -     + - + - - + ]
// If nq==false then the norm all row vectors is sqrt(2*ldn).
// If nq==true  then the norm all row vectors is 1.
//
{
    const ulong n = 1UL << ldn;
    if ( n<=2 )
    {
        if ( n==2 )
        {
            if ( nq )  swt_normalize(f, n);
            sumdiff_r( f[n-2], f[n-1] );
        }
        return;
    }

    // normalization at start:
    if ( nq )  swt_normalize(f, n);

    // rightmost 4 elements:
    sumdiff_r( f[n-4], f[n-3] );
    sumdiff_r( f[n-2], f[n-1] );

    // first passes: inverse cyclic differences left, diffsum right
    for (ulong nh=4; nh < n; nh*=2 )
    {
        const ulong off = n - 2 * nh;

        // inverse cyclic differences left:
        double s = 0.0;
        for (ulong j=off; j < off+nh; ++j)  { s += f[j]; }
        for (ulong j=off; j < off+nh; ++j)
        {
            const double d = f[j];
            f[j] = s;
            s -= (d + d);
        }

        // inverse diffsum right:
        const ulong n4 = nh / 2;
        for (ulong j=off+nh; j < off+nh+n4; ++j)  { sumdiff_r( f[j], f[n4+j] ); }
    }

    // last pass: inverse diffsum of halves
    const ulong nh = n >> 1;
    for (ulong j=0; j < nh; ++j)  { sumdiff_r( f[j], f[nh+j] ); }
}
// -------------------------
