import gc
import types
import inspect
from typing import Optional
from weakref import WeakKeyDictionary

def proxy0(data):
    def proxy1(): return data
    return proxy1

_CELLTYPE = type(proxy0(None).__closure__[0])

def replace_all_refs(org_obj, new_obj):
    gc.collect()

    hit = False
    for referrer in gc.get_referrers(org_obj):
        if isinstance(referrer, types.FrameType):
            continue
        if isinstance(referrer, dict):
            cls = None
            if '__dict__' in referrer and '__weakref__' in referrer:
                for cls in gc.get_referrers(referrer):
                    if inspect.isclass(cls) and cls.__dict__ == referrer:
                        break
            for key, value in referrer.items():
                if value is org_obj:
                    hit = True
                    value = new_obj
                    referrer[key] = value
                    if cls:
                        setattr(cls, key, new_obj)
                if key is org_obj:
                    hit = True
                    del referrer[key]
                    referrer[new_obj] = value
        elif isinstance(referrer, list):
            for i, value in enumerate(referrer):
                if value is org_obj:
                    hit = True
                    referrer[i] = new_obj
        elif isinstance(referrer, set):
            referrer.remove(org_obj)
            referrer.add(new_obj)
            hit = True
        elif isinstance(referrer, (tuple, frozenset,)):
            new_tuple = []
            for obj in referrer:
                if obj is org_obj:
                    new_tuple.append(new_obj)
                else:
                    new_tuple.append(obj)
            replace_all_refs(referrer, type(referrer)(new_tuple))
        elif isinstance(referrer, _CELLTYPE):
            def proxy0(data):
                def proxy1(): return data
                return proxy1
            proxy = proxy0(new_obj)
            newcell = proxy.__closure__[0]
            replace_all_refs(referrer, newcell)
        elif isinstance(referrer, types.FunctionType):
            localsmap = {}
            for key in ['__code__', '__globals__', '__name__',
                        '__defaults__', '__closure__']:
                orgattr = getattr(referrer, key)
                if orgattr is org_obj:
                    localsmap[key[2:-2]] = new_obj
                else:
                    localsmap[key[2:-2]] = orgattr
            localsmap['argdefs'] = localsmap['defaults']
            del localsmap['defaults']
            newfn = types.FunctionType(**localsmap)
            replace_all_refs(referrer, newfn)
    if hit is False:
        raise AttributeError("Object '%r' not found" % org_obj)
    return org_obj

class DeferredStrProperty:
    instances = WeakKeyDictionary()

    def __init__(self, attribute):
        self.attribute = attribute

    def __get__(self, obj, objtype=None):
        if obj is None:
            return vars(objtype)[self.attribute]
        if (deferred_str := self.instances.get(obj)) is None:
            return vars(obj)[self.attribute]
        return deferred_str

    def __set__(self, obj, value):
        if (deferred_str := self.instances.get(obj)) is None:
            vars(obj)[self.attribute] = value
        else:
            replace_all_refs(deferred_str, value)
            del self.instances[obj]

class DeferredStr(str):
    def __new__(cls, instance, attribute):
        if not isinstance(getattr(type(instance), attribute), DeferredStrProperty):
            setattr(type(instance), attribute, DeferredStrProperty(attribute))
        DeferredStrProperty.instances[instance] = s = super().__new__(cls, 'x')
        return s

class Test:
    value: Optional[str] = None

instance = Test()
instance2 = Test()
deferred_str = DeferredStr(instance, "value")
print("2" + deferred_str)
print("".join((deferred_str,)))
print(instance.value)
print('some time later...')
instance2.value = '3'
instance.value = '1'
print(instance2.value)
print(deferred_str)
print(deferred_str + "2")
print("".join(deferred_str))
print("2" + deferred_str)
print("".join((deferred_str,)))
