Skip to content

Optimization Modules

This section documents the JIT-compiled Gauss–Newton solvers and wrappers used inside DSG-JIT.


optimization.solvers

Nonlinear optimization solvers for DSG-JIT.

This module implements the core iterative solvers used throughout the system, with a focus on JAX-friendly, JIT-compilable routines that operate on flat state vectors and manifold-aware blocks (e.g., SE(3) poses).

The solvers are designed to work with residual functions produced by core.factor_graph.FactorGraph, and are used in:

• Pure SE3 SLAM chains
• Voxel grid smoothness / observation problems
• Hybrid SE3 + voxel joint optimization
• Differentiable experiments where measurements or weights are learned

Key Concepts

GNConfig Dataclass holding configuration for Gauss–Newton: - max_iters: maximum number of GN iterations - damping: Levenberg–Marquardt-style damping - max_step_norm: optional clamp on update step size - verbose / debug flags (if enabled)

gauss_newton(residual_fn, x0, cfg) Classic Gauss–Newton on a flat Euclidean state: - residual_fn: r(x) -> (m,) JAX array - x0: initial state - cfg: GNConfig

Computes updates using normal equations:
    Jᵀ J Δx = -Jᵀ r
and returns the optimized state.

gauss_newton_manifold(residual_fn, x0, block_slices, manifold_types, cfg) Manifold-aware Gauss–Newton: - residual_fn: r(x) -> (m,) - x0: initial flat state vector - block_slices: NodeId -> slice in x - manifold_types: NodeId -> {"se3", "euclidean", ...} - cfg: GNConfig

For SE3 blocks:
    • The update is computed in the tangent space (se(3))
    • Applied via retract / exponential map
    • Ensures updates stay on the manifold

For Euclidean blocks:
    • Updates are applied additively.

Design Goals

• Fully JAX-compatible: All heavy operations are written in terms of JAX primitives so that solvers can be JIT-compiled and differentiated through when needed.

• Stable and controlled: Optional damping and step-norm clamping help avoid NaNs and divergence in difficult configurations (e.g., bad initialization or large residuals).

• Reusable: Experiments and higher-level training loops (e.g., in experiments/ and optimization/jit_wrappers.py) call into these solvers as the core iterative engine for DSG-JIT.

Notes

These solvers are intentionally minimal and generic. They do not know anything about SE3 or voxels directly; instead, they rely on the factor graph and manifold metadata to interpret the state vector correctly.

If you add new manifold types (e.g., quaternions or higher-dimensional poses), extend the manifold handling logic in the manifold-aware solver.

damped_newton(objective, x0, cfg)

Damped Newton optimizer for small problems.

Uses a Levenberg–Marquardt-style update::

(H + λ I) \delta = abla f(x)
x_{k+1} = x_k - \delta

where H is the Hessian of the objective and λ is a damping factor.

:param objective: Objective function f(x) that maps a state vector to a scalar loss. :type objective: Callable[[jnp.ndarray], jnp.ndarray] :param x0: Initial state vector. :type x0: jnp.ndarray :param cfg: Newton solver configuration (number of iterations and damping). :type cfg: NewtonConfig :return: Optimized state vector after damped Newton iterations. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/optimization/solvers.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def damped_newton(objective: ObjectiveFn, x0: jnp.ndarray, cfg: NewtonConfig) -> jnp.ndarray:
    """Damped Newton optimizer for small problems.

    Uses a Levenberg–Marquardt-style update::

        (H + λ I) \\delta = abla f(x)
        x_{k+1} = x_k - \\delta

    where ``H`` is the Hessian of the objective and ``λ`` is a damping factor.

    :param objective: Objective function ``f(x)`` that maps a state vector to a scalar loss.
    :type objective: Callable[[jnp.ndarray], jnp.ndarray]
    :param x0: Initial state vector.
    :type x0: jnp.ndarray
    :param cfg: Newton solver configuration (number of iterations and damping).
    :type cfg: NewtonConfig
    :return: Optimized state vector after damped Newton iterations.
    :rtype: jnp.ndarray
    """
    grad_fn = jax.grad(objective)
    hess_fn = jax.hessian(objective)

    x = x0
    for _ in range(cfg.max_iters):
        g = grad_fn(x)
        H = hess_fn(x)

        n = x.shape[0]
        H_damped = H + cfg.damping * jnp.eye(n)

        # Solve H_damped * delta = g
        delta = jnp.linalg.solve(H_damped, g)

        x = x - delta

    return x

