fork download
  1. # File overload.py (which will be inlined instead of imprted)
  2.  
  3. __all__ = [
  4. 'overloaded_functions',
  5. 'overload',
  6. 'NotSupplied',
  7. 'not_supplied',
  8. 'OverloadedFunctionNotFound',
  9. 'get_args_as_dict'
  10. ]
  11.  
  12. from collections import defaultdict
  13. from typing import Any, Type
  14.  
  15. overloaded_functions = defaultdict(list)
  16.  
  17. def overload(classname: str=None):
  18. """A decorator to overload functions and methods.
  19. For a method specify its class name as the classname argument."""
  20. def wrapper(f):
  21. name = f.__name__ if not classname else f'{classname}.{f.__name__}'
  22. overloaded_functions[name].append(f)
  23. return f
  24. return wrapper
  25.  
  26. class NotSupplied:
  27. pass
  28.  
  29. not_supplied = NotSupplied() # To distinguish arguments that are not supplied
  30.  
  31. class OverloadedFunctionNotFound(RuntimeError):
  32. def __init__(self):
  33. super().__init__('No overloaded function could be found for the arguments passed.')
  34.  
  35. def get_args_as_dict(
  36. arg_specs: list[list[tuple[int, str, Type]]],
  37. args: list[Any],
  38. kwargs: dict[str, Any]
  39. ) -> dict[str, Any]:
  40.  
  41. """When overloaded arguments vary by names or types this function will return a dictionary
  42. with the actual argument names and values found.
  43.  
  44. The input is a list. Each element of the list is a list of tuples where
  45. each tuple specifies an alternate possibility. The tuple specifies what position
  46. in the argument list this argument should be found if its positional, the name of the
  47. argument and the type of the argument.
  48.  
  49. For example: args_specs is [[(1, 'sku', str), (1, 'id', int)]]
  50. The functions input arguments should have either an sku argument of type str
  51. or an id argument of type int. If the argument is not a keyword argument then
  52. we look at the type of the positional argument at index 1 for both alternates and select
  53. the correct alternate based on type. Clearly, we cannot have two signatures that are identical
  54. except for the argument name varying unless that argument is defined as keyword only.
  55. """
  56.  
  57. args_dict = {}
  58. found_kwarg = False
  59. n_args = len(args)
  60.  
  61. for arg_spec_list in arg_specs:
  62. for position, arg_name, cls in arg_spec_list:
  63. found_type = False
  64. if arg_name in kwargs and isinstance(kwargs[arg_name], cls):
  65. found_type = True
  66. break
  67. if position < n_args and isinstance(args[position], cls):
  68. args_dict[arg_name] = args[position]
  69. found_type = True
  70. break
  71. if not found_type:
  72. raise OverloadedFunctionNotFound()
  73.  
  74. args_dict.update(kwargs)
  75. return args_dict
  76.  
  77. ######################################################
  78.  
  79. # File main.py
  80.  
  81. from typing import Any
  82.  
  83. #from overload import *
  84.  
  85. # Some examples of how we can determine which overloaded function to call:
  86.  
  87. # The OP's example:
  88.  
  89. from enum import Enum
  90.  
  91. class Format(Enum):
  92. _invalid = 0
  93. CD = 1
  94. Vynil = 2
  95.  
  96. class Artist:
  97. def __init__(self, id):
  98. self.id = id
  99.  
  100. @classmethod
  101. def fromId(cls, db, artistId):
  102. return cls(artistId)
  103.  
  104.  
  105. class SomeDbAbstraction:
  106. def run_query(self, query : str) -> dict[str,Any]:
  107. print('Query executed:', query)
  108. return {
  109. 'sku': 'Some SKU',
  110. 'title': 'Some Title',
  111. 'artist': Artist(3),
  112. 'description': 'Some description',
  113. 'format': Format.CD,
  114. 'price': 10.00,
  115. 'id': 999
  116. }
  117.  
  118. def get_insert_id(self) -> int:
  119. return 1
  120.  
  121. class Product:
  122. # use a dict instead of slots for an easy
  123. # __repr__ implementation
  124. """
  125. __slots__ = (
  126. "__id",
  127. "sku",
  128. "title",
  129. "artist",
  130. "description",
  131. "format",
  132. "price"
  133. )
  134. """
  135.  
  136. def __repr__(self):
  137. return str(self.__dict__)
  138.  
  139. @overload('Product')
  140. def __init__(
  141. self,
  142. db : SomeDbAbstraction,
  143. sku : str,
  144. title : str,
  145. artist : Artist,
  146. description : str,
  147. format : Format,
  148. price : float
  149. ) -> None:
  150. print('__init__ #0')
  151. db.run_query(f"""
  152. INSERT INTO
  153. Product
  154. SET
  155. sku = '{sku}',
  156. title = {title},
  157. artist = {artist.id},
  158. description = {description},
  159. format = {format.name},
  160. price = {price}
  161. ;
  162. """)
  163. self.__id = db.get_insert_id()
  164. self.sku = title
  165. self.title = title
  166. self.artist = artist
  167. self.description = description
  168. self.format = format
  169. self.price = price
  170.  
  171. @overload('Product')
  172. def __init__(self, db : SomeDbAbstraction, id : int):
  173. print('__init__ #1')
  174. data : dict[str,Any] = db.run_query(f"SELECT * FROM Product WHERE id = {id};")
  175. self.__id = data.get("id")
  176. self.sku = data.get("title")
  177. self.title = data.get("title")
  178. self.artist = Artist.fromId(db, data.get("artist"))
  179. self.description = data.get("description")
  180. self.format = Format(data.get("format"))
  181. self.price = data.get("price")
  182.  
  183. @overload('Product')
  184. def __init__(self, db : SomeDbAbstraction, sku : str):
  185. print('__init__ #2')
  186. data : dict[str,Any] = db.run_query(f"SELECT * FROM Product WHERE sku = '{sku}';")
  187. self.__id = data.get("id")
  188. self.sku = data.get("title")
  189. self.title = data.get("title")
  190. self.artist = Artist.fromId(db, data.get("artist"))
  191. self.description = data.get("description")
  192. self.format = Format(data.get("format"))
  193. self.price = data.get("price")
  194.  
  195. @overload('Product')
  196. def __init__(
  197. self,
  198. db : SomeDbAbstraction,
  199. sku : str,
  200. title : str,
  201. artistId : int,
  202. description : str,
  203. format : Format,
  204. price : float
  205. ) -> None:
  206. print('__init__ #3')
  207. db.run_query(f"""
  208. INSERT INTO
  209. Product
  210. SET
  211. sku = '{sku}',
  212. title = {title},
  213. artist = {artistId},
  214. description = {description},
  215. format = {format.name},
  216. price = {price}
  217. ;
  218. """)
  219. self.__id = db.get_insert_id()
  220. self.sku = title
  221. self.title = title
  222. self.artist = Artist.fromId(db, artistId)
  223. self.description = description
  224. self.format = format
  225. self.price = price
  226.  
  227. def __init__(self, *args, **kwargs):
  228. """The dispatcher."""
  229. n_args = len(args)
  230. n_kwargs = len(kwargs)
  231. total_args = n_args + n_kwargs
  232.  
  233. if total_args == 2:
  234. # __init__ 1 or __init__ 2 according to whether the keyword arguments
  235. # contain `id` or `sku` with the expected type or if the 2nd positional
  236. # argument is one of the expected types:
  237. d = get_args_as_dict([[(1, 'id', int), (1, 'sku', str)]], args, kwargs)
  238. idx = 1 if 'id' in d else 2
  239. elif total_args == 7:
  240. # __init__ 0 or __init__ 3 according to whether the keyword arguments
  241. # contain `artist` or `artistId` with the expected type or if the 4th positional
  242. # argument is one of the expected types:
  243. d = get_args_as_dict([[(3, 'artist', Artist), (3, 'artistId', int)]], args, kwargs)
  244. idx = 0 if 'artist' in d else 3
  245. else:
  246. raise OverloadedFunctionNotFound()
  247.  
  248. overloaded_functions['Product.__init__'][idx](self, *args, **kwargs)
  249.  
  250. print('Example A')
  251. print(Product(
  252. SomeDbAbstraction(),
  253. 'some sku',
  254. 'some title',
  255. Artist(7),
  256. 'some description',
  257. Format.CD,
  258. 10.00), end='\n\n')
  259.  
  260. print('Example B')
  261. print(Product(
  262. SomeDbAbstraction(),
  263. 'some sku',
  264. 'some title',
  265. 9,
  266. 'some description',
  267. Format.CD,
  268. 10.00), end='\n\n')
  269.  
  270. print('Example C')
  271. print(Product(SomeDbAbstraction(), id=17), end='\n\n')
  272.  
  273. print('Example D')
  274. print(Product(SomeDbAbstraction(), 'some sku'), end='\n\n')
  275.  
  276. ###############################################################################
  277.  
  278. # Other examples:
  279.  
  280. # Example 1: Here one of the two overloaded functions takes an extra argument:
  281. # This could have been done without overloading by specifying a default
  282. # value for the second argument:
  283.  
  284. from math import log
  285.  
  286. @overload()
  287. def my_log_fn(n: int | float) -> float:
  288. """Return log(n, 10)."""
  289.  
  290. return log(n, 10)
  291.  
  292. @overload()
  293. def my_log_fn(n: int | float, base: int) -> float:
  294. """Return log(n, base)."""
  295.  
  296. return log(n, base)
  297.  
  298. def my_log_fn(n: int | float, base: int | NotSupplied=not_supplied) -> float:
  299. if base is not_supplied:
  300. return overloaded_functions['my_log_fn'][0](n)
  301.  
  302. return overloaded_functions['my_log_fn'][1](n, base)
  303.  
  304. print('Example 1:', my_log_fn(1_000_000), my_log_fn(1_000_000, 1_000))
  305.  
  306. ###############################################################################
  307.  
  308. # Example 2: We can distinguish which overloaded function to call based on
  309. # the number of arguments. This is easy if we accept only
  310. # positional arguments or only keyword arguments:
  311.  
  312. @overload()
  313. def add_to(s: set, value: object, /) -> None:
  314. """Add a value to a set."""
  315. s.add(value)
  316.  
  317. @overload()
  318. def add_to(d: dict, key: object, value: object, /) -> None:
  319. """Add a value to a dictionary."""
  320. d[key] = value
  321.  
  322. # The actual implementation
  323. def add_to(*args) -> None:
  324. """Decide which overloaded function to call based on number of arguments."""
  325.  
  326. n_args = len(args)
  327.  
  328. if n_args == 2:
  329. return overloaded_functions['add_to'][0](*args)
  330.  
  331. if n_args == 3:
  332. return overloaded_functions['add_to'][1](*args)
  333.  
  334. raise OverloadedFunctionNotFound()
  335.  
  336.  
  337. d = {}
  338. s = set()
  339.  
  340. add_to(d, 'a', 1)
  341. add_to(s, 2)
  342.  
  343. print('Example 2:', f'd = {d}, s = {s}')
  344.  
  345. ###############################################################################
  346.  
  347. # Example 3: We can distinguish which overloaded function to call based on
  348. # the arguments types. This is easy when the signatures
  349. # of the overloaded functions are identical except for
  350. # the types:
  351.  
  352. @overload()
  353. def foo(a: int, b: int) -> int:
  354. """Divide a by b and return result."""
  355.  
  356. return a // b
  357.  
  358. @overload()
  359. def foo(a: float, b: float) -> float:
  360. """Divide a by b and return result."""
  361.  
  362. return a / b
  363.  
  364. def foo(a: int | float, b: int | float) -> int | float:
  365. if isinstance(a, int) and isinstance(b, int):
  366. return overloaded_functions['foo'][0](a, b)
  367.  
  368. # We will not require both a and b to be floats
  369. if isinstance(a, (int, float)) and isinstance(b, (int, float)):
  370. return overloaded_functions['foo'][1](float(a), float(b))
  371.  
  372. raise OverloadedFunctionNotFound()
  373.  
  374. print('Example 3:', foo(7, 2), foo(7.0, 2.0))
  375.  
  376. ###############################################################################
  377.  
  378. # Example 4: Same as Example 3 but with a class:
  379.  
  380. class TestClass:
  381. @overload('TestClass')
  382. def foo(self, a: int, b: int) -> int:
  383. """Divide a by b and return result."""
  384.  
  385. return a // b
  386.  
  387. @overload('TestClass')
  388. def foo(self, a: float, b: float) -> float:
  389. """Divide a by b and return result."""
  390.  
  391. return a / b
  392.  
  393. def foo(self, a: int | float, b: int | float) -> int | float:
  394. if isinstance(a, int) and isinstance(b, int):
  395. return overloaded_functions['TestClass.foo'][0](self, a, b)
  396.  
  397. # We will not require both a and b to be floats
  398. if isinstance(a, (int, float)) and isinstance(b, (int, float)):
  399. return overloaded_functions['TestClass.foo'][1](self, float(a), float(b))
  400.  
  401. raise OverloadedFunctionNotFound()
  402.  
  403. test_class = TestClass()
  404. print('Example 4:', test_class.foo(7, 2), test_class.foo(7.0, 2.0))
  405.  
  406.  
