Skip to content

Commit c76a04c

Browse files
benjefferyhyanwong
authored andcommitted
Switch to ruff linting
1 parent 0e25baf commit c76a04c

36 files changed

+338
-467
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- run:
2525
name: Lint Python
2626
command: |
27-
flake8 --max-line-length 89 tsdate setup.py tests
27+
ruff check --max-line-length 90 tsdate setup.py tests
2828
- save_cache:
2929
key: tsdate-{{ checksum "data/prior_1000df.bak" }}
3030
paths:

.pre-commit-config.yaml

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,10 @@ repos:
77
- id: mixed-line-ending
88
- id: check-case-conflict
99
- id: check-yaml
10-
- repo: https://github.com/asottile/reorder_python_imports
11-
rev: v3.10.0
10+
- repo: https://github.com/astral-sh/ruff-pre-commit
11+
rev: v0.4.5
1212
hooks:
13-
- id: reorder-python-imports
14-
- repo: https://github.com/asottile/pyupgrade
15-
rev: v3.10.1
16-
hooks:
17-
- id: pyupgrade
18-
args: [--py3-plus, --py38-plus]
19-
- repo: https://github.com/psf/black
20-
rev: 23.7.0
21-
hooks:
22-
- id: black
23-
language_version: python3
24-
- repo: https://github.com/pycqa/flake8
25-
rev: 6.1.0
26-
hooks:
27-
- id: flake8
28-
args: [--config=.flake8]
29-
additional_dependencies: ["flake8-bugbear==23.7.10", "flake8-builtins==2.1.0"]
13+
- id: ruff
14+
args: [ "--fix", "--config", "ruff.toml" ]
15+
- id: ruff-format
16+
args: [ "--config", "ruff.toml" ]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
tskit>=0.5.0
22
tsinfer>=0.3.0
3-
flake8
3+
ruff
44
numpy
55
tqdm
66
daiquiri

ruff.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
line-length = 90
2+
3+
[lint]
4+
select = ["E", "F", "B", "W", "I", "N", "UP", "A", "RUF", "PT", "NPY"]
5+
# N803,806,802 Allow capital varnames
6+
# E741 Allow "l" as var name
7+
# PT011 allow pytest raises without match
8+
ignore = ["N803", "N806", "N802", "E741", "PT011", "PT009"]
9+
10+
[lint.isort]
11+
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
12+
known-first-party = ["tsdate"]

tests/distribution_functions.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
Utility functions to construct distributions used in variational inference,
2525
for testing purposes
2626
"""
27+
2728
import mpmath
2829
import numpy as np
2930
import scipy.integrate
3031
import scipy.special
3132

32-
from tsdate import approx
33-
from tsdate import hypergeo
33+
from tsdate import approx, hypergeo
3434

3535

3636
def kl_divergence(p, logq):
@@ -81,9 +81,7 @@ def pr_a(a, n, k):
8181
if n == k:
8282
return pr_t_bar_a(t, 1)
8383
else:
84-
return np.sum(
85-
[pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)]
86-
)
84+
return np.sum([pr_a(a, n, k) * pr_t_bar_a(t, a, n) for a in range(2, n - k + 2)])
8785

8886

8987
class TiltedGammaDiff:
@@ -114,8 +112,12 @@ def _U(a, b, z):
114112
return float(val)
115113

116114
def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3, reorder=True):
117-
assert shape1 > 0 and shape2 > 0 and shape3 > 0
118-
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
115+
assert shape1 > 0
116+
assert shape2 > 0
117+
assert shape3 > 0
118+
assert rate1 >= 0
119+
assert rate2 > 0
120+
assert rate3 >= 0
119121
# for convergence of 2F1, we need rate2 > rate3. Invariant
120122
# transformations of 2F1 allow us to switch arguments, with
121123
# appropriate rescaling
@@ -369,8 +371,12 @@ def _M(a, b, x):
369371
return float(val)
370372

371373
def __init__(self, shape1, shape2, shape3, rate1, rate2, rate3):
372-
assert shape1 > 0 and shape2 > 0 and shape3 > 0
373-
assert rate1 >= 0 and rate2 > 0 and rate3 >= 0
374+
assert shape1 > 0
375+
assert shape2 > 0
376+
assert shape3 > 0
377+
assert rate1 >= 0
378+
assert rate2 > 0
379+
assert rate3 >= 0
374380
# for numeric stability of hypergeometric we need rate2 > rate1
375381
# as this is a convolution, the order of (1) and (2) don't matter
376382
self.reparametrize = rate1 > rate2
@@ -481,11 +487,7 @@ def sufficient_statistics(self):
481487
+ scipy.special.betaln(self.shape1, self.shape2)
482488
)
483489
x = dF_dz * T / S**2 + B / S
484-
xsq = (
485-
d2F_dz2 * T**2 / S**4
486-
+ B * (B + 1) / S**2
487-
+ 2 * dF_dz * (1 + B) * T / S**3
488-
)
490+
xsq = d2F_dz2 * T**2 / S**4 + B * (B + 1) / S**2 + 2 * dF_dz * (1 + B) * T / S**3
489491
logx = dF_db + scipy.special.digamma(B) - np.log(S)
490492
return logconst, x, xsq, logx
491493

tests/exact_moments.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Moments for EP updates using exact hypergeometric evaluations rather than
44
a Laplace approximation; intended for testing and accuracy benchmarking.
55
"""
6+
67
from math import exp
78
from math import log
89

