# File overload.py (which will be inlined instead of imprted)
__all__ = [
'overloaded_functions',
'overload',
'NotSupplied',
'not_supplied',
'OverloadedFunctionNotFound',
'get_args_as_dict'
]
from collections import defaultdict
from typing import Any, Type
overloaded_functions = defaultdict(list)
def overload(classname: str=None):
"""A decorator to overload functions and methods.
For a method specify its class name as the classname argument."""
def wrapper(f):
name = f.__name__ if not classname else f'{classname}.{f.__name__}'
overloaded_functions[name].append(f)
return f
return wrapper
class NotSupplied:
pass
not_supplied = NotSupplied() # To distinguish arguments that are not supplied
class OverloadedFunctionNotFound(RuntimeError):
def __init__(self):
super().__init__('No overloaded function could be found for the arguments passed.')
def get_args_as_dict(
arg_specs: list[list[tuple[int, str, Type]]],
args: list[Any],
kwargs: dict[str, Any]
) -> dict[str, Any]:
"""When overloaded arguments vary by names or types this function will return a dictionary
with the actual argument names and values found.
The input is a list. Each element of the list is a list of tuples where
each tuple specifies an alternate possibility. The tuple specifies what position
in the argument list this argument should be found if its positional, the name of the
argument and the type of the argument.
For example: args_specs is [[(1, 'sku', str), (1, 'id', int)]]
The functions input arguments should have either an sku argument of type str
or an id argument of type int. If the argument is not a keyword argument then
we look at the type of the positional argument at index 1 for both alternates and select
the correct alternate based on type. Clearly, we cannot have two signatures that are identical
except for the argument name varying unless that argument is defined as keyword only.
"""
args_dict = {}
found_kwarg = False
n_args = len(args)
for arg_spec_list in arg_specs:
for position, arg_name, cls in arg_spec_list:
found_type = False
if arg_name in kwargs and isinstance(kwargs[arg_name], cls):
found_type = True
break
if position < n_args and isinstance(args[position], cls):
args_dict[arg_name] = args[position]
found_type = True
break
if not found_type:
raise OverloadedFunctionNotFound()
args_dict.update(kwargs)
return args_dict
######################################################
# File main.py
from typing import Any
#from overload import *
# Some examples of how we can determine which overloaded function to call:
# The OP's example:
from enum import Enum
class Format(Enum):
_invalid = 0
CD = 1
Vynil = 2
class Artist:
def __init__(self, id):
self.id = id
@classmethod
def fromId(cls, db, artistId):
return cls(artistId)
class SomeDbAbstraction:
def run_query(self, query : str) -> dict[str,Any]:
print('Query executed:', query)
return {
'sku': 'Some SKU',
'title': 'Some Title',
'artist': Artist(3),
'description': 'Some description',
'format': Format.CD,
'price': 10.00,
'id': 999
}
def get_insert_id(self) -> int:
return 1
class Product:
# use a dict instead of slots for an easy
# __repr__ implementation
"""
__slots__ = (
"__id",
"sku",
"title",
"artist",
"description",
"format",
"price"
)
"""
def __repr__(self):
return str(self.__dict__)
@overload('Product')
def __init__(
self,
db : SomeDbAbstraction,
sku : str,
title : str,
artist : Artist,
description : str,
format : Format,
price : float
) -> None:
print('__init__ #0')
db.run_query(f"""
INSERT INTO
Product
SET
sku = '{sku}',
title = {title},
artist = {artist.id},
description = {description},
format = {format.name},
price = {price}
;
""")
self.__id = db.get_insert_id()
self.sku = title
self.title = title
self.artist = artist
self.description = description
self.format = format
self.price = price
@overload('Product')
def __init__(self, db : SomeDbAbstraction, id : int):
print('__init__ #1')
data : dict[str,Any] = db.run_query(f"SELECT * FROM Product WHERE id = {id};")
self.__id = data.get("id")
self.sku = data.get("title")
self.title = data.get("title")
self.artist = Artist.fromId(db, data.get("artist"))
self.description = data.get("description")
self.format = Format(data.get("format"))
self.price = data.get("price")
@overload('Product')
def __init__(self, db : SomeDbAbstraction, sku : str):
print('__init__ #2')
data : dict[str,Any] = db.run_query(f"SELECT * FROM Product WHERE sku = '{sku}';")
self.__id = data.get("id")
self.sku = data.get("title")
self.title = data.get("title")
self.artist = Artist.fromId(db, data.get("artist"))
self.description = data.get("description")
self.format = Format(data.get("format"))
self.price = data.get("price")
@overload('Product')
def __init__(
self,
db : SomeDbAbstraction,
sku : str,
title : str,
artistId : int,
description : str,
format : Format,
price : float
) -> None:
print('__init__ #3')
db.run_query(f"""
INSERT INTO
Product
SET
sku = '{sku}',
title = {title},
artist = {artistId},
description = {description},
format = {format.name},
price = {price}
;
""")
self.__id = db.get_insert_id()
self.sku = title
self.title = title
self.artist = Artist.fromId(db, artistId)
self.description = description
self.format = format
self.price = price
def __init__(self, *args, **kwargs):
"""The dispatcher."""
n_args = len(args)
n_kwargs = len(kwargs)
total_args = n_args + n_kwargs
if total_args == 2:
# __init__ 1 or __init__ 2 according to whether the keyword arguments
# contain `id` or `sku` with the expected type or if the 2nd positional
# argument is one of the expected types:
d = get_args_as_dict([[(1, 'id', int), (1, 'sku', str)]], args, kwargs)
idx = 1 if 'id' in d else 2
elif total_args == 7:
# __init__ 0 or __init__ 3 according to whether the keyword arguments
# contain `artist` or `artistId` with the expected type or if the 4th positional
# argument is one of the expected types:
d = get_args_as_dict([[(3, 'artist', Artist), (3, 'artistId', int)]], args, kwargs)
idx = 0 if 'artist' in d else 3
else:
raise OverloadedFunctionNotFound()
overloaded_functions['Product.__init__'][idx](self, *args, **kwargs)
print('Example A')
print(Product(
SomeDbAbstraction(),
'some sku',
'some title',
Artist(7),
'some description',
Format.CD,
10.00), end='\n\n')
print('Example B')
print(Product(
SomeDbAbstraction(),
'some sku',
'some title',
9,
'some description',
Format.CD,
10.00), end='\n\n')
print('Example C')
print(Product(SomeDbAbstraction(), id=17), end='\n\n')
print('Example D')
print(Product(SomeDbAbstraction(), 'some sku'), end='\n\n')
###############################################################################
# Other examples:
# Example 1: Here one of the two overloaded functions takes an extra argument:
# This could have been done without overloading by specifying a default
# value for the second argument:
from math import log
@overload()
def my_log_fn(n: int | float) -> float:
"""Return log(n, 10)."""
return log(n, 10)
@overload()
def my_log_fn(n: int | float, base: int) -> float:
"""Return log(n, base)."""
return log(n, base)
def my_log_fn(n: int | float, base: int | NotSupplied=not_supplied) -> float:
if base is not_supplied:
return overloaded_functions['my_log_fn'][0](n)
return overloaded_functions['my_log_fn'][1](n, base)
print('Example 1:', my_log_fn(1_000_000), my_log_fn(1_000_000, 1_000))
###############################################################################
# Example 2: We can distinguish which overloaded function to call based on
# the number of arguments. This is easy if we accept only
# positional arguments or only keyword arguments:
@overload()
def add_to(s: set, value: object, /) -> None:
"""Add a value to a set."""
s.add(value)
@overload()
def add_to(d: dict, key: object, value: object, /) -> None:
"""Add a value to a dictionary."""
d[key] = value
# The actual implementation
def add_to(*args) -> None:
"""Decide which overloaded function to call based on number of arguments."""
n_args = len(args)
if n_args == 2:
return overloaded_functions['add_to'][0](*args)
if n_args == 3:
return overloaded_functions['add_to'][1](*args)
raise OverloadedFunctionNotFound()
d = {}
s = set()
add_to(d, 'a', 1)
add_to(s, 2)
print('Example 2:', f'd = {d}, s = {s}')
###############################################################################
# Example 3: We can distinguish which overloaded function to call based on
# the arguments types. This is easy when the signatures
# of the overloaded functions are identical except for
# the types:
@overload()
def foo(a: int, b: int) -> int:
"""Divide a by b and return result."""
return a // b
@overload()
def foo(a: float, b: float) -> float:
"""Divide a by b and return result."""
return a / b
def foo(a: int | float, b: int | float) -> int | float:
if isinstance(a, int) and isinstance(b, int):
return overloaded_functions['foo'][0](a, b)
# We will not require both a and b to be floats
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return overloaded_functions['foo'][1](float(a), float(b))
raise OverloadedFunctionNotFound()
print('Example 3:', foo(7, 2), foo(7.0, 2.0))
###############################################################################
# Example 4: Same as Example 3 but with a class:
class TestClass:
@overload('TestClass')
def foo(self, a: int, b: int) -> int:
"""Divide a by b and return result."""
return a // b
@overload('TestClass')
def foo(self, a: float, b: float) -> float:
"""Divide a by b and return result."""
return a / b
def foo(self, a: int | float, b: int | float) -> int | float:
if isinstance(a, int) and isinstance(b, int):
return overloaded_functions['TestClass.foo'][0](self, a, b)
# We will not require both a and b to be floats
if isinstance(a, (int, float)) and isinstance(b, (int, float)):
return overloaded_functions['TestClass.foo'][1](self, float(a), float(b))
raise OverloadedFunctionNotFound()
test_class = TestClass()
print('Example 4:', test_class.foo(7, 2), test_class.foo(7.0, 2.0))
