Skip to content

Commit a52f451

Browse files
committed
replace black and flake8 with ruff and apply ruff
1 parent 44ae61e commit a52f451

File tree

10 files changed

+221
-224
lines changed

10 files changed

+221
-224
lines changed

.github/workflows/tests.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ jobs:
2828
python3 -m pip install --upgrade pip
2929
- name: Linting
3030
run: |
31-
pip install flake8==7.1.0 Flake8-pyproject black yamllint
32-
flake8 allegro/ --count --show-source --statistics
33-
black allegro/ --check
34-
yamllint .
31+
pip install ruff yamllint
32+
ruff check .
33+
ruff format --check .
34+
yamllint .
3535
- name: Install git
3636
run: |
3737
apt install git -y

allegro/model/allegro_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def FullAllegroEnergyModel(
270270
# === pair potentials ===
271271
prev_irreps_out = per_type_energy_scale_shift.irreps_out
272272
if pair_potential is not None:
273-
274273
# case where model doesn't have edge cutoffs up to this point, but pair potential required
275274
if AtomicDataDict.EDGE_CUTOFF_KEY not in prev_irreps_out:
276275
cutoff = AddRadialCutoffToData(

allegro/nn/_allegro.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def __init__(
4747
num_layers >= 1
4848
) # zero layers is "two body", but we don't need to support that fallback case
4949

50-
assert (
51-
avg_num_neighbors is not None
52-
), "`avg_num_neighbors` must be set for Allegro models, but `avg_num_neighbors=None` found"
50+
assert avg_num_neighbors is not None, (
51+
"`avg_num_neighbors` must be set for Allegro models, but `avg_num_neighbors=None` found"
52+
)
5353

5454
# === save parameters ===
5555
self.num_layers = num_layers
@@ -99,9 +99,9 @@ def __init__(
9999
self.tps = torch.nn.ModuleList([])
100100

101101
env_embed_irreps = Irreps([(1, ir) for _, ir in input_irreps])
102-
assert (
103-
env_embed_irreps[0].ir == SCALAR
104-
), "env_embed_irreps must start with scalars"
102+
assert env_embed_irreps[0].ir == SCALAR, (
103+
"env_embed_irreps must start with scalars"
104+
)
105105

106106
arg_irreps = env_embed_irreps
107107

allegro/nn/_edgeembed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def __init__(
4242
self.num_types = len(type_names)
4343

4444
# == type embedding ==
45-
assert (
46-
initial_embedding_dim % 2 == 0
47-
), "`initial_embedding_dim` must be an even number"
45+
assert initial_embedding_dim % 2 == 0, (
46+
"`initial_embedding_dim` must be an even number"
47+
)
4848

4949
self.center_embed = torch.nn.Embedding(
5050
num_embeddings=self.num_types,

allegro/nn/_strided/_cueq_contracter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def allegro_tp_desc(
6464

6565

6666
class CuEquivarianceContracter(Contracter):
67-
6867
def __init__(self, **kwargs):
6968
super().__init__(**kwargs)
7069

@@ -89,7 +88,6 @@ def forward(
8988
idxs: torch.Tensor,
9089
scatter_dim_size: int,
9190
) -> torch.Tensor:
92-
9391
# NOTE: the reason for some duplicated code is because TorchScript doesn't support super() calls
9492
# see https://github.com/pytorch/pytorch/issues/42885
9593

allegro/nn/_strided/_flashallegro.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def tensor_product_p_kernel(
6161
BLOCK_DIM: tl.constexpr,
6262
output_dtype: tl.constexpr,
6363
):
64-
6564
# Program IDs remain the same
6665
pid_b = tl.program_id(0)
6766
pid_dim = tl.program_id(1)
@@ -184,7 +183,6 @@ def tensor_product_up_kernel(
184183
BLOCK_DIM: tl.constexpr,
185184
output_dtype: tl.constexpr,
186185
):
187-
188186
# Calculate program IDs
189187
pid_b = tl.program_id(0)
190188
pid_u = tl.program_id(1)
@@ -296,12 +294,12 @@ def _metadata_helper(dim, coo, coovalue, p_to_nnz_values):
296294
indptr = torch.cumsum(indptr, dim=0)
297295

298296
# Assert that all values fit within int16 range
299-
assert torch.all(
300-
indptr <= torch.iinfo(torch.int16).max
301-
), "Values in indptr exceed int16 max value"
302-
assert torch.all(
303-
indptr >= torch.iinfo(torch.int16).min
304-
), "Values in indptr exceed int16 min value"
297+
assert torch.all(indptr <= torch.iinfo(torch.int16).max), (
298+
"Values in indptr exceed int16 max value"
299+
)
300+
assert torch.all(indptr >= torch.iinfo(torch.int16).min), (
301+
"Values in indptr exceed int16 min value"
302+
)
305303

306304
# Cast to int16 after validation
307305
indptr = indptr.to(dtype=torch.int16)
@@ -314,7 +312,6 @@ def _metadata_helper(dim, coo, coovalue, p_to_nnz_values):
314312
return indptr, l1s, l2s, vals, p_to_nnz_mapper
315313

316314
def _initialize_metadata(w3j):
317-
318315
assert len(w3j.shape) == 4
319316
P, I, J, K = w3j.shape
320317

@@ -332,12 +329,12 @@ def _initialize_metadata(w3j):
332329
del nzidx, w3j_sum, w3j
333330

334331
# Assert that all values fit within int16 range
335-
assert torch.all(
336-
p_to_nnz_mapper <= torch.iinfo(torch.int16).max
337-
), "Values in p_to_nnz_mapper exceed int16 max value"
338-
assert torch.all(
339-
p_to_nnz_mapper >= torch.iinfo(torch.int16).min
340-
), "Values in p_to_nnz_mapper exceed int16 min value"
332+
assert torch.all(p_to_nnz_mapper <= torch.iinfo(torch.int16).max), (
333+
"Values in p_to_nnz_mapper exceed int16 max value"
334+
)
335+
assert torch.all(p_to_nnz_mapper >= torch.iinfo(torch.int16).min), (
336+
"Values in p_to_nnz_mapper exceed int16 min value"
337+
)
341338

342339
# Cast to int16 after validation
343340
p_to_nnz_mapper = p_to_nnz_mapper.to(dtype=torch.int16)
@@ -403,7 +400,6 @@ def _triton_kernel_allegro(
403400
# Acc dtype
404401
output_dtype: torch.dtype,
405402
) -> torch.Tensor:
406-
407403
output_dtype = TORCH_TRITON_DTYPE_MAPPER[output_dtype]
408404

409405
BATCH = x.shape[0]
@@ -675,7 +671,6 @@ def _flashallegro_backward(ctx, grad_output):
675671

676672

677673
class TritonContracter(Contracter):
678-
679674
def __init__(self, **kwargs):
680675
super().__init__(**kwargs)
681676

allegro/nn/spline.py

Lines changed: 88 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,89 @@
11
# This file is a part of the `allegro` package. Please see LICENSE and README at the root for information on using it.
2-
from math import pi
3-
import torch
4-
from e3nn.util.jit import compile_mode
5-
from nequip.utils.global_dtype import _GLOBAL_DTYPE
6-
7-
8-
@compile_mode("script")
9-
class PerClassSpline(torch.nn.Module):
10-
"""Module implementing the spline required for a two-body scalar embedding.
11-
12-
Per-class splines with finite support for [0, 1], and will go to zero smoothly at 1.
13-
14-
Args:
15-
num_classes (int) : number of classes or categories (for ``index_select`` operation)
16-
num_channels (int) : number of output channels
17-
num_splines (int) : number of spline basis functions
18-
spline_span (int) : number of spline basis functions that overlap on spline grid points
19-
"""
20-
21-
def __init__(
22-
self,
23-
num_classes: int,
24-
num_channels: int,
25-
num_splines: int,
26-
spline_span: int,
27-
dtype: torch.dtype = _GLOBAL_DTYPE,
28-
):
29-
super().__init__()
30-
31-
# === sanity check ===
32-
assert 0 <= spline_span <= num_splines
33-
assert num_splines > 0
34-
35-
# === save inputs parameters ===
36-
self.num_classes = num_classes
37-
self.num_channels = num_channels
38-
self.num_splines = num_splines
39-
self.spline_span = spline_span
40-
self.dtype = dtype
41-
42-
# === spline grid parameters ===
43-
lower = (
44-
torch.arange(
45-
-self.spline_span, self.num_splines - spline_span, dtype=self.dtype
46-
)
47-
/ self.num_splines
48-
)
49-
diff = (self.spline_span + 1) / self.num_splines
50-
51-
self.register_buffer("lower", lower)
52-
self.register_buffer("upper", lower + diff)
53-
self._const = 2 * pi / diff
54-
55-
# === use torch.nn.Embedding for spline weights ===
56-
self.class_embed = torch.nn.Embedding(
57-
num_embeddings=self.num_classes,
58-
embedding_dim=self.num_channels * self.num_splines,
59-
dtype=dtype,
60-
)
61-
62-
def extra_repr(self) -> str:
63-
msg = f"num classes : {self.num_classes}\n"
64-
msg += f"num channels: {self.num_channels}\n"
65-
msg += f"num splines : {self.num_splines}\n"
66-
msg += f"spline span : {self.spline_span}"
67-
return msg
68-
69-
def forward(self, x: torch.Tensor, classes: torch.Tensor) -> torch.Tensor:
70-
"""
71-
Args:
72-
x (torch.Tensor) : input tensor with shape (z, 1)
73-
classes (torch.Tensor): class tensor with shape (z,) whose values are integer indices from 0 to num_classes - 1
74-
"""
75-
# index out weights based on classes: -> (z, num_channels, num_splines)
76-
spline_weights = self.class_embed(classes).view(
77-
classes.size(0), self.num_channels, self.num_splines
78-
)
79-
spline_basis = self._get_basis(x)
80-
# (z, num_channels, num_splines), (z, num_splines) -> (z, num_channels)
81-
return torch.bmm(spline_weights, spline_basis.unsqueeze(-1)).squeeze(-1)
82-
83-
def _get_basis(self, x: torch.Tensor) -> torch.Tensor:
84-
# construct spline basis
85-
# x: (z, 1) -> spline_basis: (z, num_splines)
86-
normalized_x = self._const * (
87-
torch.clamp(x, min=self.lower, max=self.upper) - self.lower
88-
)
89-
return 0.25 * (1 - torch.cos(normalized_x)).square()
2+
from math import pi
3+
import torch
4+
from e3nn.util.jit import compile_mode
5+
from nequip.utils.global_dtype import _GLOBAL_DTYPE
6+
7+
8+
@compile_mode("script")
9+
class PerClassSpline(torch.nn.Module):
10+
"""Module implementing the spline required for a two-body scalar embedding.
11+
12+
Per-class splines with finite support for [0, 1], and will go to zero smoothly at 1.
13+
14+
Args:
15+
num_classes (int) : number of classes or categories (for ``index_select`` operation)
16+
num_channels (int) : number of output channels
17+
num_splines (int) : number of spline basis functions
18+
spline_span (int) : number of spline basis functions that overlap on spline grid points
19+
"""
20+
21+
def __init__(
22+
self,
23+
num_classes: int,
24+
num_channels: int,
25+
num_splines: int,
26+
spline_span: int,
27+
dtype: torch.dtype = _GLOBAL_DTYPE,
28+
):
29+
super().__init__()
30+
31+
# === sanity check ===
32+
assert 0 <= spline_span <= num_splines
33+
assert num_splines > 0
34+
35+
# === save inputs parameters ===
36+
self.num_classes = num_classes
37+
self.num_channels = num_channels
38+
self.num_splines = num_splines
39+
self.spline_span = spline_span
40+
self.dtype = dtype
41+
42+
# === spline grid parameters ===
43+
lower = (
44+
torch.arange(
45+
-self.spline_span, self.num_splines - spline_span, dtype=self.dtype
46+
)
47+
/ self.num_splines
48+
)
49+
diff = (self.spline_span + 1) / self.num_splines
50+
51+
self.register_buffer("lower", lower)
52+
self.register_buffer("upper", lower + diff)
53+
self._const = 2 * pi / diff
54+
55+
# === use torch.nn.Embedding for spline weights ===
56+
self.class_embed = torch.nn.Embedding(
57+
num_embeddings=self.num_classes,
58+
embedding_dim=self.num_channels * self.num_splines,
59+
dtype=dtype,
60+
)
61+
62+
def extra_repr(self) -> str:
63+
msg = f"num classes : {self.num_classes}\n"
64+
msg += f"num channels: {self.num_channels}\n"
65+
msg += f"num splines : {self.num_splines}\n"
66+
msg += f"spline span : {self.spline_span}"
67+
return msg
68+
69+
def forward(self, x: torch.Tensor, classes: torch.Tensor) -> torch.Tensor:
70+
"""
71+
Args:
72+
x (torch.Tensor) : input tensor with shape (z, 1)
73+
classes (torch.Tensor): class tensor with shape (z,) whose values are integer indices from 0 to num_classes - 1
74+
"""
75+
# index out weights based on classes: -> (z, num_channels, num_splines)
76+
spline_weights = self.class_embed(classes).view(
77+
classes.size(0), self.num_channels, self.num_splines
78+
)
79+
spline_basis = self._get_basis(x)
80+
# (z, num_channels, num_splines), (z, num_splines) -> (z, num_channels)
81+
return torch.bmm(spline_weights, spline_basis.unsqueeze(-1)).squeeze(-1)
82+
83+
def _get_basis(self, x: torch.Tensor) -> torch.Tensor:
84+
# construct spline basis
85+
# x: (z, 1) -> spline_basis: (z, num_splines)
86+
normalized_x = self._const * (
87+
torch.clamp(x, min=self.lower, max=self.upper) - self.lower
88+
)
89+
return 0.25 * (1 - torch.cos(normalized_x)).square()

0 commit comments

Comments
 (0)