gauss_newton(residual_fn, x0, cfg)

Gauss–Newton on a residual function r(x): R^n -> R^m.

The algorithm forms the normal equations::

J^T J \delta = J^T r
x_{k+1} = x_k - \delta

with optional diagonal damping and step-size clamping for stability.

:param residual_fn: Residual function r(x) returning a 1D array of shape (m,). :type residual_fn: Callable[[jnp.ndarray], jnp.ndarray] :param x0: Initial state vector of shape (n,). :type x0: jnp.ndarray :param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp). :type cfg: GNConfig :return: Optimized state vector after Gauss–Newton iterations. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/optimization/solvers.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def gauss_newton(residual_fn: ObjectiveFn, x0: jnp.ndarray, cfg: GNConfig) -> jnp.ndarray:
    """Gauss–Newton on a residual function ``r(x): R^n -> R^m``.

    The algorithm forms the normal equations::

        J^T J \\delta = J^T r
        x_{k+1} = x_k - \\delta

    with optional diagonal damping and step-size clamping for stability.

    :param residual_fn: Residual function ``r(x)`` returning a 1D array of shape ``(m,)``.
    :type residual_fn: Callable[[jnp.ndarray], jnp.ndarray]
    :param x0: Initial state vector of shape ``(n,)``.
    :type x0: jnp.ndarray
    :param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp).
    :type cfg: GNConfig
    :return: Optimized state vector after Gauss–Newton iterations.
    :rtype: jnp.ndarray
    """
    J_fn = jax.jacobian(residual_fn)  # J: (m, n)

    def step(x: jnp.ndarray) -> jnp.ndarray:
        r = residual_fn(x)    # (m,)
        J = J_fn(x)           # (m, n)

        H = J.T @ J           # (n, n)
        g = J.T @ r           # (n,)

        n = x.shape[0]
        H_damped = H + cfg.damping * jnp.eye(n)

        delta = jnp.linalg.solve(H_damped, g)  # (n,)

        # Optional step-size clamp to avoid huge jumps
        step_norm = jnp.linalg.norm(delta)
        scale = jnp.minimum(1.0, cfg.max_step_norm / (step_norm + 1e-9))

        return x - scale * delta

    x = x0
    for _ in range(cfg.max_iters):
        x = step(x)
    return x

gauss_newton_manifold(residual_fn, x0, block_slices, manifold_types, cfg)

Manifold-aware Gauss–Newton solver.

This variant still solves in a flat parameter space, but applies updates block-wise using the appropriate manifold retraction. In particular:

  • Blocks marked as "se3" are updated via se3_retract_left in the Lie algebra se(3).
  • Blocks marked as "euclidean" are updated additively.

:param residual_fn: Residual function r(x) returning a 1D array of shape (m,). :type residual_fn: Callable[[jnp.ndarray], jnp.ndarray] :param x0: Initial flat state vector of shape (n,). :type x0: jnp.ndarray :param block_slices: Mapping from node identifier to slice in x defining that variable's block. May be a dict or a sequence of (node_id, slice) pairs. :type block_slices: Union[Mapping[Any, slice], Sequence[Tuple[Any, slice]]] :param manifold_types: Mapping from node identifier to manifold label (e.g. "se3" or "euclidean"). May be a dict or a sequence of (node_id, manifold_type) pairs. :type manifold_types: Union[Mapping[Any, str], Sequence[Tuple[Any, str]]] :param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp). :type cfg: GNConfig :return: Optimized state vector after manifold-aware Gauss–Newton iterations. :rtype: jnp.ndarray

