Skip to content

Commit a2eee24

Browse files
committed
Extended comments on and some code improvements for conn builders
1 parent 518f964 commit a2eee24

File tree

3 files changed

+146
-48
lines changed

3 files changed

+146
-48
lines changed

nestkernel/conn_builder.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -789,13 +789,6 @@ nest::ThirdOutBuilder::ThirdOutBuilder( const NodeCollectionPTR third,
789789
{
790790
}
791791

792-
void
793-
nest::ThirdOutBuilder::connect()
794-
{
795-
assert( false ); // should never be called
796-
}
797-
798-
799792
nest::ThirdBernoulliWithPoolBuilder::ThirdBernoulliWithPoolBuilder( const NodeCollectionPTR third,
800793
const NodeCollectionPTR targets,
801794
ThirdInBuilder* third_in,
@@ -855,6 +848,9 @@ nest::ThirdBernoulliWithPoolBuilder::ThirdBernoulliWithPoolBuilder( const NodeCo
855848

856849
if ( not random_pool_ )
857850
{
851+
// Tell every target neuron its position in the target node collection.
852+
// This is necessary to assign the right block pool to it.
853+
//
858854
// We cannot do this parallel with targets->local_begin() since we need to
859855
// count over all elements of the node collection which might be a complex
860856
// composition of slices with non-trivial mapping between elements and vps.
@@ -879,6 +875,9 @@ nest::ThirdBernoulliWithPoolBuilder::~ThirdBernoulliWithPoolBuilder()
879875

880876
if ( not random_pool_ )
881877
{
878+
// Reset tmp_nc_index in target nodes in case a node has never been a target.
879+
// We do not want non-invalid values to persist beyond the lifetime of this builder.
880+
//
882881
// Here we can work in parallel since we just reset to invalid_index
883882
for ( auto tgt_it = targets_->thread_local_begin(); tgt_it != targets_->end(); ++tgt_it )
884883
{

nestkernel/conn_builder.h

Lines changed: 128 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class ThirdOutBuilder;
6565
* the connect interface. Derived classes implement the connect
6666
* method.
6767
*
68-
* @note Naming classes *Builder to avoid name confusion with Connector classes.
68+
* @note This class is also the base class for all components of a TripartiteConnBuilder.
6969
*/
7070
class BipartiteConnBuilder
7171
{
@@ -76,6 +76,16 @@ class BipartiteConnBuilder
7676
//! Delete synapses with or without structural plasticity
7777
virtual void disconnect();
7878

79+
/**
80+
* Create new bipartite builder.
81+
*
82+
* @param sources Source population to connect from
83+
* @param targets Target population to connect to
84+
* @param third_out `nullptr` if pure bipartite connection, pointer to \class ThirdOutBuilder object if this builder
85+
* creates the primary connection of a tripartite connectivity
86+
* @param conn_spec Connection specification (if part of tripartite, spec for the specific part)
87+
* @param syn_specs Collection of synapse specifications (usually single element, several for collocated synapses)
88+
*/
7989
BipartiteConnBuilder( NodeCollectionPTR sources,
8090
NodeCollectionPTR targets,
8191
ThirdOutBuilder* third_out,
@@ -105,8 +115,6 @@ class BipartiteConnBuilder
105115

106116
void set_synaptic_element_names( const std::string& pre_name, const std::string& post_name );
107117

108-
bool all_parameters_scalar_() const;
109-
110118
/**
111119
* Updates the number of connected synaptic elements in the target and the source.
112120
*
@@ -118,12 +126,14 @@ class BipartiteConnBuilder
118126
*/
119127
bool change_connected_synaptic_elements( size_t snode_id, size_t tnode_id, const size_t tid, int update );
120128

129+
//! Return true if rule allows creation of symmetric connectivity
121130
virtual bool
122131
supports_symmetric() const
123132
{
124133
return false;
125134
}
126135

136+
//! Return true if rule automatically creates symmetric connectivity
127137
virtual bool
128138
is_symmetric() const
129139
{
@@ -153,6 +163,8 @@ class BipartiteConnBuilder
153163
//! Implements the actual connection algorithm
154164
virtual void connect_() = 0;
155165

166+
bool all_parameters_scalar_() const;
167+
156168
virtual void
157169
sp_connect_()
158170
{
@@ -204,17 +216,17 @@ class BipartiteConnBuilder
204216
*/
205217
bool loop_over_targets_() const;
206218

207-
NodeCollectionPTR sources_;
208-
NodeCollectionPTR targets_;
219+
NodeCollectionPTR sources_; //!< Population to connect from
220+
NodeCollectionPTR targets_; //!< Population to connect to
209221

210-
ThirdOutBuilder* third_out_; //!< to be triggered when primary connection is created
222+
ThirdOutBuilder* third_out_; //!< To be triggered when primary connection is created
211223

212224
bool allow_autapses_;
213225
bool allow_multapses_;
214226
bool make_symmetric_;
215227
bool creates_symmetric_connections_;
216228

217-
//! buffer for exceptions raised in threads
229+
//! Buffer for exceptions raised in threads
218230
std::vector< std::shared_ptr< WrappedThreadException > > exceptions_raised_;
219231

220232
// Name of the pre synaptic and postsynaptic elements for this connection builder
@@ -223,12 +235,18 @@ class BipartiteConnBuilder
223235

224236
bool use_structural_plasticity_;
225237

226-
//! pointers to connection parameters specified as arrays
238+
//! Pointers to connection parameters specified as arrays
227239
std::vector< ConnParameter* > parameters_requiring_skipping_;
228240

229241
std::vector< size_t > synapse_model_id_;
230242

231-
//! dictionaries to pass to connect function, one per thread for every syn_spec
243+
/**
244+
* Dictionaries to pass to connect function, one per thread for every syn_spec
245+
*
246+
* Outer dim: syn_spec, inner dim: thread
247+
*
248+
* @note Each thread can independently modify its dictionary to pass parameters on
249+
*/
232250
std::vector< std::vector< DictionaryDatum > > param_dicts_;
233251

234252
private:
@@ -300,20 +318,47 @@ class BipartiteConnBuilder
300318
* This builder creates the actual connections from primary sources to third-factor nodes
301319
* based on the source-third lists generated by the third-out builder.
302320
*
321+
* The `ThirdOutBuilder::third_connect()` method calls `register_connection()`
322+
* for each source node to third-factor node connection that needs to be created to store this
323+
* information in the `ThirdIn` builder.
324+
*
325+
* The `connect_()` method of this class needs to be called after all primary connections
326+
* and third-factor to target connections have been created. It then exchanges information
327+
* about required source-third connections with the other MPI ranks and creates required
328+
* connections locally.
329+
*
303330
* The class is final because there is no freedom of choice of connection rule at this stage.
304331
*/
305332
class ThirdInBuilder final : public BipartiteConnBuilder
306333
{
307334
public:
308-
ThirdInBuilder( NodeCollectionPTR,
309-
NodeCollectionPTR,
310-
const DictionaryDatum&, // only for compatibility with BipartiteConnBuilder
311-
const std::vector< DictionaryDatum >& );
335+
/**
336+
* Create ThirdInBuilder
337+
*
338+
* @param sources Source population of primary connection
339+
* @param third Third-factor population
340+
* @param third_conn_spec is ignored by this builder but required to make base class happy
341+
* @param syn_specs Collection of synapse specification for connection from primary source to third factor
342+
*
343+
* @todo Once DictionaryDatums are gone, see if we can remove `third_conn_spec` and just pass empty conn spec
344+
* container to base-class constructor, since \class ThirdInBuilder has no connection rule properties to set.
345+
*/
346+
ThirdInBuilder( NodeCollectionPTR sources,
347+
NodeCollectionPTR third,
348+
const DictionaryDatum& third_conn_spec,
349+
const std::vector< DictionaryDatum >& syn_specs );
312350
~ThirdInBuilder();
313351

314-
void register_connection( size_t primary_source_id, size_t primary_target_id );
352+
/**
353+
* Register required source node to third-factor node connection.
354+
*
355+
* @param primary_source_id GID of source node to connect from
356+
* @param third_node_id GID of target node to connect to
357+
*/
358+
void register_connection( size_t primary_source_id, size_t third_node_id );
315359

316360
private:
361+
//! Exchange required connection info via MPI and create needed connections locally
317362
void connect_() override;
318363

319364
/**
@@ -336,15 +381,16 @@ class ThirdInBuilder final : public BipartiteConnBuilder
336381
{
337382
}
338383

339-
size_t source_gid;
340-
size_t third_gid;
341-
size_t third_rank;
384+
size_t source_gid; //!< GID of source node to connect from
385+
size_t third_gid; //!< GID of third-factor node to connect to
386+
size_t third_rank; //!< Rank of third-factor node (stored locally as it is needed multiple times)
342387
};
343388

344-
//! source-thirdparty GID pairs to be communicated; one per thread
389+
//! Container for register source-third connections, one per thread via pointer for collision free operation in
390+
//! thread-local storage
345391
std::vector< BlockVector< SourceThirdInfo_ >* > source_third_gids_;
346392

347-
//! number of source-third pairs to send. Outer dimension writing thread, inner dimension rank to send to
393+
//! Number of source-third pairs to send. Outer dimension is writing thread, inner dimension MPI rank to send to
348394
std::vector< std::vector< size_t >* > source_third_counts_;
349395
};
350396

@@ -361,14 +407,33 @@ class ThirdInBuilder final : public BipartiteConnBuilder
361407
class ThirdOutBuilder : public BipartiteConnBuilder
362408
{
363409
public:
364-
ThirdOutBuilder( NodeCollectionPTR,
365-
NodeCollectionPTR,
366-
ThirdInBuilder*,
367-
const DictionaryDatum&, // only for compatibility with BipartiteConnBuilder
368-
const std::vector< DictionaryDatum >& );
410+
/**
411+
* Create ThirdOutBuilder.
412+
*
413+
* @param third Third-factor population
414+
* @param targets Target population of primary connection
415+
* @param third_in ThirdInBuilder which will create source-third connections later
416+
* @param third_conn_spec Specification for third-factor connectivity
417+
* @param syn_specs Collection of synapse specifications for third-target connections
418+
*/
419+
ThirdOutBuilder( const NodeCollectionPTR third,
420+
const NodeCollectionPTR targets,
421+
ThirdInBuilder* third_in,
422+
const DictionaryDatum& third_conn_spec,
423+
const std::vector< DictionaryDatum >& syn_specs );
369424

370-
void connect() override;
425+
void
426+
connect() override final
427+
{
428+
assert( false );
429+
} //!< only call third_connect() on ThirdOutBuilder
371430

431+
/**
432+
* Create third-factor connection for given primary connection.
433+
*
434+
* @param source_gid GID of source of primary connection
435+
* @param target Target node of primary connection
436+
*/
372437
virtual void third_connect( size_t source_gid, Node& target ) = 0;
373438

374439
protected:
@@ -392,7 +457,7 @@ class ConnBuilder
392457
* @param sources Source population for primary connection
393458
* @param targets Target population for primary connection
394459
* @param conn_spec Connection specification dictionary for tripartite bernoulli rule
395-
* @param syn_spec Dictionary of synapse specification
460+
* @param syn_spec Collection of dictionaries with synapse specifications
396461
*/
397462
ConnBuilder( const std::string& primary_rule,
398463
NodeCollectionPTR sources,
@@ -410,7 +475,8 @@ class ConnBuilder
410475
* @param third Third-party population
411476
* @param conn_spec Connection specification dictionary for tripartite bernoulli rule
412477
* @param syn_specs Dictionary of synapse specifications for the three connections that may be created. Allowed keys
413-
* are `"primary"`, `"third_in"`, `"third_out"`
478+
* are `"primary"`, `"third_in"`, `"third_out"`, and for each of these the value must be a collection of dictionaries
479+
* with synapse specifications as for bipartite connectivity.
414480
*/
415481
ConnBuilder( const std::string& primary_rule,
416482
const std::string& third_rule,
@@ -430,20 +496,23 @@ class ConnBuilder
430496
void disconnect();
431497

432498
private:
433-
// order of declarations based on dependencies
499+
// Order of declarations based on dependencies, do not change.
434500
ThirdInBuilder* third_in_builder_;
435501
ThirdOutBuilder* third_out_builder_;
436502
BipartiteConnBuilder* primary_builder_;
437503
};
438504

439-
505+
/**
506+
* Build third-factor connectivity based on Bernoulli trials, selecting third factor nodes from a fixed pool per target
507+
* node.
508+
*/
440509
class ThirdBernoulliWithPoolBuilder : public ThirdOutBuilder
441510
{
442511
public:
443512
ThirdBernoulliWithPoolBuilder( NodeCollectionPTR,
444513
NodeCollectionPTR,
445-
ThirdInBuilder* third_in,
446-
const DictionaryDatum&, // only for compatibility with BCB
514+
ThirdInBuilder*,
515+
const DictionaryDatum&,
447516
const std::vector< DictionaryDatum >& );
448517
~ThirdBernoulliWithPoolBuilder();
449518

@@ -454,16 +523,39 @@ class ThirdBernoulliWithPoolBuilder : public ThirdOutBuilder
454523
connect_() override
455524
{
456525
assert( false );
457-
}
526+
} //!< only call third_connect()
527+
528+
/**
529+
* For block pool, return index of first pool element for given target node.
530+
*
531+
* @param targe_index
532+
*/
458533
size_t get_first_pool_index_( const size_t target_index ) const;
459534

460-
double p_;
461-
bool random_pool_;
462-
size_t pool_size_;
463-
size_t targets_per_third_;
535+
double p_; //!< probability of creating a third-factor connection
536+
bool random_pool_; //!< random or block pool?
537+
size_t pool_size_; //!< number of nodes per pool
538+
size_t targets_per_third_; //!< number of target nodes per third-factor node
464539

540+
/**
541+
* Type for single pool of third-factor nodes
542+
*
543+
* @todo Could probably be BlockVector, but currently some problem with back_inserter when sampling pool.
544+
*/
465545
typedef std::vector< NodeIDTriple > PoolType_;
546+
547+
//! Type mapping target GID to pool for this target
466548
typedef std::map< size_t, PoolType_ > TgtPoolMap_;
549+
550+
/**
551+
* Thread-specific pools of third-factor nodes.
552+
*
553+
* Each thread maintains a map from target node IDs to the third-factor node pool for that target node.
554+
* Since each target lives on exactly one thread, there will be no overlap. For each node, the pool is
555+
* created when a third-factor connection needs to be made to that node for the first time.
556+
* The pools are deleted when the ConnBuilder is destroyed at the end of the connect call.
557+
* We store a pointer instead of the map itself to ensure that the map is in thread-local memory.
558+
*/
467559
std::vector< TgtPoolMap_* > pools_; // outer: threads
468560
};
469561

nestkernel/node.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -954,14 +954,16 @@ class Node
954954
DeprecationWarning deprecation_warning;
955955

956956
/**
957-
* This is only to be used to get the index in the NC to the ThirdOutBuilder
957+
* Set index in node collection; required by ThirdOutBuilder.
958958
*/
959959
void set_tmp_nc_index( size_t index );
960960

961961
/**
962-
* Return index in NC and invalidate entry to avoid multiple reads. Only to be called by ThirdOutBuilder
962+
* Return and invalidate index in node collection; required by ThirdOutBuilder.
963+
*
964+
* @note Not const since it invalidates index in node object.
963965
*/
964-
size_t get_tmp_nc_index() const;
966+
size_t get_tmp_nc_index();
965967

966968

967969
private:
@@ -1048,6 +1050,8 @@ class Node
10481050
* @note This is only here so that the primary connection builder can inform the ThirdOutBuilder
10491051
* about the index of the target neuron in the targets node collection. This is required for block-based
10501052
* builders.
1053+
*
1054+
* @note Set by set_tmp_nc_index() and invalidated by get_tmp_nc_index().
10511055
*/
10521056
size_t tmp_nc_index_;
10531057
};
@@ -1196,11 +1200,14 @@ Node::set_tmp_nc_index( size_t index )
11961200
}
11971201

11981202
inline size_t
1199-
Node::get_tmp_nc_index() const
1203+
Node::get_tmp_nc_index()
12001204
{
12011205
assert( tmp_nc_index_ != invalid_index );
12021206

1203-
return tmp_nc_index_;
1207+
const auto index = tmp_nc_index_;
1208+
tmp_nc_index_ = invalid_index;
1209+
1210+
return index;
12041211
}
12051212

12061213

0 commit comments

Comments
 (0)