Optimization Modules
This section documents the JIT-compiled Gauss–Newton solvers and wrappers used inside DSG-JIT.
optimization.solvers
Nonlinear optimization solvers for DSG-JIT.
This module implements the core iterative solvers used throughout the system, with a focus on JAX-friendly, JIT-compilable routines that operate on flat state vectors and manifold-aware blocks (e.g., SE(3) poses).
The solvers are designed to work with residual functions produced by
core.factor_graph.FactorGraph, and are used in:
• Pure SE3 SLAM chains
• Voxel grid smoothness / observation problems
• Hybrid SE3 + voxel joint optimization
• Differentiable experiments where measurements or weights are learned
Key Concepts
GNConfig Dataclass holding configuration for Gauss–Newton: - max_iters: maximum number of GN iterations - damping: Levenberg–Marquardt-style damping - max_step_norm: optional clamp on update step size - verbose / debug flags (if enabled)
gauss_newton(residual_fn, x0, cfg) Classic Gauss–Newton on a flat Euclidean state: - residual_fn: r(x) -> (m,) JAX array - x0: initial state - cfg: GNConfig
Computes updates using normal equations:
Jᵀ J Δx = -Jᵀ r
and returns the optimized state.
gauss_newton_manifold(residual_fn, x0, block_slices, manifold_types, cfg) Manifold-aware Gauss–Newton: - residual_fn: r(x) -> (m,) - x0: initial flat state vector - block_slices: NodeId -> slice in x - manifold_types: NodeId -> {"se3", "euclidean", ...} - cfg: GNConfig
For SE3 blocks:
• The update is computed in the tangent space (se(3))
• Applied via retract / exponential map
• Ensures updates stay on the manifold
For Euclidean blocks:
• Updates are applied additively.
Design Goals
• Fully JAX-compatible: All heavy operations are written in terms of JAX primitives so that solvers can be JIT-compiled and differentiated through when needed.
• Stable and controlled: Optional damping and step-norm clamping help avoid NaNs and divergence in difficult configurations (e.g., bad initialization or large residuals).
• Reusable:
Experiments and higher-level training loops (e.g., in experiments/
and optimization/jit_wrappers.py) call into these solvers as the
core iterative engine for DSG-JIT.
Notes
These solvers are intentionally minimal and generic. They do not know anything about SE3 or voxels directly; instead, they rely on the factor graph and manifold metadata to interpret the state vector correctly.
If you add new manifold types (e.g., quaternions or higher-dimensional poses), extend the manifold handling logic in the manifold-aware solver.
damped_newton(objective, x0, cfg)
Damped Newton optimizer for small problems.
Uses a Levenberg–Marquardt-style update::
(H + λ I) \delta = abla f(x)
x_{k+1} = x_k - \delta
where H is the Hessian of the objective and λ is a damping factor.
:param objective: Objective function f(x) that maps a state vector to a scalar loss.
:type objective: Callable[[jnp.ndarray], jnp.ndarray]
:param x0: Initial state vector.
:type x0: jnp.ndarray
:param cfg: Newton solver configuration (number of iterations and damping).
:type cfg: NewtonConfig
:return: Optimized state vector after damped Newton iterations.
:rtype: jnp.ndarray
Source code in dsg-jit/dsg_jit/optimization/solvers.py
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | |
gauss_newton(residual_fn, x0, cfg)
Gauss–Newton on a residual function r(x): R^n -> R^m.
The algorithm forms the normal equations::
J^T J \delta = J^T r
x_{k+1} = x_k - \delta
with optional diagonal damping and step-size clamping for stability.
:param residual_fn: Residual function r(x) returning a 1D array of shape (m,).
:type residual_fn: Callable[[jnp.ndarray], jnp.ndarray]
:param x0: Initial state vector of shape (n,).
:type x0: jnp.ndarray
:param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp).
:type cfg: GNConfig
:return: Optimized state vector after Gauss–Newton iterations.
:rtype: jnp.ndarray
Source code in dsg-jit/dsg_jit/optimization/solvers.py
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | |
gauss_newton_manifold(residual_fn, x0, block_slices, manifold_types, cfg)
Manifold-aware Gauss–Newton solver.
This variant still solves in a flat parameter space, but applies updates block-wise using the appropriate manifold retraction. In particular:
- Blocks marked as
"se3"are updated viase3_retract_leftin the Lie algebrase(3). - Blocks marked as
"euclidean"are updated additively.
:param residual_fn: Residual function r(x) returning a 1D array of shape (m,).
:type residual_fn: Callable[[jnp.ndarray], jnp.ndarray]
:param x0: Initial flat state vector of shape (n,).
:type x0: jnp.ndarray
:param block_slices: Mapping from node identifier to slice in x defining that variable's block.
May be a dict or a sequence of (node_id, slice) pairs.
:type block_slices: Union[Mapping[Any, slice], Sequence[Tuple[Any, slice]]]
:param manifold_types: Mapping from node identifier to manifold label (e.g. "se3" or "euclidean").
May be a dict or a sequence of (node_id, manifold_type) pairs.
:type manifold_types: Union[Mapping[Any, str], Sequence[Tuple[Any, str]]]
:param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp).
:type cfg: GNConfig
:return: Optimized state vector after manifold-aware Gauss–Newton iterations.
:rtype: jnp.ndarray
Source code in dsg-jit/dsg_jit/optimization/solvers.py
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | |
gradient_descent(objective, x0, cfg)
Simple gradient descent optimizer.
Performs iterative updates of the form::
x_{k+1} = x_k - learning_rate *
abla f(x_k)
until ``max_iters`` is reached.
:param objective: Objective function ``f(x)`` that maps a state vector to a scalar loss.
:type objective: Callable[[jnp.ndarray], jnp.ndarray]
:param x0: Initial state vector.
:type x0: jnp.ndarray
:param cfg: Gradient-descent configuration (learning rate and number of iterations).
:type cfg: GDConfig
:return: Optimized state vector after gradient descent.
:rtype: jnp.ndarray
Source code in dsg-jit/dsg_jit/optimization/solvers.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | |
optimization.jit_wrappers
JIT-friendly optimization wrappers and training utilities for DSG-JIT.
This module provides higher-level utilities that sit on top of the core
solvers in optimization.solvers. They are responsible for:
• Building JIT-compiled solve functions for a fixed world model-backed
factor graph
• Wrapping Gauss–Newton in a functional interface (solve(x0) -> x_opt)
• Supporting differentiable inner loops for meta-learning experiments
• Implementing simple trainer-style loops used in Phase 4 experiments
Typical Usage
The experiments in experiments/ use this module to:
• Construct a `WorldModel`-backed factor graph (SE3, voxels, hybrid)
• Get a JIT-compiled residual or objective from the world model
(e.g., via :meth:`WorldModel.build_residual`, which internally groups
factors by type and shape and uses :func:`jax.vmap` for efficiency)
• Build a `solve_once(x0)` function using Gauss–Newton
• Use `jax.grad` or `jax.value_and_grad` over an outer loss that depends
on the optimized state
Example patterns include:
• Learning SE3 odometry measurements by backpropagating through the
inner Gauss–Newton solve
• Learning voxel observation points that make a grid consistent with
known ground-truth centers
• Learning factor-type weights (log-scales) for odometry vs. observations
via supervised losses on final poses/voxels
Key Utilities (typical contents)
build_jit_gauss_newton(...) Given a WorldModel and a GNConfig, returns a JIT-compiled function: solve_once(x0) -> x_opt
build_param_residual(...)
Wraps a residual function so that it depends both on the state x and
on learnable parameters theta (e.g., measurements, observation points).
DSGTrainer (if present) A lightweight helper class implementing: - inner_solve(theta): run Gauss–Newton or GD on the graph - loss(theta): compute a supervised loss on the optimized state - step(theta): one gradient step on theta
Design Goals
• Separate concerns:
The low-level solver logic lives in solvers.py, while experiment-
specific JIT wiring and training loops live here.
• Encourage functional patterns: All wrappers aim to expose pure functions that JAX can JIT and differentiate, avoiding hidden state and side effects.
• Make research experiments easy: This is the layer where new meta-learning or differentiable-graph experiments should be prototyped before they are promoted into a more general API.
Notes
Because these wrappers are tailored to DSG-JIT’s factor graph structure, they assume:
• Residual functions derived from :class:`WorldModel`, e.g.
:meth:`WorldModel.build_residual` and its hyper-parameterized
variants
• State vectors packed/unpacked via the world model’s core graph
machinery (``WorldModel.pack_state`` / ``WorldModel.unpack_state``)
When modifying or extending this module, take care to preserve JIT and grad-friendliness: avoid Python-side mutation inside jitted functions and keep logic purely functional wherever possible.
JittedGN(fn, cfg)
dataclass
JIT-compiled Gauss–Newton solver for a fixed world model-backed factor graph.
Note
This wrapper targets the Euclidean solver :func:gauss_newton. For
SE(3)/manifold problems use :class:JittedGNManifold instead.
This lightweight wrapper stores a jitted solve function and the configuration used to build it. Typical usage:
residual_fn = wm.build_residual() # vmap-optimized residual
cfg = GNConfig(...)
jgn = JittedGN.from_residual(residual_fn, cfg)
x_opt = jgn(x0)
:param fn: JIT-compiled function that maps an initial state
vector x0 to an optimized state x_opt.
:param cfg: Gauss–Newton configuration used when building
the jitted solver.
__call__(x0)
Run the jitted Gauss–Newton solve on an initial state.
:param x0: Initial flat state vector to optimize. :return: Optimized state vector after running Gauss–Newton.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
123 124 125 126 127 128 129 | |
from_residual(residual_fn, cfg)
staticmethod
Construct a :class:JittedGN from a residual function.
This wraps :func:gauss_newton with the provided configuration
and JIT-compiles the resulting solve(x0) function.
:param residual_fn: Residual function r(x) returning the stacked
residual vector for a fixed factor graph.
:param cfg: Gauss–Newton configuration (step limits, damping, etc.).
:return: A :class:JittedGN instance whose __call__ method
runs the jitted Gauss–Newton solve.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
from_world_model(wm, cfg)
staticmethod
Construct a :class:JittedGN directly from a :class:WorldModel.
This helper calls :meth:WorldModel.build_residual to obtain the
vmap-optimized residual function for the current world, and then
wraps it in a jitted Gauss–Newton solve.
Typical usage::
wm = WorldModel()
# ... add variables, factors, register residuals ...
jgn = JittedGN.from_world_model(wm, GNConfig(max_iters=20))
x0, _ = wm.pack_state()
x_opt = jgn(x0)
:param wm: World model whose factor graph defines the optimization
problem. Its :meth:build_residual method is used to
obtain the residual function.
:param cfg: Gauss–Newton configuration (step limits, damping, etc.).
:return: A :class:JittedGN instance whose __call__ method
runs the jitted Gauss–Newton solve using the world model’s
vmap-optimized residual.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | |
JittedGNManifold(fn, cfg)
dataclass
JIT-compiled manifold Gauss–Newton solver for a fixed graph.
This wrapper is intended for SLAM-style problems where the packed state vector is a concatenation of manifold variables (e.g., SE(3) poses and R^3 landmarks). It closes over the residual function and manifold metadata and returns a single jitted solve function.
Typical usage::
residual_fn = wm.build_residual()
manifold_types, block_slices = build_manifold_metadata(...)
cfg = GNConfig(max_iters=1)
jgn = JittedGNManifold.from_residual(residual_fn, manifold_types, block_slices, cfg)
x_opt = jgn(x0)
IMPORTANT
To avoid repeated compilation, construct this once and reuse it for
every incremental step. Ensure the shapes/dtypes of x0 and the
residual output remain constant across steps (template mode).
:param fn: JIT-compiled function mapping x0 -> x_opt.
:param cfg: Gauss–Newton configuration.
__call__(x0)
Run the jitted manifold Gauss–Newton solve.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
215 216 217 | |
from_residual(residual_fn, manifold_types, block_slices, cfg)
staticmethod
Construct a :class:JittedGNManifold from residual + metadata.
:param residual_fn: Residual function r(x).
:param manifold_types: Per-block manifold type strings.
:param block_slices: Per-block slices into the packed vector.
:param cfg: Solver configuration.
:return: A reusable, jitted solver.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | |
from_world_model(wm, manifold_types, block_slices, cfg)
staticmethod
Construct a manifold GN solver directly from a :class:WorldModel.
This helper obtains the residual via :meth:WorldModel.build_residual.
:param wm: World model. :param manifold_types: Per-block manifold types. :param block_slices: Per-block slices. :param cfg: Solver configuration. :return: A reusable, jitted solver.
Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 | |