Source code in dsg-jit/dsg_jit/optimization/solvers.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def gauss_newton_manifold(
    residual_fn: ObjectiveFn,
    x0: jnp.ndarray,
    block_slices: Union[Mapping[Any, slice], Sequence[Tuple[Any, slice]]],
    manifold_types: Union[Mapping[Any, str], Sequence[Tuple[Any, str]]],
    cfg: GNConfig,
) -> jnp.ndarray:
    """Manifold-aware Gauss–Newton solver.

    This variant still solves in a flat parameter space, but applies updates
    block-wise using the appropriate manifold retraction. In particular:

    * Blocks marked as ``"se3"`` are updated via ``se3_retract_left`` in the
      Lie algebra ``se(3)``.
    * Blocks marked as ``"euclidean"`` are updated additively.

    :param residual_fn: Residual function ``r(x)`` returning a 1D array of shape ``(m,)``.
    :type residual_fn: Callable[[jnp.ndarray], jnp.ndarray]
    :param x0: Initial flat state vector of shape ``(n,)``.
    :type x0: jnp.ndarray
    :param block_slices: Mapping from node identifier to slice in ``x`` defining that variable's block.
                         May be a dict or a sequence of (node_id, slice) pairs.
    :type block_slices: Union[Mapping[Any, slice], Sequence[Tuple[Any, slice]]]
    :param manifold_types: Mapping from node identifier to manifold label (e.g. ``"se3"`` or ``"euclidean"``).
                           May be a dict or a sequence of (node_id, manifold_type) pairs.
    :type manifold_types: Union[Mapping[Any, str], Sequence[Tuple[Any, str]]]
    :param cfg: Gauss–Newton configuration (iterations, damping, step-norm clamp).
    :type cfg: GNConfig
    :return: Optimized state vector after manifold-aware Gauss–Newton iterations.
    :rtype: jnp.ndarray
    """
    # J: (m, n), r: (m,)
    J_fn = jax.jacobian(residual_fn)

    x = x0

    for _ in range(cfg.max_iters):
        r = residual_fn(x)       # (m,)
        J = J_fn(x)              # (m, n)

        H = J.T @ J              # (n, n)
        g = J.T @ r              # (n,)

        n = x.shape[0]
        H_damped = H + cfg.damping * jnp.eye(n)

        delta = jnp.linalg.solve(H_damped, g)  # (n,)

        # Step size clamp
        step_norm = jnp.linalg.norm(delta)
        scale = jnp.minimum(1.0, cfg.max_step_norm / (step_norm + 1e-9))
        delta_scaled = scale * delta

        x_new = x

        if isinstance(block_slices, Mapping) and isinstance(manifold_types, Mapping):
            # Legacy path: dict lookups.
            for nid, sl in block_slices.items():
                d_i = delta_scaled[sl]
                x_i = x[sl]
                mtype = manifold_types.get(nid, "euclidean")

                if mtype == "se3":
                    # Interpret d_i as a twist in se(3) and apply left retraction
                    x_i_new = se3_retract_left(x_i, -d_i)
                else:
                    # Euclidean update
                    x_i_new = x_i - d_i

                x_new = x_new.at[sl].set(x_i_new)
        else:
            # JAX-friendly path: iterate in lockstep over (nid, slice) and (nid, mtype)
            # without dict indexing. `block_slices` and `manifold_types` are expected
            # to be aligned and ordered consistently.
            for (nid, sl), (_, mtype) in zip(block_slices, manifold_types):
                d_i = delta_scaled[sl]
                x_i = x[sl]

                if mtype == "se3":
                    x_i_new = se3_retract_left(x_i, -d_i)
                else:
                    x_i_new = x_i - d_i

                x_new = x_new.at[sl].set(x_i_new)

        x = x_new

    return x

gradient_descent(objective, x0, cfg)

Simple gradient descent optimizer.

Performs iterative updates of the form::

    x_{k+1} = x_k - learning_rate *

abla f(x_k)

until ``max_iters`` is reached.

