Skip to content

SLAM Modules

This section documents the differentiable SE(3) manifold operations, factor residuals, and SLAM-oriented measurement models.


slam.manifold

Manifold utilities for SE(3) and Euclidean variables in DSG-JIT.

This module centralizes the geometric logic needed by manifold-aware optimization routines, in particular:

• SE(3) exponential / logarithm maps
• Retraction and local parameterization for poses
• Jacobian-friendly helpers for composing / inverting SE(3)
• Metadata that maps variable types to their manifold model
  (e.g. "pose_se3" → "se3", "place1d" → "euclidean")

The core idea is that the optimizer (e.g. Gauss–Newton) should work in a local tangent space while the state lives on a manifold (SE(3) for poses, ℝⁿ for Euclidean variables). This module provides:

• Primitive SE(3) operations:
    - `se3_exp`, `se3_log`         (tangent ↔ group)
    - `so3_exp`, `so3_log`         (rotation-only)
    - `relative_pose_se3`          (pose_a⁻¹ ∘ pose_b)
    - `se3_retract`                (pose ⊕ δξ update rule)

• Manifold metadata helpers:
    - `TYPE_TO_MANIFOLD`           (str → {"se3", "euclidean", ...})
    - `get_manifold_for_var_type`
    - `build_manifold_metadata`    (NodeId → slice, manifold type)

Integration with the Optimizer

optimization.solvers.gauss_newton_manifold uses this module to:

1. Split the global state vector into blocks per variable.
2. Decide which update rule to apply:
    - SE(3) retraction for "pose_se3" blocks
    - Plain addition for Euclidean blocks
3. Keep the core solver logic generic while remaining
   numerically stable on curved manifolds.

Design Goals

Numerical stability: Use small-angle fallbacks and well-conditioned SE(3) operations to avoid NaNs in optimization and differentiation.

Separation of concerns: The factor graph and residuals should not hard-code SE(3) math; all manifold operations live here, behind a clean API.

JAX-friendly: All functions are written in a way that is compatible with JIT compilation, jax.grad, and jax.jvp / vmap.

Notes

This module currently focuses on SE(3) + Euclidean manifolds, but the design allows extending to other manifolds (e.g. SO(2), quaternions, Lie groups for velocities) by:

• Adding new entries to `TYPE_TO_MANIFOLD`
• Implementing the corresponding retract / exp / log primitives
• Extending the manifold-aware solver dispatch if needed

build_manifold_metadata(packed_state, fg)

Build manifold metadata for a factor graph.

This function inspects the variables in a :class:~core.factor_graph.FactorGraph and constructs two lookup tables that are consumed by manifold-aware solvers such as :func:optimization.solvers.gauss_newton_manifold:

  • block_slices maps each :class:~core.types.NodeId to a :class:slice in the packed state vector.
  • manifold_types maps each :class:~core.types.NodeId to a short string describing the manifold model (e.g. "se3" or "euclidean").

The indices produced by :meth:core.factor_graph.FactorGraph.pack_state may be stored either as slices or as (start, length) tuples; this helper normalizes everything to proper Python slice objects so the solver does not need to handle multiple formats.

:param fg: The factor graph whose variables should be analyzed to construct manifold metadata. :return: A tuple (block_slices, manifold_types) where block_slices is a mapping from :class:~core.types.NodeId to :class:slice in the flat state vector, and manifold_types is a mapping from :class:~core.types.NodeId to a manifold model name string.

Source code in dsg-jit/dsg_jit/slam/manifold.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def build_manifold_metadata(
    packed_state: jnp.ndarray,
    fg: FactorGraph
) -> Tuple[Dict[NodeId, slice], Dict[NodeId, str]]:
    """Build manifold metadata for a factor graph.

    This function inspects the variables in a :class:`~core.factor_graph.FactorGraph`
    and constructs two lookup tables that are consumed by
    manifold-aware solvers such as
    :func:`optimization.solvers.gauss_newton_manifold`:

    * ``block_slices`` maps each :class:`~core.types.NodeId` to a
      :class:`slice` in the packed state vector.
    * ``manifold_types`` maps each :class:`~core.types.NodeId` to a
      short string describing the manifold model (e.g. ``"se3"`` or
      ``"euclidean"``).

    The indices produced by :meth:`core.factor_graph.FactorGraph.pack_state`
    may be stored either as slices or as ``(start, length)`` tuples;
    this helper normalizes everything to proper Python ``slice``
    objects so the solver does not need to handle multiple formats.

    :param fg: The factor graph whose variables should be analyzed to
        construct manifold metadata.
    :return: A tuple ``(block_slices, manifold_types)`` where
        ``block_slices`` is a mapping from :class:`~core.types.NodeId`
        to :class:`slice` in the flat state vector, and
        ``manifold_types`` is a mapping from :class:`~core.types.NodeId`
        to a manifold model name string.
    """
    _, index = packed_state

    block_slices: Dict[NodeId, slice] = {}
    manifold_types: Dict[NodeId, str] = {}

    for nid, var in fg.variables.items():
        idx = index[nid]

        # Normalize to a slice
        if isinstance(idx, slice):
            sl = idx
        else:
            # assume (start, length)
            start, length = idx
            sl = slice(start, start + length)

        block_slices[nid] = sl

        manifold = get_manifold_for_var_type(var.type)
        manifold_types[nid] = manifold

    return block_slices, manifold_types

