=2
N= 1/2**9 * jnp.identity(N, dtype=jnp.float32)
G = 10* jnp.identity(N, dtype=jnp.float32)
Sigma = 0*jnp.ones((N,), dtype=jnp.float32)
w0 = 0.5
propy1
\
X, Y, W, \
(Wtt_EKF, Ptt_EKF), \
(Wtt_VA, Ptt_VA, Xit_VA), \
(Wtt_EM, Ptt_EM, Xit_EM) = exper(
=jrd.PRNGKey(422), # 822 522
key=N,
N=500,
T=G,
G=w0,
w0=Sigma,
Sigma=G,
P0=propy1) propy1
実験
Test
0], label=r"$(w_0)_t$")
plt.plot(W[:,1], label=r"$(w_1)_t$")
plt.plot(W[:, plt.legend()
==1][100:300,0], X[Y==1][100:300,1], label=r"$\mathbf{x}_t\ (y_t=1)$")
plt.scatter(X[Y==0][100:300,0], X[Y==0][100:300,1], label=r"$\mathbf{x}_t\ (y_t=0)$")
plt.scatter(X[Y plt.legend()
0], label=r"$(w_0)_t$ (True Weights)")
plt.plot(W[:,0], "-", label=r"$(\hat{w}_0)_t$, (DLR)")
plt.plot(Wtt_EKF[:,0], "--", label=r"$(\hat{w}_0)_t$, (FU)")
plt.plot(Wtt_VA[:,0], ":", label=r"$(\hat{w}_0)_t$, (FP)")
plt.plot(Wtt_EM[:,"time")
plt.xlabel("weights")
plt.ylabel( plt.legend()
1], label=r"$(w_1)_t$")
plt.plot(W[:,1], label=r"$(\hat{w}_1)_t$, (EKF)")
plt.plot(Wtt_EKF[:,1], label=r"$(\hat{w}_1)_t$, (VA)")
plt.plot(Wtt_VA[:,1], label=r"$(\hat{w}_1)_t$, (EM)")
plt.plot(Wtt_EM[:, plt.legend()
0,0], label=r"$(P_{1,1})_t$, (EKF)")
plt.plot(Ptt_EKF[:,0,0], label=r"$(P_{1,1})_t$, (VA)")
plt.plot(Ptt_VA[:,0,0], label=r"$(P_{1,1})_t$, (EM)")
plt.plot(Ptt_EM[:, plt.legend()
1,1], label=r"$(P_{2,2})_t$, (EKF)")
plt.plot(Ptt_EKF[:,1,1], label=r"$(P_{2,2})_t$, (VA)")
plt.plot(Ptt_VA[:,1,1], label=r"$(P_{2,2})_t$, (EM)")
plt.plot(Ptt_EM[:, plt.legend()
sum(W*X, axis=1)), 'o', label=r"$(w_0)_t$")
plt.plot(simple.losi(jnp.sum(Wtt_EKF*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EM)")
plt.plot(simple.losi(jnp. plt.legend()
= simple.losi(jnp.sum(W*X, axis=1))
true_line sum(Wtt_EKF*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EM)")
plt.plot(simple.losi(jnp. plt.legend()
print(jnp.sum((simple.losi(jnp.sum(Wtt_EKF*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line)**2))
7.2249694
7.6888556
7.6575623
import pandas as pd
= pd.read_csv("data.csv", sep="\t")
df1 df1
N | G | Sigma | propy1 | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|---|---|
2 | 2.0 | 0.062500 | 0.50 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
4 | 2.0 | 0.062500 | 1.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
6 | 2.0 | 0.062500 | 2.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
8 | 2.0 | 0.062500 | 4.00 | 0.5 | 0.094327 | 0.086147 | 0.085870 |
10 | 2.0 | 0.062500 | 8.00 | 0.5 | 0.096139 | 0.082455 | 0.082723 |
18 | 2.0 | 0.015625 | 2.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
20 | 2.0 | 0.015625 | 4.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
22 | 2.0 | 0.015625 | 8.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
34 | 2.0 | 0.003906 | 8.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
48 | 2.0 | 0.000244 | 0.25 | 0.5 | 0.039164 | 0.039163 | 0.039163 |
60 | 4.0 | 0.062500 | 0.25 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
62 | 4.0 | 0.062500 | 0.50 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
64 | 4.0 | 0.062500 | 1.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
66 | 4.0 | 0.062500 | 2.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
68 | 4.0 | 0.062500 | 4.00 | 0.5 | 0.072354 | 0.060668 | 0.061745 |
70 | 4.0 | 0.062500 | 8.00 | 0.5 | 0.064038 | 0.051949 | 0.053414 |
76 | 4.0 | 0.015625 | 1.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
78 | 4.0 | 0.015625 | 2.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
80 | 4.0 | 0.015625 | 4.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
82 | 4.0 | 0.015625 | 8.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
92 | 4.0 | 0.003906 | 4.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
94 | 4.0 | 0.003906 | 8.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
120 | 8.0 | 0.062500 | 0.25 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
122 | 8.0 | 0.062500 | 0.50 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
124 | 8.0 | 0.062500 | 1.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
126 | 8.0 | 0.062500 | 2.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
128 | 8.0 | 0.062500 | 4.00 | 0.5 | 0.048867 | 0.039920 | 0.040344 |
130 | 8.0 | 0.062500 | 8.00 | 0.5 | 0.038852 | 0.031403 | 0.031772 |
132 | 8.0 | 0.015625 | 0.25 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
134 | 8.0 | 0.015625 | 0.50 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
136 | 8.0 | 0.015625 | 1.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
138 | 8.0 | 0.015625 | 2.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
140 | 8.0 | 0.015625 | 4.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
142 | 8.0 | 0.015625 | 8.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
146 | 8.0 | 0.003906 | 0.50 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
148 | 8.0 | 0.003906 | 1.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
150 | 8.0 | 0.003906 | 2.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
152 | 8.0 | 0.003906 | 4.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
154 | 8.0 | 0.003906 | 8.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
156 | 8.0 | 0.000977 | 0.25 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
162 | 8.0 | 0.000977 | 2.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
164 | 8.0 | 0.000977 | 4.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
166 | 8.0 | 0.000977 | 8.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
168 | 8.0 | 0.000244 | 0.25 | 0.5 | 0.076621 | 0.076618 | 0.076616 |
170 | 8.0 | 0.000244 | 0.50 | 0.5 | 0.090635 | 0.090629 | 0.090625 |
172 | 8.0 | 0.000244 | 1.00 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
178 | 8.0 | 0.000244 | 8.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
"err_EKF_sc"] < df1["err_VA_sc"])].groupby("N").count() df1[(df1[
G | Sigma | propy1 | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|---|
N |
"err_EKF_sc"] > df1["err_VA_sc"] df1[
2 True
4 True
6 True
8 True
10 True
18 True
20 True
22 True
34 True
48 True
60 True
62 True
64 True
66 True
68 True
70 True
76 True
78 True
80 True
82 True
92 True
94 True
120 True
122 True
124 True
126 True
128 True
130 True
132 True
134 True
136 True
138 True
140 True
142 True
146 True
148 True
150 True
152 True
154 True
156 True
162 True
164 True
166 True
168 True
170 True
172 True
178 True
dtype: bool