分析02

00_Gen/gen_xy_logistic によって生成されたデータに対して各種モデルを適用する

関数定義


Parameter

 Parameter (G:jaxtyping.Float[Array,'{N}{N}'],
            S:jaxtyping.Float[Array,'{N}{N}'],
            w0:jaxtyping.Float[Array,'{N}'],
            P0:jaxtyping.Float[Array,'{N}{N}'],
            epsilon:jaxtyping.Float[Array,'']=Array(1.5258789e-05,
            dtype=float32, weak_type=True))

\(\!\) 00_Gen/gen_xy_logistic によって生成するときのパラメータ。

\(\!\) Type Default Details
G Float[Array, ‘{N} {N}’] \(\!\) \(\boldsymbol\Gamma\)
S Float[Array, ‘{N} {N}’] \(\!\) \(\boldsymbol\Sigma\)
w0 Float[Array, ‘{N}’] \(\!\) \(\mathbf w_{-1}\)
P0 Float[Array, ‘{N} {N}’] \(\!\) \(\mathbf P_{-1}\)
epsilon Float[Array, ’’] 1.52587890625e-05 \(\epsilon\)

\(\!\)


Param

 Param (N:int, p:int, q:int, r:int)

\(\!\) 可変パラメータ。各パラメータは次のように定義される。

  • \(N\)
  • \(T=1000\)
  • \(\boldsymbol\Gamma=2^p\mathbf I\)
  • \(\boldsymbol\Sigma=2^q\mathbf I\)
  • \(\mathbf w_{-1}=(r/2)(1,\ldots,1)^T/\sqrt{N}\)
  • \(\mathbf P_{-1}=\boldsymbol\Gamma\)
  • \(\epsilon=2^{-16}\)
\(\!\) Type Details
N int \(N\)
p int \(p\)
q int \(q\)
r int \(r\)

\(\!\)


restore_param

 restore_param (param:__main__.Param)

\(\!\) Param から Parameter に変換する。 \(\!\)

\(\!\) Type Details
param Param Param
Returns Tuple \(N\), \(T\), Parameter

save_data

 save_data (param:__main__.Param, name:str, data:dict)

\(\!\) 02_Data.h5 にデータを格納する

.
├───param1
│   ├───Gen
│   │       W (seed, T, N)
│   │       X (seed, T, N)
│   │       Y (seed, T)
│   │
│   ├───Model1
│   │       W (seed, T, N)
│   │       P (seed, T, N, N)
│   │
│   ├───Model2
│
├───param2

\(\!\)

\(\!\) Type Details
param Param Param
name str Model name (Gen, EKF, etc.)
data dict dataset_name: jnp.array
Returns None \(\!\)

WXY

 WXY (W:jaxtyping.Float[Array,'TN'], X:jaxtyping.Float[Array,'TN'],
      Y:jaxtyping.Float[Array,'T'])

\(\!\)

\(\!\) Type Details
W Float[Array, ‘T N’] \(\{\mathbf w_t\}_{t=0,\ldots,T-1}\)
X Float[Array, ‘T N’] \(\{\mathbf x_t\}_{t=0,\ldots,T-1}\)
Y Float[Array, ‘T’] \(\{y_t\}_{t=0,\ldots,T-1}\)

\(\!\)


generate

 generate (key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']],
           N:int, T:int, p:__main__.Parameter)
\(\!\) Type Details
key Union PRNGKeyArray
N int \(N\)
T int \(T\)
p Parameter Parameter
Returns WXY \(\!\)

generate_main

 generate_main (param:__main__.Param, seed:int)

\(\!\) データを生成し、02_Data.h5 に保存する。 \(\!\)

\(\!\) Type Details
param Param Param
seed int seed値

predict_main

 predict_main (param:__main__.Param, model_name:str, func:Callable[[int,in
               t,jaxtyping.Float[Array,'TN'],jaxtyping.Float[Array,'T'],__
               main__.Parameter],NamedTuple])

\(\!\) 02_Data.h5 のデータ \(X,Y\) に対して func\(W\) 等を推論し、保存する。 \(\!\)

\(\!\) Type Details
param Param Param
model_name str EKF, etc.
func Callable \(N,T,\{\mathbf x_t\}_{t=0,\ldots,T-1},\{y_t\}_{t=0,\ldots,T-1},\mathrm{p}\to\{\hat{\mathbf w_t}\}_{t=0,\ldots,T-1},\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1},\ldots\)

関数のテスト

LOCK = True

データ生成

分析

推論結果の様子

結果

  • \(N\)
  • \(T=1000\)
  • \(\boldsymbol\Gamma=2^p\mathbf I\)
  • \(\boldsymbol\Sigma=2^q\mathbf I\)
  • \(\mathbf w_{-1}=(r/2)(1,\ldots,1)^T/\sqrt{N}\)
  • \(\mathbf P_{-1}=\boldsymbol\Gamma\)
  • \(\epsilon=2^{-16}\)

Listening on: localhost:8080

誤差

\[E(\{w_{0,t}\}_{t=0,\ldots,T})=\{\hat{w}_{0,t} - w_{0,t}\}_{t=0,\ldots,T}\]

  • \(N\)
  • \(T=1000\)
  • \(\boldsymbol\Gamma=2^p\mathbf I\)
  • \(\boldsymbol\Sigma=2^q\mathbf I\)
  • \(\mathbf w_{-1}=(r/2)(1,\ldots,1)^T/\sqrt{N}\)
  • \(\mathbf P_{-1}=\boldsymbol\Gamma\)
  • \(\epsilon=2^{-16}\)

Listening on: localhost:8081

Listening on: localhost:8082

algorithm frob_error relative_error
p
-4 EKF 2.979151 33.705244
-4 VApre 1.796990 20.330618
-4 VAEM 1.792639 20.281393
-4 wEXP_PVA 1.763215 19.948498
-6 EKF 0.973215 44.042689
-6 VApre 0.778222 35.218304
-6 VAEM 0.777547 35.187783
-6 wEXP_PVA 0.774319 35.041681
-8 EKF 0.380470 68.872392
-8 VApre 0.352851 63.872866
-8 VAEM 0.352767 63.857691
-8 wEXP_PVA 0.352492 63.807886
-10 EKF 0.165993 120.191809
-10 VApre 0.162936 117.978518
-10 VAEM 0.162928 117.972616
-10 wEXP_PVA 0.162913 117.961913
-12 EKF 0.072970 211.344333
-12 VApre 0.072780 210.792704
-12 VAEM 0.072779 210.791215
-12 wEXP_PVA 0.072779 210.790546