Extended Kalman Filter

Dynamic Logistic Regression で導出された拡張カルマンフィルタ

概要

濾波推定値 \(\hat{\mathbf w}_{t/t}\) \[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\]

推定誤差共分散行列 \(\mathbf P_{t/t}\) \[\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\] \[\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T\] \[\mathbf P_{t/t}^{-1}=\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\]

各変数の情報の利用

\(\!\) \(\mathbf P_{t/t}\) \(\hat{\mathbf w}_{t/t}\) \(\mathbf P_{t/t-1}\) \(\hat{\mathbf w}_{t/t-1}\) \(\mathbf x_t\) \(y_t\)
\(\mathbf P_{t/t}\) \(\!\) \(\!\) \(\bigcirc\) \(\bigcirc\) \(\bigcirc\) \(\!\)
\(\hat{\mathbf w}_{t/t}\) \(\!\) \(\!\) \(\bigcirc\) \(\bigcirc\) \(\bigcirc\) \(\bigcirc\)

ラプラス近似による導出

1. \(p(\mathbf w_t\mid\mathbf Y_t)\)

\[ p(\mathbf w_t\mid \mathbf Y_t)=\frac{p(y_t\mid\mathbf w_t)p(\mathbf w_t\mid\mathbf Y_{t-1})}{p(y_t\mid\mathbf Y_{t-1})} \]

\[p(\mathbf w_t\mid\mathbf Y_{t-1})=\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\]

よって、

\[ \begin{equation*} p(\mathbf w_t\mid \mathbf Y_t)= \begin{cases} \sigma(\mathbf w_t^T\mathbf x_t)\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) &( y_t=1) \\ \{1-\sigma(\mathbf w_t^T\mathbf x_t)\}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) & (y_t=0) \end{cases} \end{equation*} \]

2. \(\hat{\mathbf w}_{t/t}\)

\(\mathbf w_{t}\) をMAP推定する。 \(\ln p(\mathbf w_t\mid\mathbf Y_t)\) の微分を \(\mathbf 0\) と置く。

\[ \begin{split} &\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\ &= \frac{\partial}{\partial \mathbf w_t}\ln\left[\sigma(\mathbf w_t^T\mathbf x_t)^{y_t}\{1-\sigma(\mathbf w_t^T\mathbf x_t)\}^{1-y_t}\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1})\right] \\ &= \frac{\partial}{\partial \mathbf w_t}y_t\ln\sigma(\mathbf w_t^T\mathbf x_t)+\frac{\partial}{\partial\mathbf w_t}(1-y_t)\ln\{1-\sigma(\mathbf w_t^T\mathbf x_t)\} \\ &\phantom{=} +\frac{\partial}{\partial\mathbf w_t}\ln\mathcal N(\mathbf w_t\mid\hat{\mathbf w}_{t/t-1},\mathbf P_{t/t-1}) \\ &= \left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \end{split} \]

ここで、\(\sigma(\mathbf w_t^T\mathbf x_t)\) をテイラー展開で一次近似する。

\(\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\) とする。

\[ \begin{split} \sigma(\mathbf w_t^T\mathbf x_t) &\simeq \sigma_t+\left.\frac{\partial\sigma(\mathbf w_t^T\mathbf x_t)}{\partial \mathbf w_t}\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &\phantom{00}=\sigma_t+\sigma_t\{1-\sigma_t\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \end{split} \]

\(\sigma(\mathbf w_t^T\mathbf x_t)\)\(\sigma_t\) に置き換え、\(\mathbf 0\) とおく。

\[ \begin{split} &\phantom{=}\frac{\partial}{\partial \mathbf w_t}\ln p(\mathbf w_t\mid\mathbf Y_t) \\ &=\left[ y_t-\sigma_t-\sigma_t\left\{1-\sigma_t\right\}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})^T\mathbf x_t \right]\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &=(y_t-\sigma_t)\mathbf x_t-\left[\mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T\right](\mathbf w_t-\hat{\mathbf w}_{t/t-1}) \\ &=\mathbf 0 \end{split} \]

\[ \begin{split} \mathbf w_t &= \hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}^{-1}+\sigma_t\left\{1-\sigma_t\right\}\mathbf x_t\mathbf x_t^T\right]^{-1}\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[\mathbf P_{t/t-1}\mathbf x_t-\frac{\sigma_t(1-\sigma_t)\mathbf P_{t/t-1}\mathbf x_t\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right](y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\left[1-\frac{\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\right]\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \\ &=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t) \end{split} \]

よって \(\mathbf w_t\) のMAP推定値 \(\hat{\mathbf w}_{t/t}\)

\[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\]

となる。

3. \(\mathbf P_{t/t}\)

\(\hat{\mathbf w}_{t/t-1}\)\(p(\mathbf w_t\mid\mathbf Y_t)\) のピークとすると \(\mathbf P_{t/t}\) が得られる。

\(\hat{\mathbf w}_{t/t}\) をピークとするのが本来のラプラス近似)

