/*
 * Copyright (c) 2013, Ambroz Bizjak <ambrop7@gmail.com>
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met: 
 * 
 * 1. Redistributions of source code must retain the above copyright notice, this
 *    list of conditions and the following disclaimer. 
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution. 
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * Efficient computation of a weighted sum of last N received numbers.
 * See: http://c...content-available-to-author-only...e.com/questions/10612/weighted-sum-of-last-n-numbers/10670#10670
 * 
 * Initialization:
 *   WSum state;
 *   WSum_Init(&state, number_of_weights, pointer_to_weights);
 * The weights need to be specified in "reverse", meaning that the first weight
 * applies to the most recent number.
 * 
 * Operation:
 *   double current_wsum = WSum_AddNumber(&state, your_number);
 */

#include <stdio.h>
#include <complex.h>
#include <string.h>
#include <math.h>
#include <assert.h>
#include <stdlib.h>

/*
 * Probably not the fastest FFT.
 */
static void fft_rec (int exp, double complex *data, double complex *temp, double complex rou)
{
    if (exp == 0) {
        return;
    }
    size_t half = ((size_t)1) << (exp - 1);
    for (size_t i = 0; i < half; i++) {
        temp[i] = data[2 * i];
        temp[half + i] = data[2 * i + 1];
    }
    double complex rou_2 = rou * rou;
    fft_rec(exp - 1, temp, data, rou_2);
    fft_rec(exp - 1, temp + half, data, rou_2);
    double complex f = 1;
    for (size_t i = 0; i < half; i++) {
        data[i] = temp[i] + f * temp[half + i];
        data[half + i] = temp[i] - f * temp[half + i];
        f *= rou;
    }
}

static void fft (int exp, double complex *data, double complex *temp)
{
    size_t n = ((size_t)1) << exp;
    double complex rou = cexp(I * (-2 * 3.14159265358979323846 / n));
    fft_rec(exp, data, temp, rou);
}

// Input: m weights, 2m-1 numbers.
// Output: m sums, sums[i] = sum (k=0 to m-1) weights[k]*numbers[i+m-1-k]
// This function implements everything FFT-related that the WSum code needs.
static void compute_sums (size_t m, const double *weights, const double *numbers, double *sums)
{
    int dft_exp = 0;
    size_t dft_n = 1;
    while (dft_n < 2 * m) {
        dft_exp++;
        dft_n *= 2;
    }
    
    double complex *a = malloc(dft_n * sizeof(*a));
    double complex *b = malloc(dft_n * sizeof(*b));
    double complex *c = malloc(dft_n * sizeof(*c));
    
    for (size_t i = 0; i < m; i++) {
        a[i] = weights[i];
    }
    for (size_t i = m; i < dft_n; i++) {
        a[i] = 0;
    }
    
    for (size_t i = 0; i < 2 * m - 1; i++) {
        b[i] = numbers[i];
    }
    for (size_t i = 2 * m - 1; i < dft_n; i++) {
        b[i] = 0;
    }
    
    fft(dft_exp, a, c);
    fft(dft_exp, b, c);
    
    for (size_t i = 0; i < dft_n; i++) {
        c[i] = conj(a[i] * b[i]);
    }
    
    fft(dft_exp, c, a);
    
    for (size_t i = 0; i < m; i++) {
        sums[i] = creal(c[m - 1 + i]) / dft_n;
    }
    
    free(c);
    free(b);
    free(a);
}

typedef struct {
    size_t m;
    size_t n;
    size_t num_m;
    double *weights;
    double *inputs;
    double *sums;
    double *new_sums;
    size_t m_pos;
    size_t num_m_pos;
} WSum;