:param objective: Objective function ``f(x)`` that maps a state vector to a scalar loss.
:type objective: Callable[[jnp.ndarray], jnp.ndarray]
:param x0: Initial state vector.
:type x0: jnp.ndarray
:param cfg: Gradient-descent configuration (learning rate and number of iterations).
:type cfg: GDConfig
:return: Optimized state vector after gradient descent.
:rtype: jnp.ndarray
Source code in dsg-jit/dsg_jit/optimization/solvers.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def gradient_descent(objective: ObjectiveFn, x0: jnp.ndarray, cfg: GDConfig) -> jnp.ndarray:
    """Simple gradient descent optimizer.

    Performs iterative updates of the form::

        x_{k+1} = x_k - learning_rate * 
abla f(x_k)

    until ``max_iters`` is reached.

    :param objective: Objective function ``f(x)`` that maps a state vector to a scalar loss.
    :type objective: Callable[[jnp.ndarray], jnp.ndarray]
    :param x0: Initial state vector.
    :type x0: jnp.ndarray
    :param cfg: Gradient-descent configuration (learning rate and number of iterations).
    :type cfg: GDConfig
    :return: Optimized state vector after gradient descent.
    :rtype: jnp.ndarray
    """
    grad_fn = jax.grad(objective)

    x = x0
    for _ in range(cfg.max_iters):
        g = grad_fn(x)
        x = x - cfg.learning_rate * g
    return x

optimization.jit_wrappers

JIT-friendly optimization wrappers and training utilities for DSG-JIT.

This module provides higher-level utilities that sit on top of the core solvers in optimization.solvers. They are responsible for:

• Building JIT-compiled solve functions for a fixed world model-backed
  factor graph
• Wrapping Gauss–Newton in a functional interface (solve(x0) -> x_opt)
• Supporting differentiable inner loops for meta-learning experiments
• Implementing simple trainer-style loops used in Phase 4 experiments

Typical Usage

The experiments in experiments/ use this module to:

• Construct a `WorldModel`-backed factor graph (SE3, voxels, hybrid)
• Get a JIT-compiled residual or objective from the world model
  (e.g., via :meth:`WorldModel.build_residual`, which internally groups
  factors by type and shape and uses :func:`jax.vmap` for efficiency)
• Build a `solve_once(x0)` function using Gauss–Newton
• Use `jax.grad` or `jax.value_and_grad` over an outer loss that depends
  on the optimized state

Example patterns include:

• Learning SE3 odometry measurements by backpropagating through the
  inner Gauss–Newton solve
• Learning voxel observation points that make a grid consistent with
  known ground-truth centers
• Learning factor-type weights (log-scales) for odometry vs. observations
  via supervised losses on final poses/voxels

Key Utilities (typical contents)

build_jit_gauss_newton(...) Given a WorldModel and a GNConfig, returns a JIT-compiled function: solve_once(x0) -> x_opt

build_param_residual(...) Wraps a residual function so that it depends both on the state x and on learnable parameters theta (e.g., measurements, observation points).

DSGTrainer (if present) A lightweight helper class implementing: - inner_solve(theta): run Gauss–Newton or GD on the graph - loss(theta): compute a supervised loss on the optimized state - step(theta): one gradient step on theta

Design Goals

• Separate concerns: The low-level solver logic lives in solvers.py, while experiment- specific JIT wiring and training loops live here.

• Encourage functional patterns: All wrappers aim to expose pure functions that JAX can JIT and differentiate, avoiding hidden state and side effects.

• Make research experiments easy: This is the layer where new meta-learning or differentiable-graph experiments should be prototyped before they are promoted into a more general API.

Notes

Because these wrappers are tailored to DSG-JIT’s factor graph structure, they assume:

• Residual functions derived from :class:`WorldModel`, e.g.
  :meth:`WorldModel.build_residual` and its hyper-parameterized
  variants
• State vectors packed/unpacked via the world model’s core graph
  machinery (``WorldModel.pack_state`` / ``WorldModel.unpack_state``)

When modifying or extending this module, take care to preserve JIT and grad-friendliness: avoid Python-side mutation inside jitted functions and keep logic purely functional wherever possible.

JittedGN(fn, cfg) dataclass

JIT-compiled Gauss–Newton solver for a fixed world model-backed factor graph.

Note

