diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28f6dfe..95f9ae4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -11,20 +11,20 @@ repos: - id: check-merge-conflict - id: detect-private-key -- repo: https://github.com/psf/black - rev: 24.1.1 +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 26.1.0 hooks: - id: black language_version: python3 - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort args: ["--profile", "black"] - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.3.0 hooks: - id: flake8 additional_dependencies: [flake8-docstrings] diff --git a/README.md b/README.md index 774ed97..3349b17 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Factory -> **⚠️ This project is under active development.** +> **⚠️ This project is under active development.** A Python package for crystallographic data processing and analysis. diff --git a/configs/config_example.yaml b/configs/config_example.yaml index 303ebf5..3316503 100644 --- a/configs/config_example.yaml +++ b/configs/config_example.yaml @@ -22,7 +22,7 @@ model_settings: weight_decay: 0.000 optimizer_eps: 1e-16 weight_decouple: true - + # Scheduler configuration # scheduler_step: 55 scheduler_gamma: 0.1 @@ -41,18 +41,18 @@ model_settings: # scale_prior_variance: 6241.0 # scale_distribution: FoldedNormalDistributionLayer - metadata_encoder: + metadata_encoder: name: BaseMetadataEncoder params: depth: 10 - # metadata_encoder: + # metadata_encoder: # name: SimpleMetadataEncoder # params: # hidden_dim: 256 # depth: 5 - profile_encoder: + profile_encoder: name: BaseShoeboxEncoder params: # in_channels: 1 @@ -91,7 +91,7 @@ model_settings: # params: # concentration: 6.68 # rate: 0.001 - + background_prior_distribution: # name: TorchHalfNormal # params: @@ -121,7 +121,3 @@ loss_settings: phenix_settings: r_values_reference_path: "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log" - - - - diff --git a/docs/Makefile b/docs/Makefile index 43a2023..6ff62d7 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -20,4 +20,4 @@ help: @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) watch: - sphinx-autobuild . _build/html --open-browser --watch examples \ No newline at end of file + sphinx-autobuild . _build/html --open-browser --watch examples diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index dd3ed27..15ee00f 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -7,4 +7,4 @@
  • Home
  • API Reference
  • -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/docs/api/index.rst b/docs/api/index.rst index 76477f4..0cb747f 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -7,4 +7,4 @@ API Reference model distributions -This section provides detailed API documentation for Factory's modules and functions. \ No newline at end of file +This section provides detailed API documentation for Factory's modules and functions. diff --git a/docs/conf.py b/docs/conf.py index b2a73d6..fb88eb5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -227,4 +227,4 @@ def linkcode_resolve(domain, info): fn = os.path.relpath(fn, start=os.path.dirname("../factory")) - return f"https://github.com/rs-station/factory/blob/main/factory/{fn}{linespec}" # noqa \ No newline at end of file + return f"https://github.com/rs-station/factory/blob/main/factory/{fn}{linespec}" # noqa diff --git a/docs/convert_favicon.py b/docs/convert_favicon.py index ec1aa44..c6abfcf 100644 --- a/docs/convert_favicon.py +++ b/docs/convert_favicon.py @@ -1,4 +1,9 @@ import cairosvg # Convert SVG to PNG -cairosvg.svg2png(url='images/favicon.svg', write_to='images/favicon_32x32.png', output_width=32, output_height=32) \ No newline at end of file +cairosvg.svg2png( + url="images/favicon.svg", + write_to="images/favicon_32x32.png", + output_width=32, + output_height=32, +) diff --git a/docs/images/favicon.svg b/docs/images/favicon.svg index 123e16e..76b8aa6 100644 --- a/docs/images/favicon.svg +++ b/docs/images/favicon.svg @@ -2,4 +2,4 @@ F - \ No newline at end of file + diff --git a/docs/index.rst b/docs/index.rst index bac7202..106f96b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,4 +16,4 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` -* :ref:`search` \ No newline at end of file +* :ref:`search` diff --git a/docs/installation.rst b/docs/installation.rst index 90f41e9..d2c8097 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -21,4 +21,4 @@ For development installation: git clone https://github.com/rs-station/factory.git cd factory - pip install -e ".[test,docs]" \ No newline at end of file + pip install -e ".[test,docs]" diff --git a/docs/make.bat b/docs/make.bat index 541f733..922152e 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -32,4 +32,4 @@ goto end %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end -popd \ No newline at end of file +popd diff --git a/docs/usage.rst b/docs/usage.rst index 0f57a4e..b1e940a 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -14,4 +14,4 @@ Here's a simple example of how to use Factory: # Create and use your model TODO -For more detailed examples and API reference, see the :doc:`api/index` section. \ No newline at end of file +For more detailed examples and API reference, see the :doc:`api/index` section. diff --git a/pymol_figures/refine_001.pdb b/pymol_figures/refine_001.pdb index ec8d9bb..ad2c5f0 100644 --- a/pymol_figures/refine_001.pdb +++ b/pymol_figures/refine_001.pdb @@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini, REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart REMARK 3 REMARK 3 X-RAY DATA. -REMARK 3 +REMARK 3 REMARK 3 REFINEMENT TARGET : ML -REMARK 3 +REMARK 3 REMARK 3 DATA USED IN REFINEMENT. -REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 -REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 -REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.33 -REMARK 3 COMPLETENESS FOR RANGE (%) : 90.85 -REMARK 3 NUMBER OF REFLECTIONS : 85983 -REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45512 -REMARK 3 +REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 +REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 +REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.33 +REMARK 3 COMPLETENESS FOR RANGE (%) : 90.85 +REMARK 3 NUMBER OF REFLECTIONS : 85983 +REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45512 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT. REMARK 3 R VALUE (WORKING + TEST SET) : 0.1221 REMARK 3 R VALUE (WORKING SET) : 0.1203 REMARK 3 FREE R VALUE : 0.1386 -REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 -REMARK 3 FREE R VALUE TEST SET COUNT : 8804 -REMARK 3 +REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 +REMARK 3 FREE R VALUE TEST SET COUNT : 8804 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS). REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE REMARK 3 1 56.16 - 3.41 1.00 2849 309 0.1392 0.1691 0.948 0.914 @@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.68 1903 241 0.3115 0.3307 0.747 REMARK 3 28 1.14 - 1.12 0.54 1517 170 0.3396 0.3314 0.703 0.701 REMARK 3 29 1.12 - 1.11 0.34 960 84 0.3824 0.4505 0.607 0.613 REMARK 3 30 1.11 - 1.10 0.11 324 39 0.4270 0.4367 0.652 0.359 -REMARK 3 +REMARK 3 REMARK 3 BULK SOLVENT MODELLING. REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL -REMARK 3 SOLVENT RADIUS : 1.10 -REMARK 3 SHRINKAGE RADIUS : 0.90 -REMARK 3 GRID STEP : 0.60 -REMARK 3 +REMARK 3 SOLVENT RADIUS : 1.10 +REMARK 3 SHRINKAGE RADIUS : 0.90 +REMARK 3 GRID STEP : 0.60 +REMARK 3 REMARK 3 ERROR ESTIMATES. -REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.10 -REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 12.84 -REMARK 3 +REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.10 +REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 12.84 +REMARK 3 REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT -REMARK 3 +REMARK 3 REMARK 3 B VALUES. -REMARK 3 FROM WILSON PLOT (A**2) : 15.22 +REMARK 3 FROM WILSON PLOT (A**2) : 15.22 REMARK 3 Individual atomic B REMARK 3 min max mean iso aniso REMARK 3 Overall: 11.38 62.75 20.78 1.59 0 1284 @@ -92,7 +92,7 @@ REMARK 3 42.20 - 47.34 24 REMARK 3 47.34 - 52.48 11 REMARK 3 52.48 - 57.61 2 REMARK 3 57.61 - 62.75 3 -REMARK 3 +REMARK 3 REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2 REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES. REMARK 3 BOND : 0.008 0.066 1227 Z= 0.503 @@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.085 0.253 163 REMARK 3 PLANARITY : 0.009 0.070 228 REMARK 3 DIHEDRAL : 12.968 82.313 464 REMARK 3 MIN NONBONDED DISTANCE : 2.253 -REMARK 3 +REMARK 3 REMARK 3 MOLPROBITY STATISTICS. REMARK 3 ALL-ATOM CLASHSCORE : 0.00 REMARK 3 RAMACHANDRAN PLOT: @@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 % REMARK 3 CIS-GENERAL : 0.00 % REMARK 3 TWISTED PROLINE : 0.00 % REMARK 3 TWISTED GENERAL : 0.00 % -REMARK 3 +REMARK 3 REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE): REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2. REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY; @@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.12 (0.63), RESIDUES: 173 REMARK 3 HELIX: -1.09 (0.58), RESIDUES: 68 REMARK 3 SHEET: -1.28 (0.81), RESIDUES: 14 REMARK 3 LOOP : 1.25 (0.70), RESIDUES: 91 -REMARK 3 +REMARK 3 REMARK 3 MAX DEVIATION FROM PLANES: REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE -REMARK 3 TRP 0.041 0.009 TRP A 123 -REMARK 3 HIS 0.005 0.003 HIS A 15 -REMARK 3 PHE 0.037 0.008 PHE A 38 -REMARK 3 TYR 0.017 0.006 TYR A 20 -REMARK 3 ARG 0.009 0.002 ARG A 114 +REMARK 3 TRP 0.041 0.009 TRP A 123 +REMARK 3 HIS 0.005 0.003 HIS A 15 +REMARK 3 PHE 0.037 0.008 PHE A 38 +REMARK 3 TYR 0.017 0.006 TYR A 20 +REMARK 3 ARG 0.009 0.002 ARG A 114 REMARK 3 HELIX 1 AA1 GLY A 4 HIS A 15 1 12 HELIX 2 AA2 ASN A 19 TYR A 23 5 5 @@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7 SHEET 1 AA1 3 THR A 43 ARG A 45 0 SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44 SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53 -SSBOND 1 CYS A 6 CYS A 127 -SSBOND 2 CYS A 30 CYS A 115 -SSBOND 3 CYS A 64 CYS A 80 -SSBOND 4 CYS A 76 CYS A 94 +SSBOND 1 CYS A 6 CYS A 127 +SSBOND 2 CYS A 30 CYS A 115 +SSBOND 3 CYS A 64 CYS A 80 +SSBOND 4 CYS A 76 CYS A 94 CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2 SCALE1 0.012591 0.000000 0.000000 0.00000 SCALE2 0.000000 0.012591 0.000000 0.00000 diff --git a/pymol_figures/refine_epoch=03_001.pdb b/pymol_figures/refine_epoch=03_001.pdb index 6456a71..5b46a09 100644 --- a/pymol_figures/refine_epoch=03_001.pdb +++ b/pymol_figures/refine_epoch=03_001.pdb @@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini, REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart REMARK 3 REMARK 3 X-RAY DATA. -REMARK 3 +REMARK 3 REMARK 3 REFINEMENT TARGET : ML -REMARK 3 +REMARK 3 REMARK 3 DATA USED IN REFINEMENT. -REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 -REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 -REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32 -REMARK 3 COMPLETENESS FOR RANGE (%) : 91.77 -REMARK 3 NUMBER OF REFLECTIONS : 86205 -REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45591 -REMARK 3 +REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 +REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 +REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32 +REMARK 3 COMPLETENESS FOR RANGE (%) : 91.77 +REMARK 3 NUMBER OF REFLECTIONS : 86205 +REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45591 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT. REMARK 3 R VALUE (WORKING + TEST SET) : 0.1313 REMARK 3 R VALUE (WORKING SET) : 0.1292 REMARK 3 FREE R VALUE : 0.1496 -REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 -REMARK 3 FREE R VALUE TEST SET COUNT : 8826 -REMARK 3 +REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 +REMARK 3 FREE R VALUE TEST SET COUNT : 8826 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS). REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE REMARK 3 1 56.16 - 3.42 1.00 2832 303 0.1380 0.1645 0.948 0.922 @@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.71 1992 232 0.4224 0.3941 0.547 REMARK 3 28 1.14 - 1.13 0.57 1632 177 0.4693 0.4421 0.466 0.457 REMARK 3 29 1.12 - 1.11 0.41 1145 121 0.5244 0.4445 0.352 0.592 REMARK 3 30 1.11 - 1.10 0.21 587 59 0.5427 0.6604 0.314 0.371 -REMARK 3 +REMARK 3 REMARK 3 BULK SOLVENT MODELLING. REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL -REMARK 3 SOLVENT RADIUS : 1.10 -REMARK 3 SHRINKAGE RADIUS : 0.90 -REMARK 3 GRID STEP : 0.60 -REMARK 3 +REMARK 3 SOLVENT RADIUS : 1.10 +REMARK 3 SHRINKAGE RADIUS : 0.90 +REMARK 3 GRID STEP : 0.60 +REMARK 3 REMARK 3 ERROR ESTIMATES. -REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.14 -REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 17.67 -REMARK 3 +REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.14 +REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 17.67 +REMARK 3 REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT -REMARK 3 +REMARK 3 REMARK 3 B VALUES. -REMARK 3 FROM WILSON PLOT (A**2) : 16.28 +REMARK 3 FROM WILSON PLOT (A**2) : 16.28 REMARK 3 Individual atomic B REMARK 3 min max mean iso aniso REMARK 3 Overall: 12.92 61.47 22.37 1.60 0 1284 @@ -92,7 +92,7 @@ REMARK 3 42.05 - 46.91 29 REMARK 3 46.91 - 51.76 13 REMARK 3 51.76 - 56.62 4 REMARK 3 56.62 - 61.47 3 -REMARK 3 +REMARK 3 REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2 REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES. REMARK 3 BOND : 0.007 0.066 1227 Z= 0.465 @@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.082 0.243 163 REMARK 3 PLANARITY : 0.009 0.075 228 REMARK 3 DIHEDRAL : 12.908 82.367 464 REMARK 3 MIN NONBONDED DISTANCE : 2.274 -REMARK 3 +REMARK 3 REMARK 3 MOLPROBITY STATISTICS. REMARK 3 ALL-ATOM CLASHSCORE : 0.00 REMARK 3 RAMACHANDRAN PLOT: @@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 % REMARK 3 CIS-GENERAL : 0.00 % REMARK 3 TWISTED PROLINE : 0.00 % REMARK 3 TWISTED GENERAL : 0.00 % -REMARK 3 +REMARK 3 REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE): REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2. REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY; @@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.43 (0.62), RESIDUES: 173 REMARK 3 HELIX: -1.43 (0.56), RESIDUES: 68 REMARK 3 SHEET: -1.44 (0.81), RESIDUES: 14 REMARK 3 LOOP : 1.15 (0.69), RESIDUES: 91 -REMARK 3 +REMARK 3 REMARK 3 MAX DEVIATION FROM PLANES: REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE -REMARK 3 TRP 0.034 0.008 TRP A 123 -REMARK 3 HIS 0.007 0.005 HIS A 15 -REMARK 3 PHE 0.027 0.008 PHE A 38 -REMARK 3 TYR 0.014 0.005 TYR A 20 -REMARK 3 ARG 0.007 0.002 ARG A 114 +REMARK 3 TRP 0.034 0.008 TRP A 123 +REMARK 3 HIS 0.007 0.005 HIS A 15 +REMARK 3 PHE 0.027 0.008 PHE A 38 +REMARK 3 TYR 0.014 0.005 TYR A 20 +REMARK 3 ARG 0.007 0.002 ARG A 114 REMARK 3 HELIX 1 AA1 GLY A 4 HIS A 15 1 12 HELIX 2 AA2 ASN A 19 TYR A 23 5 5 @@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7 SHEET 1 AA1 3 THR A 43 ARG A 45 0 SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44 SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53 -SSBOND 1 CYS A 6 CYS A 127 -SSBOND 2 CYS A 30 CYS A 115 -SSBOND 3 CYS A 64 CYS A 80 -SSBOND 4 CYS A 76 CYS A 94 +SSBOND 1 CYS A 6 CYS A 127 +SSBOND 2 CYS A 30 CYS A 115 +SSBOND 3 CYS A 64 CYS A 80 +SSBOND 4 CYS A 76 CYS A 94 CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2 SCALE1 0.012591 0.000000 0.000000 0.00000 SCALE2 0.000000 0.012591 0.000000 0.00000 diff --git a/pymol_figures/refine_epoch=127_001.pdb b/pymol_figures/refine_epoch=127_001.pdb index b18b022..7e1ba94 100644 --- a/pymol_figures/refine_epoch=127_001.pdb +++ b/pymol_figures/refine_epoch=127_001.pdb @@ -8,24 +8,24 @@ REMARK 3 : Oeffner,Poon,Read,Richardson,Richardson,Sacchettini, REMARK 3 : Sauter,Sobolev,Storoni,Terwilliger,Williams,Zwart REMARK 3 REMARK 3 X-RAY DATA. -REMARK 3 +REMARK 3 REMARK 3 REFINEMENT TARGET : ML -REMARK 3 +REMARK 3 REMARK 3 DATA USED IN REFINEMENT. -REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 -REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 -REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32 -REMARK 3 COMPLETENESS FOR RANGE (%) : 91.74 -REMARK 3 NUMBER OF REFLECTIONS : 86181 -REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45584 -REMARK 3 +REMARK 3 RESOLUTION RANGE HIGH (ANGSTROMS) : 1.10 +REMARK 3 RESOLUTION RANGE LOW (ANGSTROMS) : 56.16 +REMARK 3 MIN(FOBS/SIGMA_FOBS) : 1.32 +REMARK 3 COMPLETENESS FOR RANGE (%) : 91.74 +REMARK 3 NUMBER OF REFLECTIONS : 86181 +REMARK 3 NUMBER OF REFLECTIONS (NON-ANOMALOUS) : 45584 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT. REMARK 3 R VALUE (WORKING + TEST SET) : 0.1330 REMARK 3 R VALUE (WORKING SET) : 0.1309 REMARK 3 FREE R VALUE : 0.1513 -REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 -REMARK 3 FREE R VALUE TEST SET COUNT : 8823 -REMARK 3 +REMARK 3 FREE R VALUE TEST SET SIZE (%) : 10.24 +REMARK 3 FREE R VALUE TEST SET COUNT : 8823 +REMARK 3 REMARK 3 FIT TO DATA USED IN REFINEMENT (IN BINS). REMARK 3 BIN RESOLUTION RANGE COMPL. NWORK NFREE RWORK RFREE CCWORK CCFREE REMARK 3 1 56.16 - 3.42 1.00 2832 303 0.1397 0.1737 0.945 0.911 @@ -58,21 +58,21 @@ REMARK 3 27 1.15 - 1.14 0.71 1991 232 0.4554 0.4248 0.516 REMARK 3 28 1.14 - 1.13 0.57 1631 176 0.4802 0.4328 0.408 0.473 REMARK 3 29 1.12 - 1.11 0.41 1141 121 0.4715 0.4497 0.315 0.514 REMARK 3 30 1.11 - 1.10 0.21 587 59 0.4710 0.4957 0.177 0.424 -REMARK 3 +REMARK 3 REMARK 3 BULK SOLVENT MODELLING. REMARK 3 METHOD USED : FLAT BULK SOLVENT MODEL -REMARK 3 SOLVENT RADIUS : 1.10 -REMARK 3 SHRINKAGE RADIUS : 0.90 -REMARK 3 GRID STEP : 0.60 -REMARK 3 +REMARK 3 SOLVENT RADIUS : 1.10 +REMARK 3 SHRINKAGE RADIUS : 0.90 +REMARK 3 GRID STEP : 0.60 +REMARK 3 REMARK 3 ERROR ESTIMATES. -REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.16 -REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 20.50 -REMARK 3 +REMARK 3 COORDINATE ERROR (MAXIMUM-LIKELIHOOD BASED) : 0.16 +REMARK 3 PHASE ERROR (DEGREES, MAXIMUM-LIKELIHOOD BASED) : 20.50 +REMARK 3 REMARK 3 STRUCTURE FACTORS CALCULATION ALGORITHM : FFT -REMARK 3 +REMARK 3 REMARK 3 B VALUES. -REMARK 3 FROM WILSON PLOT (A**2) : 19.58 +REMARK 3 FROM WILSON PLOT (A**2) : 19.58 REMARK 3 Individual atomic B REMARK 3 min max mean iso aniso REMARK 3 Overall: 16.05 71.49 25.98 1.68 0 1284 @@ -92,7 +92,7 @@ REMARK 3 49.31 - 54.86 23 REMARK 3 54.86 - 60.40 10 REMARK 3 60.40 - 65.95 2 REMARK 3 65.95 - 71.49 2 -REMARK 3 +REMARK 3 REMARK 3 GEOMETRY RESTRAINTS LIBRARY: GEOSTD + MONOMER LIBRARY + CDL V1.2 REMARK 3 DEVIATIONS FROM IDEAL VALUES - RMSD, RMSZ FOR BONDS AND ANGLES. REMARK 3 BOND : 0.007 0.063 1227 Z= 0.483 @@ -101,7 +101,7 @@ REMARK 3 CHIRALITY : 0.084 0.249 163 REMARK 3 PLANARITY : 0.009 0.069 228 REMARK 3 DIHEDRAL : 12.938 82.333 464 REMARK 3 MIN NONBONDED DISTANCE : 2.245 -REMARK 3 +REMARK 3 REMARK 3 MOLPROBITY STATISTICS. REMARK 3 ALL-ATOM CLASHSCORE : 0.00 REMARK 3 RAMACHANDRAN PLOT: @@ -118,7 +118,7 @@ REMARK 3 CIS-PROLINE : 0.00 % REMARK 3 CIS-GENERAL : 0.00 % REMARK 3 TWISTED PROLINE : 0.00 % REMARK 3 TWISTED GENERAL : 0.00 % -REMARK 3 +REMARK 3 REMARK 3 RAMA-Z (RAMACHANDRAN PLOT Z-SCORE): REMARK 3 INTERPRETATION: BAD |RAMA-Z| > 3; SUSPICIOUS 2 < |RAMA-Z| < 3; GOOD |RAMA-Z| < 2. REMARK 3 SCORES FOR WHOLE/HELIX/SHEET/LOOP ARE SCALED INDEPENDENTLY; @@ -127,14 +127,14 @@ REMARK 3 WHOLE: -0.33 (0.63), RESIDUES: 173 REMARK 3 HELIX: -1.30 (0.57), RESIDUES: 68 REMARK 3 SHEET: -1.46 (0.79), RESIDUES: 14 REMARK 3 LOOP : 1.18 (0.70), RESIDUES: 91 -REMARK 3 +REMARK 3 REMARK 3 MAX DEVIATION FROM PLANES: REMARK 3 TYPE MAXDEV MEANDEV LINEINFILE -REMARK 3 TRP 0.032 0.009 TRP A 123 -REMARK 3 HIS 0.007 0.005 HIS A 15 -REMARK 3 PHE 0.032 0.008 PHE A 38 -REMARK 3 TYR 0.014 0.006 TYR A 20 -REMARK 3 ARG 0.010 0.002 ARG A 114 +REMARK 3 TRP 0.032 0.009 TRP A 123 +REMARK 3 HIS 0.007 0.005 HIS A 15 +REMARK 3 PHE 0.032 0.008 PHE A 38 +REMARK 3 TYR 0.014 0.006 TYR A 20 +REMARK 3 ARG 0.010 0.002 ARG A 114 REMARK 3 HELIX 1 AA1 GLY A 4 HIS A 15 1 12 HELIX 2 AA2 ASN A 19 TYR A 23 5 5 @@ -147,10 +147,10 @@ HELIX 8 AA8 ASP A 119 ARG A 125 5 7 SHEET 1 AA1 3 THR A 43 ARG A 45 0 SHEET 2 AA1 3 THR A 51 TYR A 53 -1 O ASP A 52 N ASN A 44 SHEET 3 AA1 3 ILE A 58 ASN A 59 -1 O ILE A 58 N TYR A 53 -SSBOND 1 CYS A 6 CYS A 127 -SSBOND 2 CYS A 30 CYS A 115 -SSBOND 3 CYS A 64 CYS A 80 -SSBOND 4 CYS A 76 CYS A 94 +SSBOND 1 CYS A 6 CYS A 127 +SSBOND 2 CYS A 30 CYS A 115 +SSBOND 3 CYS A 64 CYS A 80 +SSBOND 4 CYS A 76 CYS A 94 CRYST1 79.424 79.424 37.793 90.00 90.00 90.00 P 43 21 2 SCALE1 0.012591 0.000000 0.000000 0.00000 SCALE2 0.000000 0.012591 0.000000 0.00000 diff --git a/pyproject.toml b/pyproject.toml index 6e9fb60..fc06aa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,4 +35,4 @@ docs = [ source = "git" [tool.hatch.build.targets.wheel] -packages = ["src/factory"] \ No newline at end of file +packages = ["src/factory"] diff --git a/setup.py b/setup.py index 2fa3700..90a1d00 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,12 @@ import sys -import sys -sys.stderr.write( - """ +sys.stderr.write(""" =============================== Unsupported installation method =============================== factory does not support installation with `python setup.py install`. Please use `python -m pip install .` instead. -""" -) +""") sys.exit(1) @@ -22,4 +19,4 @@ setup( # noqa name="factory", install_requires=[], -) \ No newline at end of file +) diff --git a/src/factory/CC12.py b/src/factory/CC12.py index a28340d..c349381 100644 --- a/src/factory/CC12.py +++ b/src/factory/CC12.py @@ -1,15 +1,18 @@ -import torch -import wandb import os -import numpy as np + import matplotlib.pyplot as plt -from lightning.pytorch.loggers import WandbLogger +import numpy as np +import torch +import wandb from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from model import * wandb.init(project="CC12") wandb_logger = WandbLogger(project="CC12", name="CC12-Training") -artifact = wandb.use_artifact("flaviagiehr-harvard-university/full-model/best_model:latest", type="model") +artifact = wandb.use_artifact( + "flaviagiehr-harvard-university/full-model/best_model:latest", type="model" +) artifact_dir = artifact.download() ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")] @@ -17,27 +20,31 @@ checkpoint_path = os.path.join(artifact_dir, ckpt_files[0]) checkpoint = torch.load(checkpoint_path, map_location="cuda") -state_dict = checkpoint['state_dict'] +state_dict = checkpoint["state_dict"] filtered_state_dict = { - k: v for k, v in state_dict.items() - if not k.startswith("surrogate_posterior.") + k: v for k, v in state_dict.items() if not k.startswith("surrogate_posterior.") } settings = Settings() loss_settings = LossSettings() -data_loader_settings = data_loader.DataLoaderSettings(data_directory=settings.data_directory, - data_file_names=settings.data_file_names, - validation_set_split=0.5, test_set_split=0 - ) +data_loader_settings = data_loader.DataLoaderSettings( + data_directory=settings.data_directory, + data_file_names=settings.data_file_names, + validation_set_split=0.5, + test_set_split=0, +) _dataloader = data_loader.CrystallographicDataLoader(settings=data_loader_settings) _dataloader.load_data_() + def _train_model(model_no, dataloader, settings, loss_settings): model = Model(settings=settings, loss_settings=loss_settings, dataloader=dataloader) print("model type:", type(model)) - missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False) + missing_keys, unexpected_keys = model.load_state_dict( + filtered_state_dict, strict=False + ) print("Missing keys (expected for surrogate posterior):", missing_keys) print("Unexpected keys (should be none):", unexpected_keys) @@ -47,27 +54,34 @@ def _train_model(model_no, dataloader, settings, loss_settings): print(list(state_dict.keys())[:10]) trainer = Trainer( - logger=wandb_logger, - max_epochs=20, - # log_every_n_steps=5, + logger=wandb_logger, + max_epochs=20, + # log_every_n_steps=5, val_check_interval=None, - accelerator="auto", - enable_checkpointing=False, + accelerator="auto", + enable_checkpointing=False, default_root_dir="/tmp", # callbacks=[Plotting(), LossLogging(), CorrelationPlotting(), checkpoint_callback] # ScalePlotting() ) if model_no == 1: - trainer.fit(model, train_dataloaders=_dataloader.load_data_set_batched_by_image( - data_set_to_load=_dataloader.train_data_set - )) + trainer.fit( + model, + train_dataloaders=_dataloader.load_data_set_batched_by_image( + data_set_to_load=_dataloader.train_data_set + ), + ) else: - trainer.fit(model, train_dataloaders=_dataloader.load_data_set_batched_by_image( - data_set_to_load=_dataloader.validation_data_set - )) + trainer.fit( + model, + train_dataloaders=_dataloader.load_data_set_batched_by_image( + data_set_to_load=_dataloader.validation_data_set + ), + ) print("finished training for model", model_no) return model + model1 = _train_model(1, _dataloader, settings, loss_settings) model2 = _train_model(2, _dataloader, settings, loss_settings) @@ -89,13 +103,15 @@ def _train_model(model_no, dataloader, settings, loss_settings): # outputs_1.append(out1.cpu()) # outputs_2.append(out2.cpu()) + def get_reliable_mask(model): - if hasattr(model.surrogate_posterior, 'reliable_observations_mask'): + if hasattr(model.surrogate_posterior, "reliable_observations_mask"): return model.surrogate_posterior.reliable_observations_mask(min_observations=7) else: return model.surrogate_posterior.observed + reliable_mask_1 = get_reliable_mask(model1) reliable_mask_2 = get_reliable_mask(model2) @@ -103,27 +119,38 @@ def get_reliable_mask(model): # Plot histogram of observation counts fig_hist, ax_hist = plt.subplots(figsize=(10, 6)) -observation_counts_1 = model1.surrogate_posterior.observation_count.detach().cpu().numpy() -observation_counts_2 = model2.surrogate_posterior.observation_count.detach().cpu().numpy() - -ax_hist.hist([observation_counts_1, observation_counts_2], - bins=30, - alpha=0.5, - label=['Model 1', 'Model 2']) -ax_hist.set_xlabel('Number of Observations') -ax_hist.set_ylabel('Frequency') -ax_hist.set_title('Distribution of Observation Counts') +observation_counts_1 = ( + model1.surrogate_posterior.observation_count.detach().cpu().numpy() +) +observation_counts_2 = ( + model2.surrogate_posterior.observation_count.detach().cpu().numpy() +) + +ax_hist.hist( + [observation_counts_1, observation_counts_2], + bins=30, + alpha=0.5, + label=["Model 1", "Model 2"], +) +ax_hist.set_xlabel("Number of Observations") +ax_hist.set_ylabel("Frequency") +ax_hist.set_title("Distribution of Observation Counts") ax_hist.legend() wandb.log({"Observation Counts Histogram": wandb.Image(fig_hist)}) plt.close(fig_hist) -structure_factors_1 = torch.cat([model1.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy() -structure_factors_2 = torch.cat([model2.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy() +structure_factors_1 = ( + torch.cat([model1.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy() +) +structure_factors_2 = ( + torch.cat([model2.surrogate_posterior.mean], dim=0).flatten().detach().cpu().numpy() +) structure_factors_1 = structure_factors_1[common_reliable_mask.cpu().numpy()] structure_factors_2 = structure_factors_2[common_reliable_mask.cpu().numpy()] + def compute_cc12(a, b): # a = a.detach().cpu().numpy() # b = b.detach().cpu().numpy() @@ -132,17 +159,16 @@ def compute_cc12(a, b): mean_b = np.mean(b) numerator = np.sum((a - mean_a) * (b - mean_b)) - denominator = np.sqrt(np.sum((a - mean_a)**2) * np.sum((b - mean_b)**2)) + denominator = np.sqrt(np.sum((a - mean_a) ** 2) * np.sum((b - mean_b) ** 2)) return numerator / (denominator + 1e-8) + cc12 = compute_cc12(structure_factors_1, structure_factors_2) print("CC12:", cc12) fig, ax = plt.subplots(figsize=(10, 6)) -ax.scatter(structure_factors_1, - structure_factors_2, - alpha=0.5) +ax.scatter(structure_factors_1, structure_factors_2, alpha=0.5) ax.set_xlabel("Model 1 Structure Factors") ax.set_ylabel("Model 2 Structure Factors") ax.set_title(f"CC12 = {cc12:.3f}") diff --git a/src/factory/anomalous_peaks.py b/src/factory/anomalous_peaks.py index d2f77f1..e1ba0b1 100644 --- a/src/factory/anomalous_peaks.py +++ b/src/factory/anomalous_peaks.py @@ -1,26 +1,29 @@ -import torch -import pandas as pd -import numpy as np -import subprocess import os -import matplotlib.pyplot as plt +import subprocess -import wandb +import data_loader +import get_protein_data +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import pytorch_lightning as L - import reciprocalspaceship as rs import rs_distributions as rsd - +import torch +import wandb from model import * -import data_loader -import get_protein_data - -print(rs.read_mtz("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz")) +print( + rs.read_mtz( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz" + ) +) wandb.init(project="anomalous peaks", name="anomalous-peaks") -artifact = wandb.use_artifact("flaviagiehr-harvard-university/full-model/best_model:latest", type="model") +artifact = wandb.use_artifact( + "flaviagiehr-harvard-university/full-model/best_model:latest", type="model" +) artifact_dir = artifact.download() ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")] @@ -30,21 +33,27 @@ settings = Settings() loss_settings = LossSettings() -model = Model.load_from_checkpoint(checkpoint_path, settings=settings, loss_settings=loss_settings)#, dataloader=dataloader) +model = Model.load_from_checkpoint( + checkpoint_path, settings=settings, loss_settings=loss_settings +) # , dataloader=dataloader) model.eval() repo_dir = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files" mtz_output_path = os.path.join(repo_dir, "phenix_ready.mtz") -surrogate_posterior_dataset = list(model.surrogate_posterior.to_dataset(only_observed=True))[0] +surrogate_posterior_dataset = list( + model.surrogate_posterior.to_dataset(only_observed=True) +)[0] # Additional validation checks if len(surrogate_posterior_dataset) == 0: raise ValueError("Dataset has no reflections!") # Check if we have the required columns -required_columns = ['F(+)', 'F(-)', 'SIGF(+)', 'SIGF(-)'] -missing_columns = [col for col in required_columns if col not in surrogate_posterior_dataset.columns] +required_columns = ["F(+)", "F(-)", "SIGF(+)", "SIGF(-)"] +missing_columns = [ + col for col in required_columns if col not in surrogate_posterior_dataset.columns +] if missing_columns: raise ValueError(f"Missing required columns: {missing_columns}") @@ -56,13 +65,16 @@ print(f"Warning: Column {col} contains infinite values") # Ensure d_min is set -if not hasattr(surrogate_posterior_dataset, 'd_min') or surrogate_posterior_dataset.d_min is None: +if ( + not hasattr(surrogate_posterior_dataset, "d_min") + or surrogate_posterior_dataset.d_min is None +): print("Setting d_min to 2.5Å") surrogate_posterior_dataset.d_min = 2.5 # Try to get wavelength from the dataset wavelength = None -if hasattr(surrogate_posterior_dataset, 'wavelength'): +if hasattr(surrogate_posterior_dataset, "wavelength"): wavelength = surrogate_posterior_dataset.wavelength print(f"\nWavelength from dataset: {wavelength}Å") else: @@ -88,7 +100,9 @@ # mtz_output_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz" # for the reference peaks -phenix_refine_eff_file_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/phenix.eff" +phenix_refine_eff_file_path = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/phenix.eff" +) cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {phenix_refine_eff_file_path} {mtz_output_path} overwrite=True" @@ -144,33 +158,36 @@ # from rs_distributions.peak_finding import peak_finder subprocess.run( -"rs.find_peaks *[0-9].mtz *[0-9].pdb " f"-f ANOM -p PANOM -z 5.0 -o peaks.csv", -shell=True, -cwd=f"{repo_dir}", + "rs.find_peaks *[0-9].mtz *[0-9].pdb " f"-f ANOM -p PANOM -z 5.0 -o peaks.csv", + shell=True, + cwd=f"{repo_dir}", ) # Save peaks -print ("saved peaks") +print("saved peaks") -peaks_df = pd.read_csv("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv") -print(peaks_df.columns) +peaks_df = pd.read_csv( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv" +) +print(peaks_df.columns) # print(peaks_df.iloc[0]) # print(peaks_df) # Create a nice table of peak heights -peak_heights_table = pd.DataFrame({ - 'Residue': peaks_df['residue'], - 'Peak_Height': peaks_df['peakz'] -}) +peak_heights_table = pd.DataFrame( + {"Residue": peaks_df["residue"], "Peak_Height": peaks_df["peakz"]} +) # Sort by peak height in descending order for better visualization -peak_heights_table = peak_heights_table.sort_values('Peak_Height', ascending=False).reset_index(drop=True) +peak_heights_table = peak_heights_table.sort_values( + "Peak_Height", ascending=False +).reset_index(drop=True) # Add rank column for easier reference -peak_heights_table.insert(0, 'Rank', range(1, len(peak_heights_table) + 1)) +peak_heights_table.insert(0, "Rank", range(1, len(peak_heights_table) + 1)) # Round peak heights to 3 decimal places for cleaner display -peak_heights_table['Peak_Height'] = peak_heights_table['Peak_Height'].round(3) +peak_heights_table["Peak_Height"] = peak_heights_table["Peak_Height"].round(3) print("\n=== Anomalous Peak Heights ===") print(peak_heights_table.to_string(index=False)) @@ -183,6 +200,5 @@ peaks_artifact = wandb.Artifact("anomalous_peaks", type="peaks") -peaks_artifact.add_file(os.path.join(repo_dir,"peaks.csv")) +peaks_artifact.add_file(os.path.join(repo_dir, "peaks.csv")) wandb.log_artifact(peaks_artifact) - diff --git a/src/factory/callbacks.py b/src/factory/callbacks.py index cc88fac..2f2443f 100644 --- a/src/factory/callbacks.py +++ b/src/factory/callbacks.py @@ -1,19 +1,20 @@ -import torch -import reciprocalspaceship as rs -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.callbacks import Callback -import wandb -from phenix_callback import find_peaks_from_model +import io -from torchvision.transforms import ToPILImage -from lightning.pytorch.utilities import grad_norm +import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np -import io +import pandas as pd +import reciprocalspaceship as rs +import torch +import wandb +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.utilities import grad_norm +from phenix_callback import find_peaks_from_model from PIL import Image from scipy.optimize import minimize -import matplotlib.cm as cm -import pandas as pd +from torchvision.transforms import ToPILImage + class PeakHeights(Callback): "Compute Anomalous Peak Heights during training." @@ -21,8 +22,8 @@ class PeakHeights(Callback): def __init__(self): super().__init__() self.residue_colors = {} # Maps residue to color - self.peak_history = {} # Maps residue to list of (epoch, peak_height) - self.colormap = cm.get_cmap('tab20') + self.peak_history = {} # Maps residue to list of (epoch, peak_height) + self.colormap = cm.get_cmap("tab20") # self.peakheights_log = open("peakheights.out", "a") # def __del__(self): @@ -37,10 +38,14 @@ def _get_color(self, residue): def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch - if epoch%2==0: + if epoch % 2 == 0: try: - r_work, r_free, path_to_peaks = find_peaks_from_model(model=pl_module, repo_dir=trainer.logger.experiment.dir) - pl_module.logger.experiment.log({"r/r_work": r_work, "r/r_free": r_free}) + r_work, r_free, path_to_peaks = find_peaks_from_model( + model=pl_module, repo_dir=trainer.logger.experiment.dir + ) + pl_module.logger.experiment.log( + {"r/r_work": r_work, "r/r_free": r_free} + ) peaks_df = pd.read_csv(path_to_peaks) epoch = trainer.current_epoch @@ -48,9 +53,9 @@ def on_train_epoch_end(self, trainer, pl_module): # Update peak history print("organize peaks") for _, row in peaks_df.iterrows(): - residue = row['residue'] + residue = row["residue"] print(residue) - peak_height = row['peakz'] + peak_height = row["peakz"] print(peak_height) if residue not in self.peak_history: self.peak_history[residue] = [] @@ -64,28 +69,27 @@ def on_train_epoch_end(self, trainer, pl_module): epochs, heights = zip(*history) color = self._get_color(residue) plt.plot(epochs, heights, label=str(residue), color=color) - plt.xlabel('Epoch') - plt.ylabel('Peak Height') - plt.title('Peak Heights per Residue') - plt.legend(title='Residue', bbox_to_anchor=(1.05, 1), loc='upper left') + plt.xlabel("Epoch") + plt.ylabel("Peak Height") + plt.title("Peak Heights per Residue") + plt.legend(title="Residue", bbox_to_anchor=(1.05, 1), loc="upper left") plt.tight_layout() # Log to wandb - pl_module.logger.experiment.log({ - "PeakHeights": wandb.Image(plt.gcf()), - "epoch": epoch - }) + pl_module.logger.experiment.log( + {"PeakHeights": wandb.Image(plt.gcf()), "epoch": epoch} + ) plt.close() except Exception as e: - print(f"Failed for {epoch}: {e}")#, file=self.peakheights_log, flush=True) + print( + f"Failed for {epoch}: {e}" + ) # , file=self.peakheights_log, flush=True) else: return - - class LossLogging(Callback): - + def __init__(self): super().__init__() self.loss_per_epoch: float = 0 @@ -96,48 +100,76 @@ def on_train_epoch_end(self, trainer, pl_module): pl_module.log("train/loss_epoch", avg, on_step=False, on_epoch=True) def on_validation_epoch_end(self, trainer, pl_module): - + train_loss = trainer.callback_metrics.get("train/loss_step_epoch") val_loss = trainer.callback_metrics.get("validation/loss_step_epoch") if train_loss is not None and val_loss is not None: difference_train_val_loss = train_loss - val_loss - pl_module.log("loss/train-val", difference_train_val_loss, on_step=False, on_epoch=True) + pl_module.log( + "loss/train-val", + difference_train_val_loss, + on_step=False, + on_epoch=True, + ) else: - print("Skipping loss/train-val logging: train or validation loss not available this epoch.") - + print( + "Skipping loss/train-val logging: train or validation loss not available this epoch." + ) + + class Plotting(Callback): def __init__(self, dataloader): super().__init__() self.fixed_batch = None self.dataloader = dataloader - + def on_fit_start(self, trainer, pl_module): - _raw_batch = next(iter(self.dataloader.load_data_for_logging_during_training(number_of_shoeboxes_to_log=8))) + _raw_batch = next( + iter( + self.dataloader.load_data_for_logging_during_training( + number_of_shoeboxes_to_log=8 + ) + ) + ) _raw_batch = [_batch_item.to(pl_module.device) for _batch_item in _raw_batch] self.fixed_batch = tuple(_raw_batch) - - self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(pl_module.device)[:, 9] - self.fixed_batch_intensity_sum_variance = self.fixed_batch[1].to(pl_module.device)[:, 10] - self.fixed_batch_total_counts = self.fixed_batch[3].to(pl_module.device).sum(dim=1) - self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[:, 13] + self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to( + pl_module.device + )[:, 9] + self.fixed_batch_intensity_sum_variance = self.fixed_batch[1].to( + pl_module.device + )[:, 10] + self.fixed_batch_total_counts = ( + self.fixed_batch[3].to(pl_module.device).sum(dim=1) + ) + + self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[ + :, 13 + ] def on_train_epoch_end(self, trainer, pl_module): with torch.no_grad(): - shoeboxes_batch, photon_rate_output, hkl_batch, counts_batch = pl_module(self.fixed_batch, verbose_output=True) + shoeboxes_batch, photon_rate_output, hkl_batch, counts_batch = pl_module( + self.fixed_batch, verbose_output=True + ) # Extract values from dictionary - samples_photon_rate = photon_rate_output['photon_rate'] - samples_profile = photon_rate_output['samples_profile'] - samples_background = photon_rate_output['samples_background'] - samples_scale = photon_rate_output['samples_scale'] - samples_predicted_structure_factor = photon_rate_output["samples_predicted_structure_factor"] - + samples_photon_rate = photon_rate_output["photon_rate"] + samples_profile = photon_rate_output["samples_profile"] + samples_background = photon_rate_output["samples_background"] + samples_scale = photon_rate_output["samples_scale"] + samples_predicted_structure_factor = photon_rate_output[ + "samples_predicted_structure_factor" + ] + del photon_rate_output - + batch_size = shoeboxes_batch.size(0) - model_intensity = samples_scale * torch.square(samples_predicted_structure_factor) + model_intensity = samples_scale * torch.square( + samples_predicted_structure_factor + ) model_intensity = torch.mean(model_intensity, dim=1) print("model_intensity", model_intensity.shape) @@ -146,92 +178,133 @@ def on_train_epoch_end(self, trainer, pl_module): print("photon rate", photon_rate.shape) print("profile", samples_profile.shape) print("background", samples_background.shape) - print("scale", samples_scale.shape) + print("scale", samples_scale.shape) fig, axes = plt.subplots( - 2, batch_size, - figsize=(4*batch_size, 16), - gridspec_kw={'hspace': 0.05, 'wspace': 0.3} + 2, + batch_size, + figsize=(4 * batch_size, 16), + gridspec_kw={"hspace": 0.05, "wspace": 0.3}, ) - + im_handles = [] for b in range(batch_size): - if len(shoeboxes_batch.shape)>2: - im1 = axes[0,b].imshow(shoeboxes_batch[b,:, -1].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy()) - im2 = axes[1,b].imshow(samples_profile[b,0,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy()) + if len(shoeboxes_batch.shape) > 2: + im1 = axes[0, b].imshow( + shoeboxes_batch[b, :, -1] + .reshape(3, 21, 21)[1:2] + .squeeze() + .cpu() + .detach() + .numpy() + ) + im2 = axes[1, b].imshow( + samples_profile[b, 0, :] + .reshape(3, 21, 21)[1:2] + .squeeze() + .cpu() + .detach() + .numpy() + ) else: - im1 = axes[0,b].imshow(shoeboxes_batch[b,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy()) - im2 = axes[1,b].imshow(samples_profile[b,0,:].reshape(3,21,21)[1:2].squeeze().cpu().detach().numpy()) + im1 = axes[0, b].imshow( + shoeboxes_batch[b, :] + .reshape(3, 21, 21)[1:2] + .squeeze() + .cpu() + .detach() + .numpy() + ) + im2 = axes[1, b].imshow( + samples_profile[b, 0, :] + .reshape(3, 21, 21)[1:2] + .squeeze() + .cpu() + .detach() + .numpy() + ) im_handles.append(im1) # Build title string with each element on its own row (3 rows) title_lines = [ - f'Raw Shoebox {b}', - f'I = {self.fixed_batch_intensity_sum_values[b]:.2f} pm {self.fixed_batch_intensity_sum_variance[b]:.2f}', - f'Bkg (mean) = {self.fixed_batch_background_mean[b]:.2f}', - f'total counts = {self.fixed_batch_total_counts[b]:.2f}' + f"Raw Shoebox {b}", + f"I = {self.fixed_batch_intensity_sum_values[b]:.2f} pm {self.fixed_batch_intensity_sum_variance[b]:.2f}", + f"Bkg (mean) = {self.fixed_batch_background_mean[b]:.2f}", + f"total counts = {self.fixed_batch_total_counts[b]:.2f}", ] - axes[0,b].set_title('\n'.join(title_lines), pad=5, fontsize=16) - axes[0,b].axis('off') - + axes[0, b].set_title("\n".join(title_lines), pad=5, fontsize=16) + axes[0, b].axis("off") + title_prf_lines = [ - f'Profile {b}', - f'photon_rate = {photon_rate[b].sum():.2f}', - f'I(F) = {model_intensity[b].item():.2f}', - f'Scale = {torch.mean(samples_scale, dim=1)[b].item():.2f}', - f'F= {torch.mean(samples_predicted_structure_factor, dim=1)[b].item():.2f}', - f'Bkg (mean) = {torch.mean(samples_background, dim=1)[b].item():.2f}' + f"Profile {b}", + f"photon_rate = {photon_rate[b].sum():.2f}", + f"I(F) = {model_intensity[b].item():.2f}", + f"Scale = {torch.mean(samples_scale, dim=1)[b].item():.2f}", + f"F= {torch.mean(samples_predicted_structure_factor, dim=1)[b].item():.2f}", + f"Bkg (mean) = {torch.mean(samples_background, dim=1)[b].item():.2f}", ] - axes[1,b].set_title('\n'.join(title_prf_lines), pad=5, fontsize=16) - axes[1,b].axis('off') - + axes[1, b].set_title("\n".join(title_prf_lines), pad=5, fontsize=16) + axes[1, b].axis("off") + # Add individual colorbars for each column # Colorbar for the raw shoebox (top row) - cbar1 = fig.colorbar(im1, ax=axes[0,b], orientation='vertical', fraction=0.02, pad=0.04) + cbar1 = fig.colorbar( + im1, ax=axes[0, b], orientation="vertical", fraction=0.02, pad=0.04 + ) cbar1.ax.tick_params(labelsize=8) - + # Colorbar for the profile (bottom row) - cbar2 = fig.colorbar(im2, ax=axes[1,b], orientation='vertical', fraction=0.02, pad=0.04) + cbar2 = fig.colorbar( + im2, ax=axes[1, b], orientation="vertical", fraction=0.02, pad=0.04 + ) cbar2.ax.tick_params(labelsize=8) - - - + plt.tight_layout(h_pad=0.0, w_pad=0.3) # Remove extra vertical padding buf = io.BytesIO() - plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') + plt.savefig(buf, format="png", dpi=100, bbox_inches="tight") buf.seek(0) - + # Convert buffer to PIL Image pil_image = Image.open(buf) - - pl_module.logger.experiment.log({ - "Profiles": wandb.Image(pil_image), - "epoch": trainer.current_epoch - }) - + + pl_module.logger.experiment.log( + {"Profiles": wandb.Image(pil_image), "epoch": trainer.current_epoch} + ) + plt.close() buf.close() - + # Memory optimization: Clear large tensors after use - del samples_photon_rate, samples_profile, samples_background, samples_scale, samples_predicted_structure_factor + del ( + samples_photon_rate, + samples_profile, + samples_background, + samples_scale, + samples_predicted_structure_factor, + ) torch.cuda.empty_cache() # Force GPU memory cleanup + class CorrelationPlottingBinned(Callback): - + def __init__(self, dataloader): super().__init__() self.dials_reference = None self.fixed_batch = None # self.fixed_hkl_indices_for_surrogate_posterior = None self.dataloader = dataloader - + def on_fit_start(self, trainer, pl_module): print("on fit start callback") - self.dials_reference = rs.read_mtz(pl_module.model_settings.merged_mtz_file_path) - self.dials_reference, self.resolution_bins = self.dials_reference.assign_resolution_bins(bins=10) + self.dials_reference = rs.read_mtz( + pl_module.model_settings.merged_mtz_file_path + ) + self.dials_reference, self.resolution_bins = ( + self.dials_reference.assign_resolution_bins(bins=10) + ) print("Read mtz successfully: ", self.dials_reference["F"].shape) print("head hkl", self.dials_reference.get_hkls().shape) @@ -244,42 +317,61 @@ def on_train_epoch_end(self, trainer, pl_module): print("on train epoch end in callback") try: # shoeboxes_batch, photon_rate_output, hkl_batch = pl_module(self.fixed_batch, log_images=True) - samples_surrogate_posterior = pl_module.surrogate_posterior.mean - print("full samples", samples_surrogate_posterior.shape) + print("full samples", samples_surrogate_posterior.shape) # print("gathered ", photon_rate_output['samples_predicted_structure_factor'].shape) rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0] - print("rasu id (all teh same as the first one?):",pl_module.surrogate_posterior.rac.rasu_ids) - - if hasattr(pl_module.surrogate_posterior, 'reliable_observations_mask'): - reliable_mask = pl_module.surrogate_posterior.reliable_observations_mask(min_observations=50) + print( + "rasu id (all teh same as the first one?):", + pl_module.surrogate_posterior.rac.rasu_ids, + ) + + if hasattr(pl_module.surrogate_posterior, "reliable_observations_mask"): + reliable_mask = ( + pl_module.surrogate_posterior.reliable_observations_mask( + min_observations=50 + ) + ) print("take reliable mask for correlation plotting") - + else: reliable_mask = pl_module.surrogate_posterior.observed print("observed attribute", pl_module.surrogate_posterior.observed.sum()) - mask_observed_indexed_as_dials_reference = pl_module.surrogate_posterior.rac.gather( - source=pl_module.surrogate_posterior.observed & reliable_mask, rasu_id=rasu_ids, H=self.dials_reference.get_hkls() + mask_observed_indexed_as_dials_reference = ( + pl_module.surrogate_posterior.rac.gather( + source=pl_module.surrogate_posterior.observed & reliable_mask, + rasu_id=rasu_ids, + H=self.dials_reference.get_hkls(), + ) ) ordered_samples = pl_module.surrogate_posterior.rac.gather( # gather wants source to be (rac_size, ) - source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=self.dials_reference.get_hkls() - ) #(reflections_from_mtz, number_of_samples) + source=samples_surrogate_posterior.T, + rasu_id=rasu_ids, + H=self.dials_reference.get_hkls(), + ) # (reflections_from_mtz, number_of_samples) print("get_hkl from dials", self.dials_reference.get_hkls().shape) print("dials ref", self.dials_reference["F"].shape) print("ordered samples", ordered_samples.shape) print("check output of assign resolution bins:") - print( self.dials_reference.shape) - print( self.dials_reference.head()) + print(self.dials_reference.shape) + print(self.dials_reference.head()) # mask out nonobserved reflections - print("sum of (reindexed) observed reflections", mask_observed_indexed_as_dials_reference.sum()) - _masked_dials_reference = self.dials_reference[mask_observed_indexed_as_dials_reference.cpu().detach().numpy()] + print( + "sum of (reindexed) observed reflections", + mask_observed_indexed_as_dials_reference.sum(), + ) + _masked_dials_reference = self.dials_reference[ + mask_observed_indexed_as_dials_reference.cpu().detach().numpy() + ] print("masked dials reference", _masked_dials_reference["F"].shape) - _masked_ordered_samples = ordered_samples[mask_observed_indexed_as_dials_reference] + _masked_ordered_samples = ordered_samples[ + mask_observed_indexed_as_dials_reference + ] print("masked ordered samples", _masked_ordered_samples.shape) print("self.resolution_bins", self.resolution_bins) @@ -300,134 +392,182 @@ def on_train_epoch_end(self, trainer, pl_module): for i, bin_index in enumerate(self.resolution_bins): bin_mask = _masked_dials_reference["bin"] == i print("bin maks", sum(bin_mask)) - reference_to_plot = _masked_dials_reference[bin_mask]["F"].squeeze().to_numpy() - model_to_plot = _masked_ordered_samples[np.array(bin_mask)].squeeze().cpu().detach().numpy() + reference_to_plot = ( + _masked_dials_reference[bin_mask]["F"].squeeze().to_numpy() + ) + model_to_plot = ( + _masked_ordered_samples[np.array(bin_mask)] + .squeeze() + .cpu() + .detach() + .numpy() + ) - # Remove any NaN or Inf values - valid_mask = ~(np.isnan(reference_to_plot) | np.isnan(model_to_plot) | - np.isinf(reference_to_plot) | np.isinf(model_to_plot)) + valid_mask = ~( + np.isnan(reference_to_plot) + | np.isnan(model_to_plot) + | np.isinf(reference_to_plot) + | np.isinf(model_to_plot) + ) if not np.any(valid_mask): print("Warning: No valid data points after filtering") continue - + reference_to_plot = reference_to_plot[valid_mask] model_to_plot = model_to_plot[valid_mask] - + print(f"After filtering - valid data points: {len(reference_to_plot)}") - - scatter = axes[i].scatter(reference_to_plot, - model_to_plot, - marker='x', s=10, alpha=0.5, c="black") - - # Fit a straight line through the origin using numpy polyfit - # For line through origin, we fit y = mx (no intercept) - # try: - # fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0] - # print(f"Polyfit successful, slope: {fitted_slope:.6f}") - # except np.linalg.LinAlgError as e: - # print(f"Polyfit failed with LinAlgError: {e}") - # print("Falling back to manual calculation") - # # Fallback to manual calculation for line through origin - # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) - # print(f"Fallback slope: {fitted_slope:.6f}") - # except Exception as e: - # print(f"Polyfit failed with other error: {e}") - # # Fallback to manual calculation - # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) - # print(f"Fallback slope: {fitted_slope:.6f}") - - # # Plot the fitted line - # x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) - # y_fitted = fitted_slope * x_range - # axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') - + + scatter = axes[i].scatter( + reference_to_plot, + model_to_plot, + marker="x", + s=10, + alpha=0.5, + c="black", + ) + + # Fit a straight line through the origin using numpy polyfit + # For line through origin, we fit y = mx (no intercept) + # try: + # fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0] + # print(f"Polyfit successful, slope: {fitted_slope:.6f}") + # except np.linalg.LinAlgError as e: + # print(f"Polyfit failed with LinAlgError: {e}") + # print("Falling back to manual calculation") + # # Fallback to manual calculation for line through origin + # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + # print(f"Fallback slope: {fitted_slope:.6f}") + # except Exception as e: + # print(f"Polyfit failed with other error: {e}") + # # Fallback to manual calculation + # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + # print(f"Fallback slope: {fitted_slope:.6f}") + + # # Plot the fitted line + # x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) + # y_fitted = fitted_slope * x_range + # axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') + # Set log-log scale - axes[i].set_xscale('log') - axes[i].set_yscale('log') - + axes[i].set_xscale("log") + axes[i].set_yscale("log") + try: - correlation = np.corrcoef(reference_to_plot, - model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1] + correlation = np.corrcoef( + reference_to_plot, + model_to_plot + * np.max(reference_to_plot) + / np.max(model_to_plot), + )[0, 1] print("reference_to_plot", reference_to_plot) - print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot)) + print( + "model to plot scaled", + model_to_plot + * np.max(reference_to_plot) + / np.max(model_to_plot), + ) except Exception as e: print(f"skip correlation computataion {e}") # print(f"Fitted slope: {fitted_slope:.4f}") - axes[i].set_xlabel('Dials Reference F') - axes[i].set_ylabel('Predicted F') - axes[i].set_title(f'bin {bin_index} (r={correlation:.3f}') + axes[i].set_xlabel("Dials Reference F") + axes[i].set_ylabel("Predicted F") + axes[i].set_title(f"bin {bin_index} (r={correlation:.3f}") for j in range(i + 1, len(axes)): fig.delaxes(axes[j]) - - pl_module.logger.experiment.log({ - "Correlation/structure_factors_binned": wandb.Image(fig), - # "Correlation/correlation_coefficient": correlation, - "epoch": trainer.current_epoch - }) + pl_module.logger.experiment.log( + { + "Correlation/structure_factors_binned": wandb.Image(fig), + # "Correlation/correlation_coefficient": correlation, + "epoch": trainer.current_epoch, + } + ) plt.close(fig) print("logged image") - + except Exception as e: print(f"Error in correlation plotting: {str(e)}") import traceback + traceback.print_exc() + class CorrelationPlotting(Callback): - + def __init__(self, dataloader): super().__init__() self.dials_reference = None self.fixed_batch = None self.dataloader = dataloader - + def on_fit_start(self, trainer, pl_module): print("on fit start callback") - self.dials_reference = rs.read_mtz(pl_module.model_settings.merged_mtz_file_path) + self.dials_reference = rs.read_mtz( + pl_module.model_settings.merged_mtz_file_path + ) print("Read mtz successfully: ", self.dials_reference["F"].shape) print("head hkl", self.dials_reference.get_hkls().shape) - - def on_train_epoch_end(self, trainer, pl_module): print("on train epoch end in callback") try: # shoeboxes_batch, photon_rate_output, hkl_batch = pl_module(self.fixed_batch, log_images=True) - + samples_surrogate_posterior = pl_module.surrogate_posterior.mean - print("full samples", samples_surrogate_posterior.shape) + print("full samples", samples_surrogate_posterior.shape) # print("gathered ", photon_rate_output['samples_predicted_structure_factor'].shape) rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0] - print("rasu id (all teh same as the first one?):",pl_module.surrogate_posterior.rac.rasu_ids) - - if hasattr(pl_module.surrogate_posterior, 'reliable_observations_mask'): - reliable_mask = pl_module.surrogate_posterior.reliable_observations_mask(min_observations=50) + print( + "rasu id (all teh same as the first one?):", + pl_module.surrogate_posterior.rac.rasu_ids, + ) + + if hasattr(pl_module.surrogate_posterior, "reliable_observations_mask"): + reliable_mask = ( + pl_module.surrogate_posterior.reliable_observations_mask( + min_observations=50 + ) + ) print("take reliable mask for correlation plotting") - + else: reliable_mask = pl_module.surrogate_posterior.observed print("observed attribute", pl_module.surrogate_posterior.observed.sum()) - mask_observed_indexed_as_dials_reference = pl_module.surrogate_posterior.rac.gather( - source=pl_module.surrogate_posterior.observed & reliable_mask, rasu_id=rasu_ids, H=self.dials_reference.get_hkls() + mask_observed_indexed_as_dials_reference = ( + pl_module.surrogate_posterior.rac.gather( + source=pl_module.surrogate_posterior.observed & reliable_mask, + rasu_id=rasu_ids, + H=self.dials_reference.get_hkls(), + ) ) ordered_samples = pl_module.surrogate_posterior.rac.gather( # gather wants source to be (rac_size, ) - source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=self.dials_reference.get_hkls() - ) #(reflections_from_mtz, number_of_samples) + source=samples_surrogate_posterior.T, + rasu_id=rasu_ids, + H=self.dials_reference.get_hkls(), + ) # (reflections_from_mtz, number_of_samples) print("get_hkl from dials", self.dials_reference.get_hkls().shape) print("dials ref", self.dials_reference["F"].shape) print("ordered samples", ordered_samples.shape) # mask out nonobserved reflections - print("sum of (reindexed) observed reflections", mask_observed_indexed_as_dials_reference.sum()) - _masked_dials_reference = self.dials_reference[mask_observed_indexed_as_dials_reference.cpu().detach().numpy()] + print( + "sum of (reindexed) observed reflections", + mask_observed_indexed_as_dials_reference.sum(), + ) + _masked_dials_reference = self.dials_reference[ + mask_observed_indexed_as_dials_reference.cpu().detach().numpy() + ] print("masked dials reference", _masked_dials_reference["F"].shape) - _masked_ordered_samples = ordered_samples[mask_observed_indexed_as_dials_reference] + _masked_ordered_samples = ordered_samples[ + mask_observed_indexed_as_dials_reference + ] print("masked ordered samples", _masked_ordered_samples.shape) if _masked_ordered_samples.numel() == 0: @@ -437,99 +577,134 @@ def on_train_epoch_end(self, trainer, pl_module): print("Creating correlation plot...") reference_to_plot = _masked_dials_reference["F"].squeeze().to_numpy() model_to_plot = _masked_ordered_samples.squeeze().cpu().detach().numpy() - + # Remove any NaN or Inf values - valid_mask = ~(np.isnan(reference_to_plot) | np.isnan(model_to_plot) | - np.isinf(reference_to_plot) | np.isinf(model_to_plot)) - + valid_mask = ~( + np.isnan(reference_to_plot) + | np.isnan(model_to_plot) + | np.isinf(reference_to_plot) + | np.isinf(model_to_plot) + ) + if not np.any(valid_mask): print("Warning: No valid data points after filtering") return - + reference_to_plot = reference_to_plot[valid_mask] model_to_plot = model_to_plot[valid_mask] - + print(f"After filtering - valid data points: {len(reference_to_plot)}") - + fig, axes = plt.subplots(figsize=(10, 10)) - scatter = axes.scatter(reference_to_plot, - model_to_plot, - marker='x', s=10, alpha=0.5, c="black") - + scatter = axes.scatter( + reference_to_plot, model_to_plot, marker="x", s=10, alpha=0.5, c="black" + ) + # Fit a straight line through the origin using numpy polyfit # For line through origin, we fit y = mx (no intercept) try: - fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0] + fitted_slope = np.polyfit( + reference_to_plot, model_to_plot, 1, w=None, cov=False + )[0] print(f"Polyfit successful, slope: {fitted_slope:.6f}") except np.linalg.LinAlgError as e: print(f"Polyfit failed with LinAlgError: {e}") print("Falling back to manual calculation") # Fallback to manual calculation for line through origin - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") except Exception as e: print(f"Polyfit failed with other error: {e}") # Fallback to manual calculation - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") - + # Plot the fitted line x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) y_fitted = fitted_slope * x_range - axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') - + axes.plot( + x_range, + y_fitted, + "r-", + linewidth=2, + label=f"Fitted line: y = {fitted_slope:.4f}x", + ) + # Set log-log scale - axes.set_xscale('log') - axes.set_yscale('log') - - correlation = np.corrcoef(reference_to_plot, - model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1] + axes.set_xscale("log") + axes.set_yscale("log") + + correlation = np.corrcoef( + reference_to_plot, + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + )[0, 1] print("reference_to_plot", reference_to_plot) - print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot)) + print( + "model to plot scaled", + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + ) print(f"Fitted slope: {fitted_slope:.4f}") - axes.set_xlabel('Dials Reference F') - axes.set_ylabel('Predicted F') - axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})') + axes.set_xlabel("Dials Reference F") + axes.set_ylabel("Predicted F") + axes.set_title( + f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})" + ) axes.legend() - + # Add diagonal line # min_val = min(axes.get_xlim()[0], axes.get_ylim()[0]) # max_val = max(axes.get_xlim()[1], axes.get_ylim()[1]) # axes.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5) - pl_module.logger.experiment.log({ - "Correlation/structure_factors": wandb.Image(fig), - "Correlation/correlation_coefficient": correlation, - "epoch": trainer.current_epoch - }) + pl_module.logger.experiment.log( + { + "Correlation/structure_factors": wandb.Image(fig), + "Correlation/correlation_coefficient": correlation, + "epoch": trainer.current_epoch, + } + ) plt.close(fig) print("logged image") - + except Exception as e: print(f"Error in correlation plotting: {str(e)}") import traceback + traceback.print_exc() + class ScalePlotting(Callback): - + def __init__(self, dataloader): super().__init__() self.fixed_batch = None - self.scale_history = { - 'loc': [], - 'scale': [] - } + self.scale_history = {"loc": [], "scale": []} self.dataloader = dataloader - + def on_fit_start(self, trainer, pl_module): - _raw_batch = next(iter(self.dataloader.load_data_for_logging_during_training(number_of_shoeboxes_to_log=2000))) + _raw_batch = next( + iter( + self.dataloader.load_data_for_logging_during_training( + number_of_shoeboxes_to_log=2000 + ) + ) + ) _raw_batch = [_batch_item.to(pl_module.device) for _batch_item in _raw_batch] self.fixed_batch = tuple(_raw_batch) - self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[:, 13] - self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to(pl_module.device)[:, 9] - self.fixed_batch_total_counts = self.fixed_batch[3].to(pl_module.device).sum(dim=1) - + self.fixed_batch_background_mean = self.fixed_batch[1].to(pl_module.device)[ + :, 13 + ] + self.fixed_batch_intensity_sum_values = self.fixed_batch[1].to( + pl_module.device + )[:, 9] + self.fixed_batch_total_counts = ( + self.fixed_batch[3].to(pl_module.device).sum(dim=1) + ) def on_train_epoch_end(self, trainer, pl_module): with torch.no_grad(): @@ -537,103 +712,175 @@ def on_train_epoch_end(self, trainer, pl_module): # Memory optimization: Process in smaller chunks chunk_size = 20 # Process 20 shoeboxes at a time batch_size = self.fixed_batch[0].size(0) - + all_scale_means = [] all_background_means = [] - + for i in range(0, batch_size, chunk_size): end_idx = min(i + chunk_size, batch_size) - + # Create chunk batch - chunk_batch = tuple(tensor[i:end_idx] for tensor in self.fixed_batch) - - # Process chunk - shoebox_representation, metadata_representation, image_representation, shoebox_profile_representation = pl_module._batch_to_representations(batch=chunk_batch) + chunk_batch = tuple( + tensor[i:end_idx] for tensor in self.fixed_batch + ) - scale_per_reflection = pl_module.compute_scale(representations=[image_representation, metadata_representation]) - scale_mean_per_reflection = scale_per_reflection.mean.cpu().detach().numpy() + # Process chunk + ( + shoebox_representation, + metadata_representation, + image_representation, + shoebox_profile_representation, + ) = pl_module._batch_to_representations(batch=chunk_batch) + + scale_per_reflection = pl_module.compute_scale( + representations=[image_representation, metadata_representation] + ) + scale_mean_per_reflection = ( + scale_per_reflection.mean.cpu().detach().numpy() + ) all_scale_means.append(scale_mean_per_reflection) - - background_mean_per_reflection = pl_module.compute_background_distribution(shoebox_representation=shoebox_representation).mean.cpu().detach().numpy() + + background_mean_per_reflection = ( + pl_module.compute_background_distribution( + shoebox_representation=shoebox_representation + ) + .mean.cpu() + .detach() + .numpy() + ) all_background_means.append(background_mean_per_reflection) - + # Clear intermediate tensors - del shoebox_representation, metadata_representation, image_representation, scale_per_reflection - + del ( + shoebox_representation, + metadata_representation, + image_representation, + scale_per_reflection, + ) + # Combine results scale_mean_per_reflection = np.concatenate(all_scale_means) background_mean_per_reflection = np.concatenate(all_background_means) - + print(len(scale_mean_per_reflection)) - - plt.figure(figsize=(8,5)) - plt.hist(scale_mean_per_reflection, bins=100, edgecolor="black", alpha=0.6, label="Scale_mean") + + plt.figure(figsize=(8, 5)) + plt.hist( + scale_mean_per_reflection, + bins=100, + edgecolor="black", + alpha=0.6, + label="Scale_mean", + ) plt.xlabel("mean scale per reflection") plt.ylabel("Number of Observations") plt.title("Histogram of Per‐Reflection Scale Factors-Model") plt.tight_layout() - pl_module.logger.experiment.log({"scale_histogram": wandb.Image(plt.gcf())}) + pl_module.logger.experiment.log( + {"scale_histogram": wandb.Image(plt.gcf())} + ) plt.close() - print("###################### std bkg #########: ",np.std(background_mean_per_reflection)) - plt.figure(figsize=(8,5)) - plt.hist(background_mean_per_reflection, bins=100, edgecolor="black", alpha=0.6, label="Background_mean") + print( + "###################### std bkg #########: ", + np.std(background_mean_per_reflection), + ) + plt.figure(figsize=(8, 5)) + plt.hist( + background_mean_per_reflection, + bins=100, + edgecolor="black", + alpha=0.6, + label="Background_mean", + ) plt.xlabel("mean background per reflection") plt.ylabel("Number of Observations") plt.title("Histogram of Per‐Reflection Background -Model") plt.tight_layout() - pl_module.logger.experiment.log({"background_histogram": wandb.Image(plt.gcf())}) + pl_module.logger.experiment.log( + {"background_histogram": wandb.Image(plt.gcf())} + ) plt.close() - reference_to_plot = self.fixed_batch_background_mean.cpu().detach().numpy() + reference_to_plot = ( + self.fixed_batch_background_mean.cpu().detach().numpy() + ) model_to_plot = background_mean_per_reflection.squeeze() print("bkg", reference_to_plot.shape, model_to_plot.shape) fig, axes = plt.subplots(figsize=(10, 10)) - scatter = axes.scatter(reference_to_plot, - model_to_plot, - marker='x', s=10, alpha=0.5, c="black") + scatter = axes.scatter( + reference_to_plot, + model_to_plot, + marker="x", + s=10, + alpha=0.5, + c="black", + ) try: - fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0] + fitted_slope = np.polyfit( + reference_to_plot, model_to_plot, 1, w=None, cov=False + )[0] print(f"Polyfit successful, slope: {fitted_slope:.6f}") except np.linalg.LinAlgError as e: print(f"Polyfit failed with LinAlgError: {e}") print("Falling back to manual calculation") # Fallback to manual calculation for line through origin - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") except Exception as e: print(f"Polyfit failed with other error: {e}") # Fallback to manual calculation - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") - + # Plot the fitted line - x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) + x_range = np.linspace( + reference_to_plot.min(), reference_to_plot.max(), 100 + ) y_fitted = fitted_slope * x_range - axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') - + axes.plot( + x_range, + y_fitted, + "r-", + linewidth=2, + label=f"Fitted line: y = {fitted_slope:.4f}x", + ) + # Set log-log scale - axes.set_xscale('log') - axes.set_yscale('log') - - correlation = np.corrcoef(reference_to_plot, - model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1] + axes.set_xscale("log") + axes.set_yscale("log") + + correlation = np.corrcoef( + reference_to_plot, + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + )[0, 1] print("reference_to_plot", reference_to_plot) - print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot)) + print( + "model to plot scaled", + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + ) print(f"Fitted slope: {fitted_slope:.4f}") - axes.set_xlabel('Dials Reference Bkg') - axes.set_ylabel('Predicted Bkg') - axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})') + axes.set_xlabel("Dials Reference Bkg") + axes.set_ylabel("Predicted Bkg") + axes.set_title( + f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})" + ) axes.legend() - - pl_module.logger.experiment.log({ - "Correlation/Background": wandb.Image(fig), - "Correlation/background_correlation_coefficient": correlation, - "epoch": trainer.current_epoch - }) + + pl_module.logger.experiment.log( + { + "Correlation/Background": wandb.Image(fig), + "Correlation/background_correlation_coefficient": correlation, + "epoch": trainer.current_epoch, + } + ) plt.close(fig) print("logged image") @@ -641,97 +888,164 @@ def on_train_epoch_end(self, trainer, pl_module): print(f"failed ScalePlotting/Bkg with: {e}") # Memory optimization: Clear large tensors and force cleanup - del all_scale_means, all_background_means, scale_mean_per_reflection, background_mean_per_reflection + del ( + all_scale_means, + all_background_means, + scale_mean_per_reflection, + background_mean_per_reflection, + ) torch.cuda.empty_cache() ################# Plot Intensity Correlation ############################## try: # Memory optimization: Process intensity correlation in chunks without verbose output all_model_intensities = [] - + for i in range(0, batch_size, chunk_size): end_idx = min(i + chunk_size, batch_size) - chunk_batch = tuple(tensor[i:end_idx] for tensor in self.fixed_batch) - + chunk_batch = tuple( + tensor[i:end_idx] for tensor in self.fixed_batch + ) + # Get representations for this chunk - shoebox_representation, metadata_representation, image_representation,shoebox_profile_representation = pl_module._batch_to_representations(batch=chunk_batch) - + ( + shoebox_representation, + metadata_representation, + image_representation, + shoebox_profile_representation, + ) = pl_module._batch_to_representations(batch=chunk_batch) + # Compute scale and structure factors for this chunk - scale_distribution = pl_module.compute_scale(representations=[image_representation, metadata_representation]) - scale_samples = scale_distribution.rsample([pl_module.model_settings.number_of_mc_samples]).permute(1, 0, 2) - + scale_distribution = pl_module.compute_scale( + representations=[image_representation, metadata_representation] + ) + scale_samples = scale_distribution.rsample( + [pl_module.model_settings.number_of_mc_samples] + ).permute(1, 0, 2) + # Get structure factor samples for this chunk hkl_chunk = chunk_batch[4] # hkl batch - samples_surrogate_posterior = pl_module.surrogate_posterior.rsample([pl_module.model_settings.number_of_mc_samples]) + samples_surrogate_posterior = pl_module.surrogate_posterior.rsample( + [pl_module.model_settings.number_of_mc_samples] + ) rasu_ids = pl_module.surrogate_posterior.rac.rasu_ids[0] - samples_predicted_structure_factor = pl_module.surrogate_posterior.rac.gather( - source=samples_surrogate_posterior.T, rasu_id=rasu_ids, H=hkl_chunk - ).unsqueeze(-1) - + samples_predicted_structure_factor = ( + pl_module.surrogate_posterior.rac.gather( + source=samples_surrogate_posterior.T, + rasu_id=rasu_ids, + H=hkl_chunk, + ).unsqueeze(-1) + ) + # Compute model intensity for this chunk - model_intensity_chunk = scale_samples * torch.square(samples_predicted_structure_factor) - model_intensity_chunk = torch.mean(model_intensity_chunk, dim=1).squeeze() - all_model_intensities.append(model_intensity_chunk.cpu().detach().numpy()) - + model_intensity_chunk = scale_samples * torch.square( + samples_predicted_structure_factor + ) + model_intensity_chunk = torch.mean( + model_intensity_chunk, dim=1 + ).squeeze() + all_model_intensities.append( + model_intensity_chunk.cpu().detach().numpy() + ) + # Clear intermediate tensors - del shoebox_representation, metadata_representation, image_representation, scale_distribution, scale_samples - del samples_surrogate_posterior, samples_predicted_structure_factor, model_intensity_chunk - + del ( + shoebox_representation, + metadata_representation, + image_representation, + scale_distribution, + scale_samples, + ) + del ( + samples_surrogate_posterior, + samples_predicted_structure_factor, + model_intensity_chunk, + ) + # Combine results model_to_plot = np.concatenate(all_model_intensities) - reference_to_plot = self.fixed_batch_intensity_sum_values.cpu().detach().numpy() - + reference_to_plot = ( + self.fixed_batch_intensity_sum_values.cpu().detach().numpy() + ) + fig, axes = plt.subplots(figsize=(10, 10)) - scatter = axes.scatter(reference_to_plot, - model_to_plot, - marker='x', s=10, alpha=0.5, c="black") + scatter = axes.scatter( + reference_to_plot, + model_to_plot, + marker="x", + s=10, + alpha=0.5, + c="black", + ) try: - fitted_slope = np.polyfit(reference_to_plot, model_to_plot, 1, w=None, cov=False)[0] + fitted_slope = np.polyfit( + reference_to_plot, model_to_plot, 1, w=None, cov=False + )[0] print(f"Polyfit successful, slope: {fitted_slope:.6f}") except np.linalg.LinAlgError as e: print(f"Polyfit failed with LinAlgError: {e}") print("Falling back to manual calculation") # Fallback to manual calculation for line through origin - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") except Exception as e: print(f"Polyfit failed with other error: {e}") # Fallback to manual calculation - fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) + fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum( + reference_to_plot**2 + ) print(f"Fallback slope: {fitted_slope:.6f}") - + # Plot the fitted line - x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) + x_range = np.linspace( + reference_to_plot.min(), reference_to_plot.max(), 100 + ) y_fitted = fitted_slope * x_range - axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') - + axes.plot( + x_range, + y_fitted, + "r-", + linewidth=2, + label=f"Fitted line: y = {fitted_slope:.4f}x", + ) + # Set log-log scale - axes.set_xscale('log') - axes.set_yscale('log') - - correlation = np.corrcoef(reference_to_plot, - model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1] + axes.set_xscale("log") + axes.set_yscale("log") + + correlation = np.corrcoef( + reference_to_plot, + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + )[0, 1] print("reference_to_plot", reference_to_plot) - print("model to plot scaled", model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot)) + print( + "model to plot scaled", + model_to_plot * np.max(reference_to_plot) / np.max(model_to_plot), + ) print(f"Fitted slope: {fitted_slope:.4f}") - axes.set_xlabel('Dials Reference Intensity') - axes.set_ylabel('Predicted Intensity') - axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})') + axes.set_xlabel("Dials Reference Intensity") + axes.set_ylabel("Predicted Intensity") + axes.set_title( + f"Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})" + ) axes.legend() - - pl_module.logger.experiment.log({ - "Correlation/Intensity": wandb.Image(fig), - "Correlation/intensity_correlation_coefficient": correlation, - "epoch": trainer.current_epoch - }) + + pl_module.logger.experiment.log( + { + "Correlation/Intensity": wandb.Image(fig), + "Correlation/intensity_correlation_coefficient": correlation, + "epoch": trainer.current_epoch, + } + ) plt.close(fig) print("logged image") except Exception as e: print(f"failed ScalePlotting/Bkg with {e}") - # ################# Plot Count Correlation ############################## # try: # # Extract values from dictionary @@ -740,16 +1054,15 @@ def on_train_epoch_end(self, trainer, pl_module): # print("samples_photon_rate (batch, mc, pix)", samples_photon_rate.shape) # samples_photon_rate = samples_photon_rate.sum(dim=-1) # print("samples_photon_rate (batch, mc)", samples_photon_rate.shape) - + # photon_rate_per_sb = torch.mean(samples_photon_rate, dim=1).squeeze() #(batch_size, 1) # print("photon_rate (batch)", photon_rate_per_sb.shape) - # reference_to_plot = self.fixed_batch_total_counts.cpu().detach().numpy() # model_to_plot = photon_rate_per_sb.cpu().detach().numpy() # fig, axes = plt.subplots(figsize=(10, 10)) - # scatter = axes.scatter(reference_to_plot, - # model_to_plot, + # scatter = axes.scatter(reference_to_plot, + # model_to_plot, # marker='x', s=10, alpha=0.5, c="black") # try: @@ -766,23 +1079,23 @@ def on_train_epoch_end(self, trainer, pl_module): # # Fallback to manual calculation # fitted_slope = np.sum(reference_to_plot * model_to_plot) / np.sum(reference_to_plot**2) # print(f"Fallback slope: {fitted_slope:.6f}") - + # # Plot the fitted line # x_range = np.linspace(reference_to_plot.min(), reference_to_plot.max(), 100) # y_fitted = fitted_slope * x_range # axes.plot(x_range, y_fitted, 'r-', linewidth=2, label=f'Fitted line: y = {fitted_slope:.4f}x') - + # # Set log-log scale # axes.set_xscale('log') # axes.set_yscale('log') - - # correlation = np.corrcoef(reference_to_plot, + + # correlation = np.corrcoef(reference_to_plot, # model_to_plot*np.max(reference_to_plot)/np.max(model_to_plot))[0,1] # axes.set_xlabel('Dials Reference Intensity') # axes.set_ylabel('Predicted Intensity') # axes.set_title(f'Correlation Plot (Log-Log) (r={correlation:.3f}, slope={fitted_slope:.4f})') # axes.legend() - + # pl_module.logger.experiment.log({ # "Correlation/Counts": wandb.Image(fig), # "Correlation/counts_correlation_coefficient": correlation, @@ -794,14 +1107,13 @@ def on_train_epoch_end(self, trainer, pl_module): # except Exception as e: # print(f"failed corr counts with {e}") - - def on_train_end(self, trainer, pl_module): unmerged_mtz = rs.read_mtz(pl_module.model_settings.unmerged_mtz_file_path) scale_dials = unmerged_mtz["SCALEUSED"] - pl_module.logger.experiment.log({ + pl_module.logger.experiment.log( + { "scale/mean_dials": scale_dials.mean(), "scale/max_dials": scale_dials.max(), - "scale/min_dials": scale_dials.min() - }) - + "scale/min_dials": scale_dials.min(), + } + ) diff --git a/src/factory/data_loader.py b/src/factory/data_loader.py index fea43cd..0a00b1e 100644 --- a/src/factory/data_loader.py +++ b/src/factory/data_loader.py @@ -1,21 +1,20 @@ """To load shoeboxes batched by image. -Call with +Call with python data_loader.py --data_directory . """ -import torch -import os -import pytorch_lightning as pl -from torch.utils.data import DataLoader, TensorDataset, Subset, Dataset -from torch.utils.data import Sampler +import argparse +import dataclasses import itertools +import os import random -import dataclasses -import argparse -import settings +import pytorch_lightning as pl +import settings +import torch +from torch.utils.data import DataLoader, Dataset, Sampler, Subset, TensorDataset class ShoeboxTensorDataset(Dataset): @@ -45,42 +44,50 @@ def __getitem__(self, idx): ) -class CrystallographicDataLoader(): +class CrystallographicDataLoader: full_data_set: torch.utils.data.dataset.Subset train_data_set: torch.utils.data.dataset.Subset validation_data_set: torch.utils.data.dataset.Subset test_data_set: torch.utils.data.dataset.Subset data_loader_settings: settings.DataLoaderSettings - def __init__( - self, - data_loader_settings: settings.DataLoaderSettings - ): + def __init__(self, data_loader_settings: settings.DataLoaderSettings): self.data_loader_settings = data_loader_settings self.use_standard_sampler = False def _get_raw_shoebox_data_(self): - + # counts = torch.load( # os.path.join(self.settings.data_directory, self.settings.data_file_names["counts"]), weights_only=True # ) # print("counts shape:", counts.shape) dead_pixel_mask = torch.load( - os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["masks"]), weights_only=True + os.path.join( + self.data_loader_settings.data_directory, + self.data_loader_settings.data_file_names["masks"], + ), + weights_only=True, ) print("dead_pixel_mask shape:", dead_pixel_mask.shape) shoeboxes = torch.load( - os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["counts"]), weights_only=True + os.path.join( + self.data_loader_settings.data_directory, + self.data_loader_settings.data_file_names["counts"], + ), + weights_only=True, ) # shoeboxes = shoeboxes[:,:,-1] # if len(shoeboxes[0]) != 7: # self.use_standard_sampler = True - + counts = shoeboxes.clone() print("standard sampler:", self.use_standard_sampler) - stats = torch.load(os.path.join(self.data_loader_settings.data_directory,"stats.pt"), weights_only=True) + stats = torch.load( + os.path.join(self.data_loader_settings.data_directory, "stats.pt"), + weights_only=True, + ) mean = stats[0] var = stats[1] @@ -100,20 +107,21 @@ def _get_raw_shoebox_data_(self): # ) # print("dials refercne shape", dials_reference.shape) - metadata = torch.load( - os.path.join(self.data_loader_settings.data_directory, self.data_loader_settings.data_file_names["metadata"]), weights_only=True + os.path.join( + self.data_loader_settings.data_directory, + self.data_loader_settings.data_file_names["metadata"], + ), + weights_only=True, ) - - - print("Metadata shape:", metadata.shape) #(d, h,k, l ,x, y, z) + print("Metadata shape:", metadata.shape) # (d, h,k, l ,x, y, z) # metadata = torch.zeros(shoeboxes.shape[0], 7) # hkl = metadata[:,1:4].to(torch.int) # print("hkl shape", hkl.shape) - hkl = metadata[:,6:9].to(torch.int) + hkl = metadata[:, 6:9].to(torch.int) print("hkl shape", hkl.shape) # Use custom dataset for on-the-fly normalization @@ -126,81 +134,103 @@ def _get_raw_shoebox_data_(self): mean=mean, var=var, ) - print("Metadata shape full tensor:", self.full_data_set.metadata.shape) #(d, h,k, l ,x, y, z) + print( + "Metadata shape full tensor:", self.full_data_set.metadata.shape + ) # (d, h,k, l ,x, y, z) # print("concentration shape", torch.load( # os.path.join(self.settings.data_directory, "concentration.pt"), weights_only=True # ).shape) def append_image_id_to_metadata_(self) -> None: - image_ids = self._get_image_ids_from_shoeboxes( - shoebox_data_set=self.full_data_set.shoeboxes - ) - data_as_list = list(self.full_data_set.tensors) - data_as_list[1] = torch.cat( - (data_as_list[1], torch.tensor(image_ids).unsqueeze(1)), dim=1 - ) - self.full_data_set.tensors = tuple(data_as_list) - + image_ids = self._get_image_ids_from_shoeboxes( + shoebox_data_set=self.full_data_set.shoeboxes + ) + data_as_list = list(self.full_data_set.tensors) + data_as_list[1] = torch.cat( + (data_as_list[1], torch.tensor(image_ids).unsqueeze(1)), dim=1 + ) + self.full_data_set.tensors = tuple(data_as_list) + # def _cut_metadata(self, metadata: torch.Tensor) -> torch.Tensor: # # indices_to_keep = torch.tensor([self.settings.metadata_indices[i] for i in self.settings.metadata_keys_to_keep], dtype=torch.long) # indices_to_keep = torch.tensor(0,1) # return torch.index_select(metadata, dim=1, index=indices_to_keep) - - def _clean_shoeboxes_(self, shoeboxes: torch.Tensor, dead_pixel_mask, counts, metadata): - shoebox_mask = (shoeboxes[..., -1].sum(dim=1) < 150000) - return (shoeboxes[shoebox_mask], dead_pixel_mask[shoebox_mask], counts[shoebox_mask], metadata[shoebox_mask]) + def _clean_shoeboxes_( + self, shoeboxes: torch.Tensor, dead_pixel_mask, counts, metadata + ): + shoebox_mask = shoeboxes[..., -1].sum(dim=1) < 150000 + return ( + shoeboxes[shoebox_mask], + dead_pixel_mask[shoebox_mask], + counts[shoebox_mask], + metadata[shoebox_mask], + ) def _split_full_data_(self) -> None: full_data_set_length = len(self.full_data_set) validation_data_set_length = int( full_data_set_length * self.data_loader_settings.validation_set_split ) - test_data_set_length = int(full_data_set_length * self.data_loader_settings.test_set_split) + test_data_set_length = int( + full_data_set_length * self.data_loader_settings.test_set_split + ) train_data_set_length = ( full_data_set_length - validation_data_set_length - test_data_set_length ) - self.train_data_set, self.validation_data_set, self.test_data_set = torch.utils.data.random_split( - self.full_data_set, - [train_data_set_length, validation_data_set_length, test_data_set_length], - generator=torch.Generator().manual_seed(42) + self.train_data_set, self.validation_data_set, self.test_data_set = ( + torch.utils.data.random_split( + self.full_data_set, + [ + train_data_set_length, + validation_data_set_length, + test_data_set_length, + ], + generator=torch.Generator().manual_seed(42), + ) ) def load_data_(self) -> None: self._get_raw_shoebox_data_() - if not self.use_standard_sampler and self.data_loader_settings.append_image_id_to_metadata: + if ( + not self.use_standard_sampler + and self.data_loader_settings.append_image_id_to_metadata + ): self.append_image_id_to_metadata_() # self._clean_data_() self._split_full_data_() def _get_image_ids_from_shoeboxes(self, shoebox_data_set: torch.Tensor) -> list: - """Returns the list of respective image ids for - all shoeboxes in the shoebox data set.""" - + """Returns the list of respective image ids for + all shoeboxes in the shoebox data set.""" + image_ids = [] for shoebox in shoebox_data_set: - minimum_dz_index = torch.argmin(shoebox[:, 5]) # 5th index is dz - image_ids.append(shoebox[minimum_dz_index, 2].item()) # 2nd index is z - + minimum_dz_index = torch.argmin(shoebox[:, 5]) # 5th index is dz + image_ids.append(shoebox[minimum_dz_index, 2].item()) # 2nd index is z + if len(image_ids) != len(shoebox_data_set): print(len(image_ids), len(shoebox_data_set)) raise ValueError( f"The number of shoeboxes {len(shoebox_data_set)} does not match the number of image ids {len(image_ids)}." ) return image_ids - - def _map_images_to_shoeboxes(self, shoebox_data_set: torch.Tensor, metadata:torch.Tensor) -> dict: - """Returns a dictionary with image ids as keys and indices of all shoeboxes - belonging to that image as values.""" - + + def _map_images_to_shoeboxes( + self, shoebox_data_set: torch.Tensor, metadata: torch.Tensor + ) -> dict: + """Returns a dictionary with image ids as keys and indices of all shoeboxes + belonging to that image as values.""" + # image_ids = self._get_image_ids_from_shoeboxes( # shoebox_data_set=shoebox_data_set # ) print("metadata shape", metadata.shape) - image_ids = metadata[:,2].round().to(torch.int64) # Convert to integer type + image_ids = metadata[:, 2].round().to(torch.int64) # Convert to integer type print("image ids", image_ids) import numpy as np + unique_ids = torch.unique(image_ids) # print("Number of unique image IDs:", len(unique_ids)) # print("First few unique image IDs:", unique_ids[:10]) @@ -209,20 +239,23 @@ def _map_images_to_shoeboxes(self, shoebox_data_set: torch.Tensor, metadata:torc # print("Max image ID:", image_ids.max().item()) # print("Number of negative values:", (image_ids < 0).sum().item()) # print("Number of zero values:", (image_ids == 0).sum().item()) - + images_to_shoebox_indices = {} for shoebox_index, image_id in enumerate(image_ids): image_id_int = image_id.item() # Convert tensor to Python int if image_id_int not in images_to_shoebox_indices: images_to_shoebox_indices[image_id_int] = [] images_to_shoebox_indices[image_id_int].append(shoebox_index) - + # print("Number of keys in dictionary:", len(images_to_shoebox_indices)) return images_to_shoebox_indices - class BatchByImageSampler(Sampler): - def __init__(self, image_id_to_indices: dict, data_loader_settings: settings.DataLoaderSettings): + def __init__( + self, + image_id_to_indices: dict, + data_loader_settings: settings.DataLoaderSettings, + ): self.data_loader_settings = data_loader_settings # each element is a list of all shoebox-indices for one image self.image_indices_list = list(image_id_to_indices.values()) @@ -240,7 +273,7 @@ def __iter__(self): batch_count = 0 image_number = 0 for image in images: - image_number +=1 + image_number += 1 for shoebox in image: batch.append(shoebox) if len(batch) == self.batch_size: @@ -250,14 +283,17 @@ def __iter__(self): batch = [] batch_count += 1 if batch_count >= self.num_batches: - return + return def __len__(self): return self.num_batches - def load_data_set_batched_by_image(self, - data_set_to_load: torch.utils.data.dataset.Subset | torch.utils.data.TensorDataset, - ) -> torch.utils.data.dataloader.DataLoader: + def load_data_set_batched_by_image( + self, + data_set_to_load: ( + torch.utils.data.dataset.Subset | torch.utils.data.TensorDataset + ), + ) -> torch.utils.data.dataloader.DataLoader: if self.use_standard_sampler: return DataLoader( data_set_to_load, @@ -268,11 +304,14 @@ def load_data_set_batched_by_image(self, ) if isinstance(data_set_to_load, torch.utils.data.dataset.Subset): - print("Metadata shape full tensor in loadeing:", self.full_data_set.metadata.shape) #(d, h,k, l ,x, y, z) + print( + "Metadata shape full tensor in loadeing:", + self.full_data_set.metadata.shape, + ) # (d, h,k, l ,x, y, z) image_id_to_indices = self._map_images_to_shoeboxes( shoebox_data_set=self.full_data_set.shoeboxes[data_set_to_load.indices], - metadata=self.full_data_set.metadata[data_set_to_load.indices] + metadata=self.full_data_set.metadata[data_set_to_load.indices], ) else: image_id_to_indices = self._map_images_to_shoeboxes( @@ -281,8 +320,8 @@ def load_data_set_batched_by_image(self, batch_by_image_sampler = self.BatchByImageSampler( image_id_to_indices=image_id_to_indices, - data_loader_settings=self.data_loader_settings - ) + data_loader_settings=self.data_loader_settings, + ) return DataLoader( data_set_to_load, batch_sampler=batch_by_image_sampler, @@ -290,30 +329,34 @@ def load_data_set_batched_by_image(self, pin_memory=self.data_loader_settings.pin_memory, prefetch_factor=self.data_loader_settings.prefetch_factor, ) - - def load_data_for_logging_during_training(self, number_of_shoeboxes_to_log: int = 5) -> torch.utils.data.dataloader.DataLoader: + + def load_data_for_logging_during_training( + self, number_of_shoeboxes_to_log: int = 5 + ) -> torch.utils.data.dataloader.DataLoader: if self.use_standard_sampler: - subset = Subset(self.train_data_set, indices=range(number_of_shoeboxes_to_log)) + subset = Subset( + self.train_data_set, indices=range(number_of_shoeboxes_to_log) + ) return DataLoader(subset, batch_size=number_of_shoeboxes_to_log) image_id_to_indices = self._map_images_to_shoeboxes( - shoebox_data_set=self.full_data_set.shoeboxes[self.train_data_set.indices], - metadata=self.full_data_set.metadata[self.train_data_set.indices] - ) - data_loader_settings = dataclasses.replace(self.data_loader_settings, number_of_shoeboxes_per_batch=number_of_shoeboxes_to_log, number_of_batches=1) + shoebox_data_set=self.full_data_set.shoeboxes[self.train_data_set.indices], + metadata=self.full_data_set.metadata[self.train_data_set.indices], + ) + data_loader_settings = dataclasses.replace( + self.data_loader_settings, + number_of_shoeboxes_per_batch=number_of_shoeboxes_to_log, + number_of_batches=1, + ) batch_by_image_sampler = self.BatchByImageSampler( image_id_to_indices=image_id_to_indices, - data_loader_settings=data_loader_settings - ) + data_loader_settings=data_loader_settings, + ) return DataLoader( self.train_data_set, batch_sampler=batch_by_image_sampler, num_workers=self.data_loader_settings.number_of_workers, pin_memory=self.data_loader_settings.pin_memory, prefetch_factor=self.data_loader_settings.prefetch_factor, - persistent_workers=True + persistent_workers=True, ) - - - - diff --git a/src/factory/distributions.py b/src/factory/distributions.py index 2f8755d..1eb1a6d 100644 --- a/src/factory/distributions.py +++ b/src/factory/distributions.py @@ -1,26 +1,28 @@ -import torch -import torch.nn.functional as F -import networks -from networks import Linear import math - -from networks import Linear, Constraint -from torch.distributions import HalfNormal, Exponential - from abc import ABC, abstractmethod +import networks +import torch +import torch.nn.functional as F +from networks import Constraint, Linear +from torch.distributions import Exponential, HalfNormal + class ProfileDistribution(ABC, torch.nn.Module): """Base class for profile distributions that can compute KL divergence and generate profiles.""" def __init__(self): super().__init__() - + def compute_kl_divergence(self, predicted_distribution, target_distribution): - try: - return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution) + try: + return torch.distributions.kl.kl_divergence( + predicted_distribution, target_distribution + ) except NotImplementedError: - print("KL divergence not implemented for this distribution: use sampling method.") + print( + "KL divergence not implemented for this distribution: use sampling method." + ) samples = predicted_distribution.rsample([100]) log_q = predicted_distribution.log_prob(samples) log_p = target_distribution.log_prob(samples) @@ -30,32 +32,32 @@ def compute_kl_divergence(self, predicted_distribution, target_distribution): def _update_prior_distributions(self): """Update prior distributions to use current device.""" pass - + @abstractmethod def kl_divergence_weighted(self, weight): """Compute KL divergence between predicted and prior distributions. - + Args: weight: Weight factor for the KL divergence """ pass - + @abstractmethod def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor: """Compute the profile from the distribution. - + Args: *representations: Variable number of representations to combine - + Returns: torch.Tensor: The computed profile normalized to sum to 1 """ pass - + @abstractmethod def forward(self, *representations): """Forward pass of the distribution. - + Args: *representations: Variable number of representations to combine """ @@ -63,23 +65,28 @@ def forward(self, *representations): def compute_kl_divergence(predicted_distribution, target_distribution): - try: - # Create proper copies of the distributions - # pred_copy = type(predicted_distribution)( - # loc=predicted_distribution.loc.clone(), - # scale=predicted_distribution.scale.clone() - # ) - # target_copy = type(target_distribution)( - # loc=target_distribution.loc.clone(), - # scale=target_distribution.scale.clone() - # ) - return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution) - except NotImplementedError: - print("KL divergence not implemented for this distribution: use sampling method.") - samples = predicted_distribution.rsample([100]) - log_q = predicted_distribution.log_prob(samples) - log_p = target_distribution.log_prob(samples) - return (log_q - log_p).mean(dim=0) + try: + # Create proper copies of the distributions + # pred_copy = type(predicted_distribution)( + # loc=predicted_distribution.loc.clone(), + # scale=predicted_distribution.scale.clone() + # ) + # target_copy = type(target_distribution)( + # loc=target_distribution.loc.clone(), + # scale=target_distribution.scale.clone() + # ) + return torch.distributions.kl.kl_divergence( + predicted_distribution, target_distribution + ) + except NotImplementedError: + print( + "KL divergence not implemented for this distribution: use sampling method." + ) + samples = predicted_distribution.rsample([100]) + log_q = predicted_distribution.log_prob(samples) + log_p = target_distribution.log_prob(samples) + return (log_q - log_p).mean(dim=0) + class DirichletProfile(ProfileDistribution): """ @@ -96,21 +103,29 @@ def __init__(self, concentration_vector, dmodel, input_shape=(3, 21, 21)): # param.requires_grad = False self.dmodel = dmodel self.eps = 1e-6 - concentration_vector[concentration_vector>torch.quantile(concentration_vector, 0.99)] *= 40 + concentration_vector[ + concentration_vector > torch.quantile(concentration_vector, 0.99) + ] *= 40 concentration_vector /= concentration_vector.max() - self.register_buffer("concentration", concentration_vector) #.reshape((21,21,3))) + self.register_buffer( + "concentration", concentration_vector + ) # .reshape((21,21,3))) self.q_p = None - + def _update_prior_distributions(self): """Update prior distributions to use current device.""" self.dirichlet_prior = torch.distributions.Dirichlet(self.concentration) - + def kl_divergence_weighted(self, weight): self._update_prior_distributions() - print("shape inot kl profile", self.q_p.sample().shape, self.dirichlet_prior.sample().shape) - return self.compute_kl_divergence(self.q_p, self.dirichlet_prior)* weight - + print( + "shape inot kl profile", + self.q_p.sample().shape, + self.dirichlet_prior.sample().shape, + ) + return self.compute_kl_divergence(self.q_p, self.dirichlet_prior) * weight + def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor: dirichlet_profile = self.forward(*representations) return dirichlet_profile @@ -120,44 +135,68 @@ def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor: # return avg_profile def forward(self, *representations): - alphas = sum(representations) if len(representations) > 1 else representations[0] + alphas = ( + sum(representations) if len(representations) > 1 else representations[0] + ) print("alpha shape", alphas.shape) - + # Check for NaN values in input representations if torch.isnan(alphas).any(): print("WARNING: NaN values detected in input representations!") print("NaN count:", torch.isnan(alphas).sum().item()) - print("Input stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item()) - + print( + "Input stats - min:", + alphas.min().item(), + "max:", + alphas.max().item(), + "mean:", + alphas.mean().item(), + ) + if self.dmodel is not None: alphas = self.alpha_layer(alphas) # Check for NaN values after linear layer if torch.isnan(alphas).any(): print("WARNING: NaN values detected after alpha_layer!") print("NaN count:", torch.isnan(alphas).sum().item()) - print("Alpha layer output stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item()) - + print( + "Alpha layer output stats - min:", + alphas.min().item(), + "max:", + alphas.max().item(), + "mean:", + alphas.mean().item(), + ) + alphas = F.softplus(alphas) + self.eps print("profile alphas shape", alphas.shape) - + # Check for NaN values after softplus if torch.isnan(alphas).any(): print("WARNING: NaN values detected after softplus!") print("NaN count:", torch.isnan(alphas).sum().item()) - print("Softplus output stats - min:", alphas.min().item(), "max:", alphas.max().item(), "mean:", alphas.mean().item()) - + print( + "Softplus output stats - min:", + alphas.min().item(), + "max:", + alphas.max().item(), + "mean:", + alphas.mean().item(), + ) + # Replace NaN values with a safe default - + # Ensure all values are positive and finite - + self.q_p = torch.distributions.Dirichlet(alphas) print("profile q_p shape", self.q_p.rsample().shape) return self.q_p + class Distribution(torch.nn.Module): def __init__(self): super().__init__() - + def forward(self, representation): pass @@ -188,6 +227,7 @@ def forward(self, representation): norm = self.distribution(params) return norm + class ExponentialDistribution(torch.nn.Module): def __init__( self, @@ -232,7 +272,7 @@ def __init__(self, dmodel=64, number_of_mc_samples=100): self.d = 3 self.h = 21 - self.w = 21 + self.w = 21 z_coords = torch.arange(self.d).float() - (self.d - 1) / 2 y_coords = torch.arange(self.h).float() - (self.h - 1) / 2 @@ -250,11 +290,11 @@ def __init__(self, dmodel=64, number_of_mc_samples=100): self.register_buffer("pixel_positions", pixel_positions) # Initialize prior distributions - self.register_buffer('prior_mvn_mean_loc', torch.tensor(0.0)) - self.register_buffer('prior_mvn_mean_scale', torch.tensor(5.0)) - self.register_buffer('prior_mvn_cov_factor_loc', torch.tensor(0.0)) - self.register_buffer('prior_mvn_cov_factor_scale', torch.tensor(0.01)) - self.register_buffer('prior_mvn_cov_scale_scale', torch.tensor(0.01)) + self.register_buffer("prior_mvn_mean_loc", torch.tensor(0.0)) + self.register_buffer("prior_mvn_mean_scale", torch.tensor(5.0)) + self.register_buffer("prior_mvn_cov_factor_loc", torch.tensor(0.0)) + self.register_buffer("prior_mvn_cov_factor_scale", torch.tensor(0.01)) + self.register_buffer("prior_mvn_cov_scale_scale", torch.tensor(0.01)) self.mvn_mean_mean_layer = Linear( in_features=self.hidden_dim, @@ -280,17 +320,15 @@ def __init__(self, dmodel=64, number_of_mc_samples=100): def _update_prior_distributions(self): """Update prior distributions to use current device.""" self.prior_mvn_mean = torch.distributions.Normal( - loc=self.prior_mvn_mean_loc, - scale=self.prior_mvn_mean_scale + loc=self.prior_mvn_mean_loc, scale=self.prior_mvn_mean_scale ) self.prior_mvn_cov_factor = torch.distributions.Normal( - loc=self.prior_mvn_cov_factor_loc, - scale=self.prior_mvn_cov_factor_scale + loc=self.prior_mvn_cov_factor_loc, scale=self.prior_mvn_cov_factor_scale ) self.prior_mvn_cov_scale = torch.distributions.HalfNormal( scale=self.prior_mvn_cov_scale_scale ) - + def kl_divergence_weighted(self, weights): """Compute the KL divergence between the predicted and prior distributions. @@ -300,15 +338,22 @@ def kl_divergence_weighted(self, weights): Returns: torch.Tensor: The weighted sum of KL divergences """ - _kl_divergence = compute_kl_divergence( - self.mvn_mean_distribution, self.prior_mvn_mean - ).sum() * weights[0] - _kl_divergence += compute_kl_divergence( - self.mvn_cov_factor_distribution, self.prior_mvn_cov_factor - ).sum() * weights[1] - _kl_divergence += compute_kl_divergence( - self.mvn_cov_scale_distribution, self.prior_mvn_cov_scale - ).sum() * weights[2] + _kl_divergence = ( + compute_kl_divergence(self.mvn_mean_distribution, self.prior_mvn_mean).sum() + * weights[0] + ) + _kl_divergence += ( + compute_kl_divergence( + self.mvn_cov_factor_distribution, self.prior_mvn_cov_factor + ).sum() + * weights[1] + ) + _kl_divergence += ( + compute_kl_divergence( + self.mvn_cov_scale_distribution, self.prior_mvn_cov_scale + ).sum() + * weights[2] + ) return _kl_divergence def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor: @@ -334,26 +379,38 @@ def compute_profile(self, *representations: torch.Tensor) -> torch.Tensor: avg_profile = profile.mean(dim=1) avg_profile = avg_profile / (avg_profile.sum(dim=-1, keepdim=True) + 1e-10) - return profile_mvn_distribution #profile + return profile_mvn_distribution # profile - def forward(self, *representations: torch.Tensor) -> torch.distributions.LowRankMultivariateNormal: + def forward( + self, *representations: torch.Tensor + ) -> torch.distributions.LowRankMultivariateNormal: """Forward pass of the LRMVN distribution.""" # Update prior distributions to ensure they're on the correct device self._update_prior_distributions() - + self.batch_size = representations[0].shape[0] # Combine all representations if more than one is provided - combined_representation = sum(representations) if len(representations) > 1 else representations[0] + combined_representation = ( + sum(representations) if len(representations) > 1 else representations[0] + ) mvn_mean_mean = self.mvn_mean_mean_layer(representations[0]).unsqueeze(-1) - mvn_mean_std = F.softplus(self.mvn_mean_std_layer(representations[0])).unsqueeze(-1) + mvn_mean_std = F.softplus( + self.mvn_mean_std_layer(representations[0]) + ).unsqueeze(-1) - mvn_cov_factor_mean = self.mvn_cov_factor_mean_layer(combined_representation).unsqueeze(-1) - mvn_cov_factor_std = F.softplus(self.mvn_cov_factor_std_layer(combined_representation)).unsqueeze(-1) + mvn_cov_factor_mean = self.mvn_cov_factor_mean_layer( + combined_representation + ).unsqueeze(-1) + mvn_cov_factor_std = F.softplus( + self.mvn_cov_factor_std_layer(combined_representation) + ).unsqueeze(-1) # for the inverse gammas - mvn_cov_scale_parameter = F.softplus(self.mvn_cov_diagonal_scale_layer(representations[0])).unsqueeze(-1) + mvn_cov_scale_parameter = F.softplus( + self.mvn_cov_diagonal_scale_layer(representations[0]) + ).unsqueeze(-1) self.mvn_mean_distribution = torch.distributions.Normal( loc=mvn_mean_mean, @@ -368,16 +425,31 @@ def forward(self, *representations: torch.Tensor) -> torch.distributions.LowRank self.mvn_cov_scale_distribution = torch.distributions.half_normal.HalfNormal( scale=mvn_cov_scale_parameter, ) - mvn_mean_samples = self.mvn_mean_distribution.rsample([self.number_of_mc_samples]).squeeze(-1).permute(1, 0, 2) - mvn_cov_factor_samples = self.mvn_cov_factor_distribution.rsample([self.number_of_mc_samples]).permute(1, 0, 2, 3) + mvn_mean_samples = ( + self.mvn_mean_distribution.rsample([self.number_of_mc_samples]) + .squeeze(-1) + .permute(1, 0, 2) + ) + mvn_cov_factor_samples = self.mvn_cov_factor_distribution.rsample( + [self.number_of_mc_samples] + ).permute(1, 0, 2, 3) diag_samples = ( - self.mvn_cov_scale_distribution.rsample([self.number_of_mc_samples]).squeeze(-1).permute(1, 0, 2) + 1e-6 + self.mvn_cov_scale_distribution.rsample([self.number_of_mc_samples]) + .squeeze(-1) + .permute(1, 0, 2) + + 1e-6 ) self.profile_mvn_distribution = ( torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal( - loc=mvn_mean_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3), - cov_factor=mvn_cov_factor_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3, 1), - cov_diag=diag_samples.view(self.batch_size, self.number_of_mc_samples, 1, 3), + loc=mvn_mean_samples.view( + self.batch_size, self.number_of_mc_samples, 1, 3 + ), + cov_factor=mvn_cov_factor_samples.view( + self.batch_size, self.number_of_mc_samples, 1, 3, 1 + ), + cov_diag=diag_samples.view( + self.batch_size, self.number_of_mc_samples, 1, 3 + ), ) ) @@ -387,12 +459,12 @@ def to(self, *args, **kwargs): super().to(*args, **kwargs) # Get device from the first parameter device = next(self.parameters()).device - + # Move all registered buffers for name, buffer in self.named_buffers(): setattr(self, name, buffer.to(device)) - + # Update prior distributions self._update_prior_distributions() - - return self \ No newline at end of file + + return self diff --git a/src/factory/encoder.py b/src/factory/encoder.py index 61235a3..aa54253 100644 --- a/src/factory/encoder.py +++ b/src/factory/encoder.py @@ -1,6 +1,6 @@ # adapted from Luis Aldama @https://github.com/Hekstra-Lab/integrator.git import torch - + class CNN_3d_(torch.nn.Module): def __init__(self, Z=3, H=21, W=21, conv_channels=64, use_norm=True): @@ -59,7 +59,8 @@ def forward(self, x, mask=None): x = torch.flatten(x, 1) return x - + + class CNN_3d(torch.nn.Module): def __init__(self, out_dim=64): """ @@ -105,8 +106,8 @@ def forward(self, x, mask=None): x = x.view(x.size(0), -1) rep = F.relu(self.fc(x)) return rep - - + + class MLP(torch.nn.Module): def __init__(self, width, depth, dropout=None, output_dims=None): super().__init__() @@ -131,7 +132,8 @@ def forward(self, data): if num_pixels is not None: out = out.view(batch_size, num_pixels, -1) # Reshape back if needed return out - + + class ResidualLayer(torch.nn.Module): def __init__(self, width, dropout=None): super().__init__() @@ -177,6 +179,7 @@ def forward(self, shoebox_data): out = self.mlp_1(out) return out + def main(): # Test the model model = CNN_3d() @@ -189,4 +192,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/factory/generate_eff.py b/src/factory/generate_eff.py index 93ed724..ed173dd 100644 --- a/src/factory/generate_eff.py +++ b/src/factory/generate_eff.py @@ -66,36 +66,34 @@ }} """ + def main(): parser = argparse.ArgumentParser( description="Generate a Phenix .eff file with dynamic MTZ paths." ) parser.add_argument( - "--sf-mtz", required=True, - help="Path to the structure-factor MTZ file" + "--sf-mtz", required=True, help="Path to the structure-factor MTZ file" ) parser.add_argument( - "--rfree-mtz", required=True, - help="Path to the R-free flags MTZ file" + "--rfree-mtz", required=True, help="Path to the R-free flags MTZ file" ) parser.add_argument( - "--out", required=True, - help="Output path for the generated .eff file" + "--out", required=True, help="Output path for the generated .eff file" ) parser.add_argument( - "--phenix-out-mtz", required=True, - help="Specific name of the output .mtz file of phenix." + "--phenix-out-mtz", + required=True, + help="Specific name of the output .mtz file of phenix.", ) args = parser.parse_args() # Fill in the template and write to the output file content = EFF_TEMPLATE.format( - sf_mtz=args.sf_mtz, - rfree_mtz=args.rfree_mtz, - phenix_out_mtz=args.phenix_out_mtz + sf_mtz=args.sf_mtz, rfree_mtz=args.rfree_mtz, phenix_out_mtz=args.phenix_out_mtz ) Path(args.out).write_text(content) print(f"Wrote .eff file to {args.out}") + if __name__ == "__main__": main() diff --git a/src/factory/get_protein_data.py b/src/factory/get_protein_data.py index 4507d8c..018f44e 100644 --- a/src/factory/get_protein_data.py +++ b/src/factory/get_protein_data.py @@ -1,6 +1,7 @@ import gemmi import requests + def get_protein_data(url: str): response = requests.get(url) if not response.ok: @@ -16,7 +17,8 @@ def get_protein_data(url: str): if dmin_str is None: raise Exception("Resolution (dmin) not found in the CIF file") dmin = float(dmin_str) - return{"unit_cell": unit_cell, "spacegroup": spacegroup, "dmin": dmin_str} + return {"unit_cell": unit_cell, "spacegroup": spacegroup, "dmin": dmin_str} + if __name__ == "__main__": data = get_protein_data("https://files.rcsb.org/download/9B7C.cif") diff --git a/src/factory/import gemmi.py b/src/factory/import gemmi.py index 17a04b8..3883956 100644 --- a/src/factory/import gemmi.py +++ b/src/factory/import gemmi.py @@ -1,6 +1,7 @@ import gemmi import pandas as pd + def read_mtz_to_data_frame(file_path: string) -> pd.DataFramr mtz = gemmi.read_mtz_file(file_path) diff --git a/src/factory/inspect_background.py b/src/factory/inspect_background.py index cbc2315..d9997a6 100644 --- a/src/factory/inspect_background.py +++ b/src/factory/inspect_background.py @@ -10,13 +10,14 @@ dials.python ../factory/src/factory/inspect_background.py """ -import sys import os -from dials.array_family import flex -import matplotlib.pyplot as plt -import wandb import statistics +import sys + +import matplotlib.pyplot as plt import numpy as np +import wandb +from dials.array_family import flex from scipy.stats import gamma @@ -45,16 +46,24 @@ def main(refl_file): print(f"Variance of background.sum.value: {variance:.3f}") plt.figure(figsize=(8, 5)) - plt.hist(bg_sum, bins=100, edgecolor='black', alpha=0.6, label='background.sum.value') - plt.xlabel('Background Sum Value') - plt.ylabel('Count') - plt.title('Histogram of background.sum.value') + plt.hist( + bg_sum, bins=100, edgecolor="black", alpha=0.6, label="background.sum.value" + ) + plt.xlabel("Background Sum Value") + plt.ylabel("Count") + plt.title("Histogram of background.sum.value") # Add Gamma distribution histogram alpha = 1.9802 scale = 75.4010 gamma_samples = gamma.rvs(a=alpha, scale=scale, size=len(bg_sum)) - plt.hist(gamma_samples, bins=100, edgecolor='red', alpha=0.4, label='Gamma(alpha=1.98, scale=75.4)') + plt.hist( + gamma_samples, + bins=100, + edgecolor="red", + alpha=0.4, + label="Gamma(alpha=1.98, scale=75.4)", + ) plt.legend() plt.tight_layout() @@ -62,6 +71,9 @@ def main(refl_file): wandb.log({"background_histogram": wandb.Image(plt.gcf())}) wandb.finish() + if __name__ == "__main__": - refl_file = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl" + refl_file = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl" + ) main(refl_file) diff --git a/src/factory/inspect_scale.py b/src/factory/inspect_scale.py index 918d10c..839fe4c 100755 --- a/src/factory/inspect_scale.py +++ b/src/factory/inspect_scale.py @@ -1,42 +1,44 @@ -"""run with +"""run with source /n/hekstra_lab/people/aldama/software/dials-v3-16-1/dials_env.sh -dials.python ../factory/src/factory/inspect_scale.py """ +dials.python ../factory/src/factory/inspect_scale.py""" #!/usr/bin/env dials.python import sys -from dials.array_family import flex + import matplotlib.pyplot as plt + # matplotlib.use("Agg") # ← use a headless backend import wandb +from dials.array_family import flex + def main(refl_file): refl = flex.reflection_table.from_file(refl_file) if "inverse_scale_factor" not in refl: raise RuntimeError("No inverse_scale_factor column found in " + refl_file) - + # Extract the scale factors scales = refl["inverse_scale_factor"] scales_list = list(scales) - scales = [1/x for x in scales_list] + scales = [1 / x for x in scales_list] scales_list = scales mean_scale = sum(scales) / len(scales) - variance_scale = sum((s - mean_scale)**2 for s in scales) / len(scales) + variance_scale = sum((s - mean_scale) ** 2 for s in scales) / len(scales) max_scale = max(scales) min_scale = min(scales) print("Mean of the 1/scale factors:", mean_scale) print("Max of the 1/scale factors:", max_scale) print("Min of the 1/scale factors:", min_scale) print("Variance of the 1/scale factors:", variance_scale) - + # Convert to a regular Python list for plotting - # Plot histogram - plt.figure(figsize=(8,5)) - plt.hist(scales_list, bins=100, edgecolor='black', alpha=0.6, label='Data') + plt.figure(figsize=(8, 5)) + plt.hist(scales_list, bins=100, edgecolor="black", alpha=0.6, label="Data") plt.xlabel("1/Inverse Scale Factor") plt.ylabel("Number of Observations") plt.title("Histogram of Per‐Reflection 1/Scale Factors") @@ -45,15 +47,18 @@ def main(refl_file): # Overlay Gamma distribution import numpy as np from scipy.stats import gamma + x = np.linspace(min(scales_list), max(scales_list), 1000) gamma_shape = 6.68 gamma_scale = 0.155 y = gamma.pdf(x, a=gamma_shape, scale=gamma_scale) # Scale the gamma PDF to match the histogram - y_scaled = y * len(scales_list) * (max(scales_list) - min(scales_list)) / 100 # 100 bins - plt.plot(x, y_scaled, 'r-', lw=2, label=f'Gamma({gamma_shape}, {gamma_scale})') + y_scaled = ( + y * len(scales_list) * (max(scales_list) - min(scales_list)) / 100 + ) # 100 bins + plt.plot(x, y_scaled, "r-", lw=2, label=f"Gamma({gamma_shape}, {gamma_scale})") plt.legend() - + # Show or save # out_png = "scale_histogram.png" # plt.savefig(out_png, dpi=300) # ← write to disk @@ -66,5 +71,6 @@ def main(refl_file): # If you’d rather save to disk, uncomment: # plt.savefig("scale_histogram.png", dpi=300) + if __name__ == "__main__": - main("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl") \ No newline at end of file + main("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/scaled.refl") diff --git a/src/factory/loss_function.py b/src/factory/loss_function.py index 09ccd06..8db9a0c 100644 --- a/src/factory/loss_function.py +++ b/src/factory/loss_function.py @@ -1,8 +1,9 @@ import torch + class LossFunction(torch.nn.Module): def __init__(self): pass - + def forward(self): - pass \ No newline at end of file + pass diff --git a/src/factory/loss_functionality.py b/src/factory/loss_functionality.py index 9dce0a7..316d956 100644 --- a/src/factory/loss_functionality.py +++ b/src/factory/loss_functionality.py @@ -1,14 +1,16 @@ -import torch -import model import dataclasses from typing import Optional + +import model import settings +import torch from tensordict.nn.distributions import Delta from wrap_folded_normal import SparseFoldedNormalPosterior + @dataclasses.dataclass -class LossOutput(): +class LossOutput: loss: Optional[torch.Tensor] = None kl_structure_factors: Optional[torch.Tensor] = None kl_background: Optional[torch.Tensor] = None @@ -16,24 +18,33 @@ class LossOutput(): kl_scale: Optional[torch.Tensor] = None log_likelihood: Optional[torch.Tensor] = None + class LossFunction(torch.nn.Module): - def __init__(self):#, model_settings: settings.ModelSettings, loss_settings:settings.LossSettings): + def __init__( + self, + ): # , model_settings: settings.ModelSettings, loss_settings:settings.LossSettings): super().__init__() # self.model_settings = model_settings # self.loss_settings = loss_settings def compute_kl_divergence(self, predicted_distribution, target_distribution): - try: - return torch.distributions.kl.kl_divergence(predicted_distribution, target_distribution) + try: + return torch.distributions.kl.kl_divergence( + predicted_distribution, target_distribution + ) except NotImplementedError: - print("KL divergence not implemented for this distribution: use sampling method.") + print( + "KL divergence not implemented for this distribution: use sampling method." + ) samples = predicted_distribution.rsample([50]) log_q = predicted_distribution.log_prob(samples) log_p = target_distribution.log_prob(samples) return (log_q - log_p).mean(dim=0) - def compute_kl_divergence_verbose(self, predicted_distribution, target_distribution): - + def compute_kl_divergence_verbose( + self, predicted_distribution, target_distribution + ): + samples = predicted_distribution.rsample([50]) log_q = predicted_distribution.log_prob(samples) print(" pred", log_q) @@ -41,7 +52,9 @@ def compute_kl_divergence_verbose(self, predicted_distribution, target_distribut print(" target", log_p) return (log_q - log_p).mean(dim=0) - def compute_kl_divergence_surrogate_parameters(self, predicted_distribution, target_distribution, ordered_miller_indices): + def compute_kl_divergence_surrogate_parameters( + self, predicted_distribution, target_distribution, ordered_miller_indices + ): samples = predicted_distribution.rsample([50]) log_q = predicted_distribution.log_prob(samples) @@ -49,17 +62,20 @@ def compute_kl_divergence_surrogate_parameters(self, predicted_distribution, tar log_p = target_distribution.log_prob(samples) rasu_ids = surrogate_posterior.rac.rasu_ids[0] - print("prior_intensity.rsample([20]).permute(1,0)", prior_intensity.rsample([20]).permute(1,0)) + print( + "prior_intensity.rsample([20]).permute(1,0)", + prior_intensity.rsample([20]).permute(1, 0), + ) prior_intensity_samples = surrogate_posterior.rac.gather( - source=prior_intensity.rsample([20]).permute(1,0).T, rasu_id=rasu_ids, H=ordered_miller_indices.to(torch.int) + source=prior_intensity.rsample([20]).permute(1, 0).T, + rasu_id=rasu_ids, + H=ordered_miller_indices.to(torch.int), ) - - log_p = target_distribution.log_prob(samples) print("scale: target", log_p) return (log_q - log_p).mean(dim=0) - + def get_prior(self, distribution, device): # Move parameters to device and recreate the distribution # if isinstance(distribution, torch.distributions.HalfCauchy): @@ -68,9 +84,20 @@ def get_prior(self, distribution, device): # return torch.distributions.HalfNormal(distribution.scale.to(device)) # Add more cases as needed return distribution - - def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_distribution, scale_distribution, surrogate_posterior, - profile_distribution, model_settings, loss_settings) -> LossOutput: + + def forward( + self, + counts, + mask, + ordered_miller_indices, + photon_rate, + background_distribution, + scale_distribution, + surrogate_posterior, + profile_distribution, + model_settings, + loss_settings, + ) -> LossOutput: print("in loss") try: device = photon_rate.device @@ -79,37 +106,54 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_ loss_output = LossOutput() # print("cuda=", next(self.model_settings.intensity_prior_distibution.parameters()).device) - prior_background = self.get_prior(model_settings.background_prior_distribution, device=device) + prior_background = self.get_prior( + model_settings.background_prior_distribution, device=device + ) print("got priors bg") - prior_intensity = model_settings.build_intensity_prior_distribution(model_settings.rac.to(device)) - - + prior_intensity = model_settings.build_intensity_prior_distribution( + model_settings.rac.to(device) + ) + # self.get_prior(model_settings.intensity_prior_distibution, device=device) print("got priors sf") - prior_scale = self.get_prior(model_settings.scale_prior_distibution, device=device) + prior_scale = self.get_prior( + model_settings.scale_prior_distibution, device=device + ) print("got priors") # Compute KL divergence for structure factors if model_settings.use_surrogate_parameters: - rasu_ids = torch.tensor([0 for _ in range(len(ordered_miller_indices))], device=device) + rasu_ids = torch.tensor( + [0 for _ in range(len(ordered_miller_indices))], device=device + ) _kl_divergence_folded_normal = self.compute_kl_divergence( - surrogate_posterior, prior_intensity.distribution(rasu_ids, ordered_miller_indices)) + surrogate_posterior, + prior_intensity.distribution(rasu_ids, ordered_miller_indices), + ) # else: # # Handle case where no valid parameters were found # raise ValueError("No valid surrogate parameters found") elif isinstance(surrogate_posterior, SparseFoldedNormalPosterior): _kl_divergence_folded_normal = self.compute_kl_divergence( - surrogate_posterior.get_distribution(ordered_miller_indices), prior_intensity.distribution()) + surrogate_posterior.get_distribution(ordered_miller_indices), + prior_intensity.distribution(), + ) else: - rasu_ids = torch.tensor([0 for _ in range(len(ordered_miller_indices))], device=device) + rasu_ids = torch.tensor( + [0 for _ in range(len(ordered_miller_indices))], device=device + ) _kl_divergence_folded_normal = self.compute_kl_divergence_verbose( - surrogate_posterior, prior_intensity.distribution(rasu_ids, ordered_miller_indices)) - print("structure factor distr", surrogate_posterior.rsample([1]).shape)#, prior_intensity.rsample([1]).shape) + surrogate_posterior, + prior_intensity.distribution(rasu_ids, ordered_miller_indices), + ) + print( + "structure factor distr", surrogate_posterior.rsample([1]).shape + ) # , prior_intensity.rsample([1]).shape) # rasu_ids = surrogate_posterior.rac.rasu_ids[0] @@ -117,44 +161,64 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_ # source=_kl_divergence_folded_normal_all_reflections.T, rasu_id=rasu_ids, H=ordered_miller_indices.to(torch.int) # ) - kl_structure_factors = _kl_divergence_folded_normal * loss_settings.prior_structure_factors_weight + kl_structure_factors = ( + _kl_divergence_folded_normal + * loss_settings.prior_structure_factors_weight + ) # print("bkg", background_distribution.sample().shape, prior_background.sample()) # print(background_distribution.scale.device, prior_background.concentration.device) - kl_background = self.compute_kl_divergence( - background_distribution, prior_background - ) * loss_settings.prior_background_weight + kl_background = ( + self.compute_kl_divergence(background_distribution, prior_background) + * loss_settings.prior_background_weight + ) print("shape background kl", kl_background.shape) - kl_profile = ( - profile_distribution.kl_divergence_weighted( - loss_settings.prior_profile_weight, - ) + kl_profile = profile_distribution.kl_divergence_weighted( + loss_settings.prior_profile_weight, ) print("shape profile kl", kl_profile.shape) print("scale_distribution type:", type(scale_distribution)) if isinstance(scale_distribution, Delta): - kl_scale = torch.tensor(0.0, device=scale_distribution.mean.device, dtype=scale_distribution.mean.dtype) + kl_scale = torch.tensor( + 0.0, + device=scale_distribution.mean.device, + dtype=scale_distribution.mean.dtype, + ) else: - kl_scale = self.compute_kl_divergence_verbose( - scale_distribution, prior_scale - ) * loss_settings.prior_scale_weight + kl_scale = ( + self.compute_kl_divergence_verbose(scale_distribution, prior_scale) + * loss_settings.prior_scale_weight + ) print("kl scale", kl_scale.shape, kl_scale) - rate = photon_rate #.clamp(min=loss_settings.eps) - log_likelihood = torch.distributions.Poisson(rate=rate+loss_settings.eps).log_prob( - counts.unsqueeze(1)) #n_batches, mc_samples, pixels + rate = photon_rate # .clamp(min=loss_settings.eps) + log_likelihood = torch.distributions.Poisson( + rate=rate + loss_settings.eps + ).log_prob( + counts.unsqueeze(1) + ) # n_batches, mc_samples, pixels print("dim loglikelihood", log_likelihood.shape) - log_likelihood_mean = torch.mean(log_likelihood, dim=1) * mask # .unsqueeze(-1) + log_likelihood_mean = ( + torch.mean(log_likelihood, dim=1) * mask + ) # .unsqueeze(-1) print("log_likelihood_mean: n_batches, pixels", log_likelihood_mean.shape) print("mask", mask.shape) - negative_log_likelihood_per_batch = (-log_likelihood_mean).sum(dim=1)# batch_size + negative_log_likelihood_per_batch = (-log_likelihood_mean).sum( + dim=1 + ) # batch_size + + loss_per_batch = ( + negative_log_likelihood_per_batch + + kl_structure_factors + + kl_background + + kl_profile + + kl_scale + ) - loss_per_batch = negative_log_likelihood_per_batch + kl_structure_factors + kl_background + kl_profile + kl_scale - loss_output.loss = loss_per_batch.mean() loss_output.kl_structure_factors = kl_structure_factors.mean() @@ -167,5 +231,3 @@ def forward(self, counts, mask, ordered_miller_indices, photon_rate, background_ return loss_output except Exception as e: print(f"loss computation failed with {e}") - - diff --git a/src/factory/metadata_encoder.py b/src/factory/metadata_encoder.py index a791a85..f9b786c 100644 --- a/src/factory/metadata_encoder.py +++ b/src/factory/metadata_encoder.py @@ -1,9 +1,11 @@ -import torch import math + +import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Linear + def weight_initializer(weight): fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2]) std = math.sqrt(1.0 / fan_avg / 10.0) @@ -57,6 +59,7 @@ def forward(self, x): return out + class SimpleMetadataEncoder(nn.Module): def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None): super().__init__() @@ -122,6 +125,7 @@ def forward(self, x): return out + class BaseMetadataEncoder(nn.Module): def __init__(self, feature_dim=2, depth=10, dropout=0.0, output_dims=64): super().__init__() @@ -150,15 +154,29 @@ def forward(self, x): if torch.isnan(x).any(): print("WARNING: NaN values in BaseMetadataEncoder input!") print("NaN count:", torch.isnan(x).sum().item()) - print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item()) - + print( + "Stats - min:", + x.min().item(), + "max:", + x.max().item(), + "mean:", + x.mean().item(), + ) + # Process through the model x = self.model(x) - + # Check output for NaN values if torch.isnan(x).any(): print("WARNING: NaN values in BaseMetadataEncoder output!") print("NaN count:", torch.isnan(x).sum().item()) - print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item()) - - return x \ No newline at end of file + print( + "Stats - min:", + x.min().item(), + "max:", + x.max().item(), + "mean:", + x.mean().item(), + ) + + return x diff --git a/src/factory/model.py b/src/factory/model.py index da22ee6..4e2575e 100644 --- a/src/factory/model.py +++ b/src/factory/model.py @@ -1,73 +1,68 @@ - -import sys, os +import os +import sys repo_root = os.path.abspath(os.path.join(__file__, os.pardir)) -inner_pkg = os.path.join(repo_root, "abismal_torch") +inner_pkg = os.path.join(repo_root, "abismal_torch") sys.path.insert(0, inner_pkg) sys.path.insert(0, repo_root) import cProfile - -import pandas as pd +import dataclasses +import io import data_loader -import torch -import numpy as np +import distributions +import get_protein_data import lightning as L -import dataclasses +import loss_functionality +import matplotlib.pyplot as plt import metadata_encoder +import numpy as np +import pandas as pd +import reciprocalspaceship as rs +import rs_distributions.modules as rsm +import settings import shoebox_encoder +import torch import torch.nn.functional as F -from networks import * -import distributions -import loss_functionality -from callbacks import LossLogging, Plotting, CorrelationPlotting, ScalePlotting -import get_protein_data -import reciprocalspaceship as rs -from abismal_torch.prior import WilsonPrior -from abismal_torch.symmetry.reciprocal_asu import ReciprocalASU, ReciprocalASUGraph -from reciprocalspaceship.utils import apply_to_hkl, generate_reciprocal_asu -from rasu import * +import wandb from abismal_torch.likelihood import NormalLikelihood -from abismal_torch.surrogate_posterior import FoldedNormalPosterior -from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior from abismal_torch.merging import VariationalMergingModel +from abismal_torch.prior import WilsonPrior from abismal_torch.scaling import ImageScaler -from positional_encoding import positional_encoding - -from lightning.pytorch.loggers import WandbLogger +from abismal_torch.surrogate_posterior import FoldedNormalPosterior +from abismal_torch.symmetry.reciprocal_asu import ReciprocalASU, ReciprocalASUGraph +from adabelief_pytorch import AdaBelief +from callbacks import CorrelationPlotting, LossLogging, Plotting, ScalePlotting from lightning.pytorch.callbacks import Callback, ModelCheckpoint -import wandb -# from pytorch_lightning.callbacks import ModelCheckpoint - -from torchvision.transforms import ToPILImage +from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.utilities import grad_norm -import matplotlib.pyplot as plt -import io +from masked_adam import MaskedAdam, make_lazy_adabelief_for_surrogate_posterior +from networks import * from PIL import Image +from positional_encoding import positional_encoding +from rasu import * +from reciprocalspaceship.utils import apply_to_hkl, generate_reciprocal_asu +from settings import LossSettings, ModelSettings +from torch.optim.lr_scheduler import MultiStepLR, StepLR +from torchvision.transforms import ToPILImage -from settings import ModelSettings, LossSettings -import settings +from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior -import rs_distributions.modules as rsm -from masked_adam import MaskedAdam, make_lazy_adabelief_for_surrogate_posterior +# from pytorch_lightning.callbacks import ModelCheckpoint -from torch.optim.lr_scheduler import StepLR, MultiStepLR -from adabelief_pytorch import AdaBelief -import wandb torch.set_float32_matmul_precision("medium") import torch - class Model(L.LightningModule): @@ -80,30 +75,56 @@ class Model(L.LightningModule): def __init__(self, model_settings: ModelSettings, loss_settings: LossSettings): super().__init__() self.model_settings = model_settings - self.profile_encoder = model_settings.profile_encoder #encoder.CNN_3d() + self.profile_encoder = model_settings.profile_encoder # encoder.CNN_3d() self.metadata_encoder = model_settings.metadata_encoder self.intensity_encoder = model_settings.intensity_encoder self.scale_function = self.model_settings.scale_function - self.profile_distribution = self.model_settings.build_shoebox_profile_distribution(torch.load(os.path.join(self.model_settings.data_directory, "concentration.pt"), weights_only=True), dmodel=self.model_settings.dmodel) - self.background_distribution = self.model_settings.build_background_distribution(dmodel=self.model_settings.dmodel) + self.profile_distribution = ( + self.model_settings.build_shoebox_profile_distribution( + torch.load( + os.path.join( + self.model_settings.data_directory, "concentration.pt" + ), + weights_only=True, + ), + dmodel=self.model_settings.dmodel, + ) + ) + self.background_distribution = ( + self.model_settings.build_background_distribution( + dmodel=self.model_settings.dmodel + ) + ) self.surrogate_posterior = self.initialize_surrogate_posterior() - self.dispersion_param = torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001))torch.nn.Parameter(torch.tensor(0.001)) #torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001)) + self.dispersion_param = torch.tensor( + loss_settings.dispersion_parameter + ) # torch.nn.Parameter(torch.tensor(0.001))torch.nn.Parameter(torch.tensor(0.001)) #torch.tensor(loss_settings.dispersion_parameter) #torch.nn.Parameter(torch.tensor(0.001)) self.loss_settings = loss_settings self.loss_function = loss_functionality.LossFunction() self.loss = torch.Tensor(0) - self.lr = self.model_settings.learning_rate + self.lr = self.model_settings.learning_rate # self._train_dataloader = train_dataloader - pdb_data = get_protein_data.get_protein_data(self.model_settings.protein_pdb_url) - rac = ReciprocalASUGraph(*[ReciprocalASU( - cell=pdb_data["unit_cell"], - spacegroup=pdb_data["spacegroup"], - dmin=float(pdb_data["dmin"]), - anomalous=True, - )]) - self.intensity_prior_distibution = self.model_settings.build_intensity_prior_distribution(rac) - H_rasu = generate_reciprocal_asu(pdb_data["unit_cell"], pdb_data["spacegroup"], float(pdb_data["dmin"]), True) + pdb_data = get_protein_data.get_protein_data( + self.model_settings.protein_pdb_url + ) + rac = ReciprocalASUGraph( + *[ + ReciprocalASU( + cell=pdb_data["unit_cell"], + spacegroup=pdb_data["spacegroup"], + dmin=float(pdb_data["dmin"]), + anomalous=True, + ) + ] + ) + self.intensity_prior_distibution = ( + self.model_settings.build_intensity_prior_distribution(rac) + ) + H_rasu = generate_reciprocal_asu( + pdb_data["unit_cell"], pdb_data["spacegroup"], float(pdb_data["dmin"]), True + ) self.kl_structure_factors_loss = torch.tensor(0) def setup(self, stage=None): @@ -112,7 +133,6 @@ def setup(self, stage=None): self.model_settings = Model.move_settings(self.model_settings, device) self.loss_settings = Model.move_settings(self.loss_settings, device) - @property def device(self): return next(self.parameters()).device @@ -125,25 +145,24 @@ def to(self, *args, **kwargs): self.model_settings = Model.move_settings(self.model_settings, self.device) self.loss_settings = Model.move_settings(self.loss_settings, self.device) - if hasattr(self, 'surrogate_posterior'): + if hasattr(self, "surrogate_posterior"): self.surrogate_posterior.to(self.device) - - if hasattr(self, 'profile_distribution'): + + if hasattr(self, "profile_distribution"): self.profile_distribution.to(self.device) - - if hasattr(self, 'background_distribution'): + + if hasattr(self, "background_distribution"): self.background_distribution.to(self.device) - if hasattr(self, 'scale_function'): + if hasattr(self, "scale_function"): self.scale_function.to(self.device) - - if hasattr(self, 'dataloader') and self.dataloader is not None: - if hasattr(self.dataloader, 'pin_memory'): + + if hasattr(self, "dataloader") and self.dataloader is not None: + if hasattr(self.dataloader, "pin_memory"): self.dataloader.pin_memory = True return self - @staticmethod def move_settings(settings: dataclasses.dataclass, device: torch.device): moved = {} @@ -153,8 +172,14 @@ def move_settings(settings: dataclasses.dataclass, device: torch.device): val = getattr(settings, field.name) if isinstance(val, torch.distributions.Distribution): param_names = val.arg_constraints.keys() - params = {name: getattr(val, name).to(device) if hasattr(getattr(val, name), 'to') else getattr(val, name) - for name in param_names} + params = { + name: ( + getattr(val, name).to(device) + if hasattr(getattr(val, name), "to") + else getattr(val, name) + ) + for name in param_names + } moved[field.name] = type(val)(**params) # print("moved", val) @@ -162,72 +187,126 @@ def move_settings(settings: dataclasses.dataclass, device: torch.device): try: moved[field.name] = val.to(device) except Exception as e: - moved[field.name] = val + moved[field.name] = val else: moved[field.name] = val return dataclasses.replace(settings, **moved) - def initialize_surrogate_posterior(self): - initial_mean = self.model_settings.intensity_prior_distibution.distribution().mean() # self.model_settings.intensity_distribution_initial_location + initial_mean = ( + self.model_settings.intensity_prior_distibution.distribution().mean() + ) # self.model_settings.intensity_distribution_initial_location print("intensity folded normal mean shape", initial_mean.shape) - initial_scale = 0.05 * initial_mean # self.model_settings.intensity_distribution_initial_scale - - surrogate_posterior = FrequencyTrackingPosterior.from_unconstrained_loc_and_scale( - rac=self.model_settings.rac, loc=initial_mean, scale=initial_scale # epsilon=settings.epsilon + initial_scale = ( + 0.05 * initial_mean + ) # self.model_settings.intensity_distribution_initial_scale + + surrogate_posterior = ( + FrequencyTrackingPosterior.from_unconstrained_loc_and_scale( + rac=self.model_settings.rac, + loc=initial_mean, + scale=initial_scale, # epsilon=settings.epsilon + ) ) def check_grad_hook(grad): if grad is not None: - print(f"Gradient stats: min={grad.min()}, max={grad.max()}, mean={grad.mean()}, any_nan={torch.isnan(grad).any()}, all_finite={torch.isfinite(grad).all()}") + print( + f"Gradient stats: min={grad.min()}, max={grad.max()}, mean={grad.mean()}, any_nan={torch.isnan(grad).any()}, all_finite={torch.isfinite(grad).all()}" + ) return grad + surrogate_posterior.distribution.loc.register_hook(check_grad_hook) return surrogate_posterior - - def compute_scale(self, representations: list[torch.Tensor]) -> torch.distributions.Normal: - joined_representation = sum(representations) if len(representations) > 1 else representations[0] #self._add_representation_(image_representation, metadata_representation) # (batch_size, dmodel) + + def compute_scale( + self, representations: list[torch.Tensor] + ) -> torch.distributions.Normal: + joined_representation = ( + sum(representations) if len(representations) > 1 else representations[0] + ) # self._add_representation_(image_representation, metadata_representation) # (batch_size, dmodel) scale = self.scale_function(joined_representation) return scale.distribution - + def _add_representation_(self, representation1, representation2): return representation1 + representation2 - + def compute_shoebox_profile(self, representations: list[torch.Tensor]): return self.profile_distribution.compute_profile(*representations) def compute_background_distribution(self, intensity_representation): return self.background_distribution(intensity_representation) - def compute_photon_rate(self, scale_distribution, background_distribution, profile, surrogate_posterior, ordered_miller_indices, metadata, verbose_output) -> torch.Tensor: - - _samples_predicted_structure_factor = surrogate_posterior.rsample([self.model_settings.number_of_mc_samples]).unsqueeze(-1).permute(1, 0, 2) - samples_predicted_structure_factor = self.surrogate_posterior.rac.gather(_samples_predicted_structure_factor, torch.tensor([0]), ordered_miller_indices) + def compute_photon_rate( + self, + scale_distribution, + background_distribution, + profile, + surrogate_posterior, + ordered_miller_indices, + metadata, + verbose_output, + ) -> torch.Tensor: + + _samples_predicted_structure_factor = ( + surrogate_posterior.rsample([self.model_settings.number_of_mc_samples]) + .unsqueeze(-1) + .permute(1, 0, 2) + ) + samples_predicted_structure_factor = self.surrogate_posterior.rac.gather( + _samples_predicted_structure_factor, + torch.tensor([0]), + ordered_miller_indices, + ) - samples_scale = scale_distribution.rsample([self.model_settings.number_of_mc_samples]).permute(1, 0, 2) #(batch_size, number_of_samples, 1) + samples_scale = scale_distribution.rsample( + [self.model_settings.number_of_mc_samples] + ).permute( + 1, 0, 2 + ) # (batch_size, number_of_samples, 1) self.logger.experiment.log({"scale_samples_average": torch.mean(samples_scale)}) - samples_profile = profile.rsample([self.model_settings.number_of_mc_samples]).permute(1,0,2) - samples_background = background_distribution.rsample([self.model_settings.number_of_mc_samples]).permute(1, 0, 2) #(batch_size, number_of_samples, 1) + samples_profile = profile.rsample( + [self.model_settings.number_of_mc_samples] + ).permute(1, 0, 2) + samples_background = background_distribution.rsample( + [self.model_settings.number_of_mc_samples] + ).permute( + 1, 0, 2 + ) # (batch_size, number_of_samples, 1) if verbose_output: - photon_rate = samples_scale * torch.square(samples_predicted_structure_factor) * samples_profile + samples_background + photon_rate = ( + samples_scale + * torch.square(samples_predicted_structure_factor) + * samples_profile + + samples_background + ) return { - 'photon_rate': photon_rate, - 'samples_profile': samples_profile, - 'samples_predicted_structure_factor': samples_predicted_structure_factor, - 'samples_scale': samples_scale, - 'samples_background': samples_background + "photon_rate": photon_rate, + "samples_profile": samples_profile, + "samples_predicted_structure_factor": samples_predicted_structure_factor, + "samples_scale": samples_scale, + "samples_background": samples_background, } else: - photon_rate = samples_scale * torch.square(samples_predicted_structure_factor) * samples_profile + samples_background # [batch_size, mc_samples, pixels] + photon_rate = ( + samples_scale + * torch.square(samples_predicted_structure_factor) + * samples_profile + + samples_background + ) # [batch_size, mc_samples, pixels] return photon_rate - + def _cut_metadata(self, metadata) -> torch.Tensor: # metadata_cut = torch.index_select(metadata, dim=1, index=torch.tensor(self.model_settings.metadata_indices_to_keep, device=self.device)) if self.model_settings.use_positional_encoding: - encoded_metadata = positional_encoding(X=metadata, L=self.model_settings.number_of_frequencies_in_positional_encoding) + encoded_metadata = positional_encoding( + X=metadata, + L=self.model_settings.number_of_frequencies_in_positional_encoding, + ) print("encoding dim metadata", encoded_metadata.shape) return encoded_metadata else: @@ -242,28 +321,52 @@ def _log_representation_stats(self, tensor: torch.Tensor, name: str): f"{name}/min": tensor.min().item(), f"{name}/max": tensor.max().item(), } - - if hasattr(self, 'logger') and self.logger is not None: - self.logger.experiment.log(stats) + if hasattr(self, "logger") and self.logger is not None: + self.logger.experiment.log(stats) def _batch_to_representations(self, batch: tuple): - shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch - standardized_counts = shoeboxes_batch #[:,:,-1].reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21) + ( + shoeboxes_batch, + metadata_batch, + dead_pixel_mask_batch, + counts_batch, + hkl_batch, + processed_metadata_batch, + ) = batch + standardized_counts = ( + shoeboxes_batch # [:,:,-1].reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21) + ) print("sb shape", standardized_counts.shape) - profile_representation = self.profile_encoder(standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), mask=dead_pixel_mask_batch) + profile_representation = self.profile_encoder( + standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), + mask=dead_pixel_mask_batch, + ) - intensity_representation = self.intensity_encoder(standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), mask=dead_pixel_mask_batch) + intensity_representation = self.intensity_encoder( + standardized_counts.reshape(shoeboxes_batch.shape[0], 1, 3, 21, 21), + mask=dead_pixel_mask_batch, + ) + metadata_representation = self.metadata_encoder( + self._cut_metadata(processed_metadata_batch).float() + ) # (batch_size, dmodel) + print( + "metadata representation (batch size, dmodel)", + metadata_representation.shape, + ) - metadata_representation = self.metadata_encoder(self._cut_metadata(processed_metadata_batch).float()) # (batch_size, dmodel) - print("metadata representation (batch size, dmodel)", metadata_representation.shape) + joined_shoebox_representation = ( + intensity_representation + metadata_representation + ) - joined_shoebox_representation = intensity_representation + metadata_representation - - pooled_image_representation = torch.max(joined_shoebox_representation, dim=0, keepdim=True)[0] # (1, dmodel) - image_representation = pooled_image_representation #+ metadata_representation + pooled_image_representation = torch.max( + joined_shoebox_representation, dim=0, keepdim=True + )[ + 0 + ] # (1, dmodel) + image_representation = pooled_image_representation # + metadata_representation print("image rep (1, dmodel)", image_representation.shape) if torch.isnan(metadata_representation).any(): @@ -272,29 +375,53 @@ def _batch_to_representations(self, batch: tuple): raise ValueError("MLP profile_representation produced NaNs!") if torch.isnan(intensity_representation).any(): raise ValueError("MLP intensity_representation produced NaNs!") - - - return intensity_representation, metadata_representation, image_representation, profile_representation - + + return ( + intensity_representation, + metadata_representation, + image_representation, + profile_representation, + ) + def forward(self, batch, verbose_output=False): try: - intensity_representation, metadata_representation, image_representation, profile_representation = self._batch_to_representations(batch=batch) + ( + intensity_representation, + metadata_representation, + image_representation, + profile_representation, + ) = self._batch_to_representations(batch=batch) self._log_representation_stats(profile_representation, "profile_rep") self._log_representation_stats(metadata_representation, "metadata_rep") self._log_representation_stats(image_representation, "image_rep") self._log_representation_stats(intensity_representation, "intensity_rep") - - shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch + + ( + shoeboxes_batch, + metadata_batch, + dead_pixel_mask_batch, + counts_batch, + hkl_batch, + processed_metadata_batch, + ) = batch print("unique hkls", len(torch.unique(hkl_batch))) - scale_distribution = self.compute_scale(representations=[metadata_representation, image_representation])#[image_representation, metadata_representation]) # torch.distributions.Normal instance + scale_distribution = self.compute_scale( + representations=[metadata_representation, image_representation] + ) # [image_representation, metadata_representation]) # torch.distributions.Normal instance print("compute profile") - shoebox_profile = self.compute_shoebox_profile(representations=[profile_representation])#, image_representation]) + shoebox_profile = self.compute_shoebox_profile( + representations=[profile_representation] + ) # , image_representation]) print("compute bg") - background_distribution = self.compute_background_distribution(intensity_representation=intensity_representation)# shoebox_representation)#+image_representation) - self.surrogate_posterior.update_observed(rasu_id=self.surrogate_posterior.rac.rasu_ids[0], H=hkl_batch) + background_distribution = self.compute_background_distribution( + intensity_representation=intensity_representation + ) # shoebox_representation)#+image_representation) + self.surrogate_posterior.update_observed( + rasu_id=self.surrogate_posterior.rac.rasu_ids[0], H=hkl_batch + ) print("compute rate") photon_rate = self.compute_photon_rate( @@ -304,12 +431,20 @@ def forward(self, batch, verbose_output=False): surrogate_posterior=self.surrogate_posterior, ordered_miller_indices=hkl_batch, metadata=metadata_batch, - verbose_output=verbose_output + verbose_output=verbose_output, ) if verbose_output: return (shoeboxes_batch, photon_rate, hkl_batch, counts_batch) - return (counts_batch, dead_pixel_mask_batch, hkl_batch, photon_rate, background_distribution, scale_distribution, self.surrogate_posterior, - self.profile_distribution) + return ( + counts_batch, + dead_pixel_mask_batch, + hkl_batch, + photon_rate, + background_distribution, + scale_distribution, + self.surrogate_posterior, + self.profile_distribution, + ) except Exception as e: print(f"failed in forward: {e}") for param in self.parameters(): @@ -323,33 +458,43 @@ def training_step_(self, batch): for n, p in self.named_parameters(): if not n.startswith("surrogate_posterior."): p.requires_grad_(False) - local_runs = 3 if (self.current_epoch < 0.3*self.trainer.max_epochs) else 1 + local_runs = 3 if (self.current_epoch < 0.3 * self.trainer.max_epochs) else 1 for _ in range(local_runs): output = self(batch=batch) - local_loss = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param).loss + local_loss = self.loss_function( + *output, self.model_settings, self.loss_settings, self.dispersion_param + ).loss opt_loc.zero_grad(set_to_none=True) - self.manual_backward(local_loss) # builds sparse grads on rows in idx + self.manual_backward(local_loss) # builds sparse grads on rows in idx opt_loc.step() self.train() for n, p in self.named_parameters(): if not n.startswith("surrogate_posterior."): p.requires_grad_(True) output = self(batch=batch) - loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param) + loss_output = self.loss_function( + *output, self.model_settings, self.loss_settings, self.dispersion_param + ) opt_main.zero_grad() loss = loss_output.loss loss.backward() opt_main.step() self.loss = loss.detach() - self.logger.experiment.log({ + self.logger.experiment.log( + { "loss/kl_structure_factors_step": loss_output.kl_structure_factors.item(), "loss/kl_background": loss_output.kl_background.item(), "loss/kl_profile": loss_output.kl_profile.item(), "loss/kl_scale": loss_output.kl_scale.item(), "loss/log_likelihood": loss_output.log_likelihood.item(), - }) + } + ) self.log("disp_param", self.dispersion_param, on_step=True, on_epoch=True) - self.log("loss/kl_structure_factors", loss_output.kl_structure_factors.item(),on_epoch=True) + self.log( + "loss/kl_structure_factors", + loss_output.kl_structure_factors.item(), + on_epoch=True, + ) self.loss = loss print("log loss step") self.log("train/loss_step", self.loss, on_step=True, on_epoch=True) @@ -357,20 +502,31 @@ def training_step_(self, batch): for name, norm in norms.items(): self.log(f"grad_norm/{name}", norm) return self.loss - - + def training_step(self, batch): output = self(batch=batch) - loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param, self.current_epoch) + loss_output = self.loss_function( + *output, + self.model_settings, + self.loss_settings, + self.dispersion_param, + self.current_epoch, + ) self.loss = loss_output.loss.detach() - self.logger.experiment.log({ + self.logger.experiment.log( + { "loss/kl_structure_factors": loss_output.kl_structure_factors.item(), "loss/kl_background": loss_output.kl_background.item(), "loss/kl_profile": loss_output.kl_profile.item(), "loss/kl_scale": loss_output.kl_scale.item(), "loss/log_likelihood": loss_output.log_likelihood.item(), - }) - self.log("loss/kl_structure_factors", loss_output.kl_structure_factors.item(),on_epoch=True) + } + ) + self.log( + "loss/kl_structure_factors", + loss_output.kl_structure_factors.item(), + on_epoch=True, + ) self.loss = loss_output.loss self.log("train/loss_step", self.loss, on_step=True, on_epoch=True) norms = grad_norm(self, norm_type=2) @@ -378,43 +534,75 @@ def training_step(self, batch): self.log(f"grad_norm/{name}", norm) return self.loss - - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx): print(f"Validation step called for batch {batch_idx}") output = self(batch=batch) - loss_output = self.loss_function(*output, self.model_settings, self.loss_settings, self.dispersion_param, self.current_epoch) + loss_output = self.loss_function( + *output, + self.model_settings, + self.loss_settings, + self.dispersion_param, + self.current_epoch, + ) - self.log("validation_loss/kl_structure_factors", loss_output.kl_structure_factors.item(), on_step=False, on_epoch=True) - self.log("validation_loss/kl_background", loss_output.kl_background.item(), on_step=False, on_epoch=True) - self.log("validation_loss/kl_profile", loss_output.kl_profile.item(), on_step=False, on_epoch=True) - self.log("validation_loss/log_likelihood", loss_output.log_likelihood.item(), on_step=False, on_epoch=True) - self.log("validation_loss/loss", loss_output.loss.item(), on_step=True, on_epoch=True) - self.log("validation/loss_step_epoch", loss_output.loss.item(), on_step=False, on_epoch=True) + self.log( + "validation_loss/kl_structure_factors", + loss_output.kl_structure_factors.item(), + on_step=False, + on_epoch=True, + ) + self.log( + "validation_loss/kl_background", + loss_output.kl_background.item(), + on_step=False, + on_epoch=True, + ) + self.log( + "validation_loss/kl_profile", + loss_output.kl_profile.item(), + on_step=False, + on_epoch=True, + ) + self.log( + "validation_loss/log_likelihood", + loss_output.log_likelihood.item(), + on_step=False, + on_epoch=True, + ) + self.log( + "validation_loss/loss", loss_output.loss.item(), on_step=True, on_epoch=True + ) + self.log( + "validation/loss_step_epoch", + loss_output.loss.item(), + on_step=False, + on_epoch=True, + ) return loss_output.loss - - def surrogate_full_sweep(self, dataloader=None, K=1): dl = self.dataloader opt_main, opt_loc = self.optimizers() for n, p in self.named_parameters(): if not n.startswith("surrogate_posterior."): p.requires_grad_(False) - self.eval() + self.eval() self.surrogate_posterior.train() for xb in dl: xb = self.transfer_batch_to_device(xb, self.device) - - for _ in range(K): + + for _ in range(K): out = self(batch=xb) - loss_obj = self.loss_function(*out, self.model_settings, self.loss_settings, self.dispersion_param) + loss_obj = self.loss_function( + *out, self.model_settings, self.loss_settings, self.dispersion_param + ) opt_loc.zero_grad(set_to_none=True) self.manual_backward(loss_obj.loss) opt_loc.step() - self.train() + self.train() for n, p in self.named_parameters(): if not n.startswith("surrogate_posterior."): p.requires_grad_(True) @@ -431,23 +619,24 @@ def on_train_epoch_end(self): # # if metric_value is not None: # sched.step(metric_value.detach().cpu().item()) - sched.step() + sched.step() def configure_optimizers(self): """Configure optimizer based on settings from config YAML.""" - + # Get optimizer parameters from settings optimizer_name = self.model_settings.optimizer_name lr = self.model_settings.learning_rate betas = self.model_settings.optimizer_betas weight_decay = self.model_settings.weight_decay eps = float(self.model_settings.optimizer_eps) - + if optimizer_name == "AdaBelief": - opt = AdaBelief( self.parameters(), - # [ + opt = AdaBelief( + self.parameters(), + # [ # {'params': [p for n, p in self.named_parameters() if not n.startswith("surrogate_posterior.")], 'lr': 1e-4}, - # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3}, + # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3}, # ], lr=lr, eps=eps, @@ -478,26 +667,32 @@ def configure_optimizers(self): elif optimizer_name == "LazyAdam": opt = MaskedAdam( self.parameters(), - lr = lr, - betas = betas, + lr=lr, + betas=betas, surrogate_posterior_module=self.surrogate_posterior, weight_decay=weight_decay, ) - elif optimizer_name =="LazyAdaBelief": + elif optimizer_name == "LazyAdaBelief": opt = make_lazy_adabelief_for_surrogate_posterior( - self, lr=lr, weight_decay=weight_decay, - decoupled_weight_decay=self.model_settings.weight_decouple, zero_tol=1e-12 + self, + lr=lr, + weight_decay=weight_decay, + decoupled_weight_decay=self.model_settings.weight_decouple, + zero_tol=1e-12, + ) + print( + "self.model_settings.weight_decouple", + self.model_settings.weight_decouple, ) - print("self.model_settings.weight_decouple", self.model_settings.weight_decouple) else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") # Configure scheduler # opt_local = AdaBelief( - # [ + # [ # #{'params': [p for n, p in self.named_parameters() if not n.startswith("surrogate_posterior.")], 'lr': 1e-4}, - # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3}, + # {'params': self.surrogate_posterior.parameters(), 'lr': 1e-3}, # ], # lr=lr, # eps=eps, @@ -506,7 +701,11 @@ def configure_optimizers(self): # weight_decay=weight_decay, # ) - sch = MultiStepLR(opt, milestones=self.model_settings.milestones, gamma=self.model_settings.scheduler_gamma) + sch = MultiStepLR( + opt, + milestones=self.model_settings.milestones, + gamma=self.model_settings.scheduler_gamma, + ) # sch = DebugStepLR(opt, step_size=self.model_settings.scheduler_step, gamma=self.model_settings.scheduler_gamma) @@ -517,36 +716,55 @@ def configure_optimizers(self): # return [opt], [{"scheduler": sch, "interval": "epoch", "monitor": "loss/kl_structure_factors"}] return [opt], [{"scheduler": sch, "interval": "epoch"}] - + def val_dataloader(self): return self.dataloader.load_data_set_batched_by_image( - data_set_to_load=self.dataloader.validation_data_set, + data_set_to_load=self.dataloader.validation_data_set, ) - - def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted_intensities.mtz"): - self.eval() + def inference_to_intensities( + self, dataloader, wandb_dir, output_path="predicted_intensities.mtz" + ): + self.eval() results = [] for batch_idx, batch in enumerate(dataloader): # Move batch to device # batch = self.transfer_batch_to_device(batch, self.device) - shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl_batch, processed_metadata_batch = batch - + ( + shoeboxes_batch, + metadata_batch, + dead_pixel_mask_batch, + counts_batch, + hkl_batch, + processed_metadata_batch, + ) = batch # Compute representations - profile_representation, metadata_representation, image_representation, shoebox_profile_representation = self._batch_to_representations(batch=batch) + ( + profile_representation, + metadata_representation, + image_representation, + shoebox_profile_representation, + ) = self._batch_to_representations(batch=batch) # Get scale and structure factor distributions - scale_distribution = self.compute_scale([metadata_representation, image_representation]) + scale_distribution = self.compute_scale( + [metadata_representation, image_representation] + ) S0 = scale_distribution.rsample([1]).squeeze() print("shape so", S0.shape) structure_factor_dist = self.surrogate_posterior params = self.surrogate_posterior.rac.gather( - source=torch.stack([self.surrogate_posterior.distribution.loc, - self.surrogate_posterior.distribution.scale], dim=1), - rasu_id=torch.tensor([0], device=self.device), - H=hkl_batch + source=torch.stack( + [ + self.surrogate_posterior.distribution.loc, + self.surrogate_posterior.distribution.scale, + ], + dim=1, + ), + rasu_id=torch.tensor([0], device=self.device), + H=hkl_batch, ) F_loc = params[:, 0] @@ -556,7 +774,7 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted print("F_scale", F_scale.shape) I_mean = S0 * (F_loc**2 + F_scale**2) - I_sigma = S0 * torch.sqrt(2*F_scale**4 + 4*F_scale**2 * F_loc**2) + I_sigma = S0 * torch.sqrt(2 * F_scale**4 + 4 * F_scale**2 * F_loc**2) # Collect HKL indices HKL = hkl_batch.cpu().detach().numpy() @@ -564,13 +782,15 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted I_sigma_np = I_sigma.cpu().detach().numpy().flatten() for i in range(HKL.shape[0]): - results.append({ - "H": int(HKL[i, 0]), - "K": int(HKL[i, 1]), - "L": int(HKL[i, 2]), - "I_mean": float(I_mean_np[i]), - "I_sigma": float(I_sigma_np[i]) - }) + results.append( + { + "H": int(HKL[i, 0]), + "K": int(HKL[i, 1]), + "L": int(HKL[i, 2]), + "I_mean": float(I_mean_np[i]), + "I_sigma": float(I_sigma_np[i]), + } + ) print(f"Processed batch {batch_idx+1}/{len(dataloader)}") @@ -596,7 +816,7 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted ds = ds.set_dtypes({"I": "J", "SIGI": "Q"}) else: ds._mtz_dtypes = {"I": "J", "SIGI": "Q"} - + ds.cell = gemmi.UnitCell(79.1, 79.1, 38.4, 90, 90, 90) ds.spacegroup = gemmi.SpaceGroup("P43212") @@ -606,13 +826,14 @@ def inference_to_intensities(self, dataloader, wandb_dir, output_path="predicted return mtz_path + def compute_I_mean_sigma(F_loc, F_scale, S=1.0): - """ - F_loc, F_scale: torch tensors of folded normal parameters (loc, scale) - S: deterministic scale factor - Returns: I_mean, I_sigma - """ - I_mean = S * (F_loc**2 + F_scale**2) - Var_F2 = 2*F_scale**4 + 4*F_scale**2 * F_loc**2 - I_sigma = S * torch.sqrt(Var_F2) - return I_mean, I_sigma \ No newline at end of file + """ + F_loc, F_scale: torch tensors of folded normal parameters (loc, scale) + S: deterministic scale factor + Returns: I_mean, I_sigma + """ + I_mean = S * (F_loc**2 + F_scale**2) + Var_F2 = 2 * F_scale**4 + 4 * F_scale**2 * F_loc**2 + I_sigma = S * torch.sqrt(Var_F2) + return I_mean, I_sigma diff --git a/src/factory/networks.py b/src/factory/networks.py index 754e1d7..a4dfa1d 100644 --- a/src/factory/networks.py +++ b/src/factory/networks.py @@ -1,10 +1,12 @@ -import torch +import dataclasses import math + +import torch import torch.nn.functional as F -import dataclasses +from tensordict.nn.distributions import Delta from torch.distributions import TransformedDistribution from torch.distributions.transforms import AbsTransform -from tensordict.nn.distributions import Delta + def weight_initializer(weight): fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2]) @@ -12,34 +14,50 @@ def weight_initializer(weight): a = -2.0 * std b = 2.0 * std torch.nn.init.trunc_normal_(weight, 0.0, std, a, b) - + return weight + class Linear(torch.nn.Linear): def __init__(self, in_features: int, out_features: int, bias=False): super().__init__(in_features, out_features, bias=bias) # Set bias=False def reset_parameters(self) -> None: self.weight = weight_initializer(self.weight) - + def forward(self, input): # Check for NaN values in input if torch.isnan(input).any(): print(f"WARNING: NaN values in Linear layer input! Shape: {input.shape}") print("NaN count:", torch.isnan(input).sum().item()) - + output = super().forward(input) - + # Check for NaN values in output if torch.isnan(output).any(): print(f"WARNING: NaN values in Linear layer output! Shape: {output.shape}") print("NaN count:", torch.isnan(output).sum().item()) - print("Weight stats - min:", self.weight.min().item(), "max:", self.weight.max().item(), "mean:", self.weight.mean().item()) + print( + "Weight stats - min:", + self.weight.min().item(), + "max:", + self.weight.max().item(), + "mean:", + self.weight.mean().item(), + ) if self.bias is not None: - print("Bias stats - min:", self.bias.min().item(), "max:", self.bias.max().item(), "mean:", self.bias.mean().item()) - + print( + "Bias stats - min:", + self.bias.min().item(), + "max:", + self.bias.max().item(), + "mean:", + self.bias.mean().item(), + ) + return output + class Constraint(torch.nn.Module): def __init__(self, eps=1e-12, beta=1.0): super().__init__() @@ -58,26 +76,30 @@ def __init__(self, hidden_dim=64, input_dim=7, number_of_layers=2): self.input_dim = input_dim self.add_bias = True self._build_mlp_(number_of_layers=number_of_layers) - + def _build_mlp_(self, number_of_layers): mlp_layers = [] for i in range(number_of_layers): - mlp_layers.append(torch.nn.Linear( - in_features=self.input_dim if i == 0 else self.hidden_dimension, - out_features=self.hidden_dimension, - bias=self.add_bias, - )) + mlp_layers.append( + torch.nn.Linear( + in_features=self.input_dim if i == 0 else self.hidden_dimension, + out_features=self.hidden_dimension, + bias=self.add_bias, + ) + ) mlp_layers.append(self.activation) self.network = torch.nn.Sequential(*mlp_layers) - + def forward(self, x): return self.network(x) - + + @dataclasses.dataclass class ScaleOutput: distribution: torch.distributions.Normal network: torch.Tensor + class BaseDistributionLayer(torch.nn.Module): def __init__(): @@ -85,7 +107,8 @@ def __init__(): def forward(): pass - + + class DeltaDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -93,11 +116,12 @@ def __init__(self): self.bijector = torch.nn.Softplus() def forward(self, hidden_representation): - loc = hidden_representation#torch.unbind(hidden_representation, dim=-1) + loc = hidden_representation # torch.unbind(hidden_representation, dim=-1) print("shape los delta dist", loc.shape) loc = self.bijector(loc) + 1e-3 return Delta(param=loc) + class NormalDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -109,6 +133,7 @@ def forward(self, hidden_representation): scale = self.bijector(scale) + 1e-3 return torch.distributions.Normal(loc=loc, scale=scale) + class SoftplusNormalDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -121,6 +146,7 @@ def forward(self, hidden_representation): # self.normal = torch.distributions.Normal(loc=loc, scale=scale) return SoftplusNormal(loc=loc, scale=scale) + class SoftplusNormal(torch.nn.Module): def __init__(self, loc, scale): super().__init__() @@ -130,7 +156,7 @@ def __init__(self, loc, scale): def rsample(self, sample_shape=[1]): return self.bijector(self.normal.rsample(sample_shape)).unsqueeze(-1) - def log_prob(self,x): + def log_prob(self, x): # x: sample (must be > 0), mu and sigma are parameters z = torch.log(torch.expm1(x)) # inverse softplus normal = self.normal @@ -141,6 +167,7 @@ def log_prob(self,x): def forward(self, x, number_of_samples=1): return self.bijector(self.normal(x)) + class TruncatedNormalDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -153,15 +180,20 @@ def forward(self, hidden_representation): # self.normal = torch.distributions.Normal(loc=loc, scale=scale) return PositiveTruncatedNormal(loc=loc, scale=scale) + class PositiveTruncatedNormal(torch.nn.Module): def __init__(self, loc, scale): super().__init__() self.normal = torch.distributions.Normal(loc, scale) self.a = (0.0 - loc) / scale # standardized lower bound = 0 - self.Z = torch.tensor(1.0 - self.normal.cdf(0.0), device=loc.device, dtype=loc.dtype) # normalization constant + self.Z = torch.tensor( + 1.0 - self.normal.cdf(0.0), device=loc.device, dtype=loc.dtype + ) # normalization constant def rsample(self, sample_shape=torch.Size()): - u = torch.rand(sample_shape + self.a.shape, device=self.a.device, dtype=self.a.dtype) + u = torch.rand( + sample_shape + self.a.shape, device=self.a.device, dtype=self.a.dtype + ) u = u * self.Z + self.normal.cdf(0.0) # map [0,1] to [cdf(0), 1] z = self.normal.icdf(u) return z @@ -170,6 +202,7 @@ def log_prob(self, x): logp = self.normal.log_prob(x) return logp - torch.log(self.Z) + class NormalIRSample(torch.autograd.Function): @staticmethod def forward(ctx, loc, scale, samples, dFdmu, dFdsig, q): @@ -186,6 +219,7 @@ def backward(ctx, grad_output): ) = ctx.saved_tensors return grad_output * dzdmu, grad_output * dzdsig, None, None, None, None + class FoldedNormal(torch.distributions.Distribution): """ Folded Normal distribution class @@ -197,7 +231,10 @@ class FoldedNormal(torch.distributions.Distribution): Default is None. """ - arg_constraints = {"loc": torch.distributions.constraints.real, "scale": torch.distributions.constraints.positive} + arg_constraints = { + "loc": torch.distributions.constraints.real, + "scale": torch.distributions.constraints.positive, + } support = torch.distributions.constraints.nonnegative def __init__(self, loc, scale, validate_args=None): @@ -323,6 +360,7 @@ def rsample(self, sample_shape=torch.Size()): samples.requires_grad_(True) return self._irsample(self.loc, self.scale, samples, dFdmu, dFdsigma, q) + class FoldedNormalDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -334,6 +372,7 @@ def forward(self, hidden_representation): scale = self.bijector(scale) + 1e-3 return FoldedNormal(loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1)) + class LogNormalDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -345,12 +384,20 @@ def forward(self, hidden_representation): scale = self.bijector(scale) + 1e-3 print("lognormal loc, scale", loc.shape, scale.shape) dist = torch.distributions.LogNormal(loc=loc, scale=scale) - print("torch.distributions.LogNormal(loc=loc, scale=scale)", dist.rsample().shape) - dist = torch.distributions.LogNormal(loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1)) - print("torch.distributions.LogNormal(loc=loc, scale=scale) unsqueeze", dist.rsample().shape) + print( + "torch.distributions.LogNormal(loc=loc, scale=scale)", dist.rsample().shape + ) + dist = torch.distributions.LogNormal( + loc=loc.unsqueeze(-1), scale=scale.unsqueeze(-1) + ) + print( + "torch.distributions.LogNormal(loc=loc, scale=scale) unsqueeze", + dist.rsample().shape, + ) return dist + class GammaDistributionLayer(torch.nn.Module): def __init__(self): super().__init__() @@ -361,10 +408,20 @@ def forward(self, hidden_representation): concentration, rate = torch.unbind(hidden_representation, dim=-1) rate = self.bijector(rate) + 1e-3 concentration = self.bijector(concentration) + 1e-3 - return torch.distributions.Gamma(concentration=concentration.unsqueeze(-1), rate=rate.unsqueeze(-1)) + return torch.distributions.Gamma( + concentration=concentration.unsqueeze(-1), rate=rate.unsqueeze(-1) + ) + class MLPScale(torch.nn.Module): - def __init__(self, input_dimension=64, scale_distribution=FoldedNormalDistributionLayer, hidden_dimension=64, number_of_layers=1, initial_scale_guess=2/140): + def __init__( + self, + input_dimension=64, + scale_distribution=FoldedNormalDistributionLayer, + hidden_dimension=64, + number_of_layers=1, + initial_scale_guess=2 / 140, + ): super().__init__() self.activation = torch.nn.ReLU() self.hidden_dimension = hidden_dimension @@ -377,11 +434,15 @@ def __init__(self, input_dimension=64, scale_distribution=FoldedNormalDistributi def _build_mlp_(self, number_of_layers): mlp_layers = [] for i in range(number_of_layers): - mlp_layers.append(torch.nn.Linear( - in_features=self.input_dimension if i == 0 else self.hidden_dimension, - out_features=self.hidden_dimension, - bias=self.add_bias, - )) + mlp_layers.append( + torch.nn.Linear( + in_features=( + self.input_dimension if i == 0 else self.hidden_dimension + ), + out_features=self.hidden_dimension, + bias=self.add_bias, + ) + ) mlp_layers.append(self.activation) self.network = torch.nn.Sequential(*mlp_layers) @@ -391,16 +452,17 @@ def _build_mlp_(self, number_of_layers): out_features=self.scale_distribution_layer.len_params, bias=self.add_bias, ) - if self.add_bias: with torch.no_grad(): for i in range(self.scale_distribution_layer.len_params): - if i==self.scale_distribution_layer.len_params-1: - final_linear.bias[i] = torch.log(torch.tensor(self.initial_scale_guess)) + if i == self.scale_distribution_layer.len_params - 1: + final_linear.bias[i] = torch.log( + torch.tensor(self.initial_scale_guess) + ) else: final_linear.bias[i] = torch.tensor(0.1) - + map_to_distribution_layers.append(final_linear) map_to_distribution_layers.append(self.scale_distribution_layer) @@ -408,8 +470,13 @@ def _build_mlp_(self, number_of_layers): def forward(self, x) -> ScaleOutput: h = self.network(x) - print("MLP output:", h.mean().item(), h.min().item(), h.max().item(), - "any nan?", torch.isnan(h).any().item()) - + print( + "MLP output:", + h.mean().item(), + h.min().item(), + h.max().item(), + "any nan?", + torch.isnan(h).any().item(), + ) + return ScaleOutput(distribution=self.distribution(h), network=h) - \ No newline at end of file diff --git a/src/factory/parse_yaml.py b/src/factory/parse_yaml.py index bd1dade..b9da134 100644 --- a/src/factory/parse_yaml.py +++ b/src/factory/parse_yaml.py @@ -1,14 +1,15 @@ -import yaml import os -import wandb -from settings import ModelSettings, LossSettings, DataLoaderSettings, PhenixSettings -from abismal_torch.prior import WilsonPrior + import distributions -import networks -from lightning.pytorch.loggers import WandbLogger import metadata_encoder +import networks import shoebox_encoder import torch +import wandb +import yaml +from abismal_torch.prior import WilsonPrior +from lightning.pytorch.loggers import WandbLogger +from settings import DataLoaderSettings, LossSettings, ModelSettings, PhenixSettings REGISTRY = { "HalfNormalDistribution": distributions.HalfNormalDistribution, @@ -34,14 +35,15 @@ "TorchHalfNormal": torch.distributions.HalfNormal, "TorchExponential": torch.distributions.Exponential, "rsFoldedNormal": networks.FoldedNormal, - } + def _log_settings_from_yaml(path: str, logger: WandbLogger): - with open(path, 'r') as file: + with open(path, "r") as file: config = yaml.safe_load(file) logger.experiment.config.update(config) + def instantiate_from_config(config: dict): def resolve_value(v): if isinstance(v, dict): @@ -60,23 +62,33 @@ def resolve_value(v): def consistancy_check(model_settings_dict: dict): - + dmodel = model_settings_dict.get("dmodel") metadata_indices = model_settings_dict.get("metadata_indices_to_keep", []) if "shoebox_encoder" in model_settings_dict: - model_settings_dict["shoebox_encoder"].setdefault("params", {})["out_dim"] = dmodel + model_settings_dict["shoebox_encoder"].setdefault("params", {})[ + "out_dim" + ] = dmodel if "metadata_encoder" in model_settings_dict: if model_settings_dict.get("use_positional_encoding", False) == True: - number_of_frequencies = model_settings_dict.get("number_of_frequencies_in_positional_encoding", 2) - model_settings_dict["metadata_encoder"].setdefault("params", {})["feature_dim"] = len(metadata_indices) * number_of_frequencies * 2 + number_of_frequencies = model_settings_dict.get( + "number_of_frequencies_in_positional_encoding", 2 + ) + model_settings_dict["metadata_encoder"].setdefault("params", {})[ + "feature_dim" + ] = (len(metadata_indices) * number_of_frequencies * 2) else: - model_settings_dict["metadata_encoder"].setdefault("params", {})["feature_dim"] = len(metadata_indices) + model_settings_dict["metadata_encoder"].setdefault("params", {})[ + "feature_dim" + ] = len(metadata_indices) model_settings_dict["metadata_encoder"]["params"]["output_dims"] = dmodel if "scale_function" in model_settings_dict: - model_settings_dict["scale_function"].setdefault("params", {})["input_dimension"] = dmodel + model_settings_dict["scale_function"].setdefault("params", {})[ + "input_dimension" + ] = dmodel def load_settings_from_yaml(config): @@ -84,7 +96,7 @@ def load_settings_from_yaml(config): # config = yaml.safe_load(file) model_settings_dict = config.get("model_settings", {}) - + for key, value in model_settings_dict.items(): if value == "None": model_settings_dict[key] = None @@ -100,7 +112,11 @@ def load_settings_from_yaml(config): phenix_settings = PhenixSettings(**config.get("phenix_settings", {})) dataloader_settings_dict = config.get("dataloader_settings", {}) - dataloader_settings = DataLoaderSettings(**dataloader_settings_dict) if dataloader_settings_dict else None + dataloader_settings = ( + DataLoaderSettings(**dataloader_settings_dict) + if dataloader_settings_dict + else None + ) return model_settings, loss_settings, phenix_settings, dataloader_settings @@ -116,4 +132,4 @@ def resolve_builders(settings_dict): if field in settings_dict: settings_dict[field] = REGISTRY[settings_dict[field]] - return settings_dict \ No newline at end of file + return settings_dict diff --git a/src/factory/phenix_callback.py b/src/factory/phenix_callback.py index 3e6a2dd..08bf118 100644 --- a/src/factory/phenix_callback.py +++ b/src/factory/phenix_callback.py @@ -1,22 +1,24 @@ +import concurrent.futures +import glob +import multiprocessing as mp import os import re import subprocess -import pandas as pd -import wandb -import settings -from model import * -import glob import time + import generate_eff -import concurrent.futures -import multiprocessing as mp import matplotlib.pyplot as plt +import pandas as pd +import settings +import wandb +from model import * + def extract_r_values_from_pdb(pdb_file_path: str) -> tuple[float, float]: - + r_work, r_free = None, None try: - with open(pdb_file_path, 'r') as f: + with open(pdb_file_path, "r") as f: for line in f: if "REMARK 3 R VALUE (WORKING SET)" in line: r_work = float(line.strip().split(":")[1]) @@ -28,7 +30,7 @@ def extract_r_values_from_pdb(pdb_file_path: str) -> tuple[float, float]: def extract_r_values_from_log(log_file_path: str) -> tuple: - + pattern1 = re.compile(r"Start R-work") pattern2 = re.compile(r"Final R-work") @@ -36,7 +38,7 @@ def extract_r_values_from_log(log_file_path: str) -> tuple: with open(log_file_path, "r") as f: lines = f.readlines() print("opened log file") - + match = re.search(r"epoch=(\d+)", log_file_path) if match: epoch = match.group(1) @@ -48,7 +50,7 @@ def extract_r_values_from_log(log_file_path: str) -> tuple: matched_lines_final = [line.strip() for line in lines if pattern2.search(line)] # Initialize all values as None - + rwork_start = rwork_final = rfree_start = rfree_final = None if matched_lines_start: @@ -59,7 +61,9 @@ def extract_r_values_from_log(log_file_path: str) -> tuple: rwork_start = float(numbers[0]) rfree_start = float(numbers[1]) else: - print(f"Start R-work line found but not enough numbers: {matched_lines_start[0]}") + print( + f"Start R-work line found but not enough numbers: {matched_lines_start[0]}" + ) else: print("No Start R-work line found in log.") @@ -69,21 +73,28 @@ def extract_r_values_from_log(log_file_path: str) -> tuple: rwork_final = float(numbers[0]) rfree_final = float(numbers[1]) else: - print(f"Final R-work line found but not enough numbers: {matched_lines_final[0]}") + print( + f"Final R-work line found but not enough numbers: {matched_lines_final[0]}" + ) else: print("No Final R-work line found in log.") print("finished r-val extraction") - return epoch_start, epoch_final, [rwork_start, rwork_final], [rfree_start, rfree_final] + return ( + epoch_start, + epoch_final, + [rwork_start, rwork_final], + [rfree_start, rfree_final], + ) except Exception as e: print(f"Could not open log file: {e}") return None, None, [None, None], [None, None] - - def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str): - phenix_env = "/n/hekstra_lab/garden_backup/phenix-1.21/phenix-1.21.1-5286/phenix_env.sh" + phenix_env = ( + "/n/hekstra_lab/garden_backup/phenix-1.21/phenix-1.21.1-5286/phenix_env.sh" + ) phenix_output_dir = os.path.join(repo_dir, "phenix_output") os.makedirs(phenix_output_dir, exist_ok=True) path_to_pdb_model = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/9b7c.pdb" @@ -95,31 +106,41 @@ def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str): r_free_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/rfree.mtz" pdb_model_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/9b7c.pdb" - specific_phenix_eff_file_path = os.path.join(repo_dir, f"phenix_{checkpoint_name}.eff") + specific_phenix_eff_file_path = os.path.join( + repo_dir, f"phenix_{checkpoint_name}.eff" + ) phenix_out_name = f"refine_{checkpoint_name}" - subprocess.run([ - "python", "/n/holylabs/hekstra_lab/Users/fgiehr/factory/src/factory/generate_eff.py", - "--sf-mtz", mtz_output_path, - "--rfree-mtz", r_free_path, - "--out", specific_phenix_eff_file_path, - "--phenix-out-mtz", phenix_out_name, - ], check=True) + subprocess.run( + [ + "python", + "/n/holylabs/hekstra_lab/Users/fgiehr/factory/src/factory/generate_eff.py", + "--sf-mtz", + mtz_output_path, + "--rfree-mtz", + r_free_path, + "--out", + specific_phenix_eff_file_path, + "--phenix-out-mtz", + phenix_out_name, + ], + check=True, + ) - cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {specific_phenix_eff_file_path} overwrite=True" # {mtz_output_path} {r_free_path} {phenix_refine_eff_file_path} {pdb_model_path} overwrite=True" + cmd = f"source {phenix_env} && cd {repo_dir} && phenix.refine {specific_phenix_eff_file_path} overwrite=True" # {mtz_output_path} {r_free_path} {phenix_refine_eff_file_path} {pdb_model_path} overwrite=True" print("\nRunning phenix with command:", flush=True) print(cmd, flush=True) try: result = subprocess.run( - cmd, - shell=True, - executable="/bin/bash", - capture_output=True, - text=True, - check=True, - ) + cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + text=True, + check=True, + ) print("Phenix command completed successfully") except subprocess.CalledProcessError as e: @@ -141,10 +162,13 @@ def run_phenix(repo_dir: str, checkpoint_name: str, mtz_output_path: str): return phenix_out_name + def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str): - surrogate_posterior_dataset = list(model.surrogate_posterior.to_dataset(only_observed=True))[0] + surrogate_posterior_dataset = list( + model.surrogate_posterior.to_dataset(only_observed=True) + )[0] - checkpoint_name, _ = os.path.splitext(os.path.basename(checkpoint_path)) + checkpoint_name, _ = os.path.splitext(os.path.basename(checkpoint_path)) mtz_output_path = os.path.join(repo_dir, f"phenix_ready_{checkpoint_name}.mtz") @@ -156,18 +180,21 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str): else: print("Error: Failed to write MTZ file") - phenix_out_name = run_phenix(repo_dir=repo_dir, checkpoint_name=checkpoint_name, mtz_output_path=mtz_output_path) + phenix_out_name = run_phenix( + repo_dir=repo_dir, + checkpoint_name=checkpoint_name, + mtz_output_path=mtz_output_path, + ) # phenix_out_name = f"refine_{checkpoint_name}" # solved_artifact = wandb.Artifact("phenix_solved", type="mtz") pdb_file_path = glob.glob(os.path.join(repo_dir, f"{phenix_out_name}*.pdb")) print("pdb_file_path", pdb_file_path) - pdb_file_path=pdb_file_path[-1] + pdb_file_path = pdb_file_path[-1] mtz_file_path = glob.glob(os.path.join(repo_dir, f"{phenix_out_name}*.mtz")) print("mtz_file_path", mtz_file_path) - mtz_file_path=mtz_file_path[0] - + mtz_file_path = mtz_file_path[0] if pdb_file_path: print(f"Using PDB file: {pdb_file_path}") @@ -175,12 +202,14 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str): print("log_file_path", log_file_path) log_file_path = log_file_path[-1] print("log_file_path2", log_file_path) - epoch_start, epoch_final, r_work, r_free = extract_r_values_from_log(log_file_path) + epoch_start, epoch_final, r_work, r_free = extract_r_values_from_log( + log_file_path + ) print("rvals", r_work, r_free) else: print("No PDB files found matching pattern!") r_work, r_free = None, None - + # pdb_file_path = glob.glob(f"{phenix_out_name}*.pdb")[-1] # mtz_file_path = glob.glob(f"{phenix_out_name}*.mtz")[-1] print("run find peaks") @@ -190,21 +219,30 @@ def find_peaks_from_model(model, repo_dir: str, checkpoint_path: str): peaks_file_path = os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv") import reciprocalspaceship as rs + ds = rs.read_mtz(mtz_file_path) print(ds.columns.tolist()) subprocess.run( - f"rs.find_peaks {mtz_file_path} {pdb_file_path} -f ANOM -p PANOM -z 5.0 -o {peaks_file_path}", - shell=True, - # cwd=f"{repo_dir}", + f"rs.find_peaks {mtz_file_path} {pdb_file_path} -f ANOM -p PANOM -z 5.0 -o {peaks_file_path}", + shell=True, + # cwd=f"{repo_dir}", ) print("finished find peaks") - return epoch_start, r_work, r_free, os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv") + return ( + epoch_start, + r_work, + r_free, + os.path.join(repo_dir, f"peaks_{phenix_out_name}.csv"), + ) -def process_checkpoint(checkpoint_path, artifact_dir, model_settings, loss_settings, wandb_directory): +def process_checkpoint( + checkpoint_path, artifact_dir, model_settings, loss_settings, wandb_directory +): import os import re + import pandas as pd from model import Model @@ -213,34 +251,49 @@ def process_checkpoint(checkpoint_path, artifact_dir, model_settings, loss_setti print("checkpoint_path_", checkpoint_path_) print("artifact_dir", artifact_dir) - try: + try: match = re.search(r"best-(\d+)-", checkpoint_path) epoch = int(match.group(1)) if match else None - model = Model.load_from_checkpoint(checkpoint_path, model_settings=model_settings, loss_settings=loss_settings) + model = Model.load_from_checkpoint( + checkpoint_path, model_settings=model_settings, loss_settings=loss_settings + ) model.eval() - epoch_start, r_work, r_free, path_to_peaks = find_peaks_from_model(model=model, repo_dir=wandb_directory, checkpoint_path=checkpoint_path) + epoch_start, r_work, r_free, path_to_peaks = find_peaks_from_model( + model=model, repo_dir=wandb_directory, checkpoint_path=checkpoint_path + ) peaks_df = pd.read_csv(path_to_peaks) rows = [] for _, peak in peaks_df.iterrows(): - residue = peak['residue'] - peak_height = peak['peakz'] - rows.append({ - 'Epoch': epoch, - 'Checkpoint': checkpoint_path, - 'Residue': residue, - 'Peak_Height': peak_height - }) + residue = peak["residue"] + peak_height = peak["peakz"] + rows.append( + { + "Epoch": epoch, + "Checkpoint": checkpoint_path, + "Residue": residue, + "Peak_Height": peak_height, + } + ) return epoch_start, rows, r_work, r_free except Exception as e: print(f"Failed for {checkpoint_path}: {e}") return None, [], None, None -def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settings, artifact_dir, checkpoint_paths, wandb_directory): + +def run_phenix_over_all_checkpoints( + model_settings, + loss_settings, + phenix_settings, + artifact_dir, + checkpoint_paths, + wandb_directory, +): + import multiprocessing as mp + import pandas as pd import wandb - import multiprocessing as mp all_rows = [] r_works_start = [] @@ -249,7 +302,6 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin r_frees_final = [] max_peak_heights = [] epochs = [] - ctx = mp.get_context("spawn") # Prepare argument tuples for each checkpoint @@ -264,10 +316,14 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin for epoch_start, rows, r_work, r_free in results: all_rows.extend(rows) if ( - isinstance(r_work, list) and len(r_work) == 2 and - isinstance(r_free, list) and len(r_free) == 2 and - r_work[0] is not None and r_work[1] is not None and - r_free[0] is not None and r_free[1] is not None + isinstance(r_work, list) + and len(r_work) == 2 + and isinstance(r_free, list) + and len(r_free) == 2 + and r_work[0] is not None + and r_work[1] is not None + and r_free[0] is not None + and r_free[1] is not None ): r_works_start.append(r_work[0]) r_works_final.append(r_work[1]) @@ -286,10 +342,10 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin try: fig, ax = plt.subplots(figsize=(10, 6)) - ax.plot(epochs, max_peak_heights, marker='o', label='Max Peak Height') - ax.set_xlabel('Epoch') - ax.set_ylabel('Peak Height') - ax.set_title('Max Peak Height vs Epoch') + ax.plot(epochs, max_peak_heights, marker="o", label="Max Peak Height") + ax.set_xlabel("Epoch") + ax.set_ylabel("Peak Height") + ax.set_title("Max Peak Height vs Epoch") ax.legend() plt.tight_layout() wandb.log({"max_peak_heights_plot": wandb.Image(fig)}) @@ -297,9 +353,6 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin except Exception as e: print(f"failed to plot max peak heights: {e}") - - - final_df = pd.DataFrame(all_rows) if not final_df.empty and "Peak_Height" in final_df.columns: @@ -311,26 +364,27 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin if final_df.empty: print("No peak heights to log.") - # def plot_peak_heights_vs_epoch(df): try: df = final_df import matplotlib.cm as cm - + if df.empty: print("No data to plot for peak heights vs epoch.") return plt.figure(figsize=(10, 6)) - residues = df['Residue'].unique() - colormap = cm.get_cmap('tab20', len(residues)) + residues = df["Residue"].unique() + colormap = cm.get_cmap("tab20", len(residues)) color_map = {res: colormap(i) for i, res in enumerate(residues)} for res in residues: - sub = df[df['Residue'] == res] - plt.scatter(sub['Epoch'], sub['Peak_Height'], label=str(res), color=color_map[res]) - plt.xlabel('Epoch') - plt.ylabel('Peak Height') - plt.title('Peak Height vs Epoch per Residue') - plt.legend(title='Residue', bbox_to_anchor=(1.05, 1), loc='upper left') + sub = df[df["Residue"] == res] + plt.scatter( + sub["Epoch"], sub["Peak_Height"], label=str(res), color=color_map[res] + ) + plt.xlabel("Epoch") + plt.ylabel("Peak Height") + plt.title("Peak Height vs Epoch per Residue") + plt.legend(title="Residue", bbox_to_anchor=(1.05, 1), loc="upper left") plt.tight_layout() wandb.log({"PeakHeight_vs_Epoch": wandb.Image(plt.gcf())}) plt.close() @@ -339,45 +393,74 @@ def run_phenix_over_all_checkpoints(model_settings, loss_settings, phenix_settin except Exception as e: print(f"failed to plot peak heights: {e}") - try: - _, _, r_work_reference, r_free_reference = extract_r_values_from_log(phenix_settings.r_values_reference_path) + _, _, r_work_reference, r_free_reference = extract_r_values_from_log( + phenix_settings.r_values_reference_path + ) r_work_reference = r_work_reference[-1] r_free_reference = r_free_reference[-1] except Exception as e: print(f"failed to extract reference r-vals: {e}") - try: import matplotlib.pyplot as plt import matplotlib.ticker as mtick - min_y = min(min(r_works_final), min(r_frees_final),r_work_reference, r_free_reference) - 0.01 - max_y = max(max(r_works_final), max(r_frees_final), r_work_reference, r_free_reference) + 0.01 + min_y = ( + min( + min(r_works_final), + min(r_frees_final), + r_work_reference, + r_free_reference, + ) + - 0.01 + ) + max_y = ( + max( + max(r_works_final), + max(r_frees_final), + r_work_reference, + r_free_reference, + ) + + 0.01 + ) - plt.style.use('seaborn-v0_8-darkgrid') + plt.style.use("seaborn-v0_8-darkgrid") fig, ax = plt.subplots(figsize=(8, 5)) x = epochs # Arbitrary units (e.g., checkpoint index) # ax.plot(x, r_works_start, marker='o', label='R-work (start)', c="red", alpha=0.7) - ax.plot(x, r_works_final, marker='o', label='R-work (final)', c="red") + ax.plot(x, r_works_final, marker="o", label="R-work (final)", c="red") # ax.plot(x, r_frees_start, marker='s', label='R-free (start)', c="blue", alpha=0.7) - ax.plot(x, r_frees_final, marker='s', label='R-free (final)', c="blue") + ax.plot(x, r_frees_final, marker="s", label="R-free (final)", c="blue") if r_work_reference is not None: - ax.axhline(r_work_reference, color='red', linestyle='dotted', linewidth=0.7, alpha=0.5, label='r_work_reference') + ax.axhline( + r_work_reference, + color="red", + linestyle="dotted", + linewidth=0.7, + alpha=0.5, + label="r_work_reference", + ) if r_free_reference is not None: - ax.axhline(r_free_reference, color='blue', linestyle='dotted', linewidth=0.7, alpha=0.5, label='r_free_reference') + ax.axhline( + r_free_reference, + color="blue", + linestyle="dotted", + linewidth=0.7, + alpha=0.5, + label="r_free_reference", + ) - ax.set_xlabel('Training Epoch') - ax.set_ylabel('R Value') + ax.set_xlabel("Training Epoch") + ax.set_ylabel("R Value") ax.set_ylim(min_y, max_y) ax.set_yticks(np.linspace(min_y, max_y, 12)) - ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.4f')) - ax.set_title('R-work and R-free') - ax.legend(loc='best', fontsize=10, frameon=True) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.4f")) + ax.set_title("R-work and R-free") + ax.legend(loc="best", fontsize=10, frameon=True) plt.tight_layout() wandb.log({"R_values_plot": wandb.Image(fig)}) plt.close(fig) except Exception as e: print(f"failed plotting r-values: {e}") - diff --git a/src/factory/plot_peaks.py b/src/factory/plot_peaks.py index dcb7799..bc3a33b 100644 --- a/src/factory/plot_peaks.py +++ b/src/factory/plot_peaks.py @@ -1,26 +1,26 @@ -import pandas as pd +import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import wandb from Bio.PDB import PDBParser from scipy.spatial import cKDTree -import matplotlib.pyplot as plt -import wandb wandb.init(project="plot anomolous peaks") -peaks_df = pd.read_csv("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv") -print(peaks_df.columns) +peaks_df = pd.read_csv( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/peaks.csv" +) +print(peaks_df.columns) # print(peaks_df.iloc[0]) # print(peaks_df) # Create a table of peak heights -peak_heights_table = pd.DataFrame({ - 'Peak_Height': peaks_df['peakz'] -}) +peak_heights_table = pd.DataFrame({"Peak_Height": peaks_df["peakz"]}) # Log the peak heights table to wandb wandb.log({"peak_heights": wandb.Table(dataframe=peak_heights_table)}) -peak_coords = peaks_df[['cenx', 'ceny', 'cenz']].values +peak_coords = peaks_df[["cenx", "ceny", "cenz"]].values model_file = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/shaken_9b7c_refine_001.pdb" parser = PDBParser(QUIET=True) model = parser.get_structure("model", model_file) @@ -41,19 +41,21 @@ for i, peak in enumerate(peak_coords): dist, idx = tree.query(peak) atom = atoms[idx] - matches.append({ - "Peak_X": peak[0], - "Peak_Y": peak[1], - "Peak_Z": peak[2], - "Peak_Height": peaks_df.loc[i, 'peakz'], # adjust if column name differs - "Atom_Name": atom.get_name(), - "Residue": atom.get_parent().get_resname(), - "Residue_ID": atom.get_parent().get_id()[1], - "Chain": atom.get_parent().get_parent().id, - "B_factor": atom.get_bfactor(), - "Occupancy": atom.get_occupancy(), - "Distance": dist - }) + matches.append( + { + "Peak_X": peak[0], + "Peak_Y": peak[1], + "Peak_Z": peak[2], + "Peak_Height": peaks_df.loc[i, "peakz"], # adjust if column name differs + "Atom_Name": atom.get_name(), + "Residue": atom.get_parent().get_resname(), + "Residue_ID": atom.get_parent().get_id()[1], + "Chain": atom.get_parent().get_parent().id, + "B_factor": atom.get_bfactor(), + "Occupancy": atom.get_occupancy(), + "Distance": dist, + } + ) # Create output DataFrame results_df = pd.DataFrame(matches) @@ -94,4 +96,4 @@ # # Log the figure to wandb # wandb.log({"peaks_vs_model": wandb.Image(fig)}) -# plt.show() \ No newline at end of file +# plt.show() diff --git a/src/factory/positional_encoding.py b/src/factory/positional_encoding.py index 63f915b..9501be4 100644 --- a/src/factory/positional_encoding.py +++ b/src/factory/positional_encoding.py @@ -2,6 +2,7 @@ import torch + def positional_encoding(X, L): """ X: metadata (batch_size, feature_size) @@ -15,18 +16,21 @@ def positional_encoding(X, L): # Get min and max values along the last dimension min_vals = X.min(dim=-1, keepdim=True)[0] max_vals = X.max(dim=-1, keepdim=True)[0] - + # Normalize between -1 and 1 - p = 2. * (X - min_vals) / (max_vals - min_vals + 1e-8) - 1. - + p = 2.0 * (X - min_vals) / (max_vals - min_vals + 1e-8) - 1.0 + # Create frequency bands L_range = torch.arange(L, dtype=X.dtype, device=X.device) f = torch.pi * 2**L_range - + # Compute positional encoding fp = (f[..., None, :] * p[..., :, None]).reshape(p.shape[:-1] + (-1,)) - - return torch.cat(( - torch.cos(fp), - torch.sin(fp), - ), dim=-1) \ No newline at end of file + + return torch.cat( + ( + torch.cos(fp), + torch.sin(fp), + ), + dim=-1, + ) diff --git a/src/factory/rasu.py b/src/factory/rasu.py index 4738b1b..2a1ca2f 100644 --- a/src/factory/rasu.py +++ b/src/factory/rasu.py @@ -3,6 +3,7 @@ import gemmi import numpy as np import torch +from abismal_torch.symmetry.op import Op from reciprocalspaceship.decorators import cellify, spacegroupify from reciprocalspaceship.utils import ( apply_to_hkl, @@ -10,8 +11,6 @@ generate_reciprocal_cell, ) -from abismal_torch.symmetry.op import Op - class ReciprocalASU(torch.nn.Module): @cellify @@ -22,7 +21,7 @@ def __init__( spacegroup: gemmi.SpaceGroup, dmin: float, anomalous: Optional[bool] = True, - **kwargs + **kwargs, ) -> None: """ Base Layer that maps observed reflections to the reciprocal asymmetric unit (rasu). @@ -152,8 +151,18 @@ def __init__(self, *rasus: ReciprocalASU, **kwargs) -> None: for rasu_id, rasu in enumerate(self.reciprocal_asus): Hcell = generate_reciprocal_cell(rasu.cell, dmin=rasu.dmin) h, k, l = Hcell.T - self.reflection_id_grid[torch.tensor(rasu_id, dtype=torch.int),torch.tensor(h, dtype=torch.int), torch.tensor(k, dtype=torch.int), torch.tensor(l, dtype=torch.int)] = ( - rasu.reflection_id_grid[torch.tensor(h, dtype=torch.int), torch.tensor(k, dtype=torch.int), torch.tensor(l, dtype=torch.int)] + offset + self.reflection_id_grid[ + torch.tensor(rasu_id, dtype=torch.int), + torch.tensor(h, dtype=torch.int), + torch.tensor(k, dtype=torch.int), + torch.tensor(l, dtype=torch.int), + ] = ( + rasu.reflection_id_grid[ + torch.tensor(h, dtype=torch.int), + torch.tensor(k, dtype=torch.int), + torch.tensor(l, dtype=torch.int), + ] + + offset ) # self.reflection_id_grid[rasu_id, h, k, l] = ( # rasu.reflection_id_grid[h,k,l] + offset @@ -195,7 +204,7 @@ def __init__( *rasus: ReciprocalASU, parents: Optional[torch.Tensor] = None, reindexing_ops: Optional[Sequence[str]] = None, - **kwargs + **kwargs, ) -> None: """ A graph of rasu objects. diff --git a/src/factory/run_data_loader.py b/src/factory/run_data_loader.py index 3253956..06e8317 100644 --- a/src/factory/run_data_loader.py +++ b/src/factory/run_data_loader.py @@ -1,8 +1,9 @@ -from data_loader import * import argparse import json import CNN_3d +from data_loader import * + def parse_args(): parser = argparse.ArgumentParser(description="Pass DataLoader settings") @@ -19,11 +20,12 @@ def parse_args(): ) return parser.parse_args() + def main(): print("main function") # args = parse_args() # if args.data_file_names is not None: - # settings = DataLoaderSettings(data_directory=args.data_directory, + # settings = DataLoaderSettings(data_directory=args.data_directory, # data_file_names=args.data_file_names, # test_set_split=0.01 # ) @@ -39,18 +41,19 @@ def main(): data_directory = "/n/hekstra_lab/people/aldama/subset" data_file_names = { - "shoeboxes": "standardized_shoeboxes_subset.pt", - "counts": "raw_counts_subset.pt", - "metadata": "shoebox_features_subset.pt", - "masks": "masks_subset.pt", - "true_reference": "metadata_subset.pt", + "shoeboxes": "standardized_shoeboxes_subset.pt", + "counts": "raw_counts_subset.pt", + "metadata": "shoebox_features_subset.pt", + "masks": "masks_subset.pt", + "true_reference": "metadata_subset.pt", } print("load settings") - settings = DataLoaderSettings(data_directory=data_directory, - data_file_names=data_file_names, - test_set_split=0.5 - ) + settings = DataLoaderSettings( + data_directory=data_directory, + data_file_names=data_file_names, + test_set_split=0.5, + ) print("load settings done") test_data = test(settings=settings) batch = next(iter(test_data)) @@ -66,4 +69,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/src/factory/run_factory.py b/src/factory/run_factory.py index 6d227cf..b8aae5b 100644 --- a/src/factory/run_factory.py +++ b/src/factory/run_factory.py @@ -1,40 +1,46 @@ # from model_playground import * -from model import * -from settings import * -from parse_yaml import load_settings_from_yaml, _log_settings_from_yaml -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.profilers import AdvancedProfiler +import argparse import glob import os +from datetime import datetime + import phenix_callback +import torch import wandb -from pytorch_lightning.tuner.tuning import Tuner -import argparse -from datetime import datetime -from callbacks import * import yaml -import torch +from callbacks import * +from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from model import * from omegaconf import OmegaConf -from lightning.pytorch.callbacks import LearningRateMonitor +from parse_yaml import _log_settings_from_yaml, load_settings_from_yaml +from pytorch_lightning.profilers import AdvancedProfiler +from pytorch_lightning.tuner.tuning import Tuner +from settings import * + -def configure_settings(config_path=None): - model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(config_path) +def configure_settings(config_path=None): + model_settings, loss_settings, dataloader_settings = load_settings_from_yaml( + config_path + ) return model_settings, loss_settings, dataloader_settings - -def run(config_path, save_directory, run_from_version=False): + +def run(config_path, save_directory, run_from_version=False): now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if run_from_version: run_name = f"{now_str}_2025-11-05_13-59-42_from-start" - wandb_logger = WandbLogger(project="full-model", name=run_name, save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs") + wandb_logger = WandbLogger( + project="full-model", + name=run_name, + save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", + ) wandb.init( project="full-model", name=run_name, - dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs" + dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", ) config_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml" @@ -43,9 +49,10 @@ def run(config_path, save_directory, run_from_version=False): # config_path = os.path.join(config_artifact_dir, "config_example.yaml") with open(config_path, "r") as f: config_dict = yaml.safe_load(f) - model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(config_dict) + model_settings, loss_settings, phenix_settings, dataloader_settings = ( + load_settings_from_yaml(config_dict) + ) - slurm_job_id = os.environ.get("SLURM_JOB_ID") if slurm_job_id is not None: wandb.config.update({"slurm_job_id": slurm_job_id}) @@ -56,8 +63,9 @@ def run(config_path, save_directory, run_from_version=False): # checkpoint_path = os.path.join(artifact_dir, ckpt_files[-1]) checkpoint_path = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20251106_125602-hgjymq80/files/checkpoints/last.ckpt" - model = Model.load_from_checkpoint(checkpoint_path, model_settings=model_settings, loss_settings=loss_settings) - + model = Model.load_from_checkpoint( + checkpoint_path, model_settings=model_settings, loss_settings=loss_settings + ) else: run_name = f"{now_str}_from-start" @@ -66,7 +74,9 @@ def run(config_path, save_directory, run_from_version=False): base_cfg = OmegaConf.load(config_path) - wandb_logger = WandbLogger(project="full-model", name=run_name, save_dir=save_directory) + wandb_logger = WandbLogger( + project="full-model", name=run_name, save_dir=save_directory + ) wandb_run = wandb.init( project="full-model", name=run_name, @@ -79,17 +89,21 @@ def run(config_path, save_directory, run_from_version=False): wandb_config = wandb_run.config - model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(wandb_config) #wandb_config) + model_settings, loss_settings, phenix_settings, dataloader_settings = ( + load_settings_from_yaml(wandb_config) + ) # wandb_config) slurm_job_id = os.environ.get("SLURM_JOB_ID") if slurm_job_id is not None: wandb.config.update({"slurm_job_id": slurm_job_id}) model = Model(model_settings=model_settings, loss_settings=loss_settings) - + config_artifact = wandb.Artifact("config", type="config") config_artifact.add_file(config_path) # Path to your config file logged_artifact = wandb_logger.experiment.log_artifact(config_artifact) - dataloader = data_loader.CrystallographicDataLoader(data_loader_settings=dataloader_settings) + dataloader = data_loader.CrystallographicDataLoader( + data_loader_settings=dataloader_settings + ) print("Loading data ...") dataloader.load_data_() print("Data loaded successfully.") @@ -104,7 +118,6 @@ def run(config_path, save_directory, run_from_version=False): if hasattr(model, "set_dataloader"): model.set_dataloader(train_dataloader) - # if model_settings.enable_checkpointing: checkpoint_callback = ModelCheckpoint( dirpath=wandb_logger.experiment.dir + "/checkpoints", @@ -112,20 +125,20 @@ def run(config_path, save_directory, run_from_version=False): # save_top_k=6, every_n_epochs=15, # mode="min", - filename="{epoch:02d}",#-{step}-{validation_loss/loss:.2f}", + filename="{epoch:02d}", # -{step}-{validation_loss/loss:.2f}", save_weights_only=True, # every_n_train_steps=None, # Disable step-based checkpointing - save_last=True # Don't save the last checkpoint + save_last=True, # Don't save the last checkpoint ) trainer = L.pytorch.Trainer( - logger=wandb_logger, + logger=wandb_logger, max_epochs=160, - log_every_n_steps=200, - val_check_interval=200, + log_every_n_steps=200, + val_check_interval=200, # limit_val_batches=21, - accelerator="auto", - enable_checkpointing=True, #model_settings.enable_checkpointing, + accelerator="auto", + enable_checkpointing=True, # model_settings.enable_checkpointing, # default_root_dir="/tmp", # profiler="simple", # profiler=AdvancedProfiler(dirpath=wandb_logger.experiment.dir, filename="profiler.txt"), @@ -133,23 +146,28 @@ def run(config_path, save_directory, run_from_version=False): callbacks=[ Plotting(dataloader=dataloader), LossLogging(), - CorrelationPlottingBinned(dataloader=dataloader), - CorrelationPlotting(dataloader=dataloader), + CorrelationPlottingBinned(dataloader=dataloader), + CorrelationPlotting(dataloader=dataloader), ScalePlotting(dataloader=dataloader), checkpoint_callback, # PeakHeights(), - LearningRateMonitor(logging_interval='epoch'), + LearningRateMonitor(logging_interval="epoch"), # ValidationSetCorrelationPlot(dataloader=dataloader), - ] + ], ) print("lr_scheduler_configs:", trainer.lr_scheduler_configs) - - trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) - + trainer.fit( + model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader + ) + if model_settings.enable_checkpointing: - checkpoint_dir = checkpoint_callback.dirpath if hasattr(checkpoint_callback, "dirpath") else "/tmp" + checkpoint_dir = ( + checkpoint_callback.dirpath + if hasattr(checkpoint_callback, "dirpath") + else "/tmp" + ) print("checkpoint_dir", checkpoint_dir) checkpoint_pattern = os.path.join(checkpoint_dir, "**", "*.ckpt") @@ -157,14 +175,20 @@ def run(config_path, save_directory, run_from_version=False): print("checkpoint_paths", checkpoint_paths) - print("model_settings.enable_checkpointing",model_settings.enable_checkpointing) - print("checkpoint_callback.best_model_path",checkpoint_callback.best_model_path) + print( + "model_settings.enable_checkpointing", model_settings.enable_checkpointing + ) + print( + "checkpoint_callback.best_model_path", checkpoint_callback.best_model_path + ) if run_from_version == False: - config_artifact_ref = logged_artifact.name # e.g., 'your-entity/your-project/config:v0' + config_artifact_ref = ( + logged_artifact.name + ) # e.g., 'your-entity/your-project/config:v0' version_alias = config_artifact_ref.split(":")[-1] - artifact = wandb.Artifact('models', type='model') + artifact = wandb.Artifact("models", type="model") for ckpt_file in checkpoint_paths: artifact.add_file(ckpt_file) print(f"Logged {len(checkpoint_paths)} checkpoints to W&B artifact.") @@ -173,15 +197,14 @@ def run(config_path, save_directory, run_from_version=False): print("run phenix") phenix_callback.run_phenix_over_all_checkpoints( - model_settings=model_settings, - loss_settings=loss_settings, + model_settings=model_settings, + loss_settings=loss_settings, phenix_settings=phenix_settings, - artifact_dir=checkpoint_dir, + artifact_dir=checkpoint_dir, checkpoint_paths=checkpoint_paths, - wandb_directory=wandb_logger.experiment.dir + wandb_directory=wandb_logger.experiment.dir, ) - # for ckpt_file in checkpoint_paths: # try: # os.remove(ckpt_file) @@ -191,14 +214,26 @@ def run(config_path, save_directory, run_from_version=False): wandb.finish() + def main(): - parser = argparse.ArgumentParser(description='Run factory with config file') - parser.add_argument('--config', type=str, default='../../configs/config_example.yaml', help='Path to config YAML file') + parser = argparse.ArgumentParser(description="Run factory with config file") + parser.add_argument( + "--config", + type=str, + default="../../configs/config_example.yaml", + help="Path to config YAML file", + ) args, _ = parser.parse_known_args() - parser.add_argument('--save_directory', type=str, default='./factory_results', help='Directory for W&B logs and checkpoint artifacts') + parser.add_argument( + "--save_directory", + type=str, + default="./factory_results", + help="Directory for W&B logs and checkpoint artifacts", + ) args, _ = parser.parse_known_args() - + run(config_path=args.config, save_directory=args.save_directory) + if __name__ == "__main__": main() diff --git a/src/factory/run_phenix_from_model.py b/src/factory/run_phenix_from_model.py index 85185ff..97f917d 100644 --- a/src/factory/run_phenix_from_model.py +++ b/src/factory/run_phenix_from_model.py @@ -1,15 +1,16 @@ -from model import * -from settings import * -from parse_yaml import load_settings_from_yaml -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.callbacks import Callback, ModelCheckpoint - import glob import os +from datetime import datetime + import phenix_callback import wandb -from datetime import datetime +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from model import * +from parse_yaml import load_settings_from_yaml from pytorch_lightning.tuner.tuning import Tuner +from settings import * + def run(): @@ -18,48 +19,64 @@ def run(): now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - wandb_logger = WandbLogger(project="full-model", name=f"{now_str}_phenix_from_{model_alias}", save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs") + wandb_logger = WandbLogger( + project="full-model", + name=f"{now_str}_phenix_from_{model_alias}", + save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", + ) wandb.init( project="full-model", name=f"{now_str}_phenix_from_{model_alias}", - dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs" + dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", ) slurm_job_id = os.environ.get("SLURM_JOB_ID") if slurm_job_id is not None: wandb.config.update({"slurm_job_id": slurm_job_id}) - print("wandb_logger.experiment.dir",wandb_logger.experiment.dir) + print("wandb_logger.experiment.dir", wandb_logger.experiment.dir) - artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/models:{model_alias}", type="model") + artifact = wandb.use_artifact( + f"flaviagiehr-harvard-university/full-model/models:{model_alias}", type="model" + ) artifact_dir = artifact.download() print("Files in artifact_dir:", os.listdir(artifact_dir)) - config_artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/config:{config_alias}", type="config") + config_artifact = wandb.use_artifact( + f"flaviagiehr-harvard-university/full-model/config:{config_alias}", + type="config", + ) config_artifact_dir = config_artifact.download() config_path = os.path.join(config_artifact_dir, "config_example.yaml") # config_path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml" - model_settings, loss_settings, phenix_settings, dataloader_settings = load_settings_from_yaml(path=config_path) + model_settings, loss_settings, phenix_settings, dataloader_settings = ( + load_settings_from_yaml(path=config_path) + ) # model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml") - - checkpoint_paths = [os.path.join(artifact_dir, f) for f in os.listdir(artifact_dir) if f.endswith(".ckpt")] + checkpoint_paths = [ + os.path.join(artifact_dir, f) + for f in os.listdir(artifact_dir) + if f.endswith(".ckpt") + ] # artifact_dir = "/n/holylabs/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250718_124342-22j3mlzs/files/checkpoints" print(f"downloaded {len(checkpoint_paths)} checkpoint files.") print("run phenix") phenix_callback.run_phenix_over_all_checkpoints( - model_settings=model_settings, - loss_settings=loss_settings, + model_settings=model_settings, + loss_settings=loss_settings, phenix_settings=phenix_settings, - artifact_dir=artifact_dir, + artifact_dir=artifact_dir, checkpoint_paths=checkpoint_paths, - wandb_directory=wandb_logger.experiment.dir,#"/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250720_200430-py737ahw/files",#wandb_logger.experiment.dir - ) + wandb_directory=wandb_logger.experiment.dir, # "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs/wandb/run-20250720_200430-py737ahw/files",#wandb_logger.experiment.dir + ) + def main(): # cProfile.run('run()') run() + if __name__ == "__main__": main() diff --git a/src/factory/run_reference_processing.py b/src/factory/run_reference_processing.py index a181fba..e0c3c7a 100644 --- a/src/factory/run_reference_processing.py +++ b/src/factory/run_reference_processing.py @@ -23,7 +23,9 @@ def run_dials(dials_env, command): def run(): - refl_file="/n/hekstra_lab/people/aldama/subset/small_dataset/pass1/reflections_.refl" + refl_file = ( + "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1/reflections_.refl" + ) scale_command = ( f"dials.scale '{refl_file}' '{expt_file}' " f"output.reflections='{scaled_refl_out}' " @@ -36,5 +38,6 @@ def run(): dials_env = "/n/hekstra_lab/people/aldama/software/dials-v3-16-1/dials_env.sh" run_dials(dials_env, scale_command) -if name==main: - run() \ No newline at end of file + +if name == main: + run() diff --git a/src/factory/settings.py b/src/factory/settings.py index 52e61fb..998e665 100644 --- a/src/factory/settings.py +++ b/src/factory/settings.py @@ -1,41 +1,46 @@ import dataclasses -import torch -import torch -import torch.nn.functional as F -from networks import * + import distributions import get_protein_data +import metadata_encoder as me import reciprocalspaceship as rs -from abismal_torch.prior import WilsonPrior - -from rasu import * +import shoebox_encoder as se +import torch +import torch.nn.functional as F from abismal_torch.likelihood import NormalLikelihood +from abismal_torch.prior import WilsonPrior from abismal_torch.surrogate_posterior import FoldedNormalPosterior +from networks import * +from rasu import * + from wrap_folded_normal import FrequencyTrackingPosterior, SparseFoldedNormalPosterior -import shoebox_encoder as se -import metadata_encoder as me # from lazy_adam import LazyAdamW @dataclasses.dataclass -class ModelSettings(): +class ModelSettings: run_from_version: str | None = None data_directory: str = "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1" - data_file_names: dict = dataclasses.field(default_factory=lambda: { - # "shoeboxes": "standardized_counts.pt", - "counts": "counts.pt", - "metadata": "reference_.pt", - "masks": "masks.pt", - }) + data_file_names: dict = dataclasses.field( + default_factory=lambda: { + # "shoeboxes": "standardized_counts.pt", + "counts": "counts.pt", + "metadata": "reference_.pt", + "masks": "masks.pt", + } + ) build_background_distribution: type = distributions.HalfNormalDistribution build_profile_distribution: type = distributions.Distribution # build_shoebox_profile_distribution = distributions.LRMVN_Distribution build_shoebox_profile_distribution: type = distributions.DirichletProfile background_prior_distribution: torch.distributions.Gamma = dataclasses.field( - default_factory=lambda: torch.distributions.HalfNormal(0.5)) + default_factory=lambda: torch.distributions.HalfNormal(0.5) + ) scale_prior_distibution: torch.distributions.Gamma = dataclasses.field( - default_factory=lambda: torch.distributions.Gamma(concentration=torch.tensor(6.68), rate=torch.tensor(6.4463)) + default_factory=lambda: torch.distributions.Gamma( + concentration=torch.tensor(6.68), rate=torch.tensor(6.4463) + ) ) scale_function: MLPScale = dataclasses.field( default_factory=lambda: MLPScale( @@ -43,14 +48,14 @@ class ModelSettings(): scale_distribution=LogNormalDistributionLayer(hidden_dimension=64), hidden_dimension=64, number_of_layers=1, - initial_scale_guess=2/140 + initial_scale_guess=2 / 140, ) ) build_intensity_prior_distribution: WilsonPrior = WilsonPrior intensity_prior_distibution: WilsonPrior = dataclasses.field(init=False) use_surrogate_parameters: bool = False - + shoebox_encoder: type = se.BaseShoeboxEncoder() metadata_encoder: type = me.BaseMetadataEncoder() intensity_encoder: type = se.IntensityEncoder() @@ -59,12 +64,12 @@ class ModelSettings(): use_positional_encoding: bool = False number_of_frequencies_in_positional_encoding: int = 2 - - - optimizer: type = torch.optim.AdamW #torch.optim.AdamW + optimizer: type = torch.optim.AdamW # torch.optim.AdamW optimizer_betas: tuple = (0.9, 0.9) learning_rate: float = 0.001 - surrogate_posterior_learning_rate: float = 0.01 # Higher learning rate for surrogate posterior + surrogate_posterior_learning_rate: float = ( + 0.01 # Higher learning rate for surrogate posterior + ) weight_decay: float = 0.0001 # Weight decay for main parameters surrogate_weight_decay: float = 0.0001 # Weight decay for surrogate parameters dmodel: int = 64 @@ -73,10 +78,16 @@ class ModelSettings(): number_of_mc_samples: int = 20 enable_checkpointing: bool = True - lysozyme_sequence_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/data/lysozyme.seq" + lysozyme_sequence_file_path: str = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/data/lysozyme.seq" + ) - merged_mtz_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz" - unmerged_mtz_file_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz" + merged_mtz_file_path: str = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/merged.mtz" + ) + unmerged_mtz_file_path: str = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz" + ) protein_pdb_url: str = "https://files.rcsb.org/download/9B7C.cif" rac: ReciprocalASUCollection = dataclasses.field(init=False) pdb_data: dict = dataclasses.field(init=False) @@ -84,30 +95,41 @@ class ModelSettings(): def __post_init__(self): self.pdb_data = get_protein_data.get_protein_data(self.protein_pdb_url) - self.rac = ReciprocalASUGraph(*[ReciprocalASU( - cell=self.pdb_data["unit_cell"], - spacegroup=self.pdb_data["spacegroup"], - dmin=float(self.pdb_data["dmin"]), - anomalous=True, - )]) - self.intensity_prior_distibution = self.build_intensity_prior_distribution(self.rac) + self.rac = ReciprocalASUGraph( + *[ + ReciprocalASU( + cell=self.pdb_data["unit_cell"], + spacegroup=self.pdb_data["spacegroup"], + dmin=float(self.pdb_data["dmin"]), + anomalous=True, + ) + ] + ) + self.intensity_prior_distibution = self.build_intensity_prior_distribution( + self.rac + ) @dataclasses.dataclass -class LossSettings(): +class LossSettings: prior_background_weight: float = 0.0001 prior_structure_factors_weight: float = 0.0001 prior_scale_weight: float = 0.0001 - prior_profile_weight: list[float] = dataclasses.field(default_factory=lambda: [0.0001, 0.0001, 0.0001]) + prior_profile_weight: list[float] = dataclasses.field( + default_factory=lambda: [0.0001, 0.0001, 0.0001] + ) eps: float = 0.00001 + @dataclasses.dataclass -class PhenixSettings(): - r_values_reference_path: str = "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log" +class PhenixSettings: + r_values_reference_path: str = ( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/anomalous_peaks_files/pdb_model/refine_001.log" + ) -@ dataclasses.dataclass -class DataLoaderSettings(): +@dataclasses.dataclass +class DataLoaderSettings: # data_directory: str = "/n/hekstra_lab/people/aldama/subset" # data_file_names: dict = dataclasses.field(default_factory=lambda: { # "shoeboxes": "shoebox_subset.pt", @@ -116,27 +138,28 @@ class DataLoaderSettings(): # "masks": "mask_subset.pt", # "true_reference": "true_reference_subset.pt", # }) - metadata_indices: dict = dataclasses.field(default_factory=lambda: { - "d": 0, - "h": 1, - "k": 2, - "l": 3, - "x": 4, - "y": 5, - "z": 6, - }) - metadata_keys_to_keep: list = dataclasses.field(default_factory=lambda: [ - "x", "y" - ]) - + metadata_indices: dict = dataclasses.field( + default_factory=lambda: { + "d": 0, + "h": 1, + "k": 2, + "l": 3, + "x": 4, + "y": 5, + "z": 6, + } + ) + metadata_keys_to_keep: list = dataclasses.field(default_factory=lambda: ["x", "y"]) data_directory: str = "/n/hekstra_lab/people/aldama/subset/small_dataset/pass1" - data_file_names: dict = dataclasses.field(default_factory=lambda: { - # "shoeboxes": "standardized_counts.pt", - "counts": "counts.pt", - "metadata": "reference_.pt", - "masks": "masks.pt", - }) + data_file_names: dict = dataclasses.field( + default_factory=lambda: { + # "shoeboxes": "standardized_counts.pt", + "counts": "counts.pt", + "metadata": "reference_.pt", + "masks": "masks.pt", + } + ) validation_set_split: float = 0.2 test_set_split: float = 0 @@ -145,9 +168,9 @@ class DataLoaderSettings(): number_of_batches: int = 1444 number_of_workers: int = 16 pin_memory: bool = True - prefetch_factor: int|None = 2 + prefetch_factor: int | None = 2 shuffle_indices: bool = True shuffle_groups: bool = True optimize_shoeboxes_per_batch: bool = True append_image_id_to_metadata: bool = False - verbose: bool = True \ No newline at end of file + verbose: bool = True diff --git a/src/factory/shoebox_encoder.py b/src/factory/shoebox_encoder.py index 0607352..4e0460a 100644 --- a/src/factory/shoebox_encoder.py +++ b/src/factory/shoebox_encoder.py @@ -1,21 +1,20 @@ -import torch import math + +import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Linear + class ShoeboxEncoder(nn.Module): - def __init__( - self, - input_shape=(3, 21, 21), - out_dim=64 - ): + def __init__(self, input_shape=(3, 21, 21), out_dim=64): super().__init__() def forward(self, x): pass + class BaseShoeboxEncoder(nn.Module): def __init__( self, @@ -76,6 +75,7 @@ def forward(self, x, mask=None): x = x.view(x.size(0), -1) return F.relu(self.fc(x)) + class SimpleShoeboxEncoder(nn.Module): def __init__( self, @@ -134,33 +134,41 @@ def forward(self, x, mask=None): if torch.isnan(x).any(): print("WARNING: NaN values in BaseShoeboxEncoder input!") print("NaN count:", torch.isnan(x).sum().item()) - + x = F.relu(self.norm1(self.conv1(x))) - + # Check after first conv if torch.isnan(x).any(): print("WARNING: NaN values after conv1 in BaseShoeboxEncoder!") print("NaN count:", torch.isnan(x).sum().item()) - + x = self.pool(x) x = F.relu(self.norm2(self.conv2(x))) - + # Check after second conv if torch.isnan(x).any(): print("WARNING: NaN values after conv2 in BaseShoeboxEncoder!") print("NaN count:", torch.isnan(x).sum().item()) - + x = x.view(x.size(0), -1) x = F.relu(self.fc(x)) - + # Check final output if torch.isnan(x).any(): print("WARNING: NaN values in BaseShoeboxEncoder output!") print("NaN count:", torch.isnan(x).sum().item()) - print("Stats - min:", x.min().item(), "max:", x.max().item(), "mean:", x.mean().item()) - + print( + "Stats - min:", + x.min().item(), + "max:", + x.max().item(), + "mean:", + x.mean().item(), + ) + return x + class IntensityEncoder(torch.nn.Module): def __init__( self, diff --git a/src/factory/simple_encoder.py b/src/factory/simple_encoder.py index 0ce7e78..acfbf2b 100644 --- a/src/factory/simple_encoder.py +++ b/src/factory/simple_encoder.py @@ -1,9 +1,11 @@ -import torch import math + +import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Linear + class ShoeboxEncoder(nn.Module): def __init__( self, @@ -64,6 +66,7 @@ def forward(self, x, mask=None): x = x.view(x.size(0), -1) return F.relu(self.fc(x)) + def weight_initializer(weight): fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2]) std = math.sqrt(1.0 / fan_avg / 10.0) @@ -117,6 +120,7 @@ def forward(self, x): return out + class MLPMetadataEncoder(nn.Module): def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None): super().__init__() @@ -143,4 +147,4 @@ def __init__(self, feature_dim, depth=8, dropout=0.0, output_dims=None): def forward(self, x): # Process through the model x = self.model(x) - return x \ No newline at end of file + return x diff --git a/src/factory/simplest_encode.py b/src/factory/simplest_encode.py index 2ce2c0b..ceb8847 100644 --- a/src/factory/simplest_encode.py +++ b/src/factory/simplest_encode.py @@ -1,9 +1,11 @@ -import torch import math + +import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Linear + class ShoeboxEncoder(nn.Module): def __init__( self, @@ -64,6 +66,7 @@ def forward(self, x, mask=None): x = x.view(x.size(0), -1) return F.relu(self.fc(x)) + def weight_initializer(weight): fan_avg = 0.5 * (weight.shape[-1] + weight.shape[-2]) std = math.sqrt(1.0 / fan_avg / 10.0) @@ -117,6 +120,7 @@ def forward(self, x): return out + class MLPMetadataEncoder(nn.Module): def __init__(self, feature_dim, depth=10, dropout=0.0, output_dims=None): super().__init__() @@ -143,4 +147,4 @@ def __init__(self, feature_dim, depth=10, dropout=0.0, output_dims=None): def forward(self, x): # Process through the model x = self.model(x) - return x \ No newline at end of file + return x diff --git a/src/factory/test_callbacks.py b/src/factory/test_callbacks.py index fa25bdf..0c197d2 100644 --- a/src/factory/test_callbacks.py +++ b/src/factory/test_callbacks.py @@ -1,15 +1,16 @@ -from model import * -from settings import * -from parse_yaml import load_settings_from_yaml -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.callbacks import Callback, ModelCheckpoint - import glob import os +from datetime import datetime + import phenix_callback import wandb -from datetime import datetime +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from model import * +from parse_yaml import load_settings_from_yaml from pytorch_lightning.tuner.tuning import Tuner +from settings import * + def run(): @@ -17,18 +18,25 @@ def run(): now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - wandb_logger = WandbLogger(project="full-model", name=f"{now_str}_phenix_from_{model_alias}", save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs") + wandb_logger = WandbLogger( + project="full-model", + name=f"{now_str}_phenix_from_{model_alias}", + save_dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", + ) wandb.init( project="full-model", name=f"{now_str}_phenix_from_{model_alias}", - dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs" + dir="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/jobs/lightning_logs", ) slurm_job_id = os.environ.get("SLURM_JOB_ID") if slurm_job_id is not None: wandb.config.update({"slurm_job_id": slurm_job_id}) - print("wandb_logger.experiment.dir",wandb_logger.experiment.dir) + print("wandb_logger.experiment.dir", wandb_logger.experiment.dir) - artifact = wandb.use_artifact(f"flaviagiehr-harvard-university/full-model/best_models:{model_alias}", type="model") + artifact = wandb.use_artifact( + f"flaviagiehr-harvard-university/full-model/best_models:{model_alias}", + type="model", + ) artifact_dir = artifact.download() print("Files in artifact_dir:", os.listdir(artifact_dir)) @@ -37,18 +45,19 @@ def run(): # config_path = os.path.join(config_artifact.download(), "config.yaml") # model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path=config_path) - model_settings, loss_settings, dataloader_settings = load_settings_from_yaml(path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml") - + model_settings, loss_settings, dataloader_settings = load_settings_from_yaml( + path="/n/holylabs/LABS/hekstra_lab/Users/fgiehr/factory/configs/config_example.yaml" + ) ckpt_file = [f for f in os.listdir(artifact_dir) if f.endswith(".ckpt")][-1] print(f"downloaded {len(ckpt_files)} checkpoint files.") - def main(): # cProfile.run('run()') run() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/factory/test_dataloader.py b/src/factory/test_dataloader.py index 18c743f..4944a52 100644 --- a/src/factory/test_dataloader.py +++ b/src/factory/test_dataloader.py @@ -1,23 +1,27 @@ import settings from data_loader import CrystallographicDataLoader + def test_dataloader(data_loader_settings: settings.DataLoaderSettings): print("enter test") dataloader = CrystallographicDataLoader(data_loader_settings=data_loader_settings) - print("Loading data ...") + print("Loading data ...") dataloader.load_data_() print("Data loaded successfully.") - + test_data = dataloader.load_data_set_batched_by_image( data_set_to_load=dataloader.train_data_set ) print("Test data loaded successfully.") for batch in test_data: - shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl = batch + shoeboxes_batch, metadata_batch, dead_pixel_mask_batch, counts_batch, hkl = ( + batch + ) print("Batch shoeboxes :", shoeboxes_batch.shape) - print("image ids", metadata_batch[:,2]) + print("image ids", metadata_batch[:, 2]) return test_data + def parse_args(): parser = argparse.ArgumentParser(description="Pass DataLoader settings") parser.add_argument( @@ -27,10 +31,12 @@ def parse_args(): ) return parser.parse_args() + def main(): # args = parse_args() data_loader_settings = settings.DataLoaderSettings() test_dataloader(data_loader_settings=data_loader_settings) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/factory/unmerged_mtz.py b/src/factory/unmerged_mtz.py index cfb69c7..d0b6cd9 100644 --- a/src/factory/unmerged_mtz.py +++ b/src/factory/unmerged_mtz.py @@ -1,7 +1,9 @@ import reciprocalspaceship as rs # Load the MTZ file -ds = rs.read_mtz("/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz") +ds = rs.read_mtz( + "/n/holylabs/LABS/hekstra_lab/Users/fgiehr/creat_dials_unmerged/unmerged.mtz" +) # Show the column labels print("Columns:", ds.columns) diff --git a/src/factory/wrap_folded_normal.py b/src/factory/wrap_folded_normal.py index 86e20a2..9ffa3cb 100644 --- a/src/factory/wrap_folded_normal.py +++ b/src/factory/wrap_folded_normal.py @@ -1,8 +1,9 @@ -import torch from typing import Optional + +import rs_distributions.modules as rsm +import torch from abismal_torch.surrogate_posterior import FoldedNormalPosterior from abismal_torch.symmetry import ReciprocalASUCollection -import rs_distributions.modules as rsm class FrequencyTrackingPosterior(FoldedNormalPosterior): @@ -13,11 +14,11 @@ def __init__( loc: torch.Tensor, scale: torch.Tensor, epsilon: Optional[float] = 1e-12, - **kwargs + **kwargs, ): """ A surrogate posterior that tracks how many times each HKL has been observed. - + Args: rac (ReciprocalASUCollection): ReciprocalASUCollection. loc (torch.Tensor): Unconstrained location parameter of the distribution. @@ -25,11 +26,11 @@ def __init__( epsilon (float, optional): Epsilon value for numerical stability. Defaults to 1e-12. """ super().__init__(rac, loc, scale, epsilon, **kwargs) - + self.register_buffer( "observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long) ) - + # def to(self, *args, **kwargs): # super().to(*args, **kwargs) # device = next(self.parameters()).device @@ -38,26 +39,24 @@ def __init__( # if hasattr(self, 'observed'): # self.observed = self.observed.to(device) # return self - + def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None: """ Update both the observed buffer and the observation count. - + Args: rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the rasu ID of each reflection. H (torch.Tensor): A tensor of shape (n_refln, 3). """ - + h, k, l = H.T - observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l] - self.observed[observed_idx] = True self.observation_count[observed_idx] += 1 - + def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor: """ Returns a boolean tensor indicating which HKLs have been observed enough times @@ -74,16 +73,16 @@ def get_distribution(self, rasu_ids: torch.Tensor, H: torch.Tensor): return rsm.FoldedNormal(gathered_loc, gathered_scale) - - from typing import Optional + +import rs_distributions.distributions as rsd import torch import torch.nn as nn import torch.nn.functional as F -import rs_distributions.distributions as rsd from abismal_torch.surrogate_posterior.base import PosteriorBase from abismal_torch.symmetry import ReciprocalASUCollection + class SparseFoldedNormalPosterior(PosteriorBase): def __init__( self, @@ -91,7 +90,7 @@ def __init__( loc: torch.Tensor, scale: torch.Tensor, epsilon: Optional[float] = 1e-6, - **kwargs + **kwargs, ): # Call parent with a dummy distribution; we'll override it below super().__init__(rac, distribution=None, epsilon=epsilon, **kwargs) @@ -100,13 +99,13 @@ def __init__( num_params = rac.rac_size # or however many unique positions you have # Embeddings whose gradients will be sparse - self.loc_embed = nn.Embedding(num_params, 1, sparse=True) + self.loc_embed = nn.Embedding(num_params, 1, sparse=True) self.scale_embed = nn.Embedding(num_params, 1, sparse=True) # Initialize them to your starting guesses with torch.no_grad(): # loc_init & scale_init should each be shape [num_params] - self.loc_embed.weight[:, 0] = loc + self.loc_embed.weight[:, 0] = loc self.scale_embed.weight[:, 0] = scale self.epsilon = epsilon @@ -115,13 +114,12 @@ def __init__( "observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long) ) - def get_distribution(self, hkl_indices): # Map each reflection to its ASU‐index indices = hkl_indices # Gather embeddings and squeeze out the singleton dim - loc_raw = self.loc_embed(indices).squeeze(-1) + loc_raw = self.loc_embed(indices).squeeze(-1) scale_raw = self.scale_embed(indices).squeeze(-1) # Enforce positivity (you could also use transform_to()) @@ -129,26 +127,24 @@ def get_distribution(self, hkl_indices): # Functional FoldedNormal has no trainable params—it just wraps a torch.distributions return rsd.FoldedNormal(loc=loc_raw, scale=scale) - + def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None: """ Update both the observed buffer and the observation count. - + Args: rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the rasu ID of each reflection. H (torch.Tensor): A tensor of shape (n_refln, 3). """ - + h, k, l = H.T - observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l] - self.observed[observed_idx] = True self.observation_count[observed_idx] += 1 - + def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor: """ Returns a boolean tensor indicating which HKLs have been observed enough times diff --git a/tests/test_distributions.py b/tests/test_distributions.py index b6407a9..5f19e05 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,31 +1,34 @@ -import torch import pytest +import torch + from factory import distributions + def test_lrmvn_distribution(): """Test LRMVN distribution initialization and forward pass.""" batch_size = 2 hidden_dim = 64 distribution = distributions.LRMVN_Distribution(hidden_dim=hidden_dim) - + # Create dummy input shoebox_representation = torch.randn(batch_size, hidden_dim) image_representation = torch.randn(batch_size, hidden_dim) - + # Test forward pass output = distribution(shoebox_representation, image_representation) assert isinstance(output, torch.distributions.LowRankMultivariateNormal) assert output.loc.shape[0] == batch_size + def test_dirichlet_profile(): """Test Dirichlet profile initialization and forward pass.""" batch_size = 2 hidden_dim = 64 profile = distributions.DirichletProfile(dmodel=hidden_dim) - + # Create dummy input representation = torch.randn(batch_size, hidden_dim) - + # Test forward pass output = profile(representation) - assert isinstance(output, torch.distributions.Dirichlet) \ No newline at end of file + assert isinstance(output, torch.distributions.Dirichlet) diff --git a/tox.ini b/tox.ini index 0982115..4dd931d 100644 --- a/tox.ini +++ b/tox.ini @@ -26,4 +26,4 @@ setenv = extras = test commands = - pytest -v --cov=rs-template --cov-report=xml --color=yes --basetemp={envtmpdir} {posargs} \ No newline at end of file + pytest -v --cov=rs-template --cov-report=xml --color=yes --basetemp={envtmpdir} {posargs} diff --git a/wrap_folded_normal.py b/wrap_folded_normal.py index db3a485..5f30fdf 100644 --- a/wrap_folded_normal.py +++ b/wrap_folded_normal.py @@ -1,8 +1,10 @@ -import torch from typing import Optional + +import torch from abismal_torch.surrogate_posterior import FoldedNormalPosterior from abismal_torch.symmetry import ReciprocalASUCollection + class FrequencyTrackingPosterior(FoldedNormalPosterior): def __init__( self, @@ -10,11 +12,11 @@ def __init__( loc: torch.Tensor, scale: torch.Tensor, epsilon: Optional[float] = 1e-12, - **kwargs + **kwargs, ): """ A surrogate posterior that tracks how many times each HKL has been observed. - + Args: rac (ReciprocalASUCollection): ReciprocalASUCollection. loc (torch.Tensor): Unconstrained location parameter of the distribution. @@ -22,15 +24,15 @@ def __init__( epsilon (float, optional): Epsilon value for numerical stability. Defaults to 1e-12. """ super().__init__(rac, loc, scale, epsilon, **kwargs) - + self.register_buffer( "observation_count", torch.zeros(self.rac.rac_size, dtype=torch.long) ) - + def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None: """ Update both the observed buffer and the observation count. - + Args: rasu_id (torch.Tensor): A tensor of shape (n_refln,) that contains the rasu ID of each reflection. @@ -40,10 +42,10 @@ def update_observed(self, rasu_id: torch.Tensor, H: torch.Tensor) -> None: observed_idx = self.rac.reflection_id_grid[rasu_id, h, k, l] self.observed[observed_idx] = True self.observation_count[observed_idx] += 1 - + def reliable_observations_mask(self, min_observations: int = 5) -> torch.Tensor: """ Returns a boolean tensor indicating which HKLs have been observed enough times to be considered reliable. """ - return self.observation_count >= min_observations \ No newline at end of file + return self.observation_count >= min_observations