from collections import namedtuple, defaultdict
Resistor = namedtuple('Resistor', ['node1', 'node2', 'value'])
class Circuit(object):
def __init__(self, start, end, resistors):
self.endpoints = (start, end)
self.network = defaultdict(list)
for r in resistors:
self.network[tuple(sorted({r.node1, r.node2}))].append(int(r.value))
def nodes(self):
nodes = set()
for n in self.network.keys():
nodes = nodes.union(set(n))
return nodes
def neighbors(self, node):
neighbors = set()
for n in self.network.keys():
if node in n:
neighbors = neighbors.union(set(n))
neighbors.remove(node)
return neighbors
def equivalent_resistance(self):
def is_simplified():
if len(network) > 1:
return False
if len(list(network.values())[0]) > 1:
return False
return True
network = self.network
while not is_simplified():
# Simplify parallel
for nodes, resistances in network.items():
if len(resistances) == 1:
continue
equivalent = 1 / sum(1 / x for x in resistances)
network[nodes] = [equivalent]
# Simplify serial
for node in self.nodes():
if node in self.endpoints:
continue
neighbors = sorted(self.neighbors(node))
if len(neighbors) == 2:
key1 = tuple(sorted((neighbors[0], node)))
key2 = tuple(sorted((neighbors[1], node)))
r1 = network[key1]
r2 = network[key2]
if len(r1) != 1 or len(r2) != 1:
continue
r1 = r1[0]
r2 = r2[0]
del network[key1]
del network[key2]
network[tuple(neighbors)].append(r1 + r2)
return list(network.values())[0][0]
sample = '''A B C D E F
A C 5
A B 10
D A 5
D E 10
C E 10
E F 15
B F 20'''
nodes = sample.splitlines()[0].split()
resistors = [Resistor(*line.split()) for line in sample.splitlines()[1:]]
c = Circuit(nodes[0], nodes[-1], resistors)
print(c.equivalent_resistance())