diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 28f6dfe..95f9ae4 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.5.0
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -11,20 +11,20 @@ repos:
- id: check-merge-conflict
- id: detect-private-key
-- repo: https://github.com/psf/black
- rev: 24.1.1
+- repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 26.1.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
- rev: 5.13.2
+ rev: 7.0.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pycqa/flake8
- rev: 7.0.0
+ rev: 7.3.0
hooks:
- id: flake8
additional_dependencies: [flake8-docstrings]
diff --git a/README.md b/README.md
index 774ed97..3349b17 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# Factory
-> **⚠️ This project is under active development.**
+> **⚠️ This project is under active development.**
A Python package for crystallographic data processing and analysis.
diff --git a/configs/config_example.yaml b/configs/config_example.yaml
index 303ebf5..3316503 100644
--- a/configs/config_example.yaml
+++ b/configs/config_example.yaml
@@ -22,7 +22,7 @@ model_settings:
weight_decay: 0.000
optimizer_eps: 1e-16
weight_decouple: true
-
+
# Scheduler configuration
# scheduler_step: 55
scheduler_gamma: 0.1
@@ -41,18 +41,18 @@ model_settings:
# scale_prior_variance: 6241.0
# scale_distribution: FoldedNormalDistributionLayer
- metadata_encoder:
+ metadata_encoder:
name: BaseMetadataEncoder
params:
depth: 10
- # metadata_encoder:
+ # metadata_encoder:
# name: SimpleMetadataEncoder
# params:
# hidden_dim: 256
# depth: 5
- profile_encoder:
+ profile_encoder:
name: BaseShoeboxEncoder
params:
# in_channels: 1
@@ -91,7 +91,7 @@ model_settings:
# params:
# concentration: 6.68
# rate: 0.001
-
+
background_prior_distribution:
# name: TorchHalfNormal
# params:
@@ -121,7 +121,3 @@ loss_settings:
phenix_settings:
r_values_reference_path: "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log"
-
-
-
-
diff --git a/docs/Makefile b/docs/Makefile
index 43a2023..6ff62d7 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -20,4 +20,4 @@ help:
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
watch:
- sphinx-autobuild . _build/html --open-browser --watch examples
\ No newline at end of file
+ sphinx-autobuild . _build/html --open-browser --watch examples
diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html
index dd3ed27..15ee00f 100644
--- a/docs/_templates/layout.html
+++ b/docs/_templates/layout.html
@@ -7,4 +7,4 @@
Home
API Reference
-{% endblock %}
\ No newline at end of file
+{% endblock %}
diff --git a/docs/api/index.rst b/docs/api/index.rst
index 76477f4..0cb747f 100644
--- a/docs/api/index.rst
+++ b/docs/api/index.rst
@@ -7,4 +7,4 @@ API Reference
model
distributions
-This section provides detailed API documentation for Factory's modules and functions.
\ No newline at end of file
+This section provides detailed API documentation for Factory's modules and functions.
diff --git a/docs/conf.py b/docs/conf.py
index b2a73d6..fb88eb5 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -227,4 +227,4 @@ def linkcode_resolve(domain, info):
fn = os.path.relpath(fn, start=os.path.dirname("../factory"))
- return f"https://github.com/rs-station/factory/blob/main/factory/{fn}{linespec}" # noqa
\ No newline at end of file
+ return f"https://github.com/rs-station/factory/blob/main/factory/{fn}{linespec}" # noqa
diff --git a/docs/convert_favicon.py b/docs/convert_favicon.py
index ec1aa44..c6abfcf 100644
--- a/docs/convert_favicon.py
+++ b/docs/convert_favicon.py
@@ -1,4 +1,9 @@
import cairosvg
# Convert SVG to PNG
-cairosvg.svg2png(url='images/favicon.svg', write_to='images/favicon_32x32.png', output_width=32, output_height=32)
\ No newline at end of file
+cairosvg.svg2png(
+ url="images/favicon.svg",
+ write_to="images/favicon_32x32.png",
+ output_width=32,
+ output_height=32,
+)
diff --git a/docs/images/favicon.svg b/docs/images/favicon.svg
index 123e16e..76b8aa6 100644
--- a/docs/images/favicon.svg
+++ b/docs/images/favicon.svg
@@ -2,4 +2,4 @@
\ No newline at end of file
+
diff --git a/docs/index.rst b/docs/index.rst
index bac7202..106f96b 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -16,4 +16,4 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
-* :ref:`search`
\ No newline at end of file
+* :ref:`search`
diff --git a/docs/installation.rst b/docs/installation.rst
index 90f41e9..d2c8097 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -21,4 +21,4 @@ For development installation:
git clone https://github.com/rs-station/factory.git
cd factory
- pip install -e ".[test,docs]"
\ No newline at end of file
+ pip install -e ".[test,docs]"
diff --git a/docs/make.bat b/docs/make.bat
index 541f733..922152e 100644
--- a/docs/make.bat
+++ b/docs/make.bat
@@ -32,4 +32,4 @@ goto end
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
-popd
\ No newline at end of file
+popd
diff --git a/docs/usage.rst b/docs/usage.rst
index 0f57a4e..b1e940a 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -14,4 +14,4 @@ Here's a simple example of how to use Factory:
# Create and use your model
TODO
-For more detailed examples and API reference, see the :doc:`api/index` section.
\ No newline at end of file
+For more detailed examples and API reference, see the :doc:`api/index` section.
diff --git a/pymol_figures/refine_001.pdb b/pymol_figures/refine_001.pdb
index ec8d9bb..ad2c5f0 100644
--- a/pymol_figures/refine_001.pdb
+++ b/pymol_figures/refine_001.pdb
@@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini,
REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart
REMARK 3
REMARK 3 X-RAY DATA.
-REMARK 3
+REMARK 3
REMARK 3 REFINEMENT TARGET : ML
-REMARK 3
+REMARK 3
REMARK 3 DATA USED IN REFINEMENT.
-REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
-REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
-REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.33
-REMARK 3 COMPLETENESS FOR RANGE (%) : 90.85
-REMARK 3 NUMBER OF REFLECTIONS : 85983
-REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45512
-REMARK 3
+REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
+REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
+REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.33
+REMARK 3 COMPLETENESS FOR RANGE (%) : 90.85
+REMARK 3 NUMBER OF REFLECTIONS : 85983
+REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45512
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT.
REMARK 3 R VALUE (WORKING + TEST SET) : 0.1221
REMARK 3 R VALUE (WORKING SET) : 0.1203
REMARK 3 FREE R VALUE : 0.1386
-REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
-REMARK 3 FREE R VALUE TEST SET COUNT : 8804
-REMARK 3
+REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
+REMARK 3 FREE R VALUE TEST SET COUNT : 8804
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS).
REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE
REMARK 3 1 56.16 - 3.41 1.00 2849 309 0.1392 0.1691 0.948 0.914
@@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.68 1903 241 0.3115 0.3307 0.747
REMARK 3 28 1.14 - 1.12 0.54 1517 170 0.3396 0.3314 0.703 0.701
REMARK 3 29 1.12 - 1.11 0.34 960 84 0.3824 0.4505 0.607 0.613
REMARK 3 30 1.11 - 1.10 0.11 324 39 0.4270 0.4367 0.652 0.359
-REMARK 3
+REMARK 3
REMARK 3 BULK SOLVENT MODELLING.
REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL
-REMARK 3 SOLVENT RADIUS : 1.10
-REMARK 3 SHRINKAGE RADIUS : 0.90
-REMARK 3 GRID STEP : 0.60
-REMARK 3
+REMARK 3 SOLVENT RADIUS : 1.10
+REMARK 3 SHRINKAGE RADIUS : 0.90
+REMARK 3 GRID STEP : 0.60
+REMARK 3
REMARK 3 ERROR ESTIMATES.
-REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.10
-REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 12.84
-REMARK 3
+REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.10
+REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 12.84
+REMARK 3
REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT
-REMARK 3
+REMARK 3
REMARK 3 B VALUES.
-REMARK 3 FROM WILSON PLOT (A**2) : 15.22
+REMARK 3 FROM WILSON PLOT (A**2) : 15.22
REMARK 3 Individual atomic B
REMARK 3 min max mean iso aniso
REMARK 3 Overall: 11.38 62.75 20.78 1.59 0 1284
@@ -92,7 +92,7 @@ REMARK 3 42.20 - 47.34 24
REMARK 3 47.34 - 52.48 11
REMARK 3 52.48 - 57.61 2
REMARK 3 57.61 - 62.75 3
-REMARK 3
+REMARK 3
REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2
REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES.
REMARK 3 BOND : 0.008 0.066 1227 Z= 0.503
@@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.085 0.253 163
REMARK 3 PLANARITY : 0.009 0.070 228
REMARK 3 DIHEDRAL : 12.968 82.313 464
REMARK 3 MIN NONBONDED DISTANCE : 2.253
-REMARK 3
+REMARK 3
REMARK 3 MOLPROBITY STATISTICS.
REMARK 3 ALL-ATOM CLASHSCORE : 0.00
REMARK 3 RAMACHANDRAN PLOT:
@@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 %
REMARK 3 CIS-GENERAL : 0.00 %
REMARK 3 TWISTED PROLINE : 0.00 %
REMARK 3 TWISTED GENERAL : 0.00 %
-REMARK 3
+REMARK 3
REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE):
REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2.
REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY;
@@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.12 (0.63), RESIDUES: 173
REMARK 3 HELIX: -1.09 (0.58), RESIDUES: 68
REMARK 3 SHEET: -1.28 (0.81), RESIDUES: 14
REMARK 3 LOOP : 1.25 (0.70), RESIDUES: 91
-REMARK 3
+REMARK 3
REMARK 3 MAX DEVIATION FROM PLANES:
REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE
-REMARK 3 TRP 0.041 0.009 TRP A 123
-REMARK 3 HIS 0.005 0.003 HIS A 15
-REMARK 3 PHE 0.037 0.008 PHE A 38
-REMARK 3 TYR 0.017 0.006 TYR A 20
-REMARK 3 ARG 0.009 0.002 ARG A 114
+REMARK 3 TRP 0.041 0.009 TRP A 123
+REMARK 3 HIS 0.005 0.003 HIS A 15
+REMARK 3 PHE 0.037 0.008 PHE A 38
+REMARK 3 TYR 0.017 0.006 TYR A 20
+REMARK 3 ARG 0.009 0.002 ARG A 114
REMARK 3
HELIX 1 AA1 GLY A 4 HIS A 15 1 12
HELIX 2 AA2 ASN A 19 TYR A 23 5 5
@@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7
SHEET 1 AA1 3 THR A 43 ARG A 45 0
SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44
SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53
-SSBOND 1 CYS A 6 CYS A 127
-SSBOND 2 CYS A 30 CYS A 115
-SSBOND 3 CYS A 64 CYS A 80
-SSBOND 4 CYS A 76 CYS A 94
+SSBOND 1 CYS A 6 CYS A 127
+SSBOND 2 CYS A 30 CYS A 115
+SSBOND 3 CYS A 64 CYS A 80
+SSBOND 4 CYS A 76 CYS A 94
CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2
SCALE1 0.012591 0.000000 0.000000 0.00000
SCALE2 0.000000 0.012591 0.000000 0.00000
diff --git a/pymol_figures/refine_epoch=03_001.pdb b/pymol_figures/refine_epoch=03_001.pdb
index 6456a71..5b46a09 100644
--- a/pymol_figures/refine_epoch=03_001.pdb
+++ b/pymol_figures/refine_epoch=03_001.pdb
@@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini,
REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart
REMARK 3
REMARK 3 X-RAY DATA.
-REMARK 3
+REMARK 3
REMARK 3 REFINEMENT TARGET : ML
-REMARK 3
+REMARK 3
REMARK 3 DATA USED IN REFINEMENT.
-REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
-REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
-REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32
-REMARK 3 COMPLETENESS FOR RANGE (%) : 91.77
-REMARK 3 NUMBER OF REFLECTIONS : 86205
-REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45591
-REMARK 3
+REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
+REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
+REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32
+REMARK 3 COMPLETENESS FOR RANGE (%) : 91.77
+REMARK 3 NUMBER OF REFLECTIONS : 86205
+REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45591
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT.
REMARK 3 R VALUE (WORKING + TEST SET) : 0.1313
REMARK 3 R VALUE (WORKING SET) : 0.1292
REMARK 3 FREE R VALUE : 0.1496
-REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
-REMARK 3 FREE R VALUE TEST SET COUNT : 8826
-REMARK 3
+REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
+REMARK 3 FREE R VALUE TEST SET COUNT : 8826
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS).
REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE
REMARK 3 1 56.16 - 3.42 1.00 2832 303 0.1380 0.1645 0.948 0.922
@@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.71 1992 232 0.4224 0.3941 0.547
REMARK 3 28 1.14 - 1.13 0.57 1632 177 0.4693 0.4421 0.466 0.457
REMARK 3 29 1.12 - 1.11 0.41 1145 121 0.5244 0.4445 0.352 0.592
REMARK 3 30 1.11 - 1.10 0.21 587 59 0.5427 0.6604 0.314 0.371
-REMARK 3
+REMARK 3
REMARK 3 BULK SOLVENT MODELLING.
REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL
-REMARK 3 SOLVENT RADIUS : 1.10
-REMARK 3 SHRINKAGE RADIUS : 0.90
-REMARK 3 GRID STEP : 0.60
-REMARK 3
+REMARK 3 SOLVENT RADIUS : 1.10
+REMARK 3 SHRINKAGE RADIUS : 0.90
+REMARK 3 GRID STEP : 0.60
+REMARK 3
REMARK 3 ERROR ESTIMATES.
-REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.14
-REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 17.67
-REMARK 3
+REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.14
+REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 17.67
+REMARK 3
REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT
-REMARK 3
+REMARK 3
REMARK 3 B VALUES.
-REMARK 3 FROM WILSON PLOT (A**2) : 16.28
+REMARK 3 FROM WILSON PLOT (A**2) : 16.28
REMARK 3 Individual atomic B
REMARK 3 min max mean iso aniso
REMARK 3 Overall: 12.92 61.47 22.37 1.60 0 1284
@@ -92,7 +92,7 @@ REMARK 3 42.05 - 46.91 29
REMARK 3 46.91 - 51.76 13
REMARK 3 51.76 - 56.62 4
REMARK 3 56.62 - 61.47 3
-REMARK 3
+REMARK 3
REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2
REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES.
REMARK 3 BOND : 0.007 0.066 1227 Z= 0.465
@@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.082 0.243 163
REMARK 3 PLANARITY : 0.009 0.075 228
REMARK 3 DIHEDRAL : 12.908 82.367 464
REMARK 3 MIN NONBONDED DISTANCE : 2.274
-REMARK 3
+REMARK 3
REMARK 3 MOLPROBITY STATISTICS.
REMARK 3 ALL-ATOM CLASHSCORE : 0.00
REMARK 3 RAMACHANDRAN PLOT:
@@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 %
REMARK 3 CIS-GENERAL : 0.00 %
REMARK 3 TWISTED PROLINE : 0.00 %
REMARK 3 TWISTED GENERAL : 0.00 %
-REMARK 3
+REMARK 3
REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE):
REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2.
REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY;
@@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.43 (0.62), RESIDUES: 173
REMARK 3 HELIX: -1.43 (0.56), RESIDUES: 68
REMARK 3 SHEET: -1.44 (0.81), RESIDUES: 14
REMARK 3 LOOP : 1.15 (0.69), RESIDUES: 91
-REMARK 3
+REMARK 3
REMARK 3 MAX DEVIATION FROM PLANES:
REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE
-REMARK 3 TRP 0.034 0.008 TRP A 123
-REMARK 3 HIS 0.007 0.005 HIS A 15
-REMARK 3 PHE 0.027 0.008 PHE A 38
-REMARK 3 TYR 0.014 0.005 TYR A 20
-REMARK 3 ARG 0.007 0.002 ARG A 114
+REMARK 3 TRP 0.034 0.008 TRP A 123
+REMARK 3 HIS 0.007 0.005 HIS A 15
+REMARK 3 PHE 0.027 0.008 PHE A 38
+REMARK 3 TYR 0.014 0.005 TYR A 20
+REMARK 3 ARG 0.007 0.002 ARG A 114
REMARK 3
HELIX 1 AA1 GLY A 4 HIS A 15 1 12
HELIX 2 AA2 ASN A 19 TYR A 23 5 5
@@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7
SHEET 1 AA1 3 THR A 43 ARG A 45 0
SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44
SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53
-SSBOND 1 CYS A 6 CYS A 127
-SSBOND 2 CYS A 30 CYS A 115
-SSBOND 3 CYS A 64 CYS A 80
-SSBOND 4 CYS A 76 CYS A 94
+SSBOND 1 CYS A 6 CYS A 127
+SSBOND 2 CYS A 30 CYS A 115
+SSBOND 3 CYS A 64 CYS A 80
+SSBOND 4 CYS A 76 CYS A 94
CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2
SCALE1 0.012591 0.000000 0.000000 0.00000
SCALE2 0.000000 0.012591 0.000000 0.00000
diff --git a/pymol_figures/refine_epoch=127_001.pdb b/pymol_figures/refine_epoch=127_001.pdb
index b18b022..7e1ba94 100644
--- a/pymol_figures/refine_epoch=127_001.pdb
+++ b/pymol_figures/refine_epoch=127_001.pdb
@@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini,
REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart
REMARK 3
REMARK 3 X-RAY DATA.
-REMARK 3
+REMARK 3
REMARK 3 REFINEMENT TARGET : ML
-REMARK 3
+REMARK 3
REMARK 3 DATA USED IN REFINEMENT.
-REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
-REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
-REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32
-REMARK 3 COMPLETENESS FOR RANGE (%) : 91.74
-REMARK 3 NUMBER OF REFLECTIONS : 86181
-REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45584
-REMARK 3
+REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10
+REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16
+REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32
+REMARK 3 COMPLETENESS FOR RANGE (%) : 91.74
+REMARK 3 NUMBER OF REFLECTIONS : 86181
+REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45584
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT.
REMARK 3 R VALUE (WORKING + TEST SET) : 0.1330
REMARK 3 R VALUE (WORKING SET) : 0.1309
REMARK 3 FREE R VALUE : 0.1513
-REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
-REMARK 3 FREE R VALUE TEST SET COUNT : 8823
-REMARK 3
+REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24
+REMARK 3 FREE R VALUE TEST SET COUNT : 8823
+REMARK 3
REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS).
REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE
REMARK 3 1 56.16 - 3.42 1.00 2832 303 0.1397 0.1737 0.945 0.911
@@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.71 1991 232 0.4554 0.4248 0.516
REMARK 3 28 1.14 - 1.13 0.57 1631 176 0.4802 0.4328 0.408 0.473
REMARK 3 29 1.12 - 1.11 0.41 1141 121 0.4715 0.4497 0.315 0.514
REMARK 3 30 1.11 - 1.10 0.21 587 59 0.4710 0.4957 0.177 0.424
-REMARK 3
+REMARK 3
REMARK 3 BULK SOLVENT MODELLING.
REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL
-REMARK 3 SOLVENT RADIUS : 1.10
-REMARK 3 SHRINKAGE RADIUS : 0.90
-REMARK 3 GRID STEP : 0.60
-REMARK 3
+REMARK 3 SOLVENT RADIUS : 1.10
+REMARK 3 SHRINKAGE RADIUS : 0.90
+REMARK 3 GRID STEP : 0.60
+REMARK 3
REMARK 3 ERROR ESTIMATES.
-REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.16
-REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 20.50
-REMARK 3
+REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.16
+REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 20.50
+REMARK 3
REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT
-REMARK 3
+REMARK 3
REMARK 3 B VALUES.
-REMARK 3 FROM WILSON PLOT (A**2) : 19.58
+REMARK 3 FROM WILSON PLOT (A**2) : 19.58
REMARK 3 Individual atomic B
REMARK 3 min max mean iso aniso
REMARK 3 Overall: 16.05 71.49 25.98 1.68 0 1284
@@ -92,7 +92,7 @@ REMARK 3 49.31 - 54.86 23
REMARK 3 54.86 - 60.40 10
REMARK 3 60.40 - 65.95 2
REMARK 3 65.95 - 71.49 2
-REMARK 3
+REMARK 3
REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2
REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES.
REMARK 3 BOND : 0.007 0.063 1227 Z= 0.483
@@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.084 0.249 163
REMARK 3 PLANARITY : 0.009 0.069 228
REMARK 3 DIHEDRAL : 12.938 82.333 464
REMARK 3 MIN NONBONDED DISTANCE : 2.245
-REMARK 3
+REMARK 3
REMARK 3 MOLPROBITY STATISTICS.
REMARK 3 ALL-ATOM CLASHSCORE : 0.00
REMARK 3 RAMACHANDRAN PLOT:
@@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 %
REMARK 3 CIS-GENERAL : 0.00 %
REMARK 3 TWISTED PROLINE : 0.00 %
REMARK 3 TWISTED GENERAL : 0.00 %
-REMARK 3
+REMARK 3
REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE):
REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2.
REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY;
@@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.33 (0.63), RESIDUES: 173
REMARK 3 HELIX: -1.30 (0.57), RESIDUES: 68
REMARK 3 SHEET: -1.46 (0.79), RESIDUES: 14
REMARK 3 LOOP : 1.18 (0.70), RESIDUES: 91
-REMARK 3
+REMARK 3
REMARK 3 MAX DEVIATION FROM PLANES:
REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE
-REMARK 3 TRP 0.032 0.009 TRP A 123
-REMARK 3 HIS 0.007 0.005 HIS A 15
-REMARK 3 PHE 0.032 0.008 PHE A 38
-REMARK 3 TYR 0.014 0.006 TYR A 20
-REMARK 3 ARG 0.010 0.002 ARG A 114
+REMARK 3 TRP 0.032 0.009 TRP A 123
+REMARK 3 HIS 0.007 0.005 HIS A 15
+REMARK 3 PHE 0.032 0.008 PHE A 38
+REMARK 3 TYR 0.014 0.006 TYR A 20
+REMARK 3 ARG 0.010 0.002 ARG A 114
REMARK 3
HELIX 1 AA1 GLY A 4 HIS A 15 1 12
HELIX 2 AA2 ASN A 19 TYR A 23 5 5
@@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7
SHEET 1 AA1 3 THR A 43 ARG A 45 0
SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44
SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53
-SSBOND 1 CYS A 6 CYS A 127
-SSBOND 2 CYS A 30 CYS A 115
-SSBOND 3 CYS A 64 CYS A 80
-SSBOND 4 CYS A 76 CYS A 94
+SSBOND 1 CYS A 6 CYS A 127
+SSBOND 2 CYS A 30 CYS A 115
+SSBOND 3 CYS A 64 CYS A 80
+SSBOND 4 CYS A 76 CYS A 94
CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2
SCALE1 0.012591 0.000000 0.000000 0.00000
SCALE2 0.000000 0.012591 0.000000 0.00000
diff --git a/pyproject.toml b/pyproject.toml
index 6e9fb60..fc06aa1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,4 +35,4 @@ docs = [
source = "git"
[tool.hatch.build.targets.wheel]
-packages = ["src/factory"]
\ No newline at end of file
+packages = ["src/factory"]
diff --git a/setup.py b/setup.py
index 2fa3700..90a1d00 100644
--- a/setup.py
+++ b/setup.py
@@ -1,15 +1,12 @@
import sys
-import sys
-sys.stderr.write(
- """
+sys.stderr.write("""
===============================
Unsupported installation method
===============================
factory does not support installation with `python setup.py install`.
Please use `python -m pip install .` instead.
-"""
-)
+""")
sys.exit(1)
@@ -22,4 +19,4 @@
setup( # noqa
name="factory",
install_requires=[],
-)
\ No newline at end of file
+)
diff --git a/src/factory/CC12.py b/src/factory/CC12.py
index a28340d..c349381 100644
--- a/src/factory/CC12.py
+++ b/src/factory/CC12.py
@@ -1,15 +1,18 @@
-import torch
-import wandb
import os
-import numpy as np
+
import matplotlib.pyplot as plt
-from lightning.pytorch.loggers import WandbLogger
+import numpy as np
+import torch
+import wandb
from lightning.pytorch import Trainer
+from lightning.pytorch.loggers import WandbLogger
from model import *
wandb.init(project="CC12")
wandb_logger = WandbLogger(project="CC12", name="CC12-Training")
-artifact = wandb.use_artifact("flaviagiehr-harvard-university/full-model/best_model:latest", type="model")
+artifact = wandb.use_artifact(
+ "flaviagiehr-harvard-university/full-model/best_model:latest", type="model"
+)
artifact_dir = artifact.download()
ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")]
@@ -17,27 +20,31 @@
checkpoint_path = os.path.join(artifact_dir, ckpt_files[0])
checkpoint = torch.load(checkpoint_path, map_location="cuda")
-state_dict = checkpoint['state_dict']
+state_dict = checkpoint["state_dict"]
filtered_state_dict = {
- k: v for k, v in state_dict.items()
- if not k.startswith("surrogate_posterior.")
+ k: v for k, v in state_dict.items() if not k.startswith("surrogate_posterior.")
}
settings = Settings()
loss_settings = LossSettings()
-data_loader_settings = data_loader.DataLoaderSettings(data_directory=settings.data_directory,
- data_file_names=settings.data_file_names,
- validation_set_split=0.5, test_set_split=0
- )
+data_loader_settings = data_loader.DataLoaderSettings(
+ data_directory=settings.data_directory,
+ data_file_names=settings.data_file_names,
+ validation_set_split=0.5,
+ test_set_split=0,
+)
_dataloader = data_loader.CrystallographicDataLoader(settings=data_loader_settings)
_dataloader.load_data_()
+
def _train_model(model_no, dataloader, settings, loss_settings):
model = Model(settings=settings, loss_settings=loss_settings, dataloader=dataloader)
print("model type:", type(model))
- missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
+ missing_keys, unexpected_keys = model.load_state_dict(
+ filtered_state_dict, strict=False
+ )
print("Missing keys (expected for surrogate posterior):", missing_keys)
print("Unexpected keys (should be none):", unexpected_keys)
@@ -47,27 +54,34 @@ def _train_model(model_no, dataloader, settings, loss_settings):
print(list(state_dict.keys())[:10])
trainer = Trainer(
- logger=wandb_logger,
- max_epochs=20,
- # log_every_n_steps=5,
+ logger=wandb_logger,
+ max_epochs=20,
+ # log_every_n_steps=5,
val_check_interval=None,
- accelerator="auto",
- enable_checkpointing=False,
+ accelerator="auto",
+ enable_checkpointing=False,
default_root_dir="/tmp",
# callbacks=[Plotting(), LossLogging(), CorrelationPlotting(), checkpoint_callback] # ScalePlotting()
)
if model_no == 1:
- trainer.fit(model, train_dataloaders=_dataloader.load_data_set_batched_by_image(
- data_set_to_load=_dataloader.train_data_set
- ))
+ trainer.fit(
+ model,
+ train_dataloaders=_dataloader.load_data_set_batched_by_image(
+ data_set_to_load=_dataloader.train_data_set
+ ),
+ )
else:
- trainer.fit(model, train_dataloaders=_dataloader.load_data_set_batched_by_image(
- data_set_to_load=_dataloader.validation_data_set
- ))
+ trainer.fit(
+ model,
+ train_dataloaders=_dataloader.load_data_set_batched_by_image(
+ data_set_to_load=_dataloader.validation_data_set
+ ),
+ )
print("finished training for model", model_no)
return model
+
model1 = _train_model(1, _dataloader, settings, loss_settings)
model2 = _train_model(2, _dataloader, settings, loss_settings)
@@ -89,13 +103,15 @@ def _train_model(model_no, dataloader, settings, loss_settings):
# outputs_1.append(out1.cpu())
# outputs_2.append(out2.cpu())
+
def get_reliable_mask(model):
- if hasattr(model.surrogate_posterior, 'reliable_observations_mask'):
+ if hasattr(model.surrogate_posterior, "reliable_observations_mask"):
return model.surrogate_posterior.reliable_observations_mask(min_observations=7)
else:
return model.surrogate_posterior.observed
+
reliable_mask_1 = get_reliable_mask(model1)
reliable_mask_2 = get_reliable_mask(model2)
@@ -103,27 +119,38 @@ def get_reliable_mask(model):
# Plot histogram of observation counts
fig_hist, ax_hist = plt.subplots(figsize=(10, 6))
-observation_counts_1 = model1.surrogate_posterior.observation_count.detach().cpu().numpy()
-observation_counts_2 = model2.surrogate_posterior.observation_count.detach().cpu().numpy()
-
-ax_hist.hist([observation_counts_1, observation_counts_2],
- bins=30,
- alpha=0.5,
- label=['Model 1', 'Model 2'])
-ax_hist.set_xlabel('Number of Observations')
-ax_hist.set_ylabel('Frequency')
-ax_hist.set_title('Distribution of Observation Counts')
+observation_counts_1 = (
+ model1.surrogate_posterior.observation_count.detach().cpu().numpy()
+)
+observation_counts_2 = (
+ model2.surrogate_posterior.observation_count.detach().cpu().numpy()
+)
+
+ax_hist.hist(
+ [observation_counts_1, observation_counts_2],
+ bins=30,
+ alpha=0.5,
+ label=["Model 1", "Model 2"],
+)
+ax_hist.set_xlabel("Number of Observations")
+ax_hist.set_ylabel("Frequency")
+ax_hist.set_title("Distribution of Observation Counts")
ax_hist.legend()
wandb.log({"Observation Counts Histogram": wandb.Image(fig_hist)})
plt.close(fig_hist)
-structure_factors_1 = torch.cat([model1.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy()
-structure_factors_2 = torch.cat([model2.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy()
+structure_factors_1 = (
+ torch.cat([model1.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy()
+)
+structure_factors_2 = (
+ torch.cat([model2.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy()
+)
structure_factors_1 = structure_factors_1[common_reliable_mask.cpu().numpy()]
structure_factors_2 = structure_factors_2[common_reliable_mask.cpu().numpy()]
+
def compute_cc12(a, b):
# a = a.detach().cpu().numpy()
# b = b.detach().cpu().numpy()
@@ -132,17 +159,16 @@ def compute_cc12(a, b):
mean_b = np.mean(b)
numerator = np.sum((a - mean_a) * (b - mean_b))
- denominator = np.sqrt(np.sum((a - mean_a)**2) * np.sum((b - mean_b)**2))
+ denominator = np.sqrt(np.sum((a - mean_a) ** 2) * np.sum((b - mean_b) ** 2))
return numerator / (denominator + 1e-8)
+
cc12 = compute_cc12(structure_factors_1, structure_factors_2)
print("CC12:", cc12)
fig, ax = plt.subplots(figsize=(10, 6))
-ax.scatter(structure_factors_1,
- structure_factors_2,
- alpha=0.5)
+ax.scatter(structure_factors_1, structure_factors_2, alpha=0.5)
ax.set_xlabel("Model 1 Structure Factors")
ax.set_ylabel("Model 2 Structure Factors")
ax.set_title(f"CC12 = {cc12:.3f}")
diff --git a/src/factory/anomalous_peaks.py b/src/factory/anomalous_peaks.py
index d2f77f1..e1ba0b1 100644
--- a/src/factory/anomalous_peaks.py
+++ b/src/factory/anomalous_peaks.py
@@ -1,26 +1,29 @@
-import torch
-import pandas as pd
-import numpy as np
-import subprocess
import os
-import matplotlib.pyplot as plt
+import subprocess
-import wandb
+import data_loader
+import get_protein_data
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
import pytorch_lightning as L
-
import reciprocalspaceship as rs
import rs_distributions as rsd
-
+import torch
+import wandb
from model import *
-import data_loader
-import get_protein_data
-
-print(rs.read_mtz("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz"))
+print(
+ rs.read_mtz(
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz"
+ )
+)
wandb.init(project="anomalous peaks", name="anomalous-peaks")
-artifact = wandb.use_artifact("flaviagiehr-harvard-university/full-model/best_model:latest", type="model")
+artifact = wandb.use_artifact(
+ "flaviagiehr-harvard-university/full-model/best_model:latest", type="model"
+)
artifact_dir = artifact.download()
ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")]
@@ -30,21 +33,27 @@
settings = Settings()
loss_settings = LossSettings()
-model = Model.load_from_checkpoint(checkpoint_path, settings=settings, loss_settings=loss_settings)#, dataloader=dataloader)
+model = Model.load_from_checkpoint(
+ checkpoint_path, settings=settings, loss_settings=loss_settings
+) # , dataloader=dataloader)
model.eval()
repo_dir = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files"
mtz_output_path = os.path.join(repo_dir, "phenix_ready.mtz")
-surrogate_posterior_dataset = list(model.surrogate_posterior.to_dataset(only_observed=True))[0]
+surrogate_posterior_dataset = list(
+ model.surrogate_posterior.to_dataset(only_observed=True)
+)[0]
# Additional validation checks
if len(surrogate_posterior_dataset) == 0:
raise ValueError("Dataset has no reflections!")
# Check if we have the required columns
-required_columns = ['F(+)', 'F(-)', 'SIGF(+)', 'SIGF(-)']
-missing_columns = [col for col in required_columns if col not in surrogate_posterior_dataset.columns]
+required_columns = ["F(+)", "F(-)", "SIGF(+)", "SIGF(-)"]
+missing_columns = [
+ col for col in required_columns if col not in surrogate_posterior_dataset.columns
+]
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
@@ -56,13 +65,16 @@
print(f"Warning: Column {col} contains infinite values")
# Ensure d_min is set
-if not hasattr(surrogate_posterior_dataset, 'd_min') or surrogate_posterior_dataset.d_min is None:
+if (
+ not hasattr(surrogate_posterior_dataset, "d_min")
+ or surrogate_posterior_dataset.d_min is None
+):
print("Setting d_min to 2.5Å")
surrogate_posterior_dataset.d_min = 2.5
# Try to get wavelength from the dataset
wavelength = None
-if hasattr(surrogate_posterior_dataset, 'wavelength'):
+if hasattr(surrogate_posterior_dataset, "wavelength"):
wavelength = surrogate_posterior_dataset.wavelength
print(f"\nWavelength from dataset: {wavelength}Å")
else:
@@ -88,7 +100,9 @@
# mtz_output_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz" # for the reference peaks
-phenix_refine_eff_file_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/phenix.eff"
+phenix_refine_eff_file_path = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/phenix.eff"
+)
cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {phenix_refine_eff_file_path} {mtz_output_path} overwrite=True"
@@ -144,33 +158,36 @@
# from rs_distributions.peak_finding import peak_finder
subprocess.run(
-"rs.find_peaks *[0-9].mtz *[0-9].pdb " f"-f ANOM -p PANOM -z 5.0 -o peaks.csv",
-shell=True,
-cwd=f"{repo_dir}",
+ "rs.find_peaks *[0-9].mtz *[0-9].pdb " f"-f ANOM -p PANOM -z 5.0 -o peaks.csv",
+ shell=True,
+ cwd=f"{repo_dir}",
)
# Save peaks
-print ("saved peaks")
+print("saved peaks")
-peaks_df = pd.read_csv("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv")
-print(peaks_df.columns)
+peaks_df = pd.read_csv(
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv"
+)
+print(peaks_df.columns)
# print(peaks_df.iloc[0]) #
print(peaks_df)
# Create a nice table of peak heights
-peak_heights_table = pd.DataFrame({
- 'Residue': peaks_df['residue'],
- 'Peak_Height': peaks_df['peakz']
-})
+peak_heights_table = pd.DataFrame(
+ {"Residue": peaks_df["residue"], "Peak_Height": peaks_df["peakz"]}
+)
# Sort by peak height in descending order for better visualization
-peak_heights_table = peak_heights_table.sort_values('Peak_Height', ascending=False).reset_index(drop=True)
+peak_heights_table = peak_heights_table.sort_values(
+ "Peak_Height", ascending=False
+).reset_index(drop=True)
# Add rank column for easier reference
-peak_heights_table.insert(0, 'Rank', range(1, len(peak_heights_table) + 1))
+peak_heights_table.insert(0, "Rank", range(1, len(peak_heights_table) + 1))
# Round peak heights to 3 decimal places for cleaner display
-peak_heights_table['Peak_Height'] = peak_heights_table['Peak_Height'].round(3)
+peak_heights_table["Peak_Height"] = peak_heights_table["Peak_Height"].round(3)
print("\n=== Anomalous Peak Heights ===")
print(peak_heights_table.to_string(index=False))
@@ -183,6 +200,5 @@
peaks_artifact = wandb.Artifact("anomalous_peaks", type="peaks")
-peaks_artifact.add_file(os.path.join(repo_dir,"peaks.csv"))
+peaks_artifact.add_file(os.path.join(repo_dir, "peaks.csv"))
wandb.log_artifact(peaks_artifact)
-
diff --git a/src/factory/callbacks.py b/src/factory/callbacks.py
index cc88fac..2f2443f 100644
--- a/src/factory/callbacks.py
+++ b/src/factory/callbacks.py
@@ -1,19 +1,20 @@
-import torch
-import reciprocalspaceship as rs
-from lightning.pytorch.loggers import WandbLogger
-from lightning.pytorch.callbacks import Callback
-import wandb
-from phenix_callback import find_peaks_from_model
+import io
-from torchvision.transforms import ToPILImage
-from lightning.pytorch.utilities import grad_norm
+import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
-import io
+import pandas as pd
+import reciprocalspaceship as rs
+import torch
+import wandb
+from lightning.pytorch.callbacks import Callback
+from lightning.pytorch.loggers import WandbLogger
+from lightning.pytorch.utilities import grad_norm
+from phenix_callback import find_peaks_from_model
from PIL import Image
from scipy.optimize import minimize
-import matplotlib.cm as cm
-import pandas as pd
+from torchvision.transforms import ToPILImage
+
class PeakHeights(Callback):
"Compute Anomalous Peak Heights during training."
@@ -21,8 +22,8 @@ class PeakHeights(Callback):
def __init__(self):
super().__init__()
self.residue_colors = {} # Maps residue to color
- self.peak_history = {} # Maps residue to list of (epoch, peak_height)
- self.colormap = cm.get_cmap('tab20')
+ self.peak_history = {} # Maps residue to list of (epoch, peak_height)
+ self.colormap = cm.get_cmap("tab20")
# self.peakheights_log = open("peakheights.out", "a")
# def __del__(self):
@@ -37,10 +38,14 @@ def _get_color(self, residue):
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
- if epoch%2==0:
+ if epoch % 2 == 0:
try:
- r_work, r_free, path_to_peaks = find_peaks_from_model(model=pl_module, repo_dir=trainer.logger.experiment.dir)
- pl_module.logger.experiment.log({"r/r_work": r_work, "r/r_free": r_free})
+ r_work, r_free, path_to_peaks = find_peaks_from_model(
+ model=pl_module, repo_dir=trainer.logger.experiment.dir
+ )
+ pl_module.logger.experiment.log(
+ {"r/r_work": r_work, "r/r_free": r_free}
+ )
peaks_df = pd.read_csv(path_to_peaks)
epoch = trainer.current_epoch
@@ -48,9 +53,9 @@ def on_train_epoch_end(self, trainer, pl_module):
# Update peak history
print("organize peaks")
for _, row in peaks_df.iterrows():
- residue = row['residue']
+ residue = row["residue"]
print(residue)
- peak_height = row['peakz']
+ peak_height = row["peakz"]
print(peak_height)
if residue not in self.peak_history:
self.peak_history[residue] = []
@@ -64,28 +69,27 @@ def on_train_epoch_end(self, trainer, pl_module):
epochs, heights = zip(*history)
color = self._get_color(residue)
plt.plot(epochs, heights, label=str(residue), color=color)
- plt.xlabel('Epoch')
- plt.ylabel('Peak Height')
- plt.title('Peak Heights per Residue')
- plt.legend(title='Residue', bbox_to_anchor=(1.05, 1), loc='upper left')
+ plt.xlabel("Epoch")
+ plt.ylabel("Peak Height")
+ plt.title("Peak Heights per Residue")
+ plt.legend(title="Residue", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
# Log to wandb
- pl_module.logger.experiment.log({
- "PeakHeights": wandb.Image(plt.gcf()),
- "epoch": epoch
- })
+ pl_module.logger.experiment.log(
+ {"PeakHeights": wandb.Image(plt.gcf()), "epoch": epoch}
+ )
plt.close()
except Exception as e:
- print(f"Failed for {epoch}: {e}")#, file=self.peakheights_log, flush=True)
+ print(
+ f"Failed for {epoch}: {e}"
+ ) # , file=self.peakheights_log, flush=True)
else:
return
-
-
class LossLogging(Callback):
-
+
def __init__(self):
super().__init__()
self.loss_per_epoch: float = 0
@@ -96,48 +100,76 @@ def on_train_epoch_end(self, trainer, pl_module):
pl_module.log("train/loss_epoch", avg, on_step=False, on_epoch=True)
def on_validation_epoch_end(self, trainer, pl_module):
-
+
train_loss = trainer.callback_metrics.get("train/loss_step_epoch")
val_loss = trainer.callback_metrics.get("validation/loss_step_epoch")
if train_loss is not None and val_loss is not None:
difference_train_val_loss = train_loss - val_loss
- pl_module.log("loss/train-val", difference_train_val_loss, on_step=False, on_epoch=True)
+ pl_module.log(
+ "loss/train-val",
+ difference_train_val_loss,
+ on_step=False,
+ on_epoch=True,
+ )
else:
- print("Skipping loss/train-val logging: train or validation loss not available this epoch.")
-
+ print(
+ "Skipping loss/train-val logging: train or validation loss not available this epoch."
+ )
+
+
class Plotting(Callback):
def __init__(self, dataloader):
super().__init__()
self.fixed_batch = None
self.dataloader = dataloader
-
+
def on_fit_start(self, trainer, pl_module):
- _raw_batch = next(iter(self.dataloader.load_data_for_logging_during_training(number_of_shoeboxes_to_log=8)))
+ _raw_batch = next(
+ iter(
+ self.dataloader.load_data_for_logging_during_training(
+ number_of_shoeboxes_to_log=8
+ )
+ )
+ )
_raw_batch = [_batch_item.to(pl_module.device) for _batch_item in _raw_batch]
self.fixed_batch = tuple(_raw_batch)
-
- self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(pl_module.device)[:, 9]
- self.fixed_batch_intensity_sum_variance = self.fixed_batch[1].to(pl_module.device)[:, 10]
- self.fixed_batch_total_counts = self.fixed_batch[3].to(pl_module.device).sum(dim=1)
- self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[:, 13]
+ self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(
+ pl_module.device
+ )[:, 9]
+ self.fixed_batch_intensity_sum_variance = self.fixed_batch[1].to(
+ pl_module.device
+ )[:, 10]
+ self.fixed_batch_total_counts = (
+ self.fixed_batch[3].to(pl_module.device).sum(dim=1)
+ )
+
+ self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[
+ :, 13
+ ]
def on_train_epoch_end(self, trainer, pl_module):
with torch.no_grad():
- shoeboxes_batch, photon_rate_output, hkl_batch, counts_batch = pl_module(self.fixed_batch, verbose_output=True)
+ shoeboxes_batch, photon_rate_output, hkl_batch, counts_batch = pl_module(
+ self.fixed_batch, verbose_output=True
+ )
# Extract values from dictionary
- samples_photon_rate = photon_rate_output['photon_rate']
- samples_profile = photon_rate_output['samples_profile']
- samples_background = photon_rate_output['samples_background']
- samples_scale = photon_rate_output['samples_scale']
- samples_predicted_structure_factor = photon_rate_output["samples_predicted_structure_factor"]
-
+ samples_photon_rate = photon_rate_output["photon_rate"]
+ samples_profile = photon_rate_output["samples_profile"]
+ samples_background = photon_rate_output["samples_background"]
+ samples_scale = photon_rate_output["samples_scale"]
+ samples_predicted_structure_factor = photon_rate_output[
+ "samples_predicted_structure_factor"
+ ]
+
del photon_rate_output
-
+
batch_size = shoeboxes_batch.size(0)
- model_intensity = samples_scale * torch.square(samples_predicted_structure_factor)
+ model_intensity = samples_scale * torch.square(
+ samples_predicted_structure_factor
+ )
model_intensity = torch.mean(model_intensity, dim=1)
print("model_intensity", model_intensity.shape)
@@ -146,92 +178,133 @@ def on_train_epoch_end(self, trainer, pl_module):
print("photon rate", photon_rate.shape)
print("profile", samples_profile.shape)
print("background", samples_background.shape)
- print("scale", samples_scale.shape)
+ print("scale", samples_scale.shape)
fig, axes = plt.subplots(
- 2, batch_size,
- figsize=(4*batch_size, 16),
- gridspec_kw={'hspace': 0.05, 'wspace': 0.3}
+ 2,
+ batch_size,
+ figsize=(4 * batch_size, 16),
+ gridspec_kw={"hspace": 0.05, "wspace": 0.3},
)
-
+
im_handles = []
for b in range(batch_size):
- if len(shoeboxes_batch.shape)>2:
- im1 = axes[0,b].imshow(shoeboxes_batch[b,:, -1].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy())
- im2 = axes[1,b].imshow(samples_profile[b,0,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy())
+ if len(shoeboxes_batch.shape) > 2:
+ im1 = axes[0, b].imshow(
+ shoeboxes_batch[b, :, -1]
+ .reshape(3, 21, 21)[1:2]
+ .squeeze()
+ .cpu()
+ .detach()
+ .numpy()
+ )
+ im2 = axes[1, b].imshow(
+ samples_profile[b, 0, :]
+ .reshape(3, 21, 21)[1:2]
+ .squeeze()
+ .cpu()
+ .detach()
+ .numpy()
+ )
else:
- im1 = axes[0,b].imshow(shoeboxes_batch[b,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy())
- im2 = axes[1,b].imshow(samples_profile[b,0,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy())
+ im1 = axes[0, b].imshow(
+ shoeboxes_batch[b, :]
+ .reshape(3, 21, 21)[1:2]
+ .squeeze()
+ .cpu()
+ .detach()
+ .numpy()
+ )
+ im2 = axes[1, b].imshow(
+ samples_profile[b, 0, :]
+ .reshape(3, 21, 21)[1:2]
+ .squeeze()
+ .cpu()
+ .detach()
+ .numpy()
+ )
im_handles.append(im1)
# Build title string with each element on its own row (3 rows)
title_lines = [
- f'Raw Shoebox {b}',
- f'I = {self.fixed_batch_intensity_sum_values[b]:.2f} pm {self.fixed_batch_intensity_sum_variance[b]:.2f}',
- f'Bkg (mean) = {self.fixed_batch_background_mean[b]:.2f}',
- f'total counts = {self.fixed_batch_total_counts[b]:.2f}'
+ f"Raw Shoebox {b}",
+ f"I = {self.fixed_batch_intensity_sum_values[b]:.2f} pm {self.fixed_batch_intensity_sum_variance[b]:.2f}",
+ f"Bkg (mean) = {self.fixed_batch_background_mean[b]:.2f}",
+ f"total counts = {self.fixed_batch_total_counts[b]:.2f}",
]
- axes[0,b].set_title('\n'.join(title_lines), pad=5, fontsize=16)
- axes[0,b].axis('off')
-
+ axes[0, b].set_title("\n".join(title_lines), pad=5, fontsize=16)
+ axes[0, b].axis("off")
+
title_prf_lines = [
- f'Profile {b}',
- f'photon_rate = {photon_rate[b].sum():.2f}',
- f'I(F) = {model_intensity[b].item():.2f}',
- f'Scale = {torch.mean(samples_scale, dim=1)[b].item():.2f}',
- f'F= {torch.mean(samples_predicted_structure_factor, dim=1)[b].item():.2f}',
- f'Bkg (mean) = {torch.mean(samples_background, dim=1)[b].item():.2f}'
+ f"Profile {b}",
+ f"photon_rate = {photon_rate[b].sum():.2f}",
+ f"I(F) = {model_intensity[b].item():.2f}",
+ f"Scale = {torch.mean(samples_scale, dim=1)[b].item():.2f}",
+ f"F= {torch.mean(samples_predicted_structure_factor, dim=1)[b].item():.2f}",
+ f"Bkg (mean) = {torch.mean(samples_background, dim=1)[b].item():.2f}",
]
- axes[1,b].set_title('\n'.join(title_prf_lines), pad=5, fontsize=16)
- axes[1,b].axis('off')
-
+ axes[1, b].set_title("\n".join(title_prf_lines), pad=5, fontsize=16)
+ axes[1, b].axis("off")
+
# Add individual colorbars for each column
# Colorbar for the raw shoebox (top row)
- cbar1 = fig.colorbar(im1, ax=axes[0,b], orientation='vertical', fraction=0.02, pad=0.04)
+ cbar1 = fig.colorbar(
+ im1, ax=axes[0, b], orientation="vertical", fraction=0.02, pad=0.04
+ )
cbar1.ax.tick_params(labelsize=8)
-
+
# Colorbar for the profile (bottom row)
- cbar2 = fig.colorbar(im2, ax=axes[1,b], orientation='vertical', fraction=0.02, pad=0.04)
+ cbar2 = fig.colorbar(
+ im2, ax=axes[1, b], orientation="vertical", fraction=0.02, pad=0.04
+ )
cbar2.ax.tick_params(labelsize=8)
-
-
-
+
plt.tight_layout(h_pad=0.0, w_pad=0.3) # Remove extra vertical padding
buf = io.BytesIO()
- plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
+ plt.savefig(buf, format="png", dpi=100, bbox_inches="tight")
buf.seek(0)
-
+
# Convert buffer to PIL Image
pil_image = Image.open(buf)
-
- pl_module.logger.experiment.log({
- "Profiles": wandb.Image(pil_image),
- "epoch": trainer.current_epoch
- })
-
+
+ pl_module.logger.experiment.log(
+ {"Profiles": wandb.Image(pil_image), "epoch": trainer.current_epoch}
+ )
+
plt.close()
buf.close()
-
+
# Memory optimization: Clear large tensors after use
- del samples_photon_rate, samples_profile, samples_background, samples_scale, samples_predicted_structure_factor
+ del (
+ samples_photon_rate,
+ samples_profile,
+ samples_background,
+ samples_scale,
+ samples_predicted_structure_factor,
+ )
torch.cuda.empty_cache() # Force GPU memory cleanup
+
class CorrelationPlottingBinned(Callback):
-
+
def __init__(self, dataloader):
super().__init__()
self.dials_reference = None
self.fixed_batch = None
# self.fixed_hkl_indices_for_surrogate_posterior = None
self.dataloader = dataloader
-
+
def on_fit_start(self, trainer, pl_module):
print("on fit start callback")
- self.dials_reference = rs.read_mtz(pl_module.model_settings.merged_mtz_file_path)
- self.dials_reference, self.resolution_bins = self.dials_reference.assign_resolution_bins(bins=10)
+ self.dials_reference = rs.read_mtz(
+ pl_module.model_settings.merged_mtz_file_path
+ )
+ self.dials_reference, self.resolution_bins = (
+ self.dials_reference.assign_resolution_bins(bins=10)
+ )
print("Read mtz successfully: ", self.dials_reference["F"].shape)
print("head hkl", self.dials_reference.get_hkls().shape)
@@ -244,42 +317,61 @@ def on_train_epoch_end(self, trainer, pl_module):
print("on train epoch end in callback")
try:
# shoeboxes_batch, photon_rate_output, hkl_batch = pl_module(self.fixed_batch, log_images=True)
-
samples_surrogate_posterior = pl_module.surrogate_posterior.mean
- print("full samples", samples_surrogate_posterior.shape)
+ print("full samples", samples_surrogate_posterior.shape)
# print("gathered ", photon_rate_output['samples_predicted_structure_factor'].shape)
rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0]
- print("rasu id (all teh same as the first one?):",pl_module.surrogate_posterior.rac.rasu_ids)
-
- if hasattr(pl_module.surrogate_posterior, 'reliable_observations_mask'):
- reliable_mask = pl_module.surrogate_posterior.reliable_observations_mask(min_observations=50)
+ print(
+ "rasu id (all teh same as the first one?):",
+ pl_module.surrogate_posterior.rac.rasu_ids,
+ )
+
+ if hasattr(pl_module.surrogate_posterior, "reliable_observations_mask"):
+ reliable_mask = (
+ pl_module.surrogate_posterior.reliable_observations_mask(
+ min_observations=50
+ )
+ )
print("take reliable mask for correlation plotting")
-
+
else:
reliable_mask = pl_module.surrogate_posterior.observed
print("observed attribute", pl_module.surrogate_posterior.observed.sum())
- mask_observed_indexed_as_dials_reference = pl_module.surrogate_posterior.rac.gather(
- source=pl_module.surrogate_posterior.observed & reliable_mask, rasu_id=rasu_ids, H=self.dials_reference.get_hkls()
+ mask_observed_indexed_as_dials_reference = (
+ pl_module.surrogate_posterior.rac.gather(
+ source=pl_module.surrogate_posterior.observed & reliable_mask,
+ rasu_id=rasu_ids,
+ H=self.dials_reference.get_hkls(),
+ )
)
ordered_samples = pl_module.surrogate_posterior.rac.gather( # gather wants source to be (rac_size, )
- source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=self.dials_reference.get_hkls()
- ) #(reflections_from_mtz, number_of_samples)
+ source=samples_surrogate_posterior.T,
+ rasu_id=rasu_ids,
+ H=self.dials_reference.get_hkls(),
+ ) # (reflections_from_mtz, number_of_samples)
print("get_hkl from dials", self.dials_reference.get_hkls().shape)
print("dials ref", self.dials_reference["F"].shape)
print("ordered samples", ordered_samples.shape)
print("check output of assign resolution bins:")
- print( self.dials_reference.shape)
- print( self.dials_reference.head())
+ print(self.dials_reference.shape)
+ print(self.dials_reference.head())
# mask out nonobserved reflections
- print("sum of (reindexed) observed reflections", mask_observed_indexed_as_dials_reference.sum())
- _masked_dials_reference = self.dials_reference[mask_observed_indexed_as_dials_reference.cpu().detach().numpy()]
+ print(
+ "sum of (reindexed) observed reflections",
+ mask_observed_indexed_as_dials_reference.sum(),
+ )
+ _masked_dials_reference = self.dials_reference[
+ mask_observed_indexed_as_dials_reference.cpu().detach().numpy()
+ ]
print("masked dials reference", _masked_dials_reference["F"].shape)
- _masked_ordered_samples = ordered_samples[mask_observed_indexed_as_dials_reference]
+ _masked_ordered_samples = ordered_samples[
+ mask_observed_indexed_as_dials_reference
+ ]
print("masked ordered samples", _masked_ordered_samples.shape)
print("self.resolution_bins", self.resolution_bins)
@@ -300,134 +392,182 @@ def on_train_epoch_end(self, trainer, pl_module):
for i, bin_index in enumerate(self.resolution_bins):
bin_mask = _masked_dials_reference["bin"] == i
print("bin maks", sum(bin_mask))
- reference_to_plot = _masked_dials_reference[bin_mask]["F"].squeeze().to_numpy()
- model_to_plot = _masked_ordered_samples[np.array(bin_mask)].squeeze().cpu().detach().numpy()
+ reference_to_plot = (
+ _masked_dials_reference[bin_mask]["F"].squeeze().to_numpy()
+ )
+ model_to_plot = (
+ _masked_ordered_samples[np.array(bin_mask)]
+ .squeeze()
+ .cpu()
+ .detach()
+ .numpy()
+ )
-
# Remove any NaN or Inf values
- valid_mask = ~(np.isnan(reference_to_plot) | np.isnan(model_to_plot) |
- np.isinf(reference_to_plot) | np.isinf(model_to_plot))
+ valid_mask = ~(
+ np.isnan(reference_to_plot)
+ | np.isnan(model_to_plot)
+ | np.isinf(reference_to_plot)
+ | np.isinf(model_to_plot)
+ )
if not np.any(valid_mask):
print("Warning: No valid data points after filtering")
continue
-
+
reference_to_plot = reference_to_plot[valid_mask]
model_to_plot = model_to_plot[valid_mask]
-
+
print(f"After filtering - valid data points: {len(reference_to_plot)}")
-
- scatter = axes[i].scatter(reference_to_plot,
- model_to_plot,
- marker='x', s=10, alpha=0.5, c="black")
-
- # Fit a straight line through the origin using numpy polyfit
- # For line through origin, we fit y = mx (no intercept)
- # try:
- # fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0]
- # print(f"Polyfit successful, slope: {fitted_slope:.6f}")
- # except np.linalg.LinAlgError as e:
- # print(f"Polyfit failed with LinAlgError: {e}")
- # print("Falling back to manual calculation")
- # # Fallback to manual calculation for line through origin
- # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
- # print(f"Fallback slope: {fitted_slope:.6f}")
- # except Exception as e:
- # print(f"Polyfit failed with other error: {e}")
- # # Fallback to manual calculation
- # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
- # print(f"Fallback slope: {fitted_slope:.6f}")
-
- # # Plot the fitted line
- # x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
- # y_fitted = fitted_slope * x_range
- # axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
-
+
+ scatter = axes[i].scatter(
+ reference_to_plot,
+ model_to_plot,
+ marker="x",
+ s=10,
+ alpha=0.5,
+ c="black",
+ )
+
+ # Fit a straight line through the origin using numpy polyfit
+ # For line through origin, we fit y = mx (no intercept)
+ # try:
+ # fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0]
+ # print(f"Polyfit successful, slope: {fitted_slope:.6f}")
+ # except np.linalg.LinAlgError as e:
+ # print(f"Polyfit failed with LinAlgError: {e}")
+ # print("Falling back to manual calculation")
+ # # Fallback to manual calculation for line through origin
+ # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ # print(f"Fallback slope: {fitted_slope:.6f}")
+ # except Exception as e:
+ # print(f"Polyfit failed with other error: {e}")
+ # # Fallback to manual calculation
+ # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ # print(f"Fallback slope: {fitted_slope:.6f}")
+
+ # # Plot the fitted line
+ # x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
+ # y_fitted = fitted_slope * x_range
+ # axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
+
# Set log-log scale
- axes[i].set_xscale('log')
- axes[i].set_yscale('log')
-
+ axes[i].set_xscale("log")
+ axes[i].set_yscale("log")
+
try:
- correlation = np.corrcoef(reference_to_plot,
- model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1]
+ correlation = np.corrcoef(
+ reference_to_plot,
+ model_to_plot
+ * np.max(reference_to_plot)
+ / np.max(model_to_plot),
+ )[0, 1]
print("reference_to_plot", reference_to_plot)
- print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))
+ print(
+ "model to plot scaled",
+ model_to_plot
+ * np.max(reference_to_plot)
+ / np.max(model_to_plot),
+ )
except Exception as e:
print(f"skip correlation computataion {e}")
# print(f"Fitted slope: {fitted_slope:.4f}")
- axes[i].set_xlabel('Dials Reference F')
- axes[i].set_ylabel('Predicted F')
- axes[i].set_title(f'bin {bin_index} (r={correlation:.3f}')
+ axes[i].set_xlabel("Dials Reference F")
+ axes[i].set_ylabel("Predicted F")
+ axes[i].set_title(f"bin {bin_index} (r={correlation:.3f}")
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])
-
- pl_module.logger.experiment.log({
- "Correlation/structure_factors_binned": wandb.Image(fig),
- # "Correlation/correlation_coefficient": correlation,
- "epoch": trainer.current_epoch
- })
+ pl_module.logger.experiment.log(
+ {
+ "Correlation/structure_factors_binned": wandb.Image(fig),
+ # "Correlation/correlation_coefficient": correlation,
+ "epoch": trainer.current_epoch,
+ }
+ )
plt.close(fig)
print("logged image")
-
+
except Exception as e:
print(f"Error in correlation plotting: {str(e)}")
import traceback
+
traceback.print_exc()
+
class CorrelationPlotting(Callback):
-
+
def __init__(self, dataloader):
super().__init__()
self.dials_reference = None
self.fixed_batch = None
self.dataloader = dataloader
-
+
def on_fit_start(self, trainer, pl_module):
print("on fit start callback")
- self.dials_reference = rs.read_mtz(pl_module.model_settings.merged_mtz_file_path)
+ self.dials_reference = rs.read_mtz(
+ pl_module.model_settings.merged_mtz_file_path
+ )
print("Read mtz successfully: ", self.dials_reference["F"].shape)
print("head hkl", self.dials_reference.get_hkls().shape)
-
-
def on_train_epoch_end(self, trainer, pl_module):
print("on train epoch end in callback")
try:
# shoeboxes_batch, photon_rate_output, hkl_batch = pl_module(self.fixed_batch, log_images=True)
-
+
samples_surrogate_posterior = pl_module.surrogate_posterior.mean
- print("full samples", samples_surrogate_posterior.shape)
+ print("full samples", samples_surrogate_posterior.shape)
# print("gathered ", photon_rate_output['samples_predicted_structure_factor'].shape)
rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0]
- print("rasu id (all teh same as the first one?):",pl_module.surrogate_posterior.rac.rasu_ids)
-
- if hasattr(pl_module.surrogate_posterior, 'reliable_observations_mask'):
- reliable_mask = pl_module.surrogate_posterior.reliable_observations_mask(min_observations=50)
+ print(
+ "rasu id (all teh same as the first one?):",
+ pl_module.surrogate_posterior.rac.rasu_ids,
+ )
+
+ if hasattr(pl_module.surrogate_posterior, "reliable_observations_mask"):
+ reliable_mask = (
+ pl_module.surrogate_posterior.reliable_observations_mask(
+ min_observations=50
+ )
+ )
print("take reliable mask for correlation plotting")
-
+
else:
reliable_mask = pl_module.surrogate_posterior.observed
print("observed attribute", pl_module.surrogate_posterior.observed.sum())
- mask_observed_indexed_as_dials_reference = pl_module.surrogate_posterior.rac.gather(
- source=pl_module.surrogate_posterior.observed & reliable_mask, rasu_id=rasu_ids, H=self.dials_reference.get_hkls()
+ mask_observed_indexed_as_dials_reference = (
+ pl_module.surrogate_posterior.rac.gather(
+ source=pl_module.surrogate_posterior.observed & reliable_mask,
+ rasu_id=rasu_ids,
+ H=self.dials_reference.get_hkls(),
+ )
)
ordered_samples = pl_module.surrogate_posterior.rac.gather( # gather wants source to be (rac_size, )
- source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=self.dials_reference.get_hkls()
- ) #(reflections_from_mtz, number_of_samples)
+ source=samples_surrogate_posterior.T,
+ rasu_id=rasu_ids,
+ H=self.dials_reference.get_hkls(),
+ ) # (reflections_from_mtz, number_of_samples)
print("get_hkl from dials", self.dials_reference.get_hkls().shape)
print("dials ref", self.dials_reference["F"].shape)
print("ordered samples", ordered_samples.shape)
# mask out nonobserved reflections
- print("sum of (reindexed) observed reflections", mask_observed_indexed_as_dials_reference.sum())
- _masked_dials_reference = self.dials_reference[mask_observed_indexed_as_dials_reference.cpu().detach().numpy()]
+ print(
+ "sum of (reindexed) observed reflections",
+ mask_observed_indexed_as_dials_reference.sum(),
+ )
+ _masked_dials_reference = self.dials_reference[
+ mask_observed_indexed_as_dials_reference.cpu().detach().numpy()
+ ]
print("masked dials reference", _masked_dials_reference["F"].shape)
- _masked_ordered_samples = ordered_samples[mask_observed_indexed_as_dials_reference]
+ _masked_ordered_samples = ordered_samples[
+ mask_observed_indexed_as_dials_reference
+ ]
print("masked ordered samples", _masked_ordered_samples.shape)
if _masked_ordered_samples.numel() == 0:
@@ -437,99 +577,134 @@ def on_train_epoch_end(self, trainer, pl_module):
print("Creating correlation plot...")
reference_to_plot = _masked_dials_reference["F"].squeeze().to_numpy()
model_to_plot = _masked_ordered_samples.squeeze().cpu().detach().numpy()
-
+
# Remove any NaN or Inf values
- valid_mask = ~(np.isnan(reference_to_plot) | np.isnan(model_to_plot) |
- np.isinf(reference_to_plot) | np.isinf(model_to_plot))
-
+ valid_mask = ~(
+ np.isnan(reference_to_plot)
+ | np.isnan(model_to_plot)
+ | np.isinf(reference_to_plot)
+ | np.isinf(model_to_plot)
+ )
+
if not np.any(valid_mask):
print("Warning: No valid data points after filtering")
return
-
+
reference_to_plot = reference_to_plot[valid_mask]
model_to_plot = model_to_plot[valid_mask]
-
+
print(f"After filtering - valid data points: {len(reference_to_plot)}")
-
+
fig, axes = plt.subplots(figsize=(10, 10))
- scatter = axes.scatter(reference_to_plot,
- model_to_plot,
- marker='x', s=10, alpha=0.5, c="black")
-
+ scatter = axes.scatter(
+ reference_to_plot, model_to_plot, marker="x", s=10, alpha=0.5, c="black"
+ )
+
# Fit a straight line through the origin using numpy polyfit
# For line through origin, we fit y = mx (no intercept)
try:
- fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0]
+ fitted_slope = np.polyfit(
+ reference_to_plot, model_to_plot, 1, w=None, cov=False
+ )[0]
print(f"Polyfit successful, slope: {fitted_slope:.6f}")
except np.linalg.LinAlgError as e:
print(f"Polyfit failed with LinAlgError: {e}")
print("Falling back to manual calculation")
# Fallback to manual calculation for line through origin
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
except Exception as e:
print(f"Polyfit failed with other error: {e}")
# Fallback to manual calculation
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
-
+
# Plot the fitted line
x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
y_fitted = fitted_slope * x_range
- axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
-
+ axes.plot(
+ x_range,
+ y_fitted,
+ "r-",
+ linewidth=2,
+ label=f"Fitted line: y = {fitted_slope:.4f}x",
+ )
+
# Set log-log scale
- axes.set_xscale('log')
- axes.set_yscale('log')
-
- correlation = np.corrcoef(reference_to_plot,
- model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1]
+ axes.set_xscale("log")
+ axes.set_yscale("log")
+
+ correlation = np.corrcoef(
+ reference_to_plot,
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )[0, 1]
print("reference_to_plot", reference_to_plot)
- print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))
+ print(
+ "model to plot scaled",
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )
print(f"Fitted slope: {fitted_slope:.4f}")
- axes.set_xlabel('Dials Reference F')
- axes.set_ylabel('Predicted F')
- axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})')
+ axes.set_xlabel("Dials Reference F")
+ axes.set_ylabel("Predicted F")
+ axes.set_title(
+ f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})"
+ )
axes.legend()
-
+
# Add diagonal line
# min_val = min(axes.get_xlim()[0], axes.get_ylim()[0])
# max_val = max(axes.get_xlim()[1], axes.get_ylim()[1])
# axes.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
- pl_module.logger.experiment.log({
- "Correlation/structure_factors": wandb.Image(fig),
- "Correlation/correlation_coefficient": correlation,
- "epoch": trainer.current_epoch
- })
+ pl_module.logger.experiment.log(
+ {
+ "Correlation/structure_factors": wandb.Image(fig),
+ "Correlation/correlation_coefficient": correlation,
+ "epoch": trainer.current_epoch,
+ }
+ )
plt.close(fig)
print("logged image")
-
+
except Exception as e:
print(f"Error in correlation plotting: {str(e)}")
import traceback
+
traceback.print_exc()
+
class ScalePlotting(Callback):
-
+
def __init__(self, dataloader):
super().__init__()
self.fixed_batch = None
- self.scale_history = {
- 'loc': [],
- 'scale': []
- }
+ self.scale_history = {"loc": [], "scale": []}
self.dataloader = dataloader
-
+
def on_fit_start(self, trainer, pl_module):
- _raw_batch = next(iter(self.dataloader.load_data_for_logging_during_training(number_of_shoeboxes_to_log=2000)))
+ _raw_batch = next(
+ iter(
+ self.dataloader.load_data_for_logging_during_training(
+ number_of_shoeboxes_to_log=2000
+ )
+ )
+ )
_raw_batch = [_batch_item.to(pl_module.device) for _batch_item in _raw_batch]
self.fixed_batch = tuple(_raw_batch)
- self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[:, 13]
- self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(pl_module.device)[:, 9]
- self.fixed_batch_total_counts = self.fixed_batch[3].to(pl_module.device).sum(dim=1)
-
+ self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[
+ :, 13
+ ]
+ self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(
+ pl_module.device
+ )[:, 9]
+ self.fixed_batch_total_counts = (
+ self.fixed_batch[3].to(pl_module.device).sum(dim=1)
+ )
def on_train_epoch_end(self, trainer, pl_module):
with torch.no_grad():
@@ -537,103 +712,175 @@ def on_train_epoch_end(self, trainer, pl_module):
# Memory optimization: Process in smaller chunks
chunk_size = 20 # Process 20 shoeboxes at a time
batch_size = self.fixed_batch[0].size(0)
-
+
all_scale_means = []
all_background_means = []
-
+
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
-
+
# Create chunk batch
- chunk_batch = tuple(tensor[i:end_idx] for tensor in self.fixed_batch)
-
- # Process chunk
- shoebox_representation, metadata_representation, image_representation, shoebox_profile_representation = pl_module._batch_to_representations(batch=chunk_batch)
+ chunk_batch = tuple(
+ tensor[i:end_idx] for tensor in self.fixed_batch
+ )
- scale_per_reflection = pl_module.compute_scale(representations=[image_representation, metadata_representation])
- scale_mean_per_reflection = scale_per_reflection.mean.cpu().detach().numpy()
+ # Process chunk
+ (
+ shoebox_representation,
+ metadata_representation,
+ image_representation,
+ shoebox_profile_representation,
+ ) = pl_module._batch_to_representations(batch=chunk_batch)
+
+ scale_per_reflection = pl_module.compute_scale(
+ representations=[image_representation, metadata_representation]
+ )
+ scale_mean_per_reflection = (
+ scale_per_reflection.mean.cpu().detach().numpy()
+ )
all_scale_means.append(scale_mean_per_reflection)
-
- background_mean_per_reflection = pl_module.compute_background_distribution(shoebox_representation=shoebox_representation).mean.cpu().detach().numpy()
+
+ background_mean_per_reflection = (
+ pl_module.compute_background_distribution(
+ shoebox_representation=shoebox_representation
+ )
+ .mean.cpu()
+ .detach()
+ .numpy()
+ )
all_background_means.append(background_mean_per_reflection)
-
+
# Clear intermediate tensors
- del shoebox_representation, metadata_representation, image_representation, scale_per_reflection
-
+ del (
+ shoebox_representation,
+ metadata_representation,
+ image_representation,
+ scale_per_reflection,
+ )
+
# Combine results
scale_mean_per_reflection = np.concatenate(all_scale_means)
background_mean_per_reflection = np.concatenate(all_background_means)
-
+
print(len(scale_mean_per_reflection))
-
- plt.figure(figsize=(8,5))
- plt.hist(scale_mean_per_reflection, bins=100, edgecolor="black", alpha=0.6, label="Scale_mean")
+
+ plt.figure(figsize=(8, 5))
+ plt.hist(
+ scale_mean_per_reflection,
+ bins=100,
+ edgecolor="black",
+ alpha=0.6,
+ label="Scale_mean",
+ )
plt.xlabel("mean scale per reflection")
plt.ylabel("Number of Observations")
plt.title("Histogram of Per‐Reflection Scale Factors-Model")
plt.tight_layout()
- pl_module.logger.experiment.log({"scale_histogram": wandb.Image(plt.gcf())})
+ pl_module.logger.experiment.log(
+ {"scale_histogram": wandb.Image(plt.gcf())}
+ )
plt.close()
- print("###################### std bkg #########: ",np.std(background_mean_per_reflection))
- plt.figure(figsize=(8,5))
- plt.hist(background_mean_per_reflection, bins=100, edgecolor="black", alpha=0.6, label="Background_mean")
+ print(
+ "###################### std bkg #########: ",
+ np.std(background_mean_per_reflection),
+ )
+ plt.figure(figsize=(8, 5))
+ plt.hist(
+ background_mean_per_reflection,
+ bins=100,
+ edgecolor="black",
+ alpha=0.6,
+ label="Background_mean",
+ )
plt.xlabel("mean background per reflection")
plt.ylabel("Number of Observations")
plt.title("Histogram of Per‐Reflection Background -Model")
plt.tight_layout()
- pl_module.logger.experiment.log({"background_histogram": wandb.Image(plt.gcf())})
+ pl_module.logger.experiment.log(
+ {"background_histogram": wandb.Image(plt.gcf())}
+ )
plt.close()
- reference_to_plot = self.fixed_batch_background_mean.cpu().detach().numpy()
+ reference_to_plot = (
+ self.fixed_batch_background_mean.cpu().detach().numpy()
+ )
model_to_plot = background_mean_per_reflection.squeeze()
print("bkg", reference_to_plot.shape, model_to_plot.shape)
fig, axes = plt.subplots(figsize=(10, 10))
- scatter = axes.scatter(reference_to_plot,
- model_to_plot,
- marker='x', s=10, alpha=0.5, c="black")
+ scatter = axes.scatter(
+ reference_to_plot,
+ model_to_plot,
+ marker="x",
+ s=10,
+ alpha=0.5,
+ c="black",
+ )
try:
- fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0]
+ fitted_slope = np.polyfit(
+ reference_to_plot, model_to_plot, 1, w=None, cov=False
+ )[0]
print(f"Polyfit successful, slope: {fitted_slope:.6f}")
except np.linalg.LinAlgError as e:
print(f"Polyfit failed with LinAlgError: {e}")
print("Falling back to manual calculation")
# Fallback to manual calculation for line through origin
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
except Exception as e:
print(f"Polyfit failed with other error: {e}")
# Fallback to manual calculation
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
-
+
# Plot the fitted line
- x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
+ x_range = np.linspace(
+ reference_to_plot.min(), reference_to_plot.max(), 100
+ )
y_fitted = fitted_slope * x_range
- axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
-
+ axes.plot(
+ x_range,
+ y_fitted,
+ "r-",
+ linewidth=2,
+ label=f"Fitted line: y = {fitted_slope:.4f}x",
+ )
+
# Set log-log scale
- axes.set_xscale('log')
- axes.set_yscale('log')
-
- correlation = np.corrcoef(reference_to_plot,
- model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1]
+ axes.set_xscale("log")
+ axes.set_yscale("log")
+
+ correlation = np.corrcoef(
+ reference_to_plot,
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )[0, 1]
print("reference_to_plot", reference_to_plot)
- print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))
+ print(
+ "model to plot scaled",
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )
print(f"Fitted slope: {fitted_slope:.4f}")
- axes.set_xlabel('Dials Reference Bkg')
- axes.set_ylabel('Predicted Bkg')
- axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})')
+ axes.set_xlabel("Dials Reference Bkg")
+ axes.set_ylabel("Predicted Bkg")
+ axes.set_title(
+ f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})"
+ )
axes.legend()
-
- pl_module.logger.experiment.log({
- "Correlation/Background": wandb.Image(fig),
- "Correlation/background_correlation_coefficient": correlation,
- "epoch": trainer.current_epoch
- })
+
+ pl_module.logger.experiment.log(
+ {
+ "Correlation/Background": wandb.Image(fig),
+ "Correlation/background_correlation_coefficient": correlation,
+ "epoch": trainer.current_epoch,
+ }
+ )
plt.close(fig)
print("logged image")
@@ -641,97 +888,164 @@ def on_train_epoch_end(self, trainer, pl_module):
print(f"failed ScalePlotting/Bkg with: {e}")
# Memory optimization: Clear large tensors and force cleanup
- del all_scale_means, all_background_means, scale_mean_per_reflection, background_mean_per_reflection
+ del (
+ all_scale_means,
+ all_background_means,
+ scale_mean_per_reflection,
+ background_mean_per_reflection,
+ )
torch.cuda.empty_cache()
################# Plot Intensity Correlation ##############################
try:
# Memory optimization: Process intensity correlation in chunks without verbose output
all_model_intensities = []
-
+
for i in range(0, batch_size, chunk_size):
end_idx = min(i + chunk_size, batch_size)
- chunk_batch = tuple(tensor[i:end_idx] for tensor in self.fixed_batch)
-
+ chunk_batch = tuple(
+ tensor[i:end_idx] for tensor in self.fixed_batch
+ )
+
# Get representations for this chunk
- shoebox_representation, metadata_representation, image_representation,shoebox_profile_representation = pl_module._batch_to_representations(batch=chunk_batch)
-
+ (
+ shoebox_representation,
+ metadata_representation,
+ image_representation,
+ shoebox_profile_representation,
+ ) = pl_module._batch_to_representations(batch=chunk_batch)
+
# Compute scale and structure factors for this chunk
- scale_distribution = pl_module.compute_scale(representations=[image_representation, metadata_representation])
- scale_samples = scale_distribution.rsample([pl_module.model_settings.number_of_mc_samples]).permute(1, 0, 2)
-
+ scale_distribution = pl_module.compute_scale(
+ representations=[image_representation, metadata_representation]
+ )
+ scale_samples = scale_distribution.rsample(
+ [pl_module.model_settings.number_of_mc_samples]
+ ).permute(1, 0, 2)
+
# Get structure factor samples for this chunk
hkl_chunk = chunk_batch[4] # hkl batch
- samples_surrogate_posterior = pl_module.surrogate_posterior.rsample([pl_module.model_settings.number_of_mc_samples])
+ samples_surrogate_posterior = pl_module.surrogate_posterior.rsample(
+ [pl_module.model_settings.number_of_mc_samples]
+ )
rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0]
- samples_predicted_structure_factor = pl_module.surrogate_posterior.rac.gather(
- source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=hkl_chunk
- ).unsqueeze(-1)
-
+ samples_predicted_structure_factor = (
+ pl_module.surrogate_posterior.rac.gather(
+ source=samples_surrogate_posterior.T,
+ rasu_id=rasu_ids,
+ H=hkl_chunk,
+ ).unsqueeze(-1)
+ )
+
# Compute model intensity for this chunk
- model_intensity_chunk = scale_samples * torch.square(samples_predicted_structure_factor)
- model_intensity_chunk = torch.mean(model_intensity_chunk, dim=1).squeeze()
- all_model_intensities.append(model_intensity_chunk.cpu().detach().numpy())
-
+ model_intensity_chunk = scale_samples * torch.square(
+ samples_predicted_structure_factor
+ )
+ model_intensity_chunk = torch.mean(
+ model_intensity_chunk, dim=1
+ ).squeeze()
+ all_model_intensities.append(
+ model_intensity_chunk.cpu().detach().numpy()
+ )
+
# Clear intermediate tensors
- del shoebox_representation, metadata_representation, image_representation, scale_distribution, scale_samples
- del samples_surrogate_posterior, samples_predicted_structure_factor, model_intensity_chunk
-
+ del (
+ shoebox_representation,
+ metadata_representation,
+ image_representation,
+ scale_distribution,
+ scale_samples,
+ )
+ del (
+ samples_surrogate_posterior,
+ samples_predicted_structure_factor,
+ model_intensity_chunk,
+ )
+
# Combine results
model_to_plot = np.concatenate(all_model_intensities)
- reference_to_plot = self.fixed_batch_intensity_sum_values.cpu().detach().numpy()
-
+ reference_to_plot = (
+ self.fixed_batch_intensity_sum_values.cpu().detach().numpy()
+ )
+
fig, axes = plt.subplots(figsize=(10, 10))
- scatter = axes.scatter(reference_to_plot,
- model_to_plot,
- marker='x', s=10, alpha=0.5, c="black")
+ scatter = axes.scatter(
+ reference_to_plot,
+ model_to_plot,
+ marker="x",
+ s=10,
+ alpha=0.5,
+ c="black",
+ )
try:
- fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0]
+ fitted_slope = np.polyfit(
+ reference_to_plot, model_to_plot, 1, w=None, cov=False
+ )[0]
print(f"Polyfit successful, slope: {fitted_slope:.6f}")
except np.linalg.LinAlgError as e:
print(f"Polyfit failed with LinAlgError: {e}")
print("Falling back to manual calculation")
# Fallback to manual calculation for line through origin
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
except Exception as e:
print(f"Polyfit failed with other error: {e}")
# Fallback to manual calculation
- fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
+ fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(
+ reference_to_plot**2
+ )
print(f"Fallback slope: {fitted_slope:.6f}")
-
+
# Plot the fitted line
- x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
+ x_range = np.linspace(
+ reference_to_plot.min(), reference_to_plot.max(), 100
+ )
y_fitted = fitted_slope * x_range
- axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
-
+ axes.plot(
+ x_range,
+ y_fitted,
+ "r-",
+ linewidth=2,
+ label=f"Fitted line: y = {fitted_slope:.4f}x",
+ )
+
# Set log-log scale
- axes.set_xscale('log')
- axes.set_yscale('log')
-
- correlation = np.corrcoef(reference_to_plot,
- model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1]
+ axes.set_xscale("log")
+ axes.set_yscale("log")
+
+ correlation = np.corrcoef(
+ reference_to_plot,
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )[0, 1]
print("reference_to_plot", reference_to_plot)
- print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))
+ print(
+ "model to plot scaled",
+ model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot),
+ )
print(f"Fitted slope: {fitted_slope:.4f}")
- axes.set_xlabel('Dials Reference Intensity')
- axes.set_ylabel('Predicted Intensity')
- axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})')
+ axes.set_xlabel("Dials Reference Intensity")
+ axes.set_ylabel("Predicted Intensity")
+ axes.set_title(
+ f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})"
+ )
axes.legend()
-
- pl_module.logger.experiment.log({
- "Correlation/Intensity": wandb.Image(fig),
- "Correlation/intensity_correlation_coefficient": correlation,
- "epoch": trainer.current_epoch
- })
+
+ pl_module.logger.experiment.log(
+ {
+ "Correlation/Intensity": wandb.Image(fig),
+ "Correlation/intensity_correlation_coefficient": correlation,
+ "epoch": trainer.current_epoch,
+ }
+ )
plt.close(fig)
print("logged image")
except Exception as e:
print(f"failed ScalePlotting/Bkg with {e}")
-
# ################# Plot Count Correlation ##############################
# try:
# # Extract values from dictionary
@@ -740,16 +1054,15 @@ def on_train_epoch_end(self, trainer, pl_module):
# print("samples_photon_rate (batch, mc, pix)", samples_photon_rate.shape)
# samples_photon_rate = samples_photon_rate.sum(dim=-1)
# print("samples_photon_rate (batch, mc)", samples_photon_rate.shape)
-
+
# photon_rate_per_sb = torch.mean(samples_photon_rate, dim=1).squeeze() #(batch_size, 1)
# print("photon_rate (batch)", photon_rate_per_sb.shape)
-
# reference_to_plot = self.fixed_batch_total_counts.cpu().detach().numpy()
# model_to_plot = photon_rate_per_sb.cpu().detach().numpy()
# fig, axes = plt.subplots(figsize=(10, 10))
- # scatter = axes.scatter(reference_to_plot,
- # model_to_plot,
+ # scatter = axes.scatter(reference_to_plot,
+ # model_to_plot,
# marker='x', s=10, alpha=0.5, c="black")
# try:
@@ -766,23 +1079,23 @@ def on_train_epoch_end(self, trainer, pl_module):
# # Fallback to manual calculation
# fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2)
# print(f"Fallback slope: {fitted_slope:.6f}")
-
+
# # Plot the fitted line
# x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100)
# y_fitted = fitted_slope * x_range
# axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x')
-
+
# # Set log-log scale
# axes.set_xscale('log')
# axes.set_yscale('log')
-
- # correlation = np.corrcoef(reference_to_plot,
+
+ # correlation = np.corrcoef(reference_to_plot,
# model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1]
# axes.set_xlabel('Dials Reference Intensity')
# axes.set_ylabel('Predicted Intensity')
# axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})')
# axes.legend()
-
+
# pl_module.logger.experiment.log({
# "Correlation/Counts": wandb.Image(fig),
# "Correlation/counts_correlation_coefficient": correlation,
@@ -794,14 +1107,13 @@ def on_train_epoch_end(self, trainer, pl_module):
# except Exception as e:
# print(f"failed corr counts with {e}")
-
-
def on_train_end(self, trainer, pl_module):
unmerged_mtz = rs.read_mtz(pl_module.model_settings.unmerged_mtz_file_path)
scale_dials = unmerged_mtz["SCALEUSED"]
- pl_module.logger.experiment.log({
+ pl_module.logger.experiment.log(
+ {
"scale/mean_dials": scale_dials.mean(),
"scale/max_dials": scale_dials.max(),
- "scale/min_dials": scale_dials.min()
- })
-
+ "scale/min_dials": scale_dials.min(),
+ }
+ )
diff --git a/src/factory/data_loader.py b/src/factory/data_loader.py
index fea43cd..0a00b1e 100644
--- a/src/factory/data_loader.py
+++ b/src/factory/data_loader.py
@@ -1,21 +1,20 @@
"""To load shoeboxes batched by image.
-Call with
+Call with
python data_loader.py --data_directory .
"""
-import torch
-import os
-import pytorch_lightning as pl
-from torch.utils.data import DataLoader, TensorDataset, Subset, Dataset
-from torch.utils.data import Sampler
+import argparse
+import dataclasses
import itertools
+import os
import random
-import dataclasses
-import argparse
-import settings
+import pytorch_lightning as pl
+import settings
+import torch
+from torch.utils.data import DataLoader, Dataset, Sampler, Subset, TensorDataset
class ShoeboxTensorDataset(Dataset):
@@ -45,42 +44,50 @@ def __getitem__(self, idx):
)
-class CrystallographicDataLoader():
+class CrystallographicDataLoader:
full_data_set: torch.utils.data.dataset.Subset
train_data_set: torch.utils.data.dataset.Subset
validation_data_set: torch.utils.data.dataset.Subset
test_data_set: torch.utils.data.dataset.Subset
data_loader_settings: settings.DataLoaderSettings
- def __init__(
- self,
- data_loader_settings: settings.DataLoaderSettings
- ):
+ def __init__(self, data_loader_settings: settings.DataLoaderSettings):
self.data_loader_settings = data_loader_settings
self.use_standard_sampler = False
def _get_raw_shoebox_data_(self):
-
+
# counts = torch.load(
# os.path.join(self.settings.data_directory, self.settings.data_file_names["counts"]), weights_only=True
# )
# print("counts shape:", counts.shape)
dead_pixel_mask = torch.load(
- os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["masks"]), weights_only=True
+ os.path.join(
+ self.data_loader_settings.data_directory,
+ self.data_loader_settings.data_file_names["masks"],
+ ),
+ weights_only=True,
)
print("dead_pixel_mask shape:", dead_pixel_mask.shape)
shoeboxes = torch.load(
- os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["counts"]), weights_only=True
+ os.path.join(
+ self.data_loader_settings.data_directory,
+ self.data_loader_settings.data_file_names["counts"],
+ ),
+ weights_only=True,
)
# shoeboxes = shoeboxes[:,:,-1]
# if len(shoeboxes[0]) != 7:
# self.use_standard_sampler = True
-
+
counts = shoeboxes.clone()
print("standard sampler:", self.use_standard_sampler)
- stats = torch.load(os.path.join(self.data_loader_settings.data_directory,"stats.pt"), weights_only=True)
+ stats = torch.load(
+ os.path.join(self.data_loader_settings.data_directory, "stats.pt"),
+ weights_only=True,
+ )
mean = stats[0]
var = stats[1]
@@ -100,20 +107,21 @@ def _get_raw_shoebox_data_(self):
# )
# print("dials refercne shape", dials_reference.shape)
-
metadata = torch.load(
- os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["metadata"]), weights_only=True
+ os.path.join(
+ self.data_loader_settings.data_directory,
+ self.data_loader_settings.data_file_names["metadata"],
+ ),
+ weights_only=True,
)
-
-
- print("Metadata shape:", metadata.shape) #(d, h,k, l ,x, y, z)
+ print("Metadata shape:", metadata.shape) # (d, h,k, l ,x, y, z)
# metadata = torch.zeros(shoeboxes.shape[0], 7)
# hkl = metadata[:,1:4].to(torch.int)
# print("hkl shape", hkl.shape)
- hkl = metadata[:,6:9].to(torch.int)
+ hkl = metadata[:, 6:9].to(torch.int)
print("hkl shape", hkl.shape)
# Use custom dataset for on-the-fly normalization
@@ -126,81 +134,103 @@ def _get_raw_shoebox_data_(self):
mean=mean,
var=var,
)
- print("Metadata shape full tensor:", self.full_data_set.metadata.shape) #(d, h,k, l ,x, y, z)
+ print(
+ "Metadata shape full tensor:", self.full_data_set.metadata.shape
+ ) # (d, h,k, l ,x, y, z)
# print("concentration shape", torch.load(
# os.path.join(self.settings.data_directory, "concentration.pt"), weights_only=True
# ).shape)
def append_image_id_to_metadata_(self) -> None:
- image_ids = self._get_image_ids_from_shoeboxes(
- shoebox_data_set=self.full_data_set.shoeboxes
- )
- data_as_list = list(self.full_data_set.tensors)
- data_as_list[1] = torch.cat(
- (data_as_list[1], torch.tensor(image_ids).unsqueeze(1)), dim=1
- )
- self.full_data_set.tensors = tuple(data_as_list)
-
+ image_ids = self._get_image_ids_from_shoeboxes(
+ shoebox_data_set=self.full_data_set.shoeboxes
+ )
+ data_as_list = list(self.full_data_set.tensors)
+ data_as_list[1] = torch.cat(
+ (data_as_list[1], torch.tensor(image_ids).unsqueeze(1)), dim=1
+ )
+ self.full_data_set.tensors = tuple(data_as_list)
+
# def _cut_metadata(self, metadata: torch.Tensor) -> torch.Tensor:
# # indices_to_keep = torch.tensor([self.settings.metadata_indices[i] for i in self.settings.metadata_keys_to_keep], dtype=torch.long)
# indices_to_keep = torch.tensor(0,1)
# return torch.index_select(metadata, dim=1, index=indices_to_keep)
-
- def _clean_shoeboxes_(self, shoeboxes: torch.Tensor, dead_pixel_mask, counts, metadata):
- shoebox_mask = (shoeboxes[..., -1].sum(dim=1) < 150000)
- return (shoeboxes[shoebox_mask], dead_pixel_mask[shoebox_mask], counts[shoebox_mask], metadata[shoebox_mask])
+ def _clean_shoeboxes_(
+ self, shoeboxes: torch.Tensor, dead_pixel_mask, counts, metadata
+ ):
+ shoebox_mask = shoeboxes[..., -1].sum(dim=1) < 150000
+ return (
+ shoeboxes[shoebox_mask],
+ dead_pixel_mask[shoebox_mask],
+ counts[shoebox_mask],
+ metadata[shoebox_mask],
+ )
def _split_full_data_(self) -> None:
full_data_set_length = len(self.full_data_set)
validation_data_set_length = int(
full_data_set_length * self.data_loader_settings.validation_set_split
)
- test_data_set_length = int(full_data_set_length * self.data_loader_settings.test_set_split)
+ test_data_set_length = int(
+ full_data_set_length * self.data_loader_settings.test_set_split
+ )
train_data_set_length = (
full_data_set_length - validation_data_set_length - test_data_set_length
)
- self.train_data_set, self.validation_data_set, self.test_data_set = torch.utils.data.random_split(
- self.full_data_set,
- [train_data_set_length, validation_data_set_length, test_data_set_length],
- generator=torch.Generator().manual_seed(42)
+ self.train_data_set, self.validation_data_set, self.test_data_set = (
+ torch.utils.data.random_split(
+ self.full_data_set,
+ [
+ train_data_set_length,
+ validation_data_set_length,
+ test_data_set_length,
+ ],
+ generator=torch.Generator().manual_seed(42),
+ )
)
def load_data_(self) -> None:
self._get_raw_shoebox_data_()
- if not self.use_standard_sampler and self.data_loader_settings.append_image_id_to_metadata:
+ if (
+ not self.use_standard_sampler
+ and self.data_loader_settings.append_image_id_to_metadata
+ ):
self.append_image_id_to_metadata_()
# self._clean_data_()
self._split_full_data_()
def _get_image_ids_from_shoeboxes(self, shoebox_data_set: torch.Tensor) -> list:
- """Returns the list of respective image ids for
- all shoeboxes in the shoebox data set."""
-
+ """Returns the list of respective image ids for
+ all shoeboxes in the shoebox data set."""
+
image_ids = []
for shoebox in shoebox_data_set:
- minimum_dz_index = torch.argmin(shoebox[:, 5]) # 5th index is dz
- image_ids.append(shoebox[minimum_dz_index, 2].item()) # 2nd index is z
-
+ minimum_dz_index = torch.argmin(shoebox[:, 5]) # 5th index is dz
+ image_ids.append(shoebox[minimum_dz_index, 2].item()) # 2nd index is z
+
if len(image_ids) != len(shoebox_data_set):
print(len(image_ids), len(shoebox_data_set))
raise ValueError(
f"The number of shoeboxes {len(shoebox_data_set)} does not match the number of image ids {len(image_ids)}."
)
return image_ids
-
- def _map_images_to_shoeboxes(self, shoebox_data_set: torch.Tensor, metadata:torch.Tensor) -> dict:
- """Returns a dictionary with image ids as keys and indices of all shoeboxes
- belonging to that image as values."""
-
+
+ def _map_images_to_shoeboxes(
+ self, shoebox_data_set: torch.Tensor, metadata: torch.Tensor
+ ) -> dict:
+ """Returns a dictionary with image ids as keys and indices of all shoeboxes
+ belonging to that image as values."""
+
# image_ids = self._get_image_ids_from_shoeboxes(
# shoebox_data_set=shoebox_data_set
# )
print("metadata shape", metadata.shape)
- image_ids = metadata[:,2].round().to(torch.int64) # Convert to integer type
+ image_ids = metadata[:, 2].round().to(torch.int64) # Convert to integer type
print("image ids", image_ids)
import numpy as np
+
unique_ids = torch.unique(image_ids)
# print("Number of unique image IDs:", len(unique_ids))
# print("First few unique image IDs:", unique_ids[:10])
@@ -209,20 +239,23 @@ def _map_images_to_shoeboxes(self, shoebox_data_set: torch.Tensor, metadata:torc
# print("Max image ID:", image_ids.max().item())
# print("Number of negative values:", (image_ids < 0).sum().item())
# print("Number of zero values:", (image_ids == 0).sum().item())
-
+
images_to_shoebox_indices = {}
for shoebox_index, image_id in enumerate(image_ids):
image_id_int = image_id.item() # Convert tensor to Python int
if image_id_int not in images_to_shoebox_indices:
images_to_shoebox_indices[image_id_int] = []
images_to_shoebox_indices[image_id_int].append(shoebox_index)
-
+
# print("Number of keys in dictionary:", len(images_to_shoebox_indices))
return images_to_shoebox_indices
-
class BatchByImageSampler(Sampler):
- def __init__(self, image_id_to_indices: dict, data_loader_settings: settings.DataLoaderSettings):
+ def __init__(
+ self,
+ image_id_to_indices: dict,
+ data_loader_settings: settings.DataLoaderSettings,
+ ):
self.data_loader_settings = data_loader_settings
# each element is a list of all shoebox-indices for one image
self.image_indices_list = list(image_id_to_indices.values())
@@ -240,7 +273,7 @@ def __iter__(self):
batch_count = 0
image_number = 0
for image in images:
- image_number +=1
+ image_number += 1
for shoebox in image:
batch.append(shoebox)
if len(batch) == self.batch_size:
@@ -250,14 +283,17 @@ def __iter__(self):
batch = []
batch_count += 1
if batch_count >= self.num_batches:
- return
+ return
def __len__(self):
return self.num_batches
- def load_data_set_batched_by_image(self,
- data_set_to_load: torch.utils.data.dataset.Subset | torch.utils.data.TensorDataset,
- ) -> torch.utils.data.dataloader.DataLoader:
+ def load_data_set_batched_by_image(
+ self,
+ data_set_to_load: (
+ torch.utils.data.dataset.Subset | torch.utils.data.TensorDataset
+ ),
+ ) -> torch.utils.data.dataloader.DataLoader:
if self.use_standard_sampler:
return DataLoader(
data_set_to_load,
@@ -268,11 +304,14 @@ def load_data_set_batched_by_image(self,
)
if isinstance(data_set_to_load, torch.utils.data.dataset.Subset):
- print("Metadata shape full tensor in loadeing:", self.full_data_set.metadata.shape) #(d, h,k, l ,x, y, z)
+ print(
+ "Metadata shape full tensor in loadeing:",
+ self.full_data_set.metadata.shape,
+ ) # (d, h,k, l ,x, y, z)
image_id_to_indices = self._map_images_to_shoeboxes(
shoebox_data_set=self.full_data_set.shoeboxes[data_set_to_load.indices],
- metadata=self.full_data_set.metadata[data_set_to_load.indices]
+ metadata=self.full_data_set.metadata[data_set_to_load.indices],
)
else:
image_id_to_indices = self._map_images_to_shoeboxes(
@@ -281,8 +320,8 @@ def load_data_set_batched_by_image(self,
batch_by_image_sampler = self.BatchByImageSampler(
image_id_to_indices=image_id_to_indices,
- data_loader_settings=self.data_loader_settings
- )
+ data_loader_settings=self.data_loader_settings,
+ )
return DataLoader(
data_set_to_load,
batch_sampler=batch_by_image_sampler,
@@ -290,30 +329,34 @@ def load_data_set_batched_by_image(self,
pin_memory=self.data_loader_settings.pin_memory,
prefetch_factor=self.data_loader_settings.prefetch_factor,
)
-
- def load_data_for_logging_during_training(self, number_of_shoeboxes_to_log: int = 5) -> torch.utils.data.dataloader.DataLoader:
+
+ def load_data_for_logging_during_training(
+ self, number_of_shoeboxes_to_log: int = 5
+ ) -> torch.utils.data.dataloader.DataLoader:
if self.use_standard_sampler:
- subset = Subset(self.train_data_set, indices=range(number_of_shoeboxes_to_log))
+ subset = Subset(
+ self.train_data_set, indices=range(number_of_shoeboxes_to_log)
+ )
return DataLoader(subset, batch_size=number_of_shoeboxes_to_log)
image_id_to_indices = self._map_images_to_shoeboxes(
- shoebox_data_set=self.full_data_set.shoeboxes[self.train_data_set.indices],
- metadata=self.full_data_set.metadata[self.train_data_set.indices]
- )
- data_loader_settings = dataclasses.replace(self.data_loader_settings, number_of_shoeboxes_per_batch=number_of_shoeboxes_to_log, number_of_batches=1)
+ shoebox_data_set=self.full_data_set.shoeboxes[self.train_data_set.indices],
+ metadata=self.full_data_set.metadata[self.train_data_set.indices],
+ )
+ data_loader_settings = dataclasses.replace(
+ self.data_loader_settings,
+ number_of_shoeboxes_per_batch=number_of_shoeboxes_to_log,
+ number_of_batches=1,
+ )
batch_by_image_sampler = self.BatchByImageSampler(
image_id_to_indices=image_id_to_indices,
- data_loader_settings=data_loader_settings
- )
+ data_loader_settings=data_loader_settings,
+ )
return DataLoader(
self.train_data_set,
batch_sampler=batch_by_image_sampler,
num_workers=self.data_loader_settings.number_of_workers,
pin_memory=self.data_loader_settings.pin_memory,
prefetch_factor=self.data_loader_settings.prefetch_factor,
- persistent_workers=True
+ persistent_workers=True,
)
-
-
-
-
diff --git a/src/factory/distributions.py b/src/factory/distributions.py
index 2f8755d..1eb1a6d 100644
--- a/src/factory/distributions.py
+++ b/src/factory/distributions.py
@@ -1,26 +1,28 @@
-import torch
-import torch.nn.functional as F
-import networks
-from networks import Linear
import math
-
-from networks import Linear, Constraint
-from torch.distributions import HalfNormal, Exponential
-
from abc import ABC, abstractmethod
+import networks
+import torch
+import torch.nn.functional as F
+from networks import Constraint, Linear
+from torch.distributions import Exponential, HalfNormal
+
class ProfileDistribution(ABC, torch.nn.Module):
"""Base class for profile distributions that can compute KL divergence and generate profiles."""
def __init__(self):
super().__init__()
-
+
def compute_kl_divergence(self, predicted_distribution, target_distribution):
- try:
- return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution)
+ try:
+ return torch.distributions.kl.kl_divergence(
+ predicted_distribution, target_distribution
+ )
except NotImplementedError:
- print("KL divergence not implemented for this distribution: use sampling method.")
+ print(
+ "KL divergence not implemented for this distribution: use sampling method."
+ )
samples = predicted_distribution.rsample([100])
log_q = predicted_distribution.log_prob(samples)
log_p = target_distribution.log_prob(samples)
@@ -30,32 +32,32 @@ def compute_kl_divergence(self, predicted_distribution, target_distribution):
def _update_prior_distributions(self):
"""Update prior distributions to use current device."""
pass
-
+
@abstractmethod
def kl_divergence_weighted(self, weight):
"""Compute KL divergence between predicted and prior distributions.
-
+
Args:
weight: Weight factor for the KL divergence
"""
pass
-
+
@abstractmethod
def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor:
"""Compute the profile from the distribution.
-
+
Args:
*representations: Variable number of representations to combine
-
+
Returns:
torch.Tensor: The computed profile normalized to sum to 1
"""
pass
-
+
@abstractmethod
def forward(self, *representations):
"""Forward pass of the distribution.
-
+
Args:
*representations: Variable number of representations to combine
"""
@@ -63,23 +65,28 @@ def forward(self, *representations):
def compute_kl_divergence(predicted_distribution, target_distribution):
- try:
- # Create proper copies of the distributions
- # pred_copy = type(predicted_distribution)(
- # loc=predicted_distribution.loc.clone(),
- # scale=predicted_distribution.scale.clone()
- # )
- # target_copy = type(target_distribution)(
- # loc=target_distribution.loc.clone(),
- # scale=target_distribution.scale.clone()
- # )
- return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution)
- except NotImplementedError:
- print("KL divergence not implemented for this distribution: use sampling method.")
- samples = predicted_distribution.rsample([100])
- log_q = predicted_distribution.log_prob(samples)
- log_p = target_distribution.log_prob(samples)
- return (log_q - log_p).mean(dim=0)
+ try:
+ # Create proper copies of the distributions
+ # pred_copy = type(predicted_distribution)(
+ # loc=predicted_distribution.loc.clone(),
+ # scale=predicted_distribution.scale.clone()
+ # )
+ # target_copy = type(target_distribution)(
+ # loc=target_distribution.loc.clone(),
+ # scale=target_distribution.scale.clone()
+ # )
+ return torch.distributions.kl.kl_divergence(
+ predicted_distribution, target_distribution
+ )
+ except NotImplementedError:
+ print(
+ "KL divergence not implemented for this distribution: use sampling method."
+ )
+ samples = predicted_distribution.rsample([100])
+ log_q = predicted_distribution.log_prob(samples)
+ log_p = target_distribution.log_prob(samples)
+ return (log_q - log_p).mean(dim=0)
+
class DirichletProfile(ProfileDistribution):
"""
@@ -96,21 +103,29 @@ def __init__(self, concentration_vector, dmodel, input_shape=(3, 21, 21)):
# param.requires_grad = False
self.dmodel = dmodel
self.eps = 1e-6
- concentration_vector[concentration_vector>torch.quantile(concentration_vector, 0.99)] *= 40
+ concentration_vector[
+ concentration_vector > torch.quantile(concentration_vector, 0.99)
+ ] *= 40
concentration_vector /= concentration_vector.max()
- self.register_buffer("concentration", concentration_vector) #.reshape((21,21,3)))
+ self.register_buffer(
+ "concentration", concentration_vector
+ ) # .reshape((21,21,3)))
self.q_p = None
-
+
def _update_prior_distributions(self):
"""Update prior distributions to use current device."""
self.dirichlet_prior = torch.distributions.Dirichlet(self.concentration)
-
+
def kl_divergence_weighted(self, weight):
self._update_prior_distributions()
- print("shape inot kl profile", self.q_p.sample().shape, self.dirichlet_prior.sample().shape)
- return self.compute_kl_divergence(self.q_p, self.dirichlet_prior)* weight
-
+ print(
+ "shape inot kl profile",
+ self.q_p.sample().shape,
+ self.dirichlet_prior.sample().shape,
+ )
+ return self.compute_kl_divergence(self.q_p, self.dirichlet_prior) * weight
+
def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor:
dirichlet_profile = self.forward(*representations)
return dirichlet_profile
@@ -120,44 +135,68 @@ def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor:
# return avg_profile
def forward(self, *representations):
- alphas = sum(representations) if len(representations) > 1 else representations[0]
+ alphas = (
+ sum(representations) if len(representations) > 1 else representations[0]
+ )
print("alpha shape", alphas.shape)
-
+
# Check for NaN values in input representations
if torch.isnan(alphas).any():
print("WARNING: NaN values detected in input representations!")
print("NaN count:", torch.isnan(alphas).sum().item())
- print("Input stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item())
-
+ print(
+ "Input stats - min:",
+ alphas.min().item(),
+ "max:",
+ alphas.max().item(),
+ "mean:",
+ alphas.mean().item(),
+ )
+
if self.dmodel is not None:
alphas = self.alpha_layer(alphas)
# Check for NaN values after linear layer
if torch.isnan(alphas).any():
print("WARNING: NaN values detected after alpha_layer!")
print("NaN count:", torch.isnan(alphas).sum().item())
- print("Alpha layer output stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item())
-
+ print(
+ "Alpha layer output stats - min:",
+ alphas.min().item(),
+ "max:",
+ alphas.max().item(),
+ "mean:",
+ alphas.mean().item(),
+ )
+
alphas = F.softplus(alphas) + self.eps
print("profile alphas shape", alphas.shape)
-
+
# Check for NaN values after softplus
if torch.isnan(alphas).any():
print("WARNING: NaN values detected after softplus!")
print("NaN count:", torch.isnan(alphas).sum().item())
- print("Softplus output stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item())
-
+ print(
+ "Softplus output stats - min:",
+ alphas.min().item(),
+ "max:",
+ alphas.max().item(),
+ "mean:",
+ alphas.mean().item(),
+ )
+
# Replace NaN values with a safe default
-
+
# Ensure all values are positive and finite
-
+
self.q_p = torch.distributions.Dirichlet(alphas)
print("profile q_p shape", self.q_p.rsample().shape)
return self.q_p
+
class Distribution(torch.nn.Module):
def __init__(self):
super().__init__()
-
+
def forward(self, representation):
pass
@@ -188,6 +227,7 @@ def forward(self, representation):
norm = self.distribution(params)
return norm
+
class ExponentialDistribution(torch.nn.Module):
def __init__(
self,
@@ -232,7 +272,7 @@ def __init__(self, dmodel=64, number_of_mc_samples=100):
self.d = 3
self.h = 21
- self.w = 21
+ self.w = 21
z_coords = torch.arange(self.d).float() - (self.d - 1) / 2
y_coords = torch.arange(self.h).float() - (self.h - 1) / 2
@@ -250,11 +290,11 @@ def __init__(self, dmodel=64, number_of_mc_samples=100):
self.register_buffer("pixel_positions", pixel_positions)
# Initialize prior distributions
- self.register_buffer('prior_mvn_mean_loc', torch.tensor(0.0))
- self.register_buffer('prior_mvn_mean_scale', torch.tensor(5.0))
- self.register_buffer('prior_mvn_cov_factor_loc', torch.tensor(0.0))
- self.register_buffer('prior_mvn_cov_factor_scale', torch.tensor(0.01))
- self.register_buffer('prior_mvn_cov_scale_scale', torch.tensor(0.01))
+ self.register_buffer("prior_mvn_mean_loc", torch.tensor(0.0))
+ self.register_buffer("prior_mvn_mean_scale", torch.tensor(5.0))
+ self.register_buffer("prior_mvn_cov_factor_loc", torch.tensor(0.0))
+ self.register_buffer("prior_mvn_cov_factor_scale", torch.tensor(0.01))
+ self.register_buffer("prior_mvn_cov_scale_scale", torch.tensor(0.01))
self.mvn_mean_mean_layer = Linear(
in_features=self.hidden_dim,
@@ -280,17 +320,15 @@ def __init__(self, dmodel=64, number_of_mc_samples=100):
def _update_prior_distributions(self):
"""Update prior distributions to use current device."""
self.prior_mvn_mean = torch.distributions.Normal(
- loc=self.prior_mvn_mean_loc,
- scale=self.prior_mvn_mean_scale
+ loc=self.prior_mvn_mean_loc, scale=self.prior_mvn_mean_scale
)
self.prior_mvn_cov_factor = torch.distributions.Normal(
- loc=self.prior_mvn_cov_factor_loc,
- scale=self.prior_mvn_cov_factor_scale
+ loc=self.prior_mvn_cov_factor_loc, scale=self.prior_mvn_cov_factor_scale
)
self.prior_mvn_cov_scale = torch.distributions.HalfNormal(
scale=self.prior_mvn_cov_scale_scale
)
-
+
def kl_divergence_weighted(self, weights):
"""Compute the KL divergence between the predicted and prior distributions.
@@ -300,15 +338,22 @@ def kl_divergence_weighted(self, weights):
Returns:
torch.Tensor: The weighted sum of KL divergences
"""
- _kl_divergence = compute_kl_divergence(
- self.mvn_mean_distribution, self.prior_mvn_mean
- ).sum() * weights[0]
- _kl_divergence += compute_kl_divergence(
- self.mvn_cov_factor_distribution, self.prior_mvn_cov_factor
- ).sum() * weights[1]
- _kl_divergence += compute_kl_divergence(
- self.mvn_cov_scale_distribution, self.prior_mvn_cov_scale
- ).sum() * weights[2]
+ _kl_divergence = (
+ compute_kl_divergence(self.mvn_mean_distribution, self.prior_mvn_mean).sum()
+ * weights[0]
+ )
+ _kl_divergence += (
+ compute_kl_divergence(
+ self.mvn_cov_factor_distribution, self.prior_mvn_cov_factor
+ ).sum()
+ * weights[1]
+ )
+ _kl_divergence += (
+ compute_kl_divergence(
+ self.mvn_cov_scale_distribution, self.prior_mvn_cov_scale
+ ).sum()
+ * weights[2]
+ )
return _kl_divergence
def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor:
@@ -334,26 +379,38 @@ def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor:
avg_profile = profile.mean(dim=1)
avg_profile = avg_profile / (avg_profile.sum(dim=-1, keepdim=True) + 1e-10)
- return profile_mvn_distribution #profile
+ return profile_mvn_distribution # profile
- def forward(self, *representations: torch.Tensor) -> torch.distributions.LowRankMultivariateNormal:
+ def forward(
+ self, *representations: torch.Tensor
+ ) -> torch.distributions.LowRankMultivariateNormal:
"""Forward pass of the LRMVN distribution."""
# Update prior distributions to ensure they're on the correct device
self._update_prior_distributions()
-
+
self.batch_size = representations[0].shape[0]
# Combine all representations if more than one is provided
- combined_representation = sum(representations) if len(representations) > 1 else representations[0]
+ combined_representation = (
+ sum(representations) if len(representations) > 1 else representations[0]
+ )
mvn_mean_mean = self.mvn_mean_mean_layer(representations[0]).unsqueeze(-1)
- mvn_mean_std = F.softplus(self.mvn_mean_std_layer(representations[0])).unsqueeze(-1)
+ mvn_mean_std = F.softplus(
+ self.mvn_mean_std_layer(representations[0])
+ ).unsqueeze(-1)
- mvn_cov_factor_mean = self.mvn_cov_factor_mean_layer(combined_representation).unsqueeze(-1)
- mvn_cov_factor_std = F.softplus(self.mvn_cov_factor_std_layer(combined_representation)).unsqueeze(-1)
+ mvn_cov_factor_mean = self.mvn_cov_factor_mean_layer(
+ combined_representation
+ ).unsqueeze(-1)
+ mvn_cov_factor_std = F.softplus(
+ self.mvn_cov_factor_std_layer(combined_representation)
+ ).unsqueeze(-1)
# for the inverse gammas
- mvn_cov_scale_parameter = F.softplus(self.mvn_cov_diagonal_scale_layer(representations[0])).unsqueeze(-1)
+ mvn_cov_scale_parameter = F.softplus(
+ self.mvn_cov_diagonal_scale_layer(representations[0])
+ ).unsqueeze(-1)
self.mvn_mean_distribution = torch.distributions.Normal(
loc=mvn_mean_mean,
@@ -368,16 +425,31 @@ def forward(self, *representations: torch.Tensor) -> torch.distributions.LowRank
self.mvn_cov_scale_distribution = torch.distributions.half_normal.HalfNormal(
scale=mvn_cov_scale_parameter,
)
- mvn_mean_samples = self.mvn_mean_distribution.rsample([self.number_of_mc_samples]).squeeze(-1).permute(1, 0, 2)
- mvn_cov_factor_samples = self.mvn_cov_factor_distribution.rsample([self.number_of_mc_samples]).permute(1, 0, 2, 3)
+ mvn_mean_samples = (
+ self.mvn_mean_distribution.rsample([self.number_of_mc_samples])
+ .squeeze(-1)
+ .permute(1, 0, 2)
+ )
+ mvn_cov_factor_samples = self.mvn_cov_factor_distribution.rsample(
+ [self.number_of_mc_samples]
+ ).permute(1, 0, 2, 3)
diag_samples = (
- self.mvn_cov_scale_distribution.rsample([self.number_of_mc_samples]).squeeze(-1).permute(1, 0, 2) + 1e-6
+ self.mvn_cov_scale_distribution.rsample([self.number_of_mc_samples])
+ .squeeze(-1)
+ .permute(1, 0, 2)
+ + 1e-6
)
self.profile_mvn_distribution = (
torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(
- loc=mvn_mean_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3),
- cov_factor=mvn_cov_factor_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3, 1),
- cov_diag=diag_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3),
+ loc=mvn_mean_samples.view(
+ self.batch_size, self.number_of_mc_samples, 1, 3
+ ),
+ cov_factor=mvn_cov_factor_samples.view(
+ self.batch_size, self.number_of_mc_samples, 1, 3, 1
+ ),
+ cov_diag=diag_samples.view(
+ self.batch_size, self.number_of_mc_samples, 1, 3
+ ),
)
)
@@ -387,12 +459,12 @@ def to(self, *args, **kwargs):
super().to(*args, **kwargs)
# Get device from the first parameter
device = next(self.parameters()).device
-
+
# Move all registered buffers
for name, buffer in self.named_buffers():
setattr(self, name, buffer.to(device))
-
+
# Update prior distributions
self._update_prior_distributions()
-
- return self
\ No newline at end of file
+
+ return self
diff --git a/src/factory/encoder.py b/src/factory/encoder.py
index 61235a3..aa54253 100644
--- a/src/factory/encoder.py
+++ b/src/factory/encoder.py
@@ -1,6 +1,6 @@
# adapted from Luis Aldama @https://github.com/Hekstra-Lab/integrator.git
import torch
-
+
class CNN_3d_(torch.nn.Module):
def __init__(self, Z=3, H=21, W=21, conv_channels=64, use_norm=True):
@@ -59,7 +59,8 @@ def forward(self, x, mask=None):
x = torch.flatten(x, 1)
return x
-
+
+
class CNN_3d(torch.nn.Module):
def __init__(self, out_dim=64):
"""
@@ -105,8 +106,8 @@ def forward(self, x, mask=None):
x = x.view(x.size(0), -1)
rep = F.relu(self.fc(x))
return rep
-
-
+
+
class MLP(torch.nn.Module):
def __init__(self, width, depth, dropout=None, output_dims=None):
super().__init__()
@@ -131,7 +132,8 @@ def forward(self, data):
if num_pixels is not None:
out = out.view(batch_size, num_pixels, -1) # Reshape back if needed
return out
-
+
+
class ResidualLayer(torch.nn.Module):
def __init__(self, width, dropout=None):
super().__init__()
@@ -177,6 +179,7 @@ def forward(self, shoebox_data):
out = self.mlp_1(out)
return out
+
def main():
# Test the model
model = CNN_3d()
@@ -189,4 +192,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/factory/generate_eff.py b/src/factory/generate_eff.py
index 93ed724..ed173dd 100644
--- a/src/factory/generate_eff.py
+++ b/src/factory/generate_eff.py
@@ -66,36 +66,34 @@
}}
"""
+
def main():
parser = argparse.ArgumentParser(
description="Generate a Phenix .eff file with dynamic MTZ paths."
)
parser.add_argument(
- "--sf-mtz", required=True,
- help="Path to the structure-factor MTZ file"
+ "--sf-mtz", required=True, help="Path to the structure-factor MTZ file"
)
parser.add_argument(
- "--rfree-mtz", required=True,
- help="Path to the R-free flags MTZ file"
+ "--rfree-mtz", required=True, help="Path to the R-free flags MTZ file"
)
parser.add_argument(
- "--out", required=True,
- help="Output path for the generated .eff file"
+ "--out", required=True, help="Output path for the generated .eff file"
)
parser.add_argument(
- "--phenix-out-mtz", required=True,
- help="Specific name of the output .mtz file of phenix."
+ "--phenix-out-mtz",
+ required=True,
+ help="Specific name of the output .mtz file of phenix.",
)
args = parser.parse_args()
# Fill in the template and write to the output file
content = EFF_TEMPLATE.format(
- sf_mtz=args.sf_mtz,
- rfree_mtz=args.rfree_mtz,
- phenix_out_mtz=args.phenix_out_mtz
+ sf_mtz=args.sf_mtz, rfree_mtz=args.rfree_mtz, phenix_out_mtz=args.phenix_out_mtz
)
Path(args.out).write_text(content)
print(f"Wrote .eff file to {args.out}")
+
if __name__ == "__main__":
main()
diff --git a/src/factory/get_protein_data.py b/src/factory/get_protein_data.py
index 4507d8c..018f44e 100644
--- a/src/factory/get_protein_data.py
+++ b/src/factory/get_protein_data.py
@@ -1,6 +1,7 @@
import gemmi
import requests
+
def get_protein_data(url: str):
response = requests.get(url)
if not response.ok:
@@ -16,7 +17,8 @@ def get_protein_data(url: str):
if dmin_str is None:
raise Exception("Resolution (dmin) not found in the CIF file")
dmin = float(dmin_str)
- return{"unit_cell": unit_cell, "spacegroup": spacegroup, "dmin": dmin_str}
+ return {"unit_cell": unit_cell, "spacegroup": spacegroup, "dmin": dmin_str}
+
if __name__ == "__main__":
data = get_protein_data("https://files.rcsb.org/download/9B7C.cif")
diff --git a/src/factory/import gemmi.py b/src/factory/import gemmi.py
index 17a04b8..3883956 100644
--- a/src/factory/import gemmi.py
+++ b/src/factory/import gemmi.py
@@ -1,6 +1,7 @@
import gemmi
import pandas as pd
+
def read_mtz_to_data_frame(file_path: string) -> pd.DataFramr
mtz = gemmi.read_mtz_file(file_path)
diff --git a/src/factory/inspect_background.py b/src/factory/inspect_background.py
index cbc2315..d9997a6 100644
--- a/src/factory/inspect_background.py
+++ b/src/factory/inspect_background.py
@@ -10,13 +10,14 @@
dials.python ../factory/src/factory/inspect_background.py
"""
-import sys
import os
-from dials.array_family import flex
-import matplotlib.pyplot as plt
-import wandb
import statistics
+import sys
+
+import matplotlib.pyplot as plt
import numpy as np
+import wandb
+from dials.array_family import flex
from scipy.stats import gamma
@@ -45,16 +46,24 @@ def main(refl_file):
print(f"Variance of background.sum.value: {variance:.3f}")
plt.figure(figsize=(8, 5))
- plt.hist(bg_sum, bins=100, edgecolor='black', alpha=0.6, label='background.sum.value')
- plt.xlabel('Background Sum Value')
- plt.ylabel('Count')
- plt.title('Histogram of background.sum.value')
+ plt.hist(
+ bg_sum, bins=100, edgecolor="black", alpha=0.6, label="background.sum.value"
+ )
+ plt.xlabel("Background Sum Value")
+ plt.ylabel("Count")
+ plt.title("Histogram of background.sum.value")
# Add Gamma distribution histogram
alpha = 1.9802
scale = 75.4010
gamma_samples = gamma.rvs(a=alpha, scale=scale, size=len(bg_sum))
- plt.hist(gamma_samples, bins=100, edgecolor='red', alpha=0.4, label='Gamma(alpha=1.98, scale=75.4)')
+ plt.hist(
+ gamma_samples,
+ bins=100,
+ edgecolor="red",
+ alpha=0.4,
+ label="Gamma(alpha=1.98, scale=75.4)",
+ )
plt.legend()
plt.tight_layout()
@@ -62,6 +71,9 @@ def main(refl_file):
wandb.log({"background_histogram": wandb.Image(plt.gcf())})
wandb.finish()
+
if __name__ == "__main__":
- refl_file = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl"
+ refl_file = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl"
+ )
main(refl_file)
diff --git a/src/factory/inspect_scale.py b/src/factory/inspect_scale.py
index 918d10c..839fe4c 100755
--- a/src/factory/inspect_scale.py
+++ b/src/factory/inspect_scale.py
@@ -1,42 +1,44 @@
-"""run with
+"""run with
source /n/hekstra_lab/people/aldama/software/dials-v3-16-1/dials_env.sh
-dials.python ../factory/src/factory/inspect_scale.py """
+dials.python ../factory/src/factory/inspect_scale.py"""
#!/usr/bin/env dials.python
import sys
-from dials.array_family import flex
+
import matplotlib.pyplot as plt
+
# matplotlib.use("Agg") # ← use a headless backend
import wandb
+from dials.array_family import flex
+
def main(refl_file):
refl = flex.reflection_table.from_file(refl_file)
if "inverse_scale_factor" not in refl:
raise RuntimeError("No inverse_scale_factor column found in " + refl_file)
-
+
# Extract the scale factors
scales = refl["inverse_scale_factor"]
scales_list = list(scales)
- scales = [1/x for x in scales_list]
+ scales = [1 / x for x in scales_list]
scales_list = scales
mean_scale = sum(scales) / len(scales)
- variance_scale = sum((s - mean_scale)**2 for s in scales) / len(scales)
+ variance_scale = sum((s - mean_scale) ** 2 for s in scales) / len(scales)
max_scale = max(scales)
min_scale = min(scales)
print("Mean of the 1/scale factors:", mean_scale)
print("Max of the 1/scale factors:", max_scale)
print("Min of the 1/scale factors:", min_scale)
print("Variance of the 1/scale factors:", variance_scale)
-
+
# Convert to a regular Python list for plotting
-
# Plot histogram
- plt.figure(figsize=(8,5))
- plt.hist(scales_list, bins=100, edgecolor='black', alpha=0.6, label='Data')
+ plt.figure(figsize=(8, 5))
+ plt.hist(scales_list, bins=100, edgecolor="black", alpha=0.6, label="Data")
plt.xlabel("1/Inverse Scale Factor")
plt.ylabel("Number of Observations")
plt.title("Histogram of Per‐Reflection 1/Scale Factors")
@@ -45,15 +47,18 @@ def main(refl_file):
# Overlay Gamma distribution
import numpy as np
from scipy.stats import gamma
+
x = np.linspace(min(scales_list), max(scales_list), 1000)
gamma_shape = 6.68
gamma_scale = 0.155
y = gamma.pdf(x, a=gamma_shape, scale=gamma_scale)
# Scale the gamma PDF to match the histogram
- y_scaled = y * len(scales_list) * (max(scales_list) - min(scales_list)) / 100 # 100 bins
- plt.plot(x, y_scaled, 'r-', lw=2, label=f'Gamma({gamma_shape}, {gamma_scale})')
+ y_scaled = (
+ y * len(scales_list) * (max(scales_list) - min(scales_list)) / 100
+ ) # 100 bins
+ plt.plot(x, y_scaled, "r-", lw=2, label=f"Gamma({gamma_shape}, {gamma_scale})")
plt.legend()
-
+
# Show or save
# out_png = "scale_histogram.png"
# plt.savefig(out_png, dpi=300) # ← write to disk
@@ -66,5 +71,6 @@ def main(refl_file):
# If you’d rather save to disk, uncomment:
# plt.savefig("scale_histogram.png", dpi=300)
+
if __name__ == "__main__":
- main("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl")
\ No newline at end of file
+ main("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl")
diff --git a/src/factory/loss_function.py b/src/factory/loss_function.py
index 09ccd06..8db9a0c 100644
--- a/src/factory/loss_function.py
+++ b/src/factory/loss_function.py
@@ -1,8 +1,9 @@
import torch
+
class LossFunction(torch.nn.Module):
def __init__(self):
pass
-
+
def forward(self):
- pass
\ No newline at end of file
+ pass
diff --git a/src/factory/loss_functionality.py b/src/factory/loss_functionality.py
index 9dce0a7..316d956 100644
--- a/src/factory/loss_functionality.py
+++ b/src/factory/loss_functionality.py
@@ -1,14 +1,16 @@
-import torch
-import model
import dataclasses
from typing import Optional
+
+import model
import settings
+import torch
from tensordict.nn.distributions import Delta
from wrap_folded_normal import SparseFoldedNormalPosterior
+
@dataclasses.dataclass
-class LossOutput():
+class LossOutput:
loss: Optional[torch.Tensor] = None
kl_structure_factors: Optional[torch.Tensor] = None
kl_background: Optional[torch.Tensor] = None
@@ -16,24 +18,33 @@ class LossOutput():
kl_scale: Optional[torch.Tensor] = None
log_likelihood: Optional[torch.Tensor] = None
+
class LossFunction(torch.nn.Module):
- def __init__(self):#, model_settings: settings.ModelSettings, loss_settings:settings.LossSettings):
+ def __init__(
+ self,
+ ): # , model_settings: settings.ModelSettings, loss_settings:settings.LossSettings):
super().__init__()
# self.model_settings = model_settings
# self.loss_settings = loss_settings
def compute_kl_divergence(self, predicted_distribution, target_distribution):
- try:
- return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution)
+ try:
+ return torch.distributions.kl.kl_divergence(
+ predicted_distribution, target_distribution
+ )
except NotImplementedError:
- print("KL divergence not implemented for this distribution: use sampling method.")
+ print(
+ "KL divergence not implemented for this distribution: use sampling method."
+ )
samples = predicted_distribution.rsample([50])
log_q = predicted_distribution.log_prob(samples)
log_p = target_distribution.log_prob(samples)
return (log_q - log_p).mean(dim=0)
- def compute_kl_divergence_verbose(self, predicted_distribution, target_distribution):
-
+ def compute_kl_divergence_verbose(
+ self, predicted_distribution, target_distribution
+ ):
+
samples = predicted_distribution.rsample([50])
log_q = predicted_distribution.log_prob(samples)
print(" pred", log_q)
@@ -41,7 +52,9 @@ def compute_kl_divergence_verbose(self, predicted_distribution, target_distribut
print(" target", log_p)
return (log_q - log_p).mean(dim=0)
- def compute_kl_divergence_surrogate_parameters(self, predicted_distribution, target_distribution, ordered_miller_indices):
+ def compute_kl_divergence_surrogate_parameters(
+ self, predicted_distribution, target_distribution, ordered_miller_indices
+ ):
samples = predicted_distribution.rsample([50])
log_q = predicted_distribution.log_prob(samples)
@@ -49,17 +62,20 @@ def compute_kl_divergence_surrogate_parameters(self, predicted_distribution, tar
log_p = target_distribution.log_prob(samples)
rasu_ids = surrogate_posterior.rac.rasu_ids[0]
- print("prior_intensity.rsample([20]).permute(1,0)", prior_intensity.rsample([20]).permute(1,0))
+ print(
+ "prior_intensity.rsample([20]).permute(1,0)",
+ prior_intensity.rsample([20]).permute(1, 0),
+ )
prior_intensity_samples = surrogate_posterior.rac.gather(
- source=prior_intensity.rsample([20]).permute(1,0).T, rasu_id=rasu_ids, H=ordered_miller_indices.to(torch.int)
+ source=prior_intensity.rsample([20]).permute(1, 0).T,
+ rasu_id=rasu_ids,
+ H=ordered_miller_indices.to(torch.int),
)
-
-
log_p = target_distribution.log_prob(samples)
print("scale: target", log_p)
return (log_q - log_p).mean(dim=0)
-
+
def get_prior(self, distribution, device):
# Move parameters to device and recreate the distribution
# if isinstance(distribution, torch.distributions.HalfCauchy):
@@ -68,9 +84,20 @@ def get_prior(self, distribution, device):
# return torch.distributions.HalfNormal(distribution.scale.to(device))
# Add more cases as needed
return distribution
-
- def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_distribution, scale_distribution, surrogate_posterior,
- profile_distribution, model_settings, loss_settings) -> LossOutput:
+
+ def forward(
+ self,
+ counts,
+ mask,
+ ordered_miller_indices,
+ photon_rate,
+ background_distribution,
+ scale_distribution,
+ surrogate_posterior,
+ profile_distribution,
+ model_settings,
+ loss_settings,
+ ) -> LossOutput:
print("in loss")
try:
device = photon_rate.device
@@ -79,37 +106,54 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_
loss_output = LossOutput()
# print("cuda=", next(self.model_settings.intensity_prior_distibution.parameters()).device)
- prior_background = self.get_prior(model_settings.background_prior_distribution, device=device)
+ prior_background = self.get_prior(
+ model_settings.background_prior_distribution, device=device
+ )
print("got priors bg")
- prior_intensity = model_settings.build_intensity_prior_distribution(model_settings.rac.to(device))
-
-
+ prior_intensity = model_settings.build_intensity_prior_distribution(
+ model_settings.rac.to(device)
+ )
+
# self.get_prior(model_settings.intensity_prior_distibution, device=device)
print("got priors sf")
- prior_scale = self.get_prior(model_settings.scale_prior_distibution, device=device)
+ prior_scale = self.get_prior(
+ model_settings.scale_prior_distibution, device=device
+ )
print("got priors")
# Compute KL divergence for structure factors
if model_settings.use_surrogate_parameters:
- rasu_ids = torch.tensor([0 for _ in range(len(ordered_miller_indices))], device=device)
+ rasu_ids = torch.tensor(
+ [0 for _ in range(len(ordered_miller_indices))], device=device
+ )
_kl_divergence_folded_normal = self.compute_kl_divergence(
- surrogate_posterior, prior_intensity.distribution(rasu_ids, ordered_miller_indices))
+ surrogate_posterior,
+ prior_intensity.distribution(rasu_ids, ordered_miller_indices),
+ )
# else:
# # Handle case where no valid parameters were found
# raise ValueError("No valid surrogate parameters found")
elif isinstance(surrogate_posterior, SparseFoldedNormalPosterior):
_kl_divergence_folded_normal = self.compute_kl_divergence(
- surrogate_posterior.get_distribution(ordered_miller_indices), prior_intensity.distribution())
+ surrogate_posterior.get_distribution(ordered_miller_indices),
+ prior_intensity.distribution(),
+ )
else:
- rasu_ids = torch.tensor([0 for _ in range(len(ordered_miller_indices))], device=device)
+ rasu_ids = torch.tensor(
+ [0 for _ in range(len(ordered_miller_indices))], device=device
+ )
_kl_divergence_folded_normal = self.compute_kl_divergence_verbose(
- surrogate_posterior, prior_intensity.distribution(rasu_ids, ordered_miller_indices))
- print("structure factor distr", surrogate_posterior.rsample([1]).shape)#, prior_intensity.rsample([1]).shape)
+ surrogate_posterior,
+ prior_intensity.distribution(rasu_ids, ordered_miller_indices),
+ )
+ print(
+ "structure factor distr", surrogate_posterior.rsample([1]).shape
+ ) # , prior_intensity.rsample([1]).shape)
# rasu_ids = surrogate_posterior.rac.rasu_ids[0]
@@ -117,44 +161,64 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_
# source=_kl_divergence_folded_normal_all_reflections.T, rasu_id=rasu_ids, H=ordered_miller_indices.to(torch.int)
# )
- kl_structure_factors = _kl_divergence_folded_normal * loss_settings.prior_structure_factors_weight
+ kl_structure_factors = (
+ _kl_divergence_folded_normal
+ * loss_settings.prior_structure_factors_weight
+ )
# print("bkg", background_distribution.sample().shape, prior_background.sample())
# print(background_distribution.scale.device, prior_background.concentration.device)
- kl_background = self.compute_kl_divergence(
- background_distribution, prior_background
- ) * loss_settings.prior_background_weight
+ kl_background = (
+ self.compute_kl_divergence(background_distribution, prior_background)
+ * loss_settings.prior_background_weight
+ )
print("shape background kl", kl_background.shape)
- kl_profile = (
- profile_distribution.kl_divergence_weighted(
- loss_settings.prior_profile_weight,
- )
+ kl_profile = profile_distribution.kl_divergence_weighted(
+ loss_settings.prior_profile_weight,
)
print("shape profile kl", kl_profile.shape)
print("scale_distribution type:", type(scale_distribution))
if isinstance(scale_distribution, Delta):
- kl_scale = torch.tensor(0.0, device=scale_distribution.mean.device, dtype=scale_distribution.mean.dtype)
+ kl_scale = torch.tensor(
+ 0.0,
+ device=scale_distribution.mean.device,
+ dtype=scale_distribution.mean.dtype,
+ )
else:
- kl_scale = self.compute_kl_divergence_verbose(
- scale_distribution, prior_scale
- ) * loss_settings.prior_scale_weight
+ kl_scale = (
+ self.compute_kl_divergence_verbose(scale_distribution, prior_scale)
+ * loss_settings.prior_scale_weight
+ )
print("kl scale", kl_scale.shape, kl_scale)
- rate = photon_rate #.clamp(min=loss_settings.eps)
- log_likelihood = torch.distributions.Poisson(rate=rate+loss_settings.eps).log_prob(
- counts.unsqueeze(1)) #n_batches, mc_samples, pixels
+ rate = photon_rate # .clamp(min=loss_settings.eps)
+ log_likelihood = torch.distributions.Poisson(
+ rate=rate + loss_settings.eps
+ ).log_prob(
+ counts.unsqueeze(1)
+ ) # n_batches, mc_samples, pixels
print("dim loglikelihood", log_likelihood.shape)
- log_likelihood_mean = torch.mean(log_likelihood, dim=1) * mask # .unsqueeze(-1)
+ log_likelihood_mean = (
+ torch.mean(log_likelihood, dim=1) * mask
+ ) # .unsqueeze(-1)
print("log_likelihood_mean: n_batches, pixels", log_likelihood_mean.shape)
print("mask", mask.shape)
- negative_log_likelihood_per_batch = (-log_likelihood_mean).sum(dim=1)# batch_size
+ negative_log_likelihood_per_batch = (-log_likelihood_mean).sum(
+ dim=1
+ ) # batch_size
+
+ loss_per_batch = (
+ negative_log_likelihood_per_batch
+ + kl_structure_factors
+ + kl_background
+ + kl_profile
+ + kl_scale
+ )
- loss_per_batch = negative_log_likelihood_per_batch + kl_structure_factors + kl_background + kl_profile + kl_scale
-
loss_output.loss = loss_per_batch.mean()
loss_output.kl_structure_factors = kl_structure_factors.mean()
@@ -167,5 +231,3 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_
return loss_output
except Exception as e:
print(f"loss computation failed with {e}")
-
-
diff --git a/src/factory/metadata_encoder.py b/src/factory/metadata_encoder.py
index a791a85..f9b786c 100644
--- a/src/factory/metadata_encoder.py
+++ b/src/factory/metadata_encoder.py
@@ -1,9 +1,11 @@
-import torch
import math
+
+import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
+
def weight_initializer(weight):
fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2])
std = math.sqrt(1.0 / fan_avg / 10.0)
@@ -57,6 +59,7 @@ def forward(self, x):
return out
+
class SimpleMetadataEncoder(nn.Module):
def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None):
super().__init__()
@@ -122,6 +125,7 @@ def forward(self, x):
return out
+
class BaseMetadataEncoder(nn.Module):
def __init__(self, feature_dim=2, depth=10, dropout=0.0, output_dims=64):
super().__init__()
@@ -150,15 +154,29 @@ def forward(self, x):
if torch.isnan(x).any():
print("WARNING: NaN values in BaseMetadataEncoder input!")
print("NaN count:", torch.isnan(x).sum().item())
- print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item())
-
+ print(
+ "Stats - min:",
+ x.min().item(),
+ "max:",
+ x.max().item(),
+ "mean:",
+ x.mean().item(),
+ )
+
# Process through the model
x = self.model(x)
-
+
# Check output for NaN values
if torch.isnan(x).any():
print("WARNING: NaN values in BaseMetadataEncoder output!")
print("NaN count:", torch.isnan(x).sum().item())
- print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item())
-
- return x
\ No newline at end of file
+ print(
+ "Stats - min:",
+ x.min().item(),
+ "max:",
+ x.max().item(),
+ "mean:",
+ x.mean().item(),
+ )
+
+ return x
diff --git a/src/factory/model.py b/src/factory/model.py
index da22ee6..4e2575e 100644
--- a/src/factory/model.py
+++ b/src/factory/model.py
@@ -1,73 +1,68 @@
-
-import sys, os
+import os
+import sys
repo_root = os.path.abspath(os.path.join(__file__, os.pardir))
-inner_pkg = os.path.join(repo_root, "abismal_torch")
+inner_pkg = os.path.join(repo_root, "abismal_torch")
sys.path.insert(0, inner_pkg)
sys.path.insert(0, repo_root)
import cProfile
-
-import pandas as pd
+import dataclasses
+import io
import data_loader
-import torch
-import numpy as np
+import distributions
+import get_protein_data
import lightning as L
-import dataclasses
+import loss_functionality
+import matplotlib.pyplot as plt
import metadata_encoder
+import numpy as np
+import pandas as pd
+import reciprocalspaceship as rs
+import rs_distributions.modules as rsm
+import settings
import shoebox_encoder
+import torch
import torch.nn.functional as F
-from networks import *
-import distributions
-import loss_functionality
-from callbacks import LossLogging, Plotting, CorrelationPlotting, ScalePlotting
-import get_protein_data
-import reciprocalspaceship as rs
-from abismal_torch.prior import WilsonPrior
-from abismal_torch.symmetry.reciprocal_asu import ReciprocalASU, ReciprocalASUGraph
-from reciprocalspaceship.utils import apply_to_hkl, generate_reciprocal_asu
-from rasu import *
+import wandb
from abismal_torch.likelihood import NormalLikelihood
-from abismal_torch.surrogate_posterior import FoldedNormalPosterior
-from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior
from abismal_torch.merging import VariationalMergingModel
+from abismal_torch.prior import WilsonPrior
from abismal_torch.scaling import ImageScaler
-from positional_encoding import positional_encoding
-
-from lightning.pytorch.loggers import WandbLogger
+from abismal_torch.surrogate_posterior import FoldedNormalPosterior
+from abismal_torch.symmetry.reciprocal_asu import ReciprocalASU, ReciprocalASUGraph
+from adabelief_pytorch import AdaBelief
+from callbacks import CorrelationPlotting, LossLogging, Plotting, ScalePlotting
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
-import wandb
-# from pytorch_lightning.callbacks import ModelCheckpoint
-
-from torchvision.transforms import ToPILImage
+from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities import grad_norm
-import matplotlib.pyplot as plt
-import io
+from masked_adam import MaskedAdam, make_lazy_adabelief_for_surrogate_posterior
+from networks import *
from PIL import Image
+from positional_encoding import positional_encoding
+from rasu import *
+from reciprocalspaceship.utils import apply_to_hkl, generate_reciprocal_asu
+from settings import LossSettings, ModelSettings
+from torch.optim.lr_scheduler import MultiStepLR, StepLR
+from torchvision.transforms import ToPILImage
-from settings import ModelSettings, LossSettings
-import settings
+from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior
-import rs_distributions.modules as rsm
-from masked_adam import MaskedAdam, make_lazy_adabelief_for_surrogate_posterior
+# from pytorch_lightning.callbacks import ModelCheckpoint
-from torch.optim.lr_scheduler import StepLR, MultiStepLR
-from adabelief_pytorch import AdaBelief
-import wandb
torch.set_float32_matmul_precision("medium")
import torch
-
class Model(L.LightningModule):
@@ -80,30 +75,56 @@ class Model(L.LightningModule):
def __init__(self, model_settings: ModelSettings, loss_settings: LossSettings):
super().__init__()
self.model_settings = model_settings
- self.profile_encoder = model_settings.profile_encoder #encoder.CNN_3d()
+ self.profile_encoder = model_settings.profile_encoder # encoder.CNN_3d()
self.metadata_encoder = model_settings.metadata_encoder
self.intensity_encoder = model_settings.intensity_encoder
self.scale_function = self.model_settings.scale_function
- self.profile_distribution = self.model_settings.build_shoebox_profile_distribution(torch.load(os.path.join(self.model_settings.data_directory, "concentration.pt"), weights_only=True), dmodel=self.model_settings.dmodel)
- self.background_distribution = self.model_settings.build_background_distribution(dmodel=self.model_settings.dmodel)
+ self.profile_distribution = (
+ self.model_settings.build_shoebox_profile_distribution(
+ torch.load(
+ os.path.join(
+ self.model_settings.data_directory, "concentration.pt"
+ ),
+ weights_only=True,
+ ),
+ dmodel=self.model_settings.dmodel,
+ )
+ )
+ self.background_distribution = (
+ self.model_settings.build_background_distribution(
+ dmodel=self.model_settings.dmodel
+ )
+ )
self.surrogate_posterior = self.initialize_surrogate_posterior()
- self.dispersion_param = torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001))torch.nn.Parameter(torch.tensor(0.001)) #torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001))
+ self.dispersion_param = torch.tensor(
+ loss_settings.dispersion_parameter
+ ) # torch.nn.Parameter(torch.tensor(0.001))torch.nn.Parameter(torch.tensor(0.001)) #torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001))
self.loss_settings = loss_settings
self.loss_function = loss_functionality.LossFunction()
self.loss = torch.Tensor(0)
- self.lr = self.model_settings.learning_rate
+ self.lr = self.model_settings.learning_rate
# self._train_dataloader = train_dataloader
- pdb_data = get_protein_data.get_protein_data(self.model_settings.protein_pdb_url)
- rac = ReciprocalASUGraph(*[ReciprocalASU(
- cell=pdb_data["unit_cell"],
- spacegroup=pdb_data["spacegroup"],
- dmin=float(pdb_data["dmin"]),
- anomalous=True,
- )])
- self.intensity_prior_distibution = self.model_settings.build_intensity_prior_distribution(rac)
- H_rasu = generate_reciprocal_asu(pdb_data["unit_cell"], pdb_data["spacegroup"], float(pdb_data["dmin"]), True)
+ pdb_data = get_protein_data.get_protein_data(
+ self.model_settings.protein_pdb_url
+ )
+ rac = ReciprocalASUGraph(
+ *[
+ ReciprocalASU(
+ cell=pdb_data["unit_cell"],
+ spacegroup=pdb_data["spacegroup"],
+ dmin=float(pdb_data["dmin"]),
+ anomalous=True,
+ )
+ ]
+ )
+ self.intensity_prior_distibution = (
+ self.model_settings.build_intensity_prior_distribution(rac)
+ )
+ H_rasu = generate_reciprocal_asu(
+ pdb_data["unit_cell"], pdb_data["spacegroup"], float(pdb_data["dmin"]), True
+ )
self.kl_structure_factors_loss = torch.tensor(0)
def setup(self, stage=None):
@@ -112,7 +133,6 @@ def setup(self, stage=None):
self.model_settings = Model.move_settings(self.model_settings, device)
self.loss_settings = Model.move_settings(self.loss_settings, device)
-
@property
def device(self):
return next(self.parameters()).device
@@ -125,25 +145,24 @@ def to(self, *args, **kwargs):
self.model_settings = Model.move_settings(self.model_settings, self.device)
self.loss_settings = Model.move_settings(self.loss_settings, self.device)
- if hasattr(self, 'surrogate_posterior'):
+ if hasattr(self, "surrogate_posterior"):
self.surrogate_posterior.to(self.device)
-
- if hasattr(self, 'profile_distribution'):
+
+ if hasattr(self, "profile_distribution"):
self.profile_distribution.to(self.device)
-
- if hasattr(self, 'background_distribution'):
+
+ if hasattr(self, "background_distribution"):
self.background_distribution.to(self.device)
- if hasattr(self, 'scale_function'):
+ if hasattr(self, "scale_function"):
self.scale_function.to(self.device)
-
- if hasattr(self, 'dataloader') and self.dataloader is not None:
- if hasattr(self.dataloader, 'pin_memory'):
+
+ if hasattr(self, "dataloader") and self.dataloader is not None:
+ if hasattr(self.dataloader, "pin_memory"):
self.dataloader.pin_memory = True
return self
-
@staticmethod
def move_settings(settings: dataclasses.dataclass, device: torch.device):
moved = {}
@@ -153,8 +172,14 @@ def move_settings(settings: dataclasses.dataclass, device: torch.device):
val = getattr(settings, field.name)
if isinstance(val, torch.distributions.Distribution):
param_names = val.arg_constraints.keys()
- params = {name: getattr(val, name).to(device) if hasattr(getattr(val, name), 'to') else getattr(val, name)
- for name in param_names}
+ params = {
+ name: (
+ getattr(val, name).to(device)
+ if hasattr(getattr(val, name), "to")
+ else getattr(val, name)
+ )
+ for name in param_names
+ }
moved[field.name] = type(val)(**params)
# print("moved", val)
@@ -162,72 +187,126 @@ def move_settings(settings: dataclasses.dataclass, device: torch.device):
try:
moved[field.name] = val.to(device)
except Exception as e:
- moved[field.name] = val
+ moved[field.name] = val
else:
moved[field.name] = val
return dataclasses.replace(settings, **moved)
-
def initialize_surrogate_posterior(self):
- initial_mean = self.model_settings.intensity_prior_distibution.distribution().mean() # self.model_settings.intensity_distribution_initial_location
+ initial_mean = (
+ self.model_settings.intensity_prior_distibution.distribution().mean()
+ ) # self.model_settings.intensity_distribution_initial_location
print("intensity folded normal mean shape", initial_mean.shape)
- initial_scale = 0.05 * initial_mean # self.model_settings.intensity_distribution_initial_scale
-
- surrogate_posterior = FrequencyTrackingPosterior.from_unconstrained_loc_and_scale(
- rac=self.model_settings.rac, loc=initial_mean, scale=initial_scale # epsilon=settings.epsilon
+ initial_scale = (
+ 0.05 * initial_mean
+ ) # self.model_settings.intensity_distribution_initial_scale
+
+ surrogate_posterior = (
+ FrequencyTrackingPosterior.from_unconstrained_loc_and_scale(
+ rac=self.model_settings.rac,
+ loc=initial_mean,
+ scale=initial_scale, # epsilon=settings.epsilon
+ )
)
def check_grad_hook(grad):
if grad is not None:
- print(f"Gradient stats: min={grad.min()}, max={grad.max()}, mean={grad.mean()}, any_nan={torch.isnan(grad).any()}, all_finite={torch.isfinite(grad).all()}")
+ print(
+ f"Gradient stats: min={grad.min()}, max={grad.max()}, mean={grad.mean()}, any_nan={torch.isnan(grad).any()}, all_finite={torch.isfinite(grad).all()}"
+ )
return grad
+
surrogate_posterior.distribution.loc.register_hook(check_grad_hook)
return surrogate_posterior
-
- def compute_scale(self, representations: list[torch.Tensor]) -> torch.distributions.Normal:
- joined_representation = sum(representations) if len(representations) > 1 else representations[0] #self._add_representation_(image_representation, metadata_representation) # (batch_size, dmodel)
+
+ def compute_scale(
+ self, representations: list[torch.Tensor]
+ ) -> torch.distributions.Normal:
+ joined_representation = (
+ sum(representations) if len(representations) > 1 else representations[0]
+ ) # self._add_representation_(image_representation, metadata_representation) # (batch_size, dmodel)
scale = self.scale_function(joined_representation)
return scale.distribution
-
+
def _add_representation_(self, representation1, representation2):
return representation1 + representation2
-
+
def compute_shoebox_profile(self, representations: list[torch.Tensor]):
return self.profile_distribution.compute_profile(*representations)
def compute_background_distribution(self, intensity_representation):
return self.background_distribution(intensity_representation)
- def compute_photon_rate(self, scale_distribution, background_distribution, profile, surrogate_posterior, ordered_miller_indices, metadata, verbose_output) -> torch.Tensor:
-
- _samples_predicted_structure_factor = surrogate_posterior.rsample([self.model_settings.number_of_mc_samples]).unsqueeze(-1).permute(1, 0, 2)
- samples_predicted_structure_factor = self.surrogate_posterior.rac.gather(_samples_predicted_structure_factor, torch.tensor([0]), ordered_miller_indices)
+ def compute_photon_rate(
+ self,
+ scale_distribution,
+ background_distribution,
+ profile,
+ surrogate_posterior,
+ ordered_miller_indices,
+ metadata,
+ verbose_output,
+ ) -> torch.Tensor:
+
+ _samples_predicted_structure_factor = (
+ surrogate_posterior.rsample([self.model_settings.number_of_mc_samples])
+ .unsqueeze(-1)
+ .permute(1, 0, 2)
+ )
+ samples_predicted_structure_factor = self.surrogate_posterior.rac.gather(
+ _samples_predicted_structure_factor,
+ torch.tensor([0]),
+ ordered_miller_indices,
+ )
- samples_scale = scale_distribution.rsample([self.model_settings.number_of_mc_samples]).permute(1, 0, 2) #(batch_size, number_of_samples, 1)
+ samples_scale = scale_distribution.rsample(
+ [self.model_settings.number_of_mc_samples]
+ ).permute(
+ 1, 0, 2
+ ) # (batch_size, number_of_samples, 1)
self.logger.experiment.log({"scale_samples_average": torch.mean(samples_scale)})
- samples_profile = profile.rsample([self.model_settings.number_of_mc_samples]).permute(1,0,2)
- samples_background = background_distribution.rsample([self.model_settings.number_of_mc_samples]).permute(1, 0, 2) #(batch_size, number_of_samples, 1)
+ samples_profile = profile.rsample(
+ [self.model_settings.number_of_mc_samples]
+ ).permute(1, 0, 2)
+ samples_background = background_distribution.rsample(
+ [self.model_settings.number_of_mc_samples]
+ ).permute(
+ 1, 0, 2
+ ) # (batch_size, number_of_samples, 1)
if verbose_output:
- photon_rate = samples_scale * torch.square(samples_predicted_structure_factor) * samples_profile + samples_background
+ photon_rate = (
+ samples_scale
+ * torch.square(samples_predicted_structure_factor)
+ * samples_profile
+ + samples_background
+ )
return {
- 'photon_rate': photon_rate,
- 'samples_profile': samples_profile,
- 'samples_predicted_structure_factor': samples_predicted_structure_factor,
- 'samples_scale': samples_scale,
- 'samples_background': samples_background
+ "photon_rate": photon_rate,
+ "samples_profile": samples_profile,
+ "samples_predicted_structure_factor": samples_predicted_structure_factor,
+ "samples_scale": samples_scale,
+ "samples_background": samples_background,
}
else:
- photon_rate = samples_scale * torch.square(samples_predicted_structure_factor) * samples_profile + samples_background # [batch_size, mc_samples, pixels]
+ photon_rate = (
+ samples_scale
+ * torch.square(samples_predicted_structure_factor)
+ * samples_profile
+ + samples_background
+ ) # [batch_size, mc_samples, pixels]
return photon_rate
-
+
def _cut_metadata(self, metadata) -> torch.Tensor:
# metadata_cut = torch.index_select(metadata, dim=1, index=torch.tensor(self.model_settings.metadata_indices_to_keep, device=self.device))
if self.model_settings.use_positional_encoding:
- encoded_metadata = positional_encoding(X=metadata, L=self.model_settings.number_of_frequencies_in_positional_encoding)
+ encoded_metadata = positional_encoding(
+ X=metadata,
+ L=self.model_settings.number_of_frequencies_in_positional_encoding,
+ )
print("encoding dim metadata", encoded_metadata.shape)
return encoded_metadata
else:
@@ -242,28 +321,52 @@ def _log_representation_stats(self, tensor: torch.Tensor, name: str):
f"{name}/min": tensor.min().item(),
f"{name}/max": tensor.max().item(),
}
-
- if hasattr(self, 'logger') and self.logger is not None:
- self.logger.experiment.log(stats)
+ if hasattr(self, "logger") and self.logger is not None:
+ self.logger.experiment.log(stats)
def _batch_to_representations(self, batch: tuple):
- shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch
- standardized_counts = shoeboxes_batch #[:,:,-1].reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21)
+ (
+ shoeboxes_batch,
+ metadata_batch,
+ dead_pixel_mask_batch,
+ counts_batch,
+ hkl_batch,
+ processed_metadata_batch,
+ ) = batch
+ standardized_counts = (
+ shoeboxes_batch # [:,:,-1].reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21)
+ )
print("sb shape", standardized_counts.shape)
- profile_representation = self.profile_encoder(standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), mask=dead_pixel_mask_batch)
+ profile_representation = self.profile_encoder(
+ standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21),
+ mask=dead_pixel_mask_batch,
+ )
- intensity_representation = self.intensity_encoder(standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), mask=dead_pixel_mask_batch)
+ intensity_representation = self.intensity_encoder(
+ standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21),
+ mask=dead_pixel_mask_batch,
+ )
+ metadata_representation = self.metadata_encoder(
+ self._cut_metadata(processed_metadata_batch).float()
+ ) # (batch_size, dmodel)
+ print(
+ "metadata representation (batch size, dmodel)",
+ metadata_representation.shape,
+ )
- metadata_representation = self.metadata_encoder(self._cut_metadata(processed_metadata_batch).float()) # (batch_size, dmodel)
- print("metadata representation (batch size, dmodel)", metadata_representation.shape)
+ joined_shoebox_representation = (
+ intensity_representation + metadata_representation
+ )
- joined_shoebox_representation = intensity_representation + metadata_representation
-
- pooled_image_representation = torch.max(joined_shoebox_representation, dim=0, keepdim=True)[0] # (1, dmodel)
- image_representation = pooled_image_representation #+ metadata_representation
+ pooled_image_representation = torch.max(
+ joined_shoebox_representation, dim=0, keepdim=True
+ )[
+ 0
+ ] # (1, dmodel)
+ image_representation = pooled_image_representation # + metadata_representation
print("image rep (1, dmodel)", image_representation.shape)
if torch.isnan(metadata_representation).any():
@@ -272,29 +375,53 @@ def _batch_to_representations(self, batch: tuple):
raise ValueError("MLP profile_representation produced NaNs!")
if torch.isnan(intensity_representation).any():
raise ValueError("MLP intensity_representation produced NaNs!")
-
-
- return intensity_representation, metadata_representation, image_representation, profile_representation
-
+
+ return (
+ intensity_representation,
+ metadata_representation,
+ image_representation,
+ profile_representation,
+ )
+
def forward(self, batch, verbose_output=False):
try:
- intensity_representation, metadata_representation, image_representation, profile_representation = self._batch_to_representations(batch=batch)
+ (
+ intensity_representation,
+ metadata_representation,
+ image_representation,
+ profile_representation,
+ ) = self._batch_to_representations(batch=batch)
self._log_representation_stats(profile_representation, "profile_rep")
self._log_representation_stats(metadata_representation, "metadata_rep")
self._log_representation_stats(image_representation, "image_rep")
self._log_representation_stats(intensity_representation, "intensity_rep")
-
- shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch
+
+ (
+ shoeboxes_batch,
+ metadata_batch,
+ dead_pixel_mask_batch,
+ counts_batch,
+ hkl_batch,
+ processed_metadata_batch,
+ ) = batch
print("unique hkls", len(torch.unique(hkl_batch)))
- scale_distribution = self.compute_scale(representations=[metadata_representation, image_representation])#[image_representation, metadata_representation]) # torch.distributions.Normal instance
+ scale_distribution = self.compute_scale(
+ representations=[metadata_representation, image_representation]
+ ) # [image_representation, metadata_representation]) # torch.distributions.Normal instance
print("compute profile")
- shoebox_profile = self.compute_shoebox_profile(representations=[profile_representation])#, image_representation])
+ shoebox_profile = self.compute_shoebox_profile(
+ representations=[profile_representation]
+ ) # , image_representation])
print("compute bg")
- background_distribution = self.compute_background_distribution(intensity_representation=intensity_representation)# shoebox_representation)#+image_representation)
- self.surrogate_posterior.update_observed(rasu_id=self.surrogate_posterior.rac.rasu_ids[0], H=hkl_batch)
+ background_distribution = self.compute_background_distribution(
+ intensity_representation=intensity_representation
+ ) # shoebox_representation)#+image_representation)
+ self.surrogate_posterior.update_observed(
+ rasu_id=self.surrogate_posterior.rac.rasu_ids[0], H=hkl_batch
+ )
print("compute rate")
photon_rate = self.compute_photon_rate(
@@ -304,12 +431,20 @@ def forward(self, batch, verbose_output=False):
surrogate_posterior=self.surrogate_posterior,
ordered_miller_indices=hkl_batch,
metadata=metadata_batch,
- verbose_output=verbose_output
+ verbose_output=verbose_output,
)
if verbose_output:
return (shoeboxes_batch, photon_rate, hkl_batch, counts_batch)
- return (counts_batch, dead_pixel_mask_batch, hkl_batch, photon_rate, background_distribution, scale_distribution, self.surrogate_posterior,
- self.profile_distribution)
+ return (
+ counts_batch,
+ dead_pixel_mask_batch,
+ hkl_batch,
+ photon_rate,
+ background_distribution,
+ scale_distribution,
+ self.surrogate_posterior,
+ self.profile_distribution,
+ )
except Exception as e:
print(f"failed in forward: {e}")
for param in self.parameters():
@@ -323,33 +458,43 @@ def training_step_(self, batch):
for n, p in self.named_parameters():
if not n.startswith("surrogate_posterior."):
p.requires_grad_(False)
- local_runs = 3 if (self.current_epoch < 0.3*self.trainer.max_epochs) else 1
+ local_runs = 3 if (self.current_epoch < 0.3 * self.trainer.max_epochs) else 1
for _ in range(local_runs):
output = self(batch=batch)
- local_loss = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param).loss
+ local_loss = self.loss_function(
+ *output, self.model_settings, self.loss_settings, self.dispersion_param
+ ).loss
opt_loc.zero_grad(set_to_none=True)
- self.manual_backward(local_loss) # builds sparse grads on rows in idx
+ self.manual_backward(local_loss) # builds sparse grads on rows in idx
opt_loc.step()
self.train()
for n, p in self.named_parameters():
if not n.startswith("surrogate_posterior."):
p.requires_grad_(True)
output = self(batch=batch)
- loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param)
+ loss_output = self.loss_function(
+ *output, self.model_settings, self.loss_settings, self.dispersion_param
+ )
opt_main.zero_grad()
loss = loss_output.loss
loss.backward()
opt_main.step()
self.loss = loss.detach()
- self.logger.experiment.log({
+ self.logger.experiment.log(
+ {
"loss/kl_structure_factors_step": loss_output.kl_structure_factors.item(),
"loss/kl_background": loss_output.kl_background.item(),
"loss/kl_profile": loss_output.kl_profile.item(),
"loss/kl_scale": loss_output.kl_scale.item(),
"loss/log_likelihood": loss_output.log_likelihood.item(),
- })
+ }
+ )
self.log("disp_param", self.dispersion_param, on_step=True, on_epoch=True)
- self.log("loss/kl_structure_factors", loss_output.kl_structure_factors.item(),on_epoch=True)
+ self.log(
+ "loss/kl_structure_factors",
+ loss_output.kl_structure_factors.item(),
+ on_epoch=True,
+ )
self.loss = loss
print("log loss step")
self.log("train/loss_step", self.loss, on_step=True, on_epoch=True)
@@ -357,20 +502,31 @@ def training_step_(self, batch):
for name, norm in norms.items():
self.log(f"grad_norm/{name}", norm)
return self.loss
-
-
+
def training_step(self, batch):
output = self(batch=batch)
- loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param, self.current_epoch)
+ loss_output = self.loss_function(
+ *output,
+ self.model_settings,
+ self.loss_settings,
+ self.dispersion_param,
+ self.current_epoch,
+ )
self.loss = loss_output.loss.detach()
- self.logger.experiment.log({
+ self.logger.experiment.log(
+ {
"loss/kl_structure_factors": loss_output.kl_structure_factors.item(),
"loss/kl_background": loss_output.kl_background.item(),
"loss/kl_profile": loss_output.kl_profile.item(),
"loss/kl_scale": loss_output.kl_scale.item(),
"loss/log_likelihood": loss_output.log_likelihood.item(),
- })
- self.log("loss/kl_structure_factors", loss_output.kl_structure_factors.item(),on_epoch=True)
+ }
+ )
+ self.log(
+ "loss/kl_structure_factors",
+ loss_output.kl_structure_factors.item(),
+ on_epoch=True,
+ )
self.loss = loss_output.loss
self.log("train/loss_step", self.loss, on_step=True, on_epoch=True)
norms = grad_norm(self, norm_type=2)
@@ -378,43 +534,75 @@ def training_step(self, batch):
self.log(f"grad_norm/{name}", norm)
return self.loss
-
- def validation_step(self, batch, batch_idx):
+ def validation_step(self, batch, batch_idx):
print(f"Validation step called for batch {batch_idx}")
output = self(batch=batch)
- loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param, self.current_epoch)
+ loss_output = self.loss_function(
+ *output,
+ self.model_settings,
+ self.loss_settings,
+ self.dispersion_param,
+ self.current_epoch,
+ )
- self.log("validation_loss/kl_structure_factors", loss_output.kl_structure_factors.item(), on_step=False, on_epoch=True)
- self.log("validation_loss/kl_background", loss_output.kl_background.item(), on_step=False, on_epoch=True)
- self.log("validation_loss/kl_profile", loss_output.kl_profile.item(), on_step=False, on_epoch=True)
- self.log("validation_loss/log_likelihood", loss_output.log_likelihood.item(), on_step=False, on_epoch=True)
- self.log("validation_loss/loss", loss_output.loss.item(), on_step=True, on_epoch=True)
- self.log("validation/loss_step_epoch", loss_output.loss.item(), on_step=False, on_epoch=True)
+ self.log(
+ "validation_loss/kl_structure_factors",
+ loss_output.kl_structure_factors.item(),
+ on_step=False,
+ on_epoch=True,
+ )
+ self.log(
+ "validation_loss/kl_background",
+ loss_output.kl_background.item(),
+ on_step=False,
+ on_epoch=True,
+ )
+ self.log(
+ "validation_loss/kl_profile",
+ loss_output.kl_profile.item(),
+ on_step=False,
+ on_epoch=True,
+ )
+ self.log(
+ "validation_loss/log_likelihood",
+ loss_output.log_likelihood.item(),
+ on_step=False,
+ on_epoch=True,
+ )
+ self.log(
+ "validation_loss/loss", loss_output.loss.item(), on_step=True, on_epoch=True
+ )
+ self.log(
+ "validation/loss_step_epoch",
+ loss_output.loss.item(),
+ on_step=False,
+ on_epoch=True,
+ )
return loss_output.loss
-
-
def surrogate_full_sweep(self, dataloader=None, K=1):
dl = self.dataloader
opt_main, opt_loc = self.optimizers()
for n, p in self.named_parameters():
if not n.startswith("surrogate_posterior."):
p.requires_grad_(False)
- self.eval()
+ self.eval()
self.surrogate_posterior.train()
for xb in dl:
xb = self.transfer_batch_to_device(xb, self.device)
-
- for _ in range(K):
+
+ for _ in range(K):
out = self(batch=xb)
- loss_obj = self.loss_function(*out, self.model_settings, self.loss_settings, self.dispersion_param)
+ loss_obj = self.loss_function(
+ *out, self.model_settings, self.loss_settings, self.dispersion_param
+ )
opt_loc.zero_grad(set_to_none=True)
self.manual_backward(loss_obj.loss)
opt_loc.step()
- self.train()
+ self.train()
for n, p in self.named_parameters():
if not n.startswith("surrogate_posterior."):
p.requires_grad_(True)
@@ -431,23 +619,24 @@ def on_train_epoch_end(self):
# # if metric_value is not None:
# sched.step(metric_value.detach().cpu().item())
- sched.step()
+ sched.step()
def configure_optimizers(self):
"""Configure optimizer based on settings from config YAML."""
-
+
# Get optimizer parameters from settings
optimizer_name = self.model_settings.optimizer_name
lr = self.model_settings.learning_rate
betas = self.model_settings.optimizer_betas
weight_decay = self.model_settings.weight_decay
eps = float(self.model_settings.optimizer_eps)
-
+
if optimizer_name == "AdaBelief":
- opt = AdaBelief( self.parameters(),
- # [
+ opt = AdaBelief(
+ self.parameters(),
+ # [
# {'params': [p for n, p in self.named_parameters() if not n.startswith("surrogate_posterior.")], 'lr': 1e-4},
- # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3},
+ # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3},
# ],
lr=lr,
eps=eps,
@@ -478,26 +667,32 @@ def configure_optimizers(self):
elif optimizer_name == "LazyAdam":
opt = MaskedAdam(
self.parameters(),
- lr = lr,
- betas = betas,
+ lr=lr,
+ betas=betas,
surrogate_posterior_module=self.surrogate_posterior,
weight_decay=weight_decay,
)
- elif optimizer_name =="LazyAdaBelief":
+ elif optimizer_name == "LazyAdaBelief":
opt = make_lazy_adabelief_for_surrogate_posterior(
- self, lr=lr, weight_decay=weight_decay,
- decoupled_weight_decay=self.model_settings.weight_decouple, zero_tol=1e-12
+ self,
+ lr=lr,
+ weight_decay=weight_decay,
+ decoupled_weight_decay=self.model_settings.weight_decouple,
+ zero_tol=1e-12,
+ )
+ print(
+ "self.model_settings.weight_decouple",
+ self.model_settings.weight_decouple,
)
- print("self.model_settings.weight_decouple", self.model_settings.weight_decouple)
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
# Configure scheduler
# opt_local = AdaBelief(
- # [
+ # [
# #{'params': [p for n, p in self.named_parameters() if not n.startswith("surrogate_posterior.")], 'lr': 1e-4},
- # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3},
+ # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3},
# ],
# lr=lr,
# eps=eps,
@@ -506,7 +701,11 @@ def configure_optimizers(self):
# weight_decay=weight_decay,
# )
- sch = MultiStepLR(opt, milestones=self.model_settings.milestones, gamma=self.model_settings.scheduler_gamma)
+ sch = MultiStepLR(
+ opt,
+ milestones=self.model_settings.milestones,
+ gamma=self.model_settings.scheduler_gamma,
+ )
# sch = DebugStepLR(opt, step_size=self.model_settings.scheduler_step, gamma=self.model_settings.scheduler_gamma)
@@ -517,36 +716,55 @@ def configure_optimizers(self):
# return [opt], [{"scheduler": sch, "interval": "epoch", "monitor": "loss/kl_structure_factors"}]
return [opt], [{"scheduler": sch, "interval": "epoch"}]
-
+
def val_dataloader(self):
return self.dataloader.load_data_set_batched_by_image(
- data_set_to_load=self.dataloader.validation_data_set,
+ data_set_to_load=self.dataloader.validation_data_set,
)
-
- def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted_intensities.mtz"):
- self.eval()
+ def inference_to_intensities(
+ self, dataloader, wandb_dir, output_path="predicted_intensities.mtz"
+ ):
+ self.eval()
results = []
for batch_idx, batch in enumerate(dataloader):
# Move batch to device
# batch = self.transfer_batch_to_device(batch, self.device)
- shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch
-
+ (
+ shoeboxes_batch,
+ metadata_batch,
+ dead_pixel_mask_batch,
+ counts_batch,
+ hkl_batch,
+ processed_metadata_batch,
+ ) = batch
# Compute representations
- profile_representation, metadata_representation, image_representation, shoebox_profile_representation = self._batch_to_representations(batch=batch)
+ (
+ profile_representation,
+ metadata_representation,
+ image_representation,
+ shoebox_profile_representation,
+ ) = self._batch_to_representations(batch=batch)
# Get scale and structure factor distributions
- scale_distribution = self.compute_scale([metadata_representation, image_representation])
+ scale_distribution = self.compute_scale(
+ [metadata_representation, image_representation]
+ )
S0 = scale_distribution.rsample([1]).squeeze()
print("shape so", S0.shape)
structure_factor_dist = self.surrogate_posterior
params = self.surrogate_posterior.rac.gather(
- source=torch.stack([self.surrogate_posterior.distribution.loc,
- self.surrogate_posterior.distribution.scale], dim=1),
- rasu_id=torch.tensor([0], device=self.device),
- H=hkl_batch
+ source=torch.stack(
+ [
+ self.surrogate_posterior.distribution.loc,
+ self.surrogate_posterior.distribution.scale,
+ ],
+ dim=1,
+ ),
+ rasu_id=torch.tensor([0], device=self.device),
+ H=hkl_batch,
)
F_loc = params[:, 0]
@@ -556,7 +774,7 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted
print("F_scale", F_scale.shape)
I_mean = S0 * (F_loc**2 + F_scale**2)
- I_sigma = S0 * torch.sqrt(2*F_scale**4 + 4*F_scale**2 * F_loc**2)
+ I_sigma = S0 * torch.sqrt(2 * F_scale**4 + 4 * F_scale**2 * F_loc**2)
# Collect HKL indices
HKL = hkl_batch.cpu().detach().numpy()
@@ -564,13 +782,15 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted
I_sigma_np = I_sigma.cpu().detach().numpy().flatten()
for i in range(HKL.shape[0]):
- results.append({
- "H": int(HKL[i, 0]),
- "K": int(HKL[i, 1]),
- "L": int(HKL[i, 2]),
- "I_mean": float(I_mean_np[i]),
- "I_sigma": float(I_sigma_np[i])
- })
+ results.append(
+ {
+ "H": int(HKL[i, 0]),
+ "K": int(HKL[i, 1]),
+ "L": int(HKL[i, 2]),
+ "I_mean": float(I_mean_np[i]),
+ "I_sigma": float(I_sigma_np[i]),
+ }
+ )
print(f"Processed batch {batch_idx+1}/{len(dataloader)}")
@@ -596,7 +816,7 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted
ds = ds.set_dtypes({"I": "J", "SIGI": "Q"})
else:
ds._mtz_dtypes = {"I": "J", "SIGI": "Q"}
-
+
ds.cell = gemmi.UnitCell(79.1, 79.1, 38.4, 90, 90, 90)
ds.spacegroup = gemmi.SpaceGroup("P43212")
@@ -606,13 +826,14 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted
return mtz_path
+
def compute_I_mean_sigma(F_loc, F_scale, S=1.0):
- """
- F_loc, F_scale: torch tensors of folded normal parameters (loc, scale)
- S: deterministic scale factor
- Returns: I_mean, I_sigma
- """
- I_mean = S * (F_loc**2 + F_scale**2)
- Var_F2 = 2*F_scale**4 + 4*F_scale**2 * F_loc**2
- I_sigma = S * torch.sqrt(Var_F2)
- return I_mean, I_sigma
\ No newline at end of file
+ """
+ F_loc, F_scale: torch tensors of folded normal parameters (loc, scale)
+ S: deterministic scale factor
+ Returns: I_mean, I_sigma
+ """
+ I_mean = S * (F_loc**2 + F_scale**2)
+ Var_F2 = 2 * F_scale**4 + 4 * F_scale**2 * F_loc**2
+ I_sigma = S * torch.sqrt(Var_F2)
+ return I_mean, I_sigma
diff --git a/src/factory/networks.py b/src/factory/networks.py
index 754e1d7..a4dfa1d 100644
--- a/src/factory/networks.py
+++ b/src/factory/networks.py
@@ -1,10 +1,12 @@
-import torch
+import dataclasses
import math
+
+import torch
import torch.nn.functional as F
-import dataclasses
+from tensordict.nn.distributions import Delta
from torch.distributions import TransformedDistribution
from torch.distributions.transforms import AbsTransform
-from tensordict.nn.distributions import Delta
+
def weight_initializer(weight):
fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2])
@@ -12,34 +14,50 @@ def weight_initializer(weight):
a = -2.0 * std
b = 2.0 * std
torch.nn.init.trunc_normal_(weight, 0.0, std, a, b)
-
+
return weight
+
class Linear(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int, bias=False):
super().__init__(in_features, out_features, bias=bias) # Set bias=False
def reset_parameters(self) -> None:
self.weight = weight_initializer(self.weight)
-
+
def forward(self, input):
# Check for NaN values in input
if torch.isnan(input).any():
print(f"WARNING: NaN values in Linear layer input! Shape: {input.shape}")
print("NaN count:", torch.isnan(input).sum().item())
-
+
output = super().forward(input)
-
+
# Check for NaN values in output
if torch.isnan(output).any():
print(f"WARNING: NaN values in Linear layer output! Shape: {output.shape}")
print("NaN count:", torch.isnan(output).sum().item())
- print("Weight stats - min:", self.weight.min().item(), "max:", self.weight.max().item(), "mean:", self.weight.mean().item())
+ print(
+ "Weight stats - min:",
+ self.weight.min().item(),
+ "max:",
+ self.weight.max().item(),
+ "mean:",
+ self.weight.mean().item(),
+ )
if self.bias is not None:
- print("Bias stats - min:", self.bias.min().item(), "max:", self.bias.max().item(), "mean:", self.bias.mean().item())
-
+ print(
+ "Bias stats - min:",
+ self.bias.min().item(),
+ "max:",
+ self.bias.max().item(),
+ "mean:",
+ self.bias.mean().item(),
+ )
+
return output
+
class Constraint(torch.nn.Module):
def __init__(self, eps=1e-12, beta=1.0):
super().__init__()
@@ -58,26 +76,30 @@ def __init__(self, hidden_dim=64, input_dim=7, number_of_layers=2):
self.input_dim = input_dim
self.add_bias = True
self._build_mlp_(number_of_layers=number_of_layers)
-
+
def _build_mlp_(self, number_of_layers):
mlp_layers = []
for i in range(number_of_layers):
- mlp_layers.append(torch.nn.Linear(
- in_features=self.input_dim if i == 0 else self.hidden_dimension,
- out_features=self.hidden_dimension,
- bias=self.add_bias,
- ))
+ mlp_layers.append(
+ torch.nn.Linear(
+ in_features=self.input_dim if i == 0 else self.hidden_dimension,
+ out_features=self.hidden_dimension,
+ bias=self.add_bias,
+ )
+ )
mlp_layers.append(self.activation)
self.network = torch.nn.Sequential(*mlp_layers)
-
+
def forward(self, x):
return self.network(x)
-
+
+
@dataclasses.dataclass
class ScaleOutput:
distribution: torch.distributions.Normal
network: torch.Tensor
+
class BaseDistributionLayer(torch.nn.Module):
def __init__():
@@ -85,7 +107,8 @@ def __init__():
def forward():
pass
-
+
+
class DeltaDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -93,11 +116,12 @@ def __init__(self):
self.bijector = torch.nn.Softplus()
def forward(self, hidden_representation):
- loc = hidden_representation#torch.unbind(hidden_representation, dim=-1)
+ loc = hidden_representation # torch.unbind(hidden_representation, dim=-1)
print("shape los delta dist", loc.shape)
loc = self.bijector(loc) + 1e-3
return Delta(param=loc)
+
class NormalDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -109,6 +133,7 @@ def forward(self, hidden_representation):
scale = self.bijector(scale) + 1e-3
return torch.distributions.Normal(loc=loc, scale=scale)
+
class SoftplusNormalDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -121,6 +146,7 @@ def forward(self, hidden_representation):
# self.normal = torch.distributions.Normal(loc=loc, scale=scale)
return SoftplusNormal(loc=loc, scale=scale)
+
class SoftplusNormal(torch.nn.Module):
def __init__(self, loc, scale):
super().__init__()
@@ -130,7 +156,7 @@ def __init__(self, loc, scale):
def rsample(self, sample_shape=[1]):
return self.bijector(self.normal.rsample(sample_shape)).unsqueeze(-1)
- def log_prob(self,x):
+ def log_prob(self, x):
# x: sample (must be > 0), mu and sigma are parameters
z = torch.log(torch.expm1(x)) # inverse softplus
normal = self.normal
@@ -141,6 +167,7 @@ def log_prob(self,x):
def forward(self, x, number_of_samples=1):
return self.bijector(self.normal(x))
+
class TruncatedNormalDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -153,15 +180,20 @@ def forward(self, hidden_representation):
# self.normal = torch.distributions.Normal(loc=loc, scale=scale)
return PositiveTruncatedNormal(loc=loc, scale=scale)
+
class PositiveTruncatedNormal(torch.nn.Module):
def __init__(self, loc, scale):
super().__init__()
self.normal = torch.distributions.Normal(loc, scale)
self.a = (0.0 - loc) / scale # standardized lower bound = 0
- self.Z = torch.tensor(1.0 - self.normal.cdf(0.0), device=loc.device, dtype=loc.dtype) # normalization constant
+ self.Z = torch.tensor(
+ 1.0 - self.normal.cdf(0.0), device=loc.device, dtype=loc.dtype
+ ) # normalization constant
def rsample(self, sample_shape=torch.Size()):
- u = torch.rand(sample_shape + self.a.shape, device=self.a.device, dtype=self.a.dtype)
+ u = torch.rand(
+ sample_shape + self.a.shape, device=self.a.device, dtype=self.a.dtype
+ )
u = u * self.Z + self.normal.cdf(0.0) # map [0,1] to [cdf(0), 1]
z = self.normal.icdf(u)
return z
@@ -170,6 +202,7 @@ def log_prob(self, x):
logp = self.normal.log_prob(x)
return logp - torch.log(self.Z)
+
class NormalIRSample(torch.autograd.Function):
@staticmethod
def forward(ctx, loc, scale, samples, dFdmu, dFdsig, q):
@@ -186,6 +219,7 @@ def backward(ctx, grad_output):
) = ctx.saved_tensors
return grad_output * dzdmu, grad_output * dzdsig, None, None, None, None
+
class FoldedNormal(torch.distributions.Distribution):
"""
Folded Normal distribution class
@@ -197,7 +231,10 @@ class FoldedNormal(torch.distributions.Distribution):
Default is None.
"""
- arg_constraints = {"loc": torch.distributions.constraints.real, "scale": torch.distributions.constraints.positive}
+ arg_constraints = {
+ "loc": torch.distributions.constraints.real,
+ "scale": torch.distributions.constraints.positive,
+ }
support = torch.distributions.constraints.nonnegative
def __init__(self, loc, scale, validate_args=None):
@@ -323,6 +360,7 @@ def rsample(self, sample_shape=torch.Size()):
samples.requires_grad_(True)
return self._irsample(self.loc, self.scale, samples, dFdmu, dFdsigma, q)
+
class FoldedNormalDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -334,6 +372,7 @@ def forward(self, hidden_representation):
scale = self.bijector(scale) + 1e-3
return FoldedNormal(loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1))
+
class LogNormalDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -345,12 +384,20 @@ def forward(self, hidden_representation):
scale = self.bijector(scale) + 1e-3
print("lognormal loc, scale", loc.shape, scale.shape)
dist = torch.distributions.LogNormal(loc=loc, scale=scale)
- print("torch.distributions.LogNormal(loc=loc, scale=scale)", dist.rsample().shape)
- dist = torch.distributions.LogNormal(loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1))
- print("torch.distributions.LogNormal(loc=loc, scale=scale) unsqueeze", dist.rsample().shape)
+ print(
+ "torch.distributions.LogNormal(loc=loc, scale=scale)", dist.rsample().shape
+ )
+ dist = torch.distributions.LogNormal(
+ loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1)
+ )
+ print(
+ "torch.distributions.LogNormal(loc=loc, scale=scale) unsqueeze",
+ dist.rsample().shape,
+ )
return dist
+
class GammaDistributionLayer(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -361,10 +408,20 @@ def forward(self, hidden_representation):
concentration, rate = torch.unbind(hidden_representation, dim=-1)
rate = self.bijector(rate) + 1e-3
concentration = self.bijector(concentration) + 1e-3
- return torch.distributions.Gamma(concentration=concentration.unsqueeze(-1), rate=rate.unsqueeze(-1))
+ return torch.distributions.Gamma(
+ concentration=concentration.unsqueeze(-1), rate=rate.unsqueeze(-1)
+ )
+
class MLPScale(torch.nn.Module):
- def __init__(self, input_dimension=64, scale_distribution=FoldedNormalDistributionLayer, hidden_dimension=64, number_of_layers=1, initial_scale_guess=2/140):
+ def __init__(
+ self,
+ input_dimension=64,
+ scale_distribution=FoldedNormalDistributionLayer,
+ hidden_dimension=64,
+ number_of_layers=1,
+ initial_scale_guess=2 / 140,
+ ):
super().__init__()
self.activation = torch.nn.ReLU()
self.hidden_dimension = hidden_dimension
@@ -377,11 +434,15 @@ def __init__(self, input_dimension=64, scale_distribution=FoldedNormalDistributi
def _build_mlp_(self, number_of_layers):
mlp_layers = []
for i in range(number_of_layers):
- mlp_layers.append(torch.nn.Linear(
- in_features=self.input_dimension if i == 0 else self.hidden_dimension,
- out_features=self.hidden_dimension,
- bias=self.add_bias,
- ))
+ mlp_layers.append(
+ torch.nn.Linear(
+ in_features=(
+ self.input_dimension if i == 0 else self.hidden_dimension
+ ),
+ out_features=self.hidden_dimension,
+ bias=self.add_bias,
+ )
+ )
mlp_layers.append(self.activation)
self.network = torch.nn.Sequential(*mlp_layers)
@@ -391,16 +452,17 @@ def _build_mlp_(self, number_of_layers):
out_features=self.scale_distribution_layer.len_params,
bias=self.add_bias,
)
-
if self.add_bias:
with torch.no_grad():
for i in range(self.scale_distribution_layer.len_params):
- if i==self.scale_distribution_layer.len_params-1:
- final_linear.bias[i] = torch.log(torch.tensor(self.initial_scale_guess))
+ if i == self.scale_distribution_layer.len_params - 1:
+ final_linear.bias[i] = torch.log(
+ torch.tensor(self.initial_scale_guess)
+ )
else:
final_linear.bias[i] = torch.tensor(0.1)
-
+
map_to_distribution_layers.append(final_linear)
map_to_distribution_layers.append(self.scale_distribution_layer)
@@ -408,8 +470,13 @@ def _build_mlp_(self, number_of_layers):
def forward(self, x) -> ScaleOutput:
h = self.network(x)
- print("MLP output:", h.mean().item(), h.min().item(), h.max().item(),
- "any nan?", torch.isnan(h).any().item())
-
+ print(
+ "MLP output:",
+ h.mean().item(),
+ h.min().item(),
+ h.max().item(),
+ "any nan?",
+ torch.isnan(h).any().item(),
+ )
+
return ScaleOutput(distribution=self.distribution(h), network=h)
-
\ No newline at end of file
diff --git a/src/factory/parse_yaml.py b/src/factory/parse_yaml.py
index bd1dade..b9da134 100644
--- a/src/factory/parse_yaml.py
+++ b/src/factory/parse_yaml.py
@@ -1,14 +1,15 @@
-import yaml
import os
-import wandb
-from settings import ModelSettings, LossSettings, DataLoaderSettings, PhenixSettings
-from abismal_torch.prior import WilsonPrior
+
import distributions
-import networks
-from lightning.pytorch.loggers import WandbLogger
import metadata_encoder
+import networks
import shoebox_encoder
import torch
+import wandb
+import yaml
+from abismal_torch.prior import WilsonPrior
+from lightning.pytorch.loggers import WandbLogger
+from settings import DataLoaderSettings, LossSettings, ModelSettings, PhenixSettings
REGISTRY = {
"HalfNormalDistribution": distributions.HalfNormalDistribution,
@@ -34,14 +35,15 @@
"TorchHalfNormal": torch.distributions.HalfNormal,
"TorchExponential": torch.distributions.Exponential,
"rsFoldedNormal": networks.FoldedNormal,
-
}
+
def _log_settings_from_yaml(path: str, logger: WandbLogger):
- with open(path, 'r') as file:
+ with open(path, "r") as file:
config = yaml.safe_load(file)
logger.experiment.config.update(config)
+
def instantiate_from_config(config: dict):
def resolve_value(v):
if isinstance(v, dict):
@@ -60,23 +62,33 @@ def resolve_value(v):
def consistancy_check(model_settings_dict: dict):
-
+
dmodel = model_settings_dict.get("dmodel")
metadata_indices = model_settings_dict.get("metadata_indices_to_keep", [])
if "shoebox_encoder" in model_settings_dict:
- model_settings_dict["shoebox_encoder"].setdefault("params", {})["out_dim"] = dmodel
+ model_settings_dict["shoebox_encoder"].setdefault("params", {})[
+ "out_dim"
+ ] = dmodel
if "metadata_encoder" in model_settings_dict:
if model_settings_dict.get("use_positional_encoding", False) == True:
- number_of_frequencies = model_settings_dict.get("number_of_frequencies_in_positional_encoding", 2)
- model_settings_dict["metadata_encoder"].setdefault("params", {})["feature_dim"] = len(metadata_indices) * number_of_frequencies * 2
+ number_of_frequencies = model_settings_dict.get(
+ "number_of_frequencies_in_positional_encoding", 2
+ )
+ model_settings_dict["metadata_encoder"].setdefault("params", {})[
+ "feature_dim"
+ ] = (len(metadata_indices) * number_of_frequencies * 2)
else:
- model_settings_dict["metadata_encoder"].setdefault("params", {})["feature_dim"] = len(metadata_indices)
+ model_settings_dict["metadata_encoder"].setdefault("params", {})[
+ "feature_dim"
+ ] = len(metadata_indices)
model_settings_dict["metadata_encoder"]["params"]["output_dims"] = dmodel
if "scale_function" in model_settings_dict:
- model_settings_dict["scale_function"].setdefault("params", {})["input_dimension"] = dmodel
+ model_settings_dict["scale_function"].setdefault("params", {})[
+ "input_dimension"
+ ] = dmodel
def load_settings_from_yaml(config):
@@ -84,7 +96,7 @@ def load_settings_from_yaml(config):
# config = yaml.safe_load(file)
model_settings_dict = config.get("model_settings", {})
-
+
for key, value in model_settings_dict.items():
if value == "None":
model_settings_dict[key] = None
@@ -100,7 +112,11 @@ def load_settings_from_yaml(config):
phenix_settings = PhenixSettings(**config.get("phenix_settings", {}))
dataloader_settings_dict = config.get("dataloader_settings", {})
- dataloader_settings = DataLoaderSettings(**dataloader_settings_dict) if dataloader_settings_dict else None
+ dataloader_settings = (
+ DataLoaderSettings(**dataloader_settings_dict)
+ if dataloader_settings_dict
+ else None
+ )
return model_settings, loss_settings, phenix_settings, dataloader_settings
@@ -116,4 +132,4 @@ def resolve_builders(settings_dict):
if field in settings_dict:
settings_dict[field] = REGISTRY[settings_dict[field]]
- return settings_dict
\ No newline at end of file
+ return settings_dict
diff --git a/src/factory/phenix_callback.py b/src/factory/phenix_callback.py
index 3e6a2dd..08bf118 100644
--- a/src/factory/phenix_callback.py
+++ b/src/factory/phenix_callback.py
@@ -1,22 +1,24 @@
+import concurrent.futures
+import glob
+import multiprocessing as mp
import os
import re
import subprocess
-import pandas as pd
-import wandb
-import settings
-from model import *
-import glob
import time
+
import generate_eff
-import concurrent.futures
-import multiprocessing as mp
import matplotlib.pyplot as plt
+import pandas as pd
+import settings
+import wandb
+from model import *
+
def extract_r_values_from_pdb(pdb_file_path: str) -> tuple[float, float]:
-
+
r_work, r_free = None, None
try:
- with open(pdb_file_path, 'r') as f:
+ with open(pdb_file_path, "r") as f:
for line in f:
if "REMARK 3 R VALUE (WORKING SET)" in line:
r_work = float(line.strip().split(":")[1])
@@ -28,7 +30,7 @@ def extract_r_values_from_pdb(pdb_file_path: str) -> tuple[float, float]:
def extract_r_values_from_log(log_file_path: str) -> tuple:
-
+
pattern1 = re.compile(r"Start R-work")
pattern2 = re.compile(r"Final R-work")
@@ -36,7 +38,7 @@ def extract_r_values_from_log(log_file_path: str) -> tuple:
with open(log_file_path, "r") as f:
lines = f.readlines()
print("opened log file")
-
+
match = re.search(r"epoch=(\d+)", log_file_path)
if match:
epoch = match.group(1)
@@ -48,7 +50,7 @@ def extract_r_values_from_log(log_file_path: str) -> tuple:
matched_lines_final = [line.strip() for line in lines if pattern2.search(line)]
# Initialize all values as None
-
+
rwork_start = rwork_final = rfree_start = rfree_final = None
if matched_lines_start:
@@ -59,7 +61,9 @@ def extract_r_values_from_log(log_file_path: str) -> tuple:
rwork_start = float(numbers[0])
rfree_start = float(numbers[1])
else:
- print(f"Start R-work line found but not enough numbers: {matched_lines_start[0]}")
+ print(
+ f"Start R-work line found but not enough numbers: {matched_lines_start[0]}"
+ )
else:
print("No Start R-work line found in log.")
@@ -69,21 +73,28 @@ def extract_r_values_from_log(log_file_path: str) -> tuple:
rwork_final = float(numbers[0])
rfree_final = float(numbers[1])
else:
- print(f"Final R-work line found but not enough numbers: {matched_lines_final[0]}")
+ print(
+ f"Final R-work line found but not enough numbers: {matched_lines_final[0]}"
+ )
else:
print("No Final R-work line found in log.")
print("finished r-val extraction")
- return epoch_start, epoch_final, [rwork_start, rwork_final], [rfree_start, rfree_final]
+ return (
+ epoch_start,
+ epoch_final,
+ [rwork_start, rwork_final],
+ [rfree_start, rfree_final],
+ )
except Exception as e:
print(f"Could not open log file: {e}")
return None, None, [None, None], [None, None]
-
-
def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str):
- phenix_env = "/n/hekstra_lab/garden_backup/phenix-1.21/phenix-1.21.1-5286/phenix_env.sh"
+ phenix_env = (
+ "/n/hekstra_lab/garden_backup/phenix-1.21/phenix-1.21.1-5286/phenix_env.sh"
+ )
phenix_output_dir = os.path.join(repo_dir, "phenix_output")
os.makedirs(phenix_output_dir, exist_ok=True)
path_to_pdb_model = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/9b7c.pdb"
@@ -95,31 +106,41 @@ def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str):
r_free_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/rfree.mtz"
pdb_model_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/9b7c.pdb"
- specific_phenix_eff_file_path = os.path.join(repo_dir, f"phenix_{checkpoint_name}.eff")
+ specific_phenix_eff_file_path = os.path.join(
+ repo_dir, f"phenix_{checkpoint_name}.eff"
+ )
phenix_out_name = f"refine_{checkpoint_name}"
- subprocess.run([
- "python", "/n/holylabs/hekstra_lab/Users/fgiehr/factory/src/factory/generate_eff.py",
- "--sf-mtz", mtz_output_path,
- "--rfree-mtz", r_free_path,
- "--out", specific_phenix_eff_file_path,
- "--phenix-out-mtz", phenix_out_name,
- ], check=True)
+ subprocess.run(
+ [
+ "python",
+ "/n/holylabs/hekstra_lab/Users/fgiehr/factory/src/factory/generate_eff.py",
+ "--sf-mtz",
+ mtz_output_path,
+ "--rfree-mtz",
+ r_free_path,
+ "--out",
+ specific_phenix_eff_file_path,
+ "--phenix-out-mtz",
+ phenix_out_name,
+ ],
+ check=True,
+ )
- cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {specific_phenix_eff_file_path} overwrite=True" # {mtz_output_path} {r_free_path} {phenix_refine_eff_file_path} {pdb_model_path} overwrite=True"
+ cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {specific_phenix_eff_file_path} overwrite=True" # {mtz_output_path} {r_free_path} {phenix_refine_eff_file_path} {pdb_model_path} overwrite=True"
print("\nRunning phenix with command:", flush=True)
print(cmd, flush=True)
try:
result = subprocess.run(
- cmd,
- shell=True,
- executable="/bin/bash",
- capture_output=True,
- text=True,
- check=True,
- )
+ cmd,
+ shell=True,
+ executable="/bin/bash",
+ capture_output=True,
+ text=True,
+ check=True,
+ )
print("Phenix command completed successfully")
except subprocess.CalledProcessError as e:
@@ -141,10 +162,13 @@ def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str):
return phenix_out_name
+
def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str):
- surrogate_posterior_dataset = list(model.surrogate_posterior.to_dataset(only_observed=True))[0]
+ surrogate_posterior_dataset = list(
+ model.surrogate_posterior.to_dataset(only_observed=True)
+ )[0]
- checkpoint_name, _ = os.path.splitext(os.path.basename(checkpoint_path))
+ checkpoint_name, _ = os.path.splitext(os.path.basename(checkpoint_path))
mtz_output_path = os.path.join(repo_dir, f"phenix_ready_{checkpoint_name}.mtz")
@@ -156,18 +180,21 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str):
else:
print("Error: Failed to write MTZ file")
- phenix_out_name = run_phenix(repo_dir=repo_dir, checkpoint_name=checkpoint_name, mtz_output_path=mtz_output_path)
+ phenix_out_name = run_phenix(
+ repo_dir=repo_dir,
+ checkpoint_name=checkpoint_name,
+ mtz_output_path=mtz_output_path,
+ )
# phenix_out_name = f"refine_{checkpoint_name}"
# solved_artifact = wandb.Artifact("phenix_solved", type="mtz")
pdb_file_path = glob.glob(os.path.join(repo_dir, f"{phenix_out_name}*.pdb"))
print("pdb_file_path", pdb_file_path)
- pdb_file_path=pdb_file_path[-1]
+ pdb_file_path = pdb_file_path[-1]
mtz_file_path = glob.glob(os.path.join(repo_dir, f"{phenix_out_name}*.mtz"))
print("mtz_file_path", mtz_file_path)
- mtz_file_path=mtz_file_path[0]
-
+ mtz_file_path = mtz_file_path[0]
if pdb_file_path:
print(f"Using PDB file: {pdb_file_path}")
@@ -175,12 +202,14 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str):
print("log_file_path", log_file_path)
log_file_path = log_file_path[-1]
print("log_file_path2", log_file_path)
- epoch_start, epoch_final, r_work, r_free = extract_r_values_from_log(log_file_path)
+ epoch_start, epoch_final, r_work, r_free = extract_r_values_from_log(
+ log_file_path
+ )
print("rvals", r_work, r_free)
else:
print("No PDB files found matching pattern!")
r_work, r_free = None, None
-
+
# pdb_file_path = glob.glob(f"{phenix_out_name}*.pdb")[-1]
# mtz_file_path = glob.glob(f"{phenix_out_name}*.mtz")[-1]
print("run find peaks")
@@ -190,21 +219,30 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str):
peaks_file_path = os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv")
import reciprocalspaceship as rs
+
ds = rs.read_mtz(mtz_file_path)
print(ds.columns.tolist())
subprocess.run(
- f"rs.find_peaks {mtz_file_path} {pdb_file_path} -f ANOM -p PANOM -z 5.0 -o {peaks_file_path}",
- shell=True,
- # cwd=f"{repo_dir}",
+ f"rs.find_peaks {mtz_file_path} {pdb_file_path} -f ANOM -p PANOM -z 5.0 -o {peaks_file_path}",
+ shell=True,
+ # cwd=f"{repo_dir}",
)
print("finished find peaks")
- return epoch_start, r_work, r_free, os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv")
+ return (
+ epoch_start,
+ r_work,
+ r_free,
+ os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv"),
+ )
-def process_checkpoint(checkpoint_path, artifact_dir, model_settings, loss_settings, wandb_directory):
+def process_checkpoint(
+ checkpoint_path, artifact_dir, model_settings, loss_settings, wandb_directory
+):
import os
import re
+
import pandas as pd
from model import Model
@@ -213,34 +251,49 @@ def process_checkpoint(checkpoint_path, artifact_dir, model_settings, loss_setti
print("checkpoint_path_", checkpoint_path_)
print("artifact_dir", artifact_dir)
- try:
+ try:
match = re.search(r"best-(\d+)-", checkpoint_path)
epoch = int(match.group(1)) if match else None
- model = Model.load_from_checkpoint(checkpoint_path, model_settings=model_settings, loss_settings=loss_settings)
+ model = Model.load_from_checkpoint(
+ checkpoint_path, model_settings=model_settings, loss_settings=loss_settings
+ )
model.eval()
- epoch_start, r_work, r_free, path_to_peaks = find_peaks_from_model(model=model, repo_dir=wandb_directory, checkpoint_path=checkpoint_path)
+ epoch_start, r_work, r_free, path_to_peaks = find_peaks_from_model(
+ model=model, repo_dir=wandb_directory, checkpoint_path=checkpoint_path
+ )
peaks_df = pd.read_csv(path_to_peaks)
rows = []
for _, peak in peaks_df.iterrows():
- residue = peak['residue']
- peak_height = peak['peakz']
- rows.append({
- 'Epoch': epoch,
- 'Checkpoint': checkpoint_path,
- 'Residue': residue,
- 'Peak_Height': peak_height
- })
+ residue = peak["residue"]
+ peak_height = peak["peakz"]
+ rows.append(
+ {
+ "Epoch": epoch,
+ "Checkpoint": checkpoint_path,
+ "Residue": residue,
+ "Peak_Height": peak_height,
+ }
+ )
return epoch_start, rows, r_work, r_free
except Exception as e:
print(f"Failed for {checkpoint_path}: {e}")
return None, [], None, None
-def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settings, artifact_dir, checkpoint_paths, wandb_directory):
+
+def run_phenix_over_all_checkpoints(
+ model_settings,
+ loss_settings,
+ phenix_settings,
+ artifact_dir,
+ checkpoint_paths,
+ wandb_directory,
+):
+ import multiprocessing as mp
+
import pandas as pd
import wandb
- import multiprocessing as mp
all_rows = []
r_works_start = []
@@ -249,7 +302,6 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
r_frees_final = []
max_peak_heights = []
epochs = []
-
ctx = mp.get_context("spawn")
# Prepare argument tuples for each checkpoint
@@ -264,10 +316,14 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
for epoch_start, rows, r_work, r_free in results:
all_rows.extend(rows)
if (
- isinstance(r_work, list) and len(r_work) == 2 and
- isinstance(r_free, list) and len(r_free) == 2 and
- r_work[0] is not None and r_work[1] is not None and
- r_free[0] is not None and r_free[1] is not None
+ isinstance(r_work, list)
+ and len(r_work) == 2
+ and isinstance(r_free, list)
+ and len(r_free) == 2
+ and r_work[0] is not None
+ and r_work[1] is not None
+ and r_free[0] is not None
+ and r_free[1] is not None
):
r_works_start.append(r_work[0])
r_works_final.append(r_work[1])
@@ -286,10 +342,10 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
try:
fig, ax = plt.subplots(figsize=(10, 6))
- ax.plot(epochs, max_peak_heights, marker='o', label='Max Peak Height')
- ax.set_xlabel('Epoch')
- ax.set_ylabel('Peak Height')
- ax.set_title('Max Peak Height vs Epoch')
+ ax.plot(epochs, max_peak_heights, marker="o", label="Max Peak Height")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Peak Height")
+ ax.set_title("Max Peak Height vs Epoch")
ax.legend()
plt.tight_layout()
wandb.log({"max_peak_heights_plot": wandb.Image(fig)})
@@ -297,9 +353,6 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
except Exception as e:
print(f"failed to plot max peak heights: {e}")
-
-
-
final_df = pd.DataFrame(all_rows)
if not final_df.empty and "Peak_Height" in final_df.columns:
@@ -311,26 +364,27 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
if final_df.empty:
print("No peak heights to log.")
-
# def plot_peak_heights_vs_epoch(df):
try:
df = final_df
import matplotlib.cm as cm
-
+
if df.empty:
print("No data to plot for peak heights vs epoch.")
return
plt.figure(figsize=(10, 6))
- residues = df['Residue'].unique()
- colormap = cm.get_cmap('tab20', len(residues))
+ residues = df["Residue"].unique()
+ colormap = cm.get_cmap("tab20", len(residues))
color_map = {res: colormap(i) for i, res in enumerate(residues)}
for res in residues:
- sub = df[df['Residue'] == res]
- plt.scatter(sub['Epoch'], sub['Peak_Height'], label=str(res), color=color_map[res])
- plt.xlabel('Epoch')
- plt.ylabel('Peak Height')
- plt.title('Peak Height vs Epoch per Residue')
- plt.legend(title='Residue', bbox_to_anchor=(1.05, 1), loc='upper left')
+ sub = df[df["Residue"] == res]
+ plt.scatter(
+ sub["Epoch"], sub["Peak_Height"], label=str(res), color=color_map[res]
+ )
+ plt.xlabel("Epoch")
+ plt.ylabel("Peak Height")
+ plt.title("Peak Height vs Epoch per Residue")
+ plt.legend(title="Residue", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
wandb.log({"PeakHeight_vs_Epoch": wandb.Image(plt.gcf())})
plt.close()
@@ -339,45 +393,74 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin
except Exception as e:
print(f"failed to plot peak heights: {e}")
-
try:
- _, _, r_work_reference, r_free_reference = extract_r_values_from_log(phenix_settings.r_values_reference_path)
+ _, _, r_work_reference, r_free_reference = extract_r_values_from_log(
+ phenix_settings.r_values_reference_path
+ )
r_work_reference = r_work_reference[-1]
r_free_reference = r_free_reference[-1]
except Exception as e:
print(f"failed to extract reference r-vals: {e}")
-
try:
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
- min_y = min(min(r_works_final), min(r_frees_final),r_work_reference, r_free_reference) - 0.01
- max_y = max(max(r_works_final), max(r_frees_final), r_work_reference, r_free_reference) + 0.01
+ min_y = (
+ min(
+ min(r_works_final),
+ min(r_frees_final),
+ r_work_reference,
+ r_free_reference,
+ )
+ - 0.01
+ )
+ max_y = (
+ max(
+ max(r_works_final),
+ max(r_frees_final),
+ r_work_reference,
+ r_free_reference,
+ )
+ + 0.01
+ )
- plt.style.use('seaborn-v0_8-darkgrid')
+ plt.style.use("seaborn-v0_8-darkgrid")
fig, ax = plt.subplots(figsize=(8, 5))
x = epochs # Arbitrary units (e.g., checkpoint index)
# ax.plot(x, r_works_start, marker='o', label='R-work (start)', c="red", alpha=0.7)
- ax.plot(x, r_works_final, marker='o', label='R-work (final)', c="red")
+ ax.plot(x, r_works_final, marker="o", label="R-work (final)", c="red")
# ax.plot(x, r_frees_start, marker='s', label='R-free (start)', c="blue", alpha=0.7)
- ax.plot(x, r_frees_final, marker='s', label='R-free (final)', c="blue")
+ ax.plot(x, r_frees_final, marker="s", label="R-free (final)", c="blue")
if r_work_reference is not None:
- ax.axhline(r_work_reference, color='red', linestyle='dotted', linewidth=0.7, alpha=0.5, label='r_work_reference')
+ ax.axhline(
+ r_work_reference,
+ color="red",
+ linestyle="dotted",
+ linewidth=0.7,
+ alpha=0.5,
+ label="r_work_reference",
+ )
if r_free_reference is not None:
- ax.axhline(r_free_reference, color='blue', linestyle='dotted', linewidth=0.7, alpha=0.5, label='r_free_reference')
+ ax.axhline(
+ r_free_reference,
+ color="blue",
+ linestyle="dotted",
+ linewidth=0.7,
+ alpha=0.5,
+ label="r_free_reference",
+ )
- ax.set_xlabel('Training Epoch')
- ax.set_ylabel('R Value')
+ ax.set_xlabel("Training Epoch")
+ ax.set_ylabel("R Value")
ax.set_ylim(min_y, max_y)
ax.set_yticks(np.linspace(min_y, max_y, 12))
- ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.4f'))
- ax.set_title('R-work and R-free')
- ax.legend(loc='best', fontsize=10, frameon=True)
+ ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.4f"))
+ ax.set_title("R-work and R-free")
+ ax.legend(loc="best", fontsize=10, frameon=True)
plt.tight_layout()
wandb.log({"R_values_plot": wandb.Image(fig)})
plt.close(fig)
except Exception as e:
print(f"failed plotting r-values: {e}")
-
diff --git a/src/factory/plot_peaks.py b/src/factory/plot_peaks.py
index dcb7799..bc3a33b 100644
--- a/src/factory/plot_peaks.py
+++ b/src/factory/plot_peaks.py
@@ -1,26 +1,26 @@
-import pandas as pd
+import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
+import wandb
from Bio.PDB import PDBParser
from scipy.spatial import cKDTree
-import matplotlib.pyplot as plt
-import wandb
wandb.init(project="plot anomolous peaks")
-peaks_df = pd.read_csv("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv")
-print(peaks_df.columns)
+peaks_df = pd.read_csv(
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv"
+)
+print(peaks_df.columns)
# print(peaks_df.iloc[0]) #
print(peaks_df)
# Create a table of peak heights
-peak_heights_table = pd.DataFrame({
- 'Peak_Height': peaks_df['peakz']
-})
+peak_heights_table = pd.DataFrame({"Peak_Height": peaks_df["peakz"]})
# Log the peak heights table to wandb
wandb.log({"peak_heights": wandb.Table(dataframe=peak_heights_table)})
-peak_coords = peaks_df[['cenx', 'ceny', 'cenz']].values
+peak_coords = peaks_df[["cenx", "ceny", "cenz"]].values
model_file = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/shaken_9b7c_refine_001.pdb"
parser = PDBParser(QUIET=True)
model = parser.get_structure("model", model_file)
@@ -41,19 +41,21 @@
for i, peak in enumerate(peak_coords):
dist, idx = tree.query(peak)
atom = atoms[idx]
- matches.append({
- "Peak_X": peak[0],
- "Peak_Y": peak[1],
- "Peak_Z": peak[2],
- "Peak_Height": peaks_df.loc[i, 'peakz'], # adjust if column name differs
- "Atom_Name": atom.get_name(),
- "Residue": atom.get_parent().get_resname(),
- "Residue_ID": atom.get_parent().get_id()[1],
- "Chain": atom.get_parent().get_parent().id,
- "B_factor": atom.get_bfactor(),
- "Occupancy": atom.get_occupancy(),
- "Distance": dist
- })
+ matches.append(
+ {
+ "Peak_X": peak[0],
+ "Peak_Y": peak[1],
+ "Peak_Z": peak[2],
+ "Peak_Height": peaks_df.loc[i, "peakz"], # adjust if column name differs
+ "Atom_Name": atom.get_name(),
+ "Residue": atom.get_parent().get_resname(),
+ "Residue_ID": atom.get_parent().get_id()[1],
+ "Chain": atom.get_parent().get_parent().id,
+ "B_factor": atom.get_bfactor(),
+ "Occupancy": atom.get_occupancy(),
+ "Distance": dist,
+ }
+ )
# Create output DataFrame
results_df = pd.DataFrame(matches)
@@ -94,4 +96,4 @@
# # Log the figure to wandb
# wandb.log({"peaks_vs_model": wandb.Image(fig)})
-# plt.show()
\ No newline at end of file
+# plt.show()
diff --git a/src/factory/positional_encoding.py b/src/factory/positional_encoding.py
index 63f915b..9501be4 100644
--- a/src/factory/positional_encoding.py
+++ b/src/factory/positional_encoding.py
@@ -2,6 +2,7 @@
import torch
+
def positional_encoding(X, L):
"""
X: metadata (batch_size, feature_size)
@@ -15,18 +16,21 @@ def positional_encoding(X, L):
# Get min and max values along the last dimension
min_vals = X.min(dim=-1, keepdim=True)[0]
max_vals = X.max(dim=-1, keepdim=True)[0]
-
+
# Normalize between -1 and 1
- p = 2. * (X - min_vals) / (max_vals - min_vals + 1e-8) - 1.
-
+ p = 2.0 * (X - min_vals) / (max_vals - min_vals + 1e-8) - 1.0
+
# Create frequency bands
L_range = torch.arange(L, dtype=X.dtype, device=X.device)
f = torch.pi * 2**L_range
-
+
# Compute positional encoding
fp = (f[..., None, :] * p[..., :, None]).reshape(p.shape[:-1] + (-1,))
-
- return torch.cat((
- torch.cos(fp),
- torch.sin(fp),
- ), dim=-1)
\ No newline at end of file
+
+ return torch.cat(
+ (
+ torch.cos(fp),
+ torch.sin(fp),
+ ),
+ dim=-1,
+ )
diff --git a/src/factory/rasu.py b/src/factory/rasu.py
index 4738b1b..2a1ca2f 100644
--- a/src/factory/rasu.py
+++ b/src/factory/rasu.py
@@ -3,6 +3,7 @@
import gemmi
import numpy as np
import torch
+from abismal_torch.symmetry.op import Op
from reciprocalspaceship.decorators import cellify, spacegroupify
from reciprocalspaceship.utils import (
apply_to_hkl,
@@ -10,8 +11,6 @@
generate_reciprocal_cell,
)
-from abismal_torch.symmetry.op import Op
-
class ReciprocalASU(torch.nn.Module):
@cellify
@@ -22,7 +21,7 @@ def __init__(
spacegroup: gemmi.SpaceGroup,
dmin: float,
anomalous: Optional[bool] = True,
- **kwargs
+ **kwargs,
) -> None:
"""
Base Layer that maps observed reflections to the reciprocal asymmetric unit (rasu).
@@ -152,8 +151,18 @@ def __init__(self, *rasus: ReciprocalASU, **kwargs) -> None:
for rasu_id, rasu in enumerate(self.reciprocal_asus):
Hcell = generate_reciprocal_cell(rasu.cell, dmin=rasu.dmin)
h, k, l = Hcell.T
- self.reflection_id_grid[torch.tensor(rasu_id, dtype=torch.int),torch.tensor(h, dtype=torch.int), torch.tensor(k, dtype=torch.int), torch.tensor(l, dtype=torch.int)] = (
- rasu.reflection_id_grid[torch.tensor(h, dtype=torch.int), torch.tensor(k, dtype=torch.int), torch.tensor(l, dtype=torch.int)] + offset
+ self.reflection_id_grid[
+ torch.tensor(rasu_id, dtype=torch.int),
+ torch.tensor(h, dtype=torch.int),
+ torch.tensor(k, dtype=torch.int),
+ torch.tensor(l, dtype=torch.int),
+ ] = (
+ rasu.reflection_id_grid[
+ torch.tensor(h, dtype=torch.int),
+ torch.tensor(k, dtype=torch.int),
+ torch.tensor(l, dtype=torch.int),
+ ]
+ + offset
)
# self.reflection_id_grid[rasu_id, h, k, l] = (
# rasu.reflection_id_grid[h,k,l] + offset
@@ -195,7 +204,7 @@ def __init__(
*rasus: ReciprocalASU,
parents: Optional[torch.Tensor] = None,
reindexing_ops: Optional[Sequence[str]] = None,
- **kwargs
+ **kwargs,
) -> None:
"""
A graph of rasu objects.
diff --git a/src/factory/run_data_loader.py b/src/factory/run_data_loader.py
index 3253956..06e8317 100644
--- a/src/factory/run_data_loader.py
+++ b/src/factory/run_data_loader.py
@@ -1,8 +1,9 @@
-from data_loader import *
import argparse
import json
import CNN_3d
+from data_loader import *
+
def parse_args():
parser = argparse.ArgumentParser(description="Pass DataLoader settings")
@@ -19,11 +20,12 @@ def parse_args():
)
return parser.parse_args()
+
def main():
print("main function")
# args = parse_args()
# if args.data_file_names is not None:
- # settings = DataLoaderSettings(data_directory=args.data_directory,
+ # settings = DataLoaderSettings(data_directory=args.data_directory,
# data_file_names=args.data_file_names,
# test_set_split=0.01
# )
@@ -39,18 +41,19 @@ def main():
data_directory = "/n/hekstra_lab/people/aldama/subset"
data_file_names = {
- "shoeboxes": "standardized_shoeboxes_subset.pt",
- "counts": "raw_counts_subset.pt",
- "metadata": "shoebox_features_subset.pt",
- "masks": "masks_subset.pt",
- "true_reference": "metadata_subset.pt",
+ "shoeboxes": "standardized_shoeboxes_subset.pt",
+ "counts": "raw_counts_subset.pt",
+ "metadata": "shoebox_features_subset.pt",
+ "masks": "masks_subset.pt",
+ "true_reference": "metadata_subset.pt",
}
print("load settings")
- settings = DataLoaderSettings(data_directory=data_directory,
- data_file_names=data_file_names,
- test_set_split=0.5
- )
+ settings = DataLoaderSettings(
+ data_directory=data_directory,
+ data_file_names=data_file_names,
+ test_set_split=0.5,
+ )
print("load settings done")
test_data = test(settings=settings)
batch = next(iter(test_data))
@@ -66,4 +69,3 @@ def main():
if __name__ == "__main__":
main()
-
diff --git a/src/factory/run_factory.py b/src/factory/run_factory.py
index 6d227cf..b8aae5b 100644
--- a/src/factory/run_factory.py
+++ b/src/factory/run_factory.py
@@ -1,40 +1,46 @@
# from model_playground import *
-from model import *
-from settings import *
-from parse_yaml import load_settings_from_yaml, _log_settings_from_yaml
-from lightning.pytorch.loggers import WandbLogger
-from lightning.pytorch.callbacks import Callback, ModelCheckpoint
-from pytorch_lightning.profilers import AdvancedProfiler
+import argparse
import glob
import os
+from datetime import datetime
+
import phenix_callback
+import torch
import wandb
-from pytorch_lightning.tuner.tuning import Tuner
-import argparse
-from datetime import datetime
-from callbacks import *
import yaml
-import torch
+from callbacks import *
+from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
+from lightning.pytorch.loggers import WandbLogger
+from model import *
from omegaconf import OmegaConf
-from lightning.pytorch.callbacks import LearningRateMonitor
+from parse_yaml import _log_settings_from_yaml, load_settings_from_yaml
+from pytorch_lightning.profilers import AdvancedProfiler
+from pytorch_lightning.tuner.tuning import Tuner
+from settings import *
+
-def configure_settings(config_path=None):
- model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(config_path)
+def configure_settings(config_path=None):
+ model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(
+ config_path
+ )
return model_settings, loss_settings, dataloader_settings
-
-def run(config_path, save_directory, run_from_version=False):
+
+def run(config_path, save_directory, run_from_version=False):
now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
-
if run_from_version:
run_name = f"{now_str}_2025-11-05_13-59-42_from-start"
- wandb_logger = WandbLogger(project="full-model", name=run_name, save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs")
+ wandb_logger = WandbLogger(
+ project="full-model",
+ name=run_name,
+ save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
+ )
wandb.init(
project="full-model",
name=run_name,
- dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs"
+ dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
)
config_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml"
@@ -43,9 +49,10 @@ def run(config_path, save_directory, run_from_version=False):
# config_path = os.path.join(config_artifact_dir, "config_example.yaml")
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
- model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(config_dict)
+ model_settings, loss_settings, phenix_settings, dataloader_settings = (
+ load_settings_from_yaml(config_dict)
+ )
-
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id is not None:
wandb.config.update({"slurm_job_id": slurm_job_id})
@@ -56,8 +63,9 @@ def run(config_path, save_directory, run_from_version=False):
# checkpoint_path = os.path.join(artifact_dir, ckpt_files[-1])
checkpoint_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20251106_125602-hgjymq80/files/checkpoints/last.ckpt"
- model = Model.load_from_checkpoint(checkpoint_path, model_settings=model_settings, loss_settings=loss_settings)
-
+ model = Model.load_from_checkpoint(
+ checkpoint_path, model_settings=model_settings, loss_settings=loss_settings
+ )
else:
run_name = f"{now_str}_from-start"
@@ -66,7 +74,9 @@ def run(config_path, save_directory, run_from_version=False):
base_cfg = OmegaConf.load(config_path)
- wandb_logger = WandbLogger(project="full-model", name=run_name, save_dir=save_directory)
+ wandb_logger = WandbLogger(
+ project="full-model", name=run_name, save_dir=save_directory
+ )
wandb_run = wandb.init(
project="full-model",
name=run_name,
@@ -79,17 +89,21 @@ def run(config_path, save_directory, run_from_version=False):
wandb_config = wandb_run.config
- model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(wandb_config) #wandb_config)
+ model_settings, loss_settings, phenix_settings, dataloader_settings = (
+ load_settings_from_yaml(wandb_config)
+ ) # wandb_config)
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id is not None:
wandb.config.update({"slurm_job_id": slurm_job_id})
model = Model(model_settings=model_settings, loss_settings=loss_settings)
-
+
config_artifact = wandb.Artifact("config", type="config")
config_artifact.add_file(config_path) # Path to your config file
logged_artifact = wandb_logger.experiment.log_artifact(config_artifact)
- dataloader = data_loader.CrystallographicDataLoader(data_loader_settings=dataloader_settings)
+ dataloader = data_loader.CrystallographicDataLoader(
+ data_loader_settings=dataloader_settings
+ )
print("Loading data ...")
dataloader.load_data_()
print("Data loaded successfully.")
@@ -104,7 +118,6 @@ def run(config_path, save_directory, run_from_version=False):
if hasattr(model, "set_dataloader"):
model.set_dataloader(train_dataloader)
-
# if model_settings.enable_checkpointing:
checkpoint_callback = ModelCheckpoint(
dirpath=wandb_logger.experiment.dir + "/checkpoints",
@@ -112,20 +125,20 @@ def run(config_path, save_directory, run_from_version=False):
# save_top_k=6,
every_n_epochs=15,
# mode="min",
- filename="{epoch:02d}",#-{step}-{validation_loss/loss:.2f}",
+ filename="{epoch:02d}", # -{step}-{validation_loss/loss:.2f}",
save_weights_only=True,
# every_n_train_steps=None, # Disable step-based checkpointing
- save_last=True # Don't save the last checkpoint
+ save_last=True, # Don't save the last checkpoint
)
trainer = L.pytorch.Trainer(
- logger=wandb_logger,
+ logger=wandb_logger,
max_epochs=160,
- log_every_n_steps=200,
- val_check_interval=200,
+ log_every_n_steps=200,
+ val_check_interval=200,
# limit_val_batches=21,
- accelerator="auto",
- enable_checkpointing=True, #model_settings.enable_checkpointing,
+ accelerator="auto",
+ enable_checkpointing=True, # model_settings.enable_checkpointing,
# default_root_dir="/tmp",
# profiler="simple",
# profiler=AdvancedProfiler(dirpath=wandb_logger.experiment.dir, filename="profiler.txt"),
@@ -133,23 +146,28 @@ def run(config_path, save_directory, run_from_version=False):
callbacks=[
Plotting(dataloader=dataloader),
LossLogging(),
- CorrelationPlottingBinned(dataloader=dataloader),
- CorrelationPlotting(dataloader=dataloader),
+ CorrelationPlottingBinned(dataloader=dataloader),
+ CorrelationPlotting(dataloader=dataloader),
ScalePlotting(dataloader=dataloader),
checkpoint_callback,
# PeakHeights(),
- LearningRateMonitor(logging_interval='epoch'),
+ LearningRateMonitor(logging_interval="epoch"),
# ValidationSetCorrelationPlot(dataloader=dataloader),
- ]
+ ],
)
print("lr_scheduler_configs:", trainer.lr_scheduler_configs)
-
- trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
-
+ trainer.fit(
+ model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
+ )
+
if model_settings.enable_checkpointing:
- checkpoint_dir = checkpoint_callback.dirpath if hasattr(checkpoint_callback, "dirpath") else "/tmp"
+ checkpoint_dir = (
+ checkpoint_callback.dirpath
+ if hasattr(checkpoint_callback, "dirpath")
+ else "/tmp"
+ )
print("checkpoint_dir", checkpoint_dir)
checkpoint_pattern = os.path.join(checkpoint_dir, "**", "*.ckpt")
@@ -157,14 +175,20 @@ def run(config_path, save_directory, run_from_version=False):
print("checkpoint_paths", checkpoint_paths)
- print("model_settings.enable_checkpointing",model_settings.enable_checkpointing)
- print("checkpoint_callback.best_model_path",checkpoint_callback.best_model_path)
+ print(
+ "model_settings.enable_checkpointing", model_settings.enable_checkpointing
+ )
+ print(
+ "checkpoint_callback.best_model_path", checkpoint_callback.best_model_path
+ )
if run_from_version == False:
- config_artifact_ref = logged_artifact.name # e.g., 'your-entity/your-project/config:v0'
+ config_artifact_ref = (
+ logged_artifact.name
+ ) # e.g., 'your-entity/your-project/config:v0'
version_alias = config_artifact_ref.split(":")[-1]
- artifact = wandb.Artifact('models', type='model')
+ artifact = wandb.Artifact("models", type="model")
for ckpt_file in checkpoint_paths:
artifact.add_file(ckpt_file)
print(f"Logged {len(checkpoint_paths)} checkpoints to W&B artifact.")
@@ -173,15 +197,14 @@ def run(config_path, save_directory, run_from_version=False):
print("run phenix")
phenix_callback.run_phenix_over_all_checkpoints(
- model_settings=model_settings,
- loss_settings=loss_settings,
+ model_settings=model_settings,
+ loss_settings=loss_settings,
phenix_settings=phenix_settings,
- artifact_dir=checkpoint_dir,
+ artifact_dir=checkpoint_dir,
checkpoint_paths=checkpoint_paths,
- wandb_directory=wandb_logger.experiment.dir
+ wandb_directory=wandb_logger.experiment.dir,
)
-
# for ckpt_file in checkpoint_paths:
# try:
# os.remove(ckpt_file)
@@ -191,14 +214,26 @@ def run(config_path, save_directory, run_from_version=False):
wandb.finish()
+
def main():
- parser = argparse.ArgumentParser(description='Run factory with config file')
- parser.add_argument('--config', type=str, default='../../configs/config_example.yaml', help='Path to config YAML file')
+ parser = argparse.ArgumentParser(description="Run factory with config file")
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="../../configs/config_example.yaml",
+ help="Path to config YAML file",
+ )
args, _ = parser.parse_known_args()
- parser.add_argument('--save_directory', type=str, default='./factory_results', help='Directory for W&B logs and checkpoint artifacts')
+ parser.add_argument(
+ "--save_directory",
+ type=str,
+ default="./factory_results",
+ help="Directory for W&B logs and checkpoint artifacts",
+ )
args, _ = parser.parse_known_args()
-
+
run(config_path=args.config, save_directory=args.save_directory)
+
if __name__ == "__main__":
main()
diff --git a/src/factory/run_phenix_from_model.py b/src/factory/run_phenix_from_model.py
index 85185ff..97f917d 100644
--- a/src/factory/run_phenix_from_model.py
+++ b/src/factory/run_phenix_from_model.py
@@ -1,15 +1,16 @@
-from model import *
-from settings import *
-from parse_yaml import load_settings_from_yaml
-from lightning.pytorch.loggers import WandbLogger
-from lightning.pytorch.callbacks import Callback, ModelCheckpoint
-
import glob
import os
+from datetime import datetime
+
import phenix_callback
import wandb
-from datetime import datetime
+from lightning.pytorch.callbacks import Callback, ModelCheckpoint
+from lightning.pytorch.loggers import WandbLogger
+from model import *
+from parse_yaml import load_settings_from_yaml
from pytorch_lightning.tuner.tuning import Tuner
+from settings import *
+
def run():
@@ -18,48 +19,64 @@ def run():
now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
- wandb_logger = WandbLogger(project="full-model", name=f"{now_str}_phenix_from_{model_alias}", save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs")
+ wandb_logger = WandbLogger(
+ project="full-model",
+ name=f"{now_str}_phenix_from_{model_alias}",
+ save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
+ )
wandb.init(
project="full-model",
name=f"{now_str}_phenix_from_{model_alias}",
- dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs"
+ dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
)
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id is not None:
wandb.config.update({"slurm_job_id": slurm_job_id})
- print("wandb_logger.experiment.dir",wandb_logger.experiment.dir)
+ print("wandb_logger.experiment.dir", wandb_logger.experiment.dir)
- artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/models:{model_alias}", type="model")
+ artifact = wandb.use_artifact(
+ f"flaviagiehr-harvard-university/full-model/models:{model_alias}", type="model"
+ )
artifact_dir = artifact.download()
print("Files in artifact_dir:", os.listdir(artifact_dir))
- config_artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/config:{config_alias}", type="config")
+ config_artifact = wandb.use_artifact(
+ f"flaviagiehr-harvard-university/full-model/config:{config_alias}",
+ type="config",
+ )
config_artifact_dir = config_artifact.download()
config_path = os.path.join(config_artifact_dir, "config_example.yaml")
# config_path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml"
- model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(path=config_path)
+ model_settings, loss_settings, phenix_settings, dataloader_settings = (
+ load_settings_from_yaml(path=config_path)
+ )
# model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml")
-
- checkpoint_paths = [os.path.join(artifact_dir, f) for f in os.listdir(artifact_dir) if f.endswith(".ckpt")]
+ checkpoint_paths = [
+ os.path.join(artifact_dir, f)
+ for f in os.listdir(artifact_dir)
+ if f.endswith(".ckpt")
+ ]
# artifact_dir = "/n/holylabs/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250718_124342-22j3mlzs/files/checkpoints"
print(f"downloaded {len(checkpoint_paths)} checkpoint files.")
print("run phenix")
phenix_callback.run_phenix_over_all_checkpoints(
- model_settings=model_settings,
- loss_settings=loss_settings,
+ model_settings=model_settings,
+ loss_settings=loss_settings,
phenix_settings=phenix_settings,
- artifact_dir=artifact_dir,
+ artifact_dir=artifact_dir,
checkpoint_paths=checkpoint_paths,
- wandb_directory=wandb_logger.experiment.dir,#"/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250720_200430-py737ahw/files",#wandb_logger.experiment.dir
- )
+ wandb_directory=wandb_logger.experiment.dir, # "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250720_200430-py737ahw/files",#wandb_logger.experiment.dir
+ )
+
def main():
# cProfile.run('run()')
run()
+
if __name__ == "__main__":
main()
diff --git a/src/factory/run_reference_processing.py b/src/factory/run_reference_processing.py
index a181fba..e0c3c7a 100644
--- a/src/factory/run_reference_processing.py
+++ b/src/factory/run_reference_processing.py
@@ -23,7 +23,9 @@ def run_dials(dials_env, command):
def run():
- refl_file="/n/hekstra_lab/people/aldama/subset/small_dataset/pass1/reflections_.refl"
+ refl_file = (
+ "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1/reflections_.refl"
+ )
scale_command = (
f"dials.scale '{refl_file}' '{expt_file}' "
f"output.reflections='{scaled_refl_out}' "
@@ -36,5 +38,6 @@ def run():
dials_env = "/n/hekstra_lab/people/aldama/software/dials-v3-16-1/dials_env.sh"
run_dials(dials_env, scale_command)
-if name==main:
- run()
\ No newline at end of file
+
+if name == main:
+ run()
diff --git a/src/factory/settings.py b/src/factory/settings.py
index 52e61fb..998e665 100644
--- a/src/factory/settings.py
+++ b/src/factory/settings.py
@@ -1,41 +1,46 @@
import dataclasses
-import torch
-import torch
-import torch.nn.functional as F
-from networks import *
+
import distributions
import get_protein_data
+import metadata_encoder as me
import reciprocalspaceship as rs
-from abismal_torch.prior import WilsonPrior
-
-from rasu import *
+import shoebox_encoder as se
+import torch
+import torch.nn.functional as F
from abismal_torch.likelihood import NormalLikelihood
+from abismal_torch.prior import WilsonPrior
from abismal_torch.surrogate_posterior import FoldedNormalPosterior
+from networks import *
+from rasu import *
+
from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior
-import shoebox_encoder as se
-import metadata_encoder as me
# from lazy_adam import LazyAdamW
@dataclasses.dataclass
-class ModelSettings():
+class ModelSettings:
run_from_version: str | None = None
data_directory: str = "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1"
- data_file_names: dict = dataclasses.field(default_factory=lambda: {
- # "shoeboxes": "standardized_counts.pt",
- "counts": "counts.pt",
- "metadata": "reference_.pt",
- "masks": "masks.pt",
- })
+ data_file_names: dict = dataclasses.field(
+ default_factory=lambda: {
+ # "shoeboxes": "standardized_counts.pt",
+ "counts": "counts.pt",
+ "metadata": "reference_.pt",
+ "masks": "masks.pt",
+ }
+ )
build_background_distribution: type = distributions.HalfNormalDistribution
build_profile_distribution: type = distributions.Distribution
# build_shoebox_profile_distribution = distributions.LRMVN_Distribution
build_shoebox_profile_distribution: type = distributions.DirichletProfile
background_prior_distribution: torch.distributions.Gamma = dataclasses.field(
- default_factory=lambda: torch.distributions.HalfNormal(0.5))
+ default_factory=lambda: torch.distributions.HalfNormal(0.5)
+ )
scale_prior_distibution: torch.distributions.Gamma = dataclasses.field(
- default_factory=lambda: torch.distributions.Gamma(concentration=torch.tensor(6.68), rate=torch.tensor(6.4463))
+ default_factory=lambda: torch.distributions.Gamma(
+ concentration=torch.tensor(6.68), rate=torch.tensor(6.4463)
+ )
)
scale_function: MLPScale = dataclasses.field(
default_factory=lambda: MLPScale(
@@ -43,14 +48,14 @@ class ModelSettings():
scale_distribution=LogNormalDistributionLayer(hidden_dimension=64),
hidden_dimension=64,
number_of_layers=1,
- initial_scale_guess=2/140
+ initial_scale_guess=2 / 140,
)
)
build_intensity_prior_distribution: WilsonPrior = WilsonPrior
intensity_prior_distibution: WilsonPrior = dataclasses.field(init=False)
use_surrogate_parameters: bool = False
-
+
shoebox_encoder: type = se.BaseShoeboxEncoder()
metadata_encoder: type = me.BaseMetadataEncoder()
intensity_encoder: type = se.IntensityEncoder()
@@ -59,12 +64,12 @@ class ModelSettings():
use_positional_encoding: bool = False
number_of_frequencies_in_positional_encoding: int = 2
-
-
- optimizer: type = torch.optim.AdamW #torch.optim.AdamW
+ optimizer: type = torch.optim.AdamW # torch.optim.AdamW
optimizer_betas: tuple = (0.9, 0.9)
learning_rate: float = 0.001
- surrogate_posterior_learning_rate: float = 0.01 # Higher learning rate for surrogate posterior
+ surrogate_posterior_learning_rate: float = (
+ 0.01 # Higher learning rate for surrogate posterior
+ )
weight_decay: float = 0.0001 # Weight decay for main parameters
surrogate_weight_decay: float = 0.0001 # Weight decay for surrogate parameters
dmodel: int = 64
@@ -73,10 +78,16 @@ class ModelSettings():
number_of_mc_samples: int = 20
enable_checkpointing: bool = True
- lysozyme_sequence_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/data/lysozyme.seq"
+ lysozyme_sequence_file_path: str = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/data/lysozyme.seq"
+ )
- merged_mtz_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz"
- unmerged_mtz_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz"
+ merged_mtz_file_path: str = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz"
+ )
+ unmerged_mtz_file_path: str = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz"
+ )
protein_pdb_url: str = "https://files.rcsb.org/download/9B7C.cif"
rac: ReciprocalASUCollection = dataclasses.field(init=False)
pdb_data: dict = dataclasses.field(init=False)
@@ -84,30 +95,41 @@ class ModelSettings():
def __post_init__(self):
self.pdb_data = get_protein_data.get_protein_data(self.protein_pdb_url)
- self.rac = ReciprocalASUGraph(*[ReciprocalASU(
- cell=self.pdb_data["unit_cell"],
- spacegroup=self.pdb_data["spacegroup"],
- dmin=float(self.pdb_data["dmin"]),
- anomalous=True,
- )])
- self.intensity_prior_distibution = self.build_intensity_prior_distribution(self.rac)
+ self.rac = ReciprocalASUGraph(
+ *[
+ ReciprocalASU(
+ cell=self.pdb_data["unit_cell"],
+ spacegroup=self.pdb_data["spacegroup"],
+ dmin=float(self.pdb_data["dmin"]),
+ anomalous=True,
+ )
+ ]
+ )
+ self.intensity_prior_distibution = self.build_intensity_prior_distribution(
+ self.rac
+ )
@dataclasses.dataclass
-class LossSettings():
+class LossSettings:
prior_background_weight: float = 0.0001
prior_structure_factors_weight: float = 0.0001
prior_scale_weight: float = 0.0001
- prior_profile_weight: list[float] = dataclasses.field(default_factory=lambda: [0.0001, 0.0001, 0.0001])
+ prior_profile_weight: list[float] = dataclasses.field(
+ default_factory=lambda: [0.0001, 0.0001, 0.0001]
+ )
eps: float = 0.00001
+
@dataclasses.dataclass
-class PhenixSettings():
- r_values_reference_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log"
+class PhenixSettings:
+ r_values_reference_path: str = (
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log"
+ )
-@ dataclasses.dataclass
-class DataLoaderSettings():
+@dataclasses.dataclass
+class DataLoaderSettings:
# data_directory: str = "/n/hekstra_lab/people/aldama/subset"
# data_file_names: dict = dataclasses.field(default_factory=lambda: {
# "shoeboxes": "shoebox_subset.pt",
@@ -116,27 +138,28 @@ class DataLoaderSettings():
# "masks": "mask_subset.pt",
# "true_reference": "true_reference_subset.pt",
# })
- metadata_indices: dict = dataclasses.field(default_factory=lambda: {
- "d": 0,
- "h": 1,
- "k": 2,
- "l": 3,
- "x": 4,
- "y": 5,
- "z": 6,
- })
- metadata_keys_to_keep: list = dataclasses.field(default_factory=lambda: [
- "x", "y"
- ])
-
+ metadata_indices: dict = dataclasses.field(
+ default_factory=lambda: {
+ "d": 0,
+ "h": 1,
+ "k": 2,
+ "l": 3,
+ "x": 4,
+ "y": 5,
+ "z": 6,
+ }
+ )
+ metadata_keys_to_keep: list = dataclasses.field(default_factory=lambda: ["x", "y"])
data_directory: str = "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1"
- data_file_names: dict = dataclasses.field(default_factory=lambda: {
- # "shoeboxes": "standardized_counts.pt",
- "counts": "counts.pt",
- "metadata": "reference_.pt",
- "masks": "masks.pt",
- })
+ data_file_names: dict = dataclasses.field(
+ default_factory=lambda: {
+ # "shoeboxes": "standardized_counts.pt",
+ "counts": "counts.pt",
+ "metadata": "reference_.pt",
+ "masks": "masks.pt",
+ }
+ )
validation_set_split: float = 0.2
test_set_split: float = 0
@@ -145,9 +168,9 @@ class DataLoaderSettings():
number_of_batches: int = 1444
number_of_workers: int = 16
pin_memory: bool = True
- prefetch_factor: int|None = 2
+ prefetch_factor: int | None = 2
shuffle_indices: bool = True
shuffle_groups: bool = True
optimize_shoeboxes_per_batch: bool = True
append_image_id_to_metadata: bool = False
- verbose: bool = True
\ No newline at end of file
+ verbose: bool = True
diff --git a/src/factory/shoebox_encoder.py b/src/factory/shoebox_encoder.py
index 0607352..4e0460a 100644
--- a/src/factory/shoebox_encoder.py
+++ b/src/factory/shoebox_encoder.py
@@ -1,21 +1,20 @@
-import torch
import math
+
+import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
+
class ShoeboxEncoder(nn.Module):
- def __init__(
- self,
- input_shape=(3, 21, 21),
- out_dim=64
- ):
+ def __init__(self, input_shape=(3, 21, 21), out_dim=64):
super().__init__()
def forward(self, x):
pass
+
class BaseShoeboxEncoder(nn.Module):
def __init__(
self,
@@ -76,6 +75,7 @@ def forward(self, x, mask=None):
x = x.view(x.size(0), -1)
return F.relu(self.fc(x))
+
class SimpleShoeboxEncoder(nn.Module):
def __init__(
self,
@@ -134,33 +134,41 @@ def forward(self, x, mask=None):
if torch.isnan(x).any():
print("WARNING: NaN values in BaseShoeboxEncoder input!")
print("NaN count:", torch.isnan(x).sum().item())
-
+
x = F.relu(self.norm1(self.conv1(x)))
-
+
# Check after first conv
if torch.isnan(x).any():
print("WARNING: NaN values after conv1 in BaseShoeboxEncoder!")
print("NaN count:", torch.isnan(x).sum().item())
-
+
x = self.pool(x)
x = F.relu(self.norm2(self.conv2(x)))
-
+
# Check after second conv
if torch.isnan(x).any():
print("WARNING: NaN values after conv2 in BaseShoeboxEncoder!")
print("NaN count:", torch.isnan(x).sum().item())
-
+
x = x.view(x.size(0), -1)
x = F.relu(self.fc(x))
-
+
# Check final output
if torch.isnan(x).any():
print("WARNING: NaN values in BaseShoeboxEncoder output!")
print("NaN count:", torch.isnan(x).sum().item())
- print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item())
-
+ print(
+ "Stats - min:",
+ x.min().item(),
+ "max:",
+ x.max().item(),
+ "mean:",
+ x.mean().item(),
+ )
+
return x
+
class IntensityEncoder(torch.nn.Module):
def __init__(
self,
diff --git a/src/factory/simple_encoder.py b/src/factory/simple_encoder.py
index 0ce7e78..acfbf2b 100644
--- a/src/factory/simple_encoder.py
+++ b/src/factory/simple_encoder.py
@@ -1,9 +1,11 @@
-import torch
import math
+
+import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
+
class ShoeboxEncoder(nn.Module):
def __init__(
self,
@@ -64,6 +66,7 @@ def forward(self, x, mask=None):
x = x.view(x.size(0), -1)
return F.relu(self.fc(x))
+
def weight_initializer(weight):
fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2])
std = math.sqrt(1.0 / fan_avg / 10.0)
@@ -117,6 +120,7 @@ def forward(self, x):
return out
+
class MLPMetadataEncoder(nn.Module):
def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None):
super().__init__()
@@ -143,4 +147,4 @@ def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None):
def forward(self, x):
# Process through the model
x = self.model(x)
- return x
\ No newline at end of file
+ return x
diff --git a/src/factory/simplest_encode.py b/src/factory/simplest_encode.py
index 2ce2c0b..ceb8847 100644
--- a/src/factory/simplest_encode.py
+++ b/src/factory/simplest_encode.py
@@ -1,9 +1,11 @@
-import torch
import math
+
+import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
+
class ShoeboxEncoder(nn.Module):
def __init__(
self,
@@ -64,6 +66,7 @@ def forward(self, x, mask=None):
x = x.view(x.size(0), -1)
return F.relu(self.fc(x))
+
def weight_initializer(weight):
fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2])
std = math.sqrt(1.0 / fan_avg / 10.0)
@@ -117,6 +120,7 @@ def forward(self, x):
return out
+
class MLPMetadataEncoder(nn.Module):
def __init__(self, feature_dim, depth=10, dropout=0.0, output_dims=None):
super().__init__()
@@ -143,4 +147,4 @@ def __init__(self, feature_dim, depth=10, dropout=0.0, output_dims=None):
def forward(self, x):
# Process through the model
x = self.model(x)
- return x
\ No newline at end of file
+ return x
diff --git a/src/factory/test_callbacks.py b/src/factory/test_callbacks.py
index fa25bdf..0c197d2 100644
--- a/src/factory/test_callbacks.py
+++ b/src/factory/test_callbacks.py
@@ -1,15 +1,16 @@
-from model import *
-from settings import *
-from parse_yaml import load_settings_from_yaml
-from lightning.pytorch.loggers import WandbLogger
-from lightning.pytorch.callbacks import Callback, ModelCheckpoint
-
import glob
import os
+from datetime import datetime
+
import phenix_callback
import wandb
-from datetime import datetime
+from lightning.pytorch.callbacks import Callback, ModelCheckpoint
+from lightning.pytorch.loggers import WandbLogger
+from model import *
+from parse_yaml import load_settings_from_yaml
from pytorch_lightning.tuner.tuning import Tuner
+from settings import *
+
def run():
@@ -17,18 +18,25 @@ def run():
now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
- wandb_logger = WandbLogger(project="full-model", name=f"{now_str}_phenix_from_{model_alias}", save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs")
+ wandb_logger = WandbLogger(
+ project="full-model",
+ name=f"{now_str}_phenix_from_{model_alias}",
+ save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
+ )
wandb.init(
project="full-model",
name=f"{now_str}_phenix_from_{model_alias}",
- dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs"
+ dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs",
)
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id is not None:
wandb.config.update({"slurm_job_id": slurm_job_id})
- print("wandb_logger.experiment.dir",wandb_logger.experiment.dir)
+ print("wandb_logger.experiment.dir", wandb_logger.experiment.dir)
- artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/best_models:{model_alias}", type="model")
+ artifact = wandb.use_artifact(
+ f"flaviagiehr-harvard-university/full-model/best_models:{model_alias}",
+ type="model",
+ )
artifact_dir = artifact.download()
print("Files in artifact_dir:", os.listdir(artifact_dir))
@@ -37,18 +45,19 @@ def run():
# config_path = os.path.join(config_artifact.download(), "config.yaml")
# model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path=config_path)
- model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml")
-
+ model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(
+ path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml"
+ )
ckpt_file = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")][-1]
print(f"downloaded {len(ckpt_files)} checkpoint files.")
-
def main():
# cProfile.run('run()')
run()
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/factory/test_dataloader.py b/src/factory/test_dataloader.py
index 18c743f..4944a52 100644
--- a/src/factory/test_dataloader.py
+++ b/src/factory/test_dataloader.py
@@ -1,23 +1,27 @@
import settings
from data_loader import CrystallographicDataLoader
+
def test_dataloader(data_loader_settings: settings.DataLoaderSettings):
print("enter test")
dataloader = CrystallographicDataLoader(data_loader_settings=data_loader_settings)
- print("Loading data ...")
+ print("Loading data ...")
dataloader.load_data_()
print("Data loaded successfully.")
-
+
test_data = dataloader.load_data_set_batched_by_image(
data_set_to_load=dataloader.train_data_set
)
print("Test data loaded successfully.")
for batch in test_data:
- shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl = batch
+ shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl = (
+ batch
+ )
print("Batch shoeboxes :", shoeboxes_batch.shape)
- print("image ids", metadata_batch[:,2])
+ print("image ids", metadata_batch[:, 2])
return test_data
+
def parse_args():
parser = argparse.ArgumentParser(description="Pass DataLoader settings")
parser.add_argument(
@@ -27,10 +31,12 @@ def parse_args():
)
return parser.parse_args()
+
def main():
# args = parse_args()
data_loader_settings = settings.DataLoaderSettings()
test_dataloader(data_loader_settings=data_loader_settings)
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/factory/unmerged_mtz.py b/src/factory/unmerged_mtz.py
index cfb69c7..d0b6cd9 100644
--- a/src/factory/unmerged_mtz.py
+++ b/src/factory/unmerged_mtz.py
@@ -1,7 +1,9 @@
import reciprocalspaceship as rs
# Load the MTZ file
-ds = rs.read_mtz("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz")
+ds = rs.read_mtz(
+ "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz"
+)
# Show the column labels
print("Columns:", ds.columns)
diff --git a/src/factory/wrap_folded_normal.py b/src/factory/wrap_folded_normal.py
index 86e20a2..9ffa3cb 100644
--- a/src/factory/wrap_folded_normal.py
+++ b/src/factory/wrap_folded_normal.py
@@ -1,8 +1,9 @@
-import torch
from typing import Optional
+
+import rs_distributions.modules as rsm
+import torch
from abismal_torch.surrogate_posterior import FoldedNormalPosterior
from abismal_torch.symmetry import ReciprocalASUCollection
-import rs_distributions.modules as rsm
class FrequencyTrackingPosterior(FoldedNormalPosterior):
@@ -13,11 +14,11 @@ def __init__(
loc: torch.Tensor,
scale: torch.Tensor,
epsilon: Optional[float] = 1e-12,
- **kwargs
+ **kwargs,
):
"""
A surrogate posterior that tracks how many times each HKL has been observed.
-
+
Args:
rac (ReciprocalASUCollection): ReciprocalASUCollection.
loc (torch.Tensor): Unconstrained location parameter of the distribution.
@@ -25,11 +26,11 @@ def __init__(
epsilon (float, optional): Epsilon value for numerical stability. Defaults to 1e-12.
"""
super().__init__(rac, loc, scale, epsilon, **kwargs)
-
+
self.register_buffer(
"observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long)
)
-
+
# def to(self, *args, **kwargs):
# super().to(*args, **kwargs)
# device = next(self.parameters()).device
@@ -38,26 +39,24 @@ def __init__(
# if hasattr(self, 'observed'):
# self.observed = self.observed.to(device)
# return self
-
+
def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None:
"""
Update both the observed buffer and the observation count.
-
+
Args:
rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the
rasu ID of each reflection.
H (torch.Tensor): A tensor of shape (n_refln, 3).
"""
-
+
h, k, l = H.T
-
observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l]
-
self.observed[observed_idx] = True
self.observation_count[observed_idx] += 1
-
+
def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor:
"""
Returns a boolean tensor indicating which HKLs have been observed enough times
@@ -74,16 +73,16 @@ def get_distribution(self, rasu_ids: torch.Tensor, H: torch.Tensor):
return rsm.FoldedNormal(gathered_loc, gathered_scale)
-
-
from typing import Optional
+
+import rs_distributions.distributions as rsd
import torch
import torch.nn as nn
import torch.nn.functional as F
-import rs_distributions.distributions as rsd
from abismal_torch.surrogate_posterior.base import PosteriorBase
from abismal_torch.symmetry import ReciprocalASUCollection
+
class SparseFoldedNormalPosterior(PosteriorBase):
def __init__(
self,
@@ -91,7 +90,7 @@ def __init__(
loc: torch.Tensor,
scale: torch.Tensor,
epsilon: Optional[float] = 1e-6,
- **kwargs
+ **kwargs,
):
# Call parent with a dummy distribution; we'll override it below
super().__init__(rac, distribution=None, epsilon=epsilon, **kwargs)
@@ -100,13 +99,13 @@ def __init__(
num_params = rac.rac_size # or however many unique positions you have
# Embeddings whose gradients will be sparse
- self.loc_embed = nn.Embedding(num_params, 1, sparse=True)
+ self.loc_embed = nn.Embedding(num_params, 1, sparse=True)
self.scale_embed = nn.Embedding(num_params, 1, sparse=True)
# Initialize them to your starting guesses
with torch.no_grad():
# loc_init & scale_init should each be shape [num_params]
- self.loc_embed.weight[:, 0] = loc
+ self.loc_embed.weight[:, 0] = loc
self.scale_embed.weight[:, 0] = scale
self.epsilon = epsilon
@@ -115,13 +114,12 @@ def __init__(
"observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long)
)
-
def get_distribution(self, hkl_indices):
# Map each reflection to its ASU‐index
indices = hkl_indices
# Gather embeddings and squeeze out the singleton dim
- loc_raw = self.loc_embed(indices).squeeze(-1)
+ loc_raw = self.loc_embed(indices).squeeze(-1)
scale_raw = self.scale_embed(indices).squeeze(-1)
# Enforce positivity (you could also use transform_to())
@@ -129,26 +127,24 @@ def get_distribution(self, hkl_indices):
# Functional FoldedNormal has no trainable params—it just wraps a torch.distributions
return rsd.FoldedNormal(loc=loc_raw, scale=scale)
-
+
def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None:
"""
Update both the observed buffer and the observation count.
-
+
Args:
rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the
rasu ID of each reflection.
H (torch.Tensor): A tensor of shape (n_refln, 3).
"""
-
+
h, k, l = H.T
-
observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l]
-
self.observed[observed_idx] = True
self.observation_count[observed_idx] += 1
-
+
def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor:
"""
Returns a boolean tensor indicating which HKLs have been observed enough times
diff --git a/tests/test_distributions.py b/tests/test_distributions.py
index b6407a9..5f19e05 100644
--- a/tests/test_distributions.py
+++ b/tests/test_distributions.py
@@ -1,31 +1,34 @@
-import torch
import pytest
+import torch
+
from factory import distributions
+
def test_lrmvn_distribution():
"""Test LRMVN distribution initialization and forward pass."""
batch_size = 2
hidden_dim = 64
distribution = distributions.LRMVN_Distribution(hidden_dim=hidden_dim)
-
+
# Create dummy input
shoebox_representation = torch.randn(batch_size, hidden_dim)
image_representation = torch.randn(batch_size, hidden_dim)
-
+
# Test forward pass
output = distribution(shoebox_representation, image_representation)
assert isinstance(output, torch.distributions.LowRankMultivariateNormal)
assert output.loc.shape[0] == batch_size
+
def test_dirichlet_profile():
"""Test Dirichlet profile initialization and forward pass."""
batch_size = 2
hidden_dim = 64
profile = distributions.DirichletProfile(dmodel=hidden_dim)
-
+
# Create dummy input
representation = torch.randn(batch_size, hidden_dim)
-
+
# Test forward pass
output = profile(representation)
- assert isinstance(output, torch.distributions.Dirichlet)
\ No newline at end of file
+ assert isinstance(output, torch.distributions.Dirichlet)
diff --git a/tox.ini b/tox.ini
index 0982115..4dd931d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -26,4 +26,4 @@ setenv =
extras =
test
commands =
- pytest -v --cov=rs-template --cov-report=xml --color=yes --basetemp={envtmpdir} {posargs}
\ No newline at end of file
+ pytest -v --cov=rs-template --cov-report=xml --color=yes --basetemp={envtmpdir} {posargs}
diff --git a/wrap_folded_normal.py b/wrap_folded_normal.py
index db3a485..5f30fdf 100644
--- a/wrap_folded_normal.py
+++ b/wrap_folded_normal.py
@@ -1,8 +1,10 @@
-import torch
from typing import Optional
+
+import torch
from abismal_torch.surrogate_posterior import FoldedNormalPosterior
from abismal_torch.symmetry import ReciprocalASUCollection
+
class FrequencyTrackingPosterior(FoldedNormalPosterior):
def __init__(
self,
@@ -10,11 +12,11 @@ def __init__(
loc: torch.Tensor,
scale: torch.Tensor,
epsilon: Optional[float] = 1e-12,
- **kwargs
+ **kwargs,
):
"""
A surrogate posterior that tracks how many times each HKL has been observed.
-
+
Args:
rac (ReciprocalASUCollection): ReciprocalASUCollection.
loc (torch.Tensor): Unconstrained location parameter of the distribution.
@@ -22,15 +24,15 @@ def __init__(
epsilon (float, optional): Epsilon value for numerical stability. Defaults to 1e-12.
"""
super().__init__(rac, loc, scale, epsilon, **kwargs)
-
+
self.register_buffer(
"observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long)
)
-
+
def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None:
"""
Update both the observed buffer and the observation count.
-
+
Args:
rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the
rasu ID of each reflection.
@@ -40,10 +42,10 @@ def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None:
observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l]
self.observed[observed_idx] = True
self.observation_count[observed_idx] += 1
-
+
def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor:
"""
Returns a boolean tensor indicating which HKLs have been observed enough times
to be considered reliable.
"""
- return self.observation_count >= min_observations
\ No newline at end of file
+ return self.observation_count >= min_observations