Skip to content

World Model Modules

This section documents the high-level world model, scene-graph integration, voxel grid management, and training utilities.


world.model

World-level wrapper and optimization front-end around the core factor graph.

This module defines the world model abstraction: a typed layer on top of core.factor_graph.FactorGraph that understands high-level entities (poses, places, rooms, voxels, objects, agents) and also centralizes residual construction, JIT compilation, and solver orchestration.

In other words, :class:WorldModel is the bridge between:

• Low-level optimization (factor graph, residual functions, manifolds)
• High-level scene graph abstractions (poses, agents, rooms, voxels)
• Application code that wants a simple, stable API for "optimize my world"

The underlying :class:FactorGraph remains a relatively small, generic data structure that stores variables and factors and knows nothing about JAX, JIT, or manifolds. All JAX-specific logic (residual registries, vmap-based batching, Gauss–Newton wrappers, etc.) is owned by the world model.

Key responsibilities

  • Manage the underlying :class:FactorGraph instance.
  • Provide ergonomic helpers to: • Add variables with automatically assigned :class:NodeIds. • Add typed factors (e.g. priors, odometry, attachments, voxel terms). • Pack / unpack state vectors for optimization.
  • Maintain simple bookkeeping structures (e.g. maps from user-facing handles / indices back to :class:NodeIds) so that experiments and higher-level layers do not need to manipulate :class:NodeId directly.
  • Maintain a residual-function registry that maps factor-type strings (e.g. "odom_se3", "voxel_point_obs") to JAX-compatible residuals.
  • Build unified, vmap-optimized residual and objective functions on demand, caching compiled versions keyed by graph structure.
  • Expose convenient optimization entry points (e.g. :meth:optimize, or :class:optimization.jit_wrappers.JittedGN) that operate directly on the world model.

Typical usage

Experiments and higher layers typically:

1. Construct a :class:`WorldModel`.
2. Add variables & factors according to a scenario.
3. Register residual functions for each factor type of interest.
4. Build a residual or objective from the world model and call into
   :mod:`dsg_jit.optimization.solvers` or :mod:`dsg_jit.optimization.jit_wrappers`
   to run Gauss–Newton (potentially manifold-aware) or gradient-based
   optimization.
5. Decode and interpret the optimized state via the world model’s
   convenience accessors, or export it to higher-level scene-graph
   structures.

Design goals

  • Backend separation: keep :class:FactorGraph as a minimal, backend-agnostic data structure (variables, factors, connectivity), while :class:WorldModel owns JAX-facing logic such as residual construction, vmap batching, and JIT caching.
  • Scene-friendly: provide enough structure that scene graphs, voxel modules, and DSG layers can build on top of the world model without duplicating graph or optimization logic.
  • Ergonomic but explicit: favor simple, explicit methods (add_variable, add_factor, register_residual, optimize) over hidden magic, so that experiments remain easy to debug and extend.

ActiveWindowTemplate(variable_slots, factor_slots) dataclass

Defines a fixed-capacity active factor graph template for JIT-stable operation. Each variable/factor slot is identified by (type, slot_idx).

FactorSlot(factor_type, slot_idx, factor_id, var_slot_keys) dataclass

Bookkeeping for a factor slot in the active template.

VarSlot(var_type, slot_idx, node_id, dim) dataclass

Bookkeeping for a variable slot in the active template.

WorldModel() dataclass

High-level world model built on top of :class:FactorGraph.

Modes
  • Dynamic/unbounded FG (legacy, research mode): Variables and factors can be added/removed dynamically.
  • Fixed-capacity active template (real-time / JIT-stable mode): A fixed set of variable/factor slots is preallocated for JIT-compatibility and in-place updates.

In addition to wrapping the core factor graph, this class keeps simple bookkeeping dictionaries that make it easier to build static and dynamic scene graphs on top of DSG-JIT. These maps are deliberately lightweight and optional: if you never pass a name when adding variables, the underlying optimization behavior is unchanged.

Source code in dsg-jit/dsg_jit/world/model.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def __init__(self) -> None:
    # Core factor graph
    self.fg = FactorGraph()
    # Semantic maps; these are purely for convenience and do not affect
    # the underlying optimization.
    self.pose_ids = {}
    self.room_ids = {}
    self.place_ids = {}
    self.object_ids = {}
    # Mapping: agent_id -> {timestep -> NodeId}
    self.agent_pose_ids = {}
    # Residual Registry
    self._residual_registry: Dict[str, ResidualFn] = {}
    self._compiled_solvers: Dict[Tuple[str, str], Any] = {}
    # Active window template fields (for slot-based mode)
    self._active_template: Optional[ActiveWindowTemplate] = None
    self._var_slots: Dict[Tuple[str, int], VarSlot] = {}
    self._factor_slots: Dict[Tuple[str, int], FactorSlot] = {}
    self._active_factor_mask: Dict[FactorId, bool] = {}

add_agent_pose(agent_id, t, value, var_type='pose')

Add (and register) a pose for a particular agent at a timestep.

This convenience helper is meant for dynamic scene graphs where you track multiple agents over time. It simply delegates to :meth:add_variable and then records the mapping (agent_id, t).

:param agent_id: String identifier for the agent (e.g. "robot_0"). :param t: Discrete timestep index. :param value: Initial pose value for this agent at time t. :param var_type: Underlying variable type to use (defaults to "pose"; you can change this to "pose_se3" in advanced use-cases). :returns: The :class:NodeId of the new agent pose variable.

Source code in dsg-jit/dsg_jit/world/model.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def add_agent_pose(
    self,
    agent_id: str,
    t: int,
    value: jnp.ndarray,
    var_type: str = "pose",
) -> NodeId:
    """Add (and register) a pose for a particular agent at a timestep.

    This convenience helper is meant for dynamic scene graphs where you
    track multiple agents over time. It simply delegates to
    :meth:`add_variable` and then records the mapping ``(agent_id, t)``.

    :param agent_id: String identifier for the agent (e.g. ``"robot_0"``).
    :param t: Discrete timestep index.
    :param value: Initial pose value for this agent at time ``t``.
    :param var_type: Underlying variable type to use (defaults to
        ``"pose"``; you can change this to ``"pose_se3"`` in advanced
        use-cases).
    :returns: The :class:`NodeId` of the new agent pose variable.
    """
    nid = self.add_variable(var_type, value)
    if agent_id not in self.agent_pose_ids:
        self.agent_pose_ids[agent_id] = {}
    self.agent_pose_ids[agent_id][t] = nid
    return nid

add_camera_bearings(pose_id, landmark_ids, bearings, weight=None, factor_type='pose_landmark_bearing')

Add one or more camera bearing factors for a single pose.

This is a thin convenience wrapper for camera-like measurements that observe known landmarks via bearing (direction) only. It assumes that the underlying factor type is implemented by a residual such as :func:slam.measurements.pose_landmark_bearing_residual.

Each row of :param:bearings is expected to correspond to one landmark in :param:landmark_ids. The dimensionality (e.g. 2D angle or 3D unit vector) is left to the residual function.

:param pose_id: Identifier of the pose variable from which all bearings are taken. :param landmark_ids: List of landmark node identifiers, one per row in bearings. :param bearings: Array of shape (N, D) containing bearing measurements in the sensor or camera frame. :param weight: Optional scalar weight or inverse noise level applied uniformly to all bearings in this call. If None, the default inside the residual is used. :param factor_type: Factor type string to register in the underlying :class:FactorGraph. Defaults to "pose_landmark_bearing". :returns: The :class:FactorId of the last factor added. One factor is added per (pose, landmark) pair.

Source code in dsg-jit/dsg_jit/world/model.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def add_camera_bearings(
    self,
    pose_id: NodeId,
    landmark_ids: list[NodeId],
    bearings: jnp.ndarray,
    weight: float | None = None,
    factor_type: str = "pose_landmark_bearing",
) -> FactorId:
    """Add one or more camera bearing factors for a single pose.

    This is a thin convenience wrapper for camera-like measurements that
    observe known landmarks via bearing (direction) only. It assumes that
    the underlying factor type is implemented by a residual such as
    :func:`slam.measurements.pose_landmark_bearing_residual`.

    Each row of :param:`bearings` is expected to correspond to one
    landmark in :param:`landmark_ids`. The dimensionality (e.g. 2D angle
    or 3D unit vector) is left to the residual function.

    :param pose_id: Identifier of the pose variable from which all
        bearings are taken.
    :param landmark_ids: List of landmark node identifiers, one per row
        in ``bearings``.
    :param bearings: Array of shape ``(N, D)`` containing bearing
        measurements in the sensor or camera frame.
    :param weight: Optional scalar weight or inverse noise level applied
        uniformly to all bearings in this call. If ``None``, the default
        inside the residual is used.
    :param factor_type: Factor type string to register in the underlying
        :class:`FactorGraph`. Defaults to ``"pose_landmark_bearing"``.
    :returns: The :class:`FactorId` of the last factor added. One factor
        is added per (pose, landmark) pair.
    """
    if bearings.shape[0] != len(landmark_ids):
        raise ValueError(
            "add_camera_bearings expected len(landmark_ids) == bearings.shape[0], "
            f"got {len(landmark_ids)} vs {bearings.shape[0]}"
        )

    last_fid: FactorId | None = None
    for lm_id, b in zip(landmark_ids, bearings):
        params: Dict[str, object] = {"bearing": jnp.asarray(b)}
        if weight is not None:
            params["weight"] = float(weight)
        last_fid = self.add_factor(factor_type, [pose_id, lm_id], params)

    # mypy/linters: last_fid will never be None if bearings is non-empty.
    if last_fid is None:
        raise ValueError("add_camera_bearings called with empty bearings array.")
    return last_fid

add_factor(f_type, var_ids, params)

Add a new factor to the underlying factor graph.

This allocates a fresh :class:FactorId, normalizes the input variable identifiers to :class:NodeId instances, constructs a :class:core.types.Factor, and registers it in :attr:fg.

:param f_type: String identifying the factor type. This must match a key in :attr:FactorGraph.residual_fns so that the appropriate residual function can be looked up during optimization. :param var_ids: Iterable of variable identifiers (ints or :class:NodeId instances) that this factor connects. :param params: Dictionary of factor parameters passed through to the residual function (e.g. measurements, noise models, weights). :returns: The :class:FactorId of the newly added factor.

Source code in dsg-jit/dsg_jit/world/model.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def add_factor(self, f_type: str, var_ids, params: Dict) -> FactorId:
    """Add a new factor to the underlying factor graph.

    This allocates a fresh :class:`FactorId`, normalizes the input
    variable identifiers to :class:`NodeId` instances, constructs a
    :class:`core.types.Factor`, and registers it in :attr:`fg`.

    :param f_type: String identifying the factor type. This must match a
        key in :attr:`FactorGraph.residual_fns` so that the appropriate
        residual function can be looked up during optimization.
    :param var_ids: Iterable of variable identifiers (ints or
        :class:`NodeId` instances) that this factor connects.
    :param params: Dictionary of factor parameters passed through to the
        residual function (e.g. measurements, noise models, weights).
    :returns: The :class:`FactorId` of the newly added factor.
    """
    # Allocate a fresh FactorId. We cannot rely on len(self.fg.factors)
    # when factors may have been removed (e.g. after marginalization),
    # so we take the maximum existing id and add one.
    if self.fg.factors:
        max_existing_id = max(int(fid) for fid in self.fg.factors.keys())
        fid_int = max_existing_id + 1
    else:
        fid_int = 0
    fid = FactorId(fid_int)

    # Normalize everything to NodeId
    node_ids = tuple(NodeId(int(vid)) for vid in var_ids)

    f = Factor(
        id=fid,
        type=f_type,
        var_ids=node_ids,
        params=params,
    )
    self.fg.add_factor(f)
    # Adding a factor changes the factor graph structure; clear cached
    # compiled solvers / residuals so they can be rebuilt consistently.
    return fid

add_imu_preintegration_factor(pose_i, pose_j, delta, weight=None, factor_type='pose_imu_preintegration')

Add an IMU preintegration-style factor between two poses.

This is intended to work with a preintegrated IMU summary (e.g. as produced by :mod:sensors.imu), where delta contains fields such as "dR", "dv", "dp", and corresponding covariance or information terms.

The exact keys expected in delta are left to the residual implementation for factor_type, but by storing the dictionary unchanged in params["delta"] we keep this interface flexible.

:param pose_i: NodeId of the starting pose (time :math:t_k). :param pose_j: NodeId of the ending pose (time :math:t_{k+1}). :param delta: Dictionary describing the preintegrated IMU increment between pose_i and pose_j. All arrays should be JAX arrays or types convertible via :func:jax.numpy.asarray. :param weight: Optional scalar weight / scaling to apply to the IMU factor inside the residual. :param factor_type: Factor type string to register; by default this is "pose_imu_preintegration". :returns: The :class:FactorId of the created IMU factor.

Source code in dsg-jit/dsg_jit/world/model.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def add_imu_preintegration_factor(
    self,
    pose_i: NodeId,
    pose_j: NodeId,
    delta: Dict[str, jnp.ndarray],
    weight: float | None = None,
    factor_type: str = "pose_imu_preintegration",
) -> FactorId:
    """Add an IMU preintegration-style factor between two poses.

    This is intended to work with a preintegrated IMU summary (e.g. as
    produced by :mod:`sensors.imu`), where ``delta`` contains fields such
    as ``"dR"``, ``"dv"``, ``"dp"``, and corresponding covariance or
    information terms.

    The exact keys expected in ``delta`` are left to the residual
    implementation for ``factor_type``, but by storing the dictionary
    unchanged in ``params["delta"]`` we keep this interface flexible.

    :param pose_i: NodeId of the starting pose (time :math:`t_k`).
    :param pose_j: NodeId of the ending pose (time :math:`t_{k+1}`).
    :param delta: Dictionary describing the preintegrated IMU increment
        between ``pose_i`` and ``pose_j``. All arrays should be JAX
        arrays or types convertible via :func:`jax.numpy.asarray`.
    :param weight: Optional scalar weight / scaling to apply to the IMU
        factor inside the residual.
    :param factor_type: Factor type string to register; by default this is
        ``"pose_imu_preintegration"``.
    :returns: The :class:`FactorId` of the created IMU factor.
    """
    params: Dict[str, object] = {"delta": {k: jnp.asarray(v) for k, v in delta.items()}}
    if weight is not None:
        params["weight"] = float(weight)
    return self.add_factor(factor_type, [pose_i, pose_j], params)

add_lidar_ranges(pose_id, landmark_ids, ranges, directions=None, weight=None, factor_type='pose_lidar_range')

Add LiDAR-style range factors for a single pose.

This helper is intended for simple range-only or range-with-direction measurements to known landmarks, coming from a LiDAR or depth sensor.

The interpretation of directions depends on the chosen residual implementation, but a common convention is that each row is a unit vector in the sensor frame pointing toward the target.

:param pose_id: Identifier of the pose variable from which ranges are measured. :param landmark_ids: List of landmark node identifiers, one per range sample. :param ranges: Array of shape (N,) holding range values in meters. :param directions: Optional array of shape (N, 3) with unit direction vectors associated with each range measurement. :param weight: Optional scalar weight applied to all range factors. :param factor_type: Factor type string to register; by default this is "pose_lidar_range". The residual function for this type is expected to consume "range" and optionally "direction" in params. :returns: The :class:FactorId of the last factor added.

Source code in dsg-jit/dsg_jit/world/model.py
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def add_lidar_ranges(
    self,
    pose_id: NodeId,
    landmark_ids: list[NodeId],
    ranges: jnp.ndarray,
    directions: Optional[jnp.ndarray] = None,
    weight: float | None = None,
    factor_type: str = "pose_lidar_range",
) -> FactorId:
    """Add LiDAR-style range factors for a single pose.

    This helper is intended for simple range-only or range-with-direction
    measurements to known landmarks, coming from a LiDAR or depth sensor.

    The interpretation of ``directions`` depends on the chosen residual
    implementation, but a common convention is that each row is a unit
    vector in the sensor frame pointing toward the target.

    :param pose_id: Identifier of the pose variable from which ranges
        are measured.
    :param landmark_ids: List of landmark node identifiers, one per range
        sample.
    :param ranges: Array of shape ``(N,)`` holding range values in meters.
    :param directions: Optional array of shape ``(N, 3)`` with unit
        direction vectors associated with each range measurement.
    :param weight: Optional scalar weight applied to all range factors.
    :param factor_type: Factor type string to register; by default this is
        ``"pose_lidar_range"``. The residual function for this type is
        expected to consume ``"range"`` and optionally ``"direction"`` in
        ``params``.
    :returns: The :class:`FactorId` of the last factor added.
    """
    if ranges.shape[0] != len(landmark_ids):
        raise ValueError(
            "add_lidar_ranges expected len(landmark_ids) == ranges.shape[0], "
            f"got {len(landmark_ids)} vs {ranges.shape[0]}"
        )
    if directions is not None and directions.shape[0] != ranges.shape[0]:
        raise ValueError(
            "add_lidar_ranges expected directions.shape[0] == ranges.shape[0], "
            f"got {directions.shape[0]} vs {ranges.shape[0]}"
        )

    last_fid: FactorId | None = None
    for i, lm_id in enumerate(landmark_ids):
        params: Dict[str, object] = {"range": float(ranges[i])}
        if directions is not None:
            params["direction"] = jnp.asarray(directions[i])
        if weight is not None:
            params["weight"] = float(weight)
        last_fid = self.add_factor(factor_type, [pose_id, lm_id], params)

    if last_fid is None:
        raise ValueError("add_lidar_ranges called with empty ranges array.")
    return last_fid

add_object(center, name=None)

Add an object centroid variable (3D point).

:param center: 3D position of the object centroid. :param name: Optional semantic name to register in :attr:object_ids. :returns: The :class:NodeId of the new object variable.

Source code in dsg-jit/dsg_jit/world/model.py
365
366
367
368
369
370
371
372
373
374
375
def add_object(self, center: jnp.ndarray, name: Optional[str] = None) -> NodeId:
    """Add an object centroid variable (3D point).

    :param center: 3D position of the object centroid.
    :param name: Optional semantic name to register in :attr:`object_ids`.
    :returns: The :class:`NodeId` of the new object variable.
    """
    nid = self.add_variable("object", center)
    if name is not None:
        self.object_ids[name] = nid
    return nid

add_place(center, name=None)

Add a place / waypoint variable (3D point).

:param center: 3D position of the place/waypoint. :param name: Optional semantic name to register in :attr:place_ids. :returns: The :class:NodeId of the new place variable.

Source code in dsg-jit/dsg_jit/world/model.py
353
354
355
356
357
358
359
360
361
362
363
def add_place(self, center: jnp.ndarray, name: Optional[str] = None) -> NodeId:
    """Add a place / waypoint variable (3D point).

    :param center: 3D position of the place/waypoint.
    :param name: Optional semantic name to register in :attr:`place_ids`.
    :returns: The :class:`NodeId` of the new place variable.
    """
    nid = self.add_variable("place", center)
    if name is not None:
        self.place_ids[name] = nid
    return nid

add_pose(value, name=None)

Add an SE(3) pose variable.

This is a thin wrapper around :meth:add_variable. If name is provided, the pose is also registered in :attr:pose_ids, which can be useful for scene-graph style code that wants stable, human-readable handles.

:param value: Initial pose value, typically a 6D se(3) vector. :param name: Optional semantic name used as a key in :attr:pose_ids. :returns: The :class:NodeId of the newly created pose variable.

Source code in dsg-jit/dsg_jit/world/model.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def add_pose(self, value: jnp.ndarray, name: Optional[str] = None) -> NodeId:
    """Add an SE(3) pose variable.

    This is a thin wrapper around :meth:`add_variable`. If ``name`` is
    provided, the pose is also registered in :attr:`pose_ids`, which can
    be useful for scene-graph style code that wants stable, human-readable
    handles.

    :param value: Initial pose value, typically a 6D se(3) vector.
    :param name: Optional semantic name used as a key in :attr:`pose_ids`.
    :returns: The :class:`NodeId` of the newly created pose variable.
    """
    nid = self.add_variable("pose", value)
    if name is not None:
        self.pose_ids[name] = nid
    return nid

add_room(center, name=None)

Add a room center variable (3D point).

:param center: 3D position of the room center. :param name: Optional semantic name to register in :attr:room_ids. :returns: The :class:NodeId of the new room variable.

Source code in dsg-jit/dsg_jit/world/model.py
341
342
343
344
345
346
347
348
349
350
351
def add_room(self, center: jnp.ndarray, name: Optional[str] = None) -> NodeId:
    """Add a room center variable (3D point).

    :param center: 3D position of the room center.
    :param name: Optional semantic name to register in :attr:`room_ids`.
    :returns: The :class:`NodeId` of the new room variable.
    """
    nid = self.add_variable("room", center)
    if name is not None:
        self.room_ids[name] = nid
    return nid

add_variable(var_type, value)

Add a new variable to the underlying factor graph.

This allocates a fresh :class:NodeId, constructs a :class:core.types.Variable with the given type and initial value, registers it in :attr:fg, and returns the newly created id.

:param var_type: String describing the variable type (e.g. "pose", "room", "place", "object"). This is used by residual functions and manifold metadata to interpret the state. :param value: Initial value for the variable, represented as a 1D JAX array. The dimensionality is inferred from value.shape[0]. :returns: The :class:NodeId of the newly added variable.

