Skip to content

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Jan 27, 2026

Note

Medium Risk
Touches the autograd forward/backward execution pipeline and introduces new scheduling logic that could affect gradient correctness/performance and local-file handling. Risk is mitigated by explicit gating (local_gradient + config flags), fallback paths, and extensive new tests.

Overview
Adds parallel adjoint scheduling to the autograd web.run/web.run_async local-gradient pipeline, launching eligible “canonical” adjoint simulations alongside the forward solve and reusing their precomputed VJP contributions during the backward pass, with automatic fallback to the existing sequential adjoint path.

Introduces new config controls config.adjoint.parallel_all_port and config.adjoint.parallel_adjoint_mode_direction_policy, plus new basis/source plumbing (basis specs for mode/diffraction/point-field outputs, deterministic adjoint source factories, and shared VJP accumulation/filtering helpers). Monitor and monitor-data types now expose supports_parallel_adjoint()/parallel_adjoint_bases() for the initially supported monitor outputs (mode amplitudes, diffraction amplitudes, and single-point field probes).

Refactors adjoint simulation construction into make_adjoint_simulation(), updates test emulation/fixtures, and adds a comprehensive test_parallel_adjoint.py suite validating gradient equivalence, fallback warnings, direction-policy behavior, and limit/unused-work guardrails; docs and changelog are updated accordingly.

Written by Cursor Bugbot for commit cd97953. This will update automatically on new commits. Configure here.

Greptile Overview

Greptile Summary

This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.

Key changes:

  • Added config.adjoint.parallel_all_port configuration flag to enable the feature
  • Added config.adjoint.parallel_adjoint_mode_direction_policy to control mode direction handling
  • Created ParallelAdjointDescriptor classes for mode, diffraction, and point-field monitors
  • Implemented source factory functions for generating adjoint sources deterministically
  • Extended monitor data classes with supports_parallel_adjoint() and parallel_adjoint_descriptors() methods
  • Refactored adjoint simulation creation into reusable make_adjoint_simulation() function
  • Added comprehensive test suite verifying parallel vs sequential gradient equivalence
  • Updated documentation with detailed feature description

The implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.

Confidence Score: 3/5

  • This PR introduces significant new functionality with good test coverage but has floating-point comparison issues that need addressing.
  • Score reflects well-architected feature with comprehensive tests and documentation, but critical floating-point equality comparisons (5 instances) need tolerance-based checks per project standards. The refactoring is clean and maintains backward compatibility with proper fallback mechanisms.
  • Pay close attention to tidy3d/web/api/autograd/parallel_adjoint.py (lines 322, 327, 329) and tidy3d/components/autograd/source_factory.py (lines 94, 207) for floating-point comparison fixes.

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/parallel_adjoint.py New file implementing parallel adjoint scheduling. Contains floating-point comparison issues (lines 322, 327, 329) that need tolerance-based checks.
tidy3d/components/autograd/parallel_adjoint_descriptors.py New file with descriptor classes for parallel adjoint; well-structured with proper error handling and type checking.
tidy3d/components/autograd/source_factory.py New source factory utilities with floating-point equality issues (lines 94, 207) that should use tolerance-based comparisons.
tidy3d/web/api/autograd/autograd.py Extended autograd pipeline with parallel adjoint integration; adds helper functions for VJP filtering, field map accumulation, and batch processing.
tidy3d/components/data/monitor_data.py Refactored to support parallel adjoint via new supports_parallel_adjoint() and parallel_adjoint_descriptors() methods; extracted mode source creation to factory.
tidy3d/components/data/sim_data.py Extracted adjoint simulation creation into standalone make_adjoint_simulation() function for reuse; clean refactoring with no logic changes.

Sequence Diagram

