# assume sp.losi, sp.dxlosi exist (sigmoid and its derivative)
# --- helper: Mackay's K(s) and derivative K'(s) (eq.9 and eq.68 in paper) ---
@jax.jit
def K_of_s(s: Float[Array, ""]):
# K(s) = (1 + s^2 / 8)^(-1/2)
return (1.0 + (s**2) / 8.0) ** (-0.5)
@jax.jit
def Kprime_of_s(s: Float[Array, ""]):
# K'(s) = - (s / 8) * (1 + s^2 / 8)^(-3/2)
return -(s / 8.0) * (1.0 + (s**2) / 8.0) ** (-1.5)
# --- Nonstationary EKF runner ---
class NS_EKF_Out(NamedTuple):
W: Float[Array, "T N"]
P: Float[Array, "T N N"]
q: Float[Array, "T"] # tracked q_t per time step
@partial(jax.jit, static_argnames=['N','T','Nw'])
def EKF_nonstationary(
N: int,
T: int,
x: Float[Array, "{T} {N}"], # inputs sequence
y: Float[Array, "{T}"], # scalar class labels (0/1) per t
w0: Float[Array, "{N}"], # initial w_{0/-1}
P0: Float[Array, "{N} {N}"], # initial P_{0/-1}
q0: float = 1e-6, # initial state-noise variance
eta_q: float = 1e-3, # learning rate for q updates
Nw: int = 50, # window size for q gradient estimation
q_min: float = 0.0,
q_max: float = 1.0
) -> NS_EKF_Out:
r"""
Nonstationary EKF:
- uses Q_t = q_t * I
- updates q_t by gradient-ascent (average gradient over window of length Nw)
based on Appendix E (eq.70) in the provided paper. See: dynamic logistic regression.pdf. :contentReference[oaicite:1]{index=1}
"""
I = jnp.eye(N)
class Carry(NamedTuple):
Ptm: Float[Array, "{N} {N}"] # P_{t/t-1}
wtm: Float[Array, "{N}"] # w_{t/t-1}
q_t: Float[Array, ""] # current scalar q_t
buf_Ptm: Float[Array, "{Nw} {N} {N}"] # circular buffer of past Ptm (for gradient)
buf_x: Float[Array, "{Nw} {N}"] # buffer of past x
buf_y: Float[Array, "{Nw}"] # buffer of past y
buf_idx: int # next write index (0..Nw-1)
buf_count: int # how many entries filled (<= Nw)
class Input(NamedTuple):
xt: Float[Array, "{N}"]
yt: Float[Array, ""]
class Output(NamedTuple):
wtt_: Float[Array, "{N}"]
Ptt_: Float[Array, "{N} {N}"]
q_: Float[Array, ""]
# initialize buffers with zeros
init_buf_Ptm = jnp.zeros((Nw, N, N))
init_buf_x = jnp.zeros((Nw, N))
init_buf_y = jnp.zeros((Nw,))
def compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count, q_current, w_current):
"""
Compute average gradient of log-evidence w.r.t q over buffer entries (use eq.70-like form).
We follow the derivation in appendix E: grad ≈ (z - ~y) * a * K'(s) * (x^T x) / (2 s^2)
where s^2 = x^T (Ptm + q I) x, ~y = sigmoid(K(s) * a), a = w^T x (activation using previous w).
"""
def per_sample_grad(carry, elems):
# elems: (Ptm_i, x_i, y_i)
Ptm_i, x_i, y_i = elems
Pprior = Ptm_i + q_current * I # P_{t/t-1} + Q_t
s2 = x_i @ (Pprior @ x_i) # scalar
# prevent tiny s2 -> numerical issues
s2_safe = jnp.maximum(s2, 1e-12)
s = jnp.sqrt(s2_safe)
K = K_of_s(s)
Kp = Kprime_of_s(s)
a = w_current @ x_i # activation using current w (approx)
y_tilde = sp.losi(K * a) # moderated prediction ~y
xTx = x_i @ x_i
grad = (y_i - y_tilde) * a * Kp * xTx / (2.0 * s2_safe)
return carry, grad
# only iterate over the first buf_count entries
elems = (buf_Ptm[:buf_count], buf_x[:buf_count], buf_y[:buf_count])
_, grads = lax.scan(per_sample_grad, None, elems)
# mean gradient
mean_grad = jnp.mean(grads) if buf_count > 0 else 0.0
return mean_grad
def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
Ptm, wtm, q_t, buf_Ptm, buf_x, buf_y, buf_idx, buf_count = carry
xt, yt = inputs
# EKF update using current Ptm and wtm (same functions as user's originals)
Ptt_ = Ptt(Ptm, wtm, xt) # P_{t/t} (uses sp.dxlosi etc.)
wtt_ = wtt(Ptm, wtm, xt, yt)
# next predicted P_{(t+1)/t} = P_{t/t} + Q_t where Q_t = q_t * I
Pnext = Ptt_ + q_t * I
# update circular buffers: write current Ptm, xt, yt
buf_Ptm = buf_Ptm.at[buf_idx].set(Ptm)
buf_x = buf_x.at[buf_idx].set(xt)
buf_y = buf_y.at[buf_idx].set(yt)
buf_idx_next = (buf_idx + 1) % buf_Ptm.shape[0]
buf_count_next = jnp.minimum(buf_count + 1, buf_Ptm.shape[0])
# compute q gradient and update q every step (could be done every M steps)
grad_q = compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count_next, q_t, wtm)
q_new = q_t + eta_q * grad_q
q_new = jnp.clip(q_new, q_min, q_max)
new_carry = Carry(Pnext, wtt_, q_new, buf_Ptm, buf_x, buf_y, buf_idx_next, buf_count_next)
out = Output(wtt_, Ptt_, q_new)
return new_carry, out
# initial carry: P0 is prior P_{0/-1} already; w0 is w_{0/-1}
init_carry = Carry(P0, w0, jnp.array(q0), init_buf_Ptm, init_buf_x, init_buf_y, 0, 0)
_, outputs = lax.scan(step, init_carry, Input(x, y), length=T)
W = outputs.wtt_
P = outputs.Ptt_
q_seq = outputs.q_
return NS_EKF_Out(W, P, q_seq)Extended Kalman Filter
概要
濾波推定値 \(\hat{\mathbf w}_{t/t}\) \[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\]
推定誤差共分散行列 \(\mathbf P_{t/t}\) \[\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\] \[\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T\] \[\mathbf P_{t/t}^{-1}=\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\]
各変数の情報の利用
| \(\!\) | \(\mathbf P_{t/t}\) | \(\hat{\mathbf w}_{t/t}\) | \(\mathbf P_{t/t-1}\) | \(\hat{\mathbf w}_{t/t-1}\) | \(\mathbf x_t\) | \(y_t\) |
|---|---|---|---|---|---|---|
| \(\mathbf P_{t/t}\) | \(\!\) | \(\!\) | \(\bigcirc\) | \(\bigcirc\) | \(\bigcirc\) | \(\!\) |
| \(\hat{\mathbf w}_{t/t}\) | \(\!\) | \(\!\) | \(\bigcirc\) | \(\bigcirc\) | \(\bigcirc\) | \(\bigcirc\) |
ラプラス近似による導出
1. \(p(\mathbf w_t\mid\mathbf Y_t)\)
\[ p(\mathbf w_t\mid \mathbf Y_t)=\frac{p(y_t\mid\mathbf w_t)p(\mathbf w_t\mid\mathbf Y_{t-1})}{p(y_t\mid\mathbf Y_{t-1})} \]
\[p(\mathbf w_t\mid\mathbf Y_{t-1})=\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\]
よって、
\[ \begin{equation*} p(\mathbf w_t\mid \mathbf Y_t)= \begin{cases} \sigma(\mathbf w_t^T\mathbf x_t)\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) &( y_t=1) \\ \{1-\sigma(\mathbf w_t^T\mathbf x_t)\}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) & (y_t=0) \end{cases} \end{equation*} \]
2. \(\hat{\mathbf w}_{t/t}\)
\(\mathbf w_{t}\) をMAP推定する。 \(\ln p(\mathbf w_t\mid\mathbf Y_t)\) の微分を \(\mathbf 0\) と置く。
\[ \begin{split} &\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\ &= \frac{\partial}{\partial \mathbf w_t}\ln\left[\sigma(\mathbf w_t^T\mathbf x_t)^{y_t}\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}^{1-y_t}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\right] \\ &= \frac{\partial}{\partial \mathbf w_t}y_t\ln\sigma(\mathbf w_t^T\mathbf x_t)+\frac{\partial}{\partial\mathbf w_t}(1-y_t)\ln\{1-\sigma(\mathbf w_t^T\mathbf x_t)\} \\ &\phantom{=} +\frac{\partial}{\partial\mathbf w_t}\ln\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) \\ &= \left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \end{split} \]
ここで、\(\sigma(\mathbf w_t^T\mathbf x_t)\) をテイラー展開で一次近似する。
\(\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\) とする。
\[ \begin{split} \sigma(\mathbf w_t^T\mathbf x_t) &\simeq \sigma_t+\left.\frac{\partial\sigma(\mathbf w_t^T\mathbf x_t)}{\partial \mathbf w_t}\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &\phantom{00}=\sigma_t+\sigma_t\{1-\sigma_t\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \end{split} \]
\(\sigma(\mathbf w_t^T\mathbf x_t)\) を \(\sigma_t\) に置き換え、\(\mathbf 0\) とおく。
\[ \begin{split} &\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\ &=\left[ y_t-\sigma_t-\sigma_t\left\{1-\sigma_t\right\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \right]\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &=(y_t-\sigma_t)\mathbf x_t-\left[\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\right](\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &=\mathbf 0 \end{split} \]
\[ \begin{split} \mathbf w_t &= \hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}^{-1}+\sigma_t\left\{1-\sigma_t\right\}\mathbf x_t\mathbf x_t^T\right]^{-1}\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}\mathbf x_t-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right](y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[1-\frac{\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \end{split} \]
よって \(\mathbf w_t\) のMAP推定値 \(\hat{\mathbf w}_{t/t}\) は
\[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\]
となる。
3. \(\mathbf P_{t/t}\)
\(\hat{\mathbf w}_{t/t-1}\) を \(p(\mathbf w_t\mid\mathbf Y_t)\) のピークとすると \(\mathbf P_{t/t}\) が得られる。
( \(\hat{\mathbf w}_{t/t}\) をピークとするのが本来のラプラス近似)
\[ \begin{split} \mathbf P_{t/t}^{-1} &= \left.-\frac{\partial^2}{\partial\mathbf w_t^2} \ln p(\mathbf w_t\mid\mathbf Y_t)\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\ &= \left. -\frac{\partial}{\partial\mathbf w_t}\left[\left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})\right]\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\ &= \mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T \end{split} \]
Sherman-Morrison の公式によって、 \(\mathbf P_{t/t}\) が得られる。
\[ \mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T \]
関数等
Ptt
def Ptt(
Ptm:Float[Array, 'N N'], # $\mathbf P_{t/t-1}$
w:Float[Array, 'N'], # $\hat{\mathbf w}_{t/t-1}$
x:Float[Array, 'N'], # $\mathbf x_t$
)->Float[Array, 'N N']: # $\mathbf P_{t/t}$
*\(\!\)** 推定誤差共分散行列 \(\mathbf P_{t/t}\) \[\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\] \[\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T\] *\(\!\)
wtt
def wtt(
Ptm:Float[Array, 'N N'], # $\mathbf P_{t/t-1}$
w:Float[Array, 'N'], # $\hat{\mathbf w}_{t/t-1}$
x:Float[Array, 'N'], # $\mathbf x_t$
y:Float[Array, 'N'], # $y_t$
)->Float[Array, 'N']: # $\hat{\mathbf w}_{t/t}$
*\(\!\)** 濾波推定値 \(\hat{\mathbf w}_{t/t}\) \[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\] *\(\!\)
EKF_out
def EKF_out(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
*\(\!\)** EKF 関数の返り値
| \(\!\) | Type | Details |
|---|---|---|
| W | Float[Array, ‘T N’] | \(\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}\) |
| P | Float[Array, ‘T N N’] | \(\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}\) |
*\(\!\)
EKF
def EKF(
N:int, # $N$
T:int, # $T$
x:Float[Array, '{T} {N}'], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
y:Float[Array, '{T} {N}'], # $\{ y_t \}_{t=0,\ldots,T-1}$
G:Float[Array, '{N} {N}'], # $\boldsymbol\Gamma$
w0:Float[Array, '{N}'], # $\hat{\mathbf w}_{0/-1}$
P0:Float[Array, '{N} {N}'], # $\mathbf P_{0/-1}$
)->EKF_out:
*\(\!\)** 拡張カルマンフィルタ *\(\!\)