Source code in dsg-jit/dsg_jit/world/model.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def add_variable(self, var_type: str, value: jnp.ndarray) -> NodeId:
    """Add a new variable to the underlying factor graph.

    This allocates a fresh :class:`NodeId`, constructs a
    :class:`core.types.Variable` with the given type and initial value,
    registers it in :attr:`fg`, and returns the newly created id.

    :param var_type: String describing the variable type (e.g. ``"pose"``,
        ``"room"``, ``"place"``, ``"object"``). This is used by
        residual functions and manifold metadata to interpret the state.
    :param value: Initial value for the variable, represented as a
        1D JAX array. The dimensionality is inferred from
        ``value.shape[0]``.
    :returns: The :class:`NodeId` of the newly added variable.
    """
    # Allocate a fresh NodeId. We cannot rely on len(self.fg.variables)
    # when variables may have been removed (e.g. after marginalization),
    # so we take the maximum existing id and add one.
    if self.fg.variables:
        max_existing_id = max(int(nid) for nid in self.fg.variables.keys())
        nid_int = max_existing_id + 1
    else:
        nid_int = 0
    nid = NodeId(nid_int)
    v = Variable(id=nid, type=var_type, value=value)
    self.fg.add_variable(v)
    # Graph structure has changed; clear any cached compiled solvers
    # and residuals so they can be rebuilt on demand.
    return nid

build_objective()

Construct a scalar objective f(x) = ||r(x)||^2.

This wraps :meth:build_residual and returns a function that computes the squared L2 norm of the residual vector.

:return: JIT-compiled objective function f(x). :rtype: Callable[[jnp.ndarray], jnp.ndarray]

Source code in dsg-jit/dsg_jit/world/model.py
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
def build_objective(self):
    """Construct a scalar objective ``f(x) = ||r(x)||^2``.

    This wraps :meth:`build_residual` and returns a function
    that computes the squared L2 norm of the residual vector.

    :return: JIT-compiled objective function ``f(x)``.
    :rtype: Callable[[jnp.ndarray], jnp.ndarray]
    """
    residual = self.build_residual()

    def objective(x: jnp.ndarray) -> jnp.ndarray:
        r = residual(x)
        return jnp.sum(r ** 2)

    return jax.jit(objective)

build_residual(*, use_type_weights=False, learn_odom=False, learn_voxel_points=False)

Construct a unified residual function for the current world.

This method is the WorldModel-level entry point for building a JAX-compatible residual function that stacks all factor residuals. It is intended to subsume the various specialized builders that previously lived on :class:FactorGraph, such as:

  • build_residual_function_with_type_weights
  • build_residual_function_se3_odom_param_multi
  • build_residual_function_voxel_point_param[_multi]

Instead of having separate entry points, this method exposes a single interface whose behavior is controlled by configuration flags and a structured "hyper-parameter" argument passed at call time.

Parameters

use_type_weights : bool, optional Currently unused in this implementation. Reserved for future integration with type-weighted residuals. learn_odom : bool, optional Currently unused in this implementation. Reserved for future integration with learnable odometry parameters. learn_voxel_points : bool, optional Currently unused in this implementation. Reserved for future integration with learnable voxel observation points.

Returns

callable A JAX-compatible residual function. In the simplest case (all flags False) the signature is r(x) where x is a packed state vector.

Source code in dsg-jit/dsg_jit/world/model.py
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def build_residual(
    self,
    *,
    use_type_weights: bool = False,
    learn_odom: bool = False,
    learn_voxel_points: bool = False,
) -> Callable[..., Any]:
    """Construct a unified residual function for the current world.

    This method is the WorldModel-level entry point for building a
    JAX-compatible residual function that stacks all factor residuals.
    It is intended to subsume the various specialized builders that
    previously lived on :class:`FactorGraph`, such as:

    * ``build_residual_function_with_type_weights``
    * ``build_residual_function_se3_odom_param_multi``
    * ``build_residual_function_voxel_point_param[_multi]``

    Instead of having separate entry points, this method exposes a
    single interface whose behavior is controlled by configuration
    flags and a structured "hyper-parameter" argument passed at call
    time.

    Parameters
    ----------
    use_type_weights : bool, optional
        Currently unused in this implementation. Reserved for future
        integration with type-weighted residuals.
    learn_odom : bool, optional
        Currently unused in this implementation. Reserved for future
        integration with learnable odometry parameters.
    learn_voxel_points : bool, optional
        Currently unused in this implementation. Reserved for future
        integration with learnable voxel observation points.

    Returns
    -------
    callable
        A JAX-compatible residual function. In the simplest case
        (all flags ``False``) the signature is ``r(x)`` where ``x`` is
        a packed state vector.
    """
    # NOTE: For now, the configuration flags are accepted but not yet
    # wired into the implementation. They are kept in the signature to
    # preserve the planned API surface and avoid breaking callers.
    if use_type_weights or learn_odom or learn_voxel_points:
        raise NotImplementedError(
            "Hyper-parameterized residuals are provided by dedicated "
            "WorldModel helper methods (e.g. "
            "build_residual_function_with_type_weights, "
            "build_residual_function_se3_odom_param_multi). "
            "The generic build_residual hyper-parameter flags are not "
            "yet implemented."
        )

    # Slot-based mode: Use a constant cache key and enforce fixed structure.
    if self._active_template is not None:
        cache_key = ("residual", "active_template")
    else:
        # Legacy dynamic mode: cache by structure.
        factors = tuple(self.fg.factors.values())
        var_count = len(self.fg.variables)
        sig_parts = [f"{f.type}:{len(f.var_ids)}" for f in factors]
        structure_sig = f"v{var_count}|" + "|".join(sig_parts)
        cache_key = ("residual", structure_sig)

    cached = self._compiled_solvers.get(cache_key)
    if cached is not None:
        return cached

    # Group factors as before.
    factors = tuple(self.fg.factors.values())
    group_to_factors: Dict[Tuple[str, Tuple[int, ...]], List[Factor]] = {}
    for f in factors:
        var_dims: List[int] = []
        for nid in f.var_ids:
            v = self.fg.variables[nid].value
            var_dims.append(int(jnp.asarray(v).shape[0]))
        shape_sig = tuple(var_dims)
        key = (f.type, shape_sig)
        group_to_factors.setdefault(key, []).append(f)

    residual_fns = self._residual_registry
    _, index = self.pack_state()

    def residual(x: jnp.ndarray) -> jnp.ndarray:
        """Stacked residual function over all factors for the current graph, with slot-based activity mask if present."""
        var_values = self.unpack_state(x, index)
        res_chunks: List[jnp.ndarray] = []

        for (f_type, _shape_sig), flist in group_to_factors.items():
            res_fn = residual_fns.get(f_type, None)
            if res_fn is None:
                raise ValueError(
                    f"No residual fn registered for factor type '{f_type}'"
                )
            if not flist:
                continue
            # Singleton group
            if len(flist) == 1:
                f = flist[0]
                vs = [var_values[nid] for nid in f.var_ids]
                stacked = jnp.concatenate(vs)
                # Slot-based: multiply by "active" param if present
                r = res_fn(stacked, f.params)
                activity = f.params.get("active", 1.0)
                r = r * activity
                res_chunks.append(jnp.reshape(r, (-1,)))
                continue
            # Batched path
            stacked_states: List[jnp.ndarray] = []
            params_list: List[Dict[str, Any]] = []
            for f in flist:
                vs = [var_values[nid] for nid in f.var_ids]
                stacked_states.append(jnp.concatenate(vs))
                params_list.append(f.params)
            stacked_states_arr = jnp.stack(stacked_states, axis=0)
            params_tree = jtu.tree_map(
                lambda *vals: jnp.stack(
                    [jnp.asarray(v) for v in vals], axis=0
                ),
                *params_list,
            )
            def single_factor_residual(s: jnp.ndarray, p: Dict[str, Any]) -> jnp.ndarray:
                r = res_fn(s, p)
                activity = p.get("active", 1.0)
                return r * activity
            batched_res = jax.vmap(single_factor_residual)(
                stacked_states_arr, params_tree
            )
            res_chunks.append(jnp.reshape(batched_res, (-1,)))
        if not res_chunks:
            return jnp.zeros((0,), dtype=x.dtype)
        return jnp.concatenate(res_chunks, axis=0)

    residual_jit = jax.jit(residual)
    self._compiled_solvers[cache_key] = residual_jit
    return residual_jit

build_residual_function_se3_odom_param_multi()

Build a residual function with learnable SE(3) odometry.

All factors of type "odom_se3" are treated as depending on a parameter array theta of shape (K, 6), where K is the number of odometry factors. Each row of theta represents a perturbable se(3) measurement.

Returns

(residual_fn, index) residual_fn(x, theta) and the pack index mapping from :meth:pack_state.