get_manifold_for_var_type(var_type)

Return the manifold model name for a given variable type.

This is a thin helper around :data:TYPE_TO_MANIFOLD that maps a high-level variable type tag (e.g. "pose_se3", "place1d") to the underlying manifold model used by manifold-aware solvers.

:param var_type: Variable type string, such as "pose_se3", "place1d", "room1d", "landmark3d" or "voxel_cell". :return: The manifold model name (for example "se3" or "euclidean"). If the type is unknown, "euclidean" is returned by default.

Source code in dsg-jit/dsg_jit/slam/manifold.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def get_manifold_for_var_type(var_type: str) -> str:
    """Return the manifold model name for a given variable type.

    This is a thin helper around :data:`TYPE_TO_MANIFOLD` that maps a
    high-level variable type tag (e.g. ``"pose_se3"``, ``"place1d"``)
    to the underlying manifold model used by manifold-aware solvers.

    :param var_type: Variable type string, such as ``"pose_se3"``,
        ``"place1d"``, ``"room1d"``, ``"landmark3d"`` or
        ``"voxel_cell"``.
    :return: The manifold model name (for example ``"se3"`` or
        ``"euclidean"``). If the type is unknown, ``"euclidean"`` is
        returned by default.
    """
    return TYPE_TO_MANIFOLD.get(var_type, "euclidean")

slam.measurements

Residual models (measurement factors) for DSG-JIT.

This module defines the measurement-level building blocks used by the factor graph:

• Each function here implements a residual:
      r(x; params) ∈ ℝᵏ
  compatible with JAX differentiation and JIT compilation.

• Factor types in the graph (e.g. "prior", "odom_se3_geodesic",
  "voxel_point_obs") are mapped to these residual functions via
  `FactorGraph.register_residual`.

Broadly, the residuals fall into several families:

1. Priors and Simple Euclidean Factors

• `prior_residual`:
    Generic prior on any variable:
        r = x − target

Useful for:
    - anchoring poses (pose0 ≈ identity)
    - clamping scalar variables (places, rooms, weights, etc.)

2. SE(3) / SLAM-Style Motion Factors

• `odom_se3_geodesic_residual`:
    SE(3) relative pose constraint using the group logarithm:
        r = log( meas⁻¹ ∘ (T_i⁻¹ ∘ T_j) )

    Works on "pose_se3" variables and lives in se(3) (6D tangent).

• (Optionally) additive variants:
    - `odom_se3_additive_residual`
      for simpler experiments where translation/rotation are treated
      additively in ℝ⁶.

These encode frame-to-frame odometry, loop closures, and generic relative pose constraints between SE(3) nodes.

3. Landmark and Attachment Factors

• `pose_landmark_relative_residual`:
    Relative pose between a SE(3) pose and a landmark position,
    typically enforcing:
        T_pose ∘ landmark ≈ measurement

• `pose_landmark_bearing_residual`:
    Bearing-only constraint between a pose and a landmark (e.g.,
    enforcing angular consistency between measurement and predicted
    direction).

• `pose_place_attachment_residual`:
    Softly attaches a pose coordinate (e.g. x) to a 1D "place"
    variable, used for 1D topological / metric alignment.

These connect metric states (poses, landmarks, places) into a coherent SLAM + scene-graph representation.

4. Voxel Grid / Volumetric Factors

• `voxel_smoothness_residual`:
    Encourages neighboring voxel centers to form a smooth chain or
    grid. Used to regularize voxel grids representing surfaces or
    1D/2D/3D structures.

• `voxel_point_observation_residual`:
    Ties a voxel cell to an observed point in world coordinates,
    often used for learning voxel positions from point-like
    observations.

These factors are key to the differentiable voxel experiments and hybrid SE3 + voxel benchmarks.

5. Weighting and Noise Models

Most residuals support per-factor weightings via a shared helper:

• `_apply_weight(r, params)`:
    Applies scalar or diagonal weights to a residual, enabling:

        - Hand-tuned noise models (e.g. σ⁻¹)
        - Learnable factor-type weights (via log_scales)
        - Consistent scaling for multi-term objectives

This is what allows the engine to support learnable factor weights in Phase 4 experiments (e.g. learning odom vs. observation trade-offs).

Design Goals

Clear factor semantics: Each residual corresponds to a named factor type used throughout tests and experiments, so it’s obvious what each factor is doing.