This wrapper targets the Euclidean solver :func:gauss_newton. For SE(3)/manifold problems use :class:JittedGNManifold instead.

This lightweight wrapper stores a jitted solve function and the configuration used to build it. Typical usage:

residual_fn = wm.build_residual()  # vmap-optimized residual
cfg = GNConfig(...)
jgn = JittedGN.from_residual(residual_fn, cfg)
x_opt = jgn(x0)

:param fn: JIT-compiled function that maps an initial state vector x0 to an optimized state x_opt. :param cfg: Gauss–Newton configuration used when building the jitted solver.

__call__(x0)

Run the jitted Gauss–Newton solve on an initial state.

:param x0: Initial flat state vector to optimize. :return: Optimized state vector after running Gauss–Newton.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
123
124
125
126
127
128
129
def __call__(self, x0: jnp.ndarray) -> jnp.ndarray:
    """Run the jitted Gauss–Newton solve on an initial state.

    :param x0: Initial flat state vector to optimize.
    :return: Optimized state vector after running Gauss–Newton.
    """
    return self.fn(x0)

from_residual(residual_fn, cfg) staticmethod

Construct a :class:JittedGN from a residual function.

This wraps :func:gauss_newton with the provided configuration and JIT-compiles the resulting solve(x0) function.

:param residual_fn: Residual function r(x) returning the stacked residual vector for a fixed factor graph. :param cfg: Gauss–Newton configuration (step limits, damping, etc.). :return: A :class:JittedGN instance whose __call__ method runs the jitted Gauss–Newton solve.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
@staticmethod
def from_residual(
    residual_fn: Callable[[jnp.ndarray], jnp.ndarray],
    cfg: GNConfig,
) -> "JittedGN":
    """Construct a :class:`JittedGN` from a residual function.

    This wraps :func:`gauss_newton` with the provided configuration
    and JIT-compiles the resulting ``solve(x0)`` function.

    :param residual_fn: Residual function ``r(x)`` returning the stacked
                        residual vector for a fixed factor graph.
    :param cfg: Gauss–Newton configuration (step limits, damping, etc.).
    :return: A :class:`JittedGN` instance whose ``__call__`` method
             runs the jitted Gauss–Newton solve.
    """
    # Wrap existing gauss_newton. cfg is closed over and treated as static.
    def solve(x0: jnp.ndarray) -> jnp.ndarray:
        return gauss_newton(residual_fn, x0, cfg)

    # jit the whole solve for this graph
    jitted = jax.jit(solve)
    return JittedGN(fn=jitted, cfg=cfg)

from_world_model(wm, cfg) staticmethod

Construct a :class:JittedGN directly from a :class:WorldModel.

This helper calls :meth:WorldModel.build_residual to obtain the vmap-optimized residual function for the current world, and then wraps it in a jitted Gauss–Newton solve.

Typical usage::

wm = WorldModel()
# ... add variables, factors, register residuals ...
jgn = JittedGN.from_world_model(wm, GNConfig(max_iters=20))
x0, _ = wm.pack_state()
x_opt = jgn(x0)

:param wm: World model whose factor graph defines the optimization problem. Its :meth:build_residual method is used to obtain the residual function. :param cfg: Gauss–Newton configuration (step limits, damping, etc.). :return: A :class:JittedGN instance whose __call__ method runs the jitted Gauss–Newton solve using the world model’s vmap-optimized residual.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@staticmethod
def from_world_model(
    wm: "WorldModel",
    cfg: GNConfig,
) -> "JittedGN":
    """Construct a :class:`JittedGN` directly from a :class:`WorldModel`.

    This helper calls :meth:`WorldModel.build_residual` to obtain the
    vmap-optimized residual function for the current world, and then
    wraps it in a jitted Gauss–Newton solve.

    Typical usage::

        wm = WorldModel()
        # ... add variables, factors, register residuals ...
        jgn = JittedGN.from_world_model(wm, GNConfig(max_iters=20))
        x0, _ = wm.pack_state()
        x_opt = jgn(x0)

    :param wm: World model whose factor graph defines the optimization
               problem. Its :meth:`build_residual` method is used to
               obtain the residual function.
    :param cfg: Gauss–Newton configuration (step limits, damping, etc.).
    :return: A :class:`JittedGN` instance whose ``__call__`` method
             runs the jitted Gauss–Newton solve using the world model’s
             vmap-optimized residual.
    """
    residual_fn = wm.build_residual()
    return JittedGN.from_residual(residual_fn, cfg)