Source code in dsg-jit/dsg_jit/world/model.py
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
def build_residual_function_se3_odom_param_multi(self):
    """Build a residual function with learnable SE(3) odometry.

    All factors of type ``\"odom_se3\"`` are treated as depending on a
    parameter array ``theta`` of shape ``(K, 6)``, where ``K`` is the
    number of odometry factors. Each row of ``theta`` represents a
    perturbable se(3) measurement.

    Returns
    -------
    (residual_fn, index)
        ``residual_fn(x, theta)`` and the pack index mapping from
        :meth:`pack_state`.
    """
    factors = list(self.fg.factors.values())
    residual_fns = self._residual_registry

    _, index = self.pack_state()

    def residual(x: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray:
        """
        Parameters
        ----------
        x : jnp.ndarray
            Flat state vector.
        theta : jnp.ndarray
            Shape (K, 6), per-odom se(3) measurement.
        """
        var_values = self.unpack_state(x, index)
        res_list: List[jnp.ndarray] = []
        odom_idx = 0

        for f in factors:
            res_fn = residual_fns.get(f.type, None)
            if res_fn is None:
                raise ValueError(
                    f"No residual fn registered for factor type '{f.type}'"
                )

            stacked = jnp.concatenate([var_values[vid] for vid in f.var_ids])

            if f.type == "odom_se3":
                meas = theta[odom_idx]  # (6,)
                odom_idx += 1
                base_params = dict(f.params)
                base_params["measurement"] = meas
                params = base_params
            else:
                params = f.params

            r = res_fn(stacked, params)
            w = params.get("weight", 1.0)
            res_list.append(jnp.sqrt(w) * r)

        if not res_list:
            return jnp.zeros((0,), dtype=x.dtype)

        return jnp.concatenate(res_list)

    return residual, index

build_residual_function_voxel_point_param()

Build a residual function with a shared voxel observation point.

All factors of type "voxel_point_obs" will use a dynamic point_world argument passed at call time, rather than a fixed value stored in the factor params.

Returns

(residual_fn, index) residual_fn(x, point_world) where point_world has shape (3,).

Source code in dsg-jit/dsg_jit/world/model.py
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def build_residual_function_voxel_point_param(self):
    """Build a residual function with a shared voxel observation point.

    All factors of type ``\"voxel_point_obs\"`` will use a dynamic
    ``point_world`` argument passed at call time, rather than a fixed
    value stored in the factor params.

    Returns
    -------
    (residual_fn, index)
        ``residual_fn(x, point_world)`` where ``point_world`` has
        shape (3,).
    """
    factors = list(self.fg.factors.values())
    residual_fns = self._residual_registry

    _, index = self.pack_state()

    def residual(x: jnp.ndarray, point_world: jnp.ndarray) -> jnp.ndarray:
        """
        Parameters
        ----------
        x : jnp.ndarray
            Flat state vector.
        point_world : jnp.ndarray
            Shape (3,), observation point in world coords for ALL
            voxel_point_obs factors. For now we assume a single
            voxel_point_obs, or that all share the same point.
        """
        var_values = self.unpack_state(x, index)
        res_list: List[jnp.ndarray] = []

        for f in factors:
            res_fn = residual_fns.get(f.type, None)
            if res_fn is None:
                raise ValueError(
                    f"No residual fn registered for factor type '{f.type}'"
                )

            stacked = jnp.concatenate([var_values[vid] for vid in f.var_ids])

            if f.type == "voxel_point_obs":
                base_params = dict(f.params)
                base_params["point_world"] = point_world
                params = base_params
            else:
                params = f.params

            r = res_fn(stacked, params)
            w = params.get("weight", 1.0)
            res_list.append(jnp.sqrt(w) * r)

        if not res_list:
            return jnp.zeros((0,), dtype=x.dtype)

        return jnp.concatenate(res_list)

    return residual, index

build_residual_function_voxel_point_param_multi()

Build a residual function with per-factor voxel observation points.

Each "voxel_point_obs" factor consumes a row of the parameter array theta of shape (K, 3), where K is the number of such factors.

Returns

(residual_fn, index) residual_fn(x, theta) where theta has shape (K, 3).

Source code in dsg-jit/dsg_jit/world/model.py
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
def build_residual_function_voxel_point_param_multi(self):
    """Build a residual function with per-factor voxel observation points.

    Each ``\"voxel_point_obs\"`` factor consumes a row of the parameter
    array ``theta`` of shape ``(K, 3)``, where ``K`` is the number of
    such factors.

    Returns
    -------
    (residual_fn, index)
        ``residual_fn(x, theta)`` where ``theta`` has shape (K, 3).
    """
    factors = list(self.fg.factors.values())
    residual_fns = self._residual_registry

    _, index = self.pack_state()

    def residual(x: jnp.ndarray, theta: jnp.ndarray) -> jnp.ndarray:
        """
        Parameters
        ----------
        x : jnp.ndarray
            Flat state vector.
        theta : jnp.ndarray
            Shape (K, 3), per-voxel-point observation in world
            coordinates.
        """
        var_values = self.unpack_state(x, index)
        res_list: List[jnp.ndarray] = []
        obs_idx = 0  # python counter over voxel_point_obs factors

        for f in factors:
            res_fn = residual_fns.get(f.type, None)
            if res_fn is None:
                raise ValueError(
                    f"No residual fn registered for factor type '{f.type}'"
                )

            stacked = jnp.concatenate([var_values[vid] for vid in f.var_ids])

            if f.type == "voxel_point_obs":
                point_world = theta[obs_idx]  # (3,)
                obs_idx += 1
                base_params = dict(f.params)
                base_params["point_world"] = point_world
                params = base_params
            else:
                params = f.params

            r = res_fn(stacked, params)
            w = params.get("weight", 1.0)
            res_list.append(jnp.sqrt(w) * r)

        if not res_list:
            return jnp.zeros((0,), dtype=x.dtype)

        return jnp.concatenate(res_list)

    return residual, index

build_residual_function_with_type_weights(factor_type_order)

Build a residual function that supports learnable type weights.

The returned function has signature r(x, log_scales) where log_scales[i] is the log-weight associated with factor_type_order[i]. Missing types default to unit weight.

This is a WorldModel-based version of the old FactorGraph helper, implemented in terms of pack_state, unpack_state, and the WorldModel residual registry.

Source code in dsg-jit/dsg_jit/world/model.py
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
def build_residual_function_with_type_weights(
    self, factor_type_order: List[str]
):
    """Build a residual function that supports learnable type weights.

    The returned function has signature ``r(x, log_scales)`` where
    ``log_scales[i]`` is the log-weight associated with
    ``factor_type_order[i]``. Missing types default to unit weight.

    This is a WorldModel-based version of the old FactorGraph helper,
    implemented in terms of ``pack_state``, ``unpack_state``, and the
    WorldModel residual registry.
    """
    factors = list(self.fg.factors.values())
    residual_fns = self._residual_registry
    _, index = self.pack_state()

    type_to_idx = {t: i for i, t in enumerate(factor_type_order)}

    def residual(x: jnp.ndarray, log_scales: jnp.ndarray) -> jnp.ndarray:
        var_values = self.unpack_state(x, index)
        res_list: List[jnp.ndarray] = []

        for factor in factors:
            res_fn = residual_fns.get(factor.type, None)
            if res_fn is None:
                raise ValueError(
                    f"No residual fn registered for factor type '{factor.type}'"
                )

            stacked = jnp.concatenate(
                [var_values[vid] for vid in factor.var_ids], axis=0
            )
            r = res_fn(stacked, factor.params)  # (k,)

            idx = type_to_idx.get(factor.type, None)
            if idx is not None:
                scale = jnp.exp(log_scales[idx])
            else:
                scale = 1.0

            r_scaled = scale * r
            r_scaled = jnp.reshape(r_scaled, (-1,))
            res_list.append(r_scaled)

        if not res_list:
            return jnp.zeros((0,), dtype=x.dtype)

        return jnp.concatenate(res_list, axis=0)

    return residual

configure_factor_slot(factor_type, slot_idx, var_ids, params, active=True)

Configure a factor slot in the active template: set variable ids, params, and activity.

Source code in dsg-jit/dsg_jit/world/model.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def configure_factor_slot(
    self,
    factor_type: str,
    slot_idx: int,
    var_ids: Tuple[NodeId, ...],
    params: Dict,
    active: bool = True,
) -> None:
    """Configure a factor slot in the active template: set variable ids, params, and activity."""
    slot_key = (factor_type, slot_idx)
    slot = self._factor_slots.get(slot_key)
    if slot is None:
        raise KeyError(f"Factor slot {slot_key} not found in active template.")
    fid = slot.factor_id
    f = self.fg.factors[fid]
    # Update factor's var_ids and params in place.
    object.__setattr__(f, "var_ids", tuple(var_ids))

    # IMPORTANT: preserve existing keys so vmapped stacking sees a
    # consistent pytree structure across all factors in a batched group.
    new_params = dict(f.params)
    new_params.update(params)
    # Normalize scalar params to JAX arrays for stable stacking.
    for k, v in list(new_params.items()):
        if isinstance(v, (float, int)):
            new_params[k] = jnp.array(v, dtype=jnp.float32)
    new_params["active"] = jnp.array(1.0 if active else 0.0, dtype=jnp.float32)
    f.params = new_params
    self._active_factor_mask[fid] = active

fixed_lag_marginalize(keep_ids, damping=1e-06)

Disabled: Fixed-lag marginalization is not supported in active template mode. Use bounded active templates for sliding window/fixed-lag smoothing instead.

Source code in dsg-jit/dsg_jit/world/model.py
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
def fixed_lag_marginalize(
    self,
    keep_ids: List[NodeId],
    damping: float = 1e-6,
) -> None:
    """
    Disabled: Fixed-lag marginalization is not supported in active template mode.
    Use bounded active templates for sliding window/fixed-lag smoothing instead.
    """
    if self._active_template is not None:
        # Fixed-lag smoothing is handled via bounded active templates.
        # This method is disabled in slot-based mode.
        return
    # (Legacy code for dynamic mode could be restored here if needed.)
    pass

get_residual(factor_type)

Return the residual function registered for a given factor type.

Parameters

factor_type : str String identifier for the factor type.

Returns

callable or None The residual function previously registered via :meth:register_residual, or None if no function is registered for the requested type.

Source code in dsg-jit/dsg_jit/world/model.py
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
def get_residual(self, factor_type: str) -> Optional[Callable[..., Any]]:
    """Return the residual function registered for a given factor type.

    Parameters
    ----------
    factor_type : str
        String identifier for the factor type.

    Returns
    -------
    callable or None
        The residual function previously registered via
        :meth:`register_residual`, or ``None`` if no function is
        registered for the requested type.
    """
    return self._residual_registry.get(factor_type)

get_residuals()

Returns the residual registry, all currently registered residuals.

:return: Dict[str, ResidualFn]

Source code in dsg-jit/dsg_jit/world/model.py
729
730
731
732
733
734
def get_residuals(self) -> Dict[str, ResidualFn]:
    """Returns the residual registry, all currently registered residuals.

    :return: Dict[str, ResidualFn]
    """
    return self._residual_registry

get_variable_value(nid)

Return the current value of a variable.

This is a thin convenience wrapper over the underlying :class:FactorGraph variable storage and is useful when building dynamic scene graphs that want to query individual nodes.

:param nid: Identifier of the variable. :returns: A JAX array holding the variable's current value.

Source code in dsg-jit/dsg_jit/world/model.py
663
664
665
666
667
668
669
670
671
672
673
def get_variable_value(self, nid: NodeId) -> jnp.ndarray:
    """Return the current value of a variable.

    This is a thin convenience wrapper over the underlying
    :class:`FactorGraph` variable storage and is useful when building
    dynamic scene graphs that want to query individual nodes.

    :param nid: Identifier of the variable.
    :returns: A JAX array holding the variable's current value.
    """
    return self.fg.variables[nid].value

init_active_template(template)

Initialize a fixed-capacity active factor graph template for JIT-stable operation. All variables and factors are preallocated; structure is fixed.

Source code in dsg-jit/dsg_jit/world/model.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def init_active_template(self, template: ActiveWindowTemplate) -> None:
    """Initialize a fixed-capacity active factor graph template for JIT-stable operation.
    All variables and factors are preallocated; structure is fixed.
    """
    # Reset the factor graph and slot bookkeeping.
    self.fg = FactorGraph()
    self._active_template = template
    self._var_slots.clear()
    self._factor_slots.clear()
    self._active_factor_mask.clear()
    # Pre-allocate variable slots.
    for var_type, slot_idx, dim in template.variable_slots:
        init_val = jnp.zeros((dim,), dtype=jnp.float32)
        node_id = self.add_variable(var_type, init_val)
        slot_key = (var_type, slot_idx)
        self._var_slots[slot_key] = VarSlot(var_type, slot_idx, node_id, dim)
    # Pre-allocate factor slots (inactive by default).
    for factor_type, slot_idx, var_slot_keys in template.factor_slots:
        var_ids = tuple(self._var_slots[vk].node_id for vk in var_slot_keys)

        # Compute the stacked state dimension for this factor slot.
        stacked_dim = 0
        for vk in var_slot_keys:
            stacked_dim += int(self._var_slots[vk].dim)

        # IMPORTANT: In slot-based mode we rely on vmap + tree stacking.
        # That requires that all factors within a batched group have the
        # same params keys and compatible shapes.
        if factor_type == "prior":
            # Prior residual typically expects a 6D target for SE(3) poses.
            # We default to zeros and unit weight.
            params: Dict[str, Any] = {
                "target": jnp.zeros((stacked_dim,), dtype=jnp.float32),
                "weight": jnp.array(1.0, dtype=jnp.float32),
                "active": jnp.array(0.0, dtype=jnp.float32),
            }
        elif factor_type == "odom_se3":
            # Odometry measurement lives in the se(3) tangent space (6D),
            # even though the stacked state for two poses is 12D.
            params = {
                "measurement": jnp.zeros((6,), dtype=jnp.float32),
                "weight": jnp.array(1.0, dtype=jnp.float32),
                "active": jnp.array(0.0, dtype=jnp.float32),
            }
        elif factor_type == "marginal_prior":
            # Dense Gaussian prior induced by marginalization.
            params = {
                "mean": jnp.zeros((stacked_dim,), dtype=jnp.float32),
                "sqrt_info": jnp.eye(stacked_dim, dtype=jnp.float32),
                "weight": jnp.array(1.0, dtype=jnp.float32),
                "active": jnp.array(0.0, dtype=jnp.float32),
            }
        else:
            # Generic fallback: only active + weight.
            # Callers can override/extend keys via configure_factor_slot.
            params = {
                "weight": jnp.array(1.0, dtype=jnp.float32),
                "active": jnp.array(0.0, dtype=jnp.float32),
            }

        factor_id = self.add_factor(factor_type, var_ids, params)
        slot_key = (factor_type, slot_idx)
        self._factor_slots[slot_key] = FactorSlot(factor_type, slot_idx, factor_id, var_slot_keys)
        self._active_factor_mask[factor_id] = False

list_residual_types()

List all factor types with registered residual functions.

This is a convenience helper for debugging, diagnostics, and tests to verify that the WorldModel has been configured with the expected residuals for the current application.

Returns

list of str Sorted list of factor type strings for which residuals have been registered.

Source code in dsg-jit/dsg_jit/world/model.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
def list_residual_types(self) -> List[str]:
    """List all factor types with registered residual functions.

    This is a convenience helper for debugging, diagnostics, and tests
    to verify that the WorldModel has been configured with the expected
    residuals for the current application.

    Returns
    -------
    list of str
        Sorted list of factor type strings for which residuals have
        been registered.
    """
    return sorted(self._residual_registry.keys())

marginalize_variables(marginalized_ids, damping=1e-06)

Disabled: Marginalization is not supported in active template mode. Use bounded active templates for fixed-lag smoothing instead.

Source code in dsg-jit/dsg_jit/world/model.py
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def marginalize_variables(
    self,
    marginalized_ids: List[NodeId],
    damping: float = 1e-6,
) -> None:
    """
    Disabled: Marginalization is not supported in active template mode.
    Use bounded active templates for fixed-lag smoothing instead.
    """
    # If in active template mode, do nothing and explain.
    if self._active_template is not None:
        # Fixed-lag smoothing is handled via bounded active templates.
        # This method is disabled in slot-based mode.
        return
    # (Legacy code for dynamic mode could be restored here if needed.)
    pass

optimize(lr=0.1, iters=300, method='gd', damping=0.001, max_step_norm=1.0)

Run a local optimizer on the current world state.

This method packs the current variables into a flat state vector, constructs an appropriate objective or residual function, runs one of the supported optimizers, and writes the optimized state back into :attr:fg.variables.

Supported methods:

  • "gd": vanilla gradient descent on the scalar objective :math:\|r(x)\|^2.
  • "newton": damped Newton on the same scalar objective.
  • "gn": Gauss--Newton on the stacked residual vector assuming Euclidean variables.
  • "manifold_gn": manifold-aware Gauss--Newton that uses :func:slam.manifold.build_manifold_metadata to handle SE(3) and Euclidean blocks differently.
  • "gn_jit": JIT-compiled Gauss--Newton using :class:optimization.jit_wrappers.JittedGN.

:param lr: Learning rate for gradient-descent-based methods (currently used when method == "gd"). :param iters: Maximum number of iterations for the chosen optimizer. :param method: Name of the optimization method to use. See the list above for supported values. :param damping: Damping / regularization parameter used by the Newton and Gauss--Newton variants. :param max_step_norm: Maximum allowed step norm for Gauss--Newton methods; steps larger than this are clamped to improve stability. :returns: None. The world model is updated in place.

Source code in dsg-jit/dsg_jit/world/model.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def optimize(
    self,
    lr: float = 0.1,
    iters: int = 300,
    method: str = "gd",
    damping: float = 1e-3,
    max_step_norm: float = 1.0,
) -> None:
    """Run a local optimizer on the current world state.

    This method packs the current variables into a flat state vector,
    constructs an appropriate objective or residual function, runs one
    of the supported optimizers, and writes the optimized state back
    into :attr:`fg.variables`.

    Supported methods:

    - ``"gd"``: vanilla gradient descent on the scalar objective
      :math:`\\|r(x)\\|^2`.
    - ``"newton"``: damped Newton on the same scalar objective.
    - ``"gn"``: Gauss--Newton on the stacked residual vector assuming
      Euclidean variables.
    - ``"manifold_gn"``: manifold-aware Gauss--Newton that uses
      :func:`slam.manifold.build_manifold_metadata` to handle SE(3)
      and Euclidean blocks differently.
    - ``"gn_jit"``: JIT-compiled Gauss--Newton using
      :class:`optimization.jit_wrappers.JittedGN`.

    :param lr: Learning rate for gradient-descent-based methods
        (currently used when ``method == "gd"``).
    :param iters: Maximum number of iterations for the chosen optimizer.
    :param method: Name of the optimization method to use. See the list
        above for supported values.
    :param damping: Damping / regularization parameter used by the
        Newton and Gauss--Newton variants.
    :param max_step_norm: Maximum allowed step norm for Gauss--Newton
        methods; steps larger than this are clamped to improve stability.
    :returns: ``None``. The world model is updated in place.
    """
    x_init, index = self.pack_state()
    residual_fn = self.build_residual()

    if method == "gd":
        obj = self.build_objective()
        cfg = GDConfig(learning_rate=lr, max_iters=iters)
        x_opt = gradient_descent(obj, x_init, cfg)

    elif method == "newton":
        obj = self.build_objective()
        cfg = NewtonConfig(max_iters=iters, damping=damping)
        x_opt = damped_newton(obj, x_init, cfg)

    elif method == "gn":
        cfg = GNConfig(max_iters=iters, damping=damping, max_step_norm=max_step_norm)
        x_opt = gauss_newton(residual_fn, x_init, cfg)

    elif method == "manifold_gn":
        block_slices, manifold_types = build_manifold_metadata(packed_state=self.pack_state(),fg=self.fg)
        cfg = GNConfig(max_iters=iters, damping=damping, max_step_norm=max_step_norm)
        x_opt = gauss_newton_manifold(
            residual_fn, x_init, block_slices, manifold_types, cfg
        )

    elif method == "gn_jit":
        cfg = GNConfig(max_iters=iters, damping=damping, max_step_norm=max_step_norm)
        jgn = JittedGN.from_residual(residual_fn, cfg)
        x_opt = jgn(x_init)
    else:
        raise ValueError(f"Unknown optimization method '{method}'")

    # Write back
    values = self.unpack_state(x_opt, index)
    for nid, val in values.items():
        self.fg.variables[nid].value = val

pack_state()

Pack all variable values into a single flat JAX array.

The variables are ordered by sorted :class:NodeId to ensure stable indexing across calls.

:return: Tuple of (x, index) where x is the concatenated state vector and index is the mapping produced by :meth:_build_state_index. :rtype: Tuple[jnp.ndarray, Dict[NodeId, Tuple[int, int]]]

Source code in dsg-jit/dsg_jit/world/model.py
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
def pack_state(self) -> jnp.ndarray:
    """Pack all variable values into a single flat JAX array.

    The variables are ordered by sorted :class:`NodeId` to ensure stable
    indexing across calls.

    :return: Tuple of ``(x, index)`` where ``x`` is the concatenated
        state vector and ``index`` is the mapping produced by
        :meth:`_build_state_index`.
    :rtype: Tuple[jnp.ndarray, Dict[NodeId, Tuple[int, int]]]
    """
    index = self._build_state_index()
    chunks = []
    for node_id in sorted(self.fg.variables.keys()):
        var = self.fg.variables[node_id]
        chunks.append(jnp.asarray(var.value))
    return jnp.concatenate(chunks), index

register_residual(factor_type, fn)

Register a residual function for a given factor type.

This is the WorldModel-level registry that associates factor type strings (e.g. "odom_se3", "voxel_point_obs") with JAX-compatible residual functions. The registered functions are consumed by higher-level residual builders such as :meth:build_residual.

Parameters

factor_type : str String identifier for the factor type. This must match the type field stored in :class:Factor instances in the underlying :class:FactorGraph. fn : Callable Residual function implementing the measurement model. The exact signature is intentionally flexible, but it is expected to be compatible with the unified residual builder returned by :meth:build_residual (e.g. it may be vmapped across factors of a given type).

Source code in dsg-jit/dsg_jit/world/model.py
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
def register_residual(self, factor_type: str, fn: Callable[..., Any]) -> None:
    """Register a residual function for a given factor type.

    This is the WorldModel-level registry that associates factor type
    strings (e.g. ``"odom_se3"``, ``"voxel_point_obs"``) with
    JAX-compatible residual functions. The registered functions are
    consumed by higher-level residual builders such as
    :meth:`build_residual`.

    Parameters
    ----------
    factor_type : str
        String identifier for the factor type. This must match the
        ``type`` field stored in :class:`Factor` instances in the
        underlying :class:`FactorGraph`.
    fn : Callable
        Residual function implementing the measurement model. The
        exact signature is intentionally flexible, but it is expected
        to be compatible with the unified residual builder returned by
        :meth:`build_residual` (e.g. it may be vmapped across factors
        of a given type).
    """
    self._residual_registry[factor_type] = fn

set_variable_slot(var_type, slot_idx, value)

Set the value of a variable slot in the active template.

Source code in dsg-jit/dsg_jit/world/model.py
253
254
255
256
257
258
259
260
261
262
def set_variable_slot(self, var_type: str, slot_idx: int, value: jnp.ndarray) -> NodeId:
    """Set the value of a variable slot in the active template."""
    slot_key = (var_type, slot_idx)
    slot = self._var_slots.get(slot_key)
    if slot is None:
        raise KeyError(f"Variable slot {slot_key} not found in active template.")
    if value.shape[0] != slot.dim:
        raise ValueError(f"Value shape {value.shape} does not match slot dim {slot.dim}")
    self.fg.variables[slot.node_id].value = value
    return slot.node_id

snapshot_state()

Capture a shallow snapshot of the current world state.

The snapshot maps integer node ids to their current values. This is intentionally simple and serialization-friendly, and is meant to be consumed by higher-level dynamic scene graph structures that want to record the evolution of the world over time.

:returns: A dictionary mapping int(NodeId) to JAX arrays.

Source code in dsg-jit/dsg_jit/world/model.py
675
676
677
678
679
680
681
682
683
684
685
def snapshot_state(self) -> Dict[int, jnp.ndarray]:
    """Capture a shallow snapshot of the current world state.

    The snapshot maps integer node ids to their current values. This is
    intentionally simple and serialization-friendly, and is meant to be
    consumed by higher-level dynamic scene graph structures that want to
    record the evolution of the world over time.

    :returns: A dictionary mapping ``int(NodeId)`` to JAX arrays.
    """
    return {int(nid): jnp.array(var.value) for nid, var in self.fg.variables.items()}

unpack_state(x, index)

Unpack a flat state vector back into per-variable arrays.

:param x: Flattened state vector produced by :meth:pack_state or produced by an optimizer. :type x: jnp.ndarray :param index: Mapping from :class:NodeId to (start, dim) blocks as returned by :meth:_build_state_index. :type index: Dict[NodeId, Tuple[int, int]] :return: Mapping from node id to its corresponding slice of x. :rtype: Dict[NodeId, jnp.ndarray]

Source code in dsg-jit/dsg_jit/world/model.py
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
def unpack_state(self, x: jnp.ndarray, index: Dict[NodeId, Tuple[int, int]]) -> Dict[NodeId, jnp.ndarray]:
    """Unpack a flat state vector back into per-variable arrays.

    :param x: Flattened state vector produced by :meth:`pack_state` or
        produced by an optimizer.
    :type x: jnp.ndarray
    :param index: Mapping from :class:`NodeId` to ``(start, dim)`` blocks
        as returned by :meth:`_build_state_index`.
    :type index: Dict[NodeId, Tuple[int, int]]
    :return: Mapping from node id to its corresponding slice of ``x``.
    :rtype: Dict[NodeId, jnp.ndarray]
    """
    result: Dict[NodeId, jnp.ndarray] = {}
    for node_id, (start, dim) in index.items():
        result[node_id] = x[start:start+dim]
    return result

unpack_state_inplace(x_opt)

Write the optimized state vector back into the FactorGraph variable table.

Source code in dsg-jit/dsg_jit/world/model.py
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
def unpack_state_inplace(self, x_opt: jnp.ndarray) -> None:
    """
    Write the optimized state vector back into the FactorGraph variable table.
    """
    _, index = self.pack_state()  # index maps node_id -> (start, end)

    for node_id, (start, end) in index.items():
        block = x_opt[start:end]
        var = self.fg.variables[node_id]
        var.value = block  # overwrite stored variable value

marginal_prior_residual(stacked, params)

Residual for a dense Gaussian prior induced by marginalization.

This residual encodes a quadratic term of the form

1/2 (x - μ)^T H (x - μ)

via a Cholesky factorization H = L^T L. The parameters are:

mean       : μ, a 1D array of the same shape as ``stacked``.
sqrt_info  : L, a square matrix such that L^T L ≈ H.

The returned residual is L @ (x - μ), so that the overall contribution to the objective is 1/2 ||L (x - μ)||^2.

Source code in dsg-jit/dsg_jit/world/model.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def marginal_prior_residual(stacked: jnp.ndarray, params: Dict[str, Any]) -> jnp.ndarray:
    """Residual for a dense Gaussian prior induced by marginalization.

    This residual encodes a quadratic term of the form

        1/2 (x - μ)^T H (x - μ)

    via a Cholesky factorization H = L^T L. The parameters are:

        mean       : μ, a 1D array of the same shape as ``stacked``.
        sqrt_info  : L, a square matrix such that L^T L ≈ H.

    The returned residual is L @ (x - μ), so that the overall contribution
    to the objective is 1/2 ||L (x - μ)||^2.
    """
    mean = params["mean"]
    sqrt_info = params["sqrt_info"]
    return sqrt_info @ (stacked - mean)

world.scene_graph

Dynamic 3D scene graph utilities built on top of the world model.

This module provides a SceneGraphWorld abstraction that organizes poses, places, rooms, objects, and agents into a dynamic scene graph backed by the differentiable factor graph.

Conceptually, this layer is responsible for:

• Creating typed nodes:
    - Robot / agent poses (SE3)
    - Places / topological nodes (1D)
    - Rooms / regions
    - Objects (points / positions in space)
• Adding semantic and metric relationships between them via factors:
    - Pose priors
    - SE3 odometry / loop closures
    - Pose–place attachments
    - Pose–object / object–place relations
• Maintaining lightweight indexing:
    - Maps from (agent, time) → pose NodeId
    - Collections of place / room / object node ids
    - Optional trajectory dictionaries

What it does not do: • It does not implement the optimizer itself. • It does not hard-code SE3 math or Jacobians. • It does not perform rendering or perception.

All numerical optimization is delegated to:

- `world.model.WorldModel` (and its `FactorGraph`)
- `optimization.solvers` (Gauss–Newton / manifold variants)
- `slam.manifold` and `slam.measurements` for geometry and residuals

Typical usage

Experiments in experiments/exp0X_*.py follow a common pattern:

1. Construct a `SceneGraphWorld`.
2. Add a small chain of poses, places, and objects.
3. Attach priors and odometry factors.
4. Optionally attach voxel or observation factors.
5. Optimize via Gauss–Newton (JIT or non-JIT).
6. Inspect the resulting scene graph state.

Design goals

  • Ergonomics: hide raw NodeId and factor wiring behind friendly helpers like “add pose”, “add agent pose”, “attach place”, etc.
  • Differentiable backbone: everything created here remains compatible with JAX JIT and automatic differentiation downstream.
  • Extensibility: easy to add new relation types and node types without changing the optimizer or lower-level infrastructure.

SceneGraphNoiseConfig(prior_pose_sigma=0.001, odom_se3_sigma=0.05, smooth_pose_sigma=0.5, pose_place_sigma=0.05, object_at_pose_sigma=0.05, pose_landmark_sigma=0.05, pose_landmark_bearing_sigma=0.05, pose_voxel_point_sigma=0.05, voxel_smoothness_sigma=0.1, voxel_point_obs_sigma=0.05) dataclass

Default noise (standard deviation) per factor type.

These are in the same units as the residuals
  • prior / odom / smoothness: R^6 pose (m, m, m, rad, rad, rad)
  • pose_place / object_at_pose: R^1 or R^3 (m)

SceneGraphWorld()

World-level dynamic scene graph wrapper that manages typed nodes and semantic relationships, built atop the WorldModel. Provides ergonomic helpers for creating and connecting SE(3) poses, places, rooms, objects, and agents, and maintains convenient indexing for scene-graph experiments.

In addition to delegating numerical optimization to the underlying WorldModel, SceneGraphWorld maintains its own lightweight memory of node states. This persistent cache decouples the scene graph from the FactorGraph so that sliding-window marginalization or variable removal at the optimization level does not cause information loss at the scene-graph level.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(self) -> None:
    self.wm = WorldModel()
    self.pose_trajectory = {}
    self.noise = SceneGraphNoiseConfig()

    # --- Named semantic node indexes ---
    self.room_nodes = {}
    self.place_nodes = {}
    self.object_nodes = {}

    # --- Semantic adjacency (for visualization / topology) ---
    self.room_place_edges = []
    self.object_room_edges = []

    # --- Persistent scene-graph memory of node states ---
    self._memory: Dict[int, SceneNodeState] = {}
    self._factor_memory: Dict[int, SGFactorRecord] = {}
    self._next_factor_id: int = 0

    # --- Active-template mode (bounded FG / single-JIT) ---
    self._active_template_enabled: bool = False

    # var_type -> capacity (# slots)
    self._slot_capacity: Dict[str, int] = {}
    # node_id -> (var_type, slot_idx)
    self._slot_assign: Dict[int, Tuple[str, int]] = {}
    # var_type -> FIFO list of node_ids assigned (for eviction)
    self._slot_fifo: DefaultDict[str, Deque[int]] = defaultdict(deque)

    # factor_type -> capacity (# slots)
    self._factor_slot_capacity: Dict[str, int] = {}
    # factor_type -> next slot index (round-robin)
    self._factor_slot_next: DefaultDict[str, int] = defaultdict(int)

    # --- Global residuals registry ---
    self.wm.register_residual("prior", prior_residual)
    self.wm.register_residual("odom_se3", odom_se3_residual)
    self.wm.register_residual("odom_se3_geodesic", odom_se3_geodesic_residual)
    self.wm.register_residual("pose_place_attachment", pose_place_attachment_residual)
    self.wm.register_residual("object_at_pose", object_at_pose_residual)
    self.wm.register_residual("pose_temporal_smoothness", pose_temporal_smoothness_residual)
    self.wm.register_residual("pose_landmark_relative", pose_landmark_relative_residual)
    self.wm.register_residual("pose_landmark_bearing", pose_landmark_bearing_residual)
    self.wm.register_residual("pose_voxel_point", pose_voxel_point_residual)
    self.wm.register_residual("voxel_smoothness", voxel_smoothness_residual)
    self.wm.register_residual("voxel_point_obs", voxel_point_observation_residual)
    self.wm.register_residual("range", range_residual)

add_agent_pose_landmark_bearing(agent, t, landmark_id, bearing, sigma=None)

Add a bearing-only pose–landmark constraint for an agent at time t.

This wraps :meth:add_pose_landmark_bearing and resolves the pose id from :attr:pose_trajectory.

:param agent: Agent identifier. :param t: Timestep index for the pose. :param landmark_id: Node id of the 3D landmark variable. :param bearing: Iterable of length 3 giving the bearing vector in the pose frame (it will be normalized internally). :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.pose_landmark_bearing_sigma. :return: Integer factor id of the created bearing constraint. :raises KeyError: If no pose has been registered for (agent, t).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def add_agent_pose_landmark_bearing(
    self,
    agent: str,
    t: int,
    landmark_id: int,
    bearing,
    sigma: float | None = None,
) -> int:
    """
    Add a bearing-only pose–landmark constraint for an agent at time ``t``.

    This wraps :meth:`add_pose_landmark_bearing` and resolves the pose id
    from :attr:`pose_trajectory`.

    :param agent: Agent identifier.
    :param t: Timestep index for the pose.
    :param landmark_id: Node id of the 3D landmark variable.
    :param bearing: Iterable of length 3 giving the bearing vector in the
        pose frame (it will be normalized internally).
    :param sigma: Optional noise standard deviation. If ``None``, falls back
        to :attr:`SceneGraphNoiseConfig.pose_landmark_bearing_sigma`.
    :return: Integer factor id of the created bearing constraint.
    :raises KeyError: If no pose has been registered for ``(agent, t)``.
    """
    key = (agent, t)
    if key not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")

    pose_id = self.pose_trajectory[key]
    return self.add_pose_landmark_bearing(
        pose_id=pose_id,
        landmark_id=landmark_id,
        bearing=bearing,
        sigma=sigma,
    )

add_agent_pose_landmark_relative(agent, t, landmark_id, measurement, sigma=None)

Add a relative pose–landmark constraint for an agent at time t.

This is a small ergonomic wrapper around :meth:add_pose_landmark_relative that resolves the pose id using :attr:pose_trajectory.

:param agent: Agent identifier. :param t: Timestep index for the pose. :param landmark_id: Node id of the 3D landmark variable. :param measurement: Iterable of length 3 giving the expected landmark position in the pose frame. :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.pose_landmark_sigma. :return: Integer factor id of the created relative landmark constraint. :raises KeyError: If no pose has been registered for (agent, t).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
def add_agent_pose_landmark_relative(
    self,
    agent: str,
    t: int,
    landmark_id: int,
    measurement,
    sigma: float | None = None,
) -> int:
    """
    Add a relative pose–landmark constraint for an agent at time ``t``.

    This is a small ergonomic wrapper around
    :meth:`add_pose_landmark_relative` that resolves the pose id using
    :attr:`pose_trajectory`.

    :param agent: Agent identifier.
    :param t: Timestep index for the pose.
    :param landmark_id: Node id of the 3D landmark variable.
    :param measurement: Iterable of length 3 giving the expected landmark
        position in the pose frame.
    :param sigma: Optional noise standard deviation. If ``None``, falls back
        to :attr:`SceneGraphNoiseConfig.pose_landmark_sigma`.
    :return: Integer factor id of the created relative landmark constraint.
    :raises KeyError: If no pose has been registered for ``(agent, t)``.
    """
    key = (agent, t)
    if key not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")

    pose_id = self.pose_trajectory[key]
    return self.add_pose_landmark_relative(
        pose_id=pose_id,
        landmark_id=landmark_id,
        measurement=measurement,
        sigma=sigma,
    )

add_agent_pose_place_attachment(agent, t, place_id, coord_index=0, sigma=None)

Attach an agent pose at time t to a place node.

This is a higher-level wrapper around :meth:add_place_attachment which resolves the pose id via :attr:pose_trajectory.

:param agent: Agent identifier. :param t: Integer timestep index. :param place_id: Node id of the place variable (1D or 3D). :param coord_index: Index of the pose coordinate to tie to the place (typically 0 for x, 1 for y, etc.). Defaults to 0. :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.pose_place_sigma. :return: Integer factor id of the created attachment constraint. :raises KeyError: If no pose has been registered for (agent, t).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
def add_agent_pose_place_attachment(
    self,
    agent: str,
    t: int,
    place_id: int,
    coord_index: int = 0,
    sigma: float | None = None,
) -> int:
    """
    Attach an agent pose at time ``t`` to a place node.

    This is a higher-level wrapper around :meth:`add_place_attachment`
    which resolves the pose id via :attr:`pose_trajectory`.

    :param agent: Agent identifier.
    :param t: Integer timestep index.
    :param place_id: Node id of the place variable (1D or 3D).
    :param coord_index: Index of the pose coordinate to tie to the place
        (typically 0 for x, 1 for y, etc.). Defaults to 0.
    :param sigma: Optional noise standard deviation. If ``None``, falls back
        to :attr:`SceneGraphNoiseConfig.pose_place_sigma`.
    :return: Integer factor id of the created attachment constraint.
    :raises KeyError: If no pose has been registered for ``(agent, t)``.
    """
    pose_key = (agent, t)
    if pose_key not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")

    pose_id = self.pose_trajectory[pose_key]
    return self.add_place_attachment(
        pose_id=pose_id,
        place_id=place_id,
        coord_index=coord_index,
        sigma=sigma,
    )

add_agent_pose_se3(agent, t, value)

Add an SE(3) pose for a given agent at a specific timestep.

:param agent: Agent identifier (for example, a robot name). :param t: Integer timestep index. :param value: Length-6 array-like se(3) vector for the pose. :return: Integer node id of the created pose variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def add_agent_pose_se3(self, agent: str, t: int, value: jnp.ndarray) -> int:
    """
    Add an SE(3) pose for a given agent at a specific timestep.

    :param agent: Agent identifier (for example, a robot name).
    :param t: Integer timestep index.
    :param value: Length-6 array-like se(3) vector for the pose.
    :return: Integer node id of the created pose variable.
    """
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("pose_se3", jnp.asarray(value)))
    else:
        nid_int = int(self.wm.add_variable("pose_se3", value))
    self._remember_node(nid_int, "pose_se3", jnp.asarray(value))
    self.pose_trajectory[(agent, t)] = nid_int
    return nid_int

