import Control.Monad import Control.Monad.ST import Data.Array import Data.Array.ST import Data.List import Data.STRef withSTRef :: STRef s a -> (a -> ST s a) -> ST s () withSTRef ref k = do a <- readSTRef ref a' <- k a writeSTRef ref a' type Spiral = Array (Int,Int) Int spiral :: Int -> Spiral spiral size | size < 1 = error "spiral: size < 1" | otherwise = runSTArray $ do array <- newArray ((0, 0), (size-1, size-1)) undefined pos <- let start = (size-1) `div` 2 in newSTRef (start, start, 1) let move (dx,dy) = withSTRef pos $ \(x,y,n) -> do n `seq` writeArray array (y,x) n return (x+dx, y+dy, n+1) let right = move (1,0) down = move (0,1) left = move (-1,0) up = move (0,-1) let over n = do replicateM_ n up replicateM_ (n+1) right let under n = do replicateM_ n down replicateM_ (n+1) left sequence_ $ take size $ zipWith ($) (cycle [over, under]) [0..] return array spiralSize :: Spiral -> Int spiralSize = f . bounds where f ((0,0), (w,h)) | w == h && w > 0 = w + 1 f _ = error "Invalid spiral dimensions" printSpiral :: Spiral -> IO () printSpiral spiral = do let size = spiralSize spiral let items = [[spiral ! (i,j) | j <- [0..size-1]] | i <- [0..size-1]] mapM_ (putStrLn . intercalate "\t" . map show) items sumDiagonals :: Spiral -> Int sumDiagonals spiral = let size = spiralSize spiral s = sum [spiral ! (i,i) + spiral ! (size-i-1, i) | i <- [0..size-1]] in s-1 -- subtract 1 to undo counting the middle twice main = print $ sumDiagonals $ spiral 1001