JittedGNManifold(fn, cfg) dataclass

JIT-compiled manifold Gauss–Newton solver for a fixed graph.

This wrapper is intended for SLAM-style problems where the packed state vector is a concatenation of manifold variables (e.g., SE(3) poses and R^3 landmarks). It closes over the residual function and manifold metadata and returns a single jitted solve function.

Typical usage::

residual_fn = wm.build_residual()
manifold_types, block_slices = build_manifold_metadata(...)
cfg = GNConfig(max_iters=1)
jgn = JittedGNManifold.from_residual(residual_fn, manifold_types, block_slices, cfg)
x_opt = jgn(x0)
IMPORTANT

To avoid repeated compilation, construct this once and reuse it for every incremental step. Ensure the shapes/dtypes of x0 and the residual output remain constant across steps (template mode).

:param fn: JIT-compiled function mapping x0 -> x_opt. :param cfg: Gauss–Newton configuration.

__call__(x0)

Run the jitted manifold Gauss–Newton solve.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
215
216
217
def __call__(self, x0: jnp.ndarray) -> jnp.ndarray:
    """Run the jitted manifold Gauss–Newton solve."""
    return self.fn(x0)

from_residual(residual_fn, manifold_types, block_slices, cfg) staticmethod

Construct a :class:JittedGNManifold from residual + metadata.

:param residual_fn: Residual function r(x). :param manifold_types: Per-block manifold type strings. :param block_slices: Per-block slices into the packed vector. :param cfg: Solver configuration. :return: A reusable, jitted solver.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
@staticmethod
def from_residual(
    residual_fn: Callable[[jnp.ndarray], jnp.ndarray],
    manifold_types: Any,
    block_slices: Any,
    cfg: GNConfig,
) -> "JittedGNManifold":
    """Construct a :class:`JittedGNManifold` from residual + metadata.

    :param residual_fn: Residual function ``r(x)``.
    :param manifold_types: Per-block manifold type strings.
    :param block_slices: Per-block slices into the packed vector.
    :param cfg: Solver configuration.
    :return: A reusable, jitted solver.
    """

    def solve(x0: jnp.ndarray) -> jnp.ndarray:
        return gauss_newton_manifold(
            residual_fn=residual_fn,
            x0=x0,
            manifold_types=manifold_types,
            block_slices=block_slices,
            cfg=cfg,
        )

    # JIT the whole solve; cfg/manifold metadata are closed over.
    jitted = jax.jit(solve)
    return JittedGNManifold(fn=jitted, cfg=cfg)

from_world_model(wm, manifold_types, block_slices, cfg) staticmethod

Construct a manifold GN solver directly from a :class:WorldModel.

This helper obtains the residual via :meth:WorldModel.build_residual.

:param wm: World model. :param manifold_types: Per-block manifold types. :param block_slices: Per-block slices. :param cfg: Solver configuration. :return: A reusable, jitted solver.

Source code in dsg-jit/dsg_jit/optimization/jit_wrappers.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@staticmethod
def from_world_model(
    wm: "WorldModel",
    manifold_types: list[str],
    block_slices: list[slice],
    cfg: GNConfig,
) -> "JittedGNManifold":
    """Construct a manifold GN solver directly from a :class:`WorldModel`.

    This helper obtains the residual via :meth:`WorldModel.build_residual`.

    :param wm: World model.
    :param manifold_types: Per-block manifold types.
    :param block_slices: Per-block slices.
    :param cfg: Solver configuration.
    :return: A reusable, jitted solver.
    """
    residual_fn = wm.build_residual()
    return JittedGNManifold.from_residual(residual_fn=residual_fn, manifold_types=manifold_types, block_slices=block_slices, cfg=cfg)