実験

Test
N=2
G = 1/2**9 * jnp.identity(N, dtype=jnp.float32)
Sigma = 10* jnp.identity(N, dtype=jnp.float32)
w0 = 0*jnp.ones((N,), dtype=jnp.float32)
propy1 = 0.5

X, Y, W, \
  (Wtt_EKF, Ptt_EKF), \
    (Wtt_VA, Ptt_VA, Xit_VA), \
      (Wtt_EM, Ptt_EM, Xit_EM) \
        = exper(
  key=jrd.PRNGKey(422), # 822 522
  N=N, 
  T=500, 
  G=G,
  w0=w0,
  Sigma=Sigma,
  P0=G,
  propy1=propy1)
plt.plot(W[:,0], label=r"$(w_0)_t$")
plt.plot(W[:,1], label=r"$(w_1)_t$")
plt.legend()

plt.scatter(X[Y==1][100:300,0], X[Y==1][100:300,1], label=r"$\mathbf{x}_t\ (y_t=1)$")
plt.scatter(X[Y==0][100:300,0], X[Y==0][100:300,1], label=r"$\mathbf{x}_t\ (y_t=0)$")
plt.legend()

plt.plot(W[:,0], label=r"$(w_0)_t$ (True Weights)")
plt.plot(Wtt_EKF[:,0], "-", label=r"$(\hat{w}_0)_t$, (DLR)")
plt.plot(Wtt_VA[:,0], "--", label=r"$(\hat{w}_0)_t$, (FU)")
plt.plot(Wtt_EM[:,0], ":", label=r"$(\hat{w}_0)_t$, (FP)")
plt.xlabel("time")
plt.ylabel("weights")
plt.legend()

plt.plot(W[:,1], label=r"$(w_1)_t$")
plt.plot(Wtt_EKF[:,1], label=r"$(\hat{w}_1)_t$, (EKF)")
plt.plot(Wtt_VA[:,1], label=r"$(\hat{w}_1)_t$, (VA)")
plt.plot(Wtt_EM[:,1], label=r"$(\hat{w}_1)_t$, (EM)")
plt.legend()

plt.plot(Ptt_EKF[:,0,0], label=r"$(P_{1,1})_t$, (EKF)")
plt.plot(Ptt_VA[:,0,0], label=r"$(P_{1,1})_t$, (VA)")
plt.plot(Ptt_EM[:,0,0], label=r"$(P_{1,1})_t$, (EM)")
plt.legend()

plt.plot(Ptt_EKF[:,1,1], label=r"$(P_{2,2})_t$, (EKF)")
plt.plot(Ptt_VA[:,1,1], label=r"$(P_{2,2})_t$, (VA)")
plt.plot(Ptt_EM[:,1,1], label=r"$(P_{2,2})_t$, (EM)")
plt.legend()

plt.plot(simple.losi(jnp.sum(W*X, axis=1)), 'o', label=r"$(w_0)_t$")
plt.plot(simple.losi(jnp.sum(Wtt_EKF*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)), 'o', label=r"$(\hat{w}_0)_t$, (EM)")
plt.legend()

