Skip to content

Core Modules

This section documents the foundational data structures and factor-graph engine used throughout DSG-JIT.


core.types

Core typed data structures for DSG-JIT.

This module defines the lightweight container classes used throughout the differentiable factor graph system. These types are intentionally minimal: they store only structural information and initial values, while all numerical operations are performed by JAX-compiled functions in the optimization layer.

Classes

Variable Represents a node in the factor graph. A variable contains: - id: Unique identifier (string or int) - value: Initial numeric state, typically a 1-D JAX array - metadata: Optional dictionary for semantic/scene-graph information

Factor Represents a constraint between one or more variables. A factor contains: - id: Unique identifier - type: String key selecting a residual function - var_ids: Ordered list of variable ids used by the residual - params: Dictionary of parameters passed into the residual function (e.g., weights, measurements, priors)

Notes

These objects are deliberately simple and mutable; they are not meant to be used directly inside JAX-compiled functions. During optimization, the FactorGraph packs variable values into a flat JAX array x, ensuring that JIT-compiled solvers operate on purely functional data.

This module forms the backbone of DSG-JIT's dynamic scene graph architecture, enabling hybrid SE3, voxel, and semantic structures to be represented in a unified factor graph.

Factor(id, type, var_ids, params) dataclass

Abstract factor connecting one or more variables.

A factor encodes a residual term in the overall objective, defined over an ordered tuple of variable ids and parameterized by a small dictionary of measurements, noise models, or other hyperparameters.

:param id: Unique identifier for the factor. :type id: FactorId :param type: String key indicating the factor/residual type (e.g., "odom", "loop_closure", "object_prior"). :type type: str :param var_ids: Ordered tuple of NodeIds that this factor connects. :type var_ids: tuple[NodeId, ...] :param params: Dictionary of parameters passed into the residual function, such as measurements, noise weights, or prior means. :type params: Dict[str, Any]

Pose3(x, y, z, roll, pitch, yaw) dataclass

Minimal 3D pose holder.

This is a lightweight container for a 3D pose parameterized as (x, y, z, roll, pitch, yaw).

:param x: Position along the x-axis in meters. :type x: float :param y: Position along the y-axis in meters. :type y: float :param z: Position along the z-axis in meters. :type z: float :param roll: Rotation around the x-axis in radians. :type roll: float :param pitch: Rotation around the y-axis in radians. :type pitch: float :param yaw: Rotation around the z-axis in radians. :type yaw: float

Variable(id, type, value) dataclass

Generic optimization variable node in the factor graph.

Each variable represents a node in the factor graph and stores its identifier, semantic type string, and current numeric value.

:param id: Unique identifier for the variable node. :type id: NodeId :param type: Semantic type of the variable (e.g., "pose3", "landmark3d"). :type type: str :param value: Initial or current numeric value for this variable, typically a 1-D JAX array or other array-like object. :type value: Any


core.math3d

SE3 and SO3 manifold operations for DSG-JIT.

This module implements the minimal 3D Lie-group mathematics required for differentiable SLAM and scene-graph optimization:

• SO(3) exponential & logarithm maps
• SE(3) exponential & logarithm maps
• Composition and inversion
• Small-angle approximations for stable Jacobians
• Helper utilities for constructing transforms
All functions are written in JAX and support
  • JIT compilation
  • Automatic differentiation
  • Batched operation
  • Numerically stable behavior near zero-rotation limits

Key Functions

so3_exp(w) Maps a 3-vector (axis-angle) to a 3×3 rotation matrix.

so3_log(R) Maps a rotation matrix back to its axis-angle representation.

se3_exp(xi) Maps a 6-vector twist ξ = (ω, v) to a 4×4 SE(3) transform matrix.

se3_log(T) Inverse of se3_exp; extracts a twist from an SE3 matrix.

se3_inverse(T) Computes the inverse of an SE3 transform.

se3_compose(A, B) Composes two SE3 transforms: A ∘ B.

Utilities

hat(ω) Converts a 3-vector to its skew-symmetric matrix.

vee(Ω) Converts a 3×3 skew matrix back into a 3-vector.

Notes

These functions are heavily used throughout DSG-JIT: • Odometry factors • Loop-closure factors • Pose-voxel alignment • Deformation-graph updates • Hybrid SE3 + voxel joint solvers

Correctness of the global optimization engine critically depends on these Lie-group operations being differentiable, stable, and JIT-friendly.

