#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
#include <assert.h>
using namespace std;
template <class PtrType, class LockType>
class ELockWrapper
{
public:
class Proxy {
public:
Proxy(PtrType* p, LockType* lock) : ptr(p), mLock(lock) { mLock->Lock(); }
~Proxy() { mLock->Unlock(); }
PtrType* operator->() { return ptr; }
PtrType operator*() { return *ptr; }
private:
PtrType* ptr;
LockType* mLock;
};
ELockWrapper() : ptr(nullptr), lock(nullptr) {}
ELockWrapper(nullptr_t t) : ELockWrapper() {}
ELockWrapper(PtrType *p, LockType* l) : ptr(p), lock(l) {}
ELockWrapper(PtrType *p, LockType& l) : ptr(p), lock(&l) {}
ELockWrapper(const ELockWrapper& copy) = default;
ELockWrapper& operator=(const ELockWrapper& x) = default;
bool operator==(const ELockWrapper& cmp) { return cmp.ptr == ptr; }
bool operator!=(const ELockWrapper& cmp) { return !operator==(cmp); }
bool operator==(PtrType* t) { return ptr == t; }
bool operator!=(PtrType* t) { return ptr != t; }
bool operator==(bool b) { return (ptr && b) || (!ptr && !b); }
bool operator!=(bool b) { return !operator==(b); }
operator bool() const { return ptr; }
Proxy operator->() {
return Proxy(ptr, lock);
}
PtrType operator*() {
return *Proxy(ptr, lock);
}
void Delete() {
Proxy(ptr, lock);
delete ptr;
}
private:
PtrType* ptr;
LockType* lock;
};
/* Anything below this is for testing purposes */
struct TestClass {
TestClass(int i) : x(i) {}
void IncX(int numIters) { for(int i = 0; i < numIters; ++i) ++x; }
int x;
};
struct TestLock {
TestLock(std::mutex& i) : m(i) {}
void Lock() { m.lock(); }
void Unlock() { m.unlock(); }
std::mutex& m;
};
void IncThr(ELockWrapper<TestClass, TestLock>tsp, int numIters) {
tsp->IncX(numIters);
}
int main() {
TestClass obj(25);
std::mutex m;
TestLock lock(m);
ELockWrapper<TestClass, TestLock> thread_safe_ptr(&obj, &lock);
ELockWrapper<TestClass, TestLock> other_thread_safe_ptr = nullptr;
auto thread_safe_ptr_copy = thread_safe_ptr;
obj.x = 26;
if(thread_safe_ptr_copy->x != thread_safe_ptr->x) { cout << "Broken - copy" << endl; }
if(thread_safe_ptr->x != obj.x) { cout << "Broken - access" << endl; }
if((*thread_safe_ptr).x != obj.x) { cout << "Broken - dereference" << endl; }
if(thread_safe_ptr != &obj) { cout << "Broken - comparison with ptr" << endl; }
if(other_thread_safe_ptr == thread_safe_ptr) { cout << "Broken - comparison with other wrapper" << endl; }
thread_safe_ptr = nullptr;
if(!thread_safe_ptr || thread_safe_ptr == nullptr || thread_safe_ptr == false) { }
else {
cout << "Broken - nullptr comparisons" << endl;
}
thread_safe_ptr_copy->x = 0;
std::vector<std::thread> threads;
for(int i = 0; i < 8; ++i) threads.push_back(std::thread(&IncThr, thread_safe_ptr_copy, 1000000));
for(int i = 0; i < 8; ++i) threads[i].join();
if(thread_safe_ptr_copy->x != 8000000) { cout << "Broken - concurrency - " << thread_safe_ptr_copy->x << endl; }
TestClass* deallocObj = new TestClass(100);
ELockWrapper<TestClass, TestLock> delete_ptr(deallocObj, &lock);
delete_ptr.Delete();
return 0;
}