true_line = simple.losi(jnp.sum(W*X, axis=1))
plt.plot(simple.losi(jnp.sum(Wtt_EKF*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EKF)")
plt.plot(simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (VA)")
plt.plot(simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line, label=r"$(\hat{w}_0)_t$, (EM)")
plt.legend()

print(jnp.sum((simple.losi(jnp.sum(Wtt_EKF*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_VA*X, axis=1)) - true_line)**2))
print(jnp.sum((simple.losi(jnp.sum(Wtt_EM*X, axis=1)) - true_line)**2))
7.2249694
7.6888556
7.6575623
import pandas as pd
df1 = pd.read_csv("data.csv", sep="\t")
df1
N G Sigma propy1 err_EKF_sc err_VA_sc err_EM_sc
2 2.0 0.062500 0.50 0.5 0.093420 0.093336 0.093043
4 2.0 0.062500 1.00 0.5 0.093443 0.091697 0.091294
6 2.0 0.062500 2.00 0.5 0.093836 0.089228 0.088805
8 2.0 0.062500 4.00 0.5 0.094327 0.086147 0.085870
10 2.0 0.062500 8.00 0.5 0.096139 0.082455 0.082723
18 2.0 0.015625 2.00 0.5 0.093420 0.093336 0.093043
20 2.0 0.015625 4.00 0.5 0.093443 0.091697 0.091294
22 2.0 0.015625 8.00 0.5 0.093836 0.089228 0.088805
34 2.0 0.003906 8.00 0.5 0.093420 0.093336 0.093043
48 2.0 0.000244 0.25 0.5 0.039164 0.039163 0.039163
60 4.0 0.062500 0.25 0.5 0.100617 0.099913 0.099700
62 4.0 0.062500 0.50 0.5 0.093656 0.091329 0.091214
64 4.0 0.062500 1.00 0.5 0.085952 0.080761 0.080898
66 4.0 0.062500 2.00 0.5 0.078222 0.070136 0.070658
68 4.0 0.062500 4.00 0.5 0.072354 0.060668 0.061745
70 4.0 0.062500 8.00 0.5 0.064038 0.051949 0.053414
76 4.0 0.015625 1.00 0.5 0.100617 0.099913 0.099700
78 4.0 0.015625 2.00 0.5 0.093656 0.091329 0.091214
80 4.0 0.015625 4.00 0.5 0.085952 0.080761 0.080898
82 4.0 0.015625 8.00 0.5 0.078222 0.070136 0.070658
92 4.0 0.003906 4.00 0.5 0.100617 0.099913 0.099700
94 4.0 0.003906 8.00 0.5 0.093656 0.091329 0.091214
120 8.0 0.062500 0.25 0.5 0.096732 0.091938 0.092006
122 8.0 0.062500 0.50 0.5 0.083484 0.076024 0.076319
124 8.0 0.062500 1.00 0.5 0.071028 0.062117 0.062546
126 8.0 0.062500 2.00 0.5 0.059638 0.050017 0.050485
128 8.0 0.062500 4.00 0.5 0.048867 0.039920 0.040344
130 8.0 0.062500 8.00 0.5 0.038852 0.031403 0.031772
132 8.0 0.015625 0.25 0.5 0.120174 0.119479 0.119283
134 8.0 0.015625 0.50 0.5 0.109672 0.107326 0.107210
136 8.0 0.015625 1.00 0.5 0.096732 0.091938 0.092006
138 8.0 0.015625 2.00 0.5 0.083484 0.076024 0.076319
140 8.0 0.015625 4.00 0.5 0.071028 0.062117 0.062546
142 8.0 0.015625 8.00 0.5 0.059638 0.050017 0.050485
146 8.0 0.003906 0.50 0.5 0.125066 0.124960 0.124792
148 8.0 0.003906 1.00 0.5 0.120174 0.119479 0.119283
150 8.0 0.003906 2.00 0.5 0.109672 0.107326 0.107210
152 8.0 0.003906 4.00 0.5 0.096732 0.091938 0.092006
154 8.0 0.003906 8.00 0.5 0.083484 0.076024 0.076319
156 8.0 0.000977 0.25 0.5 0.104002 0.103998 0.103986
162 8.0 0.000977 2.00 0.5 0.125066 0.124960 0.124792
164 8.0 0.000977 4.00 0.5 0.120174 0.119479 0.119283
166 8.0 0.000977 8.00 0.5 0.109672 0.107326 0.107210
168 8.0 0.000244 0.25 0.5 0.076621 0.076618 0.076616
170 8.0 0.000244 0.50 0.5 0.090635 0.090629 0.090625
172 8.0 0.000244 1.00 0.5 0.104002 0.103998 0.103986
178 8.0 0.000244 8.00 0.5 0.125066 0.124960 0.124792
df1[(df1["err_EKF_sc"] < df1["err_VA_sc"])].groupby("N").count()
G Sigma propy1 err_EKF_sc err_VA_sc err_EM_sc
N
df1["err_EKF_sc"] > df1["err_VA_sc"]
2      True
4      True
6      True
8      True
10     True
18     True
20     True
22     True
34     True
48     True
60     True
62     True
64     True
66     True
68     True
70     True
76     True
78     True
80     True
82     True
92     True
94     True
120    True
122    True
124    True
126    True
128    True
130    True
132    True
134    True
136    True
138    True
140    True
142    True
146    True
148    True
150    True
152    True
154    True
156    True
162    True
164    True
166    True
168    True
170    True
172    True
178    True
dtype: bool