Navigator / docs / intro / pekf_gps.ipynb
pekf_gps.ipynb
Raw
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parametric Extended Kalman Filter(EKF) for GPS triangulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a simple implementation of an Extended Kalman Filter for GPS triangulation. The filter is implemented in Python and tested with simulated data. The filter is able to estimate the position and velocity of a moving object given the GPS measurements with a certain accuracy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State Definition\n",
    "The state of the system is defined as follows:\n",
    "$$\n",
    "\\begin{align}\n",
    "\\mathbf{x} &= \\begin{bmatrix} x \\\\ \\dot{x} \\\\ y \\\\ \\dot{y} \\\\ z \\\\ \\dot{z} \\\\ cdt \\\\ c\\dot{dt}\\end{bmatrix}\n",
    "\\end{align}\n",
    "$$\n",
    "where $x,y,z$ are the position of the object in the ECEF frame, $v_x,v_y,v_z$ are the velocity of the object in the ECEF frame, $cdt$ is the reciever clock bias and $c\\dot{dt}$ is the reciever clock drift in units of meters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State transition function:\n",
    "Assuming constant velocity model, the state transition function $F$ can be calculated by using the unit time step model:\n",
    "$$\n",
    "A = \\begin{bmatrix} 1 & \\Delta t \\\\ 0 & 1 \\end{bmatrix}\n",
    "$$\n",
    "Now the state transition function can be calculated as:\n",
    "$$\n",
    "F = \\begin{bmatrix} A & 0 & 0 & 0 \\\\ 0 & A & 0 & 0 \\\\ 0 & 0 & A & 0 \\\\ 0 & 0 & 0 & A \\end{bmatrix}\n",
    "$$\n",
    "which is and 8x8 matrix which maps the state at time $t$ to the state at time $t+1$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process noise\n",
    "The continuous time process noise matrix $Q$ is defined as:\n",
    "$$\n",
    "Q = \\int_{0}^{\\Delta t} F(t)Q_cF^T(t) dt\n",
    "$$\n",
    "where $Q_c$ is the continuous process noise matrix. This is the projection of the continuous noise onto the discrete time model. Defining $Q_c$ as:\n",
    "$$\n",
    "Q_c = \\begin{bmatrix} 0 & 0 \\\\ 0 & 1  \\end{bmatrix} \\cdot \\Phi_{B}^2\n",
    "$$\n",
    "where $\\Phi_{B}$ is the spectral density of the process noise.\n",
    "\n",
    "Now we can project the continuous time process noise onto the discrete time model:\n",
    "$$\n",
    "Q_p = \\int_{0}^{\\Delta t} F(t)Q_cF^T(t) dt = \\begin{bmatrix} \\frac{1}{3}\\Delta t^3 & \\frac{1}{2}\\Delta t^2 \\\\ \\frac{1}{2}\\Delta t^2 & \\Delta t \\end{bmatrix} \\cdot \\Phi_{B}^2\n",
    "$$\n",
    "The clock covariance can be modelled by a random walk process defined by $S_p$ the white noise spectral density leading to random walk velocity error, and $S_g$ the white noise spectral density leading to random walk clock drift error. The process noise matrix $Q_c$ is then defined as:\n",
    "$$\n",
    "Q_c = \\begin{bmatrix} S_f \\Delta t +\\frac{S_g\\Delta t^3}{3}, S_g\\frac{\\Delta t^2}{2} \\\\ S_g\\frac{\\Delta t^2}{2} , S_g \\Delta t  \\end{bmatrix}\n",
    "$$\n",
    "\n",
    "So the final process noise matrix $Q$ is defined as:\n",
    "$$\n",
    "Q = \\begin{bmatrix} Q_p, 0, 0, 0 \\\\ 0, Q_p, 0, 0 \\\\ 0, 0, Q_p, 0 \\\\ 0, 0, 0, Q_c \\end{bmatrix}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Measurement covariance\n",
    "The measurement covariance matrix is defined to be a diagonal matrix with uncorrected standard deviation $\\sigma_r$ of the measurements. The measurement covariance matrix is defined as:\n",
    "$$\n",
    "R = \\begin{bmatrix} \\sigma_r^2, 0, 0 , 0 \\\\ 0, \\sigma_r^2, 0 , 0\\\\ 0, 0, \\sigma_r^2 ,0 \\\\  0, 0, 0, \\sigma_r^2 \\end{bmatrix}\n",
    "$$\n",
    "where $\\sigma_r$ is the standard deviation of the pseudorange measurements. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Measurement function\n",
    "The measurement function $h$ to convert the state to the measurement space. To convert the state vector to the measurement space, we need to calculate the distance between the reciever and the GPS satellites plus the reciever clock offset error. Hence, the measurement function $h$ is defined as:\n",
    "$$\n",
    "h(x_s , s_i) = \\sqrt {(x_s - x_i)^2 + (y_s - y_i)^2 + (z_s - z_i)^2} + cdt\n",
    "$$\n",
    "This is a non-linear function and hence a non-linear filter is required to estimate the state of the system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing 2022-11-19 23:59:30: 100%|██████████| 2880/2880 [00:05<00:00, 488.79it/s]\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "from navigator.epoch import Epoch, from_rinex_files, EpochCollection\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "obsPath = Path(\"data/JPLM00USA_R_20223230000_01D_30S_MO.crx.gz\")\n",
    "navPath = Path(\"data/JPLM00USA_R_20223230000_01D_GN.rnx.gz\")\n",
    "\n",
    "epoches = list(from_rinex_files(\n",
    "    observation_file=obsPath,\n",
    "    navigation_file=navPath,\n",
    "    profile=Epoch.DUAL\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since these are the collection of continouos epoches, we wrap the list of the Epoches into an epoches collection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "collection = EpochCollection(\n",
    "    epochs=epoches,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not to run the EKF, we need a continuoius visiblily of the tracked satellites. This continuous visiblity allows us to run the extended kalman filter to filter the states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctg0 = collection.track(mode=EpochCollection.VISIBILITY)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's set up the WLS, PEKF and EKF traingulation algorathim for testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from navigator.core import IterativeTriangulationInterface\n",
    "import torch\n",
    "CODE_ONLY = False\n",
    "z, sv_pos = IterativeTriangulationInterface.epoches_to_timeseries(\n",
    "    epoches=ctg0,\n",
    "    sv_filter=ctg0[0].common_sv,\n",
    "    code_only=CODE_ONLY\n",
    ")\n",
    "\n",
    "# Convert to torch tensor\n",
    "z = torch.tensor(z, dtype=torch.float64)\n",
    "sv_pos = torch.tensor(sv_pos, dtype=torch.float64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Triangulating: 100%|██████████| 171/171 [00:21<00:00,  8.14it/s]\n"
     ]
    }
   ],
   "source": [
    "from navigator.core import  ExtendedKalmanInterface, Triangulate, IterativeTriangulationInterface\n",
    "\n",
    "# Define the Iterative Triangulation Interface to compare with the EKF\n",
    "tri_wls = Triangulate(\n",
    "    interface=IterativeTriangulationInterface(\n",
    "        code_only=CODE_ONLY\n",
    "\n",
    "    )\n",
    ")\n",
    "\n",
    "df_wls = tri_wls.triangulate_time_series(epoches=ctg0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define the state and measurement dimensions and necessary variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from navigator.neural import (\n",
    "    ParametricExtendedInterface,\n",
    "    SymmetricPositiveDefiniteMatrix,\n",
    "    DiagonalSymmetricPositiveDefiniteMatrix,\n",
    "    TransitionModel,\n",
    "    ObservationModel,\n",
    "    discretized_process_noise_matrix,\n",
    "    BiasedObservationModel)\n",
    "from navigator.core import  ExtendedKalmanInterface, Triangulate, IterativeTriangulationInterface\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Define the state dimension and measurement dimension\n",
    "DIM_STATE = 8\n",
    "DIM_MEASUREMENT = len(ctg0[0]) if CODE_ONLY else 2 * len(ctg0[0])   # Length of the satellites in first epoch\n",
    "DT = 30.0\n",
    "# Define the initial state\n",
    "x0 = torch.zeros(DIM_STATE)\n",
    "P0 = torch.eye(DIM_STATE) * 1e5\n",
    "\n",
    "# Define the process noise\n",
    "Q_AUT = torch.eye(DIM_STATE) * 1e-6\n",
    "Q = discretized_process_noise_matrix(dt=DT, Q_0=Q_AUT)\n",
    "R = torch.eye(DIM_MEASUREMENT) * 64\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets set up the classical EKF first and get the estimation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Extend Kalman Interface\n",
    "eki = ExtendedKalmanInterface(\n",
    "    dt=30.0, # This data is in 30s interval\n",
    "    num_sv=DIM_MEASUREMENT, # Number of satellites\n",
    "    x0=x0.numpy(), # Prior state\n",
    "    P0=P0.numpy(), # Prior covariance\n",
    "    R=R.numpy(), # Measurement noise\n",
    "    Q_0=Q_AUT.numpy(), # Process noise\n",
    "    log_innovation=True, # Log the innovation\n",
    "    code_only=True,  # Genrally num_sv and dim measuremnt is different if using phase measurements, but we already adjusted that in dim_measurements\n",
    "\n",
    ")\n",
    "\n",
    "ekf_state, ekf_resudials = eki.fixed_interval_smoothing(\n",
    "    x0=x0.numpy(),\n",
    "    P0=P0.numpy(),\n",
    "    z=z.numpy(),\n",
    "    sv_pos=sv_pos.numpy()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's set up the PEKF for training and hyperparameter estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Define the noise models\n",
    "q_module = SymmetricPositiveDefiniteMatrix(M=Q, trainable=True)\n",
    "r_module = DiagonalSymmetricPositiveDefiniteMatrix(M=R, trainable=True)\n",
    "\n",
    "# Define the transition model\n",
    "transistion_model = TransitionModel(\n",
    "    dt=DT,\n",
    "    learnable=False,\n",
    ")\n",
    "observation_model = ObservationModel(\n",
    "    trainable=False,\n",
    "    dim_measurement=DIM_MEASUREMENT,\n",
    ")\n",
    "\n",
    "pekf = ParametricExtendedInterface(\n",
    "    dt=DT,\n",
    "    dim_z=DIM_MEASUREMENT,\n",
    "    f=transistion_model,\n",
    "    h=observation_model,\n",
    "    Q=q_module,\n",
    "    R=r_module,\n",
    "    objective=ParametricExtendedInterface.NEGATIVE_LOG_LIKELIHOOD,\n",
    ").to(torch.float64)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's preprocess the tensors to computational format i.e range_measurements, satellite positions which will apply all the preprocessing given in epoch profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "TUNING_LENGTH = 50\n",
    "MAX_EPOCHES = 100\n",
    "\n",
    "z_tune = z[:TUNING_LENGTH]\n",
    "sv_pos_tune = sv_pos[:TUNING_LENGTH]\n",
    "\n",
    "\n",
    "# Set the learning rate for the optimizer\n",
    "LR_Q = 1e-3\n",
    "LR_R = 1e-3\n",
    "LR_H = 1e-5\n",
    "LR_F = 1e-5\n",
    "\n",
    "GAMMA = (1/100) ** (1/MAX_EPOCHES)\n",
    "\n",
    "params = [\n",
    "    {'params': q_module.parameters(), 'lr': LR_Q }, # L2 Regularization\n",
    "    {'params': r_module.parameters(), 'lr': LR_R,}, # L2 Regularization\n",
    "    {'params': transistion_model.parameters(), 'lr': LR_F},\n",
    "    {'params': observation_model.parameters(), 'lr': LR_H},\n",
    "]\n",
    "\n",
    "optimizer = optim.Adam(params, lr=1e-3)\n",
    "\n",
    "explrs = optim.lr_scheduler.ExponentialLR(optimizer, gamma=GAMMA)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's tune the PEKF on the tuning dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: Loss = 27038360064.2245, Q_grad_norm = 3968671806.196714, R_grad_norm = 73814684.66811484, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 1: Loss = 14107759806.863756, Q_grad_norm = 5427325854.641763, R_grad_norm = 57393183.098413624, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 2: Loss = 11258521373.270433, Q_grad_norm = 1142424715.2717113, R_grad_norm = 52620834.22234754, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 3: Loss = 9990074199.231567, Q_grad_norm = 741224403.505499, R_grad_norm = 45512246.20947275, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 4: Loss = 9379591952.420927, Q_grad_norm = 109591670.42726618, R_grad_norm = 44240649.808141544, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 5: Loss = 9188403904.928392, Q_grad_norm = 34569096.39256566, R_grad_norm = 43809336.85934101, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 6: Loss = 9055900159.286354, Q_grad_norm = 19697835.02483456, R_grad_norm = 43200877.29420602, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 7: Loss = 8919655606.221231, Q_grad_norm = 14468083.481402045, R_grad_norm = 42270765.88222693, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 8: Loss = 8768895808.666277, Q_grad_norm = 11727206.844378693, R_grad_norm = 41053096.55266218, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 9: Loss = 8602960199.455137, Q_grad_norm = 9796163.238875085, R_grad_norm = 39609970.34157556, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 10: Loss = 8423994027.050131, Q_grad_norm = 8266764.492912107, R_grad_norm = 38005464.442452624, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 11: Loss = 8234973347.414517, Q_grad_norm = 7011646.733540471, R_grad_norm = 36298405.17560213, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 12: Loss = 8038969393.125875, Q_grad_norm = 5971698.336797704, R_grad_norm = 34539476.216455996, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 13: Loss = 7838828058.864831, Q_grad_norm = 5107861.629157683, R_grad_norm = 32770222.856757324, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 14: Loss = 7637034406.189409, Q_grad_norm = 4389393.379750947, R_grad_norm = 31023182.700174958, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 15: Loss = 7435669768.402222, Q_grad_norm = 3790829.1217356743, R_grad_norm = 29322672.16482941, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 16: Loss = 7236419522.635701, Q_grad_norm = 3291004.7662997334, R_grad_norm = 27685912.744935658, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 17: Loss = 7040605548.214679, Q_grad_norm = 2872480.9824708966, R_grad_norm = 26124260.038254924, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 18: Loss = 6849230792.23733, Q_grad_norm = 2521010.3039658023, R_grad_norm = 24644397.47555266, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 19: Loss = 6663026643.204123, Q_grad_norm = 2225012.5433521424, R_grad_norm = 23249407.57018774, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 20: Loss = 6482499291.135314, Q_grad_norm = 1975087.8314740174, R_grad_norm = 21939688.336816207, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 21: Loss = 6307971694.342197, Q_grad_norm = 1763590.2305525464, R_grad_norm = 20713705.947683364, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 22: Loss = 6139620828.9814005, Q_grad_norm = 1584271.4400774112, R_grad_norm = 19568595.187601577, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 23: Loss = 5977509226.006213, Q_grad_norm = 1431991.708373743, R_grad_norm = 18500624.100254975, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 24: Loss = 5821611507.736794, Q_grad_norm = 1302490.991483885, R_grad_norm = 17505547.35433697, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 25: Loss = 5671836131.463947, Q_grad_norm = 1192210.5912873638, R_grad_norm = 16578866.094658688, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 26: Loss = 5528042864.119204, Q_grad_norm = 1098155.7019640987, R_grad_norm = 15716015.589517664, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 27: Loss = 5390056926.198083, Q_grad_norm = 1017790.681189338, R_grad_norm = 14912496.66445456, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 28: Loss = 5257680004.818107, Q_grad_norm = 948959.4386432016, R_grad_norm = 14163963.357392084, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 29: Loss = 5130698949.96082, Q_grad_norm = 889824.7074157381, R_grad_norm = 13466279.450027293, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 30: Loss = 5008892402.180165, Q_grad_norm = 838821.1917955762, R_grad_norm = 12815550.821556894, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 31: Loss = 4892035928.739513, Q_grad_norm = 794618.3535918841, R_grad_norm = 12208141.150565067, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 32: Loss = 4779905734.126259, Q_grad_norm = 756089.8999135436, R_grad_norm = 11640675.60164148, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 33: Loss = 4672281542.946666, Q_grad_norm = 722287.7495524728, R_grad_norm = 11110036.430303894, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 34: Loss = 4568948595.152737, Q_grad_norm = 692419.0796441452, R_grad_norm = 10613353.281645585, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 35: Loss = 4469699027.90784, Q_grad_norm = 665825.6094055086, R_grad_norm = 10147990.315745607, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 36: Loss = 4374332823.454965, Q_grad_norm = 641964.8863599034, R_grad_norm = 9711530.982719215, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 37: Loss = 4282658406.7676897, Q_grad_norm = 620393.1109334312, R_grad_norm = 9301762.871843401, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 38: Loss = 4194492864.0040264, Q_grad_norm = 600749.8059972642, R_grad_norm = 8916661.514112946, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 39: Loss = 4109662115.730308, Q_grad_norm = 582743.8619216436, R_grad_norm = 8554375.311739963, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 40: Loss = 4028000779.2145195, Q_grad_norm = 566141.4186487757, R_grad_norm = 8213210.618420382, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 41: Loss = 3949352083.028952, Q_grad_norm = 550755.2202965793, R_grad_norm = 7891618.255915522, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 42: Loss = 3873567550.690697, Q_grad_norm = 536435.3712707654, R_grad_norm = 7588180.349597303, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 43: Loss = 3800506693.08533, Q_grad_norm = 523061.53160780115, R_grad_norm = 7301598.547757609, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 44: Loss = 3730036717.7085285, Q_grad_norm = 510536.46131617343, R_grad_norm = 7030683.222352857, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 45: Loss = 3662032179.813398, Q_grad_norm = 498780.5211108988, R_grad_norm = 6774343.518224294, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 46: Loss = 3596374567.9867682, Q_grad_norm = 487727.37172854337, R_grad_norm = 6531578.189808527, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 47: Loss = 3532951983.156283, Q_grad_norm = 477320.48627165105, R_grad_norm = 6301467.572205095, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 48: Loss = 3471658745.957352, Q_grad_norm = 467510.5226759343, R_grad_norm = 6083165.956096881, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 49: Loss = 3412395102.257381, Q_grad_norm = 458253.33131496806, R_grad_norm = 5875895.106150232, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 50: Loss = 3355066823.8598933, Q_grad_norm = 449508.52772947995, R_grad_norm = 5678938.007223504, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 51: Loss = 3299584947.929803, Q_grad_norm = 441238.587641148, R_grad_norm = 5491633.448285461, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 52: Loss = 3245865405.637994, Q_grad_norm = 433408.24362437584, R_grad_norm = 5313371.064820745, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 53: Loss = 3193828786.1809406, Q_grad_norm = 425984.2457569391, R_grad_norm = 5143586.8833352765, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 54: Loss = 3143400040.1781535, Q_grad_norm = 418935.2201503746, R_grad_norm = 4981759.276420607, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 55: Loss = 3094508189.4978027, Q_grad_norm = 412231.7802174723, R_grad_norm = 4827405.230433204, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 56: Loss = 3047086134.3128977, Q_grad_norm = 405846.57228937553, R_grad_norm = 4680077.18830024, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 57: Loss = 3001070379.112437, Q_grad_norm = 399754.47757433157, R_grad_norm = 4539359.936384947, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 58: Loss = 2956400835.7441974, Q_grad_norm = 393932.6672256611, R_grad_norm = 4404867.94926156, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 59: Loss = 2913020610.6253324, Q_grad_norm = 388360.71802662464, R_grad_norm = 4276242.918153062, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 60: Loss = 2870875814.7719646, Q_grad_norm = 383020.6313337291, R_grad_norm = 4153151.5381404418, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 61: Loss = 2829915381.301838, Q_grad_norm = 377896.7367209831, R_grad_norm = 4035283.4751066156, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 62: Loss = 2790090903.9950395, Q_grad_norm = 372975.66386389325, R_grad_norm = 3922349.5407393924, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 63: Loss = 2751369560.7543535, Q_grad_norm = 368139.7175257958, R_grad_norm = 3814097.984747661, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 64: Loss = 2713775864.312582, Q_grad_norm = 362837.77258403064, R_grad_norm = 3710368.508006495, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 65: Loss = 2677176683.397008, Q_grad_norm = 357813.70642072445, R_grad_norm = 3610798.6361605227, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 66: Loss = 2641533906.6203136, Q_grad_norm = 353048.86314685503, R_grad_norm = 3515170.5281538353, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 67: Loss = 2606811253.7254066, Q_grad_norm = 348526.7985007457, R_grad_norm = 3423280.257893774, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 68: Loss = 2572974161.460669, Q_grad_norm = 344232.9597482138, R_grad_norm = 3334936.735632529, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 69: Loss = 2539989688.7265816, Q_grad_norm = 340154.37452536135, R_grad_norm = 3249960.77005987, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 70: Loss = 2507826447.653602, Q_grad_norm = 336279.36643346597, R_grad_norm = 3168184.25178519, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 71: Loss = 2476454483.768415, Q_grad_norm = 332597.389366387, R_grad_norm = 3089449.2681438685, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 72: Loss = 2445845235.6738324, Q_grad_norm = 329098.8120321471, R_grad_norm = 3013607.469568211, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 73: Loss = 2415971449.2755365, Q_grad_norm = 325774.8064807266, R_grad_norm = 2940519.3384909625, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 74: Loss = 2386819837.927515, Q_grad_norm = 322529.56547849544, R_grad_norm = 2870068.8164014313, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 75: Loss = 2358397147.370074, Q_grad_norm = 319145.5552865029, R_grad_norm = 2802168.9732474466, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 76: Loss = 2330629054.380563, Q_grad_norm = 315966.108735716, R_grad_norm = 2736642.511588301, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 77: Loss = 2303493516.9822807, Q_grad_norm = 312977.3776587466, R_grad_norm = 2673380.6100964085, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 78: Loss = 2276969441.1998596, Q_grad_norm = 310166.93910768616, R_grad_norm = 2612280.5989281763, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 79: Loss = 2251036588.7847815, Q_grad_norm = 307523.6928067692, R_grad_norm = 2553245.468947392, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 80: Loss = 2225675557.848583, Q_grad_norm = 305037.7349573418, R_grad_norm = 2496183.549086136, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 81: Loss = 2200918853.6372986, Q_grad_norm = 302385.84707083297, R_grad_norm = 2441064.440481936, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 82: Loss = 2176693209.2516255, Q_grad_norm = 299909.4981363322, R_grad_norm = 2387743.838112871, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 83: Loss = 2152980789.455838, Q_grad_norm = 297603.5202610562, R_grad_norm = 2336143.470019785, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 84: Loss = 2129765699.3625336, Q_grad_norm = 295456.2184043212, R_grad_norm = 2286190.531728606, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 85: Loss = 2107032648.9468656, Q_grad_norm = 293457.4119986033, R_grad_norm = 2237815.992665105, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 86: Loss = 2084766910.147102, Q_grad_norm = 291598.2478090113, R_grad_norm = 2190954.347540627, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 87: Loss = 2062989946.7840197, Q_grad_norm = 289670.24337831285, R_grad_norm = 2145580.2525428073, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 88: Loss = 2041664724.794162, Q_grad_norm = 287804.2255310878, R_grad_norm = 2101609.5663072676, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 89: Loss = 2020760769.6023047, Q_grad_norm = 286090.6805651594, R_grad_norm = 2058968.2173610283, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 90: Loss = 2000265975.3543732, Q_grad_norm = 284519.3082772663, R_grad_norm = 2017603.5596954026, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 91: Loss = 1980168633.3083196, Q_grad_norm = 283081.1803825348, R_grad_norm = 1977465.4875788363, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 92: Loss = 1960469553.9755368, Q_grad_norm = 281704.76252735365, R_grad_norm = 1938518.2108050925, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 93: Loss = 1941173970.8321483, Q_grad_norm = 280301.3160763746, R_grad_norm = 1900731.6729639717, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 94: Loss = 1922238767.2888656, Q_grad_norm = 279036.5241686112, R_grad_norm = 1864030.3217370787, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 95: Loss = 1903654176.7994034, Q_grad_norm = 277901.20522878424, R_grad_norm = 1828373.4026978405, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 96: Loss = 1885438128.7445185, Q_grad_norm = 276750.5698946697, R_grad_norm = 1793747.9537637352, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 97: Loss = 1867561395.88332, Q_grad_norm = 275681.2066710509, R_grad_norm = 1760097.266519884, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 98: Loss = 1850004074.344493, Q_grad_norm = 274741.0811140182, R_grad_norm = 1727375.4196005904, f_grad_norm = 0.0, h_grad_norm = 0.0\n",
      "Epoch 99: Loss = 1832757908.9235291, Q_grad_norm = 273921.7783331637, R_grad_norm = 1695548.9407187204, f_grad_norm = 0.0, h_grad_norm = 0.0\n"
     ]
    }
   ],
   "source": [
    "output = pekf.tune(\n",
    "    z=z_tune,\n",
    "    x0=x0,\n",
    "    P0=P0,\n",
    "    sv_pos=sv_pos_tune,\n",
    "    lr_scheduler=None,\n",
    "    max_epochs=MAX_EPOCHES,\n",
    "    clip_grad_norm=1000,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot the log likelihood of the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'Loss vs Epoch')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot the loss\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set_theme(\n",
    "    style=\"darkgrid\"\n",
    ")\n",
    "loss_plt , ax = plt.subplots(\n",
    "    figsize=(8, 5)\n",
    ")\n",
    "\n",
    "# Plot the loss\n",
    "ax.plot(output['losses'])\n",
    "\n",
    "# Set the labels\n",
    "ax.set_xlabel(\"Epoch\")\n",
    "ax.set_ylabel(\"Loss\")\n",
    "\n",
    "# Set the title\n",
    "ax.set_title(\"Loss vs Epoch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's predict using PEKF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outs = pekf.batch_smoothing(\n",
    "        z=z,\n",
    "        x0=x0,\n",
    "        P0=P0,\n",
    "        sv_pos=sv_pos,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This shows the autocorreleation function of the residuals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1000x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Import plot afc function\n",
    "from statsmodels.graphics.tsaplots import plot_acf\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "afc_plot , ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "\n",
    "plot_acf(\n",
    "    x=outs[\"innovation_residual\"][:, 0], ax=ax\n",
    ");\n",
    "plot_acf(\n",
    "    x=ekf_resudials[\"innovation_residual\"][:, 0], ax=ax,\n",
    ");\n",
    "\n",
    "ax.set_title(\"Innovation Residual ACF\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see the error plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pekf_states = pd.DataFrame(outs[\"x_smoothed\"], columns=ParametricExtendedInterface.STATE_NAMES)\n",
    "ekf_states = pd.DataFrame(ekf_state, columns=ParametricExtendedInterface.STATE_NAMES)\n",
    "wls_states = pd.DataFrame(df_wls, columns=ParametricExtendedInterface.STATE_NAMES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We know the real coordinates since this is the IGS station."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "realCoords = pd.DataFrame(\n",
    "    data=Epoch.IGS_NETWORK.get_xyz(station=\"JPLM00USA\").reshape(1,-1).repeat(len(ctg0), axis=0),\n",
    "    columns=[\"x\", \"y\", \"z\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's calcualte the ENU error from this triangulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from navigator.core import  Triangulate\n",
    "\n",
    "\n",
    "enuError_PEKF = Triangulate.enu_error(\n",
    "    predicted=pekf_states,\n",
    "    actual=realCoords\n",
    ")\n",
    "enuError_EKF = Triangulate.enu_error(\n",
    "    predicted=ekf_states,\n",
    "    actual=realCoords\n",
    ")\n",
    "\n",
    "enuError_WLS = Triangulate.enu_error(\n",
    "    predicted=wls_states,\n",
    "    actual=realCoords\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The boxen plot of the ENU error by model is shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 800x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig = Triangulate.enu_boxen_plot(\n",
    "    error_dict={\n",
    "        \"PEKF\": enuError_PEKF[20:],\n",
    "        \"EKF\": enuError_EKF[20:],\n",
    "        \"WLS\": enuError_WLS[20:]\n",
    "    }\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's calculate the total error associated with the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "totalError = pd.DataFrame(\n",
    "{\n",
    "    \"PEKF\": enuError_PEKF.apply(np.linalg.norm, axis=1),\n",
    "    \"EKF\": enuError_EKF.apply(np.linalg.norm, axis=1),\n",
    "    \"WLS\": enuError_WLS.apply(np.linalg.norm, axis=1)\n",
    "}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot the total error distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1000x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot the total error\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 5))\n",
    "sns.set_style(\"darkgrid\")\n",
    "sns.boxenplot(data=totalError[30:], ax=ax)\n",
    "\n",
    "# Set the labels\n",
    "ax.set_xlabel(\"Method\")\n",
    "ax.set_ylabel(\"Total Error (m)\")\n",
    "\n",
    "# Set the title\n",
    "ax.set_title(\"Positioning Error Comparison between PEKF, EKF and WLS\")\n",
    "\n",
    "# Save the figure\n",
    "# fig.savefig(\"total_error_comparison.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "navigator-96XIUzCO-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}