@@ -181,9 +182,7 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
181182
d3 = d2 * (a + 1) / (c + 1)
182183
mn_m = s1 * exp(f111 - f000) / t / 2 * (1 + z) + b / t / 2
183184
sq_m = (
184-
d1 * exp(f020 - f000) / 3
185-
+ d2 * exp(f121 - f000) / 3
186-
+ d3 * exp(f222 - f000) / 3
185+
d1 * exp(f020 - f000) / 3 + d2 * exp(f121 - f000) / 3 + d3 * exp(f222 - f000) / 3
187186
)
188187
va_m = sq_m - mn_m**2
189188
return mn_m, va_m
@@ -499,9 +498,7 @@ def test_rootward_moments(self, pars):
499498
)[0]
500499
assert np.isclose(logconst, np.log(ck_normconst))
501500
ck_t_i = scipy.integrate.quad(
502-
lambda t_i: t_i
503-
* self.pdf_rootward(t_i, t_j, *pars_redux)
504-
/ ck_normconst,
501+
lambda t_i: t_i * self.pdf_rootward(t_i, t_j, *pars_redux) / ck_normconst,
505502
t_j,
506503
np.inf,
507504
epsabs=0,
@@ -752,8 +749,7 @@ def f(t_i, t_j): # conditional moments
752749
)[0]
753750
ck_mn = (
754751
scipy.integrate.quad(
755-
lambda t_i: f(t_i, t_j)[0]
756-
* self.pdf_rootward(t_i, t_j, *pars_redux),
752+
lambda t_i: f(t_i, t_j)[0] * self.pdf_rootward(t_i, t_j, *pars_redux),
757753
t_j,
758754
np.inf,
759755
)[0]
@@ -762,8 +758,7 @@ def f(t_i, t_j): # conditional moments
762758
assert np.isclose(mn, ck_mn)
763759
ck_va = (
764760
scipy.integrate.quad(
765-
lambda t_i: f(t_i, t_j)[1]
766-
* self.pdf_rootward(t_i, t_j, *pars_redux),
761+
lambda t_i: f(t_i, t_j)[1] * self.pdf_rootward(t_i, t_j, *pars_redux),
767762
t_j,
768763
np.inf,
769764
)[0]

tests/test_accuracy.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424
Test cases for tsdate accuracy.
2525
"""
26+
2627
import json
2728
import os
2829

@@ -40,7 +41,7 @@ class TestAccuracy:
4041
Test for some of the basic functions used in tsdate
4142
"""
4243

43-
@pytest.mark.makefiles
44+
@pytest.mark.makefiles()
4445
def test_make_static_files(self, request):
4546
"""
4647
The function used to create the tree sequences for accuracy testing.
@@ -75,7 +76,13 @@ def test_make_static_files(self, request):
7576
ts.dump(os.path.join(request.fspath.dirname, "data", f"{name}.trees"))
7677

7778
@pytest.mark.parametrize(
78-
"ts_name,min_r2_ts,min_r2_unconstrained,min_spear_ts,min_spear_unconstrained",
79+
(
80+
"ts_name",
81+
"min_r2_ts",
82+
"min_r2_unconstrained",
83+
"min_spear_ts",
84+
"min_spear_unconstrained",
85+
),
7986
[
8087
("one_tree", 0.98601, 0.98601, 0.97719, 0.97719),
8188
("few_trees", 0.98220, 0.98220, 0.97744, 0.97744),
@@ -91,9 +98,7 @@ def test_basic(
9198
min_spear_unconstrained,
9299
request,
93100
):
94-
ts = tskit.load(
95-
os.path.join(request.fspath.dirname, "data", ts_name + ".trees")
96-
)
101+
ts = tskit.load(os.path.join(request.fspath.dirname, "data", ts_name + ".trees"))
97102

98103
sim_ancestry_parameters = json.loads(ts.provenance(0).record)["parameters"]
99104
assert sim_ancestry_parameters["command"] == "sim_ancestry"
@@ -144,7 +149,7 @@ def test_scaling(self, Ne):
144149
assert 0.9 < dts.node(dts.first().root).time / (2 * Ne) < 1.1
145150

146151
@pytest.mark.parametrize(
147-
"bkwd_rate, trio_tmrca",
152+
("bkwd_rate", "trio_tmrca"),
148153
[ # calculated from simulations
149154
(-1.0, 0.76),
150155
(-0.9, 0.79),

tests/test_approximations.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,32 @@
2424
"""
2525
Test cases for the gamma-variational approximations in tsdate
2626
"""
27+
2728
from math import sqrt
2829

2930
import numpy as np
3031
import pytest
3132
import scipy.integrate
3233
import scipy.special
3334
import scipy.stats
34-
from exact_moments import leafward_moments
35-
from exact_moments import moments
36-
from exact_moments import mutation_block_moments
37-
from exact_moments import mutation_edge_moments
38-
from exact_moments import mutation_leafward_moments
39-
from exact_moments import mutation_moments
40-
from exact_moments import mutation_rootward_moments
41-
from exact_moments import mutation_sideways_moments
42-
from exact_moments import mutation_twin_moments
43-
from exact_moments import mutation_unphased_moments
44-
from exact_moments import rootward_moments
45-
from exact_moments import sideways_moments
46-
from exact_moments import twin_moments
47-
from exact_moments import unphased_moments
35+
from exact_moments import (
36+
leafward_moments,
37+
moments,
38+
mutation_block_moments,
39+
mutation_edge_moments,
40+
mutation_leafward_moments,
41+
mutation_moments,
42+
mutation_rootward_moments,
43+
mutation_sideways_moments,
44+
mutation_twin_moments,
45+
mutation_unphased_moments,
46+
rootward_moments,
47+
sideways_moments,
48+
twin_moments,
49+
unphased_moments,
50+
)
4851

49-
from tsdate import approx
50-
from tsdate import hypergeo
52+
from tsdate import approx, hypergeo
5153

5254
# TODO: better test set?
5355
_gamma_trio_test_cases = [ # [shape1, rate1, shape2, rate2, muts, rate]
@@ -294,9 +296,7 @@ def test_average_gammas(self):
294296
E_x = np.mean(shape + 1)
295297
E_logx = np.mean(scipy.special.digamma(shape + 1))
296298
assert np.isclose(E_x, (avg_shape + 1) / avg_rate)
297-
assert np.isclose(
298-
E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate)
299-
)
299+
assert np.isclose(E_logx, scipy.special.digamma(avg_shape + 1) - np.log(avg_rate))
300300

301301

302302
class TestKLMinimizationFailed:
@@ -305,7 +305,7 @@ class TestKLMinimizationFailed:
305305
"""
306306

307307
def test_violates_jensen(self):
308-
with pytest.raises(approx.KLMinimizationFailed, match="violates Jensen's"):
308+
with pytest.raises(approx.KLMinimizationFailedError, match="violates Jensen's"):
309309
approx.approximate_gamma_kl(1, 0)
310310

311311
def test_asymptotic_bound(self):
@@ -314,10 +314,12 @@ def test_asymptotic_bound(self):
314314
alpha, _ = approx.approximate_gamma_kl(1, logx)
315315
alpha += 1
316316
alpha_bound = -0.5 / logx
317-
assert alpha == alpha_bound and alpha > 1e4
317+
assert alpha == alpha_bound
318+
assert alpha > 1e4
318319
# check that bound matches optimization result just under threshold
319320
logx = -0.000051
320321
alpha, _ = approx.approximate_gamma_kl(1, logx)
321322
alpha += 1
322323
alpha_bound = -0.5 / logx
323-
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
324+
assert np.abs(alpha - alpha_bound) < 1
325+
assert alpha < 1e4

tests/test_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Tests for the cache management code.
33
"""
4+
45
import os
56
import pathlib
67
import unittest

tests/test_cli.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424
Test cases for the command line interface for tsdate.
2525
"""
26+
2627
import json
2728
import logging
2829
from unittest import mock
@@ -75,11 +76,11 @@ def test_recombination_rate(self):
7576
parser = cli.tsdate_cli_parser()
7677
params = ["-m", "1e10"]
7778
args = parser.parse_args(
78-
["date", self.infile, self.output] + params + ["-r", "1e-100"]
79+
["date", self.infile, self.output, *params, "-r", "1e-100"]
7980
)
8081
assert args.recombination_rate == 1e-100
8182
args = parser.parse_args(
82-
["date", self.infile, self.output] + params + ["--recombination-rate", "73"]
83+
["date", self.infile, self.output, *params, "--recombination-rate", "73"]
8384
)
8485
assert args.recombination_rate == 73
8586

@@ -97,24 +98,22 @@ def test_epsilon(self):
9798
def test_num_threads(self):
9899
parser = cli.tsdate_cli_parser()
99100
params = ["--method", "maximization", "--num-threads"]
100-
args = parser.parse_args(["date", self.infile, self.output] + params + ["1"])
101+
args = parser.parse_args(["date", self.infile, self.output, *params, "1"])
101102
assert args.num_threads == 1
102-
args = parser.parse_args(["date", self.infile, self.output] + params + ["2"])
103+
args = parser.parse_args(["date", self.infile, self.output, *params, "2"])
103104
assert args.num_threads == 2
104105

105106
def test_probability_space(self):
106107
parser = cli.tsdate_cli_parser()
107108
params = ["--method", "inside_outside", "--probability-space"]
108-
args = parser.parse_args(
109-
["date", self.infile, self.output] + params + ["linear"]
110-
)
109+
args = parser.parse_args(["date", self.infile, self.output, *params, "linear"])
111110
assert args.probability_space == "linear"
112111
args = parser.parse_args(
113-
["date", self.infile, self.output] + params + ["logarithmic"]
112+
["date", self.infile, self.output, *params, "logarithmic"]
114113
)
115114
assert args.probability_space == "logarithmic"
116115

117-
@pytest.mark.parametrize("flag, log_status", logging_flags.items())
116+
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
118117
def test_verbosity(self, flag, log_status):
119118
parser = cli.tsdate_cli_parser()
120119
args = parser.parse_args(["preprocess", self.infile, self.output, flag])
@@ -130,7 +129,7 @@ def test_method(self, method):
130129
params = ["-m", "1e-8", "--method", method]
131130
if method != "variational_gamma":
132131
params += ["-n", "10"]
133-
args = parser.parse_args(["date", self.infile, self.output] + params)
132+
args = parser.parse_args(["date", self.infile, self.output, *params])
134133
assert args.method == method
135134

136135
def test_progress(self):
@@ -231,7 +230,7 @@ def test_no_output_variational_gamma(self, tmp_path, capfd):
231230
assert out == ""
232231
assert err == ""
233232

234-
@pytest.mark.parametrize("flag, log_status", logging_flags.items())
233+
@pytest.mark.parametrize(("flag", "log_status"), logging_flags.items())
235234
def test_verbosity(self, tmp_path, caplog, flag, log_status):
236235
popsize = 10000
237236
ts = msprime.simulate(

0 commit comments

Comments
 (0)