add_agent_pose_voxel_point(agent, t, voxel_id, point_meas, sigma=None)

Constrain a voxel cell using a point measurement from an agent pose.

This wraps :meth:add_pose_voxel_point and resolves the pose id from :attr:pose_trajectory.

:param agent: Agent identifier. :param t: Timestep index for the pose. :param voxel_id: Node id of the voxel cell variable. :param point_meas: Iterable of length 3 giving a point in the pose frame (for example, a back-projected LiDAR or depth sample). :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.pose_voxel_point_sigma. :return: Integer factor id of the created voxel-point constraint. :raises KeyError: If no pose has been registered for (agent, t).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
def add_agent_pose_voxel_point(
    self,
    agent: str,
    t: int,
    voxel_id: int,
    point_meas,
    sigma: float | None = None,
) -> int:
    """
    Constrain a voxel cell using a point measurement from an agent pose.

    This wraps :meth:`add_pose_voxel_point` and resolves the pose id from
    :attr:`pose_trajectory`.

    :param agent: Agent identifier.
    :param t: Timestep index for the pose.
    :param voxel_id: Node id of the voxel cell variable.
    :param point_meas: Iterable of length 3 giving a point in the pose
        frame (for example, a back-projected LiDAR or depth sample).
    :param sigma: Optional noise standard deviation. If ``None``, falls
        back to :attr:`SceneGraphNoiseConfig.pose_voxel_point_sigma`.
    :return: Integer factor id of the created voxel-point constraint.
    :raises KeyError: If no pose has been registered for ``(agent, t)``.
    """
    key = (agent, t)
    if key not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")

    pose_id = self.pose_trajectory[key]
    return self.add_pose_voxel_point(
        pose_id=pose_id,
        voxel_id=voxel_id,
        point_meas=point_meas,
        sigma=sigma,
    )

add_agent_range_measurement(agent, t, target_nid, measured_range, sigma=None, weight=None)

Add a range-only factor using an agent's pose at a given timestep.

This is a convenience wrapper around :meth:add_range_measurement that looks up the pose node id from :attr:pose_trajectory using (agent, t) and then creates a "range" factor to a target node.

:param agent: Agent identifier (for example, a robot name). :param t: Integer timestep index for the agent pose. :param target_nid: NodeId of the target variable (for example, place3d, voxel_cell or object3d). :param measured_range: Observed distance (same units as the world coordinates). :param sigma: Optional standard deviation of the measurement noise. If provided (and weight is None), it is converted to a weight via :func:slam.measurements.sigma_to_weight. :param weight: Optional explicit weight. If both sigma and weight are given, weight takes precedence. :return: Integer factor id of the created range factor. :raises KeyError: If no pose has been registered for (agent, t).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
def add_agent_range_measurement(
    self,
    agent: str,
    t: int,
    target_nid: int,
    measured_range: float,
    sigma: float | None = None,
    weight: float | None = None,
) -> int:
    """
    Add a range-only factor using an agent's pose at a given timestep.

    This is a convenience wrapper around :meth:`add_range_measurement`
    that looks up the pose node id from :attr:`pose_trajectory` using
    ``(agent, t)`` and then creates a ``"range"`` factor to a target node.

    :param agent: Agent identifier (for example, a robot name).
    :param t: Integer timestep index for the agent pose.
    :param target_nid: NodeId of the target variable (for example, ``place3d``,
        ``voxel_cell`` or ``object3d``).
    :param measured_range: Observed distance (same units as the world coordinates).
    :param sigma: Optional standard deviation of the measurement noise. If
        provided (and ``weight`` is ``None``), it is converted to a weight via
        :func:`slam.measurements.sigma_to_weight`.
    :param weight: Optional explicit weight. If both ``sigma`` and ``weight``
        are given, ``weight`` takes precedence.
    :return: Integer factor id of the created range factor.
    :raises KeyError: If no pose has been registered for ``(agent, t)``.
    """
    pose_key = (agent, t)
    if pose_key not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")

    pose_nid = self.pose_trajectory[pose_key]
    return self.add_range_measurement(
        pose_nid=pose_nid,
        target_nid=target_nid,
        measured_range=measured_range,
        sigma=sigma,
        weight=weight,
    )

add_agent_temporal_smoothness(agent, t, sigma=None)

Enforce temporal smoothness between successive poses of a given agent.

This enforces a smoothness constraint between the poses at timesteps t and t+1 for the specified agent, using :meth:add_temporal_smoothness internally.

:param agent: Agent identifier. :param t: Timestep index for the first pose in the pair. :param sigma: Optional standard deviation controlling smoothness. If None, falls back to :attr:SceneGraphNoiseConfig.smooth_pose_sigma. :return: Integer factor id of the created smoothness constraint. :raises KeyError: If either pose (agent, t) or (agent, t+1) has not been registered.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
def add_agent_temporal_smoothness(
    self,
    agent: str,
    t: int,
    sigma: float | None = None,
) -> int:
    """
    Enforce temporal smoothness between successive poses of a given agent.

    This enforces a smoothness constraint between the poses at timesteps
    ``t`` and ``t+1`` for the specified agent, using
    :meth:`add_temporal_smoothness` internally.

    :param agent: Agent identifier.
    :param t: Timestep index for the first pose in the pair.
    :param sigma: Optional standard deviation controlling smoothness. If
        ``None``, falls back to :attr:`SceneGraphNoiseConfig.smooth_pose_sigma`.
    :return: Integer factor id of the created smoothness constraint.
    :raises KeyError: If either pose ``(agent, t)`` or ``(agent, t+1)`` has
        not been registered.
    """
    key_t = (agent, t)
    key_t1 = (agent, t + 1)

    if key_t not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t}")
    if key_t1 not in self.pose_trajectory:
        raise KeyError(f"No pose registered for agent={agent!r}, t={t+1}")

    pose_id_t = self.pose_trajectory[key_t]
    pose_id_t1 = self.pose_trajectory[key_t1]
    return self.add_temporal_smoothness(
        pose_id_t=pose_id_t,
        pose_id_t1=pose_id_t1,
        sigma=sigma,
    )

add_landmark3d(xyz)

Add a 3D landmark node (R^3).

:param xyz: Iterable of length 3 giving world coordinates. :return: Integer node id of the created landmark variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
def add_landmark3d(self, xyz) -> int:
    """
    Add a 3D landmark node (R^3).

    :param xyz: Iterable of length 3 giving world coordinates.
    :return: Integer node id of the created landmark variable.
    """
    value = jnp.array(xyz, dtype=jnp.float32).reshape(3,)
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("landmark3d", value))
    else:
        nid_int = int(self.wm.add_variable("landmark3d", value))
    self._remember_node(nid_int, "landmark3d", value)
    return nid_int

add_named_object3d(name, xyz)

Add a 3D object and register it under a semantic name.

:param name: Identifier for the object (for example, "chair_1"). :param xyz: Iterable of length 3 giving the world-frame position. :return: Integer node id of the created object variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
463
464
465
466
467
468
469
470
471
472
473
def add_named_object3d(self, name: str, xyz) -> int:
    """
    Add a 3D object and register it under a semantic name.

    :param name: Identifier for the object (for example, ``"chair_1"``).
    :param xyz: Iterable of length 3 giving the world-frame position.
    :return: Integer node id of the created object variable.
    """
    obj_id = self.add_object3d(xyz)
    self.object_nodes[name] = obj_id
    return obj_id

add_object3d(xyz)

Add an object with 3D position (R^3).

:param xyz: Iterable of length 3 giving the object position in world coordinates. :return: Integer node id of the created object variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def add_object3d(self, xyz) -> int:
    """
    Add an object with 3D position (R^3).

    :param xyz: Iterable of length 3 giving the object position in
        world coordinates.
    :return: Integer node id of the created object variable.
    """
    xyz = jnp.array(xyz, dtype=jnp.float32).reshape(3,)
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("object3d", xyz))
    else:
        nid_int = int(self.wm.add_variable("object3d", xyz))
    self._remember_node(nid_int, "object3d", xyz)
    return nid_int

add_object_room_edge(object_id, room_id)

Register a semantic edge between an object node and a room node.

This helper is intentionally lightweight: it does not add a numeric factor to the underlying factor graph. Instead it records topological connectivity for visualization and higher-level reasoning, similar to classic dynamic scene-graph frameworks.

:param object_id: Integer node id of the object variable. :param room_id: Integer node id of the room variable. :return: None.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
def add_object_room_edge(self, object_id: int, room_id: int) -> None:
    """
    Register a semantic edge between an object node and a room node.

    This helper is intentionally lightweight: it does *not* add a numeric
    factor to the underlying factor graph. Instead it records topological
    connectivity for visualization and higher-level reasoning, similar to
    classic dynamic scene-graph frameworks.

    :param object_id: Integer node id of the object variable.
    :param room_id: Integer node id of the room variable.
    :return: None.
    """
    self.object_room_edges.append((int(object_id), int(room_id)))
    # Also register in the SceneGraph's factor memory for visualization.
    self._remember_factor(
        f_type="semantic_object_room",
        var_ids=(int(object_id), int(room_id)),
        params={},
        relation="object-room",
    )

add_odom_se3_additive(pose_i, pose_j, dx, sigma=None)

Add an additive SE(3) odometry factor in R^6.

The measurement is a translation along the x-axis plus zero rotation.

:param pose_i: Node id of the source pose. :param pose_j: Node id of the destination pose. :param dx: Translation along the x-axis in meters. :param sigma: Optional standard deviation for the odometry noise. If None, :attr:SceneGraphNoiseConfig.odom_se3_sigma is used. :return: Integer factor id of the created odometry constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
def add_odom_se3_additive(
    self,
    pose_i: int,
    pose_j: int,
    dx: float,
    sigma: float | None = None,
) -> int:
    """
    Add an additive SE(3) odometry factor in R^6.

    The measurement is a translation along the x-axis plus zero rotation.

    :param pose_i: Node id of the source pose.
    :param pose_j: Node id of the destination pose.
    :param dx: Translation along the x-axis in meters.
    :param sigma: Optional standard deviation for the odometry noise. If
        ``None``, :attr:`SceneGraphNoiseConfig.odom_se3_sigma` is used.
    :return: Integer factor id of the created odometry constraint.
    """
    meas = jnp.array([dx, 0.0, 0.0, 0.0, 0.0, 0.0])

    if sigma is None:
        sigma = self.noise.odom_se3_sigma

    weight = sigma_to_weight(sigma)

    params = {
        "measurement": meas,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="odom_se3",
        var_ids=(pose_i, pose_j),
        params=params,
        relation="factor:odom_se3",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("odom_se3", (pose_i, pose_j), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "odom_se3",
        (pose_i, pose_j),
        params,
    )
    return int(fid)

add_odom_se3_geodesic(pose_i, pose_j, dx, yaw=0.0, sigma=None)

Add a geodesic SE(3) odometry factor.

The measurement is parameterized as translation + yaw in se(3).

:param pose_i: Node id of the source pose. :param pose_j: Node id of the destination pose. :param dx: Translation along the x-axis in meters. :param yaw: Heading change around the z-axis in radians. :param sigma: Optional standard deviation for the odometry noise. If None, :attr:SceneGraphNoiseConfig.odom_se3_sigma is used. :return: Integer factor id of the created odometry constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
def add_odom_se3_geodesic(
    self,
    pose_i: int,
    pose_j: int,
    dx: float,
    yaw: float = 0.0,
    sigma: float | None = None,
) -> int:
    """
    Add a geodesic SE(3) odometry factor.

    The measurement is parameterized as translation + yaw in se(3).

    :param pose_i: Node id of the source pose.
    :param pose_j: Node id of the destination pose.
    :param dx: Translation along the x-axis in meters.
    :param yaw: Heading change around the z-axis in radians.
    :param sigma: Optional standard deviation for the odometry noise. If
        ``None``, :attr:`SceneGraphNoiseConfig.odom_se3_sigma` is used.
    :return: Integer factor id of the created odometry constraint.
    """
    meas = jnp.array([dx, 0.0, 0.0, 0.0, 0.0, yaw])

    if sigma is None:
        sigma = self.noise.odom_se3_sigma

    weight = sigma_to_weight(sigma)

    params = {
        "measurement": meas,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="odom_se3_geodesic",
        var_ids=(pose_i, pose_j),
        params=params,
        relation="factor:odom_se3_geodesic",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("odom_se3_geodesic", (pose_i, pose_j), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "odom_se3_geodesic",
        (pose_i, pose_j),
        params,
    )
    return int(fid)

add_place1d(x)

Add a 1D place variable.

:param x: Scalar position along a 1D axis (e.g. corridor coordinate). :return: Integer node id of the created place variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def add_place1d(self, x: float) -> int:
    """
    Add a 1D place variable.

    :param x: Scalar position along a 1D axis (e.g. corridor coordinate).
    :return: Integer node id of the created place variable.
    """
    value = jnp.array([x], dtype=jnp.float32)
    if self._active_template_enabled:
        nid = int(self._assign_var_slot("place1d", value))
    else:
        nid = int(self.wm.add_variable("place1d", value))
    self._remember_node(nid, "place1d", value)
    return nid

add_place3d(name, xyz)

Add a 3D place node (R^3) with a human-readable name.

This is a semantic helper for dynamic scene-graph style usage.

:param name: Identifier for the place (for example, "place_A"). :param xyz: Iterable of length 3 giving the world-frame position. :return: Integer node id of the created place variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def add_place3d(self, name: str, xyz) -> int:
    """
    Add a 3D place node (R^3) with a human-readable name.

    This is a semantic helper for dynamic scene-graph style usage.

    :param name: Identifier for the place (for example, ``"place_A"``).
    :param xyz: Iterable of length 3 giving the world-frame position.
    :return: Integer node id of the created place variable.
    """
    value = jnp.array(xyz, dtype=jnp.float32).reshape(3,)
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("place3d", value))
    else:
        nid_int = int(self.wm.add_variable("place3d", value))
    self._remember_node(nid_int, "place3d", value)
    self.place_nodes[name] = nid_int
    return nid_int

add_place_attachment(pose_id, place_id, coord_index=0, sigma=None)

Attach a SE(3) pose to a place node (1D or 3D).

This is a higher-level, dimension-aware wrapper around the pose_place_attachment residual, and is intended for scene-graph style experiments where places may be either 1D (topological) or 3D (metric positions).

:param pose_id: Node id of the SE(3) pose variable. :param place_id: Node id of the place variable. The underlying state dimension is inferred at runtime from the factor graph (for example, 1 for place1d or 3 for place3d). :param coord_index: Index of the pose coordinate to tie to the place (typically 0 for x, 1 for y, etc.). Defaults to 0. :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.pose_place_sigma. :return: Integer factor id of the created attachment constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
def add_place_attachment(
    self,
    pose_id: int,
    place_id: int,
    coord_index: int = 0,
    sigma: float | None = None,
) -> int:
    """
    Attach a SE(3) pose to a place node (1D or 3D).

    This is a higher-level, dimension-aware wrapper around the
    ``pose_place_attachment`` residual, and is intended for scene-graph
    style experiments where places may be either 1D (topological) or
    3D (metric positions).

    :param pose_id: Node id of the SE(3) pose variable.
    :param place_id: Node id of the place variable. The underlying state
        dimension is inferred at runtime from the factor graph (for
        example, 1 for ``place1d`` or 3 for ``place3d``).
    :param coord_index: Index of the pose coordinate to tie to the place
        (typically 0 for x, 1 for y, etc.). Defaults to 0.
    :param sigma: Optional noise standard deviation. If ``None``, falls
        back to :attr:`SceneGraphNoiseConfig.pose_place_sigma`.
    :return: Integer factor id of the created attachment constraint.
    """
    # Infer place dimensionality from the underlying variable.
    place_nid = NodeId(place_id)
    place_var = self.wm.fg.variables[place_nid]
    place_dim_val = place_var.value.shape[0]

    pose_dim = jnp.array(6)
    place_dim = jnp.array(place_dim_val)
    pose_coord_index = jnp.array(coord_index)

    if sigma is None:
        sigma = self.noise.pose_place_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "pose_dim": pose_dim,
        "place_dim": place_dim,
        "pose_coord_index": pose_coord_index,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_place_attachment",
        var_ids=(pose_id, place_id),
        params=params,
        relation="pose-place",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_place_attachment", (pose_id, place_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_place_attachment",
        (pose_id, place_id),
        params,
    )
    return int(fid)

add_pose_landmark_bearing(pose_id, landmark_id, bearing, sigma=None)

Add a bearing-only constraint from pose to landmark.

:param pose_id: Node id of the SE(3) pose variable. :param landmark_id: Node id of the 3D landmark variable. :param bearing: Iterable of length 3 giving the bearing vector in the pose frame (will be normalized internally). :param sigma: Optional noise standard deviation. If None, :attr:SceneGraphNoiseConfig.pose_landmark_bearing_sigma is used. :return: Integer factor id of the created bearing constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
def add_pose_landmark_bearing(
    self,
    pose_id: int,
    landmark_id: int,
    bearing,
    sigma: float | None = None,
) -> int:
    """
    Add a bearing-only constraint from pose to landmark.

    :param pose_id: Node id of the SE(3) pose variable.
    :param landmark_id: Node id of the 3D landmark variable.
    :param bearing: Iterable of length 3 giving the bearing vector in the
        pose frame (will be normalized internally).
    :param sigma: Optional noise standard deviation. If ``None``,
        :attr:`SceneGraphNoiseConfig.pose_landmark_bearing_sigma` is used.
    :return: Integer factor id of the created bearing constraint.
    """
    b = jnp.array(bearing, dtype=jnp.float32).reshape(3,)
    b = b / (jnp.linalg.norm(b) + 1e-8)

    if sigma is None:
        sigma = self.noise.pose_landmark_bearing_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "bearing_meas": b,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_landmark_bearing",
        var_ids=(pose_id, landmark_id),
        params=params,
        relation="factor:pose_landmark_bearing",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_landmark_bearing", (pose_id, landmark_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_landmark_bearing",
        (pose_id, landmark_id),
        params,
    )
    return int(fid)

