Skip to content

Commit 0b88cba

Browse files
authored
Merge pull request #106 from Salehbigdeli/master
Make typedispatch decorator more general
2 parents a434735 + 6f8bcfd commit 0b88cba

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

fastcore/dispatch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(self, funcs=(), bases=()):
9797

9898
def add(self, f):
9999
"Add type `t` and function `f`"
100-
a0,a1 = _p2_anno(f)
100+
if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)
101+
else: a0,a1 = _p2_anno(f)
101102
t = self.funcs.d.get(a0)
102103
if t is None:
103104
t = _TypeDict()
@@ -121,7 +122,8 @@ def __call__(self, *args, **kwargs):
121122
ts = L(args).map(type)[:2]
122123
f = self[tuple(ts)]
123124
if not f: return args[0]
124-
if self.inst is not None: f = MethodType(f, self.inst)
125+
if isinstance(f, staticmethod): f = f.__func__
126+
elif self.inst is not None: f = MethodType(f, self.inst)
125127
elif self.owner is not None: f = MethodType(f, self.owner)
126128
return f(*args, **kwargs)
127129

@@ -152,7 +154,9 @@ class DispatchReg:
152154
"A global registry for `TypeDispatch` objects keyed by function name"
153155
def __init__(self): self.d = defaultdict(TypeDispatch)
154156
def __call__(self, f):
155-
nm = f'{f.__qualname__}'
157+
if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'
158+
else: nm = f'{f.__qualname__}'
159+
if isinstance(f, classmethod): f=f.__func__
156160
self.d[nm].add(f)
157161
return self.d[nm]
158162

nbs/03_dispatch.ipynb

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,8 @@
449449
"\n",
450450
" def add(self, f):\n",
451451
" \"Add type `t` and function `f`\"\n",
452-
" a0,a1 = _p2_anno(f)\n",
452+
" if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n",
453+
" else: a0,a1 = _p2_anno(f)\n",
453454
" t = self.funcs.d.get(a0)\n",
454455
" if t is None:\n",
455456
" t = _TypeDict()\n",
@@ -473,7 +474,8 @@
473474
" ts = L(args).map(type)[:2]\n",
474475
" f = self[tuple(ts)]\n",
475476
" if not f: return args[0]\n",
476-
" if self.inst is not None: f = MethodType(f, self.inst)\n",
477+
" if isinstance(f, staticmethod): f = f.__func__\n",
478+
" elif self.inst is not None: f = MethodType(f, self.inst)\n",
477479
" elif self.owner is not None: f = MethodType(f, self.owner)\n",
478480
" return f(*args, **kwargs)\n",
479481
"\n",
@@ -943,7 +945,7 @@
943945
{
944946
"data": {
945947
"text/markdown": [
946-
"<h4 id=\"TypeDispatch.__call__\" class=\"doc_header\"><code>TypeDispatch.__call__</code><a href=\"__main__.py#L33\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
948+
"<h4 id=\"TypeDispatch.__call__\" class=\"doc_header\"><code>TypeDispatch.__call__</code><a href=\"__main__.py#L34\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
947949
"\n",
948950
"> <code>TypeDispatch.__call__</code>(**\\*`args`**, **\\*\\*`kwargs`**)\n",
949951
"\n",
@@ -1031,7 +1033,7 @@
10311033
{
10321034
"data": {
10331035
"text/markdown": [
1034-
"<h4 id=\"TypeDispatch.returns\" class=\"doc_header\"><code>TypeDispatch.returns</code><a href=\"__main__.py#L21\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
1036+
"<h4 id=\"TypeDispatch.returns\" class=\"doc_header\"><code>TypeDispatch.returns</code><a href=\"__main__.py#L22\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
10351037
"\n",
10361038
"> <code>TypeDispatch.returns</code>(**`x`**)\n",
10371039
"\n",
@@ -1198,7 +1200,9 @@
11981200
" \"A global registry for `TypeDispatch` objects keyed by function name\"\n",
11991201
" def __init__(self): self.d = defaultdict(TypeDispatch)\n",
12001202
" def __call__(self, f):\n",
1201-
" nm = f'{f.__qualname__}'\n",
1203+
" if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'\n",
1204+
" else: nm = f'{f.__qualname__}'\n",
1205+
" if isinstance(f, classmethod): f=f.__func__\n",
12021206
" self.d[nm].add(f)\n",
12031207
" return self.d[nm]\n",
12041208
"\n",
@@ -1227,6 +1231,36 @@
12271231
"test_eq(f_td_test('a','b'), 'ab')"
12281232
]
12291233
},
1234+
{
1235+
"cell_type": "markdown",
1236+
"metadata": {},
1237+
"source": [
1238+
"#### Using typedispatch With other decorators\n",
1239+
"\n",
1240+
"You can use `typedispatch` with `classmethod` and `staticmethod` decorator"
1241+
]
1242+
},
1243+
{
1244+
"cell_type": "code",
1245+
"execution_count": null,
1246+
"metadata": {},
1247+
"outputs": [],
1248+
"source": [
1249+
"class A:\n",
1250+
" @typedispatch\n",
1251+
" def f_td_test(self, x:numbers.Integral, y): return x+1\n",
1252+
" @typedispatch\n",
1253+
" @classmethod\n",
1254+
" def f_td_test(cls, x:int, y:float): return x+y\n",
1255+
" @typedispatch\n",
1256+
" @staticmethod\n",
1257+
" def f_td_test(x:int, y:int): return x*y\n",
1258+
" \n",
1259+
"test_eq(A.f_td_test(3,2), 6)\n",
1260+
"test_eq(A.f_td_test(3,2.0), 5)\n",
1261+
"test_eq(A().f_td_test(3,'2.0'), 4)"
1262+
]
1263+
},
12301264
{
12311265
"cell_type": "markdown",
12321266
"metadata": {},

0 commit comments

Comments
 (0)