Differentiable and JIT-friendly: All residuals are written to be compatible with jax.jit and jax.grad, enabling higher-level meta-learning and end-to-end differentiable training loops.

Composable: Residuals do not own the factor graph logic; they simply implement r(x; params). All graph structure, manifold handling, and joint optimization is handled in core.factor_graph, slam.manifold, and optimization.solvers.

Notes

When adding a new factor type:

1. Implement a residual here:
       def my_factor_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray

2. Register it with the factor graph:
       fg.register_residual("my_factor", my_factor_residual)

3. (Optionally) add tests under `tests/` and, if relevant, a
   differentiable experiment under `experiments/`.

This pattern keeps the measurement models centralized and makes the engine easy to extend for new research ideas.

object_at_pose_residual(x, params)

Residual tying a 3D object position to a pose translation.

Interprets x as [pose(6), object(3)] and encourages the object position to coincide with the pose translation plus an optional fixed offset.

:param x: Stacked state block [pose(6), object(3)]. :type x: jnp.ndarray :param params: Parameter dictionary containing integer fields "pose_dim" and "obj_dim", and optionally "offset" (a 3D vector) and a weight handled by :func:_apply_weight. :type params: dict :return: 3D residual object - (pose_translation + offset). :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def object_at_pose_residual(x: jnp.ndarray, params: dict) -> jnp.ndarray:
    """
    Residual tying a 3D object position to a pose translation.

    Interprets ``x`` as ``[pose(6), object(3)]`` and encourages the
    object position to coincide with the pose translation plus an
    optional fixed offset.

    :param x: Stacked state block ``[pose(6), object(3)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing integer fields
        ``"pose_dim"`` and ``"obj_dim"``, and optionally ``"offset"``
        (a 3D vector) and a weight handled by :func:`_apply_weight`.
    :type params: dict
    :return: 3D residual ``object - (pose_translation + offset)``.
    :rtype: jnp.ndarray
    """
    pose_dim = int(params["pose_dim"])
    obj_dim = int(params["obj_dim"])

    assert x.shape[0] == pose_dim + obj_dim

    pose = x[:pose_dim]
    obj = x[pose_dim : pose_dim + obj_dim]

    offset = params.get("offset", jnp.zeros(3))
    offset = jnp.asarray(offset).reshape(3,)

    t = pose[:3]  # tx, ty, tz
    r = obj - (t + offset)
    return _apply_weight(r, params)

odom_residual(x, params)

Simple odometry-style residual in Euclidean space.

Interprets x as a concatenation of two poses pose0 and pose1 in R^d and enforces an additive odometry relation

(pose1 - pose0) - measurement = 0.

:param x: Stacked pose vector [pose0, pose1]. :type x: jnp.ndarray :param params: Parameter dictionary containing "measurement" with the desired relative displacement. :type params: Dict[str, jnp.ndarray] :return: Euclidean odometry residual. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def odom_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Simple odometry-style residual in Euclidean space.

    Interprets ``x`` as a concatenation of two poses ``pose0`` and
    ``pose1`` in R^d and enforces an additive odometry relation

    ``(pose1 - pose0) - measurement = 0``.

    :param x: Stacked pose vector ``[pose0, pose1]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"measurement"`` with
        the desired relative displacement.
    :type params: Dict[str, jnp.ndarray]
    :return: Euclidean odometry residual.
    :rtype: jnp.ndarray
    """
    dim = x.shape[0] // 2
    pose0 = x[:dim]
    pose1 = x[dim:]
    meas = params["measurement"]
    return (pose1 - pose0) - meas

odom_se3_geodesic_residual(x, params)

Experimental SE(3) geodesic residual using relative_pose_se3.

Interprets x as two 6D poses in se(3) and uses :func:core.math3d.relative_pose_se3 to compute the estimated relative pose before subtracting the provided measurement.

:param x: Stacked pose vector [pose0(6), pose1(6)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "measurement" with the desired relative pose in se(3), and optionally a weight understood by :func:_apply_weight. :type params: Dict[str, jnp.ndarray] :return: Geodesic SE(3) odometry residual in se(3). :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def odom_se3_geodesic_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Experimental SE(3) geodesic residual using ``relative_pose_se3``.

    Interprets ``x`` as two 6D poses in se(3) and uses
    :func:`core.math3d.relative_pose_se3` to compute the estimated
    relative pose before subtracting the provided measurement.

    :param x: Stacked pose vector ``[pose0(6), pose1(6)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"measurement"`` with
        the desired relative pose in se(3), and optionally a weight
        understood by :func:`_apply_weight`.
    :type params: Dict[str, jnp.ndarray]
    :return: Geodesic SE(3) odometry residual in se(3).
    :rtype: jnp.ndarray
    """
    assert x.shape[0] == 12, "odom_se3_geodesic_residual expects two 6D poses stacked."

    pose0 = x[:6]
    pose1 = x[6:]
    meas = params["measurement"]

    xi_est = relative_pose_se3(pose0, pose1)
    r = xi_est - meas
    return _apply_weight(r, params)

