---
skip_showdoc: true---
key によるバッチ化
Test
=10
N=1000
T= 1/2**7 * jnp.identity(N, dtype=jnp.float32)
G = 1.5 * jnp.identity(N, dtype=jnp.float32)
Sigma = 0*jnp.ones((N,), dtype=jnp.float32)/jnp.sqrt(N)
w0 = G
P0 = 0.5
propy1
= jax.vmap(
batched_exper lambda key: Comp.RMS(key, N, T, G, w0, Sigma, P0, propy1)
)
= jrd.PRNGKey(0)
master_key = jrd.split(master_key, 1000)
keys
= batched_exper(keys) W_norms, RMS_EKF, RMS_VA, RMS_EM
-1), RMS_EKF.reshape(-1))
plt.scatter(W_norms.reshape(-1), RMS_VA.reshape(-1), marker="x")
plt.scatter(W_norms.reshape(-1), RMS_EM.reshape(-1), marker=".") plt.scatter(W_norms.reshape(
# bin の定義
= jnp.arange(0.1, 12, 0.1)
bins = jnp.arange(0.0, 12.1, 0.1) # 例: [0.0, 0.1, 0.2, ..., 12.0]
bin_edges
# 各要素が属するビンを計算 (1~len(bins) のインデックス)
= jnp.digitize(W_norms.ravel(), bin_edges) - 1 # shape (T*N,)
bin_idx
# フラット化
= RMS_EKF.ravel()
ekf_flat = RMS_VA.ravel()
va_flat = RMS_EM.ravel()
em_flat
# 各 bin ごとの総和とカウントを計算
= jnp.bincount(bin_idx, weights=ekf_flat, length=len(bins))
sum_ekf = jnp.bincount(bin_idx, weights=va_flat, length=len(bins))
sum_va = jnp.bincount(bin_idx, weights=em_flat, length=len(bins))
sum_em = jnp.bincount(bin_idx, length=len(bins))
counts
# 平均を計算(ゼロ除算防止)
= jnp.where(counts > 0, sum_ekf / counts, jnp.nan)
mean_ekf = jnp.where(counts > 0, sum_va / counts, jnp.nan)
mean_va = jnp.where(counts > 0, sum_em / counts, jnp.nan)
mean_em
# 最終結果をまとめる
= jnp.stack([bins, mean_ekf, mean_va, mean_em], axis=1)
df_source
# Pandas DataFrame に変換(必要なら)
= pd.DataFrame(jnp.array(df_source), columns=["bin", "EKF_RMS", "VA_RMS", "EM_RMS"]).set_index("bin") df
0.8:10].head(20) df[
EKF_RMS | VA_RMS | EM_RMS | |
---|---|---|---|
bin | |||
0.8 | 0.742320 | 0.743473 | 0.743350 |
0.9 | 0.839010 | 0.841791 | 0.841518 |
1.0 | 0.924905 | 0.927862 | 0.927469 |
1.1 | 1.016297 | 1.020453 | 1.019893 |
1.2 | 1.105977 | 1.110465 | 1.109501 |
1.3 | 1.174655 | 1.180493 | 1.179403 |
1.4 | 1.252368 | 1.259346 | 1.257725 |
1.5 | 1.327946 | 1.335311 | 1.333957 |
1.6 | 1.379003 | 1.375262 | 1.373536 |
1.7 | 1.458880 | 1.461715 | 1.460669 |
1.8 | 1.502239 | 1.492858 | 1.491333 |
1.9 | 1.553879 | 1.537835 | 1.536580 |
2.0 | 1.609895 | 1.585574 | 1.583595 |
2.1 | 1.650254 | 1.615584 | 1.614009 |
2.2 | 1.710323 | 1.670204 | 1.669090 |
2.3 | 1.770196 | 1.712925 | 1.711877 |
2.4 | 1.830160 | 1.768064 | 1.767073 |
2.5 | 1.889573 | 1.808915 | 1.808709 |
2.6 | 1.922787 | 1.830989 | 1.830768 |
2.7 | 1.968957 | 1.869247 | 1.869909 |
2.7].plot() df[:
=10
N=1000
T= 1/2**9 * jnp.identity(N, dtype=jnp.float32)
G = 0.5 * jnp.identity(N, dtype=jnp.float32)
Sigma = 0*jnp.ones((N,), dtype=jnp.float32)/jnp.sqrt(N)
w0 = G
P0 = 0.5
propy1
= jax.vmap(
batched_exper lambda key: Comp.losi_error(key, N, T, G, w0, Sigma, P0, propy1)
)
= jrd.PRNGKey(0)
master_key = jrd.split(master_key, 1000)
keys
= batched_exper(keys) RMS_EKF, RMS_VA, RMS_EM
sum(), RMS_VA.sum(), RMS_EM.sum() RMS_EKF.
(Array(1.0641768e+07, dtype=float32),
Array(8.087362e+06, dtype=float32),
Array(8.100394e+06, dtype=float32))
= [2, 4, 8]
Ns = jnp.array([1/2**4, 1/2**6, 1/2**8, 1/2**10, 1/2**12])
Gs = jnp.array([1/2**2, 1/2, 1, 2, 4, 8])
Sigmas = jnp.array([0.5, 0.1])
propy1s
import pandas as pd
# 最初に空の DataFrame を用意
= pd.DataFrame(columns=["N", "G", "Sigma", "propy1", "err_EKF_sc", "err_VA_sc", "err_EM_sc"])
df
for N in Ns:
print("N", N)
for G_ in Gs:
print("G", G_)
for Sigma_ in Sigmas:
for propy1 in propy1s:
= 1000
T = G_ * jnp.identity(N, dtype=jnp.float32)
G = Sigma_ * jnp.identity(N, dtype=jnp.float32)
Sigma = 0*jnp.ones((N,), dtype=jnp.float32)/jnp.sqrt(N)
w0 = G
P0 = jax.vmap(
batched_exper lambda key: Comp.losi_error(key, N, T, G, w0, Sigma, P0, propy1),
=(0,)
in_axes
)= jrd.PRNGKey(0)
master_key = jrd.split(master_key, 1000)
keys
= batched_exper(keys)
err_EKF, err_VA, err_EM
= jnp.sqrt(err_EKF.mean())
err_EKF_sc = jnp.sqrt(err_VA.mean())
err_VA_sc = jnp.sqrt(err_EM.mean())
err_EM_sc
# pandas df に追加
len(df)] = [
df.loc[int(N),
float(G_),
float(Sigma_),
float(propy1),
float(err_EKF_sc),
float(err_VA_sc),
float(err_EM_sc)
]
print(df.head())
N 2
G 0.0625
G 0.015625
G 0.00390625
G 0.0009765625
G 0.00024414062
N 4
G 0.0625
G 0.015625
G 0.00390625
G 0.0009765625
G 0.00024414062
N 8
G 0.0625
G 0.015625
G 0.00390625
G 0.0009765625
G 0.00024414062
N G Sigma propy1 err_EKF_sc err_VA_sc err_EM_sc
0 2.0 0.0625 0.25 0.5 0.092954 0.093325 0.093123
1 2.0 0.0625 0.25 0.1 0.092736 0.093197 0.092990
2 2.0 0.0625 0.50 0.5 0.093420 0.093336 0.093043
3 2.0 0.0625 0.50 0.1 0.093356 0.093240 0.092941
4 2.0 0.0625 1.00 0.5 0.093443 0.091697 0.091294
N | G | Sigma | propy1 | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|---|---|
2 | 2.0 | 0.062500 | 0.50 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
4 | 2.0 | 0.062500 | 1.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
6 | 2.0 | 0.062500 | 2.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
8 | 2.0 | 0.062500 | 4.00 | 0.5 | 0.094327 | 0.086147 | 0.085870 |
10 | 2.0 | 0.062500 | 8.00 | 0.5 | 0.096139 | 0.082455 | 0.082723 |
18 | 2.0 | 0.015625 | 2.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
20 | 2.0 | 0.015625 | 4.00 | 0.5 | 0.093443 | 0.091697 | 0.091294 |
22 | 2.0 | 0.015625 | 8.00 | 0.5 | 0.093836 | 0.089228 | 0.088805 |
34 | 2.0 | 0.003906 | 8.00 | 0.5 | 0.093420 | 0.093336 | 0.093043 |
48 | 2.0 | 0.000244 | 0.25 | 0.5 | 0.039164 | 0.039163 | 0.039163 |
60 | 4.0 | 0.062500 | 0.25 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
62 | 4.0 | 0.062500 | 0.50 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
64 | 4.0 | 0.062500 | 1.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
66 | 4.0 | 0.062500 | 2.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
68 | 4.0 | 0.062500 | 4.00 | 0.5 | 0.072354 | 0.060668 | 0.061745 |
70 | 4.0 | 0.062500 | 8.00 | 0.5 | 0.064038 | 0.051949 | 0.053414 |
76 | 4.0 | 0.015625 | 1.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
78 | 4.0 | 0.015625 | 2.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
80 | 4.0 | 0.015625 | 4.00 | 0.5 | 0.085952 | 0.080761 | 0.080898 |
82 | 4.0 | 0.015625 | 8.00 | 0.5 | 0.078222 | 0.070136 | 0.070658 |
92 | 4.0 | 0.003906 | 4.00 | 0.5 | 0.100617 | 0.099913 | 0.099700 |
94 | 4.0 | 0.003906 | 8.00 | 0.5 | 0.093656 | 0.091329 | 0.091214 |
120 | 8.0 | 0.062500 | 0.25 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
122 | 8.0 | 0.062500 | 0.50 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
124 | 8.0 | 0.062500 | 1.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
126 | 8.0 | 0.062500 | 2.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
128 | 8.0 | 0.062500 | 4.00 | 0.5 | 0.048867 | 0.039920 | 0.040344 |
130 | 8.0 | 0.062500 | 8.00 | 0.5 | 0.038852 | 0.031403 | 0.031772 |
132 | 8.0 | 0.015625 | 0.25 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
134 | 8.0 | 0.015625 | 0.50 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
136 | 8.0 | 0.015625 | 1.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
138 | 8.0 | 0.015625 | 2.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
140 | 8.0 | 0.015625 | 4.00 | 0.5 | 0.071028 | 0.062117 | 0.062546 |
142 | 8.0 | 0.015625 | 8.00 | 0.5 | 0.059638 | 0.050017 | 0.050485 |
146 | 8.0 | 0.003906 | 0.50 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
148 | 8.0 | 0.003906 | 1.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
150 | 8.0 | 0.003906 | 2.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
152 | 8.0 | 0.003906 | 4.00 | 0.5 | 0.096732 | 0.091938 | 0.092006 |
154 | 8.0 | 0.003906 | 8.00 | 0.5 | 0.083484 | 0.076024 | 0.076319 |
156 | 8.0 | 0.000977 | 0.25 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
162 | 8.0 | 0.000977 | 2.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
164 | 8.0 | 0.000977 | 4.00 | 0.5 | 0.120174 | 0.119479 | 0.119283 |
166 | 8.0 | 0.000977 | 8.00 | 0.5 | 0.109672 | 0.107326 | 0.107210 |
168 | 8.0 | 0.000244 | 0.25 | 0.5 | 0.076621 | 0.076618 | 0.076616 |
170 | 8.0 | 0.000244 | 0.50 | 0.5 | 0.090635 | 0.090629 | 0.090625 |
172 | 8.0 | 0.000244 | 1.00 | 0.5 | 0.104002 | 0.103998 | 0.103986 |
178 | 8.0 | 0.000244 | 8.00 | 0.5 | 0.125066 | 0.124960 | 0.124792 |
"propy1").mean() df.groupby(
N | G | Sigma | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|---|
propy1 | ||||||
0.1 | 4.666667 | 0.01665 | 2.625 | 0.090087 | 0.087966 | 0.087953 |
0.5 | 4.666667 | 0.01665 | 2.625 | 0.090186 | 0.088121 | 0.088114 |
= df[df["propy1"] == 0.5][["N", "G", "Sigma", "err_EKF_sc", "err_VA_sc", "err_EM_sc"]] df1
df1.shape
(90, 7)
損失関数と同じ評価関数
\[E=\|\hat{\mathbf w}_{t} - \mathbf w_t\|^2\]
新たに導入した評価関数: \[E=\sqrt{\frac{1}{T}\sum_t^T\left(\left[\sigma(\mathbf w_t^T\mathbf x_t) - \sigma(\hat{\mathbf w}_{t-1}^T\mathbf x_t)\right]^2\right)}\]
次元数 \(N\) に関しての比較
90 個のデータのうち、変分近似が拡張カルマンフィルタよりも \(E\) が大きかったデータは - \(N=2\) : 20/30 - \(N=4\) : 18/30 - \(N=8\) : 5/30
"err_EKF_sc"] < df1["err_VA_sc"])].groupby("N").count() df1[(df1[
G | Sigma | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|
N | |||||
2.0 | 20 | 20 | 20 | 20 | 20 |
4.0 | 18 | 18 | 18 | 18 | 18 |
8.0 | 5 | 5 | 5 | 5 | 5 |
遷移行列 \(\boldsymbol\Gamma=pI\) に関しての比較 90 個のデータのうち、変分近似が拡張カルマンフィルタよりも \(E\) が大きかったデータは [1/24, 1/26, 1/28, 1/210, 1/2**12] - \(p = 1/2^{12}\) : 13/18 - \(p = 1/2^{10}\) : 14/18 - \(p = 1/2^8\) : 10/18 - \(p = 1/2^6\) : 5/18 - \(p = 1/2^4\) : 1/18
"err_EKF_sc"] < df1["err_VA_sc"])].groupby("G").count() df1[(df1[
N | Sigma | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|
G | |||||
0.000244 | 13 | 13 | 13 | 13 | 13 |
0.000977 | 14 | 14 | 14 | 14 | 14 |
0.003906 | 10 | 10 | 10 | 10 | 10 |
0.015625 | 5 | 5 | 5 | 5 | 5 |
0.062500 | 1 | 1 | 1 | 1 | 1 |
\(\mathbf x_t\) の共分散行列 \(\Sigma=qI\) に関しての比較 [1/2**2, 1/2, 1, 2, 4, 8] - \(q = 1/2^2\) : 9/15 - \(q = 1/2\) : 9/15 - \(q = 1\) : 8/15 - \(q = 2\) : 7/15 - \(q = 4\) : 6/15 - \(q = 8\) : 4/15
"err_EKF_sc"] < df1["err_VA_sc"])].groupby("Sigma").count() df1[(df1[
N | G | err_EKF_sc | err_VA_sc | err_EM_sc | |
---|---|---|---|---|---|
Sigma | |||||
0.25 | 9 | 9 | 9 | 9 | 9 |
0.50 | 9 | 9 | 9 | 9 | 9 |
1.00 | 8 | 8 | 8 | 8 | 8 |
2.00 | 7 | 7 | 7 | 7 | 7 |
4.00 | 6 | 6 | 6 | 6 | 6 |
8.00 | 4 | 4 | 4 | 4 | 4 |
"G", "err_EKF_sc", "err_VA_sc", "err_EM_sc"]].groupby("G").mean().plot() df1[[
"Sigma", "err_EKF_sc", "err_VA_sc", "err_EM_sc"]].groupby("Sigma").mean().plot() df1[[
"N", "err_EKF_sc", "err_VA_sc", "err_EM_sc"]].groupby("N").mean().plot() df1[[