add_pose_landmark_relative(pose_id, landmark_id, measurement, sigma=None)

Add a relative measurement between a pose and a 3D landmark.

The measurement is expressed in the pose frame.

:param pose_id: Node id of the SE(3) pose variable. :param landmark_id: Node id of the 3D landmark variable. :param measurement: Iterable of length 3 giving the expected landmark position in the pose frame. :param sigma: Optional noise standard deviation. If None, :attr:SceneGraphNoiseConfig.pose_landmark_sigma is used. :return: Integer factor id of the created relative landmark constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
def add_pose_landmark_relative(
    self,
    pose_id: int,
    landmark_id: int,
    measurement,
    sigma: float | None = None,
) -> int:
    """
    Add a relative measurement between a pose and a 3D landmark.

    The measurement is expressed in the pose frame.

    :param pose_id: Node id of the SE(3) pose variable.
    :param landmark_id: Node id of the 3D landmark variable.
    :param measurement: Iterable of length 3 giving the expected landmark
        position in the pose frame.
    :param sigma: Optional noise standard deviation. If ``None``,
        :attr:`SceneGraphNoiseConfig.pose_landmark_sigma` is used.
    :return: Integer factor id of the created relative landmark constraint.
    """
    meas = jnp.array(measurement, dtype=jnp.float32).reshape(3,)

    if sigma is None:
        sigma = self.noise.pose_landmark_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "measurement": meas,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_landmark_relative",
        var_ids=(pose_id, landmark_id),
        params=params,
        relation="factor:pose_landmark_relative",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_landmark_relative", (pose_id, landmark_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_landmark_relative",
        (pose_id, landmark_id),
        params,
    )
    return int(fid)

add_pose_se3(value)

Add a generic SE(3) pose variable.

:param value: Length-6 array-like se(3) vector [tx, ty, tz, rx, ry, rz]. :return: Integer node id of the created pose variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
352
353
354
355
356
357
358
359
360
361
362
363
364
def add_pose_se3(self, value: jnp.ndarray) -> int:
    """
    Add a generic SE(3) pose variable.

    :param value: Length-6 array-like se(3) vector [tx, ty, tz, rx, ry, rz].
    :return: Integer node id of the created pose variable.
    """
    if self._active_template_enabled:
        nid = int(self._assign_var_slot("pose_se3", jnp.asarray(value)))
    else:
        nid = int(self.wm.add_variable("pose_se3", value))
    self._remember_node(nid, "pose_se3", jnp.asarray(value))
    return nid

add_pose_voxel_point(pose_id, voxel_id, point_meas, sigma=None)

Constrain a voxel cell to align with a point measurement seen from a pose.

:param pose_id: Node id of the SE(3) pose variable. :param voxel_id: Node id of the voxel cell variable. :param point_meas: Iterable of length 3 giving a point in the pose frame (for example, a back-projected depth sample). :param sigma: Optional noise standard deviation. If None, :attr:SceneGraphNoiseConfig.pose_voxel_point_sigma is used. :return: Integer factor id of the created voxel-point constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
def add_pose_voxel_point(
    self,
    pose_id: int,
    voxel_id: int,
    point_meas,
    sigma: float | None = None,
) -> int:
    """
    Constrain a voxel cell to align with a point measurement seen from a pose.

    :param pose_id: Node id of the SE(3) pose variable.
    :param voxel_id: Node id of the voxel cell variable.
    :param point_meas: Iterable of length 3 giving a point in the pose
        frame (for example, a back-projected depth sample).
    :param sigma: Optional noise standard deviation. If ``None``,
        :attr:`SceneGraphNoiseConfig.pose_voxel_point_sigma` is used.
    :return: Integer factor id of the created voxel-point constraint.
    """
    point_meas = jnp.array(point_meas, dtype=jnp.float32).reshape(3,)

    if sigma is None:
        sigma = self.noise.pose_voxel_point_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "point_meas": point_meas,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_voxel_point",
        var_ids=(pose_id, voxel_id),
        params=params,
        relation="factor:pose_voxel_point",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_voxel_point", (pose_id, voxel_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_voxel_point",
        (pose_id, voxel_id),
        params,
    )
    return int(fid)

add_range_measurement(pose_nid, target_nid, measured_range, sigma=None, weight=None)

Add a range-only sensor factor between a pose and a 3D target.

This creates a factor of type "range" whose residual is:

r = ||target - pose|| - measured_range

The underlying residual is implemented in slam.measurements.range_residual.

:param pose_nid: NodeId of the pose (pose_se3) variable. :param target_nid: NodeId of the target variable (e.g. place3d, voxel_cell, object3d). :param measured_range: Observed distance (same units as world coordinates). :param sigma: Optional standard deviation of the measurement noise. If provided, it will be converted to a weight as 1 / sigma^2. :param weight: Optional explicit weight. If both sigma and weight are given, weight takes precedence. :return: Integer factor id of the created range factor.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
def add_range_measurement(
    self,
    pose_nid: int,
    target_nid: int,
    measured_range: float,
    sigma: float | None = None,
    weight: float | None = None,
) -> int:
    """
    Add a range-only sensor factor between a pose and a 3D target.

    This creates a factor of type ``"range"`` whose residual is:

        r = ||target - pose|| - measured_range

    The underlying residual is implemented in ``slam.measurements.range_residual``.

    :param pose_nid: NodeId of the pose (pose_se3) variable.
    :param target_nid: NodeId of the target variable (e.g. place3d, voxel_cell, object3d).
    :param measured_range: Observed distance (same units as world coordinates).
    :param sigma: Optional standard deviation of the measurement noise. If provided,
                  it will be converted to a weight as 1 / sigma^2.
    :param weight: Optional explicit weight. If both ``sigma`` and ``weight`` are given,
                   ``weight`` takes precedence.
    :return: Integer factor id of the created range factor.
    """
    if weight is not None:
        w = float(weight)
    elif sigma is not None:
        w = sigma_to_weight(sigma)
    else:
        w = 1.0

    meas = jnp.array([float(measured_range)], dtype=jnp.float32)
    params = {"range": meas, "weight": w}
    remembered = self._remember_factor(
        f_type="range",
        var_ids=(pose_nid, target_nid),
        params=params,
        relation="factor:range",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("range", (pose_nid, target_nid), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "range",
        (pose_nid, target_nid),
        params,
    )
    return int(fid)

add_room(name, center)

Add a 3D room node (R^3 center) with a semantic name.

This is a thin wrapper around a Euclidean variable, but exposes a room-level abstraction for dynamic scene-graph experiments.

:param name: Identifier for the room (for example, "room_A"). :param center: Iterable of length 3 giving the approximate room centroid in world coordinates. :return: Integer node id of the created room variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def add_room(self, name: str, center) -> int:
    """
    Add a 3D room node (R^3 center) with a semantic name.

    This is a thin wrapper around a Euclidean variable, but exposes a
    room-level abstraction for dynamic scene-graph experiments.

    :param name: Identifier for the room (for example, ``"room_A"``).
    :param center: Iterable of length 3 giving the approximate room
        centroid in world coordinates.
    :return: Integer node id of the created room variable.
    """
    value = jnp.array(center, dtype=jnp.float32).reshape(3,)
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("room3d", value))
    else:
        nid_int = int(self.wm.add_variable("room3d", value))
    self._remember_node(nid_int, "room3d", value)
    self.room_nodes[name] = nid_int
    return nid_int

add_room1d(x)

Add a 1D 'room' variable (just a scalar, wrapped as a length-1 vector).

The room is stored in :attr:room_nodes using an auto-generated string key of the form "room1d_{k}" where k is the current number of rooms.

:param x: 1D coordinate, shape (1,) or a scalar float. :return: Integer node id of the created room variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def add_room1d(self, x: jnp.ndarray) -> int:
    """
    Add a 1D 'room' variable (just a scalar, wrapped as a length-1 vector).

    The room is stored in :attr:`room_nodes` using an auto-generated
    string key of the form ``"room1d_{k}"`` where ``k`` is the current
    number of rooms.

    :param x: 1D coordinate, shape ``(1,)`` or a scalar float.
    :return: Integer node id of the created room variable.
    """
    # Normalize to a length-1 float32 vector.
    if isinstance(x, float) or (hasattr(x, "ndim") and x.ndim == 0):
        x = jnp.array([float(x)], dtype=jnp.float32)

    x = jnp.array(x, dtype=jnp.float32).reshape((1,))

    if self._active_template_enabled:
        nid = int(self._assign_var_slot("room1d", x))
    else:
        nid = int(self.wm.add_variable("room1d", x))  # after normalizing x
    self._remember_node(nid, "room1d", x)
    name = f"room1d_{len(self.room_nodes)}"
    self.room_nodes[name] = nid
    return nid

add_room_place_edge(room_id, place_id)

Register a semantic edge between a room node and a place node.

This helper is intentionally lightweight: it does not add a numeric factor to the underlying factor graph. Instead it records topological connectivity for visualization and higher-level reasoning, similar to classic dynamic scene-graph frameworks.

:param room_id: Integer node id of the room variable. :param place_id: Integer node id of the place variable. :return: None.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
def add_room_place_edge(self, room_id: int, place_id: int) -> None:
    """
    Register a semantic edge between a room node and a place node.

    This helper is intentionally lightweight: it does *not* add a numeric
    factor to the underlying factor graph. Instead it records topological
    connectivity for visualization and higher-level reasoning, similar to
    classic dynamic scene-graph frameworks.

    :param room_id: Integer node id of the room variable.
    :param place_id: Integer node id of the place variable.
    :return: None.
    """
    self.room_place_edges.append((int(room_id), int(place_id)))
    # Also register in the SceneGraph's factor memory for visualization.
    self._remember_factor(
        f_type="semantic_room_place",
        var_ids=(int(room_id), int(place_id)),
        params={},
        relation="room-place",
    )

add_temporal_smoothness(pose_id_t, pose_id_t1, sigma=None)

Enforce smoothness between successive poses.

:param pose_id_t: Node id of the pose at time t. :param pose_id_t1: Node id of the pose at time t+1. :param sigma: Optional standard deviation of the pose difference; a larger value gives weaker smoothness. If None, :attr:SceneGraphNoiseConfig.smooth_pose_sigma is used. :return: Integer factor id of the created smoothness constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
def add_temporal_smoothness(
    self,
    pose_id_t: int,
    pose_id_t1: int,
    sigma: float | None = None,
) -> int:
    """
    Enforce smoothness between successive poses.

    :param pose_id_t: Node id of the pose at time ``t``.
    :param pose_id_t1: Node id of the pose at time ``t+1``.
    :param sigma: Optional standard deviation of the pose difference; a
        larger value gives weaker smoothness. If ``None``,
        :attr:`SceneGraphNoiseConfig.smooth_pose_sigma` is used.
    :return: Integer factor id of the created smoothness constraint.
    """
    if sigma is None:
        sigma = self.noise.smooth_pose_sigma
    weight = sigma_to_weight(sigma)

    params = {"weight": weight}
    remembered = self._remember_factor(
        f_type="pose_temporal_smoothness",
        var_ids=(pose_id_t, pose_id_t1),
        params=params,
        relation="factor:pose_temporal_smoothness",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_temporal_smoothness", (pose_id_t, pose_id_t1), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_temporal_smoothness",
        (pose_id_t, pose_id_t1),
        params,
    )
    return int(fid)

add_voxel_cell(xyz)

Add a voxel cell center in world coordinates (R^3).

:param xyz: Iterable of length 3 giving the voxel center position. :return: Integer node id of the created voxel variable.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
def add_voxel_cell(self, xyz) -> int:
    """
    Add a voxel cell center in world coordinates (R^3).

    :param xyz: Iterable of length 3 giving the voxel center position.
    :return: Integer node id of the created voxel variable.
    """
    value = jnp.array(xyz, dtype=jnp.float32).reshape(3,)
    if self._active_template_enabled:
        nid_int = int(self._assign_var_slot("voxel_cell", value))
    else:
        nid_int = int(self.wm.add_variable("voxel_cell", value))
    self._remember_node(nid_int, "voxel_cell", value)
    return nid_int

add_voxel_point_observation(voxel_id, point_world, sigma=None)

Add an observation tying a voxel center to a 3D point in world coordinates.

:param voxel_id: Node id of the voxel cell variable. :param point_world: Iterable of length 3 giving a world-frame point (for example, from fused depth or a point cloud). :param sigma: Optional noise standard deviation. If None, :attr:SceneGraphNoiseConfig.voxel_point_obs_sigma is used. :return: Integer factor id of the created observation constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
def add_voxel_point_observation(
    self,
    voxel_id: int,
    point_world,
    sigma: float | None = None,
) -> int:
    """
    Add an observation tying a voxel center to a 3D point in world coordinates.

    :param voxel_id: Node id of the voxel cell variable.
    :param point_world: Iterable of length 3 giving a world-frame point
        (for example, from fused depth or a point cloud).
    :param sigma: Optional noise standard deviation. If ``None``,
        :attr:`SceneGraphNoiseConfig.voxel_point_obs_sigma` is used.
    :return: Integer factor id of the created observation constraint.
    """
    point_world = jnp.array(point_world, dtype=jnp.float32).reshape(3,)

    if sigma is None:
        sigma = self.noise.voxel_point_obs_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "point_world": point_world,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="voxel_point_obs",
        var_ids=(voxel_id,),
        params=params,
        relation="factor:voxel_point_obs",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("voxel_point_obs", (voxel_id,), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "voxel_point_obs",
        (voxel_id,),
        params,
    )
    return int(fid)

add_voxel_smoothness(voxel_i_id, voxel_j_id, offset, sigma=None)

Enforce grid-like spacing between two voxel centers.

