|
10 | 10 | },
|
11 | 11 | {
|
12 | 12 | "cell_type": "code",
|
13 |
| - "execution_count": 214, |
| 13 | + "execution_count": 1, |
14 | 14 | "id": "90979a41",
|
15 | 15 | "metadata": {},
|
16 | 16 | "outputs": [],
|
|
40 | 40 | },
|
41 | 41 | {
|
42 | 42 | "cell_type": "code",
|
43 |
| - "execution_count": null, |
| 43 | + "execution_count": 2, |
44 | 44 | "id": "fdb156d6",
|
45 | 45 | "metadata": {},
|
46 |
| - "outputs": [ |
47 |
| - { |
48 |
| - "name": "stdout", |
49 |
| - "output_type": "stream", |
50 |
| - "text": [ |
51 |
| - "<function compile_statespace.<locals>.f at 0x000001820DC942C0>\n" |
52 |
| - ] |
53 |
| - } |
54 |
| - ], |
| 46 | + "outputs": [], |
55 | 47 | "source": [
|
56 | 48 | "mod = (\n",
|
57 | 49 | " sts.LevelTrendComponent(order=2, innovations_order=[0, 1], name='level') +\n",
|
|
87 | 79 | },
|
88 | 80 | {
|
89 | 81 | "cell_type": "code",
|
90 |
| - "execution_count": 218, |
| 82 | + "execution_count": 3, |
91 | 83 | "id": "3661408d",
|
92 | 84 | "metadata": {},
|
93 | 85 | "outputs": [],
|
|
140 | 132 | },
|
141 | 133 | {
|
142 | 134 | "cell_type": "code",
|
143 |
| - "execution_count": 219, |
| 135 | + "execution_count": 4, |
144 | 136 | "id": "35351096",
|
145 | 137 | "metadata": {},
|
146 | 138 | "outputs": [],
|
|
268 | 260 | },
|
269 | 261 | {
|
270 | 262 | "cell_type": "code",
|
271 |
| - "execution_count": 132, |
| 263 | + "execution_count": 5, |
272 | 264 | "id": "ee21ef4e",
|
273 | 265 | "metadata": {},
|
274 | 266 | "outputs": [],
|
|
330 | 322 | },
|
331 | 323 | {
|
332 | 324 | "cell_type": "code",
|
333 |
| - "execution_count": 133, |
| 325 | + "execution_count": 6, |
334 | 326 | "id": "8c89b018",
|
335 | 327 | "metadata": {},
|
336 | 328 | "outputs": [],
|
|
387 | 379 | },
|
388 | 380 | {
|
389 | 381 | "cell_type": "code",
|
390 |
| - "execution_count": 134, |
| 382 | + "execution_count": 7, |
391 | 383 | "id": "bba53a26",
|
392 | 384 | "metadata": {},
|
393 | 385 | "outputs": [],
|
|
426 | 418 | },
|
427 | 419 | {
|
428 | 420 | "cell_type": "code",
|
429 |
| - "execution_count": 11, |
| 421 | + "execution_count": 8, |
430 | 422 | "id": "c17949b7",
|
431 | 423 | "metadata": {},
|
432 | 424 | "outputs": [],
|
|
441 | 433 | "id": "f0bc0287",
|
442 | 434 | "metadata": {},
|
443 | 435 | "source": [
|
444 |
| - "## Gradient with respect to **H**\n", |
| 436 | + "### Gradient with respect to **H**\n", |
445 | 437 | "\n",
|
446 | 438 | "From the article we have :\n",
|
447 | 439 | "\n",
|
|
472 | 464 | },
|
473 | 465 | {
|
474 | 466 | "cell_type": "code",
|
475 |
| - "execution_count": 135, |
| 467 | + "execution_count": 9, |
476 | 468 | "id": "84cb6867",
|
477 | 469 | "metadata": {},
|
478 | 470 | "outputs": [],
|
|
499 | 491 | " return K.T @ P_filtered_grad @ K - 0.5 * K.T @ a_filtered_grad @ v.T @ F_inv - 0.5 * F_inv @ v @ a_filtered_grad.T @ K + F_inv - F_inv @ v @ v.T @ F_inv"
|
500 | 492 | ]
|
501 | 493 | },
|
| 494 | + { |
| 495 | + "cell_type": "markdown", |
| 496 | + "id": "4fa2ffc0", |
| 497 | + "metadata": {}, |
| 498 | + "source": [ |
| 499 | + "### Gradient with respect to **T**\n", |
| 500 | + "\n", |
| 501 | + "This gradient was not given in the article. Here are the steps that got me to this expression :\n", |
| 502 | + "\n", |
| 503 | + "1 - Only $x_{n|n-1}$ and $P_{n|n-1}$ depends on $T_n$. Hence :\n", |
| 504 | + "$$\n", |
| 505 | + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} + \\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial T}{\\partial P_{n|n-1}}\n", |
| 506 | + "$$\n", |
| 507 | + "2 - Using the equation (11) and (12) of the article, on the (1), we directly got that :\n", |
| 508 | + "$$\n", |
| 509 | + "\\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T\n", |
| 510 | + "$$\n", |
| 511 | + "3 - Recognizing the first quadratic form in the equation (2), and using equation (11) we got :\n", |
| 512 | + "$$\n", |
| 513 | + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T^T} = P_{n|n-1}T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}^T + P_{n|n-1}^T T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}\n", |
| 514 | + "$$\n", |
| 515 | + "4 - Now transposing to get the dependencies on T :\n", |
| 516 | + "$$\n", |
| 517 | + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", |
| 518 | + "$$\n", |
| 519 | + "5 - Finally, we have :\n", |
| 520 | + "$$\n", |
| 521 | + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T + \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", |
| 522 | + "$$" |
| 523 | + ] |
| 524 | + }, |
| 525 | + { |
| 526 | + "cell_type": "code", |
| 527 | + "execution_count": 10, |
| 528 | + "id": "9a560ed9", |
| 529 | + "metadata": {}, |
| 530 | + "outputs": [], |
| 531 | + "source": [ |
| 532 | + "def grad_T(inp, out, out_grad):\n", |
| 533 | + " y, a, P, T, Z, H, Q = inp\n", |
| 534 | + " a_hat_grad, P_h_grad, y_grad = out_grad\n", |
| 535 | + "\n", |
| 536 | + " y_hat = Z.dot(a)\n", |
| 537 | + " v = y - y_hat\n", |
| 538 | + "\n", |
| 539 | + " PZT = P.dot(Z.T)\n", |
| 540 | + " F = Z.dot(PZT) + H\n", |
| 541 | + " F_inv = pt.linalg.inv(F)\n", |
| 542 | + "\n", |
| 543 | + " K = PZT.dot(F_inv)\n", |
| 544 | + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", |
| 545 | + "\n", |
| 546 | + " v = v.dimshuffle(0, 'x')\n", |
| 547 | + " a = a.dimshuffle(0, 'x')\n", |
| 548 | + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x')\n", |
| 549 | + "\n", |
| 550 | + " a_filtered = a + K.dot(v)\n", |
| 551 | + " P_filtered = I_KZ @ P\n", |
| 552 | + "\n", |
| 553 | + " return a_hat_grad @ a_filtered.T + P_h_grad @ T @ P_filtered.T + P_h_grad.T @ T @ P_filtered" |
| 554 | + ] |
| 555 | + }, |
502 | 556 | {
|
503 | 557 | "cell_type": "markdown",
|
504 | 558 | "id": "bd458dee",
|
|
509 | 563 | },
|
510 | 564 | {
|
511 | 565 | "cell_type": "code",
|
512 |
| - "execution_count": null, |
| 566 | + "execution_count": 11, |
513 | 567 | "id": "afb362e5",
|
514 | 568 | "metadata": {},
|
515 | 569 | "outputs": [],
|
|
566 | 620 | },
|
567 | 621 | {
|
568 | 622 | "cell_type": "code",
|
569 |
| - "execution_count": 259, |
| 623 | + "execution_count": 12, |
570 | 624 | "id": "7cead2c1",
|
571 | 625 | "metadata": {},
|
572 | 626 | "outputs": [],
|
|
576 | 630 | "kalman_step_op = OpFromGraph(\n",
|
577 | 631 | " inputs=[y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
|
578 | 632 | " outputs=kalman_step(y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym),\n",
|
579 |
| - " lop_overrides=[grad_y, grad_a_hat, grad_P_hat, None, None, grad_H, grad_Q],\n", |
| 633 | + " lop_overrides=[grad_y, grad_a_hat, grad_P_hat, grad_T, None, grad_H, grad_Q],\n", |
580 | 634 | " inline=True\n",
|
581 | 635 | ")\n",
|
582 | 636 | "\n",
|
|
604 | 658 | },
|
605 | 659 | {
|
606 | 660 | "cell_type": "code",
|
607 |
| - "execution_count": 264, |
| 661 | + "execution_count": 13, |
608 | 662 | "id": "b6eb5d48",
|
609 | 663 | "metadata": {},
|
610 | 664 | "outputs": [],
|
|
665 | 719 | },
|
666 | 720 | {
|
667 | 721 | "cell_type": "code",
|
668 |
| - "execution_count": 253, |
| 722 | + "execution_count": 14, |
669 | 723 | "id": "908946b0",
|
670 | 724 | "metadata": {},
|
671 | 725 | "outputs": [],
|
|
712 | 766 | },
|
713 | 767 | {
|
714 | 768 | "cell_type": "code",
|
715 |
| - "execution_count": null, |
| 769 | + "execution_count": 15, |
716 | 770 | "id": "a85fe92e",
|
717 | 771 | "metadata": {},
|
718 | 772 | "outputs": [],
|
|
769 | 823 | },
|
770 | 824 | {
|
771 | 825 | "cell_type": "code",
|
772 |
| - "execution_count": 254, |
| 826 | + "execution_count": 16, |
773 | 827 | "id": "27a60fb3",
|
774 | 828 | "metadata": {},
|
775 | 829 | "outputs": [],
|
|
779 | 833 | },
|
780 | 834 | {
|
781 | 835 | "cell_type": "code",
|
782 |
| - "execution_count": 257, |
| 836 | + "execution_count": 17, |
783 | 837 | "id": "a413c8e9",
|
784 | 838 | "metadata": {},
|
785 | 839 | "outputs": [
|
786 | 840 | {
|
787 | 841 | "name": "stdout",
|
788 | 842 | "output_type": "stream",
|
789 | 843 | "text": [
|
790 |
| - "defaultdict(<class 'dict'>, {'exec_time': 0.016296579997288063})\n" |
| 844 | + "defaultdict(<class 'dict'>, {'exec_time': 0.017576184973586352})\n" |
791 | 845 | ]
|
792 | 846 | }
|
793 | 847 | ],
|
|
797 | 851 | },
|
798 | 852 | {
|
799 | 853 | "cell_type": "code",
|
800 |
| - "execution_count": 260, |
| 854 | + "execution_count": 18, |
801 | 855 | "id": "d35b98d6",
|
802 | 856 | "metadata": {},
|
803 | 857 | "outputs": [],
|
|
807 | 861 | },
|
808 | 862 | {
|
809 | 863 | "cell_type": "code",
|
810 |
| - "execution_count": 261, |
| 864 | + "execution_count": 19, |
811 | 865 | "id": "539c18c2",
|
812 | 866 | "metadata": {},
|
813 | 867 | "outputs": [
|
814 | 868 | {
|
815 | 869 | "name": "stdout",
|
816 | 870 | "output_type": "stream",
|
817 | 871 | "text": [
|
818 |
| - "defaultdict(<class 'dict'>, {'exec_time': 0.015451419999590144})\n" |
| 872 | + "defaultdict(<class 'dict'>, {'exec_time': 0.021262520016171044})\n" |
819 | 873 | ]
|
820 | 874 | }
|
821 | 875 | ],
|
|
825 | 879 | },
|
826 | 880 | {
|
827 | 881 | "cell_type": "code",
|
828 |
| - "execution_count": 267, |
| 882 | + "execution_count": 20, |
829 | 883 | "id": "1e633e75",
|
830 | 884 | "metadata": {},
|
831 | 885 | "outputs": [],
|
|
835 | 889 | },
|
836 | 890 | {
|
837 | 891 | "cell_type": "code",
|
838 |
| - "execution_count": 268, |
| 892 | + "execution_count": 21, |
839 | 893 | "id": "7118dfec",
|
840 | 894 | "metadata": {},
|
841 | 895 | "outputs": [
|
842 | 896 | {
|
843 | 897 | "name": "stdout",
|
844 | 898 | "output_type": "stream",
|
845 | 899 | "text": [
|
846 |
| - "defaultdict(<class 'dict'>, {'Forward pass': 0.00269013000652194, 'Backprop': 0.002321220003068447})\n" |
| 900 | + "defaultdict(<class 'dict'>, {'Forward pass': 0.002400995010975749, 'Backprop': 0.0018996200058609247})\n" |
847 | 901 | ]
|
848 | 902 | }
|
849 | 903 | ],
|
|
869 | 923 | },
|
870 | 924 | {
|
871 | 925 | "cell_type": "code",
|
872 |
| - "execution_count": 274, |
| 926 | + "execution_count": 22, |
873 | 927 | "id": "fbae0189",
|
874 | 928 | "metadata": {},
|
875 | 929 | "outputs": [],
|
|
905 | 959 | },
|
906 | 960 | {
|
907 | 961 | "cell_type": "code",
|
908 |
| - "execution_count": 278, |
| 962 | + "execution_count": 23, |
909 | 963 | "id": "c3a114b2",
|
910 | 964 | "metadata": {},
|
911 | 965 | "outputs": [
|
|
914 | 968 | "output_type": "stream",
|
915 | 969 | "text": [
|
916 | 970 | "Comparison between classic a0 gradient and our custom OpFromGraph : True\n",
|
917 |
| - "Comparison between classic a0 gradient and our handmaid NumPy backprop : True\n" |
| 971 | + "Comparison between classic a0 gradient and our handmade NumPy backprop : True\n" |
918 | 972 | ]
|
919 | 973 | }
|
920 | 974 | ],
|
921 | 975 | "source": [
|
922 | 976 | "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0, grad_a0_op))\n",
|
923 |
| - "print(\"Comparison between classic a0 gradient and our handmaid NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))" |
| 977 | + "print(\"Comparison between classic a0 gradient and our handmade NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))" |
924 | 978 | ]
|
925 | 979 | },
|
926 | 980 | {
|
927 | 981 | "cell_type": "code",
|
928 |
| - "execution_count": 279, |
| 982 | + "execution_count": 24, |
929 | 983 | "id": "867d5e2f",
|
930 | 984 | "metadata": {},
|
931 | 985 | "outputs": [],
|
932 | 986 | "source": [
|
933 | 987 | "# First the classic way with autodiff\n",
|
934 | 988 | "\n",
|
935 |
| - "grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, H_sym, Q_sym])\n", |
| 989 | + "grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n", |
936 | 990 | "f_grad = pytensor.function(\n",
|
937 | 991 | " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
|
938 | 992 | " outputs=grad_list,\n",
|
|
942 | 996 | "\n",
|
943 | 997 | "# Now using our OpFromGraph custom gradient\n",
|
944 | 998 | "\n",
|
945 |
| - "grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, H_sym, Q_sym])\n", |
| 999 | + "grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, T_sym, H_sym, Q_sym])\n", |
946 | 1000 | "f_grad = pytensor.function(\n",
|
947 | 1001 | " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n",
|
948 | 1002 | " outputs=grad_list_op,\n",
|
|
953 | 1007 | },
|
954 | 1008 | {
|
955 | 1009 | "cell_type": "code",
|
956 |
| - "execution_count": 289, |
| 1010 | + "execution_count": 25, |
957 | 1011 | "id": "25f0a57b",
|
958 | 1012 | "metadata": {},
|
959 | 1013 | "outputs": [
|
|
964 | 1018 | "Comparison between classic y gradient and our custom OpFromGraph : True\n",
|
965 | 1019 | "Comparison between classic a0 gradient and our custom OpFromGraph : True\n",
|
966 | 1020 | "Comparison between classic P0 gradient and our custom OpFromGraph : True\n",
|
| 1021 | + "Comparison between classic T gradient and our custom OpFromGraph : True\n", |
967 | 1022 | "Comparison between classic H gradient and our custom OpFromGraph : True\n",
|
968 | 1023 | "Comparison between classic Q gradient and our custom OpFromGraph : True\n"
|
969 | 1024 | ]
|
|
973 | 1028 | "print(\"Comparison between classic y gradient and our custom OpFromGraph :\", np.allclose(grad_a0[0], grad_a0_op[0]))\n",
|
974 | 1029 | "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0[1], grad_a0_op[1]))\n",
|
975 | 1030 | "print(\"Comparison between classic P0 gradient and our custom OpFromGraph :\", np.allclose((grad_a0[2] + grad_a0[2].T)/2, grad_a0_op[2]))\n",
|
976 |
| - "print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n", |
977 |
| - "print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[4] + grad_a0[4].T)/2, grad_a0_op[4]))" |
| 1031 | + "print(\"Comparison between classic T gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n", |
| 1032 | + "print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[4], grad_a0_op[4]))\n", |
| 1033 | + "print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[5] + grad_a0[5].T)/2, grad_a0_op[5]))" |
978 | 1034 | ]
|
979 | 1035 | }
|
980 | 1036 | ],
|
|
0 commit comments