Skip to content

Commit 8ccf00b

Browse files
committed
ParaFold for AlphaFold 2.3.1
1 parent 8aeebbf commit 8ccf00b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+784
-724
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.slurm

LICENSE

Lines changed: 0 additions & 21 deletions
This file was deleted.

README.md

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,77 @@
22
<img src="./docs/parafoldlogo.png" width="400" >
33
</div>
44

5-
# ParallelFold
5+
# ParaFold
66

77
Author: Bozitao Zhong - zbztzhz@gmail.com
88

9-
:station: We are adding new functions to ParallelFold, you can see our [Roadmap](https://trello.com/b/sAqBIxBC/parallelfold).
10-
11-
:bookmark_tabs: Please cite our [paper](https://arxiv.org/abs/2111.06340) if you used ParallelFold (ParaFold) in you research.
9+
:bookmark_tabs: Please cite our [paper](https://arxiv.org/abs/2111.06340) if you used ParaFold (ParallelFold) in you research.
1210

1311
## Overview
1412

13+
Recent change: **ParaFold now supports AlphaFold 2.3.1**
14+
1515
This project is a modified version of DeepMind's [AlphaFold2](https://github.com/deepmind/alphafold) to achieve high-throughput protein structure prediction.
1616

1717
We have these following modifications to the original AlphaFold pipeline:
1818

1919
- Divide **CPU part** (MSA and template searching) and **GPU part** (prediction model)
2020

21-
**ParallelFold now supports AlphaFold 2.1.2**
22-
2321

2422

2523
## How to install
2624

2725
We recommend to install AlphaFold locally, and not using **docker**.
2826

29-
For CUDA 11, you can refer to the [installation guide here](./docs/install.md).
27+
```bash
28+
# clone this repo
29+
git clone https://github.com/Zuricho/ParallelFold.git
3030

31-
For CUDA 10.1, you can refer to the [installation guide here](./docs/install_cuda10.md).
31+
# Create a miniconda environment for ParaFold/AlphaFold
32+
# Recommend you to use python 3.8, version < 3.7 have missing packages, python versions newer than 3.8 were not tested
33+
conda create -n parafold python=3.8
3234

35+
pip install py3dmol
36+
# openmm 7.7 is recommended (original alphafold using 7.5.1, but it is not supported now)
37+
conda install -c conda-forge openmm=7.7 pdbfixer
3338

39+
# use pip3 to install most of packages
40+
pip3 install -r requirements.txt
3441

35-
## Some detail information of modified files
42+
# install cuda and cudnn
43+
# cudatoolkit 11.3.1 matches cudnn 8.2.1
44+
conda install cudatoolkit=11.3 cudnn
3645

37-
- `run_alphafold.py`: modified version of original `run_alphafold.py`, it has multiple additional functions like skipping featuring steps when exists `feature.pkl` in output folder
38-
- `run_alphafold.sh`: bash script to run `run_alphafold.py`
39-
- `run_figure.py`: this file can help you make figure for your system
46+
# downgrade jaxlib to the correct version, matches with cuda and cudnn version
47+
pip3 install --upgrade --no-cache-dir jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
4048

49+
# install packages for multiple sequence alignment
50+
conda install -c bioconda hmmer=3.3.2 hhsuite=3.3.0 kalign2=2.04
4151

52+
chmod +x run_alphafold.sh
53+
```
4254

43-
## How to run
4455

45-
Visit the [usage page](./docs/usage.md) to know how to run
4656

57+
## Some detail information of modified files
4758

59+
- `run_alphafold.py`: modified version of original `run_alphafold.py`, it has multiple additional functions like skipping featuring steps when exists `feature.pkl` in output folder
60+
- `run_alphafold.sh`: bash script to run `run_alphafold.py`
61+
- `run_figure.py`: this file can help you make figure for your system
4862

49-
## Functions
50-
51-
You can using some flags to change prediction model for ParallelFold:
52-
53-
`-r`: Skip AMBER refinement [Under repair]
5463

55-
`-b`: Using benchmark mode - running JAX model for twice, and the second run can used for evaluate running time
5664

57-
`-R`: Change the number of cycles in recycling
65+
## How to run
5866

59-
**More functions are under development.**
67+
Visit the [usage page](./docs/usage.md) to know how to run
6068

6169

6270

6371
## What is this for
6472

65-
ParallelFold can help you accelerate AlphaFold when you want to predict multiple sequences. After dividing the CPU part and GPU part, users can finish feature step by multiple processors. Using ParallelFold, you can run AlphaFold 2~3 times faster than DeepMind's procedure.
73+
ParallelFold can help you accelerate AlphaFold when you want to predict multiple sequences. After dividing the CPU part and GPU part, users can finish feature step by multiple processors. Using ParaFold, you can run AlphaFold 2~3 times faster than DeepMind's procedure.
6674

67-
**If you have any question, please send GitHub issues**
75+
**If you have any question, please raise issues**
6876

6977

7078

alphafold/common/residue_constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
# 4,5,6,7: 'chi1,2,3,4-group'
121121
# The atom positions are relative to the axis-end-atom of the corresponding
122122
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
123-
# is defined such that the dihedral-angle-definiting atom (the last entry in
123+
# is defined such that the dihedral-angle-defining atom (the last entry in
124124
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
125125
# format: [atomname, group_idx, rel_position]
126126
rigid_group_atom_positions = {
@@ -772,10 +772,10 @@ def _make_rigid_transformation_4x4(ex, ey, translation):
772772
# and an array with (restype, atomtype, coord) for the atom positions
773773
# and compute affine transformation matrices (4,4) from one rigid group to the
774774
# previous group
775-
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
775+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
776776
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
777777
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
778-
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
778+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
779779
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
780780
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
781781
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)

alphafold/data/pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self,
117117
uniref90_database_path: str,
118118
mgnify_database_path: str,
119119
bfd_database_path: Optional[str],
120-
uniclust30_database_path: Optional[str],
120+
uniref30_database_path: Optional[str],
121121
small_bfd_database_path: Optional[str],
122122
template_searcher: TemplateSearcher,
123123
template_featurizer: templates.TemplateHitFeaturizer,
@@ -135,9 +135,9 @@ def __init__(self,
135135
binary_path=jackhmmer_binary_path,
136136
database_path=small_bfd_database_path)
137137
else:
138-
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
138+
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
139139
binary_path=hhblits_binary_path,
140-
databases=[bfd_database_path, uniclust30_database_path])
140+
databases=[bfd_database_path, uniref30_database_path])
141141
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
142142
binary_path=jackhmmer_binary_path,
143143
database_path=mgnify_database_path)
@@ -211,14 +211,14 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
211211
use_precomputed_msas=self.use_precomputed_msas)
212212
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
213213
else:
214-
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
215-
hhblits_bfd_uniclust_result = run_msa_tool(
216-
msa_runner=self.hhblits_bfd_uniclust_runner,
214+
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
215+
hhblits_bfd_uniref_result = run_msa_tool(
216+
msa_runner=self.hhblits_bfd_uniref_runner,
217217
input_fasta_path=input_fasta_path,
218218
msa_out_path=bfd_out_path,
219219
msa_format='a3m',
220220
use_precomputed_msas=self.use_precomputed_msas)
221-
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
221+
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
222222

223223
templates_result = self.template_featurizer.get_templates(
224224
query_sequence=input_sequence,

alphafold/data/templates.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ class LengthError(PrefilterError):
8989
'template_aatype': np.float32,
9090
'template_all_atom_masks': np.float32,
9191
'template_all_atom_positions': np.float32,
92-
'template_domain_names': np.object,
93-
'template_sequence': np.object,
92+
'template_domain_names': object,
93+
'template_sequence': object,
9494
'template_sum_probs': np.float32,
9595
}
9696

@@ -1002,8 +1002,8 @@ def get_templates(
10021002
(1, num_res, residue_constants.atom_type_num), np.float32),
10031003
'template_all_atom_positions': np.zeros(
10041004
(1, num_res, residue_constants.atom_type_num, 3), np.float32),
1005-
'template_domain_names': np.array([''.encode()], dtype=np.object),
1006-
'template_sequence': np.array([''.encode()], dtype=np.object),
1005+
'template_domain_names': np.array([''.encode()], dtype=object),
1006+
'template_sequence': np.array([''.encode()], dtype=object),
10071007
'template_sum_probs': np.array([0], dtype=np.float32)
10081008
}
10091009
return TemplateSearchResult(

alphafold/data/tools/jackhmmer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,20 @@ def query(self,
167167
input_fasta_path: str,
168168
max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
169169
"""Queries the database using Jackhmmer."""
170+
return self.query_multiple([input_fasta_path], max_sequences)[0]
171+
172+
def query_multiple(
173+
self,
174+
input_fasta_paths: Sequence[str],
175+
max_sequences: Optional[int] = None,
176+
) -> Sequence[Sequence[Mapping[str, Any]]]:
177+
"""Queries the database for multiple queries using Jackhmmer."""
170178
if self.num_streamed_chunks is None:
171-
single_chunk_result = self._query_chunk(
172-
input_fasta_path, self.database_path, max_sequences)
173-
return [single_chunk_result]
179+
single_chunk_results = []
180+
for input_fasta_path in input_fasta_paths:
181+
single_chunk_results.append([self._query_chunk(
182+
input_fasta_path, self.database_path, max_sequences)])
183+
return single_chunk_results
174184

175185
db_basename = os.path.basename(self.database_path)
176186
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
@@ -185,7 +195,7 @@ def query(self,
185195

186196
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
187197
with futures.ThreadPoolExecutor(max_workers=2) as executor:
188-
chunked_output = []
198+
chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
189199
for i in range(1, self.num_streamed_chunks + 1):
190200
# Copy the chunk locally
191201
if i == 1:
@@ -197,9 +207,9 @@ def query(self,
197207

198208
# Run Jackhmmer with the chunk
199209
future.result()
200-
chunked_output.append(self._query_chunk(
201-
input_fasta_path, db_local_chunk(i), max_sequences))
202-
210+
for fasta_index, input_fasta_path in enumerate(input_fasta_paths):
211+
chunked_outputs[fasta_index].append(self._query_chunk(
212+
input_fasta_path, db_local_chunk(i), max_sequences))
203213
# Remove the local copy of the chunk
204214
os.remove(db_local_chunk(i))
205215
# Do not set next_future for the last chunk so that this works even for
@@ -208,4 +218,4 @@ def query(self,
208218
future = next_future
209219
if self.streaming_callback:
210220
self.streaming_callback(i)
211-
return chunked_output
221+
return chunked_outputs

alphafold/model/all_atom_multimer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def torsion_angles_to_frames(
426426
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
427427
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
428428

429-
all_frames_to_backb = jax.tree_multimap(
429+
all_frames_to_backb = jax.tree_map(
430430
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
431431
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
432432
chi4_frame_to_backb[:, None])

alphafold/model/common_modules.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,64 @@ def __call__(self, inputs):
128128

129129
return output
130130

131+
132+
class LayerNorm(hk.LayerNorm):
133+
"""LayerNorm module.
134+
135+
Equivalent to hk.LayerNorm but with different parameter shapes: they are
136+
always vectors rather than possibly higher-rank tensors. This makes it easier
137+
to change the layout whilst keep the model weight-compatible.
138+
"""
139+
140+
def __init__(self,
141+
axis,
142+
create_scale: bool,
143+
create_offset: bool,
144+
eps: float = 1e-5,
145+
scale_init=None,
146+
offset_init=None,
147+
use_fast_variance: bool = False,
148+
name=None,
149+
param_axis=None):
150+
super().__init__(
151+
axis=axis,
152+
create_scale=False,
153+
create_offset=False,
154+
eps=eps,
155+
scale_init=None,
156+
offset_init=None,
157+
use_fast_variance=use_fast_variance,
158+
name=name,
159+
param_axis=param_axis)
160+
self._temp_create_scale = create_scale
161+
self._temp_create_offset = create_offset
162+
163+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
164+
is_bf16 = (x.dtype == jnp.bfloat16)
165+
if is_bf16:
166+
x = x.astype(jnp.float32)
167+
168+
param_axis = self.param_axis[0] if self.param_axis else -1
169+
param_shape = (x.shape[param_axis],)
170+
171+
param_broadcast_shape = [1] * x.ndim
172+
param_broadcast_shape[param_axis] = x.shape[param_axis]
173+
scale = None
174+
offset = None
175+
if self._temp_create_scale:
176+
scale = hk.get_parameter(
177+
'scale', param_shape, x.dtype, init=self.scale_init)
178+
scale = scale.reshape(param_broadcast_shape)
179+
180+
if self._temp_create_offset:
181+
offset = hk.get_parameter(
182+
'offset', param_shape, x.dtype, init=self.offset_init)
183+
offset = offset.reshape(param_broadcast_shape)
184+
185+
out = super().__call__(x, scale=scale, offset=offset)
186+
187+
if is_bf16:
188+
out = out.astype(jnp.bfloat16)
189+
190+
return out
191+

0 commit comments

Comments
 (0)