void WSum_Init (WSum *o, size_t orig_n, const double *orig_weights)
{
    assert(orig_n > 0);
    
    // choose m
    o->m = sqrt(orig_n * log(orig_n));
    if (o->m == 0) {
        o->m = 1;
    } else if (o->m > orig_n) {
        o->m = orig_n;
    }
    
    // choose n, taking care it's a multiple of m
    o->n = orig_n;
    if (o->n % o->m) {
        o->n += o->m - (o->n % o->m);
    }
    o->num_m = o->n / o->m;
    
    // allocate and initialize weights
    o->weights = malloc((o->n + 1) * sizeof(o->weights[0]));
    for (size_t i = 0; i < orig_n; i++) {
        o->weights[i] = orig_weights[i];
    }
    for (size_t i = orig_n; i < o->n + 1; i++) {
        o->weights[i] = 0;
    }
    
    // allocate and initialize inputs
    o->inputs = malloc((o->n + o->m) * sizeof(o->inputs[0]));
    for (size_t i = 0; i < o->n; i++) {
        o->inputs[i] = 0;
    }
    
    // allocate and initialize sums
    o->sums = malloc(o->n * sizeof(o->sums[0]));
    o->new_sums = malloc(o->n * sizeof(o->new_sums[0]));
    for (size_t i = 0; i < o->n; i++) {
        o->sums[i] = 0;
        o->new_sums[i] = 0;
    }
    
    o->m_pos = 0;
    o->num_m_pos = 0;
}

void WSum_Free (WSum *o)
{
    free(o->new_sums);
    free(o->sums);
    free(o->inputs);
    free(o->weights);
}

double WSum_AddNumber (WSum *o, double x)
{
    assert(o->m_pos < o->m);
    
    double sum = 0;
    
    // add from sums
    for (size_t i = 0; i < o->num_m; i++) {
        sum += o->sums[i * o->m + o->m_pos];
    }
    
    // push this number to inputs
    o->inputs[o->n + o->m_pos] = x;
    o->m_pos++;
    
    // add from pending inputs (incl. this number)
    for (size_t i = 0; i < o->m_pos; i++) {
        sum += o->inputs[o->n + i] * o->weights[o->m_pos - 1 - i];
    }
    
    // decide up to which block we would like to update
    size_t end_update_pos = o->num_m;
    
    // cannot compute last two new_sums blocks unless we have the new inputs
    if (o->m_pos < o->m) {
        while (end_update_pos > 0 && end_update_pos >= o->num_m - 1) {
            end_update_pos--;
        }
    }
    
    // compute new_sums blocks
    while (o->num_m_pos < end_update_pos) {
        if (o->num_m_pos == o->num_m - 1) {
            for (size_t j = 0; j < o->m; j++) {
                o->inputs[o->n + j] = 0;
            }
        } else {
            memcpy(o->inputs + (o->num_m_pos + 1) * o->m, o->inputs + (o->num_m_pos + 2) * o->m, o->m * sizeof(o->inputs[0]));
        }
        compute_sums(o->m, o->weights + (o->num_m - 1 - o->num_m_pos) * o->m + 1, o->inputs + o->num_m_pos * o->m, o->new_sums + o->num_m_pos * o->m);
        o->num_m_pos++;
    }
    
    // switch at end of input block
    if (o->m_pos == o->m) {
        assert(o->num_m_pos == o->num_m);
        memcpy(o->inputs, o->inputs + o->m, o->m * sizeof(o->inputs[0]));
        double *old_sums = o->sums;
        o->sums = o->new_sums;
        o->new_sums = old_sums;
        o->m_pos = 0;
        o->num_m_pos = 0;
    }
    
    return sum;
}

#define CHECK_SUMS 1
#define NUM_ITER 1000

int main ()
{
    double w[] = {4, 6, 1, -2, 61, 1, 623, 135, 15, 224, 5, 3, 146, 52, 83, 23};
    size_t n = sizeof(w) / sizeof(w[0]);
    
#if CHECK_SUMS
    double prev[n];
    for (size_t i = 0; i < n; i++) {
        prev[i] = 0;
    }
#endif
    
    WSum state;
    WSum_Init(&state, n, w);
    
    printf("m=%zu n=%zu num_m=%zu\n", state.m, state.n, state.num_m);
    
    for (int k = 0; k < NUM_ITER; k++) {
        double x = rand();
        double sum = WSum_AddNumber(&state, x);
        
#if CHECK_SUMS
        memmove(prev, prev + 1, (n - 1) * sizeof(prev[0]));
        prev[n - 1] = x;
        
        double check_sum = 0;
        for (size_t i = 0; i < n; i++) {
            check_sum += w[i] * prev[n - 1 - i];
        }
        
        int failed = (fabs(sum - check_sum) > 0.01);
        printf("%f %f %f %s\n", x, sum, check_sum, (failed ? "FAILED" : "OK"));
        
        if (failed) {
            return 1;
        }
#else
        printf("%f %f\n", x, sum);
#endif
    }
    
    WSum_Free(&state);
    return 0;
}
