Skip to content

Commit 37e54a2

Browse files
committed
Fixed weight correction for stdp synapses with axonal delays for edge cases
1 parent b435e30 commit 37e54a2

File tree

10 files changed

+54
-27
lines changed

10 files changed

+54
-27
lines changed

models/stdp_pl_synapse_hom_ax_delay.h

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ class stdp_pl_synapse_hom_ax_delay : public Connection< targetidentifierT, Axona
188188
*/
189189
void correct_synapse_stdp_ax_delay( const size_t tid,
190190
const double t_last_spike,
191+
const double t_spike_critical_interval_end,
191192
double& weight_revert,
193+
const double K_plus_revert,
192194
const double t_post_spike,
193195
const STDPPLHomAxDelayCommonProperties& cp );
194196

@@ -301,8 +303,12 @@ stdp_pl_synapse_hom_ax_delay< targetidentifierT >::send( Event& e,
301303
&start,
302304
&finish );
303305

306+
// Framework for STDP with predominantly axonal delays:
307+
// Store pre-synaptic trace for potential later correction
308+
const double K_plus_revert = Kplus_;
309+
304310
// facilitation due to postsynaptic spikes since last pre-synaptic spike
305-
double minus_dt;
311+
double minus_dt = 0.; // TODO JV
306312
while ( start != finish )
307313
{
308314
minus_dt = t_lastspike_ + axonal_delay_ms - ( start->t_ + dendritic_delay_ms );
@@ -331,13 +337,11 @@ stdp_pl_synapse_hom_ax_delay< targetidentifierT >::send( Event& e,
331337
const long time_while_critical =
332338
e.get_rel_delivery_steps( kernel().simulation_manager.get_slice_origin() ) - 2 * get_dendritic_delay_steps() + 1;
333339
// Only add correction entry if there could potentially be any post-synaptic spike that occurs before the
334-
// pre-synaptic one arrives at the synapse. Has to be strictly greater than min_delay, because a post-synaptic spike
335-
// at time slice_origin+min_delay corresponds to the last update step in the current slice (before delivery) and was
336-
// thus already known at time of delivery of the pre-synaptic one.
340+
// pre-synaptic one arrives at the synapse.
337341
if ( time_while_critical > 0 )
338342
{
339343
static_cast< ArchivingNode* >( target )->add_correction_entry_stdp_ax_delay(
340-
static_cast< SpikeEvent& >( e ), t_lastspike_, weight_revert, time_while_critical );
344+
static_cast< SpikeEvent& >( e ), t_lastspike_, weight_revert, K_plus_revert, time_while_critical );
341345
}
342346

343347
Kplus_ = Kplus_ * std::exp( ( t_lastspike_ - t_spike ) * cp.tau_plus_inv_ ) + 1.0;
@@ -384,28 +388,27 @@ template < typename targetidentifierT >
384388
inline void
385389
stdp_pl_synapse_hom_ax_delay< targetidentifierT >::correct_synapse_stdp_ax_delay( const size_t tid,
386390
const double t_last_spike,
391+
const double t_spike_critical_interval_end,
387392
double& weight_revert,
393+
const double K_plus_revert,
388394
const double t_post_spike,
389395
const STDPPLHomAxDelayCommonProperties& cp )
390396
{
391-
const double t_spike = t_lastspike_; // no new pre-synaptic spike since last send()
392397
const double wrong_weight = weight_; // incorrectly transmitted weight
393-
weight_ = weight_revert; // removes the last depressive step
394398
Node* target = get_target( tid );
395399

396400
const double axonal_delay_ms = get_axonal_delay_ms();
397401
double dendritic_delay_ms = get_dendritic_delay_ms();
398402

403+
const double t_spike = t_spike_critical_interval_end + dendritic_delay_ms - axonal_delay_ms;
404+
399405
// facilitation due to new post-synaptic spike
400406
const double minus_dt = t_last_spike + axonal_delay_ms - ( t_post_spike + dendritic_delay_ms );
401407

402-
double K_plus_revert;
403408
// Only facilitate if not facilitated already (only if first correction for this post-spike)
404409
if ( minus_dt < -1.0 * kernel().connection_manager.get_stdp_eps() )
405410
{
406-
// Kplus value at t_last_spike_ needed
407-
K_plus_revert = ( Kplus_ - 1.0 ) / std::exp( ( t_last_spike - t_spike ) * cp.tau_plus_inv_ );
408-
weight_ = facilitate_( weight_, K_plus_revert * std::exp( minus_dt * cp.tau_plus_inv_ ), cp );
411+
weight_ = facilitate_( weight_revert, K_plus_revert * std::exp( minus_dt * cp.tau_plus_inv_ ), cp );
409412

410413
// update weight_revert in case further correction will be required later
411414
weight_revert = weight_;

nestkernel/archiving_node.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ void
304304
ArchivingNode::add_correction_entry_stdp_ax_delay( SpikeEvent& spike_event,
305305
const double t_last_pre_spike,
306306
const double weight_revert,
307+
const double K_plus_revert,
307308
const double time_while_critical )
308309
{
309310
assert( correction_entries_stdp_ax_delay_.size()
@@ -313,8 +314,8 @@ ArchivingNode::add_correction_entry_stdp_ax_delay( SpikeEvent& spike_event,
313314
assert( static_cast< size_t >( idx ) < correction_entries_stdp_ax_delay_.size() );
314315

315316
const SpikeData& spike_data = spike_event.get_sender_spike_data();
316-
correction_entries_stdp_ax_delay_[ idx ].push_back(
317-
CorrectionEntrySTDPAxDelay( spike_data.get_lcid(), spike_data.get_syn_id(), t_last_pre_spike, weight_revert ) );
317+
correction_entries_stdp_ax_delay_[ idx ].push_back( CorrectionEntrySTDPAxDelay(
318+
spike_data.get_lcid(), spike_data.get_syn_id(), t_last_pre_spike, weight_revert, K_plus_revert ) );
318319
}
319320

320321
void
@@ -358,7 +359,9 @@ ArchivingNode::correct_synapses_stdp_ax_delay_( const Time& t_spike )
358359
it_corr_entry.syn_id_,
359360
it_corr_entry.lcid_,
360361
it_corr_entry.t_last_pre_spike_,
362+
( ori + Time::step( lag + 1 ) ).get_ms(),
361363
it_corr_entry.weight_revert_,
364+
it_corr_entry.K_plus_revert_,
362365
t_spike.get_ms() );
363366
}
364367
// indicate that the new spike was processed by these STDP synapses

nestkernel/archiving_node.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class ArchivingNode : public StructuralPlasticityNode
112112
void add_correction_entry_stdp_ax_delay( SpikeEvent& spike_event,
113113
const double t_last_pre_spike,
114114
const double weight_revert,
115+
const double K_plus_revert,
115116
const double time_while_critical );
116117

117118
protected:
@@ -176,22 +177,25 @@ class ArchivingNode : public StructuralPlasticityNode
176177
CorrectionEntrySTDPAxDelay( const size_t lcid,
177178
const synindex syn_id,
178179
const double t_last_pre_spike,
179-
const double weight_revert )
180+
const double weight_revert,
181+
const double K_plus_revert )
180182
: lcid_( lcid )
181183
, syn_id_( syn_id )
182184
, t_last_pre_spike_( t_last_pre_spike )
183185
, weight_revert_( weight_revert )
186+
, K_plus_revert_( K_plus_revert )
184187
{
185188
}
186189

187190
unsigned int lcid_; //!< local connection index
188191
synindex syn_id_; //!< synapse-type index
189192
double t_last_pre_spike_; //!< time of the last pre-synaptic spike before this spike
190193
double weight_revert_; //!< synaptic weight to revert to (STDP depression needs to be undone)
194+
double K_plus_revert_; //!< pre-synaptic trace before possibly incorrect facilitation
191195
};
192196

193197
//! check for correct correction entry size
194-
using correction_entry_size = StaticAssert< sizeof( ArchivingNode::CorrectionEntrySTDPAxDelay ) == 24 >::success;
198+
using correction_entry_size = StaticAssert< sizeof( ArchivingNode::CorrectionEntrySTDPAxDelay ) == 32 >::success;
195199

196200
protected:
197201
/**

nestkernel/connection.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ class Connection
179179
*/
180180
void correct_synapse_stdp_ax_delay( const size_t tid,
181181
const double t_last_pre_spike,
182+
const double t_spike_critical_interval_end,
182183
double& weight_revert,
184+
const double K_plus_revert,
183185
const double t_post_spike,
184186
const CommonSynapseProperties& );
185187

@@ -461,9 +463,11 @@ Connection< targetidentifierT, DelayTypeT >::calibrate( const TimeConverter& tc
461463
template < typename targetidentifierT, typename DelayTypeT >
462464
inline void
463465
Connection< targetidentifierT, DelayTypeT >::correct_synapse_stdp_ax_delay( const size_t,
466+
const double,
464467
const double,
465468
double&,
466469
const double,
470+
const double,
467471
const CommonSynapseProperties& )
468472
{
469473
throw IllegalConnection( "Connection does not support correction in case of STDP with predominantly axonal delays." );

nestkernel/connection_manager.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,16 +322,19 @@ class ConnectionManager : public ManagerInterface
322322
* @param syn_id Synapse type.
323323
* @param lcid Local index of the synapse in the array of connections of the same type for this thread.
324324
* @param t_last_pre_spike Time of the last pre-synaptic spike before the pre-synaptic spike which needs a correction.
325+
* @param t_spike
325326
* @param weight_revert The synaptic weight before depression after facilitation as baseline for potential later
326327
* correction.
327328
* @param t_post_spike Time of the current post-synaptic spike.
328329
*/
329-
void correct_synapse_stdp_ax_delay( const size_t tid,
330-
const synindex syn_id,
331-
const size_t lcid,
332-
const double t_last_pre_spike,
330+
void correct_synapse_stdp_ax_delay( size_t tid,
331+
synindex syn_id,
332+
size_t lcid,
333+
double t_last_pre_spike,
334+
double t_spike_critical_interval_end,
333335
double& weight_revert,
334-
const double t_post_spike );
336+
double K_plus_revert,
337+
double t_post_spike );
335338

336339
/**
337340
* Send event e to all device targets of source source_node_id
@@ -937,12 +940,14 @@ ConnectionManager::correct_synapse_stdp_ax_delay( const size_t tid,
937940
const synindex syn_id,
938941
const size_t lcid,
939942
const double t_last_pre_spike,
943+
const double t_spike_critical_interval_end,
940944
double& weight_revert,
945+
const double K_plus_revert,
941946
const double t_post_spike )
942947
{
943948
++num_corrections_;
944949
connections_[ tid ][ syn_id ]->correct_synapse_stdp_ax_delay(
945-
tid, syn_id, lcid, t_last_pre_spike, weight_revert, t_post_spike );
950+
tid, syn_id, lcid, t_last_pre_spike, t_spike_critical_interval_end, weight_revert, K_plus_revert, t_post_spike );
946951
}
947952

948953
inline void

nestkernel/connector_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class ConnectorBase
161161
* @param syn_id Synapse type.
162162
* @param lcid Local index of the synapse in the array of connections of the same type for this thread.
163163
* @param t_last_pre_spike Time of the last pre-synaptic spike before the pre-synaptic spike which needs a correction.
164+
* @param t_spike The time of the pre-synaptic spike which needs a correction.
164165
* @param weight_revert The synaptic weight before depression after facilitation as baseline for potential later
165166
* correction.
166167
* @param t_post_spike Time of the current post-synaptic spike.
@@ -169,7 +170,9 @@ class ConnectorBase
169170
const synindex syn_id,
170171
const size_t lcid,
171172
const double t_last_pre_spike,
173+
const double t_spike_critical_interval_end,
172174
double& weight_revert,
175+
const double K_plus_revert,
173176
const double t_post_spike ) = 0;
174177

175178
virtual void
@@ -438,7 +441,9 @@ class Connector : public ConnectorBase
438441
const synindex syn_id,
439442
const size_t lcid,
440443
const double t_last_pre_spike,
444+
const double t_spike_critical_interval_end,
441445
double& weight_revert,
446+
const double K_plus_revert,
442447
const double t_post_spike ) override;
443448

444449
// Implemented in connector_base_impl.h

nestkernel/connector_base_impl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,16 @@ Connector< ConnectionT >::correct_synapse_stdp_ax_delay( const size_t tid,
6767
const synindex syn_id,
6868
const size_t lcid,
6969
const double t_last_pre_spike,
70+
const double t_spike_critical_interval_end,
7071
double& weight_revert,
72+
const double K_plus_revert,
7173
const double t_post_spike )
7274
{
7375
typename ConnectionT::CommonPropertiesType const& cp = static_cast< GenericConnectorModel< ConnectionT >* >(
7476
kernel().model_manager.get_connection_models( tid )[ syn_id ] )
7577
->get_common_properties();
76-
C_[ lcid ].correct_synapse_stdp_ax_delay( tid, t_last_pre_spike, weight_revert, t_post_spike, cp );
78+
C_[ lcid ].correct_synapse_stdp_ax_delay(
79+
tid, t_last_pre_spike, t_spike_critical_interval_end, weight_revert, K_plus_revert, t_post_spike, cp );
7780
}
7881

7982
} // of namespace nest

nestkernel/event.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,10 +1011,9 @@ inline void
10111011
Event::set_stamp( Time const& s )
10121012
{
10131013
stamp_ = s;
1014-
stamp_steps_ = 0; // setting stamp_steps to zero indicates
1015-
// stamp_steps needs to be recalculated from
1016-
// stamp_ next time it is needed (e.g., in
1017-
// get_rel_delivery_steps)
1014+
// setting stamp_steps to zero indicates stamp_steps needs to be recalculated from stamp_ next time it is needed
1015+
// (e.g., in get_rel_delivery_steps)
1016+
stamp_steps_ = 0;
10181017
}
10191018

10201019
inline long

nestkernel/node.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Node::get_local_device_id() const
144144
}
145145

146146
void
147-
Node::add_correction_entry_stdp_ax_delay( SpikeEvent&, const double, const double, const double )
147+
Node::add_correction_entry_stdp_ax_delay( SpikeEvent&, const double, const double, const double, const double )
148148
{
149149
throw UnexpectedEvent( "Node does not support framework for STDP synapses with predominantly axonal delays." );
150150
}

nestkernel/node.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ class Node
10251025
void add_correction_entry_stdp_ax_delay( SpikeEvent& spike_event,
10261026
const double t_last_pre_spike,
10271027
const double weight_revert,
1028+
const double K_plus_revert,
10281029
const double time_while_critical );
10291030

10301031
/**

0 commit comments

Comments
 (0)