\[ \begin{split} \mathbf P_{t/t}^{-1} &= \left.-\frac{\partial^2}{\partial\mathbf w_t^2} \ln p(\mathbf w_t\mid\mathbf Y_t)\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\ &= \left. -\frac{\partial}{\partial\mathbf w_t}\left[\left\{y_t-\sigma(\mathbf w_t^T\mathbf x_t)\right\}\mathbf x_t-\mathbf P_{t/t-1}^{-1}(\mathbf w_t-\hat{\mathbf w}_{t/t-1})\right]\right|_{\mathbf w_t=\hat{\mathbf w}_{t/t-1}} \\ &= \mathbf P_{t/t-1}^{-1}+\sigma_t(1-\sigma_t)\mathbf x_t\mathbf x_t^T \end{split} \]

Sherman-Morrison の公式によって、 \(\mathbf P_{t/t}\) が得られる。

\[ \mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T \]

関数等


source

Ptt


def Ptt(
    Ptm:Float[Array, 'N N'], # $\mathbf P_{t/t-1}$
    w:Float[Array, 'N'], # $\hat{\mathbf w}_{t/t-1}$
    x:Float[Array, 'N'], # $\mathbf x_t$
)->Float[Array, 'N N']: # $\mathbf P_{t/t}$

*\(\!\)** 推定誤差共分散行列 \(\mathbf P_{t/t}\) \[\sigma_t=\sigma(\hat{\mathbf w}_{t/t-1}^T\mathbf x_t)\] \[\mathbf P_{t/t}=\mathbf P_{t/t-1}-\frac{\sigma_t(1-\sigma_t)}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}(\mathbf P_{t/t-1}\mathbf x_t)(\mathbf P_{t/t-1}\mathbf x_t)^T\] *\(\!\)


source

wtt


def wtt(
    Ptm:Float[Array, 'N N'], # $\mathbf P_{t/t-1}$
    w:Float[Array, 'N'], # $\hat{\mathbf w}_{t/t-1}$
    x:Float[Array, 'N'], # $\mathbf x_t$
    y:Float[Array, 'N'], # $y_t$
)->Float[Array, 'N']: # $\hat{\mathbf w}_{t/t}$

*\(\!\)** 濾波推定値 \(\hat{\mathbf w}_{t/t}\) \[\hat{\mathbf w}_{t/t}=\hat{\mathbf w}_{t/t-1}+\frac{1}{1+\sigma_t(1-\sigma_t)\mathbf x_t^T\mathbf P_{t/t-1}\mathbf x_t}\mathbf P_{t/t-1}\mathbf x_t(y_t-\sigma_t)\] *\(\!\)


source

EKF_out


