"""Parametric Extended Kalman Filter (EKF) implementation for state estimation in nonlinear dynamic systems. This module provides a class `PEKF` that implements the Extended Kalman Filter (EKF) algorithm for state estimation in nonlinear dynamic systems. The EKF algorithm estimates the state of a discrete-time controlled process that is governed by a nonlinear stochastic difference equation. The parametric variant of the EKF allows for data-driven learning of hyperparameters and model parameters for the state transition and measurement functions using neural networks and backpropagation algorithms. The loss function for training the PEKF is the Maximum Likelihood Estimation (MLE) loss which is computed using the joint log-likelihood of measurements. Classes: EKF: Implements the Extended Kalman Filter. Functions: joint_jacobian_transform: Transforms a function(callable or nn.Module) to compute its Jacobian and value in single pass. log_likelihood: Computes the log-likelihood for given residuals (innovation) and it's covariance. Example: ``` # Define state transition and measurement functions as nn.Module f = StateTransitionFunction() h = MeasurementFunction() # Initialize the EKF ekf = EKF(dim_x=4, dim_z=2, f=f, h=h) # Define initial state, covariance, and noise matrices x0 = torch.zeros(4) P0 = torch.eye(4) Q = torch.eye(4) * 0.1 R = torch.eye(2) * 0.1 # Define measurement sequence z = torch.rand(10, 2) # Perform batch filtering and smoothing results = ekf.batch_smoothing(z, x0, P0, Q, R) # Calculate the loss for training loss = ekf(z, x0, P0, Q, R) ``` """ import torch import torch.nn as nn from .joint_jacobian_function import joint_jacobian_transform from .loss.negative_log_likelihood import log_likelihood from .loss.varaitional_loss import ( kl_divergence_of_transistion_model, log_likelihood_of_observation_model, ) __all__ = [ "ParametricExtendedKalmanFilter", "NegativeLogLikelihoodPEKF", "VariationalNEKF", ] class ParametricExtendedKalmanFilter(nn.Module): """Implementation of the Extended Kalman Filter (EKF) algorithm for state estimation. Args: dim_x (int): Dimension of the state vector. dim_z (int): Dimension of the measurement vector. f (nn.Module): State transition function. h (nn.Module): Measurement function. Attributes: dim_x (int): Dimension of the state vector. dim_z (int): Dimension of the measurement vector. f (nn.Module): State transition function. h (nn.Module): Measurement function. I (torch.Tensor): Identity matrix of size (dim_x, dim_x). _f (Callable): State transition function with Jacobian computation. _h (Callable): Measurement function with Jacobian computation. Methods: predict: Predicts the state of the system. update: Updates the state estimate based on the measurement. predict_update: Runs the predict-update loop. batch_filtering: Processes the sequence of measurements. fixed_interval_smoothing: Performs fixed-interval smoothing on the state estimates. batch_smoothing: Processes the sequence of measurements to form a Maximum Likelihood Estimation (MLE) loss. autocorreleation: Computes the autocorrelation of the innovation residuals sequence. forward: Processes the sequence of measurements to form a Maximum Likelihood Estimation (MLE) loss. """ TERMS = { "PriorEstimate": "x_prior", "PriorCovariance": "P_prior", "StateJacobian": "State_Jacobian", "PosteriorEstimate": "x_posterior", "PosteriorCovariance": "P_posterior", "InnovationResidual": "innovation_residual", "InnovationCovariance": "innovation_covariance", "KalmanGain": "Kalman_gain", "MeasurementJacobian": "Measurement_Jacobian", "SmoothedEstimate": "x_smoothed", "SmoothedCovariance": "P_smoothed", "SmoothedInitialEstimate": "x0_smoothed", "SmoothedInitialCovariance": "P0_smoothed", } def __init__( self, dim_x: int, dim_z: int, f: nn.Module, h: nn.Module, ) -> None: """Initializes the PEKF algorithm. Args: dim_x (int): Dimension of the state vector. dim_z (int): Dimension of the measurement vector. f (nn.Module): State transition function. h (nn.Module): Measurement function. Note: - The state transition function f signature is f(x: torch.Tensor, *args) -> torch.Tensor - The measurement function h signature is h(x: torch.Tensor, *args) -> torch.Tensor - Any argument must scale with the batch dimension. """ super().__init__() # Store the dimensions of the matrices self.dim_x = dim_x self.dim_z = dim_z # Store the state transition and measurement functions self.f = f self.h = h # Register the jacobian functions self._f = joint_jacobian_transform(f) self._h = joint_jacobian_transform(h) def predict( self, x_posterior: torch.Tensor, P_posterior: torch.Tensor, Q: torch.Tensor | nn.Parameter, f_args: tuple = (), ) -> dict[str, torch.Tensor]: """Predicts the state of the system. Args: x_posterior (torch.Tensor): Posterior state estimate. (dim_x, ) P_posterior (torch.Tensor): Posterior state error covariance. (dim_x, dim_x) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). Returns: dict[str, torch.Tensor]: Dictionary containing the predicted state estimate, state error covariance, and the state transition matrix. Note: - f_args must scale with the batch dimension. """ # Return the predicted state and state error covariance F, x_prior = self._f(x_posterior, *f_args) P_prior = F @ P_posterior @ F.T + Q return { self.TERMS["PriorEstimate"]: x_prior, self.TERMS["PriorCovariance"]: P_prior, self.TERMS["StateJacobian"]: F, } def update( self, z: torch.Tensor, x_prior: torch.Tensor, P_prior: torch.Tensor, R: torch.Tensor | nn.Parameter, h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Updates the state estimate based on the measurement. Args: x_prior (torch.Tensor): Prior state estimate. (dim_x, ) P_prior (torch.Tensor): Prior state error covariance. (dim_x, dim_x) z (torch.Tensor): Measurement vector. (dim_z, ) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: dict[str, torch.Tensor]: Dictionary containing the updated state estimate, state error covariance, and the innovation. Note: - h_args must scale with the batch dimension. """ # Compute the predicted measurement and the Jacobian H, z_pred = self._h(x_prior, *h_args) # Compute the innovation y = z - z_pred # Compute the innovation covariance matrix S = H @ P_prior @ H.T + R # Compute the Kalman gain K = P_prior @ H.T @ torch.linalg.inv(S) # Update the state vector x_post = x_prior + K @ y # Update the state covariance matrix using joseph form since # EKF is not guaranteed to be optimal factor = torch.eye(self.dim_x, device=x_post.device, dtype=x_post.dtype) - K @ H P_post = factor @ P_prior @ factor.T + K @ R @ K.T return { self.TERMS["PosteriorEstimate"]: x_post, self.TERMS["PosteriorCovariance"]: P_post, self.TERMS["InnovationResidual"]: y, self.TERMS["InnovationCovariance"]: S, self.TERMS["KalmanGain"]: K, self.TERMS["MeasurementJacobian"]: H, } def predict_update( self, x_posterior: torch.Tensor, P_posterior: torch.Tensor, z: torch.Tensor, Q: torch.Tensor | nn.Parameter, R: torch.Tensor | nn.Parameter, f_args: tuple = (), h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Runs the predict-update loop. Args: x_posterior (torch.Tensor): Posterior state estimate. (dim_x, ) P_posterior (torch.Tensor): Posterior state error covariance. (dim_x, dim_x) z (torch.Tensor): Measurement vector. (dim_z, ) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: dict[str, torch.Tensor]: Dictionary containing the state estimates and state error covariances. Note: - h_args and f_args must scale with the batch dimension. """ # Predict the state prediction = self.predict(x_posterior, P_posterior, Q, f_args) # Update the state update = self.update( z=z, x_prior=prediction[self.TERMS["PriorEstimate"]], P_prior=prediction[self.TERMS["PriorCovariance"]], R=R, h_args=h_args, ) return {**prediction, **update} def batch_filtering( self, z: torch.Tensor, x0: torch.Tensor, P0: torch.Tensor, Q: torch.Tensor | nn.Parameter, R: torch.Tensor | nn.Parameter, f_args: tuple = (), h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Processes the sequence of measurements. Args: z (torch.Tensor): Measurement sequence. (num_timesteps, dim_z) x0 (torch.Tensor): Initial state estimate. (dim_x, ) P0 (torch.Tensor): Initial state error covariance. (dim_x, dim_x) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: dict[str, list[torch.Tensor]]: Dictionary containing lists of state estimates and state error covariances at each time step. Note: - h_args and f_args must scale with the batch dimension. """ # Sequence length T = z.shape[0] # Initialize the intermediate variables output = { self.TERMS[key]: [] for key in self.TERMS.keys() if not key.startswith("Smoothed") } # Run the filtering algorithm for t in range(T): # Perform the predict-update loop results = self.predict_update( x_posterior=( x0 if t == 0 else output[self.TERMS["PosteriorEstimate"]][-1] ), P_posterior=( P0 if t == 0 else output[self.TERMS["PosteriorCovariance"]][-1] ), z=z[t], Q=Q, R=R, f_args=(args[t] for args in f_args), h_args=(args[t] for args in h_args), ) # Update the output for term in output: output[term].append(results[term]) # Stack the results for term in output: output[term] = torch.stack(output[term]) return output def fixed_interval_smoothing( self, x0: torch.Tensor, P0: torch.Tensor, x_posterior: torch.Tensor, P_posterior: torch.Tensor, FJacobians: torch.Tensor, Q: torch.Tensor | nn.Parameter, ) -> dict[str, torch.Tensor]: """Performs fixed-interval smoothing on the state estimates. Args: x0 (torch.Tensor): Initial state estimate. (dim_x, ) P0 (torch.Tensor): Initial state error covariance. (dim_x, dim_x) x_posterior (torch.Tensor): Filtered state estimates and covariances. (T, dim_x) P_posterior (torch.Tensor): Filtered state error covariances. (T, dim_x, dim_x) FJacobians (torch.Tensor): State transition Jacobians. (T, dim_x, dim_x) Q (torch.Tensor): Process noise covariance. (dim_x, dim_x) Returns: dict[str, list[torch.Tensor]]: Dictionary containing lists of smoothed state estimates and state error covariances at each time step. """ # Initialize the smoothed state estimates and state error covariances x_smoothed = [] P_smoothed = [] # Last state estimate is already the smoothed state estimate x_smoothed.append(x_posterior[-1]) P_smoothed.append(P_posterior[-1]) # Sequence length T = x_posterior.shape[0] # Loop and perform fixed-interval smoothing from T - 2 (i.e second last state estimate) to 0 (i.e first state estimate) for t in range(T - 2, -1, -1): # Compute the prior covariance P_prior = FJacobians[t + 1] @ P_posterior[t] @ FJacobians[t + 1].T + Q # Compute the smoothing gain L = P_posterior[t] @ FJacobians[t + 1].T @ torch.linalg.inv(P_prior) # Compute the smoothed state estimate x_smoothed.insert( 0, x_posterior[t] + L @ (x_smoothed[0] - FJacobians[t + 1] @ x_posterior[t]), ) # Compute the smoothed state error covariance P_smoothed.insert(0, P_posterior[t] + L @ (P_smoothed[0] - P_prior) @ L.T) # Smoothed initial state estimate and state error covariance P_prior = FJacobians[0] @ P0 @ FJacobians[0].T + Q L = P0 @ FJacobians[0].T @ torch.linalg.inv(P_prior) x0_smoothed = x0 + L @ (x_smoothed[0] - FJacobians[0] @ x0) P0_smoothed = P0 + L @ (P_smoothed[0] - P_prior) @ L.T return { self.TERMS["SmoothedEstimate"]: torch.stack(x_smoothed), self.TERMS["SmoothedCovariance"]: torch.stack(P_smoothed), self.TERMS["SmoothedInitialEstimate"]: x0_smoothed, self.TERMS["SmoothedInitialCovariance"]: P0_smoothed, } def batch_smoothing( self, z: torch.Tensor, x0: torch.Tensor, P0: torch.Tensor, Q: torch.Tensor | nn.Parameter, R: torch.Tensor | nn.Parameter, f_args: tuple = (), h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Processes the sequence of measurements to form an Maximum Likelihood Estimation (MLE) loss. Args: z (torch.Tensor): Measurement sequence. (T, dim_z) x0 (torch.Tensor): Initial state estimate. (dim_x, ) P0 (torch.Tensor): Initial state covariance. (dim_x, dim_x) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: dict[str, torch.Tensor]: Dictionary containing the state estimates and state error covariances. Note: - h_args and f_args must scale with the time dimension. """ # Process the measurements results = self.batch_filtering(z, x0, P0, Q, R, f_args, h_args) # Perform fixed-interval smoothing smoothed = self.fixed_interval_smoothing( x0, P0, results[self.TERMS["PosteriorEstimate"]], results[self.TERMS["PosteriorCovariance"]], results[self.TERMS["StateJacobian"]], Q, ) return {**results, **smoothed} @staticmethod def autocorreleation( innovation_residuals: torch.Tensor, lag: int = 0, ) -> torch.Tensor: """Computes the autocorrelation of the innovation residuals sequence. Args: innovation_residuals (torch.Tensor): Innovation residuals. (T, dim_z) lag (int, optional): Lag. Defaults to 1. Returns: torch.Tensor: Autocorrelation. """ # If the T dimension is less than the lag, return 0 if innovation_residuals.shape[0] < lag: return 0 # Center the residuals residuals = innovation_residuals - torch.mean(innovation_residuals, dim=0) # Compute the outer product expectation of the residuals outer_product = 0 for i in range(len(residuals) - lag): outer_product += torch.outer(residuals[i], residuals[i + lag]) # Compute the autocorrelation return outer_product / (len(residuals) - lag) class NegativeLogLikelihoodPEKF(ParametricExtendedKalmanFilter): """This class calculates the Maximum Likelihood Estimation (MLE) loss for training the PEKF.""" def forward( self, z: torch.Tensor, x0: torch.Tensor, P0: torch.Tensor, Q: torch.Tensor | nn.Parameter, R: torch.Tensor | nn.Parameter, f_args: tuple = (), h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Processes the sequence of measurements to form an Maximum Likelihood Estimation (MLE) loss. Args: z (torch.Tensor): Measurement sequence. (T, dim_z) x0 (torch.Tensor): Initial state estimate. (dim_x, ) P0 (torch.Tensor): Initial state covariance. (dim_x, dim_x) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: loss (torch.Tensor): The loss value. Note: - h_args and f_args must scale with the time dimension. - Do not call this method for inference. Use batch_filtering or batch_smoothing instead. """ # Process the measurements results = self.batch_filtering( z=z, x0=x0, P0=P0, Q=Q, R=R, f_args=f_args, h_args=h_args, ) # The negative log-likelihood loss for grdient descent # The negative sign is used to convert the maximum likelihood problem to a minimization problem return -log_likelihood( results[self.TERMS["InnovationResidual"]], results[self.TERMS["InnovationCovariance"]], ) class VariationalNEKF(ParametricExtendedKalmanFilter): """The variational Neural Extended Kalman Filter (NEKF) model described in the paper.""" def forward( self, z: torch.Tensor, x0: torch.Tensor, P0: torch.Tensor, Q: torch.Tensor | nn.Parameter, R: torch.Tensor | nn.Parameter, f_args: tuple = (), h_args: tuple = (), ) -> dict[str, torch.Tensor]: """Processes the sequence of measurements to form an Variational Objective ELBO loss. Args: z (torch.Tensor): Measurement sequence. (T, dim_z) x0 (torch.Tensor): Initial state estimate. (dim_x, ) P0 (torch.Tensor): Initial state covariance. (dim_x, dim_x) Q (torch.Tensor | nn.Parameter): Process noise covariance. (dim_x, dim_x) R (torch.Tensor | nn.Parameter): Measurement noise covariance. (dim_z, dim_z) f_args (tuple, optional): Additional arguments for the state transition function. Defaults to (). h_args (tuple, optional): Additional arguments for the measurement function. Defaults to (). Returns: loss (torch.Tensor): The loss value. Note: - h_args and f_args must scale with the time dimension. - Do not call this method for inference. Use batch_filtering or batch_smoothing instead. """ # Process the measurements # Diable gradient computation of claculation of approximate posterior results = self.batch_smoothing( z=z, x0=x0, P0=P0, Q=Q, R=R, f_args=f_args, h_args=h_args, ) # With gradient computation enabled, calculate the KL divergence and log likelihood kl_divergence = kl_divergence_of_transistion_model( Df=self._f, DF_args=f_args, x0=x0, P0=P0, x_smoothed=results[self.TERMS["SmoothedEstimate"]], P_smoothed=results[self.TERMS["SmoothedCovariance"]], Q=Q, ) # Compute the log likelihood of the observation model log_likelihood = log_likelihood_of_observation_model( y=z, Dh=self._h, DH_args=h_args, x_smoothed=results[self.TERMS["SmoothedEstimate"]], P_smoothed=results[self.TERMS["SmoothedCovariance"]], R=R, ) # Minimize the negative ELBO return kl_divergence - log_likelihood