/* PlayFair chiper */

#include <ctype.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define MAXSTR 1000
#define WIDTH 5
#define TABLESIZE 25
#define NLETTER 26
#define ENCRYPTION 1
#define DECRYPTION (WIDTH - ENCRYPTION)

char *extract(char *, const char *);
void delX(char *);
char *getpair(char *, const char *);
int lookup(const char *, int);
void makepair(char *, const char *);
void maketable(char *, const char *);
char *playfair(const char *, const char *, int, int);
char *replace(char *, const char *, int);
void replace_all(char *, const char *, int);
int rowcoltochar(const char *, int, int);
void showpair(const char *);
const char *topair(char *, const char *);
void showtable(char *);

int main(int argc, char *argv[])
{
    char *out = NULL, *out2 = NULL;
    printf("<< ENCRYPTION >>\n");
    out = playfair("hello, world", "fairplay", ENCRYPTION, 1);
    printf("%s\n", out);
    printf("\n<< DECRYPTION >>\n"); 
    out2 = playfair(out, "fairplay", DECRYPTION, 1); 
    printf("%s\n", out2); 
    free(out);
    free(out2);
    return 0;
}

char *playfair(const char *src, const char *key, int direction, int verbose)
{
    char table[TABLESIZE];
    char *es = (char *) malloc(strlen(src) + 1);
    char *ek = (char *) malloc(strlen(src) + 1);
    char *out = NULL;
    
    if (direction != ENCRYPTION && direction != DECRYPTION)
        goto error;
    if (es == NULL || ek == NULL)
        goto error;

    ek = extract(ek, key);
    if (verbose)
        printf("key:\n%s\n%s\n", key, ek);
    maketable(table, key);
    if (verbose)
        showtable(table);
    es = extract(es, src);
    if (verbose)
            printf("src:\n%s\n%s\n", src, es);
    out = (char *) malloc(strlen(es) * 2 + 1);
    if (out == NULL)
            goto error;
    if (direction == ENCRYPTION) {
        makepair(out, es);
    } else {
        strcpy(out, es);
    }
    if (verbose)
        showpair(out);
    replace_all(out, table, direction);
    if (verbose) {
        showpair(out);
        fflush(stdout);
    }
    if (direction == DECRYPTION)
        delX(out);
    return out;

error:
    free(ek);
    free(es);
    free(out);
    return NULL;
}

void maketable(char *table, const char *key)
{
    int i, j;
    char used[NLETTER];
    static char k[MAXSTR];

    memset(table, '\0', TABLESIZE);
    memset(used, 0, sizeof(used));
    used['J' - 'A'] = 1;
    extract(k, key);
    for (i = 0, j = 0; k[i]; i++) {
        int c = (k[i] == 'J') ? 'I' : k[i];
        if (used[c - 'A'])
            continue;
        used[c - 'A'] = 1;
        table[j++] = c;
    }
    for (i = 0; j < TABLESIZE; j++) {
        while (used[i])
            i++;
        table[j] = i++ + 'A';
    }
}

void showtable(char *table)
{
    int i;
    for (i = 0; i < TABLESIZE; i++) {
        int c = table[i];
        if (c == '\0')
            c = '.';
        printf((i % WIDTH) ? " %c" : "%c", c);
        if ((i + 1) % WIDTH == 0)
            printf("\n");
    }
}

char *extract(char *to, const char *from)
{
    char *retval = to;
    for (; *from; from++) 
        if (isalpha(*from))
            *to++ = toupper(*from);
    *to = '\0';
    return retval;
}

void makepair(char *out, const char *p)
{
    while (p = topair(out, p))
        out += 2;
    out = '\0';
}

const char *topair(char *pair, const char *s)
{
    if (*s == '\0')
        return NULL;
    pair[0] = *s++;
    if (pair[0] == *s || *s == '\0')
        pair[1] = 'X';
    else
        pair[1] = *s++;

    return s;
}

void showpair(const char *s)
{
    int isfirst = 1;
    while (*s) {
        if (!isfirst)
            printf("-");
        isfirst = 0;
        if (s[1] == '\0') {
            /* ありえないと思うが */
            printf("%c", s[0]);
            s++;
        } else {
            printf("%c%c", s[0], s[1]);
            s += 2;
        }
    }
    printf("\n");
}

void replace_all(char *s, const char *table, int dec)
{
    while (s[0] && s[1]) {
        replace(s, table, dec);
        s += 2;
    }
}

char *replace(char *pair, const char *table, int dec)
{
    int i, pos[2], row[2], col[2];
    for (i = 0; i < 2; i++) {
        pos[i] = lookup(table, pair[i]);
        row[i] = pos[i] / WIDTH;
        col[i] = pos[i] % WIDTH;
    }
    if (row[0] ==  row[1]) {
        for (i = 0; i < 2; i++)
            col[i] = (col[i] + dec) % WIDTH;
    } else if (col[0] == col[1]) {
        for (i = 0; i < 2; i++)
            row[i] = (row[i] + dec) % WIDTH;
    } else {
        int tmp = col[0];
        col[0] = col[1];
        col[1] = tmp;
    }
    for (i = 0; i < 2; i++)
        pair[i] = rowcoltochar(table, row[i], col[i]);
    return pair;
}

int lookup(const char *table, int c)
{
    char *p = strchr(table, c);
    return p - table;
}

int rowcoltochar(const char *table, int row, int col)
{
    return table[row * WIDTH + col];
}

void delX(char *s)
{
    char *t = s;
    while (s[0] && s[1]) {
        *t++ = *s++;
        if (*s != 'X')
            *t++ = *s;
        s++;
    }
    *t = '\0';
}

/* END OF FILE */