def EKF_out(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

*\(\!\)** EKF 関数の返り値

\(\!\) Type Details
W Float[Array, ‘T N’] \(\{\hat{\mathbf w}_{t/t}\}_{t=0,\ldots,T-1}\)
P Float[Array, ‘T N N’] \(\{\mathbf P_{t/t}\}_{t=0,\ldots,T-1}\)

*\(\!\)


source

EKF


def EKF(
    N:int, # $N$
    T:int, # $T$
    x:Float[Array, '{T} {N}'], # $\{ \mathbf x_t \}_{t=0,\ldots,T-1}$
    y:Float[Array, '{T} {N}'], # $\{ y_t \}_{t=0,\ldots,T-1}$
    G:Float[Array, '{N} {N}'], # $\boldsymbol\Gamma$
    w0:Float[Array, '{N}'], # $\hat{\mathbf w}_{0/-1}$
    P0:Float[Array, '{N} {N}'], # $\mathbf P_{0/-1}$
)->EKF_out:

*\(\!\)** 拡張カルマンフィルタ *\(\!\)

# assume sp.losi, sp.dxlosi exist (sigmoid and its derivative)

# --- helper: Mackay's K(s) and derivative K'(s) (eq.9 and eq.68 in paper) ---
@jax.jit
def K_of_s(s: Float[Array, ""]):
  # K(s) = (1 + s^2 / 8)^(-1/2)
  return (1.0 + (s**2) / 8.0) ** (-0.5)

@jax.jit
def Kprime_of_s(s: Float[Array, ""]):
  # K'(s) = - (s / 8) * (1 + s^2 / 8)^(-3/2)
  return -(s / 8.0) * (1.0 + (s**2) / 8.0) ** (-1.5)

# --- Nonstationary EKF runner ---
class NS_EKF_Out(NamedTuple):
  W: Float[Array, "T N"]
  P: Float[Array, "T N N"]
  q: Float[Array, "T"]   # tracked q_t per time step

@partial(jax.jit, static_argnames=['N','T','Nw'])
def EKF_nonstationary(
    N: int,
    T: int,
    x: Float[Array, "{T} {N}"],  # inputs sequence
    y: Float[Array, "{T}"],      # scalar class labels (0/1) per t
    w0: Float[Array, "{N}"],     # initial w_{0/-1}
    P0: Float[Array, "{N} {N}"], # initial P_{0/-1}
    q0: float = 1e-6,            # initial state-noise variance
    eta_q: float = 1e-3,         # learning rate for q updates
    Nw: int = 50,                # window size for q gradient estimation
    q_min: float = 0.0,
    q_max: float = 1.0
) -> NS_EKF_Out:
  r"""
  Nonstationary EKF:
  - uses Q_t = q_t * I
  - updates q_t by gradient-ascent (average gradient over window of length Nw)
    based on Appendix E (eq.70) in the provided paper. See: dynamic logistic regression.pdf. :contentReference[oaicite:1]{index=1}
  """

  I = jnp.eye(N)

  class Carry(NamedTuple):
    Ptm: Float[Array, "{N} {N}"]  # P_{t/t-1}
    wtm: Float[Array, "{N}"]      # w_{t/t-1}
    q_t: Float[Array, ""]         # current scalar q_t
    buf_Ptm: Float[Array, "{Nw} {N} {N}"]  # circular buffer of past Ptm (for gradient)
    buf_x: Float[Array, "{Nw} {N}"]        # buffer of past x
    buf_y: Float[Array, "{Nw}"]            # buffer of past y
    buf_idx: int                          # next write index (0..Nw-1)
    buf_count: int                        # how many entries filled (<= Nw)

  class Input(NamedTuple):
    xt: Float[Array, "{N}"]
    yt: Float[Array, ""]

  class Output(NamedTuple):
    wtt_: Float[Array, "{N}"]
    Ptt_: Float[Array, "{N} {N}"]
    q_: Float[Array, ""]

  # initialize buffers with zeros
  init_buf_Ptm = jnp.zeros((Nw, N, N))
  init_buf_x = jnp.zeros((Nw, N))
  init_buf_y = jnp.zeros((Nw,))

  def compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count, q_current, w_current):
    """
    Compute average gradient of log-evidence w.r.t q over buffer entries (use eq.70-like form).
    We follow the derivation in appendix E: grad ≈ (z - ~y) * a * K'(s) * (x^T x) / (2 s^2)
    where s^2 = x^T (Ptm + q I) x, ~y = sigmoid(K(s) * a), a = w^T x (activation using previous w).
    """
    def per_sample_grad(carry, elems):
      # elems: (Ptm_i, x_i, y_i)
      Ptm_i, x_i, y_i = elems
      Pprior = Ptm_i + q_current * I      # P_{t/t-1} + Q_t
      s2 = x_i @ (Pprior @ x_i)           # scalar
      # prevent tiny s2 -> numerical issues
      s2_safe = jnp.maximum(s2, 1e-12)
      s = jnp.sqrt(s2_safe)
      K = K_of_s(s)
      Kp = Kprime_of_s(s)
      a = w_current @ x_i                 # activation using current w (approx)
      y_tilde = sp.losi(K * a)            # moderated prediction ~y
      xTx = x_i @ x_i
      grad = (y_i - y_tilde) * a * Kp * xTx / (2.0 * s2_safe)
      return carry, grad

    # only iterate over the first buf_count entries
    elems = (buf_Ptm[:buf_count], buf_x[:buf_count], buf_y[:buf_count])
    _, grads = lax.scan(per_sample_grad, None, elems)
    # mean gradient
    mean_grad = jnp.mean(grads) if buf_count > 0 else 0.0
    return mean_grad

  def step(carry: Carry, inputs: Input) -> Tuple[Carry, Output]:
    Ptm, wtm, q_t, buf_Ptm, buf_x, buf_y, buf_idx, buf_count = carry
    xt, yt = inputs

    # EKF update using current Ptm and wtm (same functions as user's originals)
    Ptt_ = Ptt(Ptm, wtm, xt)   # P_{t/t} (uses sp.dxlosi etc.)
    wtt_ = wtt(Ptm, wtm, xt, yt)

    # next predicted P_{(t+1)/t} = P_{t/t} + Q_t where Q_t = q_t * I
    Pnext = Ptt_ + q_t * I

    # update circular buffers: write current Ptm, xt, yt
    buf_Ptm = buf_Ptm.at[buf_idx].set(Ptm)
    buf_x   = buf_x.at[buf_idx].set(xt)
    buf_y   = buf_y.at[buf_idx].set(yt)
    buf_idx_next = (buf_idx + 1) % buf_Ptm.shape[0]
    buf_count_next = jnp.minimum(buf_count + 1, buf_Ptm.shape[0])

    # compute q gradient and update q every step (could be done every M steps)
    grad_q = compute_q_gradient_from_buffer(buf_Ptm, buf_x, buf_y, buf_count_next, q_t, wtm)
    q_new = q_t + eta_q * grad_q
    q_new = jnp.clip(q_new, q_min, q_max)

    new_carry = Carry(Pnext, wtt_, q_new, buf_Ptm, buf_x, buf_y, buf_idx_next, buf_count_next)
    out = Output(wtt_, Ptt_, q_new)
    return new_carry, out

  # initial carry: P0 is prior P_{0/-1} already; w0 is w_{0/-1}
  init_carry = Carry(P0, w0, jnp.array(q0), init_buf_Ptm, init_buf_x, init_buf_y, 0, 0)

  _, outputs = lax.scan(step, init_carry, Input(x, y), length=T)

  W = outputs.wtt_
  P = outputs.Ptt_
  q_seq = outputs.q_
  return NS_EKF_Out(W, P, q_seq)