@@ -19,6 +19,8 @@ def write_unique_counts(fasta_filename: Path,
19
19
index_filename : Path ,
20
20
kmer_batch_size : int ,
21
21
kmer_lengths : list [int ],
22
+ initial_search_length : int ,
23
+ exclude_bases : set [bytes ],
22
24
num_threads : int ,
23
25
use_binary_search = False ,
24
26
verbose : bool = False ):
@@ -110,6 +112,8 @@ def write_unique_counts(fasta_filename: Path,
110
112
sequence_segment ,
111
113
min_kmer_length ,
112
114
max_kmer_length ,
115
+ initial_search_length ,
116
+ exclude_bases ,
113
117
num_threads ,
114
118
data_type , # type: ignore
115
119
verbose )
@@ -119,6 +123,7 @@ def write_unique_counts(fasta_filename: Path,
119
123
sequence_segment ,
120
124
kmer_lengths ,
121
125
num_kmers ,
126
+ exclude_bases ,
122
127
num_threads ,
123
128
data_type , # type: ignore
124
129
verbose )
@@ -155,6 +160,8 @@ def binary_search(index_filename: Path,
155
160
sequence_segment : SequenceSegment ,
156
161
min_kmer_length : int ,
157
162
max_kmer_length : int ,
163
+ initial_search_length : int ,
164
+ exclude_bases : set [bytes ],
158
165
num_threads : int ,
159
166
data_type : Union [np .uint8 , np .uint16 , np .uint32 ],
160
167
verbose : bool ) -> tuple [npt .NDArray [np .uint ], int ]:
@@ -163,13 +170,16 @@ def binary_search(index_filename: Path,
163
170
164
171
# NB: Floor division for midpoint
165
172
# NB: Avoid an overflow error by dividing first before sum
166
- starting_kmer_length = (max_kmer_length // 2 ) + (min_kmer_length // 2 )
173
+ if initial_search_length :
174
+ starting_kmer_length = initial_search_length
175
+ else :
176
+ starting_kmer_length = (max_kmer_length // 2 ) + (min_kmer_length // 2 )
167
177
168
178
# Track which kmer positions have finished searching,
169
179
# skipping any kmers starting with an ambiguous base
170
- finished_search = np . frombuffer (sequence_segment . data ,
171
- dtype = np . uint8 ,
172
- count = num_kmers ) == ord ( b'N' )
180
+ finished_search = get_ambiguous_positions (sequence_segment ,
181
+ num_kmers ,
182
+ exclude_bases )
173
183
174
184
# Print out the number of ambiguous positions skipped
175
185
ambiguous_positions_skipped = finished_search .sum ()
@@ -194,8 +204,10 @@ def binary_search(index_filename: Path,
194
204
current_length_query ,
195
205
finished_search ,
196
206
max_kmer_length ,
197
- min_kmer_length )
207
+ min_kmer_length ,
208
+ initial_search_length )
198
209
210
+ # Print out the number of ambiguous positions skipped if verbosity is on
199
211
if verbose :
200
212
upper_bound_change_count = np .count_nonzero (
201
213
upper_length_bound [(~ finished_search ).nonzero ()] < max_kmer_length )
@@ -209,7 +221,7 @@ def binary_search(index_filename: Path,
209
221
verbose_print (verbose , f"{ short_kmers_discarded_count } k-mers shorter "
210
222
"than the minimum length discarded due to ambiguity" )
211
223
212
- # List of minimum lengths (where 0 is nothing was found)
224
+ # List of unique minimum lengths (where 0 is nothing was found)
213
225
unique_lengths = np .zeros (num_kmers , dtype = data_type )
214
226
215
227
iteration_count = 1
@@ -303,7 +315,8 @@ def update_upper_search_bound(upper_length_bound_array: npt.NDArray[np.uint],
303
315
current_length_query_array : npt .NDArray [np .uint ],
304
316
finished_search_array : npt .NDArray [np .bool_ ],
305
317
max_kmer_length ,
306
- min_kmer_length ):
318
+ min_kmer_length ,
319
+ initial_search_length ):
307
320
"""Modifies in the input arrays to update the upper search bound based on
308
321
ambiguous bases in the sequence data.
309
322
Updates the query lengths between the new maximum upper bound
@@ -329,12 +342,23 @@ def update_upper_search_bound(upper_length_bound_array: npt.NDArray[np.uint],
329
342
# Set the maximum length up to 1 next to the ambiguous base position
330
343
upper_length_bound_array [length_change_position :i ] = \
331
344
max_lengths_to_ambiguous_position
345
+
332
346
# Calculate the new query length as the midpoint between the updated
333
347
# upper and the current lower bounds
334
- current_length_query_array [ length_change_position : i ] = np .floor (
348
+ new_initial_search_array = np .floor (
335
349
(upper_length_bound_array [length_change_position :i ] / 2 ) +
336
350
(lower_length_bound_array [length_change_position :i ] / 2 )).astype (
337
351
data_type )
352
+
353
+ # If we have an initial search length
354
+ if initial_search_length :
355
+ # Use the initial search length if it is less than the new midpoint
356
+ new_initial_search_array = np .fmin (new_initial_search_array ,
357
+ initial_search_length )
358
+
359
+ current_length_query_array [length_change_position :i ] = \
360
+ new_initial_search_array
361
+
338
362
# Mark positions with values of (min length - 1) to 1 as finished
339
363
finished_search_array [minimum_length_position + 1 :i ] = True
340
364
@@ -343,14 +367,16 @@ def linear_search(index_filename: Path,
343
367
sequence_segment : SequenceSegment ,
344
368
kmer_lengths : list [int ],
345
369
num_kmers : int ,
370
+ exclude_bases : set [bytes ],
346
371
num_threads : int ,
347
372
data_type : Union [np .uint8 , np .uint16 , np .uint32 ],
348
373
verbose : bool ) -> tuple [npt .NDArray [np .uint ], int ]:
349
374
# Track which kmer positions have finished searching,
350
375
# skipping any kmers starting with an ambiguous base
351
376
# NB: Iterating over bytes returns ints
352
- finished_search = np .array ([c == ord (b'N' ) for c in
353
- sequence_segment .data [:num_kmers ]])
377
+ finished_search = get_ambiguous_positions (sequence_segment ,
378
+ num_kmers ,
379
+ exclude_bases )
354
380
355
381
ambiguous_positions_skipped = finished_search .sum ()
356
382
verbose_print (verbose , f"Skipping { ambiguous_positions_skipped } ambiguous "
@@ -478,6 +504,26 @@ def get_num_kmers(sequence_segment: SequenceSegment,
478
504
return sequence_length - lookahead_length
479
505
480
506
507
+ def get_ambiguous_positions (sequence_segment : SequenceSegment ,
508
+ num_positions : int ,
509
+ ambiguous_bases : set [bytes ]):
510
+ """Returns a boolean array of ambiguous positions in a sequence segment
511
+ Where True is an ambiguous position and False is a non-ambiguous"""
512
+
513
+ # Track which kmer positions have finished searching,
514
+ # skipping any kmers starting with an ambiguous base
515
+ sequence_buffer = np .frombuffer (sequence_segment .data ,
516
+ dtype = np .uint8 ,
517
+ count = num_positions )
518
+
519
+ ambiguous_array_positions = np .full (num_positions , False , dtype = bool )
520
+ for base in ambiguous_bases :
521
+ ambiguous_array_positions |= (sequence_buffer == ord (base ))
522
+
523
+ return ambiguous_array_positions
524
+
525
+
526
+
481
527
def print_summary_statisitcs (verbose : bool ,
482
528
total_unique_lengths_count : int ,
483
529
total_ambiguous_positions : int ,
@@ -505,6 +551,8 @@ def main(args):
505
551
index_filename = args .index_file
506
552
kmer_batch_size = args .kmer_batch_size
507
553
kmer_lengths_arg = args .kmer_lengths
554
+ initial_search_length = args .initial_search_length
555
+ exclude_bases_arg = args .exclude_bases
508
556
num_threads = args .thread_count
509
557
verbose = args .verbose
510
558
@@ -530,10 +578,20 @@ def main(args):
530
578
else :
531
579
kmer_lengths = list (map (int , kmer_lengths_arg .split ("," )))
532
580
581
+ if (initial_search_length and
582
+ not use_binary_search ):
583
+ raise ValueError ("Initial search length only valid when a range of "
584
+ "k-mer lengths is given" )
585
+
586
+ exclude_bases = set ([bytes (base , encoding = "utf-8" )
587
+ for base in exclude_bases_arg ])
588
+
533
589
write_unique_counts (Path (fasta_filename ),
534
590
Path (index_filename ),
535
591
kmer_batch_size ,
536
592
kmer_lengths ,
593
+ initial_search_length ,
594
+ exclude_bases ,
537
595
num_threads ,
538
596
use_binary_search ,
539
597
verbose )
0 commit comments