import java.io.*;
import java.util.*;

public class Main {

    public static void main(String[] args) {
        InputReader in = new InputReader(System.in);
        OutputWriter out = new OutputWriter(System.out);
        TaskD solver = new TaskD(in, out);
        solver.solve();
        out.close();
    }

    static class TaskD {
        static private final int[] dx = {1, -1, 0, 0};
        static private final int[] dy = {0, 0, 1, -1};

        static int n, m;
        static int total;

        static int[][] f;
        static int[] posX;
        static int[] posY;
        static int[] leftBound;

        List<Integer>[] edges;
        LinkCutTree lct;
        InputReader in;
        OutputWriter out;

        TaskD(InputReader in, OutputWriter out) {
            this.in = in;
            this.out = out;
        }

        private void solve() {
            n = in.readInt();
            m = in.readInt();
            total = n * m;
            f = new int[n][m];
            posX = new int[total];
            posY = new int[total];
            for(int i = 0; i < n; ++i) {
                for(int j = 0; j < m; ++j) {
                    f[i][j] = in.readInt() - 1;
                    posX[f[i][j]] = i;
                    posY[f[i][j]] = j;
                }
            }
            edges = new ArrayList[total];
            leftBound = new int[total];
            lct = new LinkCutTree(total);
            for(int i = 0; i < total; ++i) {
                edges[i] = new ArrayList<>();
            }
            for(int l = 0, r = 0; r < total; ++r) {
               l = addCell(l, r);
               leftBound[r] = l;
            }
            long ans = 0;
            SegmentTree t = new SegmentTree(total + 4);
            for(int r = 0; r < total; ++r) {
                int x = posX[r];
                int y = posY[r];
                List<Integer> seg = new ArrayList<>();
                for(int i = 0; i < 4; ++i) {
                    int px = x + dx[i];
                    int py = y + dy[i];
                    if (px < 0 || px >= n) continue;
                    if (py < 0 || py >= m) continue;
                    int to = f[px][py];
                    if (to > r) continue;
                    seg.add(to);
                }
                seg.add(-1);
                seg.add(r);
                Collections.sort(seg);
                for(int i = seg.size() - 1, comp = 1; i > 0; --i) {
                    if (seg.get(i - 1) + 1 > seg.get(i)) {
                        --comp;
                        continue;
                    }
                    t.inc(seg.get(i - 1) + 1, seg.get(i) + 1, comp);
                    --comp;
                }
                ans += t.query(leftBound[r], r + 1);
            }
            out.print(ans);
        }

        private int addCell(int l, int r) {
            int x = posX[r];
            int y = posY[r];
            for(int i = 0; i < 4; ++i) {
                int px = x + dx[i];
                int py = y + dy[i];
                if (px < 0 || px >= n) continue;
                if (py < 0 || py >= m) continue;
                int to = f[px][py];
                while (l <= to && to <= r) {
                    if (lct.link(to, r)) {
                        edges[to].add(r);
                        break;
                    }
                    removeCell(l);
                    ++l;
                }
            }
            return l;
        }

        private void removeCell(int c) {
            for(int to : edges[c]) {
                lct.cut(c, to);
            }
            edges[c].clear();
        }
    }

    static class SegmentTree {
        int size;
        int height;
        int[] t;
        int[] cnt;
        int[] push;

        SegmentTree(int n) {
            this.size = n;
            this.height = 0;
            int temp = n;
            while (temp > 0) {
                temp >>= 1;
                ++this.height;
            }
            t = new int[n << 1];
            cnt = new int[n << 1];
            push = new int[n];
            for(int i = 0; i < n; ++i) {
                cnt[i + n] = 1;
            }
            for(int i = n - 1; i >= 0; --i) {
                cnt[i] = cnt[i << 1] + cnt[i << 1 | 1];
            }
        }

        private void apply(int p, int value) {
            t[p] += value;
            if (p < size) push[p] += value;
        }

        private void recalc(int p) {
            while (p > 1) {
                p >>= 1;
                t[p] = Math.min(t[p << 1], t[p << 1 | 1]) + push[p];
                cnt[p] = (t[p] == t[p << 1] + push[p] ? cnt[p << 1] : 0)
                        + (t[p] == t[p << 1 | 1] + push[p] ? cnt[p << 1 | 1] : 0);
            }
        }

        private void push(int p) {
            for (int s = height; s > 0; s--) {
                int i = p >> s;
                if (push[i] != 0) {
                    apply(i << 1, push[i]);
                    apply(i << 1 | 1, push[i]);
                    push[i] = 0;
                }
            }
        }

        private void inc(int l, int r, int value) {
            l += size;
            r += size;
            int l0 = l;
            int r0 = r;
            for (; l < r; l >>= 1, r >>= 1) {
                if ((l & 1) != 0) apply(l++, value);
                if ((r & 1) != 0) apply(--r, value);
            }
            recalc(l0);
            recalc(r0 - 1);
        }

        private int query(int l, int r) {
            l += size;
            r += size;
            push(l);
            push(r - 1);
            int min = Integer.MAX_VALUE;
            int minCount = 0;
            for (; l < r; l >>= 1, r >>= 1) {
                if ((l & 1) != 0) {
                    if (min > t[l]) {
                        min = t[l];
                        minCount = cnt[l];
                    } else if (min == t[l]) {
                        minCount += cnt[l];
                    }
                    ++l;
                }
                if ((r & 1) != 0) {
                    --r;
                    if (min > t[r]) {
                        min = t[r];
                        minCount = cnt[r];
                    } else if (min == t[r]) {
                        minCount += cnt[r];
                    }
                }
            }
            return minCount;
        }
    }