sequenceDiagram
    participant User
    participant AutogradAPI as Autograd API
    participant ParallelAdjoint as Parallel Adjoint
    participant Batch as Batch Executor
    participant Solver as FDTD Solver

    User->>AutogradAPI: run(sim, local_gradient=True)
    AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim)
    ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors
    ParallelAdjoint->>ParallelAdjoint: filter by direction policy
    ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims
    ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload
    
    alt Parallel Adjoint Enabled
        AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...})
        Batch->>Solver: run forward sim
        Batch->>Solver: run adjoint sim 1
        Batch->>Solver: run adjoint sim 2
        Batch-->>AutogradAPI: BatchData
        AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases()
        AutogradAPI-->>User: SimulationData + aux_data
    else Parallel Adjoint Disabled
        AutogradAPI->>Solver: run forward sim only
        AutogradAPI-->>User: SimulationData
    end

    User->>AutogradAPI: backward pass (VJP)
    AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases)
    ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP
    ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps
    ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback
    
    alt Has fallback VJPs
        AutogradAPI->>Solver: run sequential adjoint for remaining
        Solver-->>AutogradAPI: adjoint field data
        AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential
    end
    
    AutogradAPI-->>User: gradient
Loading

@marcorudolphflex
Copy link
Contributor Author

@greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 4 times, most recently from 8386a06 to 0c9fe1b Compare January 27, 2026 12:37
@marcorudolphflex marcorudolphflex marked this pull request as ready for review January 27, 2026 12:46
@marcorudolphflex
Copy link
Contributor Author

technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link
Contributor

github-actions bot commented Jan 27, 2026

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/parallel_adjoint_bases.py (90.1%): Missing lines 22,61,72,76,105,117,120-121,125,160,164,204,233,237
  • tidy3d/components/autograd/source_factory.py (91.8%): Missing lines 30,85,87,97,126,142,151,165,208,240,244
  • tidy3d/components/autograd/utils.py (70.0%): Missing lines 94-95,99
  • tidy3d/components/data/monitor_data.py (85.7%): Missing lines 196,200,1476,1888,4091
  • tidy3d/components/data/sim_data.py (91.7%): Missing lines 1351
  • tidy3d/components/monitor.py (89.9%): Missing lines 129,194,200,1177,1910,1916,1993
  • tidy3d/config/sections.py (100%)
  • tidy3d/web/api/autograd/autograd.py (100%)
  • tidy3d/web/api/autograd/backward.py (80.0%): Missing lines 38
  • tidy3d/web/api/autograd/constants.py (100%)
  • tidy3d/web/api/autograd/parallel_adjoint.py (87.0%): Missing lines 48,67-69,111,115,140,145,150,164,232,318,321,378-379,382,386-387,389,392,419-420,422,425,445,452-459,469,484,488,526,531
  • tidy3d/web/api/autograd/utils.py (100%)

Summary

  • Total: 820 lines
  • Missing: 80 lines
  • Coverage: 90%

tidy3d/components/autograd/parallel_adjoint_bases.py

Lines 18-26

  18 
  19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
  20     values = np.asarray(coord_values)
  21     if values.size == 0:
! 22         raise ValueError("No coordinate values available to index.")
  23     if values.dtype.kind in ("f", "c"):
  24         matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
  25     else:
  26         matches = np.where(values == target)[0]

Lines 57-65

  57         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  58     ) -> complex:
  59         vjp = data_fields_vjp.get(self.data_path)
  60         if vjp is None:
! 61             return 0.0 + 0.0j
  62         data_index = self._data_index_from_sim_data(sim_data_orig)
  63         vjp_array = np.asarray(vjp)
  64         value = complex(vjp_array[data_index])
  65         return value

Lines 68-80

  68         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  69     ) -> None:
  70         vjp = data_fields_vjp.get(self.data_path)
  71         if vjp is None:
! 72             return
  73         vjp_array = np.asarray(vjp)
  74         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  75         if vjp_array is not vjp:
! 76             data_fields_vjp[self.data_path] = vjp_array
  77 
  78 
  79 @dataclass(frozen=True)
  80 class DiffractionAdjointBasis:

Lines 101-109

  101         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
  102     ) -> complex:
  103         vjp = data_fields_vjp.get(self.data_path)
  104         if vjp is None:
