Skip to content

Commit fc48f87

Browse files
committed
Adding Gradient with respect to T
1 parent 644b51f commit fc48f87

File tree

1 file changed

+100
-44
lines changed

1 file changed

+100
-44
lines changed

notebooks/Kalman_Filter_Gradient.ipynb

Lines changed: 100 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 214,
13+
"execution_count": 1,
1414
"id": "90979a41",
1515
"metadata": {},
1616
"outputs": [],
@@ -40,18 +40,10 @@
4040
},
4141
{
4242
"cell_type": "code",
43-
"execution_count": null,
43+
"execution_count": 2,
4444
"id": "fdb156d6",
4545
"metadata": {},
46-
"outputs": [
47-
{
48-
"name": "stdout",
49-
"output_type": "stream",
50-
"text": [
51-
"<function compile_statespace.<locals>.f at 0x000001820DC942C0>\n"
52-
]
53-
}
54-
],
46+
"outputs": [],
5547
"source": [
5648
"mod = (\n",
5749
" sts.LevelTrendComponent(order=2, innovations_order=[0, 1], name='level') +\n",
@@ -87,7 +79,7 @@
8779
},
8880
{
8981
"cell_type": "code",
90-
"execution_count": 218,
82+
"execution_count": 3,
9183
"id": "3661408d",
9284
"metadata": {},
9385
"outputs": [],
@@ -140,7 +132,7 @@
140132
},
141133
{
142134
"cell_type": "code",
143-
"execution_count": 219,
135+
"execution_count": 4,
144136
"id": "35351096",
145137
"metadata": {},
146138
"outputs": [],
@@ -268,7 +260,7 @@
268260
},
269261
{
270262
"cell_type": "code",
271-
"execution_count": 132,
263+
"execution_count": 5,
272264
"id": "ee21ef4e",
273265
"metadata": {},
274266
"outputs": [],
@@ -330,7 +322,7 @@
330322
},
331323
{
332324
"cell_type": "code",
333-
"execution_count": 133,
325+
"execution_count": 6,
334326
"id": "8c89b018",
335327
"metadata": {},
336328
"outputs": [],
@@ -387,7 +379,7 @@
387379
},
388380
{
389381
"cell_type": "code",
390-
"execution_count": 134,
382+
"execution_count": 7,
391383
"id": "bba53a26",
392384
"metadata": {},
393385
"outputs": [],
@@ -426,7 +418,7 @@
426418
},
427419
{
428420
"cell_type": "code",
429-
"execution_count": 11,
421+
"execution_count": 8,
430422
"id": "c17949b7",
431423
"metadata": {},
432424
"outputs": [],
@@ -441,7 +433,7 @@
441433
"id": "f0bc0287",
442434
"metadata": {},
443435
"source": [
444-
"## Gradient with respect to **H**\n",
436+
"### Gradient with respect to **H**\n",
445437
"\n",
446438
"From the article we have :\n",
447439
"\n",
@@ -472,7 +464,7 @@
472464
},
473465
{
474466
"cell_type": "code",
475-
"execution_count": 135,
467+
"execution_count": 9,
476468
"id": "84cb6867",
477469
"metadata": {},
478470
"outputs": [],
@@ -499,6 +491,68 @@
499491
" return K.T @ P_filtered_grad @ K - 0.5 * K.T @ a_filtered_grad @ v.T @ F_inv - 0.5 * F_inv @ v @ a_filtered_grad.T @ K + F_inv - F_inv @ v @ v.T @ F_inv"
500492
]
501493
},
494+
{
495+
"cell_type": "markdown",
496+
"id": "4fa2ffc0",
497+
"metadata": {},
498+
"source": [
499+
"### Gradient with respect to **T**\n",
500+
"\n",
501+
"This gradient was not given in the article. Here are the steps that got me to this expression :\n",
502+
"\n",
503+
"1 - Only $x_{n|n-1}$ and $P_{n|n-1}$ depends on $T_n$. Hence :\n",
504+
"$$\n",
505+
"\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} + \\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial T}{\\partial P_{n|n-1}}\n",
506+
"$$\n",
507+
"2 - Using the equation (11) and (12) of the article, on the (1), we directly got that :\n",
508+
"$$\n",
509+
"\\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T\n",
510+
"$$\n",
511+
"3 - Recognizing the first quadratic form in the equation (2), and using equation (11) we got :\n",
512+
"$$\n",
513+
"\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T^T} = P_{n|n-1}T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}^T + P_{n|n-1}^T T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}\n",
514+
"$$\n",
515+
"4 - Now transposing to get the dependencies on T :\n",
516+
"$$\n",
517+
"\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n",
518+
"$$\n",
519+
"5 - Finally, we have :\n",
520+
"$$\n",
521+
"\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T + \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n",
522+
"$$"
523+
]
524+
},
525+
{
526+
"cell_type": "code",
527+
"execution_count": 10,
528+
"id": "9a560ed9",
529+
"metadata": {},
530+
"outputs": [],
531+
"source": [
532+
"def grad_T(inp, out, out_grad):\n",
533+
" y, a, P, T, Z, H, Q = inp\n",
534+
" a_hat_grad, P_h_grad, y_grad = out_grad\n",
535+
"\n",
536+
" y_hat = Z.dot(a)\n",
537+
" v = y - y_hat\n",
538+
"\n",
539+
" PZT = P.dot(Z.T)\n",
540+
" F = Z.dot(PZT) + H\n",
541+
" F_inv = pt.linalg.inv(F)\n",
542+
"\n",
543+
" K = PZT.dot(F_inv)\n",
544+
" I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n",
545+
"\n",
546+
" v = v.dimshuffle(0, 'x')\n",
547+
" a = a.dimshuffle(0, 'x')\n",
548+
" a_hat_grad = a_hat_grad.dimshuffle(0, 'x')\n",
549+
"\n",
550+
" a_filtered = a + K.dot(v)\n",
551+
" P_filtered = I_KZ @ P\n",
552+
"\n",
553+
" return a_hat_grad @ a_filtered.T + P_h_grad @ T @ P_filtered.T + P_h_grad.T @ T @ P_filtered"
554+
]
555+
},
502556
{
503557
"cell_type": "markdown",
504558
"id": "bd458dee",
@@ -509,7 +563,7 @@
509563
},
510564
{
511565
"cell_type": "code",
512-
"execution_count": null,
566+
"execution_count": 11,
513567
"id": "afb362e5",
514568
"metadata": {},
515569
"outputs": [],
@@ -566,7 +620,7 @@
566620
},
567621
{
568622
"cell_type": "code",
569-
"execution_count": 259,
623+
"execution_count": 12,
570624
"id": "7cead2c1",
571625
"metadata": {},
572626
"outputs": [],
@@ -576,7 +630,7 @@
576630
"kalman_step_op = OpFromGraph(\n",
577631
" inputs=[y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
578632
" outputs=kalman_step(y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym),\n",
579-
" lop_overrides=[grad_y, grad_a_hat, grad_P_hat, None, None, grad_H, grad_Q],\n",
633+
" lop_overrides=[grad_y, grad_a_hat, grad_P_hat, grad_T, None, grad_H, grad_Q],\n",
580634
" inline=True\n",
581635
")\n",
582636
"\n",
@@ -604,7 +658,7 @@
604658
},
605659
{
606660
"cell_type": "code",
607-
"execution_count": 264,
661+
"execution_count": 13,
608662
"id": "b6eb5d48",
609663
"metadata": {},
610664
"outputs": [],
@@ -665,7 +719,7 @@
665719
},
666720
{
667721
"cell_type": "code",
668-
"execution_count": 253,
722+
"execution_count": 14,
669723
"id": "908946b0",
670724
"metadata": {},
671725
"outputs": [],
@@ -712,7 +766,7 @@
712766
},
713767
{
714768
"cell_type": "code",
715-
"execution_count": null,
769+
"execution_count": 15,
716770
"id": "a85fe92e",
717771
"metadata": {},
718772
"outputs": [],
@@ -769,7 +823,7 @@
769823
},
770824
{
771825
"cell_type": "code",
772-
"execution_count": 254,
826+
"execution_count": 16,
773827
"id": "27a60fb3",
774828
"metadata": {},
775829
"outputs": [],
@@ -779,15 +833,15 @@
779833
},
780834
{
781835
"cell_type": "code",
782-
"execution_count": 257,
836+
"execution_count": 17,
783837
"id": "a413c8e9",
784838
"metadata": {},
785839
"outputs": [
786840
{
787841
"name": "stdout",
788842
"output_type": "stream",
789843
"text": [
790-
"defaultdict(<class 'dict'>, {'exec_time': 0.016296579997288063})\n"
844+
"defaultdict(<class 'dict'>, {'exec_time': 0.017576184973586352})\n"
791845
]
792846
}
793847
],
@@ -797,7 +851,7 @@
797851
},
798852
{
799853
"cell_type": "code",
800-
"execution_count": 260,
854+
"execution_count": 18,
801855
"id": "d35b98d6",
802856
"metadata": {},
803857
"outputs": [],
@@ -807,15 +861,15 @@
807861
},
808862
{
809863
"cell_type": "code",
810-
"execution_count": 261,
864+
"execution_count": 19,
811865
"id": "539c18c2",
812866
"metadata": {},
813867
"outputs": [
814868
{
815869
"name": "stdout",
816870
"output_type": "stream",
817871
"text": [
818-
"defaultdict(<class 'dict'>, {'exec_time': 0.015451419999590144})\n"
872+
"defaultdict(<class 'dict'>, {'exec_time': 0.021262520016171044})\n"
819873
]
820874
}
821875
],
@@ -825,7 +879,7 @@
825879
},
826880
{
827881
"cell_type": "code",
828-
"execution_count": 267,
882+
"execution_count": 20,
829883
"id": "1e633e75",
830884
"metadata": {},
831885
"outputs": [],
@@ -835,15 +889,15 @@
835889
},
836890
{
837891
"cell_type": "code",
838-
"execution_count": 268,
892+
"execution_count": 21,
839893
"id": "7118dfec",
840894
"metadata": {},
841895
"outputs": [
842896
{
843897
"name": "stdout",
844898
"output_type": "stream",
845899
"text": [
846-
"defaultdict(<class 'dict'>, {'Forward pass': 0.00269013000652194, 'Backprop': 0.002321220003068447})\n"
900+
"defaultdict(<class 'dict'>, {'Forward pass': 0.002400995010975749, 'Backprop': 0.0018996200058609247})\n"
847901
]
848902
}
849903
],
@@ -869,7 +923,7 @@
869923
},
870924
{
871925
"cell_type": "code",
872-
"execution_count": 274,
926+
"execution_count": 22,
873927
"id": "fbae0189",
874928
"metadata": {},
875929
"outputs": [],
@@ -905,7 +959,7 @@
905959
},
906960
{
907961
"cell_type": "code",
908-
"execution_count": 278,
962+
"execution_count": 23,
909963
"id": "c3a114b2",
910964
"metadata": {},
911965
"outputs": [
@@ -914,25 +968,25 @@
914968
"output_type": "stream",
915969
"text": [
916970
"Comparison between classic a0 gradient and our custom OpFromGraph : True\n",
917-
"Comparison between classic a0 gradient and our handmaid NumPy backprop : True\n"
971+
"Comparison between classic a0 gradient and our handmade NumPy backprop : True\n"
918972
]
919973
}
920974
],
921975
"source": [
922976
"print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0, grad_a0_op))\n",
923-
"print(\"Comparison between classic a0 gradient and our handmaid NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))"
977+
"print(\"Comparison between classic a0 gradient and our handmade NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))"
924978
]
925979
},
926980
{
927981
"cell_type": "code",
928-
"execution_count": 279,
982+
"execution_count": 24,
929983
"id": "867d5e2f",
930984
"metadata": {},
931985
"outputs": [],
932986
"source": [
933987
"# First the classic way with autodiff\n",
934988
"\n",
935-
"grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, H_sym, Q_sym])\n",
989+
"grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n",
936990
"f_grad = pytensor.function(\n",
937991
" inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
938992
" outputs=grad_list,\n",
@@ -942,7 +996,7 @@
942996
"\n",
943997
"# Now using our OpFromGraph custom gradient\n",
944998
"\n",
945-
"grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, H_sym, Q_sym])\n",
999+
"grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n",
9461000
"f_grad = pytensor.function(\n",
9471001
" inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
9481002
" outputs=grad_list_op,\n",
@@ -953,7 +1007,7 @@
9531007
},
9541008
{
9551009
"cell_type": "code",
956-
"execution_count": 289,
1010+
"execution_count": 25,
9571011
"id": "25f0a57b",
9581012
"metadata": {},
9591013
"outputs": [
@@ -964,6 +1018,7 @@
9641018
"Comparison between classic y gradient and our custom OpFromGraph : True\n",
9651019
"Comparison between classic a0 gradient and our custom OpFromGraph : True\n",
9661020
"Comparison between classic P0 gradient and our custom OpFromGraph : True\n",
1021+
"Comparison between classic T gradient and our custom OpFromGraph : True\n",
9671022
"Comparison between classic H gradient and our custom OpFromGraph : True\n",
9681023
"Comparison between classic Q gradient and our custom OpFromGraph : True\n"
9691024
]
@@ -973,8 +1028,9 @@
9731028
"print(\"Comparison between classic y gradient and our custom OpFromGraph :\", np.allclose(grad_a0[0], grad_a0_op[0]))\n",
9741029
"print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0[1], grad_a0_op[1]))\n",
9751030
"print(\"Comparison between classic P0 gradient and our custom OpFromGraph :\", np.allclose((grad_a0[2] + grad_a0[2].T)/2, grad_a0_op[2]))\n",
976-
"print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n",
977-
"print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[4] + grad_a0[4].T)/2, grad_a0_op[4]))"
1031+
"print(\"Comparison between classic T gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n",
1032+
"print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[4], grad_a0_op[4]))\n",
1033+
"print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[5] + grad_a0[5].T)/2, grad_a0_op[5]))"
9781034
]
9791035
}
9801036
],

0 commit comments

Comments
 (0)