fork download
  1. import Control.Monad
  2. import Control.Monad.ST
  3. import Data.Array
  4. import Data.Array.ST
  5. import Data.List
  6. import Data.STRef
  7.  
  8. withSTRef :: STRef s a -> (a -> ST s a) -> ST s ()
  9. withSTRef ref k = do
  10. a <- readSTRef ref
  11. a' <- k a
  12. writeSTRef ref a'
  13.  
  14. type Spiral = Array (Int,Int) Int
  15.  
  16. spiral :: Int -> Spiral
  17. spiral size
  18. | size < 1 = error "spiral: size < 1"
  19. | otherwise = runSTArray $ do
  20. array <- newArray ((0, 0), (size-1, size-1)) undefined
  21. pos <- let start = (size-1) `div` 2
  22. in newSTRef (start, start, 1)
  23.  
  24. let move (dx,dy) = withSTRef pos $ \(x,y,n) -> do
  25. n `seq` writeArray array (y,x) n
  26. return (x+dx, y+dy, n+1)
  27.  
  28. let right = move (1,0)
  29. down = move (0,1)
  30. left = move (-1,0)
  31. up = move (0,-1)
  32.  
  33. let over n = do
  34. replicateM_ n up
  35. replicateM_ (n+1) right
  36. let under n = do
  37. replicateM_ n down
  38. replicateM_ (n+1) left
  39.  
  40. sequence_ $ take size $ zipWith ($) (cycle [over, under]) [0..]
  41.  
  42. return array
  43.  
  44. spiralSize :: Spiral -> Int
  45. spiralSize = f . bounds where
  46. f ((0,0), (w,h)) | w == h && w > 0 = w + 1
  47. f _ = error "Invalid spiral dimensions"
  48.  
  49. printSpiral :: Spiral -> IO ()
  50. printSpiral spiral = do
  51. let size = spiralSize spiral
  52. let items = [[spiral ! (i,j) | j <- [0..size-1]] | i <- [0..size-1]]
  53. mapM_ (putStrLn . intercalate "\t" . map show) items
  54.  
  55. sumDiagonals :: Spiral -> Int
  56. sumDiagonals spiral =
  57. let size = spiralSize spiral
  58. s = sum [spiral ! (i,i) + spiral ! (size-i-1, i) | i <- [0..size-1]]
  59. in s-1 -- subtract 1 to undo counting the middle twice
  60.  
  61. main = print $ sumDiagonals $ spiral 1001
stdin
Standard input is empty
compilation info
[1 of 1] Compiling Main             ( prog.hs, prog.o )
Linking prog ...
stdout
669171001