N=2
G = 1/2**9 * jnp.identity(N, dtype=jnp.float32)
Sigma = 10* jnp.identity(N, dtype=jnp.float32)
w0 = 0*jnp.ones((N,), dtype=jnp.float32)
propy1 = 0.5
X, Y, W, \
(Wtt_EKF, Ptt_EKF), \
(Wtt_VA, Ptt_VA, Xit_VA), \
(Wtt_EM, Ptt_EM, Xit_EM) \
= exper(
key=jrd.PRNGKey(422), # 822 522
N=N,
T=500,
G=G,
w0=w0,
Sigma=Sigma,
P0=G,
propy1=propy1)実験
Test
plt.plot(W[:,0], label=r"$(w_0)_t$")
plt.plot(W[:,1], label=r"$(w_1)_t$")
plt.legend()
plt.scatter(X[Y==1][100:300,0], X[Y==1][100:300,1], label=r"$\mathbf{x}_t\ (y_t=1)$")
plt.scatter(X[Y==0][100:300,0], X[Y==0][100:300,1], label=r"$\mathbf{x}_t\ (y_t=0)$")
plt.legend()
plt.plot(W[:,0], label=r"$(w_0)_t$ (True Weights)")
plt.plot(Wtt_EKF[:,0], "-", label=r"$(\hat{w}_0)_t$, (DLR)")
plt.plot(Wtt_VA[:,0], "--", label=r"$(\hat{w}_0)_t$, (FU)")
plt.plot(Wtt_EM[:,0], ":", label=r"$(\hat{w}_0)_t$, (FP)")
plt.xlabel("time")
plt.ylabel("weights")
plt.legend()
plt.plot(W[:,1], label=r"$(w_1)_t$")
plt.plot(Wtt_EKF[:,1], label=r"$(\hat{w}_1)_t$, (EKF)")
plt.plot(Wtt_VA[:,1], label=r"$(\hat{w}_1)_t$, (VA)")
plt.plot(Wtt_EM[:,1], label=r"$(\hat{w}_1)_t$, (EM)")
plt.legend()
plt.plot(Ptt_EKF[:,0,0], label=r"$(P_{1,1})_t$, (EKF)")
plt.plot(Ptt_VA[:,0,0], label=r"$(P_{1,1})_t$, (VA)")
plt.plot(Ptt_EM[:,0,0], label=r"$(P_{1,1})_t$, (EM)")
plt.legend()
plt.plot(Ptt_EKF[:,1,1], label=r"$(P_{2,2})_t$, (EKF)")
plt.plot(Ptt_VA[:,1,1], label=r"$(P_{2,2})_t$, (VA)")
plt.plot(Ptt_EM[:,1,1], label=r"$(P_{2,2})_t$, (EM)")
plt.legend()
plt.plot(simple.losi(jnp.sum(W*X, axis=1)), 'o', label=r"$(w_0)_t$")
plt.plot(simple.losi(jnp.sum(Wtt_EKF*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EM)")
plt.legend()
true_line = simple.losi(jnp.sum(W*X, axis=1))
plt.plot(simple.losi(jnp.sum(Wtt_EKF*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EM)")
plt.legend()
print(jnp.sum((simple.losi(jnp.sum(Wtt_EKF*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line)**2))7.2249694
7.6888556
7.6575623
import pandas as pddf1 = pd.read_csv("data.csv", sep="\t")
df1| N | G | Sigma | propy1 | err_EKF_sc | err_VA_sc | err_EM_sc | |
|---|---|---|---|---|---|---|---|
| 2 | 2.0 | 0.062500 | 0.50 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
| 4 | 2.0 | 0.062500 | 1.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
| 6 | 2.0 | 0.062500 | 2.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
| 8 | 2.0 | 0.062500 | 4.00 | 0.5 | 0.094327 | 0.086147 | 0.085870 |
| 10 | 2.0 | 0.062500 | 8.00 | 0.5 | 0.096139 | 0.082455 | 0.082723 |
| 18 | 2.0 | 0.015625 | 2.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
| 20 | 2.0 | 0.015625 | 4.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
| 22 | 2.0 | 0.015625 | 8.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
| 34 | 2.0 | 0.003906 | 8.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
| 48 | 2.0 | 0.000244 | 0.25 | 0.5 | 0.039164 | 0.039163 | 0.039163 |
| 60 | 4.0 | 0.062500 | 0.25 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
| 62 | 4.0 | 0.062500 | 0.50 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
| 64 | 4.0 | 0.062500 | 1.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
| 66 | 4.0 | 0.062500 | 2.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
| 68 | 4.0 | 0.062500 | 4.00 | 0.5 | 0.072354 | 0.060668 | 0.061745 |
| 70 | 4.0 | 0.062500 | 8.00 | 0.5 | 0.064038 | 0.051949 | 0.053414 |
| 76 | 4.0 | 0.015625 | 1.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
| 78 | 4.0 | 0.015625 | 2.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
| 80 | 4.0 | 0.015625 | 4.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
| 82 | 4.0 | 0.015625 | 8.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
| 92 | 4.0 | 0.003906 | 4.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
| 94 | 4.0 | 0.003906 | 8.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
| 120 | 8.0 | 0.062500 | 0.25 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
| 122 | 8.0 | 0.062500 | 0.50 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
| 124 | 8.0 | 0.062500 | 1.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
| 126 | 8.0 | 0.062500 | 2.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
| 128 | 8.0 | 0.062500 | 4.00 | 0.5 | 0.048867 | 0.039920 | 0.040344 |
| 130 | 8.0 | 0.062500 | 8.00 | 0.5 | 0.038852 | 0.031403 | 0.031772 |
| 132 | 8.0 | 0.015625 | 0.25 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
| 134 | 8.0 | 0.015625 | 0.50 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
| 136 | 8.0 | 0.015625 | 1.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
| 138 | 8.0 | 0.015625 | 2.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
| 140 | 8.0 | 0.015625 | 4.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
| 142 | 8.0 | 0.015625 | 8.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
| 146 | 8.0 | 0.003906 | 0.50 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
| 148 | 8.0 | 0.003906 | 1.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
| 150 | 8.0 | 0.003906 | 2.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
| 152 | 8.0 | 0.003906 | 4.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
| 154 | 8.0 | 0.003906 | 8.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
| 156 | 8.0 | 0.000977 | 0.25 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
| 162 | 8.0 | 0.000977 | 2.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
| 164 | 8.0 | 0.000977 | 4.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
| 166 | 8.0 | 0.000977 | 8.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
| 168 | 8.0 | 0.000244 | 0.25 | 0.5 | 0.076621 | 0.076618 | 0.076616 |
| 170 | 8.0 | 0.000244 | 0.50 | 0.5 | 0.090635 | 0.090629 | 0.090625 |
| 172 | 8.0 | 0.000244 | 1.00 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
| 178 | 8.0 | 0.000244 | 8.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
df1[(df1["err_EKF_sc"] < df1["err_VA_sc"])].groupby("N").count()| G | Sigma | propy1 | err_EKF_sc | err_VA_sc | err_EM_sc | |
|---|---|---|---|---|---|---|
| N |
df1["err_EKF_sc"] > df1["err_VA_sc"]2 True
4 True
6 True
8 True
10 True
18 True
20 True
22 True
34 True
48 True
60 True
62 True
64 True
66 True
68 True
70 True
76 True
78 True
80 True
82 True
92 True
94 True
120 True
122 True
124 True
126 True
128 True
130 True
132 True
134 True
136 True
138 True
140 True
142 True
146 True
148 True
150 True
152 True
154 True
156 True
162 True
164 True
166 True
168 True
170 True
172 True
178 True
dtype: bool