|
29 | 29 | #include <iostream> |
30 | 30 | #include <map> |
31 | 31 | #include <set> |
| 32 | +#include <utility> |
32 | 33 | #include <vector> |
33 | 34 |
|
34 | 35 | // Includes from nestkernel: |
@@ -300,7 +301,10 @@ class SourceTable |
300 | 301 | * Finds the first entry in sources_ at the given thread id and |
301 | 302 | * synapse type that is equal to snode_id. |
302 | 303 | */ |
303 | | - size_t find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const; |
| 304 | + size_t find_first_source( const size_t tid, |
| 305 | + const synindex syn_id, |
| 306 | + const size_t snode_id, |
| 307 | + bool isCompressedEnabled = false ) const; |
304 | 308 |
|
305 | 309 | /** |
306 | 310 | * Marks entry in sources_ at given position as disabled. |
@@ -471,27 +475,75 @@ SourceTable::no_targets_to_process( const size_t tid ) |
471 | 475 | } |
472 | 476 |
|
473 | 477 | inline size_t |
474 | | -SourceTable::find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const |
| 478 | +SourceTable::find_first_source( const size_t tid, |
| 479 | + const synindex syn_id, |
| 480 | + const size_t snode_id, |
| 481 | + bool isCompressedEnabled /* default = false */ ) const |
475 | 482 | { |
476 | | - // binary search in sorted sources |
477 | | - const BlockVector< Source >::const_iterator begin = sources_[ tid ][ syn_id ].begin(); |
478 | | - const BlockVector< Source >::const_iterator end = sources_[ tid ][ syn_id ].end(); |
479 | | - BlockVector< Source >::const_iterator it = std::lower_bound( begin, end, Source( snode_id, true ) ); |
480 | | - |
481 | | - // source found by binary search could be disabled, iterate through |
482 | | - // sources until a valid one is found |
483 | | - while ( it != end ) |
| 483 | + using SourceIter = BlockVector< Source >::const_iterator; |
| 484 | + |
| 485 | + const SourceIter begin = sources_[ tid ][ syn_id ].begin(); |
| 486 | + const SourceIter end = sources_[ tid ][ syn_id ].end(); |
| 487 | + |
| 488 | + if ( isCompressedEnabled ) |
484 | 489 | { |
485 | | - if ( it->get_node_id() == snode_id and not it->is_disabled() ) |
| 490 | + // binary search in sorted sources |
| 491 | + SourceIter it = std::lower_bound( begin, end, Source( snode_id, true ) ); |
| 492 | + |
| 493 | + // source found by binary search could be disabled, iterate through |
| 494 | + // sources until a valid one is found |
| 495 | + while ( it != end ) |
486 | 496 | { |
487 | | - const size_t lcid = it - begin; |
488 | | - return lcid; |
| 497 | + if ( it->get_node_id() == snode_id and not it->is_disabled() ) |
| 498 | + { |
| 499 | + const size_t lcid = it - begin; |
| 500 | + return lcid; |
| 501 | + } |
| 502 | + ++it; |
489 | 503 | } |
490 | | - ++it; |
| 504 | + |
| 505 | + // no enabled entry with this snode ID found |
| 506 | + return invalid_index; |
491 | 507 | } |
| 508 | + else |
| 509 | + { |
492 | 510 |
|
493 | | - // no enabled entry with this snode ID found |
494 | | - return invalid_index; |
| 511 | + auto nth_equal = |
| 512 | + []( SourceIter first, SourceIter last, const Source& value, size_t n ) -> std::pair< bool, SourceIter > |
| 513 | + { |
| 514 | + if ( n == 0 ) |
| 515 | + { |
| 516 | + auto iter = std::find( first, last, value ); |
| 517 | + return { iter != last, iter }; |
| 518 | + } |
| 519 | + auto iter = std::find( first, last, value ); |
| 520 | + while ( n > 0 && iter != last ) |
| 521 | + { |
| 522 | + --n; |
| 523 | + iter = std::find( std::next( iter ), last, value ); |
| 524 | + } |
| 525 | + return { iter != last, iter }; |
| 526 | + }; |
| 527 | + size_t pos = 0; |
| 528 | + auto res = nth_equal( begin, end, Source( snode_id, true ), pos ); |
| 529 | + if ( !res.first ) |
| 530 | + { |
| 531 | + return invalid_index; |
| 532 | + } |
| 533 | + |
| 534 | + while ( res.first ) |
| 535 | + { |
| 536 | + if ( res.second->get_node_id() == snode_id && not res.second->is_disabled() ) |
| 537 | + { |
| 538 | + // found a valid source |
| 539 | + size_t lcid = res.second - begin; |
| 540 | + return lcid; |
| 541 | + } |
| 542 | + ++pos; |
| 543 | + res = nth_equal( std::next( res.second ), end, Source( snode_id, true ), pos ); |
| 544 | + } |
| 545 | + return invalid_index; |
| 546 | + } |
495 | 547 | } |
496 | 548 |
|
497 | 549 | inline void |
@@ -521,8 +573,8 @@ SourceTable::num_unique_sources( const size_t tid, const synindex syn_id ) const |
521 | 573 | size_t n = 0; |
522 | 574 | size_t last_source = 0; |
523 | 575 | for ( BlockVector< Source >::const_iterator cit = sources_[ tid ][ syn_id ].begin(); |
524 | | - cit != sources_[ tid ][ syn_id ].end(); |
525 | | - ++cit ) |
| 576 | + cit != sources_[ tid ][ syn_id ].end(); |
| 577 | + ++cit ) |
526 | 578 | { |
527 | 579 | if ( last_source != ( *cit ).get_node_id() ) |
528 | 580 | { |
|
0 commit comments