//////////////////////////////////////////////////////////////////////
// linked_list.h

#pragma once

//////////////////////////////////////////////////////////////////////

#include <cstddef>
#include <functional>

//////////////////////////////////////////////////////////////////////

struct list_node
{
	list_node *next;
	list_node *prev;
};

//////////////////////////////////////////////////////////////////////

template<typename T, size_t off> struct linked_list
{
	//////////////////////////////////////////////////////////////////////

	static list_node const *get_node(T const *o)
	{
		return reinterpret_cast<list_node const *>(reinterpret_cast<char const *>(o) + off);
	}

	//////////////////////////////////////////////////////////////////////

	static list_node *get_node(T *o)
	{
		return reinterpret_cast<list_node *>(reinterpret_cast<char *>(o) + off);
	}

	//////////////////////////////////////////////////////////////////////

	static T const *get_object(list_node const *node)
	{
		return reinterpret_cast<T const *>(reinterpret_cast<char const *>(node) - off);
	}

	//////////////////////////////////////////////////////////////////////

	static T *get_object(list_node *node)
	{
		return reinterpret_cast<T *>(reinterpret_cast<char *>(node) - off);
	}

	//////////////////////////////////////////////////////////////////////

	list_node root;

	//////////////////////////////////////////////////////////////////////

	linked_list()
	{
		root.next = &root;
		root.prev = &root;
	}

	//////////////////////////////////////////////////////////////////////

	void push_front(T *obj)
	{
		list_node *node = get_node(obj);
		root.next->prev = node;
		node->next = root.next;
		node->prev = &root;
		root.next = node;
	}

	//////////////////////////////////////////////////////////////////////

	void push_front(T &obj)
	{
		list_node *node = get_node(&obj);
		root.next->prev = node;
		node->next = root.next;
		node->prev = &root;
		root.next = node;
	}

	//////////////////////////////////////////////////////////////////////

	void push_back(T *obj)
	{
		list_node *node = get_node(obj);
		node->prev = root.prev;
		node->next = &root;
		root.prev->next = node;
		root.prev = node;
	}

	//////////////////////////////////////////////////////////////////////

	void push_back(T &obj)
	{
		list_node *node = get_node(&obj);
		node->prev = root.prev;
		node->next = &root;
		root.prev->next = node;
		root.prev = node;
	}

	//////////////////////////////////////////////////////////////////////

	void insert_before(T *pos, T *obj)
	{
		list_node *node = get_node(obj);
		list_node *n = get_node(pos);
		n->prev->next = node;
		node->prev = n->prev;
		n->prev = node;
		node->next = n;
	}

	//////////////////////////////////////////////////////////////////////

	void insert_before(T &pos, T &obj)
	{
		list_node *node = get_node(&obj);
		list_node *n = get_node(&pos);
		n->prev->next = node;
		node->prev = n->prev;
		n->prev = node;
		node->next = n;
	}

	//////////////////////////////////////////////////////////////////////

	void insert_after(T *pos, T *obj)
	{
		list_node *node = get_node(obj);
		list_node *n = get_node(pos);
		n->next->prev = node;
		node->next = n->next;
		n->next = node;
		node->prev = n;
	}

	//////////////////////////////////////////////////////////////////////

	void insert_after(T &pos, T &obj)
	{
		list_node *node = get_node(&obj);
		list_node *n = get_node(&pos);
		n->next->prev = node;
		node->next = n->next;
		n->next = node;
		node->prev = n;
	}

	//////////////////////////////////////////////////////////////////////

	void remove(T *obj)
	{
		list_node *node = get_node(obj);
		node->prev->next = node->next;
		node->next->prev = node->prev;
	}

	//////////////////////////////////////////////////////////////////////

	void remove(T &obj)
	{
		list_node *node = get_node(&obj);
		node->prev->next = node->next;
		node->next->prev = node->prev;
	}

	//////////////////////////////////////////////////////////////////////

	T *pop_back()
	{
		list_node *node = root.prev;
		node->prev->next = node->next;
		node->next->prev = node->prev;
		return get_object(node);
	}

	//////////////////////////////////////////////////////////////////////

	T *pop_front()
	{
		list_node *node = root.next;
		node->next->prev = node->prev;
		node->prev->next = node->next;
		return get_object(node);
	}

	//////////////////////////////////////////////////////////////////////

	bool empty() const
	{
		return root.next == &root;
	}

	//////////////////////////////////////////////////////////////////////

	void clear()
	{
		root.next = root.prev = &root;
	}

	//////////////////////////////////////////////////////////////////////

	T *head() const
	{
		return get_object(root.next);
	}

	//////////////////////////////////////////////////////////////////////

	T *tail() const
	{
		return get_object(root.prev);
	}

	//////////////////////////////////////////////////////////////////////

	T const *end()
	{
		return get_object(&root);
	}

	//////////////////////////////////////////////////////////////////////

	T *next(T *i) const
	{
		return get_object(get_node(i)->next);
	}

	//////////////////////////////////////////////////////////////////////

	T *prev(T *i) const
	{
		return get_object(get_node(i)->prev);
	}

	//////////////////////////////////////////////////////////////////////

	bool for_each(std::function<bool (T *)> func)
	{
		for(T *i = head(); i != end(); i = next(i))
		{
			if(!func(i))
			{
				return false;
			}
		}
		return true;
	}

	//////////////////////////////////////////////////////////////////////

	T *find(std::function<bool (T *)> func)
	{
		for(T *i = head(); i != end(); i = next(i))
		{
			if(func(i))
			{
				return i;
			}
		}
		return nullptr;
	}

	//////////////////////////////////////////////////////////////////////

};

//////////////////////////////////////////////////////////////////////

// Yuck:

#define declare_linked_list(type_name, node_name) \
	linked_list<type_name, offsetof(type_name, node_name)>

#define typedef_linked_list(type_name, node_name) \
	typedef declare_linked_list(type_name, node_name)

//////////////////////////////////////////////////////////////////////
// main.cpp

#include <stdio.h>
#include <random>
//#include "linked_list.h"

struct foo
{
	foo() : i(rand() % 10) { }

	int i;

	list_node node1;	// would like it if these could be made private
	list_node node2;	// but the nasty macros need to see inside...
	list_node node3;	// getting rid of the macros would be even better
};

// None of these 3 options are very nice:

// 1. declare a list with the macro
declare_linked_list(foo, node1) list1;

// 2. or via a typedef
typedef_linked_list(foo, node2) list2_t;
list2_t list2;

// 3. or very wordy non-macro declaration
linked_list<foo, offsetof(foo, node3)> list3;

int main(int, char **)
{
	printf("Begin\n");
	
	foo foos[10];

	for(int i=0; i<10; ++i)
	{
		list1.push_back(foos[i]);
		list2.push_back(foos[i]);
		list3.push_back(foos[i]);
	}

	int sum = 0;
	int n = 0;
	// but this for loop is clear and readable and has very low overhead
	for(foo *i = list1.head(); i != list1.end(); i = list1.next(i))
	{
		sum += i->i;
	}
	printf("Total: %d\n", sum);

	list2.remove(foos[2]);

	n = 0;
	sum = 0;
	for(foo *i = list2.head(); i != list2.end(); i = list2.next(i))
	{
		sum += i->i;
	}
	printf("Total2: %d\n", sum);

	getchar();
	return 0;
}