compose_pose_se3(a, b)

Compose two SE(3) poses in 6D vector representation.

:param a: 6D pose vector. :type a: jnp.ndarray :param b: 6D pose vector. :type b: jnp.ndarray :return: Composed pose a ∘ b in 6D vector form. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def compose_pose_se3(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """
    Compose two SE(3) poses in 6D vector representation.

    :param a: 6D pose vector.
    :type a: jnp.ndarray
    :param b: 6D pose vector.
    :type b: jnp.ndarray
    :return: Composed pose ``a ∘ b`` in 6D vector form.
    :rtype: jnp.ndarray
    """
    ta, wa = pose_vec_to_rt(a)
    tb, wb = pose_vec_to_rt(b)

    Ra = so3_exp(wa)
    Rb = so3_exp(wb)

    R = Ra @ Rb
    t = Ra @ tb + ta

    w = so3_log(R)
    return jnp.concatenate([t, w])

hat(v)

Compute the so(3) hat operator.

:param v: 3‑vector. :type v: jnp.ndarray :return: 3×3 skew‑symmetric matrix such that hat(v) @ w = v × w. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def hat(v: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the so(3) hat operator.

    :param v: 3‑vector.
    :type v: jnp.ndarray
    :return: 3×3 skew‑symmetric matrix such that ``hat(v) @ w = v × w``.
    :rtype: jnp.ndarray
    """
    x, y, z = v[0], v[1], v[2]
    return jnp.array(
        [
            [0.0, -z, y],
            [z, 0.0, -x],
            [-y, x, 0.0],
        ]
    )

pose_vec_to_rt(v)

Split a 6D pose vector into translation and rotation components.

:param v: 6D pose vector [tx, ty, tz, wx, wy, wz]. :type v: jnp.ndarray :return: Tuple (t, w) where t is translation (3,) and w is axis-angle rotation (3,). :rtype: tuple[jnp.ndarray, jnp.ndarray]

Source code in dsg-jit/dsg_jit/core/math3d.py
67
68
69
70
71
72
73
74
75
76
77
78
79
def pose_vec_to_rt(v: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
    """
    Split a 6D pose vector into translation and rotation components.

    :param v: 6D pose vector ``[tx, ty, tz, wx, wy, wz]``.
    :type v: jnp.ndarray
    :return: Tuple ``(t, w)`` where ``t`` is translation ``(3,)`` and ``w`` is axis-angle rotation ``(3,)``.
    :rtype: tuple[jnp.ndarray, jnp.ndarray]
    """
    v = jnp.asarray(v)
    t = v[0:3]
    w = v[3:6]
    return t, w

relative_pose_se3(a, b)

Compute relative pose from a to b in 6D coordinates.

:param a: First pose (6,). :type a: jnp.ndarray :param b: Second pose (6,). :type b: jnp.ndarray :return: Relative twist such that Exp(xi) ≈ T_a^{-1} T_b. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def relative_pose_se3(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    """
    Compute relative pose from ``a`` to ``b`` in 6D coordinates.

    :param a: First pose ``(6,)``.
    :type a: jnp.ndarray
    :param b: Second pose ``(6,)``.
    :type b: jnp.ndarray
    :return: Relative twist such that ``Exp(xi) ≈ T_a^{-1} T_b``.
    :rtype: jnp.ndarray
    """
    ta, wa = pose_vec_to_rt(a)
    tb, wb = pose_vec_to_rt(b)

    Ra = so3_exp(wa)
    Rb = so3_exp(wb)

    R_rel = Ra.T @ Rb
    w_rel = so3_log(R_rel)
    t_rel = Ra.T @ (tb - ta)

    return jnp.concatenate([t_rel, w_rel])

se3_exp(xi)

Exponential map from se(3) to SE(3).

:param xi: 6‑vector twist [v_x, v_y, v_z, w_x, w_y, w_z]. :type xi: jnp.ndarray :return: 4×4 homogeneous SE(3) transform matrix. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def se3_exp(xi: jnp.ndarray) -> jnp.ndarray:
    """
    Exponential map from se(3) to SE(3).

    :param xi: 6‑vector twist ``[v_x, v_y, v_z, w_x, w_y, w_z]``.
    :type xi: jnp.ndarray
    :return: 4×4 homogeneous SE(3) transform matrix.
    :rtype: jnp.ndarray
    """
    v = xi[:3]
    w = xi[3:]

    theta = jnp.linalg.norm(w)
    I = jnp.eye(3)

    # Rotation
    R = so3_exp(w)

    # Small-angle approximations for Jacobian
    def small_angle():
        # For tiny rotation, J ≈ I + 0.5 * W
        W = hat(w)
        return I + 0.5 * W

    def normal_angle():
        W = hat(w)
        W2 = W @ W
        A = jnp.sin(theta) / theta
        B = (1 - jnp.cos(theta)) / (theta * theta)
        return I + A * W + B * W2

    J = jax.lax.cond(theta < 1e-5, small_angle, normal_angle)

    t = J @ v

    # Build full SE(3) matrix
    T = jnp.eye(4)
    T = T.at[:3, :3].set(R)
    T = T.at[:3, 3].set(t)

    return T

se3_identity()

Return the identity SE(3) pose.

:return: Zero 6‑vector representing identity rotation and translation. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
307
308
309
310
311
312
313
314
def se3_identity() -> jnp.ndarray:
    """
    Return the identity SE(3) pose.

    :return: Zero 6‑vector representing identity rotation and translation.
    :rtype: jnp.ndarray
    """
    return jnp.zeros(6, dtype=jnp.float32)

se3_retract_left(pose, delta)

Left‑multiplicative SE(3) retraction.

:param pose: Base pose in 6D coordinates. :type pose: jnp.ndarray :param delta: Incremental twist in 6D. :type delta: jnp.ndarray :return: Updated pose Exp(delta) * pose in 6D vector form. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def se3_retract_left(pose: jnp.ndarray, delta: jnp.ndarray) -> jnp.ndarray:
    """
    Left‑multiplicative SE(3) retraction.

    :param pose: Base pose in 6D coordinates.
    :type pose: jnp.ndarray
    :param delta: Incremental twist in 6D.
    :type delta: jnp.ndarray
    :return: Updated pose ``Exp(delta) * pose`` in 6D vector form.
    :rtype: jnp.ndarray
    """
    pose = jnp.asarray(pose)
    delta = jnp.asarray(delta)

    # Split into translation + rotation-vector
    t, w = pose_vec_to_rt(pose)
    dt, dw = pose_vec_to_rt(delta)

    R = so3_exp(w)
    R_d = so3_exp(dw)

    # Apply left-multiplicative update
    R_new = R_d @ R
    t_new = R_d @ t + dt

    w_new = so3_log(R_new)
    return jnp.concatenate([t_new, w_new])

so3_exp(w)

Exponential map from so(3) to SO(3).

:param w: Rotation vector in axis‑angle form (3,). :type w: jnp.ndarray :return: Rotation matrix in SO(3) with shape (3, 3). :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def so3_exp(w: jnp.ndarray) -> jnp.ndarray:
    """
    Exponential map from so(3) to SO(3).

    :param w: Rotation vector in axis‑angle form ``(3,)``.
    :type w: jnp.ndarray
    :return: Rotation matrix in SO(3) with shape ``(3, 3)``.
    :rtype: jnp.ndarray
    """
    w = jnp.asarray(w)
    theta = jnp.linalg.norm(w)
    I = jnp.eye(3)

    def small_angle() -> jnp.ndarray:
        # First-order approximation for small angles
        return I + hat(w)

    def normal_angle() -> jnp.ndarray:
        k = w / theta
        K = hat(k)
        return I + jnp.sin(theta) * K + (1.0 - jnp.cos(theta)) * (K @ K)

    return jax.lax.cond(theta < 1e-5, small_angle, normal_angle)

so3_log(R)

Logarithm map from SO(3) to so(3).

:param R: Rotation matrix in SO(3) with shape (3, 3). :type R: jnp.ndarray :return: Rotation vector (3,) in axis‑angle form. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def so3_log(R: jnp.ndarray) -> jnp.ndarray:
    """
    Logarithm map from SO(3) to so(3).

    :param R: Rotation matrix in SO(3) with shape ``(3, 3)``.
    :type R: jnp.ndarray
    :return: Rotation vector ``(3,)`` in axis‑angle form.
    :rtype: jnp.ndarray
    """
    R = jnp.asarray(R)
    # Compute cos(theta) with clamping
    trace = jnp.trace(R)
    cos_theta = (trace - 1.0) / 2.0

    # Clamp to valid domain for arccos
    cos_theta = jnp.clip(cos_theta, -1.0, 1.0)
    theta = jnp.arccos(cos_theta)

    # Small-angle threshold
    eps = 1e-5

    def small_angle_case(_) -> jnp.ndarray:
        # For very small angles, R ~ I + hat(w), so:
        # hat(w) ~ R - I  => w ~ vee(R - I)
        w_skew = R - jnp.eye(3, dtype=R.dtype)
        return vee(w_skew)

    def general_case(_) -> jnp.ndarray:
        # Standard formula:
        #   w^ = (theta / (2 sin(theta))) * (R - R^T)
        #   w  = vee(w^)
        w_skew = R - R.T
        # Safe denominator
        denom = 2.0 * jnp.sin(theta)
        factor = theta / (denom + 1e-12)
        w = factor * vee(w_skew)
        return w

    w = jax.lax.cond(
        theta < eps,
        small_angle_case,
        general_case,
        operand=None,
    )

    return w

vee(R)

Inverse of the hat operator.

:param R: 3×3 skew‑symmetric matrix. :type R: jnp.ndarray :return: Corresponding 3‑vector. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/core/math3d.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def vee(R: jnp.ndarray) -> jnp.ndarray:
    """
    Inverse of the hat operator.

    :param R: 3×3 skew‑symmetric matrix.
    :type R: jnp.ndarray
    :return: Corresponding 3‑vector.
    :rtype: jnp.ndarray
    """
    return jnp.array([
        R[2, 1] - R[1, 2],
        R[0, 2] - R[2, 0],
        R[1, 0] - R[0, 1],
    ]) / 2.0

core.factor_graph

Core factor graph data structure for DSG-JIT.

This module implements a minimal, backend-agnostic factor graph. It stores variables and factors and provides basic helpers for managing the graph structure, but it does not contain any JAX, residual, or JIT-specific logic. All residual construction, vmap batching, and solver orchestration are handled by higher-level components such as :class:dsg_jit.world.model.WorldModel.

The FactorGraph stores
  • Variables (nodes in the optimization graph)
  • Factors (constraints between variables)

Design Philosophy

  • Keep this layer small and generic so it can serve as a stable backend for multiple front-ends (WorldModel, alternative world models, or external bindings).
  • Do not depend on JAX or expose residual-building APIs here; instead, expose only structural information (variables, factors, IDs).
  • Allow higher layers to interpret variables and factors however they like (e.g., as poses, voxels, semantic objects) without baking that interpretation into the core graph.

Typical Usage

  • add_variable / add_factor: Build up the factor graph structure.
  • Higher-level code (e.g., :mod:dsg_jit.world.model) inspects :attr:variables and :attr:factors to construct residual functions and objectives suitable for Gauss–Newton or gradient-based solvers.

FactorGraph(variables=dict(), factors=dict()) dataclass

Abstract factor graph for DSG-JIT.

This class stores variables and factors and exposes simple helpers to register them. All residual-building and optimization logic is delegated to higher-level components such as :class:WorldModel.

add_factor(factor)

Register a new factor in the factor graph.

The factor must only reference variables that already exist in :attr:variables.

:param factor: Factor to add to the graph. Its id must be unique. :type factor: Factor

Source code in dsg-jit/dsg_jit/core/factor_graph.py
66
67
68
69
70
71
72
73
74
75
76
def add_factor(self, factor: Factor) -> None:
    """Register a new factor in the factor graph.

    The factor must only reference variables that already exist in
    :attr:`variables`.

    :param factor: Factor to add to the graph. Its ``id`` must be unique.
    :type factor: Factor
    """
    assert factor.id not in self.factors
    self.factors[factor.id] = factor

add_variable(var)

Register a new variable in the factor graph.

This does not modify any existing factors; it simply makes the variable available to be referenced by factors.

:param var: Variable to add to the graph. Its id must be unique. :type var: Variable

Source code in dsg-jit/dsg_jit/core/factor_graph.py
54
55
56
57
58
59
60
61
62
63
64
def add_variable(self, var: Variable) -> None:
    """Register a new variable in the factor graph.

    This does *not* modify any existing factors; it simply makes the
    variable available to be referenced by factors.

    :param var: Variable to add to the graph. Its ``id`` must be unique.
    :type var: Variable
    """
    assert var.id not in self.variables
    self.variables[var.id] = var