    static class LinkCutTree {

        private static class Node {
            Node left;
            Node right;
            Node parent;
            boolean revert;

            boolean isRoot() {
                return parent == null || (parent.left != this && parent.right != this);
            }

            void push() {
                if (revert) {
                    revert = false;
                    Node t = left;
                    left = right;
                    right = t;
                    if (left != null)   left.revert = !left.revert;
                    if (right != null)  right.revert = !right.revert;
                }
            }
        }

        static int size;
        Node[] nodes;

        LinkCutTree(int n) {
            LinkCutTree.size = n;
            nodes = new Node[n];
            for(int i = 0; i < n; ++i) {
                nodes[i] = new Node();
            }
        }

        private boolean link(int x, int y) {
            return link(nodes[x], nodes[y]);
        }

        private void cut(int x, int y) {
            cut(nodes[x], nodes[y]);
        }

        private static boolean link(Node x, Node y) {
            if (connected(x, y))
                return false;
            makeRoot(x);
            x.parent = y;
            return true;
        }

        private static void cut(Node x, Node y) {
            makeRoot(x);
            expose(y);
            if (y.right != x || x.left != null || x.right != null) {
                throw new RuntimeException("error: no edge (x,y)");
            }
            y.right.parent = null;
            y.right = null;
        }

        static void connect(Node ch, Node p, Boolean isLeftChild) {
            if (ch != null)
                ch.parent = p;
            if (isLeftChild != null) {
                if (isLeftChild)
                    p.left = ch;
                else
                    p.right = ch;
            }
        }

        static void rotate(Node x) {
            Node p = x.parent;
            Node g = p.parent;
            boolean isRootP = p.isRoot();
            boolean leftChildX = (x == p.left);
            connect(leftChildX ? x.right : x.left, p, leftChildX);
            connect(p, x, !leftChildX);
            connect(x, g, !isRootP ? p == g.left : null);
        }

        static void splay(Node x) {
            while (!x.isRoot()) {
                Node p = x.parent;
                Node g = p.parent;
                if (!p.isRoot())
                    g.push();
                p.push();
                x.push();
                if (!p.isRoot())
                    rotate((x == p.left) == (p == g.left) ? p/*zig-zig*/ : x/*zig-zag*/);
                rotate(x);
            }
            x.push();
        }

        // makes node x the root of the virtual tree, and also x becomes the leftmost node in its splay tree
        static Node expose(Node x) {
            Node last = null;
            for (Node y = x; y != null; y = y.parent) {
                splay(y);
                y.left = last;
                last = y;
            }
            splay(x);
            return last;
        }

        private static void makeRoot(Node x) {
            expose(x);
            x.revert = !x.revert;
        }

        private static boolean connected(Node x, Node y) {
            if (x == y)
                return true;
            expose(x);
            expose(y);
            return x.parent != null;
        }
    }

    static class OutputWriter {
        private final PrintWriter writer;

        public OutputWriter(OutputStream outputStream) {
            writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream)));
        }

        public OutputWriter(Writer writer) {
            this.writer = new PrintWriter(writer);
        }

        public void print(Object... objects) {
            for (int i = 0; i < objects.length; i++) {
                if (i != 0) {
                    writer.print(' ');
                }
                writer.print(objects[i]);
            }
            writer.println();
        }

        public void close() {
            writer.close();
        }

    }

    static class InputReader {
        private InputStream stream;
        private byte[] buf = new byte[1024];
        private int curChar;
        private int numChars;
        private InputReader.SpaceCharFilter filter;

        public InputReader(InputStream stream) {
            this.stream = stream;
        }

        public int read() {
            if (numChars == -1) {
                throw new InputMismatchException();
            }
            if (curChar >= numChars) {
                curChar = 0;
                try {
                    numChars = stream.read(buf);
                } catch (IOException e) {
                    throw new InputMismatchException();
                }
                if (numChars <= 0) {
                    return -1;
                }
            }
            return buf[curChar++];
        }

        public int readInt() {
            int c = read();
            while (isSpaceChar(c)) {
                c = read();
            }
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = read();
            }
            int res = 0;
            do {
                if (c < '0' || c > '9') {
                    throw new InputMismatchException();
                }
                res *= 10;
                res += c - '0';
                c = read();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public boolean isSpaceChar(int c) {
            if (filter != null) {
                return filter.isSpaceChar(c);
            }
            return isWhitespace(c);
        }

        public static boolean isWhitespace(int c) {
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }

        public double readDouble() {
            int c = read();
            while (isSpaceChar(c)) {
                c = read();
            }
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = read();
            }
            double res = 0;
            while (!isSpaceChar(c) && c != '.') {
                if (c == 'e' || c == 'E') {
                    return res * Math.pow(10, readInt());
                }
                if (c < '0' || c > '9') {
                    throw new InputMismatchException();
                }
                res *= 10;
                res += c - '0';
                c = read();
            }
            if (c == '.') {
                c = read();
                double m = 1;
                while (!isSpaceChar(c)) {
                    if (c == 'e' || c == 'E') {
                        return res * Math.pow(10, readInt());
                    }
                    if (c < '0' || c > '9') {
                        throw new InputMismatchException();
                    }
                    m /= 10;
                    res += (c - '0') * m;
                    c = read();
                }
            }
            return res * sgn;
        }

        public interface SpaceCharFilter {
            public boolean isSpaceChar(int ch);
        }

    }
}