odom_se3_residual(x, params)

SE(3)-style odometry residual in a 6D vector parameterization.

Treats each pose as a 6-vector [tx, ty, tz, wx, wy, wz] and a 6D measurement in the same parameterization. The residual is

(pose_j - pose_i) - measurement.

This is a simple additive model in R^6 and is used as the workhorse SE(3) chain factor in many experiments.

:param x: Stacked pose vector [pose_i(6), pose_j(6)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "measurement" with the desired relative pose in R^6. :type params: Dict[str, jnp.ndarray] :return: SE(3) odometry residual in R^6. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def odom_se3_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    SE(3)-style odometry residual in a 6D vector parameterization.

    Treats each pose as a 6-vector ``[tx, ty, tz, wx, wy, wz]`` and a
    6D measurement in the same parameterization. The residual is

    ``(pose_j - pose_i) - measurement``.

    This is a simple additive model in R^6 and is used as the workhorse
    SE(3) chain factor in many experiments.

    :param x: Stacked pose vector ``[pose_i(6), pose_j(6)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"measurement"`` with
        the desired relative pose in R^6.
    :type params: Dict[str, jnp.ndarray]
    :return: SE(3) odometry residual in R^6.
    :rtype: jnp.ndarray
    """
    #residual true Newton Solver for pure SE(3)
    #TODO replace with true Newton Solver
    dim = x.shape[0] // 2  # should be 6
    pose_i = x[:dim]
    pose_j = x[dim:]
    meas = params["measurement"]
    r = (pose_j - pose_i) - meas
    return r

pose_landmark_bearing_residual(x, params)

Bearing-only residual between a pose and a 3D landmark.

Interprets x as [pose(6), landmark(3)] and compares the predicted bearing from the pose to the landmark against a measured bearing vector.

:param x: Stacked state block [pose(6), landmark(3)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "bearing_meas" (a 3D bearing vector in the pose frame). Any weighting is applied upstream. :type params: dict :return: 3D residual bearing_pred - bearing_meas. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def pose_landmark_bearing_residual(
    x: jnp.ndarray,
    params: dict,
) -> jnp.ndarray:
    """
    Bearing-only residual between a pose and a 3D landmark.

    Interprets ``x`` as ``[pose(6), landmark(3)]`` and compares the
    predicted bearing from the pose to the landmark against a measured
    bearing vector.

    :param x: Stacked state block ``[pose(6), landmark(3)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"bearing_meas"``
        (a 3D bearing vector in the pose frame). Any weighting is
        applied upstream.
    :type params: dict
    :return: 3D residual ``bearing_pred - bearing_meas``.
    :rtype: jnp.ndarray
    """
    pose = x[:6]
    landmark = x[6:9]

    bearing_meas = params["bearing_meas"]  # (3,)

    T = se3_exp(pose)
    R = T[:3, :3]
    t = T[:3, 3]

    landmark_pose = R.T @ (landmark - t)

    def safe_normalize(v):
        n = jnp.linalg.norm(v)
        return v / (n + 1e-8)

    bearing_pred = safe_normalize(landmark_pose)
    bearing_meas = safe_normalize(bearing_meas)

    residual = bearing_pred - bearing_meas
    return residual

pose_landmark_relative_residual(x, params)

Relative pose–landmark residual in SE(3).

Interprets x as [pose(6), landmark(3)] and enforces that the landmark, expressed in the pose frame, matches a measured 3D point.

:param x: Stacked state block [pose(6), landmark(3)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "measurement" (a 3D point in the pose frame). Any weighting is applied upstream by :func:_apply_weight. :type params: dict :return: 3D residual between predicted and measured landmark positions in the pose frame. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def pose_landmark_relative_residual(
    x: jnp.ndarray,
    params: dict,
) -> jnp.ndarray:
    """
    Relative pose–landmark residual in SE(3).

    Interprets ``x`` as ``[pose(6), landmark(3)]`` and enforces that the
    landmark, expressed in the pose frame, matches a measured 3D point.

    :param x: Stacked state block ``[pose(6), landmark(3)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"measurement"``
        (a 3D point in the pose frame). Any weighting is applied
        upstream by :func:`_apply_weight`.
    :type params: dict
    :return: 3D residual between predicted and measured landmark
        positions in the pose frame.
    :rtype: jnp.ndarray
    """
    pose = x[:6]
    landmark = x[6:9]

    meas = params["measurement"]  # (3,)
    T = se3_exp(pose)
    R = T[:3, :3]
    t = T[:3, 3]

    # landmark expressed in pose frame
    landmark_pose = R.T @ (landmark - t)

    residual = landmark_pose - meas
    return residual  # weight is applied via _apply_weight in FactorGraph

pose_place_attachment_residual(x, params)

Residual tying a scalar place variable to one coordinate of a pose.

Interprets x as [pose, place] and enforces that the place value tracks a particular coordinate of the pose (e.g., x-position).

:param x: Stacked state block [pose, place]. :type x: jnp.ndarray :param params: Parameter dictionary with integer entries "pose_dim", "place_dim", and "pose_coord_index" indicating the layout of x and which pose coordinate to attach to. May also contain a weight handled by :func:_apply_weight. :type params: dict :return: 1D residual enforcing place[0] ≈ pose[pose_coord_index]. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def pose_place_attachment_residual(x: jnp.ndarray, params: dict) -> jnp.ndarray:
    """
    Residual tying a scalar place variable to one coordinate of a pose.

    Interprets ``x`` as ``[pose, place]`` and enforces that the place
    value tracks a particular coordinate of the pose (e.g., x-position).

    :param x: Stacked state block ``[pose, place]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary with integer entries
        ``"pose_dim"``, ``"place_dim"``, and ``"pose_coord_index"``
        indicating the layout of ``x`` and which pose coordinate to
        attach to. May also contain a weight handled by
        :func:`_apply_weight`.
    :type params: dict
    :return: 1D residual enforcing ``place[0] ≈ pose[pose_coord_index]``.
    :rtype: jnp.ndarray
    """
    pose_dim = int(params["pose_dim"])
    place_dim = int(params["place_dim"])
    coord_idx = int(params["pose_coord_index"])

    assert x.shape[0] == pose_dim + place_dim

    pose = x[:pose_dim]
    place = x[pose_dim : pose_dim + place_dim]

    # Make it 1D of length 1, not scalar
    r = jnp.array([place[0] - pose[coord_idx]])
    return _apply_weight(r, params)

pose_temporal_smoothness_residual(x, params)

Temporal smoothness residual between two SE(3) poses.

Interprets x as [pose_t, pose_t1] in R^6 and penalizes the difference pose_t1 - pose_t.

:param x: Stacked state block [pose_t(6), pose_t1(6)]. :type x: jnp.ndarray :param params: Parameter dictionary, optionally containing a weight handled by :func:_apply_weight. :type params: dict :return: 6D temporal smoothness residual. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def pose_temporal_smoothness_residual(x: jnp.ndarray, params: dict) -> jnp.ndarray:
    """
    Temporal smoothness residual between two SE(3) poses.

    Interprets ``x`` as ``[pose_t, pose_t1]`` in R^6 and penalizes the
    difference ``pose_t1 - pose_t``.

    :param x: Stacked state block ``[pose_t(6), pose_t1(6)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary, optionally containing a weight
        handled by :func:`_apply_weight`.
    :type params: dict
    :return: 6D temporal smoothness residual.
    :rtype: jnp.ndarray
    """
    dim = x.shape[0] // 2
    pose_t = x[:dim]
    pose_t1 = x[dim:]
    r = pose_t1 - pose_t
    return _apply_weight(r, params)

pose_voxel_point_residual(x, params)

Residual between a pose and a voxel center given a point measurement.

Interprets x as [pose(6), voxel_center(3)]. The measurement is a point expressed in the pose frame; it is projected into the world frame and compared against the voxel center.

:param x: Stacked state block [pose(6), voxel_center(3)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "point_meas" (a 3D point in the pose frame). Any weighting is applied upstream by :func:_apply_weight. :type params: dict :return: 3D residual voxel_center - predicted_world_point. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def pose_voxel_point_residual(
    x: jnp.ndarray,
    params: dict,
) -> jnp.ndarray:
    """
    Residual between a pose and a voxel center given a point measurement.

    Interprets ``x`` as ``[pose(6), voxel_center(3)]``. The measurement
    is a point expressed in the pose frame; it is projected into the
    world frame and compared against the voxel center.

    :param x: Stacked state block ``[pose(6), voxel_center(3)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"point_meas"``
        (a 3D point in the pose frame). Any weighting is applied
        upstream by :func:`_apply_weight`.
    :type params: dict
    :return: 3D residual ``voxel_center - predicted_world_point``.
    :rtype: jnp.ndarray
    """
    pose = x[:6]
    voxel = x[6:9]  # voxel center in world frame

    point_meas = params["point_meas"]  # (3,)

    T = se3_exp(pose)          # 4x4
    R = T[:3, :3]
    t = T[:3, 3]

    world_point = R @ point_meas + t  # predicted world point from this measurement

    residual = voxel - world_point
    return residual

prior_residual(x, params)

Simple prior on a single variable.

Computes residual = x - target for any vector dimension.

:param x: Current variable value (flattened state block). :type x: jnp.ndarray :param params: Parameter dictionary containing "target" and optionally a weight understood by :func:_apply_weight. :type params: Dict[str, jnp.ndarray] :return: Prior residual x - target (possibly reweighted). :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def prior_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Simple prior on a single variable.

    Computes ``residual = x - target`` for any vector dimension.

    :param x: Current variable value (flattened state block).
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"target"`` and
        optionally a weight understood by :func:`_apply_weight`.
    :type params: Dict[str, jnp.ndarray]
    :return: Prior residual ``x - target`` (possibly reweighted).
    :rtype: jnp.ndarray
    """
    target = params["target"]
    r = x - target
    return _apply_weight(r, params)

range_residual(x, params)

Range-only residual between a pose and a 3D target.

This residual assumes that x is the concatenation of a 6D SE(3) pose (in se(3) vector form) and a 3D target position::

x = [pose_se3(6), target(3)]

Only the translational part of the pose is used. The residual is::

r = ||target - t|| - r_meas

where t is the pose translation and r_meas is the measured range. A scalar weight is applied in the same way as other residuals via :func:_apply_weight.

:param x: Concatenated pose and target state, shape (9,). :param params: Parameter dictionary with keys: - "range": scalar or length-1 array containing the measured range. - "weight" (optional): scalar weight to apply. If omitted, a weight of 1.0 is used by :func:_apply_weight. :return: Residual vector of shape (1,) (after weighting).

Source code in dsg-jit/dsg_jit/slam/measurements.py
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def range_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Range-only residual between a pose and a 3D target.

    This residual assumes that ``x`` is the concatenation of a 6D SE(3)
    pose (in se(3) vector form) and a 3D target position::

        x = [pose_se3(6), target(3)]

    Only the translational part of the pose is used. The residual is::

        r = ||target - t|| - r_meas

    where ``t`` is the pose translation and ``r_meas`` is the measured
    range. A scalar weight is applied in the same way as other residuals
    via :func:`_apply_weight`.

    :param x: Concatenated pose and target state, shape ``(9,)``.
    :param params: Parameter dictionary with keys:
        - ``"range"``: scalar or length-1 array containing the
          measured range.
        - ``"weight"`` (optional): scalar weight to apply. If omitted,
          a weight of ``1.0`` is used by :func:`_apply_weight`.
    :return: Residual vector of shape ``(1,)`` (after weighting).
    """
    # Split state: first 6 are se(3) (pose), last 3 are 3D target position.
    pose = x[:6]
    target = x[6:9]

    # Translation component of the pose.
    t = pose[:3]

    # Euclidean distance between pose translation and target.
    diff = target - t
    dist = jnp.linalg.norm(diff)

    # Measured range can be a scalar or length-1 array.
    r_meas = params["range"]
    r_meas = jnp.array(r_meas, dtype=jnp.float32).reshape(())

    # Residual: predicted - measured.
    r = dist - r_meas

    # Wrap as 1D vector and apply weight.
    r_vec = jnp.array([r], dtype=jnp.float32)
    return _apply_weight(r_vec, params)

se3_chain_residual(x, params)

Alias for SE(3) chain / odometry residual used in visualization.

This is a thin wrapper around :func:odom_se3_residual, so that experiments and visualization code can refer to a semantically descriptive name ("se3_chain") without duplicating logic.

:param x: Stacked pose vector [pose_i(6), pose_j(6)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "measurement" with the desired relative pose in R^6. :type params: Dict[str, jnp.ndarray] :return: SE(3) chain residual produced by :func:odom_se3_residual. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def se3_chain_residual(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Alias for SE(3) chain / odometry residual used in visualization.

    This is a thin wrapper around :func:`odom_se3_residual`, so that
    experiments and visualization code can refer to a semantically
    descriptive name ("se3_chain") without duplicating logic.

    :param x: Stacked pose vector ``[pose_i(6), pose_j(6)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"measurement"`` with
        the desired relative pose in R^6.
    :type params: Dict[str, jnp.ndarray]
    :return: SE(3) chain residual produced by :func:`odom_se3_residual`.
    :rtype: jnp.ndarray
    """
    return odom_se3_residual(x, params)

sigma_to_weight(sigma)

Convert standard deviation(s) to an information-style weight.

For a scalar standard deviation sigma, this returns 1 / sigma**2. For a vector of standard deviations, it returns the elementwise inverse-variance 1 / sigma[i]**2.

:param sigma: Scalar or vector of standard deviations. :type sigma: Union[float, jnp.ndarray] :return: Scalar or vector of weights 1 / sigma**2. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def sigma_to_weight(sigma):
    """
    Convert standard deviation(s) to an information-style weight.

    For a scalar standard deviation ``sigma``, this returns ``1 / sigma**2``.
    For a vector of standard deviations, it returns the elementwise
    inverse-variance ``1 / sigma[i]**2``.

    :param sigma: Scalar or vector of standard deviations.
    :type sigma: Union[float, jnp.ndarray]
    :return: Scalar or vector of weights ``1 / sigma**2``.
    :rtype: jnp.ndarray
    """
    s = jnp.asarray(sigma)
    return 1.0 / (s * s)

voxel_point_observation_residual(x, params)

Observation factor tying a voxel center to a world-frame point.

Interprets x as [voxel_center(3)] and encourages it to match an observed point in world coordinates.

:param x: State block containing a single voxel center. :type x: jnp.ndarray :param params: Parameter dictionary containing "point_world" (a 3D point in the world frame). Any weighting is applied upstream by :func:_apply_weight. :type params: dict :return: 3D residual voxel_center - point_world. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
def voxel_point_observation_residual(
    x: jnp.ndarray,
    params: dict,
) -> jnp.ndarray:
    """
    Observation factor tying a voxel center to a world-frame point.

    Interprets ``x`` as ``[voxel_center(3)]`` and encourages it to match
    an observed point in world coordinates.

    :param x: State block containing a single voxel center.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"point_world"``
        (a 3D point in the world frame). Any weighting is applied
        upstream by :func:`_apply_weight`.
    :type params: dict
    :return: 3D residual ``voxel_center - point_world``.
    :rtype: jnp.ndarray
    """
    voxel = x[:3]
    point_world = params["point_world"]  # (3,)
    return voxel - point_world

voxel_smoothness_residual(x, params)

Smoothness / grid regularity constraint between two voxel centers.

Interprets x as [voxel_i(3), voxel_j(3)] and penalizes the deviation from an expected offset between neighboring voxels.

:param x: Stacked state block [voxel_i(3), voxel_j(3)]. :type x: jnp.ndarray :param params: Parameter dictionary containing "offset" (a 3D expected difference voxel_j - voxel_i) and optionally a weight handled by :func:_apply_weight. :type params: dict :return: 3D residual (voxel_j - voxel_i) - offset. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/slam/measurements.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def voxel_smoothness_residual(
    x: jnp.ndarray,
    params: dict,
) -> jnp.ndarray:
    """
    Smoothness / grid regularity constraint between two voxel centers.

    Interprets ``x`` as ``[voxel_i(3), voxel_j(3)]`` and penalizes the
    deviation from an expected offset between neighboring voxels.

    :param x: Stacked state block ``[voxel_i(3), voxel_j(3)]``.
    :type x: jnp.ndarray
    :param params: Parameter dictionary containing ``"offset"`` (a 3D
        expected difference ``voxel_j - voxel_i``) and optionally a
        weight handled by :func:`_apply_weight`.
    :type params: dict
    :return: 3D residual ``(voxel_j - voxel_i) - offset``.
    :rtype: jnp.ndarray
    """
    voxel_i = x[:3]
    voxel_j = x[3:6]

    offset = params["offset"]  # (3,)

    residual = (voxel_j - voxel_i) - offset
    return residual

slam.pipeline

High-level SLAM pipelines built on top of DSG-JIT.

This module provides small, composable helpers that glue together:

  • WorldModel / SceneGraphWorld
  • Sensors + SensorFusionManager
  • FactorGraph + Gauss-Newton optimizer

The intent is that experiments (or ROS2 nodes) call into these functions rather than reimplementing the same boilerplate in every file.

PoseGraphResult(x_opt, pose_ids, landmark_ids=None) dataclass

Result of a pose-graph SLAM solve.

:param x_opt: Optimized stacked state vector. :type x_opt: jax.numpy.ndarray :param pose_ids: List of node ids corresponding to poses in the graph. :type pose_ids: list[int] :param landmark_ids: Optional list of landmark node ids, if present. :type landmark_ids: list[int] | None

pose_vectors_from_result(wm, result)

Extract SE(3) pose vectors from an optimized solution.

:param wm: The world model that owns the variables. :type wm: world.model.WorldModel :param result: Optimization result describing pose node ids and the stacked state. :type result: PoseGraphResult

:return: Mapping from pose node id -> pose vector (6,). :rtype: dict[int, jax.numpy.ndarray]

Source code in dsg-jit/dsg_jit/slam/pipeline.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def pose_vectors_from_result(
    wm: WorldModel,
    result: PoseGraphResult,
) -> Dict[int, jnp.ndarray]:
    """
    Extract SE(3) pose vectors from an optimized solution.

    :param wm:
        The world model that owns the variables.
    :type wm: world.model.WorldModel
    :param result:
        Optimization result describing pose node ids and the stacked state.
    :type result: PoseGraphResult

    :return:
        Mapping from pose node id -> pose vector (6,).
    :rtype: dict[int, jax.numpy.ndarray]
    """
    x_opt = result.x_opt
    _, index = wm.pack_state()

    out: Dict[int, jnp.ndarray] = {}
    for nid in result.pose_ids:
        sl = index[nid]
        out[nid] = x_opt[sl]
    return out

run_pose_graph_slam(wm, cfg=None)

Run Gauss-Newton on the pose/landmark graph in wm.fg.

This treats whatever is currently in the WorldModel's factor graph as the SLAM problem. It does not modify wm in-place; it returns the optimized stacked state and helper lists for extracting poses/landmarks.

Typical usage:

.. code-block:: python

result = run_pose_graph_slam(wm)
poses = [result.x_opt[index[nid]] for nid in result.pose_ids]

:param wm: The world model containing a FactorGraph with SE(3) pose variables (and optionally landmark variables) plus factors (odom, priors, range/bearing, etc.). :type wm: world.model.WorldModel :param cfg: Configuration for the Gauss-Newton solver. If None, a default GNConfig is used. :type cfg: core.types.GNConfig | None

:return: A :class:PoseGraphResult containing the optimized stacked state vector and node-id lists for poses and landmarks. :rtype: PoseGraphResult

Source code in dsg-jit/dsg_jit/slam/pipeline.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def run_pose_graph_slam(
    wm: WorldModel,
    cfg: GNConfig | None = None,
) -> PoseGraphResult:
    """
    Run Gauss-Newton on the pose/landmark graph in ``wm.fg``.

    This treats whatever is currently in the WorldModel's factor graph as
    the SLAM problem. It does not modify ``wm`` in-place; it returns the
    optimized stacked state and helper lists for extracting poses/landmarks.

    Typical usage:

    .. code-block:: python

        result = run_pose_graph_slam(wm)
        poses = [result.x_opt[index[nid]] for nid in result.pose_ids]

    :param wm:
        The world model containing a FactorGraph with SE(3) pose variables
        (and optionally landmark variables) plus factors (odom, priors,
        range/bearing, etc.).
    :type wm: world.model.WorldModel
    :param cfg:
        Configuration for the Gauss-Newton solver. If ``None``, a default
        ``GNConfig`` is used.
    :type cfg: core.types.GNConfig | None

    :return:
        A :class:`PoseGraphResult` containing the optimized stacked state
        vector and node-id lists for poses and landmarks.
    :rtype: PoseGraphResult
    """
    if cfg is None:
        cfg = GNConfig()

    fg: FactorGraph = wm.fg

    # Pack initial state
    x0, _ = wm.pack_state()
    residual_fn = wm.build_residual()

    # Manifold types: we already stored these per-variable in the graph.
    manifold_types = build_manifold_metadata(packed_state=wm.pack_state(),fg=fg)



    # Solve
    x_opt = gauss_newton_manifold(
        residual_fn,
        x0,
        manifold_types,
        cfg,
    )

    # Build convenience lists for poses / landmarks.
    pose_ids: List[int] = []
    landmark_ids: List[int] = []

    for nid, v in wm.fg.variables.items():
        if v.manifold == "se3":
            pose_ids.append(nid)
        elif v.manifold == "R3":
            landmark_ids.append(nid)

    return PoseGraphResult(
        x_opt=x_opt,
        pose_ids=sorted(pose_ids),
        landmark_ids=sorted(landmark_ids) if landmark_ids else None,
    )

update_worldmodel_from_solution(wm, result)

Write optimized variables from a :class:PoseGraphResult back into wm.

This is a small helper so that downstream code (DSG construction, visualization, dataset export) can reflect the optimized state.

:param wm: The world model whose factor-graph variables will be updated in-place. :type wm: world.model.WorldModel :param result: Output from :func:run_pose_graph_slam, containing the optimized stacked state vector. :type result: PoseGraphResult

Source code in dsg-jit/dsg_jit/slam/pipeline.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def update_worldmodel_from_solution(wm: WorldModel, result: PoseGraphResult) -> None:
    """
    Write optimized variables from a :class:`PoseGraphResult` back into ``wm``.

    This is a small helper so that downstream code (DSG construction,
    visualization, dataset export) can reflect the optimized state.

    :param wm:
        The world model whose factor-graph variables will be updated in-place.
    :type wm: world.model.WorldModel
    :param result:
        Output from :func:`run_pose_graph_slam`, containing the optimized
        stacked state vector.
    :type result: PoseGraphResult
    """
    fg = wm.fg
    x_opt = result.x_opt
    _, index = wm.pack_state()  # re-pack to get consistent slices

    for nid, sl in index.items():
        v = fg.variables[nid]
        v.value = x_opt[sl]

visualize_pose_graph_3d(wm, title=None)

Convenience helper to plot the current factor graph in 3D.

This simply calls :func:world.visualization.plot_factor_graph_3d with the world's underlying :class:FactorGraph.

:param wm: World model whose factor graph will be visualized. :type wm: world.model.WorldModel :param title: Optional plot title. :type title: str | None

Source code in dsg-jit/dsg_jit/slam/pipeline.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def visualize_pose_graph_3d(
    wm: WorldModel,
    title: str | None = None,
) -> None:
    """
    Convenience helper to plot the current factor graph in 3D.

    This simply calls :func:`world.visualization.plot_factor_graph_3d`
    with the world's underlying :class:`FactorGraph`.

    :param wm:
        World model whose factor graph will be visualized.
    :type wm: world.model.WorldModel
    :param title:
        Optional plot title.
    :type title: str | None
    """
    plot_factor_graph_3d(wm.fg)