Skip to content

Commit a3f5369

Browse files
committed
Fix issue: #3532
1 parent b9bdcc7 commit a3f5369

File tree

2 files changed

+72
-20
lines changed

2 files changed

+72
-20
lines changed

nestkernel/connection_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ nest::ConnectionManager::find_connection( const size_t tid,
996996
{
997997
// lcid will hold the position of the /first/ connection from node
998998
// snode_id to any local node, or be invalid
999-
size_t lcid = source_table_.find_first_source( tid, syn_id, snode_id );
999+
size_t lcid = source_table_.find_first_source( tid, syn_id, snode_id, use_compressed_spikes() );
10001000
if ( lcid == invalid_index )
10011001
{
10021002
return invalid_index;
@@ -1446,7 +1446,7 @@ nest::ConnectionManager::get_targets( const std::vector< size_t >& sources,
14461446
{
14471447
for ( size_t i = 0; i < sources.size(); ++i )
14481448
{
1449-
const size_t start_lcid = source_table_.find_first_source( tid, syn_id, sources[ i ] );
1449+
const size_t start_lcid = source_table_.find_first_source( tid, syn_id, sources[ i ], use_compressed_spikes() );
14501450
if ( start_lcid != invalid_index )
14511451
{
14521452
connections_[ tid ][ syn_id ]->get_target_node_ids( tid, start_lcid, post_synaptic_element, targets[ i ] );

nestkernel/source_table.h

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <iostream>
3030
#include <map>
3131
#include <set>
32+
#include <utility>
3233
#include <vector>
3334

3435
// Includes from nestkernel:
@@ -300,7 +301,10 @@ class SourceTable
300301
* Finds the first entry in sources_ at the given thread id and
301302
* synapse type that is equal to snode_id.
302303
*/
303-
size_t find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const;
304+
size_t find_first_source( const size_t tid,
305+
const synindex syn_id,
306+
const size_t snode_id,
307+
bool isCompressedEnabled = false ) const;
304308

305309
/**
306310
* Marks entry in sources_ at given position as disabled.
@@ -471,27 +475,75 @@ SourceTable::no_targets_to_process( const size_t tid )
471475
}
472476

473477
inline size_t
474-
SourceTable::find_first_source( const size_t tid, const synindex syn_id, const size_t snode_id ) const
478+
SourceTable::find_first_source( const size_t tid,
479+
const synindex syn_id,
480+
const size_t snode_id,
481+
bool isCompressedEnabled /* default = false */ ) const
475482
{
476-
// binary search in sorted sources
477-
const BlockVector< Source >::const_iterator begin = sources_[ tid ][ syn_id ].begin();
478-
const BlockVector< Source >::const_iterator end = sources_[ tid ][ syn_id ].end();
479-
BlockVector< Source >::const_iterator it = std::lower_bound( begin, end, Source( snode_id, true ) );
480-
481-
// source found by binary search could be disabled, iterate through
482-
// sources until a valid one is found
483-
while ( it != end )
483+
using SourceIter = BlockVector< Source >::const_iterator;
484+
485+
const SourceIter begin = sources_[ tid ][ syn_id ].begin();
486+
const SourceIter end = sources_[ tid ][ syn_id ].end();
487+
488+
if ( isCompressedEnabled )
484489
{
485-
if ( it->get_node_id() == snode_id and not it->is_disabled() )
490+
// binary search in sorted sources
491+
SourceIter it = std::lower_bound( begin, end, Source( snode_id, true ) );
492+
493+
// source found by binary search could be disabled, iterate through
494+
// sources until a valid one is found
495+
while ( it != end )
486496
{
487-
const size_t lcid = it - begin;
488-
return lcid;
497+
if ( it->get_node_id() == snode_id and not it->is_disabled() )
498+
{
499+
const size_t lcid = it - begin;
500+
return lcid;
501+
}
502+
++it;
489503
}
490-
++it;
504+
505+
// no enabled entry with this snode ID found
506+
return invalid_index;
491507
}
508+
else
509+
{
492510

493-
// no enabled entry with this snode ID found
494-
return invalid_index;
511+
auto nth_equal =
512+
[]( SourceIter first, SourceIter last, const Source& value, size_t n ) -> std::pair< bool, SourceIter >
513+
{
514+
if ( n == 0 )
515+
{
516+
auto iter = std::find( first, last, value );
517+
return { iter != last, iter };
518+
}
519+
auto iter = std::find( first, last, value );
520+
while ( n > 0 && iter != last )
521+
{
522+
--n;
523+
iter = std::find( std::next( iter ), last, value );
524+
}
525+
return { iter != last, iter };
526+
};
527+
size_t pos = 0;
528+
auto res = nth_equal( begin, end, Source( snode_id, true ), pos );
529+
if ( !res.first )
530+
{
531+
return invalid_index;
532+
}
533+
534+
while ( res.first )
535+
{
536+
if ( res.second->get_node_id() == snode_id && not res.second->is_disabled() )
537+
{
538+
// found a valid source
539+
size_t lcid = res.second - begin;
540+
return lcid;
541+
}
542+
++pos;
543+
res = nth_equal( std::next( res.second ), end, Source( snode_id, true ), pos );
544+
}
545+
return invalid_index;
546+
}
495547
}
496548

497549
inline void
@@ -521,8 +573,8 @@ SourceTable::num_unique_sources( const size_t tid, const synindex syn_id ) const
521573
size_t n = 0;
522574
size_t last_source = 0;
523575
for ( BlockVector< Source >::const_iterator cit = sources_[ tid ][ syn_id ].begin();
524-
cit != sources_[ tid ][ syn_id ].end();
525-
++cit )
576+
cit != sources_[ tid ][ syn_id ].end();
577+
++cit )
526578
{
527579
if ( last_source != ( *cit ).get_node_id() )
528580
{

0 commit comments

Comments
 (0)