! 105             return 0.0 + 0.0j
  106         try:
  107             data_index = self._data_index_from_sim_data(sim_data_orig)
  108         except ValueError:
  109             return 0.0 + 0.0j

Lines 113-129

  113         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  114     ) -> None:
  115         vjp = data_fields_vjp.get(self.data_path)
  116         if vjp is None:
! 117             return
  118         try:
  119             data_index = self._data_index_from_sim_data(sim_data_orig)
! 120         except ValueError:
! 121             return
  122         vjp_array = np.asarray(vjp)
  123         vjp_array[data_index] = 0.0
  124         if vjp_array is not vjp:
! 125             data_fields_vjp[self.data_path] = vjp_array
  126 
  127 
  128 @dataclass(frozen=True)
  129 class PointFieldAdjointBasis:

Lines 156-168

  156         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  157     ) -> None:
  158         vjp = data_fields_vjp.get(self.data_path)
  159         if vjp is None:
! 160             return
  161         vjp_array = np.asarray(vjp)
  162         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  163         if vjp_array is not vjp:
! 164             data_fields_vjp[self.data_path] = vjp_array
  165 
  166 
  167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasis

Lines 200-208

  200 ) -> list[PointFieldAdjointBasis]:
  201     bases: list[PointFieldAdjointBasis] = []
  202     for component, freqs in component_freqs:
  203         if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204             continue
  205         for freq in freqs:
  206             bases.append(
  207                 PointFieldAdjointBasis(
  208                     monitor_index=monitor_index,

Lines 229-241

  229     for order_x in orders_x:
  230         for order_y in orders_y:
  231             angle_theta = float(theta_for(int(order_x), int(order_y)))
  232             if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233                 continue
  234             for pol in pols:
  235                 pol_str = str(pol)
  236                 if pol_str not in ("s", "p"):
! 237                     continue
  238                 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
  239                 bases.append(
  240                     DiffractionAdjointBasis(
  241                         monitor_index=monitor_index,

tidy3d/components/autograd/source_factory.py

Lines 26-34

  26 def flip_direction(direction: object) -> str:
  27     if hasattr(direction, "values"):
  28         direction = str(direction.values)
  29     if direction not in ("+", "-"):
! 30         raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
  31     return "-" if direction == "+" else "+"
  32 
  33 
  34 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:

Lines 81-91

  81     coefficient: complex,
  82     fwidth: float,
  83 ) -> CustomCurrentSource | None:
  84     if any(simulation.symmetry):
! 85         raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
  86     if not monitor.colocate:
! 87         raise ValueError("Point-field adjoint sources require colocated field monitors.")
  88 
  89     grid = simulation.discretize_monitor(monitor)
  90     coords = {}
  91     spatial_coords = grid.boundaries

Lines 93-101

   93     for axis, dim in enumerate("xyz"):
   94         if monitor.size[axis] == 0:
   95             coords[dim] = np.array([monitor.center[axis]])
   96         else:
!  97             coords[dim] = np.array(spatial_coords_dict[dim][:-1])
   98     values = (
   99         2
  100         * -1j
  101         * coefficient

Lines 122-130

  122     values *= scaling_factor
  123     values = np.nan_to_num(values, nan=0.0)
  124 
  125     if np.all(values == 0):
! 126         return None
  127 
  128     dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
  129     return CustomCurrentSource(
  130         center=monitor.geometry.center,

Lines 138-146

  138 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
  139     structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
  140     mediums = simulation.scene.intersecting_media(monitor, structures)
  141     if len(mediums) != 1:
! 142         raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
  143     return list(mediums)[0]
  144 
  145 
  146 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:

Lines 147-155

  147     boundary = simulation.boundary_spec[axis_name]
  148     plus = boundary.plus
  149     if hasattr(plus, "bloch_vec"):
  150         return float(plus.bloch_vec)
! 151     return 0.0
  152 
  153 
  154 def diffraction_order_range(
  155     size: float, bloch_vec: float, freq: float, medium: object

Lines 161-169

  161     limit = abs(index) * freq * size / C_0
  162     order_min = int(np.ceil(-limit - bloch_vec))
  163     order_max = int(np.floor(limit - bloch_vec))
  164     if order_max < order_min:
! 165         return np.array([], dtype=int)
  166     return np.arange(order_min, order_max + 1, dtype=int)
  167 
  168 
  169 def diffraction_source_from_simulation(

Lines 204-212

  204     theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
  205     angle_theta = float(theta_vals[0, 0, 0])
  206     angle_phi = float(phi_vals[0, 0, 0])
  207     if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 208         raise ValueError("Adjoint source not available for evanescent diffraction order.")
  209 
  210     pol_angle = 0.0 if polarization == "p" else np.pi / 2
  211     bck_eps = medium.eps_model(freq)
  212     return _diffraction_plane_wave(

Lines 236-248

  236     angle_theta = float(theta_data.sel(**angle_sel_kwargs))
  237     angle_phi = float(phi_data.sel(**angle_sel_kwargs))
  238 
  239     if np.isnan(angle_theta):
! 240         return None
  241 
  242     pol_str = str(polarization)
  243     if pol_str not in ("p", "s"):
! 244         raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
  245 
  246     pol_angle = 0.0 if pol_str == "p" else np.pi / 2
  247     bck_eps = diff_data.medium.eps_model(freq)
  248     return _diffraction_plane_wave(

tidy3d/components/autograd/utils.py

Lines 90-103

   90     for k, v in addition.items():
   91         if k in target:
   92             val = target[k]
   93             if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
!  94                 if len(val) != len(v):
!  95                     raise ValueError(
   96                         f"Cannot accumulate field map for key '{k}': "
   97                         f"length mismatch ({len(val)} vs {len(v)})."
   98                     )
!  99                 target[k] = type(val)(x + y for x, y in zip(val, v))
  100             else:
  101                 target[k] += v
  102         else:
  103             target[k] = v

tidy3d/components/data/monitor_data.py

Lines 192-204

  192         return []
  193 
  194     def supports_parallel_adjoint(self) -> bool:
  195         """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 196         return False
  197 
  198     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  199         """Return parallel adjoint bases for this monitor data."""
! 200         return []
  201 
  202     @staticmethod
  203     def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
  204         """Get the complex amplitude out of some data."""

Lines 1472-1480

  1472 
  1473     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1474         """Return parallel adjoint bases for single-point field monitors."""
  1475         if not self.supports_parallel_adjoint():
! 1476             return []
  1477         component_freqs = [
  1478             (str(component), data_array.coords["f"].values)
  1479             for component, data_array in self.field_components.items()
  1480         ]

Lines 1884-1892

  1884         return self
  1885 
  1886     def supports_parallel_adjoint(self) -> bool:
  1887         """Return ``True`` for mode monitor amplitude adjoints."""
! 1888         return True
  1889 
  1890     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1891         """Return parallel adjoint bases for mode monitor amplitudes."""
  1892         amps = self.amps

Lines 4087-4095

  4087         return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
  4088 
  4089     def supports_parallel_adjoint(self) -> bool:
  4090         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4091         return True
  4092 
  4093     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  4094         """Return parallel adjoint bases for diffraction monitor amplitudes."""
  4095         amps = self.amps

tidy3d/components/data/sim_data.py

Lines 1347-1355

  1347 
  1348     if adjoint_source_info.normalize_sim:
  1349         normalize_index_adj = 0
  1350     else:
! 1351         normalize_index_adj = None
  1352 
  1353     sim_adj_update_dict["normalize_index"] = normalize_index_adj
  1354 
  1355     return sim_original.updated_copy(**sim_adj_update_dict)

tidy3d/components/monitor.py

Lines 125-133

  125     for freq in freqs:
  126         orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
  127         orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
  128         if orders_x.size == 0 or orders_y.size == 0:
! 129             continue
  130 
  131         ux = _reciprocal_coords(
  132             orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
  133         )

Lines 190-198

  190         return self.storage_size(num_cells=num_cells, tmesh=tmesh)
  191 
  192     def supports_parallel_adjoint(self) -> bool:
  193         """Return ``True`` if this monitor can provide parallel adjoint bases."""
! 194         return False
  195 
  196     def parallel_adjoint_bases(
  197         self, simulation: Simulation, monitor_index: int
  198     ) -> list[ParallelAdjointBasis]:

Lines 196-204

  196     def parallel_adjoint_bases(
  197         self, simulation: Simulation, monitor_index: int
  198     ) -> list[ParallelAdjointBasis]:
  199         """Return parallel adjoint bases for this monitor."""
! 200         return []
  201 
  202 
  203 class FreqMonitor(Monitor, ABC):
  204     """:class:`Monitor` that records data in the frequency-domain."""

Lines 1173-1181

  1173         return amps_size + fields_size
  1174 
  1175     def supports_parallel_adjoint(self) -> bool:
  1176         """Return ``True`` for mode monitor amplitude adjoints."""
! 1177         return True
  1178 
  1179     def parallel_adjoint_bases(
  1180         self, simulation: Simulation, monitor_index: int
  1181     ) -> list[ParallelAdjointBasis]:

Lines 1906-1914

  1906         return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
  1907 
  1908     def supports_parallel_adjoint(self) -> bool:
  1909         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1910         return True
  1911 
  1912     def parallel_adjoint_bases(
  1913         self, simulation: Simulation, monitor_index: int
  1914     ) -> list[ParallelAdjointBasis]:

Lines 1912-1920

  1912     def parallel_adjoint_bases(
  1913         self, simulation: Simulation, monitor_index: int
  1914     ) -> list[ParallelAdjointBasis]:
  1915         """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1916         return _diffraction_parallel_adjoint_bases(self, simulation, monitor_index)
  1917 
  1918 
  1919 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
  1920     """:class:`Monitor` that uses a 2D Fourier transform to compute the

Lines 1989-1997

  1989         return BYTES_COMPLEX * len(self.freqs)
  1990 
  1991     def supports_parallel_adjoint(self) -> bool:
  1992         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1993         return True
  1994 
  1995     def parallel_adjoint_bases(
  1996         self, simulation: Simulation, monitor_index: int
  1997     ) -> list[ParallelAdjointBasis]:

tidy3d/web/api/autograd/backward.py

Lines 34-42

  34 
  35     td.log.info("Running custom vjp (adjoint) pipeline.")
  36 
  37     if not already_filtered:
! 38         data_fields_vjp = filter_vjp_map(data_fields_vjp)
  39 
  40     # if all entries are zero, there is no adjoint sim to run
  41     if not data_fields_vjp:
  42         return []

tidy3d/web/api/autograd/parallel_adjoint.py

Lines 44-52

  44 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
  45     scaled = {}
  46     for k, v in field_map.items():
  47         if isinstance(v, (list, tuple)):
! 48             scaled[k] = type(v)(scale * x for x in v)
  49         else:
  50             scaled[k] = scale * v
  51     return scaled

Lines 63-73

  63     unsupported: list[str] = []
  64     for monitor_index, monitor in enumerate(simulation.monitors):
  65         try:
  66             bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 67         except ValueError:
! 68             unsupported.append(monitor.name)
! 69             continue
  70         if bases_for_monitor:
  71             bases.extend(bases_for_monitor)
  72         elif not monitor.supports_parallel_adjoint():
  73             unsupported.append(monitor.name)

Lines 107-119

  107     basis_spec: object,
  108 ) -> object:
  109     post_norm = sim_data_adj.simulation.post_norm
  110     if not hasattr(basis_spec, "freq"):
! 111         return post_norm
  112     freqs = np.asarray(post_norm.coords["f"].values)
  113     idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
  114     if not np.isclose(freqs[idx], basis_spec.freq):
! 115         raise td.exceptions.AdjointError(
  116             "Parallel adjoint basis frequency not found in adjoint post-normalization."
  117         )
  118     return post_norm.isel(f=[idx])

Lines 136-154

  136         for key, data_array in monitor_data.field_components.items():
  137             if "f" in data_array.dims:
  138                 freqs = np.asarray(data_array.coords["f"].values)
  139                 if freqs.size == 0:
! 140                     raise td.exceptions.AdjointError(
  141                         "Parallel adjoint expected frequency data but no frequencies were found."
  142                     )
  143                 idx = int(np.argmin(np.abs(freqs - freq)))
  144                 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 145                     raise td.exceptions.AdjointError(
  146                         "Parallel adjoint basis frequency not found in monitor data."
  147                     )
  148                 updates[key] = data_array.isel(f=[idx])
  149         return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 150     return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
  151 
  152 
  153 def _select_sim_data_freq(
  154     sim_data_adj: td.SimulationData,

Lines 160-168

  160     for monitor in sim.monitors:
  161         if hasattr(monitor, "freqs"):
  162             monitor_updated = monitor.updated_copy(freqs=[freq])
  163         else:
! 164             monitor_updated = monitor
  165         monitors.append(monitor_updated)
  166         monitor_map[monitor.name] = monitor_updated
  167     sim_updated = sim.updated_copy(monitors=monitors)

Lines 228-236

  228     simulation: td.Simulation,
  229     basis_sources: list[tuple[ParallelAdjointBasis, Any]],
  230 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
  231     if not basis_sources:
! 232         return []
  233 
  234     sim_data_stub = td.SimulationData(simulation=simulation, data=())
  235     sources = [source for _, source in basis_sources]
  236     sources_processed = td.SimulationData._adjoint_src_width_single(sources)

Lines 314-325

  314             coefficient=coefficient,
  315             fwidth=fwidth,
  316         )
  317         if source is None:
! 318             raise ValueError("Adjoint point source has zero amplitude.")
  319         return adjoint_source_info_single(source)
  320 
! 321     raise ValueError("Unsupported parallel adjoint basis.")
  322 
  323 
  324 @dataclass(frozen=True)
  325 class ParallelAdjointPayload:

Lines 374-396

  374                 simulation=simulation,
  375                 basis=basis,
  376                 coefficient=1.0 + 0.0j,
  377             )
! 378         except ValueError as exc:
! 379             td.log.info(
  380                 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
  381             )
! 382             continue
  383         basis_sources.append((basis, source_info.sources[0]))
  384 
  385     if not basis_sources:
! 386         if basis_specs:
! 387             td.log.info("Parallel adjoint produced no simulations for this task.")
  388         else:
! 389             td.log.warning(
  390                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  391             )
! 392         return None
  393 
  394     grouped = _group_parallel_adjoint_bases_by_port(simulation, basis_sources)
  395     if len(grouped) > max_num_adjoint_per_fwd:
  396         raise AdjointError(

Lines 415-429

  415         task_map[adj_task_name] = bases
  416         used_bases.extend(bases)
  417 
  418     if not sims_adj_dict:
! 419         if basis_specs:
! 420             td.log.info("Parallel adjoint produced no simulations for this task.")
  421         else:
! 422             td.log.warning(
  423                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  424             )
! 425         return None
  426 
  427     td.log.info(
  428         "Parallel adjoint enabled: launched "
  429         f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."

Lines 441-449

  441     task_paths: dict[str, str],
  442     base_dir: PathLike,
  443 ) -> None:
  444     if not task_names:
! 445         return
  446     target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
  447     target_dir.mkdir(parents=True, exist_ok=True)
  448     for task_name in task_names:
  449         src_path = task_paths.get(task_name)

Lines 448-463

  448     for task_name in task_names:
  449         src_path = task_paths.get(task_name)
  450         if not src_path:
  451             continue
! 452         src = Path(src_path)
! 453         if not src.exists():
! 454             continue
! 455         dst = target_dir / src.name
! 456         if src.resolve() == dst.resolve():
! 457             continue
! 458         dst.parent.mkdir(parents=True, exist_ok=True)
! 459         src.replace(dst)
  460 
  461 
  462 def apply_parallel_adjoint(
  463     data_fields_vjp: AutogradFieldMap,

Lines 465-473

  465     sim_data_orig: td.SimulationData,
  466 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
  467     basis_maps = parallel_info.get("basis_maps")
  468     if basis_maps is None:
! 469         return {}, data_fields_vjp
  470 
  471     data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
  472     vjp_parallel: AutogradFieldMap = {}
  473     norm_cache: dict[int, np.ndarray] = {}

Lines 480-492

  480     used_bases = 0
  481     for basis in basis_specs:
  482         basis_map = basis_maps.get(basis)
  483         if basis_map is None:
! 484             continue
  485         basis_real = basis_map.get("real")
  486         basis_imag = basis_map.get("imag")
  487         if basis_real is None or basis_imag is None:
! 488             continue
  489         tracked_bases += 1
  490         if isinstance(basis, DiffractionAdjointBasis):
  491             norm = norm_cache.get(basis.monitor_index)
  492             if norm is None:

Lines 522-535

  522                     f"{unused_sims} simulations were unused. Disable parallel adjoint to avoid "
  523                     "unused precomputations."
  524                 )
  525             else:
! 526                 td.log.warning(
  527                     f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
  528                     "evaluation. Disable parallel adjoint to avoid unused precomputations."
  529                 )
  530         else:
! 531             td.log.warning(
  532                 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
  533                 f"evaluation; {unused_bases} had zero VJP coefficients. Disable parallel adjoint "
  534                 "to avoid unused precomputations."
  535             )

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 0c9fe1b to 4d59014 Compare January 27, 2026 15:56
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 4d59014 to 7cabb37 Compare January 27, 2026 16:10
Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?

- Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy`
- `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint.
- `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions.
- `"no_parallel"`: disable parallel adjoint entirely.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by less-determined, do you mean that it's harder to predict the adjoint sources to run in parallel and that's why a user would want to turn it off?

Comment on lines +69 to +70
- Only effective when: `config.adjoint.local_gradient = True`
- If `local_gradient=False`, the flag is ignored and behavior remains unchanged.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes, this was the "easy" start.


#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all parallel simulations counted as adjoint toward this cap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
- If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, changed it to raising an AdjointError as we do it currently for sequential adjoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some section in the readme

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 7cabb37 to f351142 Compare January 28, 2026 08:19
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from f351142 to 9abfd55 Compare January 28, 2026 09:46
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 9abfd55 to ac73e25 Compare January 28, 2026 11:34
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from e13d19f to 6b77f75 Compare January 29, 2026 16:37

When enabled, Tidy3D launches eligible adjoint simulations in parallel with the forward simulation by running a set of canonical "unit" adjoint solves up front. During the backward pass, it reuses those precomputed results and scales them with the actual VJP coefficients from your objective.

Net effect: reduced gradient wall-clock time (often close to ~2x faster in the "many deterministic adjoints" regime), at the cost of sometimes running adjoint solves that your objective ultimately does not use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there cases where the number of unused simulations can be large? I guess this is where `max_adjoint_per_fwd' protects the user from accidentally running a bunch of sims? Would we want to also issue a warning if a large number of adjoint simulations are unused (maybe this is done already)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this can happen if the objective does not use all modes or frequencies. Added a warning which does inform about the number of unused parallel simulations.

Parallel adjoint launches one canonical adjoint simulation per eligible “basis,” so the total
count is driven by how many distinct outputs your monitors expose:

- **Mode monitors**: one basis per `(freq, mode_index, direction)`. If
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the sequential version, we do some grouping for adjoint sources based on number of ports versus number of frequencies. Does the same happen for parallel adjoint?

I'm thinking of a multi-port optimization (like multiple mode monitors) and a single frequency optimization. In sequential when we know the vjp coming from the objective, we can launch all the adjoint sources at the same time provided we set the amplitude and phase for the single frequency. Thinking through this, we wouldn't be able to do the same for the parallel case right? Instead, if I'm understanding correctly, we would run them all as single frequency injections from each mode monitor and then combine results after computing the objective function grad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I also think this is not possible as we do not know the individual VJPs from the ports before. Added a note in the readme to clarify that.

) -> None:
if not parallel_info or not sims_adj:
return
td.log.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean part of the process will happen with parallel and part with sequential?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently yes. But this does not really makes sense at that point. Changed it such that we completely fall back to sequential in this case.

raise td.exceptions.AdjointError(
"Parallel adjoint basis frequency not found in adjoint post-normalization."
)
return post_norm.isel(f=[idx])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be called without the list? post_norm.isel(f=idx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a scalar index would drop the f dimension in xarray, which breaks downstream expectations for post_norm.f and frequency‑aligned broadcasting. The list keeps a length‑1 f dim and matches _select_monitor_data_freq.

monitor = simulation.monitors[basis.monitor_index]
fwidth = adjoint_fwidth_from_simulation(simulation)

if isinstance(basis, DiffractionAdjointBasis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected that we would have cases of mismatched basis and monitor type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in the current setup, but was introduced to detect regressions. But I guess this is over-defensive here, removed that.

@groberts-flex
Copy link
Contributor

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially).
I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes.
I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 3 times, most recently from 2ace906 to 9cdf564 Compare February 3, 2026 11:21
@marcorudolphflex
Copy link
Contributor Author

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review!
I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea?
We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 9cdf564 to cacd0bf Compare February 3, 2026 11:59
@groberts-flex
Copy link
Contributor

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review! I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea? We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review! I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea? We should definitely discuss how individual cases are handled and how this could be further configured.

I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

Thanks for your reply on this and the extra information. I was intending to get to thinking through more how it might be broken up today, but ended up spending the day on the CustomMedium debugging. I'll get around to this tomorrow and maybe we could chat later Thursday sometime on the different cases here!

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from cacd0bf to b279009 Compare February 6, 2026 11:52
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 2 times, most recently from a95423b to 32b0482 Compare February 6, 2026 13:18
@marcorudolphflex
Copy link
Contributor Author

marcorudolphflex commented Feb 6, 2026

Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). I think it would be worth us all talking through the approach a bit more and ways we could simplify some of the logic here. At the very least, for code review purposes, it might be easier if this can be broken up into smaller changes. I'm excited to chat through, it's definitely a cool feature, thanks again for working on it!

Thanks for the review! I can understand that this is not easy to review. Unfortunately, I am not quite sure how I could exactly split it up as most changes are related, do you have an idea? We should definitely discuss how individual cases are handled and how this could be further configured.
I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them.

Thanks for your reply on this and the extra information. I was intending to get to thinking through more how it might be broken up today, but ended up spending the day on the CustomMedium debugging. I'll get around to this tomorrow and maybe we could chat later Thursday sometime on the different cases here!

No worries :)
Sorry, missed your comment here. I think it would make sense to draft an RFC for this to open the discussion. Still would be helpful to chat in-person, maybe use the next Tuesday meeting?

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 32b0482 to cd97953 Compare February 6, 2026 14:48
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

if basis_maps is None:
return {}, data_fields_vjp

data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fallback VJP copy coerces list/tuple values to arrays

Medium Severity

apply_parallel_adjoint creates its fallback copy with {k: np.array(v, copy=True) for ...}, which silently converts any list or tuple VJP values into numpy arrays. If a VJP value is a list of arrays (as get_static can produce), np.array(...) would either stack them or create an object array, losing the original container type. Downstream accumulate_field_map branches on isinstance(val, (list, tuple)) and would take the wrong path for such coerced values, leading to incorrect gradient accumulation.

Fix in Cursor Fix in Web

@marcorudolphflex marcorudolphflex marked this pull request as draft February 6, 2026 15:05
@marcorudolphflex marcorudolphflex removed the request for review from tylerflex February 6, 2026 15:06
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from cd97953 to f7ab451 Compare February 9, 2026 11:27
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from f7ab451 to 766a33d Compare February 11, 2026 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants