import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;

class Scratch {

    private static final Map<Class<? extends Entity>, List<Entity>> entities
        = new HashMap<>();

    public static void main(String[] args) {
        SomeEnemy firstSomeEnemy = new SomeEnemy();
        add(firstSomeEnemy);

        SomeEnemy secondSomeEnemy = new SomeEnemy();
        add(secondSomeEnemy);

        AnotherEnemy firstAnotherEnemy = new AnotherEnemy();
        add(firstAnotherEnemy);

        AnotherEnemy secondAnotherEnemy = new AnotherEnemy();
        add(secondAnotherEnemy);

        List<SomeEnemy> fetchOne = getAll(SomeEnemy.class);
        SomeEnemy firstFetched = fetchOne.get(0);
        fetchOne.add(new SomeEnemy());
        System.out.println(firstFetched == firstSomeEnemy);

        remove(firstFetched);
        List<SomeEnemy> fetchTwo = getAll(SomeEnemy.class);
        SomeEnemy secondFetched = fetchTwo.get(0);
        System.out.println(secondFetched == secondSomeEnemy);

        remove(secondSomeEnemy);
        List<SomeEnemy> fetchThree = getAll(SomeEnemy.class);
        System.out.println(fetchThree.isEmpty());

        List<AnotherEnemy> fetchFour = getAll(AnotherEnemy.class);
        AnotherEnemy fourthFetched = fetchFour.get(0);
        System.out.println(fourthFetched == firstAnotherEnemy);

        remove(firstAnotherEnemy);
        List<AnotherEnemy> fetchFive = getAll(AnotherEnemy.class);
        AnotherEnemy fifthFetched = fetchFive.get(0);
        System.out.println(fifthFetched == secondAnotherEnemy);

        remove(secondAnotherEnemy);
        List<AnotherEnemy> fetchSix = getAll(AnotherEnemy.class);
        System.out.println(fetchSix.isEmpty());

    }

    static void add(Entity entity) {
        Class<? extends Entity> type = entity.getClass();

        if (!entities.containsKey(type)) {
            entities.put(type, new CopyOnWriteArrayList<>());
        }

        entities.get(type).add(entity);
    }

    private static <T extends Entity> List<T> getAll(Class<? extends T> type) {
        return new ArrayList<>(getListCasted(type));
    }

    @SuppressWarnings("unchecked")
    private static <T extends Entity> List<T> getListCasted(Class<? extends T> type) {
        return (List<T>) entities.getOrDefault(type, Collections.emptyList());
    }

    public static void remove(Entity entity) {
        getListCasted(entity.getClass()).remove(entity);
    }
}

interface Entity {}
interface Enemy extends Entity {}

class SomeEnemy implements Enemy {}
class AnotherEnemy implements Enemy {}