:param voxel_i_id: Node id of the first voxel cell. :param voxel_j_id: Node id of the second voxel cell. :param offset: Iterable of length 3 giving the expected vector from voxel i to voxel j (for example, [dx, 0, 0]). :param sigma: Optional noise standard deviation. If None, :attr:SceneGraphNoiseConfig.voxel_smoothness_sigma is used. :return: Integer factor id of the created smoothness constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
def add_voxel_smoothness(
    self,
    voxel_i_id: int,
    voxel_j_id: int,
    offset,
    sigma: float | None = None,
) -> int:
    """
    Enforce grid-like spacing between two voxel centers.

    :param voxel_i_id: Node id of the first voxel cell.
    :param voxel_j_id: Node id of the second voxel cell.
    :param offset: Iterable of length 3 giving the expected vector from
        voxel ``i`` to voxel ``j`` (for example, ``[dx, 0, 0]``).
    :param sigma: Optional noise standard deviation. If ``None``,
        :attr:`SceneGraphNoiseConfig.voxel_smoothness_sigma` is used.
    :return: Integer factor id of the created smoothness constraint.
    """
    offset = jnp.array(offset, dtype=jnp.float32).reshape(3,)

    if sigma is None:
        sigma = self.noise.voxel_smoothness_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "offset": offset,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="voxel_smoothness",
        var_ids=(voxel_i_id, voxel_j_id),
        params=params,
        relation="factor:voxel_smoothness",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("voxel_smoothness", (voxel_i_id, voxel_j_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "voxel_smoothness",
        (voxel_i_id, voxel_j_id),
        params,
    )
    return int(fid)

attach_object_to_pose(pose_id, obj_id, offset=(0.0, 0.0, 0.0), sigma=None)

Attach an object to a pose with an optional 3D offset.

:param pose_id: Node id of the SE(3) pose variable. :param obj_id: Node id of the 3D object variable. :param offset: Iterable of length 3 giving the offset from the pose frame to the object in pose coordinates. :param sigma: Optional noise standard deviation. If None, falls back to :attr:SceneGraphNoiseConfig.object_at_pose_sigma. :return: Integer factor id of the created object-at-pose constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
def attach_object_to_pose(
    self,
    pose_id: int,
    obj_id: int,
    offset=(0.0, 0.0, 0.0),
    sigma: float | None = None,
) -> int:
    """
    Attach an object to a pose with an optional 3D offset.

    :param pose_id: Node id of the SE(3) pose variable.
    :param obj_id: Node id of the 3D object variable.
    :param offset: Iterable of length 3 giving the offset from the pose
        frame to the object in pose coordinates.
    :param sigma: Optional noise standard deviation. If ``None``, falls
        back to :attr:`SceneGraphNoiseConfig.object_at_pose_sigma`.
    :return: Integer factor id of the created object-at-pose constraint.
    """
    pose_dim = jnp.array(6)
    obj_dim = jnp.array(3)
    offset_arr = jnp.array(offset, dtype=jnp.float32).reshape(3,)

    if sigma is None:
        sigma = self.noise.object_at_pose_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "pose_dim": pose_dim,
        "obj_dim": obj_dim,
        "offset": offset_arr,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="object_at_pose",
        var_ids=(pose_id, obj_id),
        params=params,
        relation="factor:object_at_pose",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("object_at_pose", (pose_id, obj_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "object_at_pose",
        (pose_id, obj_id),
        params,
    )
    return int(fid)

attach_pose_to_place_x(pose_id, place_id)

Attach a pose to a 1D place along the x-coordinate.

This is a low-level helper that assumes a 6D pose and 1D place.

:param pose_id: Node id of the SE(3) pose variable. :param place_id: Node id of the 1D place variable. :return: Integer factor id of the created attachment constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def attach_pose_to_place_x(self, pose_id: int, place_id: int) -> int:
    """
    Attach a pose to a 1D place along the x-coordinate.

    This is a low-level helper that assumes a 6D pose and 1D place.

    :param pose_id: Node id of the SE(3) pose variable.
    :param place_id: Node id of the 1D place variable.
    :return: Integer factor id of the created attachment constraint.
    """
    pose_dim = jnp.array(6)
    place_dim = jnp.array(1)
    pose_coord_index = jnp.array(0)

    sigma = self.noise.pose_place_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "pose_dim": pose_dim,
        "place_dim": place_dim,
        "pose_coord_index": pose_coord_index,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_place_attachment",
        var_ids=(pose_id, place_id),
        params=params,
        relation="pose-place",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_place_attachment", (pose_id, place_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_place_attachment",
        (pose_id, place_id),
        params,
    )
    return int(fid)

attach_pose_to_room_x(pose_id, room_id)

Attach a pose to a 1D room along the x-coordinate.

This is analogous to :meth:attach_pose_to_place_x but uses a room node instead of a place node.

:param pose_id: Node id of the SE(3) pose variable. :param room_id: Node id of the 1D room variable. :return: Integer factor id of the created attachment constraint.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
def attach_pose_to_room_x(self, pose_id: int, room_id: int) -> int:
    """
    Attach a pose to a 1D room along the x-coordinate.

    This is analogous to :meth:`attach_pose_to_place_x` but uses a room
    node instead of a place node.

    :param pose_id: Node id of the SE(3) pose variable.
    :param room_id: Node id of the 1D room variable.
    :return: Integer factor id of the created attachment constraint.
    """
    pose_dim = jnp.array(6)
    place_dim = jnp.array(1)
    pose_coord_index = jnp.array(0)

    sigma = self.noise.pose_place_sigma
    weight = sigma_to_weight(sigma)

    params = {
        "pose_dim": pose_dim,
        "place_dim": place_dim,
        "pose_coord_index": pose_coord_index,
        "weight": weight,
    }
    remembered = self._remember_factor(
        f_type="pose_place_attachment",
        var_ids=(pose_id, room_id),
        params=params,
        relation="pose-place",
    )
    if self._active_template_enabled:
        self._assign_factor_slot("pose_place_attachment", (pose_id, room_id), params, active=True)
        return int(remembered)

    fid = self.wm.add_factor(
        "pose_place_attachment",
        (pose_id, room_id),
        params,
    )
    return int(fid)

dump_state()

Return a snapshot of all variable values in the world.

:return: Dictionary mapping integer node ids to JAX arrays of values.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1593
1594
1595
1596
1597
1598
1599
def dump_state(self) -> Dict[int, jnp.ndarray]:
    """
    Return a snapshot of all variable values in the world.

    :return: Dictionary mapping integer node ids to JAX arrays of values.
    """
    return {nid: state.value for nid, state in self._memory.items()}

enable_active_template(template)

Enable fixed-capacity active-template mode.

In this mode, SceneGraphWorld retains full persistent memory, but only a bounded active subset is mapped into the WorldModel slots. This enables a single stable JIT compilation and constant-latency solves.

:param template: ActiveWindowTemplate instance (from world.model).

Source code in dsg-jit/dsg_jit/world/scene_graph.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def enable_active_template(self, template) -> None:
    """Enable fixed-capacity active-template mode.

    In this mode, SceneGraphWorld retains full persistent memory, but
    only a bounded active subset is mapped into the WorldModel slots.
    This enables a single stable JIT compilation and constant-latency solves.

    :param template: ActiveWindowTemplate instance (from world.model).
    """
    self.wm.init_active_template(template)
    self._active_template_enabled = True

    self._slot_capacity = {vs.var_type: int(vs.count) for vs in template.var_slots}
    self._factor_slot_capacity = {fs.factor_type: int(fs.count) for fs in template.factor_slots}

    # reset assignment state (SceneGraph memory remains intact)
    self._slot_assign.clear()
    self._slot_fifo.clear()
    self._factor_slot_next.clear()

get_object3d(obj_id)

Return the current 3D position of an object.

:param obj_id: Integer node id of the object variable. :return: JAX array of shape (3,) giving the object position.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
def get_object3d(self, obj_id: int) -> jnp.ndarray:
    """
    Return the current 3D position of an object.

    :param obj_id: Integer node id of the object variable.
    :return: JAX array of shape ``(3,)`` giving the object position.
    """
    oid = int(obj_id)
    if oid not in self._memory:
        raise KeyError(f"No object registered in SceneGraph memory for id={oid}")
    return self._memory[oid].value

get_place(place_id)

Return the current scalar value of a 1D place.

:param place_id: Integer node id of the place variable. :return: Floating-point scalar position.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
def get_place(self, place_id: int) -> float:
    """
    Return the current scalar value of a 1D place.

    :param place_id: Integer node id of the place variable.
    :return: Floating-point scalar position.
    """
    pid = int(place_id)
    if pid not in self._memory:
        raise KeyError(f"No place registered in SceneGraph memory for id={pid}")
    return float(self._memory[pid].value[0])

get_pose(pose_id)

Return the current SE(3) pose value.

:param pose_id: Integer node id of the pose variable. :return: JAX array of shape (6,) containing the se(3) vector.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
def get_pose(self, pose_id: int) -> jnp.ndarray:
    """
    Return the current SE(3) pose value.

    :param pose_id: Integer node id of the pose variable.
    :return: JAX array of shape ``(6,)`` containing the se(3) vector.
    """
    pid = int(pose_id)
    if pid not in self._memory:
        raise KeyError(f"No pose registered in SceneGraph memory for id={pid}")
    return self._memory[pid].value

optimize(method='gn', iters=40)

Run nonlinear optimization over the current factor graph. This optimizes the current WorldModel factor graph (which may be bounded active-template or unbounded, depending on configuration).

:param method: Optimization method name (currently "gn" for Gauss–Newton). :param iters: Maximum number of iterations to run. :return: None. The internal world model state is updated in-place.

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
def optimize(self, method: str = "gn", iters: int = 40) -> None:
    """
    Run nonlinear optimization over the current factor graph.
    This optimizes the current WorldModel factor graph (which may be bounded active-template or unbounded, depending on configuration).

    :param method: Optimization method name (currently ``"gn"`` for
        Gauss–Newton).
    :param iters: Maximum number of iterations to run.
    :return: ``None``. The internal world model state is updated in-place.
    """
    self.wm.optimize(method=method, iters=iters, damping=1e-3, max_step_norm=0.5)

    for nid, var in self.wm.fg.variables.items():
        nid_int = int(nid)
        if nid_int in self._memory:
            self._memory[nid_int].value = var.value

optimize_active_batch(iters=5, damping=0.001)

Optimize only the currently active bounded FG (active-template mode).

:param iters: An integer representing the maximum number of iterations for an optimization :param damping: The minimum precision for a solve

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
def optimize_active_batch(self, iters: int = 5, damping: float = 1e-3) -> None:
    """Optimize only the currently active bounded FG (active-template mode).

    :param iters: An integer representing the maximum number of iterations for an optimization
    :param damping: The minimum precision for a solve
    """
    if not self._active_template_enabled:
        raise RuntimeError(
            "Active-template mode not enabled. Call enable_active_template(...)"
        )

    self.wm.optimize(
        method="gn",
        iters=int(iters),
        damping=float(damping),
        max_step_norm=0.5,
    )

    # Pull optimized values back into SG memory for active slot variables
    for nid, var in self.wm.fg.variables.items():
        nid_int = int(nid)
        if nid_int in self._memory:
            self._memory[nid_int].value = var.value

optimize_global_offline(iters=40, damping=0.001)

Full batch optimization over the entire persistent SceneGraph memory.

:param iters: An integer representing the maximum number of iterations for an optimization :param damping: The minimum precision for a solve

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
def optimize_global_offline(self, iters: int = 40, damping: float = 1e-3) -> None:
    """Full batch optimization over the entire persistent SceneGraph memory.

    :param iters: An integer representing the maximum number of iterations for an optimization
    :param damping: The minimum precision for a solve
    """
    tmp = WorldModel()

    # Register residuals from this SceneGraphWorld
    for k, fn in self.wm._residual_registry.items():
        tmp.register_residual(k, fn)

    # Replay variables and keep remap
    remap: Dict[int, int] = {}
    for nid, st in self._memory.items():
        new_id = int(tmp.add_variable(st.var_type, st.value))
        remap[int(nid)] = new_id

    # Replay factors (skip semantic-only and inactive)
    for rec in self._factor_memory.values():
        if not rec.active:
            continue
        if rec.f_type.startswith("semantic_"):
            continue

        mapped = tuple(remap[v] for v in rec.var_ids if v in remap)
        if len(mapped) != len(rec.var_ids):
            continue

        tmp.add_factor(rec.f_type, mapped, dict(rec.params))

    tmp.optimize(method="gn", iters=int(iters), damping=float(damping), max_step_norm=0.5)

    inv = {v: k for k, v in remap.items()}
    for nid, var in tmp.fg.variables.items():
        nid_int = int(nid)
        if nid_int in inv:
            orig = inv[nid_int]
            if orig in self._memory:
                self._memory[orig].value = var.value

visualize_web(host='127.0.0.1', port=8000, open_browser=True)

Launch a local Three.js-based 3D viewer for this SceneGraph.

:param host: A string representing the Host IP, is configured for LocalHost by default :param port: An integer representing a target host port to expose the webviewer

Source code in dsg-jit/dsg_jit/world/scene_graph.py
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
def visualize_web(
    self,
    host: str = "127.0.0.1",
    port: int = 8000,
    open_browser: bool = True,
) -> None:
    """Launch a local Three.js-based 3D viewer for this SceneGraph.

    :param host: A string representing the Host IP, is configured for LocalHost by default
    :param port: An integer representing a target host port to expose the webviewer """
    from dsg_jit.world.web_viewer import run_scenegraph_web_viewer

    run_scenegraph_web_viewer(self, host=host, port=port, open_browser=open_browser)

SceneNodeState(node_id, var_type, value) dataclass

Lightweight cache of a scene-graph node's latest value.

This decouples the persistent scene graph from the underlying optimization FactorGraph: even if a variable is marginalized or removed from the FactorGraph (for example, in a sliding-window setup), the SceneGraph can still serve its last optimized value.


world.voxel_grid

Voxel grid utilities for differentiable volumetric scene representations.

This module defines helpers for constructing voxel-level variables and their associated factors on top of the DSG-JIT world model.

Key responsibilities

  • Create voxel chains or grids: • 1D voxel chains (for smooth curves or “lines” in space). • Higher-dimensional voxel layouts (as needed by experiments).
  • Register and attach voxel-related factors: • Smoothness factors between neighboring voxels (using voxel_smoothness_residual). • Point-observation factors tying voxels to measurements in world coordinates (using voxel_point_observation_residual). • Optional voxel priors for regularization or supervision.

  • Provide convenience routines for: • Initializing voxel positions (e.g. along an axis). • Accessing the optimized voxel centers from the packed state.

Role in the DSG-JIT stack

Voxel grids are a key piece of the volumetric side of the engine. They allow us to:

• Represent surfaces or occupancy with a differentiable structure.
• Run Gauss–Newton over large chains / grids of voxels.
• Jointly optimize voxels with SE3 poses and other scene graph nodes
  (hybrid SE3 + voxel experiments and benchmarks).

Integration points

  • Uses world.model.WorldModel to create voxel variables and factors.
  • Relies on residuals defined in slam.measurements for:
    • smoothness,
    • point observations,
    • and priors.
  • Works seamlessly with optimization.solvers.gauss_newton_manifold and related JIT-compiled solvers.

Design goals

  • Scalable: Able to create hundreds or thousands of voxel nodes and factors that still admit fast, JIT-compiled optimization.
  • Composable: Plays nicely with SE3 poses, places, and other world entities in a single factor graph.
  • Experiment-oriented: Keeps the voxel construction boilerplate out of experiment scripts, making it easier to design new voxel-based learning tasks.

VoxelGridSpec(origin, dims, resolution) dataclass

Specification for constructing a regular voxel grid.

This lightweight container defines the spatial layout of a voxel grid, including its world-space origin, discrete grid dimensions, and the physical resolution of each voxel cell.

:param origin: A 3-element array giving the world-space center of voxel coordinate (0, 0, 0). This is the reference point from which all voxel centers are computed. :param dims: A tuple (nx, ny, nz) representing the number of voxels along the x-, y-, and z-axes respectively. :param resolution: The edge length of each voxel cell in world units. The spacing between voxel centers is equal to this resolution.

build_voxel_grid(sg, spec)

Construct a regular voxel grid inside the SceneGraphWorld.

This allocates one voxel_cell variable per grid coordinate (ix, iy, iz) using the voxel resolution and origin defined in spec. Each voxel is positioned at:

center = origin + [ix * res, iy * res, iz * res]

The resulting mapping enables downstream creation of voxel smoothness constraints and scene-graph integration.

:param sg: The active SceneGraphWorld instance where voxel nodes will be created. Must expose add_voxel_cell(center) which returns a node ID. :param spec: Voxel grid specification containing: - spec.origin: 3D world origin of the grid. - spec.dims: Tuple (nx, ny, nz) specifying grid dimensions. - spec.resolution: Edge length of each voxel cell. :return: A dictionary mapping each grid index (ix, iy, iz) to the corresponding voxel node ID allocated within the scene graph.

Source code in dsg-jit/dsg_jit/world/voxel_grid.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def build_voxel_grid(
    sg: SceneGraphWorld,
    spec: VoxelGridSpec,
) -> Dict[GridIndex, int]:
    """
    Construct a regular voxel grid inside the SceneGraphWorld.

    This allocates one `voxel_cell` variable per grid coordinate `(ix, iy, iz)`
    using the voxel resolution and origin defined in `spec`. Each voxel is
    positioned at:

        center = origin + [ix * res, iy * res, iz * res]

    The resulting mapping enables downstream creation of voxel smoothness
    constraints and scene-graph integration.

    :param sg: The active `SceneGraphWorld` instance where voxel nodes will be
        created. Must expose `add_voxel_cell(center)` which returns a node ID.
    :param spec: Voxel grid specification containing:
        - `spec.origin`: 3D world origin of the grid.
        - `spec.dims`: Tuple `(nx, ny, nz)` specifying grid dimensions.
        - `spec.resolution`: Edge length of each voxel cell.
    :return: A dictionary mapping each grid index `(ix, iy, iz)` to the
        corresponding voxel node ID allocated within the scene graph.
    """
    origin = jnp.array(spec.origin, dtype=jnp.float32).reshape(3,)
    nx, ny, nz = spec.dims
    res = float(spec.resolution)

    index_to_id: Dict[GridIndex, int] = {}

    for ix in range(nx):
        for iy in range(ny):
            for iz in range(nz):
                offset = jnp.array([ix * res, iy * res, iz * res], dtype=jnp.float32)
                center = origin + offset
                nid = sg.add_voxel_cell(center)
                index_to_id[(ix, iy, iz)] = nid

    return index_to_id

connect_grid_neighbors_1d_x(sg, index_to_id, spec, sigma=None)

Connect 3D voxel grid nodes along the +x direction using smoothness factors.

This function iterates over all voxel indices (ix, iy, iz) such that ix + 1 < nx, and adds a voxel smoothness constraint between each voxel and its +x neighbor. The enforced residual encourages:

voxel(ix+1, iy, iz) - voxel(ix, iy, iz) ≈ [resolution, 0, 0]

This is sufficient to enforce a 1D chain structure along the x-axis and is used when constructing structured voxel grids for optimization.

:param sg: The active SceneGraphWorld instance to which smoothness factors will be added. Must expose add_voxel_smoothness(i, j, offset, sigma). :param index_to_id: Mapping from grid index (ix, iy, iz) to the corresponding node ID in the scene graph or factor graph. :param spec: Voxel grid specification containing dimensions and voxel resolution. Expected to provide: - spec.dims: Tuple (nx, ny, nz) with number of voxels. - spec.resolution: Voxel edge length in world units. :param sigma: Optional noise standard deviation for the smoothness factor. If None, the default sigma inside sg.add_voxel_smoothness is used. :return: None. This function mutates the scene graph world in-place by adding smoothness edges between neighboring x-axis voxels.

Source code in dsg-jit/dsg_jit/world/voxel_grid.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def connect_grid_neighbors_1d_x(
    sg: SceneGraphWorld,
    index_to_id: Dict[GridIndex, int],
    spec: VoxelGridSpec,
    sigma: float | None = None,
) -> None:
    """
    Connect 3D voxel grid nodes along the +x direction using smoothness factors.

    This function iterates over all voxel indices `(ix, iy, iz)` such that
    `ix + 1 < nx`, and adds a voxel smoothness constraint between each voxel
    and its +x neighbor. The enforced residual encourages:

        voxel(ix+1, iy, iz) - voxel(ix, iy, iz) ≈ [resolution, 0, 0]

    This is sufficient to enforce a 1D chain structure along the x-axis
    and is used when constructing structured voxel grids for optimization.

    :param sg: The active `SceneGraphWorld` instance to which smoothness
        factors will be added. Must expose `add_voxel_smoothness(i, j, offset, sigma)`.
    :param index_to_id: Mapping from grid index `(ix, iy, iz)` to the corresponding
        node ID in the scene graph or factor graph.
    :param spec: Voxel grid specification containing dimensions and voxel resolution.
        Expected to provide:
            - `spec.dims`: Tuple `(nx, ny, nz)` with number of voxels.
            - `spec.resolution`: Voxel edge length in world units.
    :param sigma: Optional noise standard deviation for the smoothness factor.
        If `None`, the default sigma inside `sg.add_voxel_smoothness` is used.
    :return: None. This function mutates the scene graph world in-place by
        adding smoothness edges between neighboring x-axis voxels.
    """
    nx, ny, nz = spec.dims
    res = float(spec.resolution)

    offset = jnp.array([res, 0.0, 0.0], dtype=jnp.float32)

    for ix in range(nx - 1):
        for iy in range(ny):
            for iz in range(nz):
                vid_i = index_to_id[(ix, iy, iz)]
                vid_j = index_to_id[(ix + 1, iy, iz)]
                sg.add_voxel_smoothness(vid_i, vid_j, offset, sigma=sigma)

world.training

High-level training utilities for differentiable scene graph experiments.

This module provides a small training harness that sits on top of:

• `world.model.WorldModel` / `world.scene_graph.SceneGraphWorld`
• JAX-based optimizers and Gauss–Newton solvers
• Residual functions from `slam.measurements`

Its main role is to support meta-learning and hyperparameter learning over the differentiable DSG-JIT engine. Examples include:

• Learning factor-type weights (e.g. odometry vs. observation).
• Learning measurement parameters (odom SE3 chains, voxel obs).
• Running outer-loop gradient descent over:
    - log-scale weights,
    - observation locations,
    - or other “theta” parameters that influence the inner solve.

Typical structure

A typical training loop implemented here follows this pattern:

1. Build a world / scene graph for a given scenario.
2. Build a residual function that depends both on:
       - the state x (poses, voxels, etc.), and
       - learnable parameters θ (e.g. measurements, log-scales).
3. Run an inner optimization (Gauss–Newton or gradient descent)
   to obtain x*(θ).
4. Compute a supervised loss L(x*(θ), target).
5. Differentiate L w.r.t. θ using JAX (`jax.grad` or `jax.value_and_grad`).
6. Update θ with an outer optimizer step.

The DSGTrainer (or equivalent helper) encapsulates this pattern, exposing step / train_step–style methods that return both the new parameters and useful diagnostics (loss, gradient norms, etc.).

Design goals

  • Keep experiments small: Training logic lives here so individual experiments can focus on constructing the world and defining the supervision signal.
  • JAX-first design: Training functions are written to be JIT-able and differentiable, allowing seamless scaling from toy experiments to larger graphs.
  • Research-friendly: The code is intentionally lightweight and easy to modify for new research ideas around learnable costs, priors, and structure.

DSGTrainer(wm, factor_type_order, inner_cfg) dataclass

High-level trainer for differentiable DSG experiments.

This class encapsulates a simple bi-level optimization pattern where: an inner loop solves for the scene graph state x, and an outer loop optimizes meta-parameters such as factor-type weights.

:param wm: World model containing the factor graph and scene graph. :param factor_type_order: Ordered list of factor type names; each entry corresponds to a log-scale entry in the weight vector. :param inner_cfg: Configuration for the inner gradient–descent solver applied to the state.

__post_init__()

Post-initialization hook.

This method caches the underlying factor graph from the world model and builds a residual function that accepts per-factor-type log-scales.

Source code in dsg-jit/dsg_jit/world/training.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __post_init__(self):
    """
    Post-initialization hook.

    This method caches the underlying factor graph from the world model
    and builds a residual function that accepts per-factor-type log-scales.
    """
    self.fg: FactorGraph = self.wm.fg
    self.residual_w = self.fg.build_residual_function_with_type_weights(
        self.factor_type_order
    )

solve_state(log_scales)

Run the inner optimization to solve for the state vector.

Given a vector of log-scales for factor types, this method performs explicit gradient descent on the objective

0.5 * || r(x, log_scales) ||^2,

where r is the weighted residual function built from the factor graph.

:param log_scales: Array of shape (T,) containing per-factor-type log-scale weights. :return: Optimized flat state vector x after the inner GD loop.

Source code in dsg-jit/dsg_jit/world/training.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def solve_state(self, log_scales: jnp.ndarray) -> jnp.ndarray:
    """
    Run the inner optimization to solve for the state vector.

    Given a vector of log-scales for factor types, this method performs
    explicit gradient descent on the objective

        0.5 * || r(x, log_scales) ||^2,

    where r is the weighted residual function built from the factor graph.

    :param log_scales: Array of shape ``(T,)`` containing per-factor-type log-scale weights.
    :return: Optimized flat state vector ``x`` after the inner GD loop.
    """
    x0, _ = self.wm.pack_state()

    def loss_x(x, log_scales):
        r = self.residual_w(x, log_scales)
        return 0.5 * jnp.sum(r * r)

    grad_loss_x = jax.grad(loss_x)

    # plain Python loop (no jax.lax.while_loop, easier to debug)
    x = x0
    for _ in range(self.inner_cfg.max_iters):
        g = grad_loss_x(x, log_scales)

        # gradient may be large; clamp the step
        step = -self.inner_cfg.learning_rate * g
        step_norm = jnp.linalg.norm(step)
        max_norm = self.inner_cfg.max_step_norm

        def clamp_step(step, step_norm, max_norm):
            scale = max_norm / (step_norm + 1e-8)
            return step * scale

        step = jax.lax.cond(
            step_norm > max_norm,
            lambda _: clamp_step(step, step_norm, max_norm),
            lambda _: step,
            operand=None,
        )

        x = x + step

    return x

unpack_state(x)

Unpack a flat state vector into a NodeId-keyed dictionary.

This is a thin wrapper around the factor graph's unpack_state that uses the index structure implied by the current world model.

:param x: Flat state vector to be unpacked. :return: Mapping from NodeId to the corresponding slice of x as a JAX array.

Source code in dsg-jit/dsg_jit/world/training.py
153
154
155
156
157
158
159
160
161
162
163
164
def unpack_state(self, x: jnp.ndarray):
    """
    Unpack a flat state vector into a NodeId-keyed dictionary.

    This is a thin wrapper around the factor graph's ``unpack_state``
    that uses the index structure implied by the current world model.

    :param x: Flat state vector to be unpacked.
    :return: Mapping from ``NodeId`` to the corresponding slice of ``x`` as a JAX array.
    """
    _, index = self.wm.pack_state()
    return self.wm.unpack_state(x, index)

InnerGDConfig(learning_rate=0.01, max_iters=40, max_step_norm=1.0) dataclass

Configuration for the inner gradient–descent solver.

:param learning_rate: Step size used for each inner GD update on the state. :param max_iters: Maximum number of inner GD iterations. :param max_step_norm: Maximum allowed L2 norm of a single GD step; used to clamp overly large updates for numerical stability.

world.visualization

Visualization utilities for DSG-JIT.

This module provides lightweight 2D and 3D rendering tools for visualizing factor graphs, scene graphs, and mixed-level semantic structures. It is designed to support both debugging and demonstration of DSG-JIT’s hierarchical representations, including robot poses, voxel cells, places, rooms, and arbitrary semantic objects.

The visualization pipeline follows three main steps:

  1. Exporting graph data
    export_factor_graph_for_vis() converts an internal FactorGraph into color-coded VisNode and VisEdge lists. Variable types such as pose_se3, voxel_cell, place1d, and room1d are mapped to coarse visualization categories, and heuristic 3D positions are extracted for rendering.

  2. 2D top-down rendering
    plot_factor_graph_2d() produces a Matplotlib top-down view (x–y plane) with automatically computed bounds, node type coloring, and optional label rendering. This is especially useful for SE(3) SLAM chains, grid-based voxel fields, and planar semantic graphs.

  3. Full 3D scene graph rendering
    plot_factor_graph_3d() draws a complete 3D view of poses, voxels, places, rooms, and objects. Edges between nodes represent geometric or semantic relationships. Aspect ratios are normalized so spatial structure remains visually meaningful regardless of scale.

These visualizers are intentionally decoupled from the high-level world model (SceneGraphWorld) so they can be used directly on raw factor graphs produced by optimization procedures or experiment scripts.

Example usage is provided in: - experiments/exp17_visual_factor_graph.py (basic 2D + 3D factor graph) - experiments/exp18_scenegraph_3d.py (HYDRA-style multi-level scene graph) - experiments/exp18_scenegraph_demo.py (HYDRA-style 2D + 3D scene graph) - experiments/exp19_dynamic_scene_graph_demo.py (dynamic agent trajectories)

Module contents
  • VisNode: Lightweight typed node container for visualization.
  • VisEdge: Lightweight edge container (factor connections).
  • _infer_node_type(): Maps variable types → canonical visualization types.
  • _extract_position(): Extracts a 3D coordinate from variable states.
  • export_factor_graph_for_vis(): Converts a FactorGraph → vis nodes & edges.
  • plot_factor_graph_2d(): Renders a 2D top-down view of the graph.
  • plot_factor_graph_3d(): Renders a full 3D scene graph with semantic layers.
  • plot_scenegraph_3d(): Renders a scene graph with semantic layers and (optionally) agent trajectories.
  • plot_dynamic_trajectories_3d(): Renders 3D agent trajectories with time-encoded color.

This module is designed to be extendable—for example: - Additional node types can be added via _infer_node_type. - SceneGraphWorld can later provide richer semantic annotations. - Future versions may support interactive or WebGL visualizations.

VisEdge(var_ids, factor_type) dataclass

Lightweight edge representation for visualization.

VisNode(id, type, position, label) dataclass

Lightweight node representation for visualization.

export_factor_graph_for_vis(fg)

Export a FactorGraph into a visualization-friendly node/edge list.

This does not require any SceneGraphWorld; it just uses variables/factors.

:param fg: The factor graph to visualize. :return: (nodes, edges) where nodes is a list of VisNode and edges is a list of VisEdge.

Source code in dsg-jit/dsg_jit/world/visualization.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def export_factor_graph_for_vis(fg: FactorGraph) -> Tuple[List[VisNode], List[VisEdge]]:
    """
    Export a FactorGraph into a visualization-friendly node/edge list.

    This does *not* require any SceneGraphWorld; it just uses variables/factors.

    :param fg: The factor graph to visualize.
    :return: (nodes, edges) where nodes is a list of VisNode and edges is a list of VisEdge.
    """
    nodes: List[VisNode] = []
    edges: List[VisEdge] = []

    # Nodes
    for nid, var in fg.variables.items():
        ntype = _infer_node_type(var.type)
        pos = _extract_position(var.type, var.value)
        nodes.append(
            VisNode(
                id=nid,
                type=ntype,
                position=pos,
                label=f"{ntype}:{int(nid)}",
            )
        )

    # Edges (one edge per factor, between all its variables)
    for f in fg.factors.values():
        edges.append(VisEdge(var_ids=tuple(f.var_ids), factor_type=f.type))

    return nodes, edges

plot_dynamic_trajectories_3d(dsg, x_opt, index, title='Dynamic 3D Scene Graph', color_by_time=True)

Render 3D agent trajectories with time encoded as color.

This helper is intended for DynamicSceneGraph-style structures where agents move through time. It treats time as an implicit fourth dimension and visualizes it via either a color gradient or a solid color per agent.

:param dsg: Dynamic scene graph object exposing an iterable agents attribute and a get_agent_trajectory(agent, x_opt, index) method that returns an array of shape (T, 6) or (T, 3). Only the translational components (x, y, z) are visualized. :param x_opt: Optimized flat state vector used to decode agent poses. :param index: Mapping from node identifier to slice or (start, dim) describing how to extract each node’s state from x_opt. This is passed through to dsg.get_agent_trajectory. :param title: Optional figure title for the 3D plot. :param color_by_time: If True, encode time as a colormap gradient along each trajectory; if False, use a single solid color per agent. :return: None. The function creates and displays a Matplotlib 3D figure.

Source code in dsg-jit/dsg_jit/world/visualization.py
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def plot_dynamic_trajectories_3d(
    dsg: Any,
    x_opt: Any,
    index: Dict[Any, Union[slice, tuple]],
    title: str = "Dynamic 3D Scene Graph",
    color_by_time: bool = True,
) -> None:
    """
    Render 3D agent trajectories with time encoded as color.

    This helper is intended for ``DynamicSceneGraph``-style structures where
    agents move through time. It treats time as an implicit fourth
    dimension and visualizes it via either a color gradient or a solid
    color per agent.

    :param dsg: Dynamic scene graph object exposing an iterable
        ``agents`` attribute and a
        ``get_agent_trajectory(agent, x_opt, index)`` method that returns
        an array of shape ``(T, 6)`` or ``(T, 3)``. Only the translational
        components ``(x, y, z)`` are visualized.
    :param x_opt: Optimized flat state vector used to decode agent poses.
    :param index: Mapping from node identifier to slice or ``(start, dim)``
        describing how to extract each node’s state from ``x_opt``. This is
        passed through to ``dsg.get_agent_trajectory``.
    :param title: Optional figure title for the 3D plot.
    :param color_by_time: If ``True``, encode time as a colormap gradient
        along each trajectory; if ``False``, use a single solid color per
        agent.
    :return: None. The function creates and displays a Matplotlib 3D figure.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.set_title(title)

    all_pts: List[np.ndarray] = []

    if not hasattr(dsg, "agents"):
        raise ValueError("Dynamic scene graph object must expose an 'agents' attribute")

    # Use a base list of colors when not color-coding by time.
    base_colors = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0", "C1", "C2", "C3"])  # type: ignore[index]

    for i, agent in enumerate(dsg.agents):
        traj = dsg.get_agent_trajectory(agent, x_opt, index)
        traj = np.asarray(traj)
        if traj.ndim != 2 or traj.shape[1] < 3:
            continue

        xyz = traj[:, :3]
        all_pts.append(xyz)

        if color_by_time and xyz.shape[0] > 1:
            # Encode time as a gradient along the trajectory
            t = np.linspace(0.0, 1.0, xyz.shape[0])
            cmap = plt.get_cmap("viridis")
            for j in range(xyz.shape[0] - 1):
                c = cmap(t[j])
                ax.plot(
                    xyz[j : j + 2, 0],
                    xyz[j : j + 2, 1],
                    xyz[j : j + 2, 2],
                    color=c,
                    linewidth=2.0,
                )
        else:
            color = base_colors[i % len(base_colors)]
            ax.plot(
                xyz[:, 0],
                xyz[:, 1],
                xyz[:, 2],
                linewidth=2.0,
                label=f"{agent}_traj",
                color=color,
            )

    # Autoscale axes to include all trajectories
    if all_pts:
        stacked = np.vstack(all_pts)
        mins = stacked.min(axis=0)
        maxs = stacked.max(axis=0)
        center = 0.5 * (mins + maxs)
        extent = float((maxs - mins).max())
        if extent <= 0.0:
            extent = 1.0
        scale = 0.6 * extent
        ax.set_xlim(center[0] - scale, center[0] + scale)
        ax.set_ylim(center[1] - scale, center[1] + scale)
        ax.set_zlim(center[2] - scale, center[2] + scale)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    # Only show legend entries when not using per-segment colors.
    if not color_by_time:
        handles, labels = ax.get_legend_handles_labels()
        if labels:
            uniq = {}
            for h, l in zip(handles, labels):
                if l and l not in uniq:
                    uniq[l] = h
            ax.legend(uniq.values(), uniq.keys(), loc="best")

    plt.tight_layout()
    plt.show()

plot_factor_graph_2d(fg, show_labels=True)

Simple top-down 2D visualization of the factor graph.

  • nodes colored by type
  • edges drawn between connected variable nodes (projected to x–y)
  • dynamic aspect ratio and bounds based on node extents

:param fg: The factor graph to visualize. :param show_labels: Whether to draw node labels.

Source code in dsg-jit/dsg_jit/world/visualization.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def plot_factor_graph_2d(fg: FactorGraph, show_labels: bool = True) -> None:
    """
    Simple top-down 2D visualization of the factor graph.

    - nodes colored by type
    - edges drawn between connected variable nodes (projected to x–y)
    - dynamic aspect ratio and bounds based on node extents

    :param fg: The factor graph to visualize.
    :param show_labels: Whether to draw node labels.
    """
    nodes, edges = export_factor_graph_for_vis(fg)

    # color palette per node type
    type_to_color: Dict[NodeType, str] = {
        "pose": "C0",
        "voxel": "C1",
        "place": "C2",
        "room": "C3",
        "other": "C4",
    }

    # Build quick lookup for positions and types
    node_pos: Dict[NodeId, jnp.ndarray] = {n.id: n.position for n in nodes}
    node_type: Dict[NodeId, NodeType] = {n.id: n.type for n in nodes}

    fig, ax = plt.subplots()
    ax.set_aspect("equal")

    # Draw edges (as lines between all pairs in each factor)
    for e in edges:
        var_ids = list(e.var_ids)
        if len(var_ids) < 2:
            continue
        for i in range(len(var_ids) - 1):
            ida = var_ids[i]
            idb = var_ids[i + 1]
            a = node_pos.get(ida)
            b = node_pos.get(idb)
            if a is None or b is None:
                continue

            kind = _classify_edge_kind(node_type.get(ida, "other"),
                                       node_type.get(idb, "other"))

            if kind == "room-place":
                color, ls, lw, alpha = "magenta", "-", 1.5, 0.6
            elif kind == "place-object":
                color, ls, lw, alpha = "magenta", ":", 1.2, 0.6
            elif kind == "pose-edge":
                color, ls, lw, alpha = "gray", "--", 0.8, 0.4
            else:
                color, ls, lw, alpha = "k", ":", 0.5, 0.2

            ax.plot(
                [float(a[0]), float(b[0])],
                [float(a[1]), float(b[1])],
                linewidth=lw,
                alpha=alpha,
                linestyle=ls,
                color=color,
            )

    # Draw nodes
    xs, ys = [], []
    for n in nodes:
        c = type_to_color.get(n.type, "k")
        x, y = float(n.position[0]), float(n.position[1])
        xs.append(x)
        ys.append(y)
        ax.scatter(x, y, s=25, c=c)
        if show_labels:
            ax.text(x + 0.05, y + 0.05, n.label, fontsize=6)

    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.set_title("DSG-JIT Factor Graph (2D / top-down)")

    # Dynamic bounds with equal aspect
    if xs and ys:
        min_x, max_x = min(xs), max(xs)
        min_y, max_y = min(ys), max(ys)
        max_range = max(max_x - min_x, max_y - min_y) / 2.0
        if max_range < 1e-3:
            max_range = 1.0
        mid_x = 0.5 * (max_x + min_x)
        mid_y = 0.5 * (max_y + min_y)
        ax.set_xlim(mid_x - max_range * 1.1, mid_x + max_range * 1.1)
        ax.set_ylim(mid_y - max_range * 1.1, mid_y + max_range * 1.1)

    fig.tight_layout()
    plt.show()

plot_factor_graph_3d(fg, show_labels=True)

3D visualization of the factor graph.

  • Nodes plotted as (x, y, z)
  • Edges drawn as 3D line segments
  • Colors by node type

:param fg: The factor graph to visualize. :param show_labels: Whether to draw node labels in 3D.

Source code in dsg-jit/dsg_jit/world/visualization.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def plot_factor_graph_3d(fg: FactorGraph, show_labels: bool = True) -> None:
    """
    3D visualization of the factor graph.

    - Nodes plotted as (x, y, z)
    - Edges drawn as 3D line segments
    - Colors by node type

    :param fg: The factor graph to visualize.
    :param show_labels: Whether to draw node labels in 3D.
    """
    nodes, edges = export_factor_graph_for_vis(fg)

    type_to_color: Dict[NodeType, str] = {
        "pose": "C0",
        "voxel": "C1",
        "place": "C2",
        "room": "C3",
        "other": "C4",
    }

    node_pos: Dict[NodeId, jnp.ndarray] = {n.id: n.position for n in nodes}
    node_type: Dict[NodeId, NodeType] = {n.id: n.type for n in nodes}

    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")

    # Draw edges
    for e in edges:
        var_ids = list(e.var_ids)
        if len(var_ids) < 2:
            continue
        for i in range(len(var_ids) - 1):
            ida = var_ids[i]
            idb = var_ids[i + 1]
            a = node_pos.get(ida)
            b = node_pos.get(idb)
            if a is None or b is None:
                continue

            kind = _classify_edge_kind(node_type.get(ida, "other"),
                                       node_type.get(idb, "other"))

            if kind == "room-place":
                color, ls, lw, alpha = "magenta", "-", 1.5, 0.6
            elif kind == "place-object":
                color, ls, lw, alpha = "magenta", ":", 1.2, 0.6
            elif kind == "pose-edge":
                color, ls, lw, alpha = "gray", "--", 0.8, 0.4
            else:
                color, ls, lw, alpha = "k", ":", 0.5, 0.2

            ax.plot(
                [float(a[0]), float(b[0])],
                [float(a[1]), float(b[1])],
                [float(a[2]), float(b[2])],
                linewidth=lw,
                alpha=alpha,
                linestyle=ls,
                color=color,
            )

    # Draw nodes
    xs, ys, zs = [], [], []
    for n in nodes:
        c = type_to_color.get(n.type, "k")
        x, y, z = map(float, n.position[:3])
        xs.append(x)
        ys.append(y)
        zs.append(z)
        ax.scatter(x, y, z, s=30, c=c)
        if show_labels:
            ax.text(x, y, z, n.label, fontsize=6)

    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.set_zlabel("z [m]")
    ax.set_title("DSG-JIT Factor Graph (3D)")

    # Make aspect ratio equal in 3D
    if xs and ys and zs:
        min_x, max_x = min(xs), max(xs)
        min_y, max_y = min(ys), max(ys)
        min_z, max_z = min(zs), max(zs)
        max_range = max(max_x - min_x, max_y - min_y, max_z - min_z) / 2.0
        if max_range < 1e-3:
            max_range = 1.0
        mid_x = 0.5 * (max_x + min_x)
        mid_y = 0.5 * (max_y + min_y)
        mid_z = 0.5 * (max_z + min_z)
        ax.set_xlim(mid_x - max_range * 1.1, mid_x + max_range * 1.1)
        ax.set_ylim(mid_y - max_range * 1.1, mid_y + max_range * 1.1)
        ax.set_zlim(mid_z - max_range * 1.1, mid_z + max_range * 1.1)

    plt.show()

plot_scenegraph_3d(sg, x_opt=None, index=None, title='Scene Graph 3D', dsg=None)

Render a 3D scene graph with rooms, places, objects, place attachments, and optional agent trajectories.

This function supports two modes: - If sg exposes a _memory attribute (the SceneGraph memory layer introduced in SceneGraphWorld), node positions are read from this memory and x_opt and index are ignored. - If no memory is present, the function falls back to the previous behavior using x_opt and index to decode node states.

:param sg: Scene-graph world instance. It is expected to expose attributes such as rooms, places, objects, place_parents, object_parents, and place_attachments, following the conventions used by :class:SceneGraphWorld. :param x_opt: (Optional) Optimized flat state vector (e.g. from :meth:WorldModel.pack_state), containing the current estimates of all node states. Not required if sg exposes a _memory layer. :param index: (Optional) Mapping from node identifier to either a slice or (start, dim) tuple describing where that node’s state lives inside x_opt. Not required if sg exposes a _memory layer. :param title: Optional figure title for the Matplotlib 3D axes. :param dsg: Optional dynamic scene graph used to overlay agent trajectories. It should expose an iterable agents attribute and a get_agent_trajectory(agent, x_opt, index) method that returns an array of shape (T, 6) or (T, 3). :return: None. The function creates and displays a Matplotlib 3D figure.

Source code in dsg-jit/dsg_jit/world/visualization.py
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
def plot_scenegraph_3d(
    sg: Any,
    x_opt: Any = None,
    index: Optional[Dict[Any, Union[slice, tuple]]] = None,
    title: str = "Scene Graph 3D",
    dsg: Optional[Any] = None,
) -> None:
    """
    Render a 3D scene graph with rooms, places, objects, place attachments,
    and optional agent trajectories.

    This function supports two modes:
    - If ``sg`` exposes a ``_memory`` attribute (the SceneGraph memory layer introduced in ``SceneGraphWorld``),
      node positions are read from this memory and ``x_opt`` and ``index`` are ignored.
    - If no memory is present, the function falls back to the previous behavior using ``x_opt`` and ``index``
      to decode node states.

    :param sg: Scene-graph world instance. It is expected to expose
        attributes such as ``rooms``, ``places``, ``objects``,
        ``place_parents``, ``object_parents``, and ``place_attachments``,
        following the conventions used by :class:`SceneGraphWorld`.
    :param x_opt: (Optional) Optimized flat state vector (e.g. from
        :meth:`WorldModel.pack_state`), containing the current estimates
        of all node states. Not required if ``sg`` exposes a ``_memory`` layer.
    :param index: (Optional) Mapping from node identifier to either a slice or
        ``(start, dim)`` tuple describing where that node’s state lives
        inside ``x_opt``. Not required if ``sg`` exposes a ``_memory`` layer.
    :param title: Optional figure title for the Matplotlib 3D axes.
    :param dsg: Optional dynamic scene graph used to overlay agent
        trajectories. It should expose an iterable ``agents`` attribute
        and a ``get_agent_trajectory(agent, x_opt, index)`` method that
        returns an array of shape ``(T, 6)`` or ``(T, 3)``.
    :return: None. The function creates and displays a Matplotlib 3D figure.
    """
    has_memory = hasattr(sg, "_memory")
    mem = getattr(sg, "_memory", None)

    def _partition_memory_by_type() -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
        """
        Derive rooms / places / objects mappings from the SceneGraphWorld
        memory layer when explicit dictionaries (sg.rooms, sg.places,
        sg.objects) are not available or are empty.

        This assumes each memory entry is a small dataclass-like object
        exposing ``node_id`` and ``var_type`` attributes, where
        ``var_type`` starts with e.g. ``"room"``, ``"place"``, or
        ``"object"`` / ``"voxel"``.
        """
        rooms_m: Dict[str, Any] = {}
        places_m: Dict[str, Any] = {}
        objects_m: Dict[str, Any] = {}
        if not has_memory or mem is None:
            return rooms_m, places_m, objects_m

        # Iterate over stored node states and group by var_type prefix.
        for state in getattr(mem, "values", lambda: [])():
            vt = getattr(state, "var_type", "")
            nid = getattr(state, "node_id", None)
            if nid is None:
                continue
            # Construct simple human-readable names when no explicit names exist.
            if vt.startswith("room"):
                key = f"room_{nid}"
                rooms_m[key] = nid
            elif vt.startswith("place"):
                key = f"place_{nid}"
                places_m[key] = nid
            elif vt.startswith("object") or vt.startswith("voxel"):
                key = f"obj_{nid}"
                objects_m[key] = nid
        return rooms_m, places_m, objects_m

    def _has_state(nid: Any) -> bool:
        """
        Check whether we have a stored state for the given node id.

        When using SceneGraphWorld memory, we support both integer keys
        and arbitrary NodeId-like keys by trying ``nid`` directly first
        and then falling back to ``int(nid)`` if conversion is possible.
        """
        if has_memory:
            # Try raw key as-is
            try:
                if nid in mem:
                    return True
            except TypeError:
                # Some key types may not support `in` with this nid
                pass
            # Fallback: try integer-cast key
            try:
                nid_int = int(nid)
            except (TypeError, ValueError):
                return False
            return nid_int in mem
        if index is None:
            return False
        return nid in index

    def _vec(nid: Any) -> np.ndarray:
        if has_memory:
            # Support both direct nid keys and integer-cast keys.
            state = None
            # Try raw nid first
            try:
                state = mem.get(nid)  # type: ignore[call-arg]
            except AttributeError:
                # If _memory is not a Mapping, fall back to direct indexing
                try:
                    state = mem[nid]  # type: ignore[index]
                except Exception:
                    state = None
            if state is None:
                # Fallback: try integer-cast key
                try:
                    nid_int = int(nid)
                except (TypeError, ValueError):
                    raise KeyError(f"No state in SceneGraph memory for node id={nid!r}")
                try:
                    state = mem.get(nid_int)  # type: ignore[call-arg]
                except AttributeError:
                    state = mem[nid_int]  # type: ignore[index]
            if state is None:
                raise KeyError(f"No state in SceneGraph memory for node id={nid!r}")
            v = np.asarray(state.value).reshape(-1)
            return v
        if index is None or x_opt is None:
            raise ValueError("x_opt and index must be provided when SceneGraph memory is not available")
        idx = index[nid]
        if isinstance(idx, slice):
            sl = idx
        else:
            start, length = idx
            sl = slice(start, start + length)
        v = np.asarray(x_opt[sl]).reshape(-1)
        return v

    # Safely grab scene-graph structures (with defaults if missing).
    rooms = getattr(sg, "rooms", {}) or {}
    places = getattr(sg, "places", {}) or {}
    objects = getattr(sg, "objects", {}) or {}

    # If we have a memory layer but no explicit named dicts, derive them from memory.
    if has_memory and mem is not None:
        if not rooms or not isinstance(rooms, dict):
            mem_rooms, mem_places, mem_objects = _partition_memory_by_type()
            # Only fill in from memory when each layer is empty; this way,
            # user-provided names (if any) take precedence.
            if not rooms:
                rooms = mem_rooms
            if not places:
                places = mem_places
            if not objects:
                objects = mem_objects

    place_parents = getattr(sg, "place_parents", {}) or {}
    object_parents = getattr(sg, "object_parents", {}) or {}
    attachments = getattr(sg, "place_attachments", []) or []

    # -------------------------------------------------
    # Collect pose node ids (for rendering trajectories / agent poses).
    # We look in both the memory layer and the place-attachment edges.
    pose_ids: set[Any] = set()

    # From attachments: first element of each tuple is assumed to be a pose node id.
    for pose_nid, _ in attachments:
        pose_ids.add(pose_nid)

    # From memory: any node whose var_type starts with "pose" is treated as a pose.
    if has_memory and mem is not None:
        for state in getattr(mem, "values", lambda: [])():
            vt = getattr(state, "var_type", "")
            if vt.startswith("pose"):
                nid = getattr(state, "node_id", None)
                if nid is not None:
                    pose_ids.add(nid)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection="3d")
    ax.set_title(title)

    all_pts = []

    # ------------------------------
    # Rooms: large semi-transparent markers
    # ------------------------------
    first_room = next(iter(rooms), None)
    for name, nid in rooms.items():
        if not _has_state(nid):
            continue
        p = _vec(nid)
        if p.shape[0] < 3:
            # If we only have 1D, pad to 3D for visualization
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        all_pts.append(p[:3])
        label = "room" if name == first_room else ""
        ax.scatter(
            p[0],
            p[1],
            p[2],
            s=200,
            marker="s",
            alpha=0.3,
            edgecolor="k",
            label=label,
        )

    # ------------------------------
    # Places: medium spheres
    # ------------------------------
    first_place = next(iter(places), None)
    for name, nid in places.items():
        if not _has_state(nid):
            continue
        p = _vec(nid)
        if p.shape[0] < 3:
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        all_pts.append(p[:3])
        label = "place" if name == first_place else ""
        ax.scatter(
            p[0],
            p[1],
            p[2],
            s=60,
            marker="o",
            alpha=0.8,
            label=label,
        )

    # ------------------------------
    # Objects: small pyramids/triangles
    # ------------------------------
    first_obj = next(iter(objects), None)
    for name, nid in objects.items():
        if not _has_state(nid):
            continue
        p = _vec(nid)
        if p.shape[0] < 3:
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        all_pts.append(p[:3])
        label = "object" if name == first_obj else ""
        ax.scatter(
            p[0],
            p[1],
            p[2],
            s=40,
            marker="^",
            alpha=0.9,
            label=label,
        )

    # ------------------------------
    # Poses: agent pose nodes (small spheres)
    # ------------------------------
    first_pose = next(iter(pose_ids), None)
    for nid in pose_ids:
        if not _has_state(nid):
            continue
        p = _vec(nid)
        if p.shape[0] < 3:
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        all_pts.append(p[:3])
        label = "pose" if nid == first_pose else ""
        ax.scatter(
            p[0],
            p[1],
            p[2],
            s=30,
            marker="o",
            alpha=1.0,
            label=label,
        )

    # ------------------------------
    # Hierarchical edges: room -> place, place -> object
    # ------------------------------
    for place_nid, room_nid in place_parents.items():
        if not (_has_state(place_nid) and _has_state(room_nid)):
            continue
        p = _vec(place_nid)
        r = _vec(room_nid)
        if p.shape[0] < 3:
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        if r.shape[0] < 3:
            r = np.pad(r, (0, 3 - r.shape[0]), mode="constant")
        ax.plot(
            [p[0], r[0]],
            [p[1], r[1]],
            [p[2], r[2]],
            linestyle="-",
            linewidth=1.0,
            alpha=0.5,
        )

    for obj_nid, place_nid in object_parents.items():
        if not (_has_state(obj_nid) and _has_state(place_nid)):
            continue
        o = _vec(obj_nid)
        p = _vec(place_nid)
        if o.shape[0] < 3:
            o = np.pad(o, (0, 3 - o.shape[0]), mode="constant")
        if p.shape[0] < 3:
            p = np.pad(p, (0, 3 - p.shape[0]), mode="constant")
        ax.plot(
            [o[0], p[0]],
            [o[1], p[1]],
            [o[2], p[2]],
            linestyle="-",
            linewidth=1.0,
            alpha=0.5,
        )

    # ------------------------------
    # Place attachments: pose -> place (dashed)
    # ------------------------------
    for pose_nid, place_nid in attachments:
        if not (_has_state(pose_nid) and _has_state(place_nid)):
            continue
        pose = _vec(pose_nid)
        plc = _vec(place_nid)
        if pose.shape[0] < 3:
            pose = np.pad(pose, (0, 3 - pose.shape[0]), mode="constant")
        if plc.shape[0] < 3:
            plc = np.pad(plc, (0, 3 - plc.shape[0]), mode="constant")
        ax.plot(
            [pose[0], plc[0]],
            [pose[1], plc[1]],
            [pose[2], plc[2]],
            linestyle="--",
            linewidth=1.0,
            alpha=0.7,
        )

    # ------------------------------
    # Optional: agent trajectories from DynamicSceneGraph
    # ------------------------------
    if dsg is not None and hasattr(dsg, "agents"):
        for agent in dsg.agents:
            traj = dsg.get_agent_trajectory(agent, x_opt, index)
            traj = np.asarray(traj)
            if traj.ndim != 2 or traj.shape[1] < 3:
                continue
            xs, ys, zs = traj[:, 0], traj[:, 1], traj[:, 2]
            all_pts.extend(traj[:, :3])
            ax.plot(xs, ys, zs, linewidth=2.0, alpha=0.9, label=f"{agent}_traj")

    # ------------------------------
    # Autoscale axes to fit everything
    # ------------------------------
    if all_pts:
        all_pts_arr = np.vstack(all_pts)
        mins = all_pts_arr.min(axis=0)
        maxs = all_pts_arr.max(axis=0)
        center = 0.5 * (mins + maxs)
        extent = float((maxs - mins).max())
        if extent <= 0.0:
            extent = 1.0
        scale = 0.6 * extent
        ax.set_xlim(center[0] - scale, center[0] + scale)
        ax.set_ylim(center[1] - scale, center[1] + scale)
        ax.set_zlim(center[2] - scale, center[2] + scale)

    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    # Deduplicate legend entries
    handles, labels = ax.get_legend_handles_labels()
    if labels:
        uniq = {}
        for h, l in zip(handles, labels):
            if l and l not in uniq:
                uniq[l] = h
        ax.legend(uniq.values(), uniq.keys(), loc="best")

    plt.tight_layout()
    plt.show()

world.dynamic_scene_graph

Dynamic scene-graph utilities built on top of :mod:world.scene_graph.

This module provides a lightweight wrapper around :class:world.scene_graph.SceneGraphWorld that makes dynamic (time-indexed) scene graphs easier to build and reason about.

The goal is to keep all of the optimization and factor-graph logic in the existing engine, while giving users a small, ergonomic API for working with trajectories and other time-varying entities.

Design goals

  • Don't duplicate state: the underlying :class:SceneGraphWorld and :class:WorldModel remain the single source of truth.
  • Time-aware helpers: convenience functions for adding agent trajectories, querying poses across time, and wiring odometry factors between consecutive poses.
  • Engine-friendly: everything ultimately calls into existing SceneGraphWorld methods, so this module is safe to ignore if you want to use the lower-level API directly.

Typical usage

.. code-block:: python

from world.scene_graph import SceneGraphWorld
from world.dynamic_scene_graph import DynamicSceneGraph
import jax.numpy as jnp

sg = SceneGraphWorld()
dsg = DynamicSceneGraph(sg)

agent = "robot0"

# Add a short trajectory
dsg.add_agent_pose(agent, t=0, pose_se3=jnp.zeros(6))
dsg.add_agent_pose(agent, t=1, pose_se3=jnp.array([1.0, 0, 0, 0, 0, 0]))

# Connect poses with odometry in the x-direction
dsg.add_odom_tx(agent, t0=0, t1=1, dx=1.0, weight=10.0)

# Later, after optimization, you can recover the optimized trajectory with
# dsg.get_agent_trajectory(...).

DynamicSceneGraph(world, agents=set()) dataclass

Helper for building dynamic (time-indexed) scene graphs.

This class is a thin façade over :class:world.scene_graph.SceneGraphWorld. It does not introduce new optimization logic or state; instead it organizes common patterns for working with agent trajectories and other dynamic structures.

Parameters

world: The underlying :class:SceneGraphWorld instance. All variables and factors are ultimately added to world.wm. agents: Optional set of agent identifiers. You usually don't need to pass this explicitly; agents are registered lazily when you call :meth:add_agent or :meth:add_agent_pose.

add_agent(agent_id)

Register an agent identifier.

This does not create any variables by itself; it simply tracks the identifier so you can discover which agents exist in the graph.

:param agent_id: Hashable identifier for the agent (for example, "robot0"). :type agent_id: Hashable :return: The same agent_id that was passed in, for convenience. :rtype: Hashable

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def add_agent(self, agent_id: Hashable) -> Hashable:
    """Register an agent identifier.

    This does not create any variables by itself; it simply tracks the
    identifier so you can discover which agents exist in the graph.

    :param agent_id: Hashable identifier for the agent (for example, ``"robot0"``).
    :type agent_id: Hashable
    :return: The same ``agent_id`` that was passed in, for convenience.
    :rtype: Hashable
    """

    self.agents.add(agent_id)
    return agent_id

add_agent_pose(agent_id, t, pose_se3)

Add an SE(3) pose variable for a given agent and time.

This delegates directly to :meth:SceneGraphWorld.add_agent_pose_se3 and records the agent identifier in :attr:agents.

:param agent_id: Identifier for the agent. :type agent_id: Hashable :param t: Discrete time index (for example, frame or step index). :type t: int :param pose_se3: 6D se(3) vector [tx, ty, tz, rx, ry, rz]. :type pose_se3: jax.numpy.ndarray :return: The node identifier of the newly created pose variable. :rtype: NodeId

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def add_agent_pose(self, agent_id: Hashable, t: int, pose_se3: jnp.ndarray) -> NodeId:
    """Add an SE(3) pose variable for a given agent and time.

    This delegates directly to :meth:`SceneGraphWorld.add_agent_pose_se3`
    and records the agent identifier in :attr:`agents`.

    :param agent_id: Identifier for the agent.
    :type agent_id: Hashable
    :param t: Discrete time index (for example, frame or step index).
    :type t: int
    :param pose_se3: 6D se(3) vector ``[tx, ty, tz, rx, ry, rz]``.
    :type pose_se3: jax.numpy.ndarray
    :return: The node identifier of the newly created pose variable.
    :rtype: NodeId
    """

    self.agents.add(agent_id)
    return self.world.add_agent_pose_se3(agent_id, t, pose_se3)

add_agent_trajectory(agent_id, poses_se3, start_t=0, add_odom=True, default_dx=None, weight=1.0)

Add a contiguous trajectory for one agent and optionally wire odometry.

This is a convenience helper that repeatedly calls :meth:add_agent_pose and, if add_odom is True, :meth:add_odom_tx between consecutive time steps.

:param agent_id: Identifier for the agent. :type agent_id: Hashable :param poses_se3: Iterable of se(3) pose vectors. The first element is placed at t = start_t, the next at t = start_t + 1, and so on. :type poses_se3: Iterable[jax.numpy.ndarray] :param start_t: Time index to use for the first pose. :type start_t: int :param add_odom: If True, automatically connect consecutive poses with a 1D odometry factor along x via :meth:add_odom_tx. :type add_odom: bool :param default_dx: If not None, use this value as the expected displacement in x between each consecutive pair of poses. If None and add_odom is True, the displacement is inferred as poses_se3[k+1][0] - poses_se3[k][0]. :type default_dx: float | None :param weight: Scalar weight used for each odometry factor when add_odom is enabled. :type weight: float :return: Node identifiers of all created pose variables, in temporal order. :rtype: list[NodeId]

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def add_agent_trajectory(
    self,
    agent_id: Hashable,
    poses_se3: Iterable[jnp.ndarray],
    start_t: int = 0,
    add_odom: bool = True,
    default_dx: float | None = None,
    weight: float = 1.0,
) -> List[NodeId]:
    """Add a contiguous trajectory for one agent and optionally wire odometry.

    This is a convenience helper that repeatedly calls :meth:`add_agent_pose`
    and, if ``add_odom`` is ``True``, :meth:`add_odom_tx` between consecutive
    time steps.

    :param agent_id: Identifier for the agent.
    :type agent_id: Hashable
    :param poses_se3: Iterable of se(3) pose vectors. The first element is
        placed at ``t = start_t``, the next at ``t = start_t + 1``, and so on.
    :type poses_se3: Iterable[jax.numpy.ndarray]
    :param start_t: Time index to use for the first pose.
    :type start_t: int
    :param add_odom: If ``True``, automatically connect consecutive poses with
        a 1D odometry factor along ``x`` via :meth:`add_odom_tx`.
    :type add_odom: bool
    :param default_dx: If not ``None``, use this value as the expected
        displacement in ``x`` between each consecutive pair of poses. If
        ``None`` and ``add_odom`` is ``True``, the displacement is inferred as
        ``poses_se3[k+1][0] - poses_se3[k][0]``.
    :type default_dx: float | None
    :param weight: Scalar weight used for each odometry factor when
        ``add_odom`` is enabled.
    :type weight: float
    :return: Node identifiers of all created pose variables, in temporal order.
    :rtype: list[NodeId]
    """

    node_ids: List[NodeId] = []
    t = start_t
    prev_t: int | None = None
    prev_pose: jnp.ndarray | None = None

    for pose in poses_se3:
        nid = self.add_agent_pose(agent_id, t, pose)
        node_ids.append(nid)

        if add_odom and prev_t is not None:
            if default_dx is not None:
                dx = float(default_dx)
            else:
                # Infer displacement along x from the raw pose guesses
                dx = float(pose[0] - prev_pose[0])
            self.add_odom_tx(agent_id, prev_t, t, dx=dx, weight=weight)

        prev_t = t
        prev_pose = pose
        t += 1

    return node_ids

add_odom_tx(agent_id, t0, t1, dx, weight=1.0)

Connect two consecutive poses with a 1D odometry factor in x.

This is a convenience wrapper around :meth:SceneGraphWorld.add_odom_se3_additive, which interprets dx as a translation along the x axis and assumes identity rotation.

:param agent_id: Agent identifier. :type agent_id: Hashable :param t0: Time index of the from pose. :type t0: int :param t1: Time index of the to pose. :type t1: int :param dx: Expected displacement in x from pose (agent_id, t0) to pose (agent_id, t1). :type dx: float :param weight: Scalar weight applied to the odometry residual. :type weight: float :return: None. :rtype: None

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def add_odom_tx(
    self,
    agent_id: Hashable,
    t0: int,
    t1: int,
    dx: float,
    weight: float = 1.0,
) -> None:
    """Connect two consecutive poses with a 1D odometry factor in ``x``.

    This is a convenience wrapper around
    :meth:`SceneGraphWorld.add_odom_se3_additive`, which interprets ``dx`` as a
    translation along the ``x`` axis and assumes identity rotation.

    :param agent_id: Agent identifier.
    :type agent_id: Hashable
    :param t0: Time index of the *from* pose.
    :type t0: int
    :param t1: Time index of the *to* pose.
    :type t1: int
    :param dx: Expected displacement in ``x`` from pose ``(agent_id, t0)`` to
        pose ``(agent_id, t1)``.
    :type dx: float
    :param weight: Scalar weight applied to the odometry residual.
    :type weight: float
    :return: ``None``.
    :rtype: None
    """

    pose_i = self.world.pose_trajectory[(agent_id, t0)]
    pose_j = self.world.pose_trajectory[(agent_id, t1)]
    self.world.add_odom_se3_additive(pose_i, pose_j, dx=dx, sigma=weight)

add_range_obs(agent, t, target_nid, measured_range, sigma=0.1)

Add a range measurement from an agent's pose at time t to a target node.

This wraps :meth:SceneGraphWorld.add_range_measurement, using the pose node from pose_trajectory[(agent, t)].

:param agent: Agent key, e.g. "robot0". :param t: Integer time step. :param target_nid: NodeId of the target (place3d, voxel_cell, object3d, etc.). :param measured_range: Observed distance. :param sigma: Optional measurement noise standard deviation.

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def add_range_obs(
    self,
    agent: str,
    t: int,
    target_nid: int,
    measured_range: float,
    sigma: float | None = 0.1,
) -> None:
    """
    Add a range measurement from an agent's pose at time t to a target node.

    This wraps :meth:`SceneGraphWorld.add_range_measurement`, using the
    pose node from ``pose_trajectory[(agent, t)]``.

    :param agent: Agent key, e.g. ``"robot0"``.
    :param t: Integer time step.
    :param target_nid: NodeId of the target (place3d, voxel_cell, object3d, etc.).
    :param measured_range: Observed distance.
    :param sigma: Optional measurement noise standard deviation.
    """
    pose_nid = self.world.pose_trajectory[(agent, t)]
    self.world.add_range_measurement(
        pose_nid=pose_nid,
        target_nid=target_nid,
        measured_range=measured_range,
        sigma=sigma,
    )

all_pose_time_keys()

Return all (agent, t) keys present in the underlying world.

This is mainly useful for debugging or for building custom visualizations and exporters.

:return: All time-index keys found in :attr:SceneGraphWorld.pose_trajectory. :rtype: list[TimeKey]

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
343
344
345
346
347
348
349
350
351
352
353
354
def all_pose_time_keys(self) -> List[TimeKey]:
    """Return all ``(agent, t)`` keys present in the underlying world.

    This is mainly useful for debugging or for building custom visualizations
    and exporters.

    :return: All time-index keys found in
        :attr:`SceneGraphWorld.pose_trajectory`.
    :rtype: list[TimeKey]
    """

    return list(self.world.pose_trajectory.keys())

get_agent_pose_nodes(agent_id)

Return the sequence of pose node IDs for an agent, ordered by time.

:param agent_id: Agent identifier. :type agent_id: Hashable :return: Pose node IDs for the given agent, sorted by their time index. :rtype: list[NodeId]

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
271
272
273
274
275
276
277
278
279
280
281
def get_agent_pose_nodes(self, agent_id: Hashable) -> List[NodeId]:
    """Return the sequence of pose node IDs for an agent, ordered by time.

    :param agent_id: Agent identifier.
    :type agent_id: Hashable
    :return: Pose node IDs for the given agent, sorted by their time index.
    :rtype: list[NodeId]
    """

    times = self.get_agent_times(agent_id)
    return [self.world.pose_trajectory[(agent_id, t)] for t in times]

get_agent_times(agent_id)

Return the sorted list of time indices for which this agent has poses.

:param agent_id: Agent identifier. :type agent_id: Hashable :return: Sorted time indices where (agent_id, t) exists in :attr:SceneGraphWorld.pose_trajectory. :rtype: list[int]

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
258
259
260
261
262
263
264
265
266
267
268
269
def get_agent_times(self, agent_id: Hashable) -> List[int]:
    """Return the sorted list of time indices for which this agent has poses.

    :param agent_id: Agent identifier.
    :type agent_id: Hashable
    :return: Sorted time indices where ``(agent_id, t)`` exists in
        :attr:`SceneGraphWorld.pose_trajectory`.
    :rtype: list[int]
    """

    times = [t for (a, t) in self.world.pose_trajectory.keys() if a == agent_id]
    return sorted(times)

get_agent_trajectory(agent_id, x_opt, index)

Extract an optimized trajectory for one agent from a flat state vector.

:param agent_id: Agent identifier. :type agent_id: Hashable :param x_opt: Optimized flat state vector produced by one of the Gauss–Newton solvers, such as :func:optimization.solvers.gauss_newton_manifold. :type x_opt: jax.numpy.ndarray :param index: Mapping from :class:NodeId to (start, dim) tuples as returned by :meth:world.model.WorldModel.pack_state. :type index: Mapping[NodeId, Tuple[int, int]] :return: Array of shape (T, 6) containing the se(3) vectors for each time step in chronological order. :rtype: jax.numpy.ndarray

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def get_agent_trajectory(
    self,
    agent_id: Hashable,
    x_opt: jnp.ndarray,
    index: Mapping[NodeId, Tuple[int, int]],
) -> jnp.ndarray:
    """Extract an optimized trajectory for one agent from a flat state vector.

    :param agent_id: Agent identifier.
    :type agent_id: Hashable
    :param x_opt: Optimized flat state vector produced by one of the
        Gauss–Newton solvers, such as
        :func:`optimization.solvers.gauss_newton_manifold`.
    :type x_opt: jax.numpy.ndarray
    :param index: Mapping from :class:`NodeId` to ``(start, dim)`` tuples as
        returned by :meth:`world.model.WorldModel.pack_state`.
    :type index: Mapping[NodeId, Tuple[int, int]]
    :return: Array of shape ``(T, 6)`` containing the se(3) vectors for each
        time step in chronological order.
    :rtype: jax.numpy.ndarray
    """

    nodes = self.get_agent_pose_nodes(agent_id)
    traj = []
    for nid in nodes:
        start, dim = index[nid]
        traj.append(x_opt[start : start + dim])
    return jnp.stack(traj, axis=0)

get_all_trajectories(x_opt, index)

Extract trajectories for all known agents from an optimized state.

This is a convenience wrapper around :meth:get_agent_trajectory that iterates over :attr:agents and returns a mapping from agent identifier to a (T_i, 6) array of se(3) poses.

:param x_opt: Optimized flat state vector produced by one of the Gauss–Newton solvers. :type x_opt: jax.numpy.ndarray :param index: Mapping from :class:NodeId to (start, dim) tuples as returned by :meth:world.model.WorldModel.pack_state. :type index: Mapping[NodeId, Tuple[int, int]] :return: Dictionary mapping each agent identifier to its optimized trajectory as an array of shape (T_i, 6). :rtype: dict[Hashable, jax.numpy.ndarray]

Source code in dsg-jit/dsg_jit/world/dynamic_scene_graph.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def get_all_trajectories(
    self,
    x_opt: jnp.ndarray,
    index: Mapping[NodeId, Tuple[int, int]],
) -> Dict[Hashable, jnp.ndarray]:
    """Extract trajectories for all known agents from an optimized state.

    This is a convenience wrapper around :meth:`get_agent_trajectory` that
    iterates over :attr:`agents` and returns a mapping from agent identifier
    to a ``(T_i, 6)`` array of se(3) poses.

    :param x_opt: Optimized flat state vector produced by one of the
        Gauss–Newton solvers.
    :type x_opt: jax.numpy.ndarray
    :param index: Mapping from :class:`NodeId` to ``(start, dim)`` tuples as
        returned by :meth:`world.model.WorldModel.pack_state`.
    :type index: Mapping[NodeId, Tuple[int, int]]
    :return: Dictionary mapping each agent identifier to its optimized
        trajectory as an array of shape ``(T_i, 6)``.
    :rtype: dict[Hashable, jax.numpy.ndarray]
    """

    trajectories: Dict[Hashable, jnp.ndarray] = {}
    for agent_id in self.agents:
        trajectories[agent_id] = self.get_agent_trajectory(agent_id, x_opt, index)
    return trajectories