; hill cipher

(define (make-matrix rows columns . value)
  (do ((m (make-vector rows)) (i 0 (+ i 1)))
      ((= i rows) m)
    (if (null? value)
        (vector-set! m i (make-vector columns))
        (vector-set! m i (make-vector columns (car value))))))

(define (matrix-rows x) (vector-length x))

(define (matrix-cols x) (vector-length (vector-ref x 0)))

(define (matrix-ref m i j) (vector-ref (vector-ref m i) j))

(define (matrix-set! m i j x) (vector-set! (vector-ref m i) j x))

(define-syntax for
  (syntax-rules ()
    ((for (var first past step) body ...)
      (let ((ge? (if (< first past) >= <=)))
        (do ((var first (+ var step)))
            ((ge? var past))
          body ...)))
    ((for (var first past) body ...)
      (let* ((f first) (p past) (s (if (< first past) 1 -1)))
        (for (var f p s) body ...)))
    ((for (var past) body ...)
      (let* ((p past)) (for (var 0 p) body ...)))))

(define (matrix-add a b)
  (let ((ar (matrix-rows a)) (ac (matrix-cols a))
        (br (matrix-rows b)) (bc (matrix-cols b)))
    (if (or (not (= ar br)) (not (= ac bc)))
        (error 'matrix-add "incompatible matrices")
        (let ((c (make-matrix ar ac)))
          (for (i ar)
            (for (j ac)
              (matrix-set! c i j
                (+ (matrix-ref a i j)
                   (matrix-ref b i j)))))
          c))))

(define (matrix-scalar-multiply n a)
  (let* ((ar (matrix-rows a))
         (ac (matrix-cols a))
         (c (make-matrix ar ac)))
    (for (i ar)
      (for (j ac)
        (matrix-set! c i j
          (* n (matrix-ref a i j)))))
    c))

(define (matrix-multiply a b)
  (let ((ar (matrix-rows a)) (ac (matrix-cols a))
        (br (matrix-rows b)) (bc (matrix-cols b)))
    (if (not (= ac br))
        (error 'matrix-multiply "incompatible matrices")
        (let ((c (make-matrix ar bc 0)))
          (for (i ar)
            (for (j bc)
              (for (k ac)
                (matrix-set! c i j
                  (+ (matrix-ref c i j)
                     (* (matrix-ref a i k)
                        (matrix-ref b k j)))))))
          c))))

(define (matrix-transpose a)
  (let* ((ar (matrix-rows a))
         (ac (matrix-cols a))
         (c (make-matrix ac ar)))
    (for (i ar)
      (for (j ac)
        (matrix-set! c j i
          (matrix-ref a i j))))
    c))

(define (sub-matrix a i j)
  (let ((r (matrix-rows a)) (c (matrix-cols a)))
    (let ((m (make-matrix (- r 1) (- c 1))) (new-i -1))
      (for (old-i c)
        (when (not (= old-i i))
          (set! new-i (+ new-i 1))
          (let ((new-j -1))
            (for (old-j r)
              (when (not (= old-j j))
                (set! new-j (+ new-j 1))
                (matrix-set! m new-i new-j
                  (matrix-ref a old-i old-j)))))))
      m)))

(define (matrix-determinant a) ; assume a is square
  (let ((n (matrix-rows a)))
    (if (= n 2)
        (- (* (matrix-ref a 0 0) (matrix-ref a 1 1))
           (* (matrix-ref a 1 0) (matrix-ref a 0 1)))
        (let loop ((j 0) (k 1) (d 0))
          (if (= j n) d
            (loop (+ j 1) (* k -1)
                  (+ d (* k (matrix-ref a 0 j)
                          (matrix-determinant
                            (sub-matrix a 0 j))))))))))

(define (matrix-cofactors a) ; assume a is square
  (let* ((n (matrix-rows a)) (cof (make-matrix n n)))
    (if (= n 2)
        (for (i n)
          (for (j n)
            (matrix-set! cof i j
              (* (expt -1 (+ i j))
                 (matrix-ref a (- 1 i) (- 1 j))))))
        (for (i n)
          (for (j n)
            (matrix-set! cof i j
              (* (expt -1 (+ i j))
                 (matrix-determinant (sub-matrix a i j)))))))
    cof))

(define (matrix-adjugate a)
  (matrix-transpose (matrix-cofactors a)))

(define (matrix-inverse a)
  (matrix-scalar-multiply
    (/ (matrix-determinant a))
    (matrix-adjugate a)))

(define (inverse x m) ; inverse of x (mod m)
  (let loop ((x x) (a 0) (b m) (u 1))
    (if (zero? x)
        (if (= b 1) (modulo a m) 0)
        (let ((q (quotient b x)))
          (loop (modulo b x) u x
                (modulo (- a (* q u)) m))))))

(define (matrix-map f a)
  (let ((r (matrix-rows a))
        (c (matrix-cols a)))
    (let ((b (make-matrix r c)))
      (for (i r)
        (for (j c)
          (matrix-set! b i j
            (f (matrix-ref a i j)))))
      b)))

(define (matrix-multiply-modulo a b m)
  (define (modm n) (modulo n m))
  (matrix-map modm (matrix-multiply a b)))

(define (matrix-inverse-modulo a m)
  (define (modm n) (modulo n m))
  (matrix-map modm
    (matrix-scalar-multiply
      (inverse (modulo (matrix-determinant a) m) m)
      (matrix-transpose (matrix-cofactors a)))))

(define (c->i c) (- (char->integer (char-upcase c)) 65))

(define (i->c i) (integer->char (+ i 65)))

(define (string->matrix str blocksize)
  (let* ((len (string-length str))
         (rows (ceiling (/ len blocksize)))
         (m (make-matrix rows blocksize #\Z)))
    (for (k len)
      (let ((i (quotient k blocksize))
            (j (remainder k blocksize)))
        (matrix-set! m i j (string-ref str k))))
    m))

(define (matrix->string m)
  (let ((r (matrix-rows m))
        (c (matrix-cols m)))
    (let ((cs (list)))
      (for (i r)
        (for (j c)
          (set! cs (cons (matrix-ref m i j) cs))))
      (list->string (reverse cs)))))

(define (encrypt str key blocksize modulus)
  (matrix->string
    (matrix-map i->c
      (matrix-multiply-modulo
        (matrix-map c->i (string->matrix str blocksize))
        (matrix-map c->i (string->matrix key blocksize))
        modulus))))

(define (decrypt str key blocksize modulus)
  (matrix->string
    (matrix-map i->c
      (matrix-multiply-modulo
        (matrix-map c->i (string->matrix str blocksize))
        (matrix-inverse-modulo
          (matrix-map c->i (string->matrix key blocksize))
          modulus)
        modulus))))

(display (encrypt "PROGRAMMINGPRAXIS" "GYBNQKURP" 3 26)) (newline)
(display (decrypt "TMFXAUYSSONMQTYCVR" "GYBNQKURP" 3 26)) (newline)