diff --git a/src/cdtools/models/bragg_2d_ptycho.py b/src/cdtools/models/bragg_2d_ptycho.py index b3869a9c..507323d0 100644 --- a/src/cdtools/models/bragg_2d_ptycho.py +++ b/src/cdtools/models/bragg_2d_ptycho.py @@ -285,13 +285,14 @@ def from_dataset( # > dataset.sample_info['orientation'] > transmission geometry if surface_normal is not None: surface_normal = np.asarray(surface_normal) - elif scattering_mode.strip().lower() in {'t', 'transmission'}: - surface_normal = np.array([0.,0.,1.]) - elif scattering_mode.strip().lower() in {'r', 'reflection'}: - outgoing_dir = np.cross(det_basis[:,0], det_basis[:,1]) - outgoing_dir /= np.linalg.norm(outgoing_dir) - surface_normal = outgoing_dir + np.array([0.,0.,1.]) - surface_normal /= np.linalg.norm(outgoing_dir) + elif isinstance(scattering_mode, str): + if scattering_mode.strip().lower() in {'t', 'transmission'}: + surface_normal = np.array([0.,0.,1.]) + elif scattering_mode.strip().lower() in {'r', 'reflection'}: + outgoing_dir = np.cross(det_basis[:,0], det_basis[:,1]) + outgoing_dir /= np.linalg.norm(outgoing_dir) + surface_normal = outgoing_dir + np.array([0.,0.,1.]) + surface_normal /= np.linalg.norm(outgoing_dir) elif scattering_mode is not None: raise ValueError( 'Scattering mode must be either "transmission" ("t"), "reflection" ("r"), or the default of None.'