Success #stdin #stdout 0.22s 17216KB
stdin
Standard input is empty
stdout
Example A
__init__ #0
Query executed: 
            INSERT INTO
                Product
            SET
                sku = 'some sku',
                title = some title,
                artist = 7,
                description = some description,
                format = CD,
                price = 10.0
            ;
        
{'_Product__id': 1, 'sku': 'some title', 'title': 'some title', 'artist': <__main__.Artist object at 0x14a06397b800>, 'description': 'some description', 'format': <Format.CD: 1>, 'price': 10.0}

Example B
__init__ #3
Query executed: 
            INSERT INTO
                Product
            SET
                sku = 'some sku',
                title = some title,
                artist = 9,
                description = some description,
                format = CD,
                price = 10.0
            ;
        
{'_Product__id': 1, 'sku': 'some title', 'title': 'some title', 'artist': <__main__.Artist object at 0x14a06397b7a0>, 'description': 'some description', 'format': <Format.CD: 1>, 'price': 10.0}

Example C
__init__ #1
Query executed: SELECT * FROM Product WHERE id = 17;
{'_Product__id': 999, 'sku': 'Some Title', 'title': 'Some Title', 'artist': <__main__.Artist object at 0x14a06397b6b0>, 'description': 'Some description', 'format': <Format.CD: 1>, 'price': 10.0}

Example D
__init__ #2
Query executed: SELECT * FROM Product WHERE sku = 'some sku';
{'_Product__id': 999, 'sku': 'Some Title', 'title': 'Some Title', 'artist': <__main__.Artist object at 0x14a06397b7a0>, 'description': 'Some description', 'format': <Format.CD: 1>, 'price': 10.0}

Example 1: 5.999999999999999 2.0
Example 2: d = {'a': 1}, s = {2}
Example 3: 3 3.5
Example 4: 3 3.5