-
Notifications
You must be signed in to change notification settings - Fork 70
feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint #3208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint #3208
Conversation
|
@greptile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
8386a06 to
0c9fe1b
Compare
|
technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/parallel_adjoint_bases.pyLines 18-26 18
19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
20 values = np.asarray(coord_values)
21 if values.size == 0:
! 22 raise ValueError("No coordinate values available to index.")
23 if values.dtype.kind in ("f", "c"):
24 matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
25 else:
26 matches = np.where(values == target)[0]Lines 57-65 57 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
58 ) -> complex:
59 vjp = data_fields_vjp.get(self.data_path)
60 if vjp is None:
! 61 return 0.0 + 0.0j
62 data_index = self._data_index_from_sim_data(sim_data_orig)
63 vjp_array = np.asarray(vjp)
64 value = complex(vjp_array[data_index])
65 return valueLines 68-80 68 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
69 ) -> None:
70 vjp = data_fields_vjp.get(self.data_path)
71 if vjp is None:
! 72 return
73 vjp_array = np.asarray(vjp)
74 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
75 if vjp_array is not vjp:
! 76 data_fields_vjp[self.data_path] = vjp_array
77
78
79 @dataclass(frozen=True)
80 class DiffractionAdjointBasis:Lines 101-109 101 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
102 ) -> complex:
103 vjp = data_fields_vjp.get(self.data_path)
104 if vjp is None:
! 105 return 0.0 + 0.0j
106 try:
107 data_index = self._data_index_from_sim_data(sim_data_orig)
108 except ValueError:
109 return 0.0 + 0.0jLines 113-129 113 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
114 ) -> None:
115 vjp = data_fields_vjp.get(self.data_path)
116 if vjp is None:
! 117 return
118 try:
119 data_index = self._data_index_from_sim_data(sim_data_orig)
! 120 except ValueError:
! 121 return
122 vjp_array = np.asarray(vjp)
123 vjp_array[data_index] = 0.0
124 if vjp_array is not vjp:
! 125 data_fields_vjp[self.data_path] = vjp_array
126
127
128 @dataclass(frozen=True)
129 class PointFieldAdjointBasis:Lines 156-168 156 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
157 ) -> None:
158 vjp = data_fields_vjp.get(self.data_path)
159 if vjp is None:
! 160 return
161 vjp_array = np.asarray(vjp)
162 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
163 if vjp_array is not vjp:
! 164 data_fields_vjp[self.data_path] = vjp_array
165
166
167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasisLines 200-208 200 ) -> list[PointFieldAdjointBasis]:
201 bases: list[PointFieldAdjointBasis] = []
202 for component, freqs in component_freqs:
203 if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204 continue
205 for freq in freqs:
206 bases.append(
207 PointFieldAdjointBasis(
208 monitor_index=monitor_index,Lines 229-241 229 for order_x in orders_x:
230 for order_y in orders_y:
231 angle_theta = float(theta_for(int(order_x), int(order_y)))
232 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233 continue
234 for pol in pols:
235 pol_str = str(pol)
236 if pol_str not in ("s", "p"):
! 237 continue
238 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
239 bases.append(
240 DiffractionAdjointBasis(
241 monitor_index=monitor_index,tidy3d/components/autograd/source_factory.pyLines 26-34 26 def flip_direction(direction: object) -> str:
27 if hasattr(direction, "values"):
28 direction = str(direction.values)
29 if direction not in ("+", "-"):
! 30 raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
31 return "-" if direction == "+" else "+"
32
33
34 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:Lines 81-91 81 coefficient: complex,
82 fwidth: float,
83 ) -> CustomCurrentSource | None:
84 if any(simulation.symmetry):
! 85 raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
86 if not monitor.colocate:
! 87 raise ValueError("Point-field adjoint sources require colocated field monitors.")
88
89 grid = simulation.discretize_monitor(monitor)
90 coords = {}
91 spatial_coords = grid.boundariesLines 93-101 93 for axis, dim in enumerate("xyz"):
94 if monitor.size[axis] == 0:
95 coords[dim] = np.array([monitor.center[axis]])
96 else:
! 97 coords[dim] = np.array(spatial_coords_dict[dim][:-1])
98 values = (
99 2
100 * -1j
101 * coefficientLines 122-130 122 values *= scaling_factor
123 values = np.nan_to_num(values, nan=0.0)
124
125 if np.all(values == 0):
! 126 return None
127
128 dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
129 return CustomCurrentSource(
130 center=monitor.geometry.center,Lines 138-146 138 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
139 structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
140 mediums = simulation.scene.intersecting_media(monitor, structures)
141 if len(mediums) != 1:
! 142 raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
143 return list(mediums)[0]
144
145
146 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:Lines 147-155 147 boundary = simulation.boundary_spec[axis_name]
148 plus = boundary.plus
149 if hasattr(plus, "bloch_vec"):
150 return float(plus.bloch_vec)
! 151 return 0.0
152
153
154 def diffraction_order_range(
155 size: float, bloch_vec: float, freq: float, medium: objectLines 161-169 161 limit = abs(index) * freq * size / C_0
162 order_min = int(np.ceil(-limit - bloch_vec))
163 order_max = int(np.floor(limit - bloch_vec))
164 if order_max < order_min:
! 165 return np.array([], dtype=int)
166 return np.arange(order_min, order_max + 1, dtype=int)
167
168
169 def diffraction_source_from_simulation(Lines 204-212 204 theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
205 angle_theta = float(theta_vals[0, 0, 0])
206 angle_phi = float(phi_vals[0, 0, 0])
207 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 208 raise ValueError("Adjoint source not available for evanescent diffraction order.")
209
210 pol_angle = 0.0 if polarization == "p" else np.pi / 2
211 bck_eps = medium.eps_model(freq)
212 return _diffraction_plane_wave(Lines 236-248 236 angle_theta = float(theta_data.sel(**angle_sel_kwargs))
237 angle_phi = float(phi_data.sel(**angle_sel_kwargs))
238
239 if np.isnan(angle_theta):
! 240 return None
241
242 pol_str = str(polarization)
243 if pol_str not in ("p", "s"):
! 244 raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
245
246 pol_angle = 0.0 if pol_str == "p" else np.pi / 2
247 bck_eps = diff_data.medium.eps_model(freq)
248 return _diffraction_plane_wave(tidy3d/components/autograd/utils.pyLines 90-103 90 for k, v in addition.items():
91 if k in target:
92 val = target[k]
93 if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
! 94 if len(val) != len(v):
! 95 raise ValueError(
96 f"Cannot accumulate field map for key '{k}': "
97 f"length mismatch ({len(val)} vs {len(v)})."
98 )
! 99 target[k] = type(val)(x + y for x, y in zip(val, v))
100 else:
101 target[k] += v
102 else:
103 target[k] = vtidy3d/components/data/monitor_data.pyLines 192-204 192 return []
193
194 def supports_parallel_adjoint(self) -> bool:
195 """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 196 return False
197
198 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
199 """Return parallel adjoint bases for this monitor data."""
! 200 return []
201
202 @staticmethod
203 def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex:
204 """Get the complex amplitude out of some data."""Lines 1472-1480 1472
1473 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1474 """Return parallel adjoint bases for single-point field monitors."""
1475 if not self.supports_parallel_adjoint():
! 1476 return []
1477 component_freqs = [
1478 (str(component), data_array.coords["f"].values)
1479 for component, data_array in self.field_components.items()
1480 ]Lines 1884-1892 1884 return self
1885
1886 def supports_parallel_adjoint(self) -> bool:
1887 """Return ``True`` for mode monitor amplitude adjoints."""
! 1888 return True
1889
1890 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1891 """Return parallel adjoint bases for mode monitor amplitudes."""
1892 amps = self.ampsLines 4087-4095 4087 return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
4088
4089 def supports_parallel_adjoint(self) -> bool:
4090 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4091 return True
4092
4093 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
4094 """Return parallel adjoint bases for diffraction monitor amplitudes."""
4095 amps = self.ampstidy3d/components/data/sim_data.pyLines 1347-1355 1347
1348 if adjoint_source_info.normalize_sim:
1349 normalize_index_adj = 0
1350 else:
! 1351 normalize_index_adj = None
1352
1353 sim_adj_update_dict["normalize_index"] = normalize_index_adj
1354
1355 return sim_original.updated_copy(**sim_adj_update_dict)tidy3d/components/monitor.pyLines 125-133 125 for freq in freqs:
126 orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
127 orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
128 if orders_x.size == 0 or orders_y.size == 0:
! 129 continue
130
131 ux = _reciprocal_coords(
132 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
133 )Lines 190-198 190 return self.storage_size(num_cells=num_cells, tmesh=tmesh)
191
192 def supports_parallel_adjoint(self) -> bool:
193 """Return ``True`` if this monitor can provide parallel adjoint bases."""
! 194 return False
195
196 def parallel_adjoint_bases(
197 self, simulation: Simulation, monitor_index: int
198 ) -> list[ParallelAdjointBasis]:Lines 196-204 196 def parallel_adjoint_bases(
197 self, simulation: Simulation, monitor_index: int
198 ) -> list[ParallelAdjointBasis]:
199 """Return parallel adjoint bases for this monitor."""
! 200 return []
201
202
203 class FreqMonitor(Monitor, ABC):
204 """:class:`Monitor` that records data in the frequency-domain."""Lines 1173-1181 1173 return amps_size + fields_size
1174
1175 def supports_parallel_adjoint(self) -> bool:
1176 """Return ``True`` for mode monitor amplitude adjoints."""
! 1177 return True
1178
1179 def parallel_adjoint_bases(
1180 self, simulation: Simulation, monitor_index: int
1181 ) -> list[ParallelAdjointBasis]:Lines 1906-1914 1906 return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
1907
1908 def supports_parallel_adjoint(self) -> bool:
1909 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1910 return True
1911
1912 def parallel_adjoint_bases(
1913 self, simulation: Simulation, monitor_index: int
1914 ) -> list[ParallelAdjointBasis]:Lines 1912-1920 1912 def parallel_adjoint_bases(
1913 self, simulation: Simulation, monitor_index: int
1914 ) -> list[ParallelAdjointBasis]:
1915 """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1916 return _diffraction_parallel_adjoint_bases(self, simulation, monitor_index)
1917
1918
1919 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
1920 """:class:`Monitor` that uses a 2D Fourier transform to compute theLines 1989-1997 1989 return BYTES_COMPLEX * len(self.freqs)
1990
1991 def supports_parallel_adjoint(self) -> bool:
1992 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1993 return True
1994
1995 def parallel_adjoint_bases(
1996 self, simulation: Simulation, monitor_index: int
1997 ) -> list[ParallelAdjointBasis]:tidy3d/web/api/autograd/backward.pyLines 34-42 34
35 td.log.info("Running custom vjp (adjoint) pipeline.")
36
37 if not already_filtered:
! 38 data_fields_vjp = filter_vjp_map(data_fields_vjp)
39
40 # if all entries are zero, there is no adjoint sim to run
41 if not data_fields_vjp:
42 return []tidy3d/web/api/autograd/parallel_adjoint.pyLines 44-52 44 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
45 scaled = {}
46 for k, v in field_map.items():
47 if isinstance(v, (list, tuple)):
! 48 scaled[k] = type(v)(scale * x for x in v)
49 else:
50 scaled[k] = scale * v
51 return scaledLines 63-73 63 unsupported: list[str] = []
64 for monitor_index, monitor in enumerate(simulation.monitors):
65 try:
66 bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 67 except ValueError:
! 68 unsupported.append(monitor.name)
! 69 continue
70 if bases_for_monitor:
71 bases.extend(bases_for_monitor)
72 elif not monitor.supports_parallel_adjoint():
73 unsupported.append(monitor.name)Lines 107-119 107 basis_spec: object,
108 ) -> object:
109 post_norm = sim_data_adj.simulation.post_norm
110 if not hasattr(basis_spec, "freq"):
! 111 return post_norm
112 freqs = np.asarray(post_norm.coords["f"].values)
113 idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
114 if not np.isclose(freqs[idx], basis_spec.freq):
! 115 raise td.exceptions.AdjointError(
116 "Parallel adjoint basis frequency not found in adjoint post-normalization."
117 )
118 return post_norm.isel(f=[idx])Lines 136-154 136 for key, data_array in monitor_data.field_components.items():
137 if "f" in data_array.dims:
138 freqs = np.asarray(data_array.coords["f"].values)
139 if freqs.size == 0:
! 140 raise td.exceptions.AdjointError(
141 "Parallel adjoint expected frequency data but no frequencies were found."
142 )
143 idx = int(np.argmin(np.abs(freqs - freq)))
144 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 145 raise td.exceptions.AdjointError(
146 "Parallel adjoint basis frequency not found in monitor data."
147 )
148 updates[key] = data_array.isel(f=[idx])
149 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 150 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
151
152
153 def _select_sim_data_freq(
154 sim_data_adj: td.SimulationData,Lines 160-168 160 for monitor in sim.monitors:
161 if hasattr(monitor, "freqs"):
162 monitor_updated = monitor.updated_copy(freqs=[freq])
163 else:
! 164 monitor_updated = monitor
165 monitors.append(monitor_updated)
166 monitor_map[monitor.name] = monitor_updated
167 sim_updated = sim.updated_copy(monitors=monitors)Lines 228-236 228 simulation: td.Simulation,
229 basis_sources: list[tuple[ParallelAdjointBasis, Any]],
230 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
231 if not basis_sources:
! 232 return []
233
234 sim_data_stub = td.SimulationData(simulation=simulation, data=())
235 sources = [source for _, source in basis_sources]
236 sources_processed = td.SimulationData._adjoint_src_width_single(sources)Lines 314-325 314 coefficient=coefficient,
315 fwidth=fwidth,
316 )
317 if source is None:
! 318 raise ValueError("Adjoint point source has zero amplitude.")
319 return adjoint_source_info_single(source)
320
! 321 raise ValueError("Unsupported parallel adjoint basis.")
322
323
324 @dataclass(frozen=True)
325 class ParallelAdjointPayload:Lines 374-396 374 simulation=simulation,
375 basis=basis,
376 coefficient=1.0 + 0.0j,
377 )
! 378 except ValueError as exc:
! 379 td.log.info(
380 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
381 )
! 382 continue
383 basis_sources.append((basis, source_info.sources[0]))
384
385 if not basis_sources:
! 386 if basis_specs:
! 387 td.log.info("Parallel adjoint produced no simulations for this task.")
388 else:
! 389 td.log.warning(
390 "Parallel adjoint disabled because no eligible monitor outputs were found."
391 )
! 392 return None
393
394 grouped = _group_parallel_adjoint_bases_by_port(simulation, basis_sources)
395 if len(grouped) > max_num_adjoint_per_fwd:
396 raise AdjointError(Lines 415-429 415 task_map[adj_task_name] = bases
416 used_bases.extend(bases)
417
418 if not sims_adj_dict:
! 419 if basis_specs:
! 420 td.log.info("Parallel adjoint produced no simulations for this task.")
421 else:
! 422 td.log.warning(
423 "Parallel adjoint disabled because no eligible monitor outputs were found."
424 )
! 425 return None
426
427 td.log.info(
428 "Parallel adjoint enabled: launched "
429 f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."Lines 441-449 441 task_paths: dict[str, str],
442 base_dir: PathLike,
443 ) -> None:
444 if not task_names:
! 445 return
446 target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
447 target_dir.mkdir(parents=True, exist_ok=True)
448 for task_name in task_names:
449 src_path = task_paths.get(task_name)Lines 448-463 448 for task_name in task_names:
449 src_path = task_paths.get(task_name)
450 if not src_path:
451 continue
! 452 src = Path(src_path)
! 453 if not src.exists():
! 454 continue
! 455 dst = target_dir / src.name
! 456 if src.resolve() == dst.resolve():
! 457 continue
! 458 dst.parent.mkdir(parents=True, exist_ok=True)
! 459 src.replace(dst)
460
461
462 def apply_parallel_adjoint(
463 data_fields_vjp: AutogradFieldMap,Lines 465-473 465 sim_data_orig: td.SimulationData,
466 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
467 basis_maps = parallel_info.get("basis_maps")
468 if basis_maps is None:
! 469 return {}, data_fields_vjp
470
471 data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
472 vjp_parallel: AutogradFieldMap = {}
473 norm_cache: dict[int, np.ndarray] = {}Lines 480-492 480 used_bases = 0
481 for basis in basis_specs:
482 basis_map = basis_maps.get(basis)
483 if basis_map is None:
! 484 continue
485 basis_real = basis_map.get("real")
486 basis_imag = basis_map.get("imag")
487 if basis_real is None or basis_imag is None:
! 488 continue
489 tracked_bases += 1
490 if isinstance(basis, DiffractionAdjointBasis):
491 norm = norm_cache.get(basis.monitor_index)
492 if norm is None:Lines 522-535 522 f"{unused_sims} simulations were unused. Disable parallel adjoint to avoid "
523 "unused precomputations."
524 )
525 else:
! 526 td.log.warning(
527 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
528 "evaluation. Disable parallel adjoint to avoid unused precomputations."
529 )
530 else:
! 531 td.log.warning(
532 f"Parallel adjoint used {used_bases} of {tracked_bases} bases after VJP "
533 f"evaluation; {unused_bases} had zero VJP coefficients. Disable parallel adjoint "
534 "to avoid unused precomputations."
535 ) |
0c9fe1b to
4d59014
Compare
4d59014 to
7cabb37
Compare
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?
| - Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy` | ||
| - `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint. | ||
| - `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions. | ||
| - `"no_parallel"`: disable parallel adjoint entirely. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by less-determined, do you mean that it's harder to predict the adjoint sources to run in parallel and that's why a user would want to turn it off?
| - Only effective when: `config.adjoint.local_gradient = True` | ||
| - If `local_gradient=False`, the flag is ignored and behavior remains unchanged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes, this was the "easy" start.
|
|
||
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all parallel simulations counted as adjoint toward this cap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
tidy3d/plugins/autograd/README.md
Outdated
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. | ||
| - If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, changed it to raising an AdjointError as we do it currently for sequential adjoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some section in the readme
7cabb37 to
f351142
Compare
f351142 to
9abfd55
Compare
9abfd55 to
ac73e25
Compare
e13d19f to
6b77f75
Compare
|
|
||
| When enabled, Tidy3D launches eligible adjoint simulations in parallel with the forward simulation by running a set of canonical "unit" adjoint solves up front. During the backward pass, it reuses those precomputed results and scales them with the actual VJP coefficients from your objective. | ||
|
|
||
| Net effect: reduced gradient wall-clock time (often close to ~2x faster in the "many deterministic adjoints" regime), at the cost of sometimes running adjoint solves that your objective ultimately does not use. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there cases where the number of unused simulations can be large? I guess this is where `max_adjoint_per_fwd' protects the user from accidentally running a bunch of sims? Would we want to also issue a warning if a large number of adjoint simulations are unused (maybe this is done already)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this can happen if the objective does not use all modes or frequencies. Added a warning which does inform about the number of unused parallel simulations.
| Parallel adjoint launches one canonical adjoint simulation per eligible “basis,” so the total | ||
| count is driven by how many distinct outputs your monitors expose: | ||
|
|
||
| - **Mode monitors**: one basis per `(freq, mode_index, direction)`. If |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the sequential version, we do some grouping for adjoint sources based on number of ports versus number of frequencies. Does the same happen for parallel adjoint?
I'm thinking of a multi-port optimization (like multiple mode monitors) and a single frequency optimization. In sequential when we know the vjp coming from the objective, we can launch all the adjoint sources at the same time provided we set the amplitude and phase for the single frequency. Thinking through this, we wouldn't be able to do the same for the parallel case right? Instead, if I'm understanding correctly, we would run them all as single frequency injections from each mode monitor and then combine results after computing the objective function grad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I also think this is not possible as we do not know the individual VJPs from the ports before. Added a note in the readme to clarify that.
| ) -> None: | ||
| if not parallel_info or not sims_adj: | ||
| return | ||
| td.log.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean part of the process will happen with parallel and part with sequential?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently yes. But this does not really makes sense at that point. Changed it such that we completely fall back to sequential in this case.
| raise td.exceptions.AdjointError( | ||
| "Parallel adjoint basis frequency not found in adjoint post-normalization." | ||
| ) | ||
| return post_norm.isel(f=[idx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be called without the list? post_norm.isel(f=idx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a scalar index would drop the f dimension in xarray, which breaks downstream expectations for post_norm.f and frequency‑aligned broadcasting. The list keeps a length‑1 f dim and matches _select_monitor_data_freq.
| monitor = simulation.monitors[basis.monitor_index] | ||
| fwidth = adjoint_fwidth_from_simulation(simulation) | ||
|
|
||
| if isinstance(basis, DiffractionAdjointBasis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it expected that we would have cases of mismatched basis and monitor type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not in the current setup, but was introduced to detect regressions. But I guess this is over-defensive here, removed that.
|
Thanks @marcorudolphflex for working on this, it is a really huge effort! I've been working through the doc you included here and the code and still have a good bit to go. I am curious to understand a bit more about the cases that are mostly being accelerated with the parallel adjoint approach and which cases we end up needing to run extra simulations that we don't end up needing as a result (or which may have been able to be grouped if running sequentially). |
2ace906 to
9cdf564
Compare
Thanks for the review! I think the essential cases where we use more or unused simulations are frequency-grouped simulations and cases where monitor components have a 0-vjp as the objective does not use them. |
9cdf564 to
cacd0bf
Compare
Thanks for your reply on this and the extra information. I was intending to get to thinking through more how it might be broken up today, but ended up spending the day on the CustomMedium debugging. I'll get around to this tomorrow and maybe we could chat later Thursday sometime on the different cases here! |
cacd0bf to
b279009
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
a95423b to
32b0482
Compare
No worries :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
32b0482 to
cd97953
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 3 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| if basis_maps is None: | ||
| return {}, data_fields_vjp | ||
|
|
||
| data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fallback VJP copy coerces list/tuple values to arrays
Medium Severity
apply_parallel_adjoint creates its fallback copy with {k: np.array(v, copy=True) for ...}, which silently converts any list or tuple VJP values into numpy arrays. If a VJP value is a list of arrays (as get_static can produce), np.array(...) would either stack them or create an object array, losing the original container type. Downstream accumulate_field_map branches on isinstance(val, (list, tuple)) and would take the wrong path for such coerced values, leading to incorrect gradient accumulation.
cd97953 to
f7ab451
Compare
f7ab451 to
766a33d
Compare


Note
Medium Risk
Touches the autograd forward/backward execution pipeline and introduces new scheduling logic that could affect gradient correctness/performance and local-file handling. Risk is mitigated by explicit gating (
local_gradient+ config flags), fallback paths, and extensive new tests.Overview
Adds parallel adjoint scheduling to the autograd
web.run/web.run_asynclocal-gradient pipeline, launching eligible “canonical” adjoint simulations alongside the forward solve and reusing their precomputed VJP contributions during the backward pass, with automatic fallback to the existing sequential adjoint path.Introduces new config controls
config.adjoint.parallel_all_portandconfig.adjoint.parallel_adjoint_mode_direction_policy, plus new basis/source plumbing (basis specs for mode/diffraction/point-field outputs, deterministic adjoint source factories, and shared VJP accumulation/filtering helpers). Monitor and monitor-data types now exposesupports_parallel_adjoint()/parallel_adjoint_bases()for the initially supported monitor outputs (mode amplitudes, diffraction amplitudes, and single-point field probes).Refactors adjoint simulation construction into
make_adjoint_simulation(), updates test emulation/fixtures, and adds a comprehensivetest_parallel_adjoint.pysuite validating gradient equivalence, fallback warnings, direction-policy behavior, and limit/unused-work guardrails; docs and changelog are updated accordingly.Written by Cursor Bugbot for commit cd97953. This will update automatically on new commits. Configure here.
Greptile Overview
Greptile Summary
This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when
local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.Key changes:
config.adjoint.parallel_all_portconfiguration flag to enable the featureconfig.adjoint.parallel_adjoint_mode_direction_policyto control mode direction handlingParallelAdjointDescriptorclasses for mode, diffraction, and point-field monitorssupports_parallel_adjoint()andparallel_adjoint_descriptors()methodsmake_adjoint_simulation()functionThe implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.
Confidence Score: 3/5
tidy3d/web/api/autograd/parallel_adjoint.py(lines 322, 327, 329) andtidy3d/components/autograd/source_factory.py(lines 94, 207) for floating-point comparison fixes.Important Files Changed
supports_parallel_adjoint()andparallel_adjoint_descriptors()methods; extracted mode source creation to factory.make_adjoint_simulation()function for reuse; clean refactoring with no logic changes.Sequence Diagram
sequenceDiagram participant User participant AutogradAPI as Autograd API participant ParallelAdjoint as Parallel Adjoint participant Batch as Batch Executor participant Solver as FDTD Solver User->>AutogradAPI: run(sim, local_gradient=True) AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim) ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors ParallelAdjoint->>ParallelAdjoint: filter by direction policy ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload alt Parallel Adjoint Enabled AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...}) Batch->>Solver: run forward sim Batch->>Solver: run adjoint sim 1 Batch->>Solver: run adjoint sim 2 Batch-->>AutogradAPI: BatchData AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases() AutogradAPI-->>User: SimulationData + aux_data else Parallel Adjoint Disabled AutogradAPI->>Solver: run forward sim only AutogradAPI-->>User: SimulationData end User->>AutogradAPI: backward pass (VJP) AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases) ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback alt Has fallback VJPs AutogradAPI->>Solver: run sequential adjoint for remaining Solver-->>AutogradAPI: adjoint field data AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential end AutogradAPI-->>User: gradient