diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 53bf10a1..f3424f15 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,6 +16,7 @@ env: jobs: tests: runs-on: ${{ matrix.os }} + if: contains(github.event.pull_request.labels.*.name, 'skip-ci') == false strategy: matrix: include: @@ -88,6 +89,7 @@ jobs: docs: runs-on: ubuntu-latest + if: contains(github.event.pull_request.labels.*.name, 'skip-ci') == false steps: - uses: actions/checkout@v5 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e9bb969..d1df0c9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,9 +17,20 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + additional_dependencies: [tomli] + - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.14.3" hooks: - id: ruff-format - id: ruff-check args: ["--fix", "--show-fixes"] + + - repo: https://github.com/sphinx-contrib/sphinx-lint + rev: v1.0.0 + hooks: + - id: sphinx-lint diff --git a/CHANGES.rst b/CHANGES.rst index 311c232b..ca575630 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -95,7 +95,7 @@ Maintenance - Deleted unused functions - - Deleted unsued architectures + - Deleted unused architectures - Renamed symmetry function [`#166 `__] - Added radionets logo to README [`#169 `__] @@ -220,7 +220,7 @@ Radionets 0.2.0 (2023-01-31) API Changes ----------- -- Train on half-sized iamges and applying symmetry afterward is a backward incompatible change +- Train on half-sized images and applying symmetry afterward is a backward incompatible change - Models trained with early versions of ``radionets`` are not supported anymore [`#140 `__] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..a2b8950d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,84 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations diff --git a/configs/radionets_default_train_config.toml b/configs/radionets_default_train_config.toml index 98c10e20..e3898ade 100644 --- a/configs/radionets_default_train_config.toml +++ b/configs/radionets_default_train_config.toml @@ -1,42 +1,76 @@ -# This is a TOML document. - -title = "Train configuration" - -[mode] -quiet = true -gpu = false - -[logging] -comet_ml = true -project_name = "VLA" -plot_n_epochs = 2 -scale = true +title = "Radionets Default Training Configuration" +# ───────────────────────────────────────────────────────────────────────────── +# PATHS & I/O +# ───────────────────────────────────────────────────────────────────────────── [paths] data_path = "./example_data/" model_path = "./build/example_model/example.model" -pre_model = "none" +checkpoint = false -[general] +# ───────────────────────────────────────────────────────────────────────────── +# MODEL & ARCHITECTURE +# ───────────────────────────────────────────────────────────────────────────── +[model] +arch_name = "SRResNet18" fourier = true amp_phase = true normalize = false -source_list = false -arch_name = "filter_deep" -loss_func = "splitted_L1" -num_epochs = 5 -inspection = true -output_format = "png" -switch_loss = false -when_switch = 25 - -[hypers] -batch_size = 100 -lr = 1e-3 - -[param_scheduling] -use = true -lr_start = 7e-2 -lr_max = 3e-1 -lr_stop = 5e-2 -lr_ratio = 0.25 + +# ───────────────────────────────────────────────────────────────────────────── +# TRAINING +# ───────────────────────────────────────────────────────────────────────────── +[training] +num_epochs = 50 +batch_size = 16 + +[training.loss] +loss_func = "MSELoss" + +[training.optimizer] +optimizer = "AdamW" +lr = 0.001 + +[training.lr_scheduling] +scheduler = "OneCycleLR" + +# kwargs for OneCycleLR +# See: https://docs.pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html +max_lr = 1e-3 # Maxmimum at 1e-3 +div_factor = 1e2 # Initial at 1e-3 / 1e2 = 1e-5 +final_div_factor = 1e2 # Final at initial_lr / 1e2 = 1e-7 +pct_start = 0.25 + +# ───────────────────────────────────────────────────────────────────────────── +# ACCELERATORS & DATALOADING +# ───────────────────────────────────────────────────────────────────────────── +[devices] +accelerator = "gpu" +num_devices = "auto" +precision = "32-true" +deepspeed = "deepspeed_stage_2" + +[dataloader] +module = "WebDatasetModule" +num_workers = 10 +prefetch_factor = 2 +persistent_workers = false + +# ───────────────────────────────────────────────────────────────────────────── +# CALLBACKS +# ───────────────────────────────────────────────────────────────────────────── +[callbacks.checkpoint] +every_n_epochs = 2 +save_top_k = -1 + +[callbacks.batch_size_finder] +mode = "binsearch" + +# ───────────────────────────────────────────────────────────────────────────── +# LOGGING +# ───────────────────────────────────────────────────────────────────────────── +[logging] +project_name = "Radionets Experiment" +plot_n_epochs = 1 +scale = true +comet_ml = true diff --git a/docs/changes/194.maintenance.2.rst b/docs/changes/194.maintenance.2.rst deleted file mode 100644 index 4ffd7420..00000000 --- a/docs/changes/194.maintenance.2.rst +++ /dev/null @@ -1 +0,0 @@ -Fix plot sizes in classes of :mod:`radionets.core.callbacks` diff --git a/docs/changes/194.maintenance.rst b/docs/changes/194.maintenance.rst index 94ac79e2..72cd55aa 100644 --- a/docs/changes/194.maintenance.rst +++ b/docs/changes/194.maintenance.rst @@ -1 +1,3 @@ -Added docstrings to classses of :mod:`radionets.core.callbacks` +Added docstrings to classes of :mod:`radionets.core.callbacks` + +Fix plot sizes in classes of :mod:`radionets.core.callbacks` diff --git a/docs/developer-guide/contributions.md b/docs/developer-guide/contributions.md index da93fcc6..afd7eac5 100644 --- a/docs/developer-guide/contributions.md +++ b/docs/developer-guide/contributions.md @@ -182,7 +182,7 @@ use the imperative, a short description as the first line, followed by a blank l and then followed by details if needed, e.g. as a bullet list. ```{seealso} -[Convetional Commits][conventionalcommits] for examples and information +[Conventional Commits][conventionalcommits] for examples and information on how to write good commit messages. ``` Make sure you frequently test the code during development (see {ref}`testing`). diff --git a/pyproject.toml b/pyproject.toml index a9480d56..f2a5ebb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,10 +102,7 @@ dev = [ ] [project.scripts] -radionets-simulation = "radionets.simulations.scripts.simulate_images:main" -radionets-training = "radionets.training.scripts.start_training:main" -radionets-evaluation = "radionets.evaluation.scripts.start_evaluation:main" -radionets-quickstart = "radionets.tools.quickstart:quickstart" +radionets = "radionets.tools.cli:main" [tool.hatch.version] source = "vcs" @@ -133,7 +130,25 @@ omit = [ output = "coverage.xml" [tool.pytest.ini_options] -addopts = "--verbose" +norecursedirs = [ + ".git", + ".github", + "dist", + "build", + "docs", +] +addopts = [ + "--strict-markers", + "--doctest-modules", + "--color=yes", + "--disable-pytest-warnings", + "--ignore=legacy/checkpoints", +] +filterwarnings = [ + "error::FutureWarning", + "ignore:You are using `torch.load` with `weights_only=False`.*:FutureWarning", +] +xfail_strict = true [tool.towncrier] package = "radionets" @@ -174,7 +189,7 @@ addopts = "--verbose" [tool.ruff] target-version = "py311" line-length = 88 -extend-exclude = ["tests", "examples"] +extend-exclude = ["examples"] [tool.ruff.lint] extend-select = [ @@ -192,6 +207,9 @@ unfixable = [] [tool.ruff.lint.per-file-ignores] "examples/**" = ["I"] +"tests/**" = [ + "E402", +] [tool.ruff.format] quote-style = "double" @@ -202,3 +220,23 @@ docstring-code-format = true [tool.ruff.lint.isort] known-first-party = ["radionets"] + +[tool.mypy] +files = [ + "src/radionets", +] + +install_types = "True" +non_interactive = "True" +disallow_untyped_defs = "True" +ignore_missing_imports = "True" +show_error_codes = "True" +warn_redundant_casts = "True" +warn_unused_configs = "True" +warn_unused_ignores = "True" +allow_redefinition = "True" +warn_no_return = "False" + +[tool.codespell] +skip = "examples/**" # NOTE: For now we skip the examples +ignore-words-list = "RIME,bund" diff --git a/src/radionets/architecture/__init__.py b/src/radionets/architecture/__init__.py index 64660d39..7330434d 100644 --- a/src/radionets/architecture/__init__.py +++ b/src/radionets/architecture/__init__.py @@ -11,7 +11,7 @@ ) from .blocks import BottleneckResBlock, Decoder, Encoder, NNBlock, SRBlock from .layers import LocallyConnected2d -from .unc_archs import Uncertainty, UncertaintyWrapper +from .uncertainty_archs import Uncertainty, UncertaintyWrapper __all__ = [ "BottleneckResBlock", diff --git a/src/radionets/architecture/archs.py b/src/radionets/architecture/archs.py index 8d1ddd48..da8ecdcc 100644 --- a/src/radionets/architecture/archs.py +++ b/src/radionets/architecture/archs.py @@ -3,9 +3,9 @@ import torch from torch import nn -from radionets.architecture.activation import GeneralReLU -from radionets.architecture.blocks import ComplexSRBlock, SRBlock -from radionets.architecture.layers import ( +from .activation import GeneralReLU +from .blocks import ComplexSRBlock, SRBlock +from .layers import ( ComplexConv2d, ComplexInstanceNorm2d, ComplexPReLU, @@ -65,10 +65,10 @@ def __init__(self): ), ) - def _create_blocks(self, n_blocks): + def _create_blocks(self, n_blocks, **kwargs): blocks = [] for _ in range(n_blocks): - blocks.append(SRBlock(64, 64)) + blocks.append(SRBlock(64, 64, **kwargs)) self.blocks = nn.Sequential(*blocks) @@ -118,10 +118,10 @@ def __init__(self): ), ) - def _create_blocks(self, n_blocks): + def _create_blocks(self, n_blocks, **kwargs): blocks = [] for _ in range(n_blocks): - blocks.append(ComplexSRBlock(64, 64)) + blocks.append(ComplexSRBlock(64, 64, **kwargs)) self.blocks = nn.Sequential(*blocks) diff --git a/src/radionets/architecture/blocks.py b/src/radionets/architecture/blocks.py index 723a43f2..4f9cc7b9 100644 --- a/src/radionets/architecture/blocks.py +++ b/src/radionets/architecture/blocks.py @@ -36,7 +36,7 @@ class NNBlock(nn.Module): Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -70,7 +70,7 @@ def __init__( Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ super().__init__() @@ -121,7 +121,7 @@ class SRBlock(NNBlock): Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -164,7 +164,7 @@ def __init__( Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -228,7 +228,7 @@ class ComplexSRBlock(NNBlock): Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -271,7 +271,7 @@ def __init__( Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -329,7 +329,7 @@ class BottleneckResBlock(NNBlock): Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -365,7 +365,7 @@ def __init__( Controls the behavior of input and output groups. See :class:`~torch.nn.Conv2d`. Default: 1 dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -451,7 +451,7 @@ class Encoder(NNBlock): bias : bool Whether to apply bias. Default: False dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False batchnorm : bool, optional If ``True``, add a batchnorm layer to the @@ -492,7 +492,7 @@ def __init__( bias : bool Whether to apply bias. Default: False dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False batchnorm : bool, optional If ``True``, add a batchnorm layer to the @@ -561,7 +561,7 @@ class Decoder(NNBlock): bias : bool Whether to apply bias. Default: False dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ @@ -601,7 +601,7 @@ def __init__( bias : bool Whether to apply bias. Default: False dropout : bool or float, optional - Wether to apply dropout. If float > 0 this is + Whether to apply dropout. If float > 0 this is the dropout percentage. Default: False """ diff --git a/src/radionets/architecture/loss.py b/src/radionets/architecture/loss.py new file mode 100644 index 00000000..a2487daf --- /dev/null +++ b/src/radionets/architecture/loss.py @@ -0,0 +1,105 @@ +import numpy as np +import torch +from torch import Tensor, nn + +from radionets.evaluation.utils import apply_symmetry + + +class SplittedL1Loss(nn.Module): + def __init__(self, reduction: str = "mean") -> None: + super().__init__() + self.reduction = reduction + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """ + Runs the forward pass. + """ + inp_amp = pred[:, 0, :] + inp_phase = pred[:, 1, :] + + tar_amp = target[:, 0, :] + tar_phase = target[:, 1, :] + + l1 = nn.L1Loss(self.reduction) + loss_amp = l1(inp_amp, tar_amp) + loss_phase = l1(inp_phase, tar_phase) + loss = loss_amp + loss_phase + + return loss + + +class MaskedSplittedL1Loss(nn.Module): + def __init__( + self, + size_average: bool = None, + reduce: bool = None, + reduction: str = "mean", + center: list | tuple = None, + radius: int = 30, + ) -> None: + super().__init__() + + self.reduction = reduction + self.center = center + self.radius = radius + + # Assign mask so it can be cached during forward call; + # None at first, then torch.Tensor after caching + self._mask: torch.Tensor | None = None + + def _create_circular_mask( + self, + w: int, + h: int, + center: list | tuple = None, + radius: int = None, + device: torch.device = None, + ) -> np.ndarray: + if center is None: + center = (int(w / 2), int(h / 2)) + + if radius is None: + radius = min(center[0], center[1], w - center[0], h - center[1]) + + x = torch.arange(w, device=device).view(1, -1) + y = torch.arange(h, device=device).view(-1, 1) + dist_from_center = torch.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2) + + mask = dist_from_center <= radius + + return mask + + def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: + inputs = apply_symmetry(inputs) + targets = apply_symmetry(targets) + + _, _, h, w = targets.shape + + inp_amp = inputs[:, 0] + inp_phase = inputs[:, 1] + + tar_amp = targets[:, 0] + tar_phase = targets[:, 1] + + if self._mask is None or self._mask.device != inputs.device: + self._mask = self._create_circular_mask( + w=w, + h=h, + center=self.center, + radius=self.radius, + device=inputs.device, + ) + + weight = torch.where(self._mask, 1.0, 0.3) + + inp_amp *= weight + inp_phase *= weight + tar_amp *= weight + tar_phase *= weight + + l1 = nn.L1Loss(reduction=self.reduction) + loss_amp = l1(inp_amp, tar_amp) + loss_phase = l1(inp_phase, tar_phase) + loss = loss_amp + loss_phase + + return loss diff --git a/src/radionets/architecture/unc_archs.py b/src/radionets/architecture/uncertainty_archs.py similarity index 100% rename from src/radionets/architecture/unc_archs.py rename to src/radionets/architecture/uncertainty_archs.py diff --git a/src/radionets/core/__init__.py b/src/radionets/core/__init__.py index 86a03079..d9d66461 100644 --- a/src/radionets/core/__init__.py +++ b/src/radionets/core/__init__.py @@ -1,54 +1,4 @@ -from .callbacks import ( - AvgLossCallback, - CometCallback, - CudaCallback, - DataAug, - GradientCallback, - Normalize, - PredictionImageGradient, - SaveTempCallback, - SwitchLoss, -) -from .data import ( - DataBunch, - H5DataSet, - get_bundles, - get_dls, - load_data, - open_bundle, - open_bundle_pack, - open_fft_bundle, - save_bundle, - save_fft_pair, -) -from .learner import define_learner, get_learner -from .logging import setup_logger -from .model import init_cnn, load_pre_model, save_model +from .callbacks import Callbacks +from .logging import Loggers -__all__ = [ - "AvgLossCallback", - "CometCallback", - "CudaCallback", - "DataAug", - "DataBunch", - "GradientCallback", - "H5DataSet", - "Normalize", - "PredictionImageGradient", - "SaveTempCallback", - "SwitchLoss", - "define_learner", - "get_bundles", - "get_dls", - "get_learner", - "init_cnn", - "load_data", - "load_pre_model", - "open_bundle", - "open_bundle_pack", - "open_fft_bundle", - "save_bundle", - "save_fft_pair", - "save_model", - "setup_logger", -] +__all__ = ["Callbacks", "Loggers"] diff --git a/src/radionets/core/callbacks.py b/src/radionets/core/callbacks.py index 5631c522..d8243f16 100644 --- a/src/radionets/core/callbacks.py +++ b/src/radionets/core/callbacks.py @@ -1,521 +1,523 @@ +import warnings +from abc import ABC from pathlib import Path -import comet_ml -import kornia as K +import lightning as L import matplotlib.pyplot as plt import numpy as np -import torch -from fastai.callback.core import Callback, CancelBackwardException +import pandas as pd +from lightning.pytorch.callbacks import ( + BatchSizeFinder, + DeviceStatsMonitor, + EarlyStopping, + LearningRateMonitor, + ModelCheckpoint, + RichProgressBar, + Timer, +) +from lightning.pytorch.callbacks import Callback as LightningCallback +from lightning.pytorch.loggers import CometLogger, MLFlowLogger from matplotlib.colors import PowerNorm +from pydantic import BaseModel -from radionets.core.logging import setup_logger -from radionets.core.model import save_model -from radionets.core.utils import _maybe_item, get_ifft_torch -from radionets.evaluation.utils import ( - apply_normalization, - apply_symmetry, - check_vmin_vmax, - eval_model, - get_ifft, - get_images, - load_data, - load_pretrained_model, - make_axes_nice, - rescale_normalization, -) +from radionets.evaluation.contour import analyse_intensity, area_of_contour +from radionets.evaluation.utils import apply_symmetry, get_ifft +from radionets.plotting.utils import get_vmin_vmax, set_cbar -__all__ = [ - "CometCallback", - "AvgLossCallback", - "PredictionImageGradient", - "GradientCallback", - "SwitchLoss", - "SaveTempCallback", - "Normalize", - "DataAug", - "CudaCallback", -] - -LOGGER = setup_logger(namespace=__name__) - - -class CometCallback(Callback): - """Callback for logging training metrics and visualizations - to Comet ML. - - This callback logs training and validation losses, and - creates plots for predictions and Fourier-transformed - data for monitoring during training. - - Parameters - ---------- - name : str - Project name for the Comet ML experiment. - validation_data : str or Path - Path to the validation dataset. - plot_n_epochs : int - Frequency of plotting (every n epochs). - amp_phase : bool - Whether to use amplitude-phase representation. - scale : str - Scaling method for data. - """ - - def __init__(self, name, validation_data, plot_n_epochs, amp_phase, scale): - self.experiment = comet_ml.Experiment(project_name=name) - self.data_path = validation_data - self.plot_epoch = plot_n_epochs - self.test_ds = load_data(self.data_path, mode="test", fourier=True) - self.amp_phase = amp_phase - self.scale = scale - self.uncertainty = False - - def after_train(self): - self.experiment.log_metric( - "Train Loss", - self.recorder._train_mets.map(_maybe_item), - epoch=self.epoch + 1, - ) - def after_validate(self): - self.experiment.log_metric( - "Validation Loss", - self.recorder._valid_mets.map(_maybe_item), - epoch=self.epoch + 1, - ) +class Callbacks: + @classmethod + def get_callbacks(cls, train_config: BaseModel) -> list: + default_callback = RichProgressBar() + callbacks = [default_callback] - def plot_test_pred(self): - img_test, img_true, _ = get_images(self.test_ds, 1, rand=False) - img_test = img_test.unsqueeze(0) - img_true = img_true.unsqueeze(0) - model = self.model + if train_config.callbacks.model_checkpoint: + model_checkpoint = ModelCheckpoint( + **train_config.callbacks.model_checkpoint.model_dump() + ) + callbacks.append(model_checkpoint) - try: - if self.learn.normalize.mode == "all": - norm_dict = {"all": 0} - img_test, norm_dict = apply_normalization(img_test, norm_dict) - except AttributeError: - pass + if train_config.callbacks.batch_size_finder: + batch_size_finder = BatchSizeFinder( + **train_config.callbacks.batch_size_finder.model_dump() + ) + callbacks.append(batch_size_finder) - with self.experiment.test(), torch.no_grad(): - pred = eval_model(img_test, model) + if train_config.callbacks.early_stopping: + early_stopping = EarlyStopping( + **train_config.callbacks.early_stopping.model_dump() + ) + callbacks.append(early_stopping) - try: - if self.learn.normalize.mode == "all": - pred = rescale_normalization(pred, norm_dict) - except AttributeError: - pass - - if pred.shape[1] == 4: - self.uncertainty = True - pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1) - - images = {"pred": pred, "truth": img_true} - images = apply_symmetry(images) - pred = images["pred"] - img_true = images["truth"] - - fig, ax = plt.subplots(2, 2, figsize=(11, 8.5), layout="constrained") - ax = ax.ravel() - - lim_amp = check_vmin_vmax(img_true[0, 0]) - lim_phase = check_vmin_vmax(img_true[0, 1]) - im1 = ax[0].imshow( - pred[0, 0], cmap="radionets.PuOr", vmin=-lim_amp, vmax=lim_amp - ) - make_axes_nice(fig, ax[0], im1, "Real") + if train_config.callbacks.lr_monitor: + lr_monitor = LearningRateMonitor( + **train_config.callbacks.lr_monitor.model_dump() + ) + callbacks.append(lr_monitor) - im2 = ax[1].imshow( - pred[0, 1], cmap="radionets.PuOr", vmin=-lim_phase, vmax=lim_phase - ) - make_axes_nice(fig, ax[1], im2, "Imaginary") + if train_config.callbacks.device_stats_monitor: + callbacks.append(DeviceStatsMonitor()) - im3 = ax[2].imshow( - img_true[0, 0], cmap="radionets.PuOr", vmin=-lim_amp, vmax=lim_amp - ) - make_axes_nice(fig, ax[2], im3, "Org. Real") + if train_config.callbacks.timer: + timer = Timer(**train_config.callbacks.timer.model_dump()) + callbacks.append(timer) - im4 = ax[3].imshow( - img_true[0, 1], cmap="radionets.PuOr", vmin=-lim_phase, vmax=lim_phase - ) - make_axes_nice(fig, ax[3], im4, "Org. Imaginary") + if train_config.logging.comet_ml: + callbacks.append(CometCallback(train_config)) - self.experiment.log_figure( - figure=fig, figure_name=f"{self.epoch + 1}_pred_epoch" - ) - plt.close() + if train_config.logging.mlflow: + callbacks.append(MLFlowCallback(train_config)) - def plot_test_fft(self): - img_test, img_true, _ = get_images(self.test_ds, 1, rand=False) - img_test = img_test.unsqueeze(0) - img_true = img_true.unsqueeze(0) - model = self.model + if train_config.logging.codecarbon: + callbacks.append(MLFlowCodeCarbonCallback(train_config)) - try: - if self.learn.normalize.mode == "all": - norm_dict = {"all": 0} - img_test, norm_dict = apply_normalization(img_test, norm_dict) - except AttributeError: - pass + callbacks.append(LogAdditionalParamsCallback(train_config)) - with self.experiment.test(), torch.no_grad(): - pred = eval_model(img_test, model) + return callbacks - try: - if self.learn.normalize.mode == "all": - pred = rescale_normalization(pred, norm_dict) - except AttributeError: - pass - if self.uncertainty: - pred = torch.stack((pred[:, 0, :], pred[:, 2, :]), dim=1) +class PlottingCallbackABC(ABC, LightningCallback): + def __init__(self, train_config, *args, **kwargs): + super().__init__() + self.train_config = train_config + self.amp_phase = train_config.model.amp_phase + self.scale = train_config.logging.scale + + self.cached_batch = None - images = {"pred": pred, "truth": img_true} - images = apply_symmetry(images) - pred = images["pred"] - img_true = images["truth"] + data_types = ["Amplitude", "Phase"] if self.amp_phase else ["Real", "Imaginary"] + results = [" Prediction", " Ground Truth"] + self.pred_plot_titles = [t + r for r in results for t in data_types] - ifft_pred = get_ifft_torch( - pred, + def plot_val_pred(self, predictions, targets, current_epoch: int): + self.fig, self.axs = plt.subplots( + 2, 2, figsize=(12, 8.5), layout="constrained", sharex=True, sharey=True + ) + self.axs = self.axs.flatten() + + limits_0 = get_vmin_vmax(targets[0, 0]) # Limits for amp/real + limits_1 = get_vmin_vmax(targets[0, 1]) # Limits for phase/imaginary + + im0 = self.axs[0].imshow( + predictions[0, 0], + cmap="radionets.PuOr", + vmin=-limits_0, + vmax=limits_0, + origin="lower", + ) + im1 = self.axs[1].imshow( + predictions[0, 1], + cmap="radionets.PuOr", + vmin=-limits_1, + vmax=limits_1, + origin="lower", + ) + im2 = self.axs[2].imshow( + targets[0, 0], + cmap="radionets.PuOr", + vmin=-limits_0, + vmax=limits_0, + origin="lower", + ) + im3 = self.axs[3].imshow( + targets[0, 1], + cmap="radionets.PuOr", + vmin=-limits_1, + vmax=limits_1, + origin="lower", + ) + + for ax, im, title in zip( + self.axs, + [im0, im1, im2, im3], + self.pred_plot_titles, + ): + set_cbar(self.fig, ax, im, title=title, phase="Phase" in title) + + self.axs[0].set(ylabel="Frequels") + self.axs[2].set(xlabel="Frequels", ylabel="Frequels") + self.axs[3].set(xlabel="Frequels") + + def plot_val_fft(self, predictions, targets, current_epoch): + ifft_pred = get_ifft( + predictions, amp_phase=self.amp_phase, scale=self.scale, - uncertainty=self.uncertainty, - ) - ifft_truth = get_ifft_torch( - img_true, amp_phase=self.amp_phase, scale=self.scale ) + ifft_truth = get_ifft(targets, amp_phase=self.amp_phase, scale=self.scale) - fig, ax = plt.subplots(1, 3, figsize=(16, 4.5), layout="constrained") + self.fig, self.axs = plt.subplots(1, 3, figsize=(16, 4.5), layout="constrained") - im1 = ax[0].imshow( - ifft_pred, norm=PowerNorm(0.25, vmax=ifft_truth.max()), cmap="inferno" + im0 = self.axs[0].imshow( + ifft_pred, + norm=PowerNorm(0.25, vmax=ifft_truth.max()), + cmap="inferno", + origin="lower", + ) + im1 = self.axs[1].imshow( + ifft_truth, + norm=PowerNorm(0.25), + cmap="inferno", + origin="lower", ) - im2 = ax[1].imshow(ifft_truth, norm=PowerNorm(0.25), cmap="inferno") - a = check_vmin_vmax(ifft_pred - ifft_truth) - im3 = ax[2].imshow( - ifft_pred - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a + + limits = get_vmin_vmax(ifft_pred - ifft_truth) + im2 = self.axs[2].imshow( + ifft_pred - ifft_truth, + cmap="radionets.PuOr", + vmin=-limits, + vmax=limits, + origin="lower", ) - make_axes_nice(fig, ax[0], im1, "FFT Prediction") - make_axes_nice(fig, ax[1], im2, "FFT Truth") - make_axes_nice(fig, ax[2], im3, "FFT Diff") + for ax, im, title in zip( + self.axs, + [im0, im1, im2], + ["Prediction", "Truth", "Difference"], + ): + set_cbar(self.fig, ax, im, title="FFT " + title) - ax[0].set( + self.axs[0].set( ylabel="Pixels", xlabel="Pixels", ) - ax[1].set_xlabel("Pixels") - ax[2].set_xlabel("Pixels") + self.axs[1].set_xlabel("Pixels") + self.axs[2].set_xlabel("Pixels") - self.experiment.log_figure( - figure=fig, figure_name=f"{self.epoch + 1}_fft_epoch" - ) - plt.close() - - def after_epoch(self): - if (self.epoch + 1) % self.plot_epoch == 0: - self.plot_test_pred() - self.plot_test_fft() - - -class AvgLossCallback(Callback): - """Callback for tracking and plotting average training - and validation losses. - - Saves the average loss for training and validation - that is printed to the terminal. - """ - - def __init__(self): - if not hasattr(self, "loss_train"): - self.loss_train = [] - if not hasattr(self, "loss_valid"): - self.loss_valid = [] - if not hasattr(self, "lrs"): - self.lrs = [] - - def after_train(self): - self.loss_train.append(self.recorder._train_mets.map(_maybe_item)) - - def after_validate(self): - self.loss_valid.append(self.recorder._valid_mets.map(_maybe_item)) - - def after_batch(self): - self.lrs.append(self.opt.hypers[-1]["lr"]) - - def plot_loss(self): - min_epoch = np.argmin(self.loss_valid) - plt.plot(self.loss_train, label="Training loss") - plt.plot(self.loss_valid, label="Validation loss") - plt.axvline( - min_epoch, - color="black", - linestyle="dashed", - label=f"Minimum at Epoch {min_epoch}", - ) - plt.xlabel(r"Number of Epochs") - plt.ylabel(r"Loss") - plt.legend() + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module) -> None: + """Log predictions at validation epoch end.""" - train = np.array(self.loss_train) - valid = np.array(self.loss_valid) + if self.cached_batch is None: + val_dataloader = trainer.datamodule.val_dataloader() + batch = next(iter(val_dataloader)) - return bool(len(train[train < 0]) == 0 or len(valid[valid < 0]) == 0) + # cache only one sample + self.cached_batch = ( + batch[0][0][None, ...].cpu(), + batch[1][0][None, ...].cpu(), + ) - def plot_lrs(self): - plt.plot(self.lrs) - plt.xlabel(r"Number of Batches") - plt.ylabel(r"Learning rate") + if (trainer.current_epoch + 1) % self.train_config.logging.plot_n_epochs == 0: + batch = ( + self.cached_batch[0].to(pl_module.device), + self.cached_batch[1].to(pl_module.device), + ) + predictions = pl_module.predict_step(batch, batch_idx=0).cpu() + targets = batch[1].cpu() -class CudaCallback(Callback): - """Callback to move model to CUDA device before training. + # check if images are half or full + if predictions.shape[-2] != predictions.shape[-1]: + predictions = apply_symmetry(predictions) + targets = apply_symmetry(targets) - Simple callback that ensures the model is moved to the - GPU before the training loop. + self.plot_val_pred( + predictions, + targets, + current_epoch=trainer.current_epoch, + ) - Attributes - ---------- - _order : int - Callback execution order (3). - """ + self.plot_val_fft( + predictions, + targets, + current_epoch=trainer.current_epoch, + ) - _order = 3 - def before_fit(self): - self.model.cuda() +class CometCallback(PlottingCallbackABC): + def __init__(self, train_config, *args, **kwargs): + super().__init__(train_config, *args, **kwargs) + self.experiment = None + def plot_val_pred(self, predictions, targets, current_epoch: int) -> None: + super().plot_val_pred(predictions, targets, current_epoch) -class DataAug(Callback): - """Callback that applies data augmentation using - random rotations. + self.experiment.log_figure( + figure=self.fig, + figure_name=f"fourier_pred_{current_epoch:0>4}", + ) - Applies random multiples of 90-degree rotations to both - input and target tensors before each batch to augment - the training data. - """ + plt.close(self.fig) - _order = 3 + def plot_val_fft(self, predictions, targets, current_epoch: int) -> None: + super().plot_val_fft(predictions, targets, current_epoch) - def before_batch(self): - x = self.xb[0].clone() - y = self.yb[0].clone() + self.experiment.log_figure( + figure=self.fig, + figure_name=f"fft_pred_{current_epoch:0>4}", + ) + plt.close(self.fig) - randint = np.random.randint(0, 1, x.shape[0]) * 2 - last_axis = len(x.shape) - 1 + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module) -> None: + """Log predictions at validation epoch end.""" + if self.experiment is None: + try: + self.experiment = next( + logger.experiment + for logger in trainer.loggers + if isinstance(logger, CometLogger) + ) + except StopIteration as e: + raise ValueError( + f"Could not find a CometLogger instance in {trainer.loggers}." + ) from e - for i in range(x.shape[0]): - x[i] = torch.rot90(x[i], int(randint[i]), [last_axis - 2, last_axis - 1]) - y[i] = torch.rot90(y[i], int(randint[i]), [last_axis - 2, last_axis - 1]) + super().on_validation_epoch_end(trainer, pl_module) - x = x.squeeze(1) - y = y.squeeze(1) - self.learn.xb = [x] - self.learn.yb = [y] +class MLFlowCallback(PlottingCallbackABC): + def __init__(self, train_config, *args, **kwargs): + super().__init__(train_config, *args, **kwargs) + self.experiment = None -class Normalize(Callback): - """Normalization callback for input and target data. + def plot_val_pred(self, predictions, targets, current_epoch: int) -> None: + super().plot_val_pred(predictions, targets, current_epoch) - Parameters - ---------- - conf : dict - Dictionary containing the normalization type stored - under the ``'normalize'`` key. - """ + artifact_file = f"fourier_pred_{current_epoch:0>4}.png" - _order = 4 + self.experiment.log_figure( + figure=self.fig, + artifact_file=artifact_file, + run_id=self.logger._run_id, + ) - def __init__(self, conf): - self.mode = conf["normalize"] - if self.mode == "mean": - self.mean_real = conf["norm_factors"]["mean_real"] - self.mean_imag = conf["norm_factors"]["mean_imag"] - self.std_real = conf["norm_factors"]["std_real"] - self.std_imag = conf["norm_factors"]["std_imag"] + plt.close(self.fig) - def normalize(self, x, m, s): - return (x - m) / s + def plot_val_fft(self, predictions, targets, current_epoch: int) -> None: + super().plot_val_fft(predictions, targets, current_epoch) - def before_batch(self): - x = self.xb[0].clone() - y = self.yb[0].clone() + artifact_file = f"fft_pred_{current_epoch:0>4}.png" - if self.mode == "max": - x[:, 0] *= 1 / torch.amax(x[:, 0], dim=(-2, -1), keepdim=True) - x[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True) - y[:, 0] *= 1 / torch.amax(x[:, 0], dim=(-2, -1), keepdim=True) - y[:, 1] *= 1 / torch.amax(torch.abs(x[:, 1]), dim=(-2, -1), keepdim=True) + self.experiment.log_figure( + figure=self.fig, + artifact_file=artifact_file, + run_id=self.logger._run_id, + ) + plt.close(self.fig) - elif self.mode == "mean": - x[:, 0][x[:, 0] != 0] = self.normalize( - x[:, 0][x[:, 0] != 0], self.mean_real, self.std_real - ) + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module) -> None: + """Log predictions at validation epoch end.""" + if self.experiment is None: + try: + self.logger = next( + logger + for logger in trainer.loggers + if isinstance(logger, MLFlowLogger) + ) + self.experiment = self.logger.experiment - x[:, 1][x[:, 1] != 0] = self.normalize( - x[:, 1][x[:, 1] != 0], self.mean_imag, self.std_imag - ) + self.base_dir = ( + self.train_config.paths.model_path / f"mlflow/{self.logger._run_id}" + ) + self.base_dir.mkdir(parents=True) - y[:, 0] = self.normalize(y[:, 0], self.mean_real, self.std_real) - y[:, 1] = self.normalize(y[:, 1], self.mean_imag, self.std_imag) + except StopIteration as e: + raise ValueError( + f"Could not find a MLFlowLogger instance in {trainer.loggers}." + ) from e - elif self.mode == "all": - # normalize each image so that mean=0 and std=1 - means = x.mean(axis=-1).mean(axis=-1).reshape(x.shape[0], x.shape[1], 1, 1) - stds = x.std(axis=-1).std(axis=-1).reshape(x.shape[0], x.shape[1], 1, 1) - x = self.normalize(x, means, stds) - y = self.normalize(y, means, stds) + super().on_validation_epoch_end(trainer, pl_module) - self.learn.xb = [x] - self.learn.yb = [y] +class MLFlowCodeCarbonCallback(LightningCallback): + def __init__(self, train_config, *args, **kwargs): + self.train_config = train_config -class SaveTempCallback(Callback): - """Callback for saving temporary model checkpoints - during training. + self.experiment = None - Parameters - ---------- - model_path : str or Path - Path where temporary models will be saved. - """ + def on_fit_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) + self.num_samples = trainer.datamodule.train_length + self.num_samples += trainer.datamodule.valid_length - _order = 95 + try: + self._log_metrics() + except (FileNotFoundError, KeyError) as e: + warnings.warn(f"{e}. No emissions were logged.", stacklevel=2) - def __init__(self, model_path): - self.model_path = model_path + self._log_params() - def after_epoch(self): - p = Path(self.model_path).parent - p.mkdir(parents=True, exist_ok=True) + def on_test_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) + self.num_samples = trainer.datamodule.test_length - if (self.epoch + 1) % 10 == 0: - out = p / f"temp_{self.epoch + 1}.model" - save_model(self, out) - LOGGER.info(f"Finished Epoch {self.epoch + 1}, model saved.") + try: + self._log_metrics() + except (FileNotFoundError, KeyError) as e: + warnings.warn(f"{e}. No emissions were logged.", stacklevel=2) + self._log_params() -class SwitchLoss(Callback): - """Callback for switching loss functions during training. + def on_predict_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) + self.num_samples = trainer.datamodule.predict_length - Changes the loss function to a different one after a specified - number of epochs. + try: + self._log_metrics() + except (FileNotFoundError, KeyError) as e: + warnings.warn(f"{e}. No emissions were logged.", stacklevel=2) - Parameters - ---------- - second_loss : callable - The loss function to switch to. - when_switch : int - Epoch number after which to switch loss functions. - """ + self._log_params() - _order = 5 + def _set_up_experiment(self, trainer): + try: + self.logger = next( + logger for logger in trainer.loggers if isinstance(logger, MLFlowLogger) + ) + trainer.carbontracker.tracker.stop() + self.experiment = self.logger.experiment + self.task = trainer.radionets_task + + except StopIteration as e: + raise ValueError( + f"Could not find a MLFlowLogger instance in {trainer.loggers}." + ) from e + + def _log_metrics(self): + emission_file = Path( + self.train_config.logging.codecarbon.output_dir + "/emissions.csv" + ) + emission_data = pd.read_csv(emission_file).to_dict() - def __init__(self, second_loss, when_switch): - self.second_loss = second_loss - self.when_switch = when_switch + eval_res = dict( + running_time_total=emission_data["duration"][0], + running_time=emission_data["duration"][0] / self.num_samples, + power_draw_total=emission_data["energy_consumed"][0] * 3.6e6, + power_draw=emission_data["energy_consumed"][0] * 3.6e6 / self.num_samples, + ) - def before_epoch(self): - if (self.epoch + 1) > self.when_switch: - self.learn.loss_func = self.second_loss + for key, val in eval_res.items(): + self.experiment.log_metric( + key=key, + value=val, + run_id=self.logger._run_id, + ) + self.architecture = emission_data["gpu_model"][0] -class GradientCallback(Callback): - """Callback for gradient and prediction tracking. + # Remove file after logging all important metrics to mlflow. + # This prevents codecarbon from creating 'emissions.csv_%d.bak' + # files in the save directory + if emission_file.is_file(): + emission_file.unlink() - Parameters - ---------- - num_epochs : int - Number of training epochs. - validation_data : str or Path - Path to the validation dataset. - arch_name : str - Name of the architecture used for the model. - amp_phase : bool - Whether to use amplitude-phase representation. - """ + def _log_params(self): + dataset = self.train_config.paths.data_path.name + dataset += "_amp_phase" if self.train_config.model.amp_phase else "_real_imag" - def __init__(self, num_epochs, validation_data, arch_name, amp_phase): - self.num_epochs = num_epochs - self.data_path = validation_data - self.test_ds = load_data( - self.data_path, mode="test", fourier=True, source_list=False - ) - self.arch_name = arch_name - self.amp_phase = amp_phase - - def before_backward(self): - raise CancelBackwardException - - def after_cancel_backward(self): - self.learn.loss.backward() - - # access gradients of weights of layers (with specified batch and epoch) - if self.epoch == self.num_epochs - 1 and self.iter == self.n_iter - 1: - grads = [] - for param in self.learn.model.parameters(): - grads.append(param.grad.view(-1)) - # print or save - - def after_epoch(self): - img_test, img_true = get_images(self.test_ds, 1, rand=False) - - # for each epoch put test image through model and save to csv - fname_template = "pred_{i}.csv" - np.savetxt( - fname_template.format(i=self.epoch), - get_ifft(eval_model(img_test, self.model), self.amp_phase), - delimiter=",", - ) + model = "Radionets" + model += "_" + str(self.train_config.model.arch_name().__class__.__name__) + model += "_" + str(self.train_config.training.optimizer.optimizer.__name__) - # # fourier space - amp_names = "pred_amp_{i}.csv" - phase_names = "pred_phase_{i}.csv" - output = eval_model(img_test, self.model) - np.savetxt( - amp_names.format(i=self.epoch), output[0][0].cpu().numpy(), delimiter="," + if self.train_config.training.lr_scheduling: + model += "_" + str( + self.train_config.training.lr_scheduling.scheduler.__name__ + ) + + params_dict = dict( + model=model, + dataset=dataset, + task=self.task, + architecture=self.architecture, ) - np.savetxt( - phase_names.format(i=self.epoch), output[0][1].cpu().numpy(), delimiter="," + for key, val in params_dict.items(): + self.experiment.log_param( + key=key, + value=val, + run_id=self.logger._run_id, + ) + + +class LogAdditionalParamsCallback(LightningCallback): + def __init__(self, train_config, *args, **kwargs): + self.train_config = train_config + self.amp_phase = train_config.model.amp_phase + + self.experiment = None + + def on_fit_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) + + self._log_metrics( + dataloader=trainer.datamodule.val_dataloader(), + pl_module=pl_module, ) + def on_test_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) -class PredictionImageGradient(Callback): - """Callback for computing spatial gradients - of model predictions. - - Parameters - ---------- - validation_data : str or Path - Path to validation dataset. - model : str or Path - Path to pretrained model. - amp_phase : bool - Whether to use amplitude-phase representation. - arch_name : str - Name of the architecture used for the model. - """ - - def __init__(self, validation_data, model, amp_phase, arch_name): - self.data_path = validation_data - self.test_ds = load_data( - self.data_path, mode="test", fourier=True, source_list=False + self._log_metrics( + dataloader=trainer.datamodule.test_dataloader(), + pl_module=pl_module, ) - self.model = model - self.amp_phase = amp_phase - self.arch_name = arch_name - def save_output_pred(self): - img_test, img_true = get_images(self.test_ds, 5, rand=False) + def on_predict_end(self, trainer, pl_module): + if self.experiment is None: + self._set_up_experiment(trainer) - img_size = img_test[0].shape[-1] - model_used = load_pretrained_model(self.arch_name, self.model, img_size) + self._log_metrics( + dataloader=trainer.datamodule.predict_dataloader(), + pl_module=pl_module, + ) - output = eval_model(img_test[0], model_used) - gradient = K.filters.spatial_gradient(output) + def _set_up_experiment(self, trainer): + try: + self.logger = next( + logger for logger in trainer.loggers if isinstance(logger, MLFlowLogger) + ) + self.experiment = self.logger.experiment + + except StopIteration as e: + raise ValueError( + f"Could not find a MLFlowLogger instance in {trainer.loggers}." + ) from e + + def _log_metrics(self, dataloader, pl_module): + area = [] + total_flux = [] + peak_flux = [] + for batch in dataloader: + preds = pl_module.predict_step(batch[0], batch_idx=0).detach().cpu() + targets = batch[1].detach().cpu() + + # check if images are half or full + if preds.shape[-2] != preds.shape[-1]: + preds = apply_symmetry(preds) + targets = apply_symmetry(targets) + + ifft_preds = get_ifft(preds, amp_phase=self.amp_phase) + ifft_targets = get_ifft(targets, amp_phase=self.amp_phase) + + area.extend( + [ + area_of_contour(ifft_pred, ifft_target) + for ifft_pred, ifft_target in zip(ifft_preds, ifft_targets) + ] + ) - grads_x = get_ifft(gradient[:, :, 0], self.amp_phase) - grads_y = get_ifft(gradient[:, :, 1], self.amp_phase) + total, peak = analyse_intensity(ifft_preds, ifft_targets) + total_flux.extend(total) + peak_flux.extend(peak) - return grads_x, grads_y + trainable_params = sum( + p.numel() for p in pl_module.parameters() if p.requires_grad + ) + additional_metrics = dict( + num_trainable_parameters=trainable_params, + mean_area_ratio=np.abs(1.0 - np.mean(area)), + mean_total_flux=np.abs(1.0 - np.mean(total_flux)), + mean_peak_flux=np.abs(1.0 - np.mean(peak_flux)), + ) + + for key, val in additional_metrics.items(): + self.experiment.log_metric( + key=key, + value=val, + run_id=self.logger._run_id, + ) diff --git a/src/radionets/core/data.py b/src/radionets/core/data.py deleted file mode 100644 index a089ecc9..00000000 --- a/src/radionets/core/data.py +++ /dev/null @@ -1,223 +0,0 @@ -import re -from pathlib import Path - -import h5py -import numpy as np -import torch -from torch.utils.data import DataLoader, Dataset - -__all__ = [ - "DataBunch", - "H5DataSet", - "get_bundles", - "get_dls", - "load_data", - "open_bundle", - "open_bundle_pack", - "open_fft_bundle", - "save_bundle", - "save_fft_pair", -] - - -class H5DataSet: - def __init__(self, bundle_paths, tar_fourier): - """ - Save the bundle paths and the number of bundles in one file. - """ - if bundle_paths == []: - raise ValueError("No bundles found! Please check the names of your files.") - self.bundles = bundle_paths - self.num_img = len(self.open_bundle(self.bundles[0], "x")) - self.tar_fourier = tar_fourier - - def __call__(self): - return print("This is the H5DataSet class.") - - def __len__(self): - """ - Returns the total number of pictures in this dataset - """ - return len(self.bundles) * self.num_img - - def __getitem__(self, i): - x = self.open_image("x", i) - y = self.open_image("y", i) - return x, y - - def open_bundle(self, bundle_path, var): - bundle = h5py.File(bundle_path, "r") - data = bundle[var] - return data - - def open_image(self, var, i): - if isinstance(i, int): - i = torch.tensor([i]) - - elif isinstance(i, np.ndarray): - i = torch.tensor(i) - - indices, _ = torch.sort(i) - bundle = torch.div(indices, self.num_img, rounding_mode="floor") - image = indices - bundle * self.num_img - bundle_unique = torch.unique(bundle) - - bundle_paths = [ - h5py.File(self.bundles[bundle], "r") for bundle in bundle_unique - ] - bundle_paths_str = list(map(str, bundle_paths)) - - data = torch.tensor( - np.array( - [ - bund[var][img] - for bund, bund_str in zip(bundle_paths, bundle_paths_str) - for img in image[ - bundle == bundle_unique[bundle_paths_str.index(bund_str)] - ] - ] - ) - ) - - if self.tar_fourier is False and data.shape[1] == 2: - raise ValueError( - "Two channeled data is used despite Fourier being False.\ - Set Fourier to True!" - ) - - if data.shape[0] == 1: - data = data.squeeze(0) - return data.float() - - -def get_dls(train_ds, valid_ds, batch_size, **kwargs): - return ( - DataLoader(train_ds, batch_size=batch_size, shuffle=True, **kwargs), - DataLoader(valid_ds, batch_size=batch_size, shuffle=True, **kwargs), - ) - - -class DataBunch: - def __init__( - self, - train_dl: DataLoader, - valid_dl: DataLoader, - num_classes: int | None = None, - ): - self.train_dl = train_dl - self.valid_dl = valid_dl - self.num_classes = num_classes - - @property - def train_ds(self) -> Dataset: - return self.train_dl.dataset - - @property - def valid_ds(self) -> Dataset: - return self.valid_dl.dataset - - def __call__(self): - return print("This is the DataBunch class.") - - def __repr__(self) -> str: - return ( - f"DataBunch(train_size={len(self.train_ds)}, " - f"valid_size={len(self.valid_ds)}, " - f"num_classes={self.num_classes})" - ) - - -def save_bundle(path, bundle, counter, name="gs_bundle"): - with h5py.File(str(path) + str(counter) + ".h5", "w") as hf: - hf.create_dataset(name, data=bundle) - hf.close() - - -def open_bundle(path): - """ - open radio galaxy bundles created in first analysis step - """ - f = h5py.File(path, "r") - bundle = np.array(f["gs_bundle"]) - return bundle - - -def open_fft_bundle(path): - """ - open radio galaxy bundles created in first analysis step - """ - f = h5py.File(path, "r") - x = np.array(f["x"]) - y = np.array(f["y"]) - return x, y - - -def get_bundles(path): - """ - returns list of bundle paths located in a directory - """ - data_path = Path(path) - bundles = np.array([x for x in data_path.iterdir()]) - return bundles - - -def save_fft_pair(path, x, y, z=None, name_x="x", name_y="y", name_z="z"): - """ - write fft_pairs created in second analysis step to h5 file - """ - with h5py.File(path, "w") as hf: - hf.create_dataset(name_x, data=x) - hf.create_dataset(name_y, data=y) - if z is not None: - [hf.create_dataset(name_z + str(i), data=z[i]) for i in range(len(z))] - hf.close() - - -def open_bundle_pack(path): - bundle_x = [] - bundle_y = [] - bundle_z = [] - - f = h5py.File(path, "r") - - bundle_size = len(f) // 3 - for i in range(bundle_size): - bundle_x_i = np.array(f["x" + str(i)]) - bundle_x.append(bundle_x_i) - bundle_y_i = np.array(f["y" + str(i)]) - bundle_y.append(bundle_y_i) - bundle_z_i = np.array(f["z" + str(i)]) - bundle_z.append(bundle_z_i) - - f.close() - - return np.array(bundle_x), np.array(bundle_y), bundle_z - - -def load_data(data_path, mode, fourier=False): - """ - Load data set from a directory and return it as H5DataSet. - - Parameters - ---------- - data_path : str - path to data directory - mode : str - specify data set type, e.g. test - fourier : bool - use Fourier images as target if True, default is False - - Returns - ------- - test_ds : H5DataSet - dataset containing x and y images - """ - bundle_paths = get_bundles(data_path) - - data = np.sort( - [path for path in bundle_paths if re.findall("samp_" + mode, path.name)] - ) - data = sorted(data, key=lambda f: int("".join(filter(str.isdigit, str(f))))) - - ds = H5DataSet(data, tar_fourier=fourier) - return ds diff --git a/src/radionets/core/learner.py b/src/radionets/core/learner.py deleted file mode 100644 index e9ed16b7..00000000 --- a/src/radionets/core/learner.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch.nn as nn -from fastai.callback.schedule import ParamScheduler, combined_cos -from fastai.data.core import DataLoaders -from fastai.learner import Learner -from fastai.optimizer import Adam - -import radionets.core.loss_functions as loss_functions -from radionets.core.callbacks import ( - AvgLossCallback, - CometCallback, - CudaCallback, - DataAug, - Normalize, - SaveTempCallback, - SwitchLoss, -) -from radionets.core.model import init_cnn - -__all__ = ["get_learner", "define_learner"] - - -def get_learner(data, arch, lr, loss_func=None, cb_funcs=None, opt_func=Adam, **kwargs): - if not loss_func: - loss_func = nn.MSELoss() - - init_cnn(arch) - dls = DataLoaders.from_dsets( - data.train_ds, data.valid_ds, bs=data.train_dl.batch_size - ) - return Learner(dls, arch, loss_func, lr=lr, cbs=cb_funcs, opt_func=opt_func) - - -def define_learner(data, arch, train_conf, lr_find=False, plot_loss=False): - cbfs = [] - model_path = train_conf["model_path"] - lr = train_conf["lr"] - opt_func = Adam - - if train_conf["param_scheduling"]: - sched = { - "lr": combined_cos( - train_conf["lr_ratio"], - train_conf["lr_start"], - train_conf["lr_max"], - train_conf["lr_stop"], - ) - } - cbfs.extend([ParamScheduler(sched)]) - - if train_conf["gpu"]: - cbfs.extend([CudaCallback]) - - cbfs.extend( - [ - SaveTempCallback(model_path=model_path), - AvgLossCallback, - DataAug, - ] - ) - - # use switch loss - if train_conf["switch_loss"]: - cbfs.extend( - [ - SwitchLoss( - second_loss=loss_functions.comb_likelihood, - when_switch=train_conf["when_switch"], - ), - ] - ) - - if train_conf["comet_ml"] and not lr_find and not plot_loss: - cbfs.extend( - [ - CometCallback( - name=train_conf["project_name"], - validation_data=train_conf["data_path"], - plot_n_epochs=train_conf["plot_n_epochs"], - amp_phase=train_conf["amp_phase"], - scale=train_conf["scale"], - ), - ] - ) - - if not plot_loss and train_conf["normalize"] != "none": - cbfs.extend([Normalize(train_conf)]) - # get loss func - if train_conf["loss_func"] == "feature_loss": - loss_func = loss_functions.init_feature_loss() - else: - loss_func = getattr(loss_functions, train_conf["loss_func"]) - - # Combine model and data in learner - learn = get_learner( - data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func - ) - return learn diff --git a/src/radionets/core/logging.py b/src/radionets/core/logging.py index be76f22f..2f6e4782 100644 --- a/src/radionets/core/logging.py +++ b/src/radionets/core/logging.py @@ -1,9 +1,68 @@ import logging +from lightning.pytorch.loggers import CSVLogger +from pydantic import BaseModel from rich.logging import RichHandler +logging.basicConfig( + level="INFO", + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(rich_tracebacks=True)], +) -def setup_logger(namespace="rich", level="INFO", **kwargs): + +class Loggers: + @classmethod + def get_loggers(cls, train_config: BaseModel) -> list: + default_logger = CSVLogger( + save_dir=train_config.paths.model_path, + **train_config.logging.default_logger.model_dump(), + ) + default_logger._name = train_config.logging.project_name + + loggers = [default_logger] + + if train_config.logging.comet_ml: + try: + from lightning.pytorch.loggers import CometLogger + except ImportError as e: + raise ModuleNotFoundError( + "'comet_ml' was set to 'true' in your training config but " + "radionets could not import 'CometLogger'. This usually " + "indicates that 'comet_ml' is missing from your environment. " + "You can install it using 'uv pip install comet_ml'." + ) from e + + comet_logger = CometLogger( + project=train_config.logging.project_name, + api_key=train_config.logging.comet_ml.api_key.get_secret_value(), + **train_config.logging.comet_ml.model_dump(exclude="api_key"), + ) + loggers.append(comet_logger) + + if train_config.logging.mlflow: + try: + from lightning.pytorch.loggers import MLFlowLogger + except ImportError as e: + raise ModuleNotFoundError( + "'mlflow' was set to 'true' in your training config but " + "radionets could not import 'MLflowLogger'. This usually " + "indicates that 'mlflow' is missing from your environment. " + "You can install it using 'uv pip install mlflow'." + ) from e + + mlflow_logger = MLFlowLogger( + experiment_name=train_config.logging.project_name, + save_dir=train_config.paths.model_path, + **train_config.logging.mlflow.model_dump(), + ) + loggers.append(mlflow_logger) + + return loggers + + +def _setup_logger(namespace="rich", level="INFO", **kwargs): """Basic logging setup. Uses :class:`~rich.logging.RichHandler` for formatting and highlighting of the log. diff --git a/src/radionets/core/loss_functions.py b/src/radionets/core/loss_functions.py deleted file mode 100644 index aa353cd2..00000000 --- a/src/radionets/core/loss_functions.py +++ /dev/null @@ -1,134 +0,0 @@ -import numpy as np -import torch -from torch import nn - -__all__ = [ - "beta_nll_loss", - "create_circular_mask", - "jet_seg", - "l1", - "mse", - "splitted_L1", - "splitted_L1_masked", -] - - -def l1(x, y): - pred = x["pred"] - - l1 = nn.L1Loss() - loss = l1(pred, y) - - return loss - - -def create_circular_mask(h, w, center=None, radius=None, bs=64): - if center is None: - center = (int(w / 2), int(h / 2)) - - if radius is None: - radius = min(center[0], center[1], w - center[0], h - center[1]) - - Y, X = np.ogrid[:h, :w] - dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) - - mask = dist_from_center <= radius - - return np.repeat([mask], bs, axis=0) - - -def splitted_L1_masked(x, y): - pred = x["pred"] - inp_amp = pred[:, 0, :] - inp_phase = pred[:, 1, :] - - tar_amp = y[:, 0, :] - tar_phase = y[:, 1, :] - - mask = torch.tensor(create_circular_mask(256, 256, radius=50, bs=y.shape[0])) - - inp_amp[~mask] *= 0.3 - inp_phase[~mask] *= 0.3 - tar_amp[~mask] *= 0.3 - tar_phase[~mask] *= 0.3 - - l1 = nn.L1Loss() - loss_amp = l1(inp_amp, tar_amp) - loss_phase = l1(inp_phase, tar_phase) - loss = loss_amp + loss_phase - - return loss - - -def splitted_L1(x, y): - pred = x["pred"] - inp_amp = pred[:, 0, :] - inp_phase = pred[:, 1, :] - - tar_amp = y[:, 0, :] - tar_phase = y[:, 1, :] - - l1 = nn.L1Loss() - loss_amp = l1(inp_amp, tar_amp) - loss_phase = l1(inp_phase, tar_phase) - loss = loss_amp + loss_phase - - return loss - - -def beta_nll_loss(x: torch.tensor, y: torch.tensor, beta: float = 0.5): - """Compute beta-NLL loss - - Parameters - ---------- - x : :func:`torch.tensor` - Prediction of the model. - y : :func:`torch.tensor` - Ground truth. - beta : float - Parameter from range [0, 1] controlling relative - weighting between data points, where "0" corresponds to - high weight on low error points and "1" to an equal weighting. - - Returns - ------- - float : Loss per batch element of shape B - """ - pred = x["pred"] - pred_amp = pred[:, 0, :] - pred_phase = pred[:, 2, :] - mean = torch.stack([pred_amp, pred_phase], axis=1) - - unc_amp = pred[:, 1, :] - unc_phase = pred[:, 3, :] - variance = torch.stack([unc_amp, unc_phase], axis=1) - - tar_amp = y[:, 0, :] - tar_phase = y[:, 1, :] - target = torch.stack([tar_amp, tar_phase], axis=1) - - loss = 0.5 * ((target - mean) ** 2 / variance + variance.log()) - - if beta > 0: - loss = loss * variance.detach() ** beta - - return loss.mean() - - -def mse(x, y): - pred = x["pred"] - mse = nn.MSELoss() - loss = mse(pred, y) - - return loss - - -def jet_seg(x, y): - pred = x["pred"] - - # weight components farer outside more - loss_l1_weighted = 0 - for i in range(pred.shape[1]): - loss_l1_weighted += l1(pred[:, i], y[:, i]) * (i + 1) - - return loss_l1_weighted diff --git a/src/radionets/core/model.py b/src/radionets/core/model.py deleted file mode 100644 index 7502336c..00000000 --- a/src/radionets/core/model.py +++ /dev/null @@ -1,104 +0,0 @@ -from pathlib import Path - -import torch -from torch import nn - -from radionets.core.logging import setup_logger - -__all__ = [ - "init_cnn", - "load_pre_model", - "save_model", -] - -LOGGER = setup_logger(namespace=__name__) - - -def _init_cnn(m, f): - if isinstance(m, nn.Conv2d): - f(m.weight, a=0.1) - if getattr(m, "bias", None) is not None: - m.bias.data.zero_() - for c in m.children(): - _init_cnn(c, f) - - -def init_cnn(m, uniform=False): - f = nn.init.kaiming_uniform_ if uniform else nn.init.kaiming_normal_ - _init_cnn(m, f) - - -def load_pre_model(learn, pre_path, visualize=False, plot_loss=False): - """Loads a previously saved model as pre-model. - - Parameters - ---------- - learn : learner - Object of type learner. - pre_path : str - Path to the pre-model. - visualize : bool - Default: False - plot_loss : bool - Default: False - """ - name_pretrained = Path(pre_path).stem - LOGGER.info(f"Load pretrained model: {name_pretrained}") - - if torch.cuda.is_available() and not plot_loss: - checkpoint = torch.load(pre_path) - else: - checkpoint = torch.load(pre_path, map_location=torch.device("cpu")) - - if visualize: - learn.load_state_dict(checkpoint["model"]) - return checkpoint["norm_dict"] - elif plot_loss: - learn.avg_loss.loss_train = checkpoint["train_loss"] - learn.avg_loss.loss_valid = checkpoint["valid_loss"] - learn.avg_loss.lrs = checkpoint["lrs"] - else: - learn.model.load_state_dict(checkpoint["model"]) - learn.opt.load_state_dict(checkpoint["opt"]) - learn.epoch = checkpoint["epoch"] - learn.avg_loss.loss_train = checkpoint["train_loss"] - learn.avg_loss.loss_valid = checkpoint["valid_loss"] - learn.avg_loss.lrs = checkpoint["lrs"] - learn.recorder.iters = checkpoint["iters"] - learn.recorder.values = checkpoint["vals"] - - -def save_model(learn, model_path): - if hasattr(learn, "normalize"): - if learn.normalize.mode == "mean": - norm_dict = { - "mean_real": learn.normalize.mean_real, - "mean_imag": learn.normalize.mean_imag, - "std_real": learn.normalize.std_real, - "std_imag": learn.normalize.std_imag, - } - elif learn.normalize.mode == "max": - norm_dict = {"max_scaling": 0} - elif learn.normalize.mode == "all": - norm_dict = {"all": 0} - elif not learn.normalize.mode: - norm_dict = {} - else: - raise ValueError(f"Undefined mode {learn.normalize.mode}, check for typos") - else: - norm_dict = {} - - torch.save( - { - "model": learn.model.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "train_loss": learn.avg_loss.loss_train, - "valid_loss": learn.avg_loss.loss_valid, - "lrs": learn.avg_loss.lrs, - "norm_dict": norm_dict, - }, - model_path, - ) diff --git a/src/radionets/core/utils.py b/src/radionets/core/utils.py deleted file mode 100644 index ed11c878..00000000 --- a/src/radionets/core/utils.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np -import torch - - -def _maybe_item(t): - t = t.value - return t.item() if isinstance(t, torch.Tensor) and t.numel() == 1 else t - - -def get_ifft_torch(array, amp_phase=False, scale=False, uncertainty=False): - if len(array.shape) == 3: - array = array.unsqueeze(0) - - if amp_phase: - amp = 10 ** (10 * array[:, 0] - 10) - 1e-10 if scale else array[:, 0] - - if uncertainty: - a = amp * torch.cos(array[:, 2]) - b = amp * torch.sin(array[:, 2]) - else: - a = amp * torch.cos(array[:, 1]) - b = amp * torch.sin(array[:, 1]) - compl = a + b * 1j - else: - compl = array[:, 0] + array[:, 1] * 1j - - if compl.shape[0] == 1: - compl = compl.squeeze(0) - - return torch.abs(torch.fft.ifftshift(torch.fft.ifft2(torch.fft.fftshift(compl)))) - - -def split_real_imag(array): - """ - takes a complex array and returns the real and the imaginary part - """ - return array.real, array.imag - - -def split_amp_phase(array): - """ - takes a complex array and returns the amplitude and the phase - """ - amp = np.abs(array) - phase = np.angle(array) - return amp, phase diff --git a/src/radionets/evaluation/blob_detection.py b/src/radionets/evaluation/blob_detection.py deleted file mode 100644 index dbec8da8..00000000 --- a/src/radionets/evaluation/blob_detection.py +++ /dev/null @@ -1,86 +0,0 @@ -from math import sqrt - -import numpy as np -import torch -from skimage.feature import blob_log - - -def calc_blobs(ifft_pred, ifft_truth): - if isinstance(ifft_pred, torch.Tensor): - ifft_pred = ifft_pred.numpy() - - if isinstance(ifft_truth, torch.Tensor): - ifft_truth = ifft_truth.numpy() - - tresh = ifft_truth.max() * 0.1 - kwargs = { - "min_sigma": 1, - "max_sigma": 10, - "num_sigma": 100, - "threshold": tresh, - "overlap": 0.9, - } - - blobs_log_pred = blob_log(ifft_pred, **kwargs) - blobs_log_truth = blob_log(ifft_truth, **kwargs) - - # Compute radii in the 3rd column. - blobs_log_pred[:, 2] = blobs_log_pred[:, 2] * sqrt(2) - blobs_log_truth[:, 2] = blobs_log_truth[:, 2] * sqrt(2) - - return blobs_log_pred, blobs_log_truth - - -def crop_first_component(pred, truth, blob_truth): - """Returns the cropped image with the first component of the - true image for both prediction and truth. - - Parameters - ---------- - pred : ndarray - predicted source image - truth : ndarray - true source image - blob_truth : list - list with the coordiantes for the first component - - Returns - ------- - ndarray - cropped images - """ - y, x, r = blob_truth - x_coord, y_coord = corners(y, x, r) - - flux_truth = truth[x_coord[0] : x_coord[1], y_coord[0] : y_coord[1]] - flux_pred = pred[x_coord[0] : x_coord[1], y_coord[0] : y_coord[1]] - - return flux_pred, flux_truth - - -def corners(x, y, r): - """Generates the value range for cropping the first component out of - the images. - - Parameters - ---------- - y : float - y coordiante - x : float - x coordinate - r : float - radius of the first component - - Returns - ------- - list - start and end point for the cropping - """ - r = int(np.round(r)) - x = int(x) - y = int(y) - - x_coord = [x - r, x + r + 1] - y_coord = [y - r, y + r + 1] - - return x_coord, y_coord diff --git a/src/radionets/evaluation/contour.py b/src/radionets/evaluation/contour.py index 42516c42..e6743457 100644 --- a/src/radionets/evaluation/contour.py +++ b/src/radionets/evaluation/contour.py @@ -1,52 +1,70 @@ -import matplotlib as mpl +from __future__ import annotations + +from typing import TYPE_CHECKING + import matplotlib.pyplot as plt import numpy as np +from torch import Tensor + +if TYPE_CHECKING: + from matplotlib.contour import QuadContourSet + from numpy.typing import ArrayLike -def compute_area_ratio(CS_pred, CS_truth): - """Compute the ratio of the areas of truth and prediction. +def _compute_source_area(vertices: ArrayLike) -> float: + """Helper function to compute area of a source + using the shoelace formula. Parameters ---------- - CS_pred : contour object - contour object of prediction - CS_truth : contour object - contour object of truth + vertices : :func:`~numpy.ndarray`, shape (N, 2) + Polygon (source) vertices as (x, y) coordinates. Returns ------- float - ratio between area of truth and prediction + Area of the source. """ - areas_truth = np.array([]) - areas_pred = np.array([]) + x = vertices[:, 0] + y = vertices[:, 1] + + s1 = np.dot(x, np.roll(y, -1)) + s2 = np.dot(y, np.roll(x, -1)) - for area in CS_truth.get_paths(): - truth_x = area.vertices[:, 0] - truth_y = area.vertices[:, 1] + return 0.5 * np.abs(s1 - s2) - area_truth = 0.5 * np.sum( - truth_y[:-1] * np.diff(truth_x) - truth_x[:-1] * np.diff(truth_y) - ) - area_truth = np.abs(area_truth) - areas_truth = np.append(areas_truth, area_truth) - for area in CS_pred.get_paths(): - pred_x = area.vertices[:, 0] - pred_y = area.vertices[:, 1] +def compute_area_ratio(cs_pred: QuadContourSet, cs_truth: QuadContourSet) -> float: + """Computes the ratio of true and predicted source areas. - area_pred = 0.5 * np.sum( - pred_y[:-1] * np.diff(pred_x) - pred_x[:-1] * np.diff(pred_y) - ) - area_pred = np.abs(area_pred) - areas_pred = np.append(areas_pred, area_pred) + Parameters + ---------- + cs_pred : :class:`~matplotlib.contour.QuadContourSet` + contour object of prediction + cs_truth : :class:`~matplotlib.contour.QuadContourSet` + contour object of truth + + Returns + ------- + float + Ratio between true and predicted source areas. + """ + areas_pred = np.array( + [_compute_source_area(path.vertices) for path in cs_pred.get_paths()] + ) + areas_truth = np.array( + [_compute_source_area(path.vertices) for path in cs_truth.get_paths()] + ) return areas_pred.sum() / areas_truth.sum() -def area_of_contour(ifft_pred, ifft_truth): - """Create first contour of prediction and truth and return - the area ratio. +def area_of_contour( + ifft_pred: ArrayLike, + ifft_truth: ArrayLike, + level: float = 0.05, +) -> float: + """Compute area ratio at 5% of the maximum of prediction and truth. Parameters ---------- @@ -60,34 +78,52 @@ def area_of_contour(ifft_pred, ifft_truth): float area difference """ - mpl.use("Agg") + levels = [ifft_truth.max() * level] + + fig, ax = plt.subplots() + cs_pred = ax.contour(ifft_pred, levels=levels) + cs_truth = ax.contour(ifft_truth, levels=levels) + plt.close(fig) + + return compute_area_ratio(cs_pred, cs_truth) + - levels = [ifft_truth.max() * 0.05] +def analyse_intensity(pred: ArrayLike, truth: ArrayLike) -> tuple[float, float]: + """Compute intensity ratios between prediction + and ground truth images. - CS1 = plt.contour(ifft_pred, levels=levels) + Parameters + ---------- + pred : :func:`~numpy.ndarray`, shape (..., H, W) + Prediction image(s). + truth : :func:`~numpy.ndarray`, shape (..., H, W) + Ground truth image(s). - plt.close() + Returns + ------- + sum_ratio : :func:`~numpy.ndarray` + Ratio of summed intensities (prediction / truth). + peak_ratio : :func:`~numpy.ndarray` + Ratio of peak intensities (prediction / truth). + """ + if pred.ndim == 2: + pred = pred[None, ...] - CS2 = plt.contour(ifft_truth, levels=levels) + if truth.ndim == 2: + truth = truth[None, ...] - val = compute_area_ratio(CS1, CS2) - mpl.rcParams.update(mpl.rcParamsDefault) - return val + if isinstance(pred, Tensor): + pred = pred.detach().cpu().numpy() + if isinstance(truth, Tensor): + truth = truth.detach().cpu().numpy() -def analyse_intensity(pred, truth): - if len(pred.shape) == 2: - pred = pred.reshape(1, pred.shape[-2], pred.shape[-1]) - truth = truth.reshape(1, truth.shape[-2], truth.shape[-1]) + threshold = truth.max(axis=(-2, -1), keepdims=True) * 0.05 - threshold = (truth.max(-1).max(-1) * 0.05).reshape(truth.shape[0], 1, 1) source_truth = np.where(truth > threshold, truth, 0) source_pred = np.where(pred > threshold, pred, 0) - sum_truth = source_truth.sum(-1).sum(-1) - sum_pred = source_pred.sum(-1).sum(-1) - - peak_truth = source_truth.max(-1).max(-1) - peak_pred = source_pred.max(-1).max(-1) + sum_ratio = source_pred.sum(axis=(-2, -1)) / source_truth.sum(axis=(-2, -1)) + peak_ratio = source_pred.max(axis=(-2, -1)) / source_truth.max(axis=(-2, -1)) - return sum_pred / sum_truth, peak_pred / peak_truth + return sum_ratio, peak_ratio diff --git a/src/radionets/evaluation/dynamic_range.py b/src/radionets/evaluation/dynamic_range.py index 5458153e..a15346f5 100644 --- a/src/radionets/evaluation/dynamic_range.py +++ b/src/radionets/evaluation/dynamic_range.py @@ -1,73 +1,181 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray -def get_boxsize(num_corners, num_pixel=63): - factors = np.array([0.3, 0.22, 0.16]) - size = int(num_pixel * factors[num_corners - 2]) - return size +_BOX_FACTORS = np.array([0.3, 0.22, 0.16]) -def select_box(rms, sensitivity=1e-6): - for arr in rms: - arr[arr > sensitivity] = 0 - rms_boxes = rms.astype(bool).sum(axis=0) - return rms_boxes +def get_boxsize(num_corners: int, num_pixel: int = 63) -> int: + """ + Compute corner box size based on number of corners used. + Parameters + ---------- + num_corners : int + Number of corners to use (2, 3, or 4). + num_pixel : int, optional + Image size in pixels. Default: 63 -def compute_rms(batch, size): - rms1 = rms2 = rms3 = rms4 = np.ones(len(batch)) * -1 - rms1 = np.sqrt((batch[:, :size, :size].reshape(-1, size**2) ** 2).mean(axis=1)) - rms2 = np.sqrt((batch[:, :size, -size:].reshape(-1, size**2) ** 2).mean(axis=1)) - rms3 = np.sqrt((batch[:, -size:, :size].reshape(-1, size**2) ** 2).mean(axis=1)) - rms4 = np.sqrt((batch[:, -size:, -size:].reshape(-1, size**2) ** 2).mean(axis=1)) - return np.stack([rms1, rms2, rms3, rms4], axis=0) + Returns + ------- + int + Box size in pixels. + """ + return int(num_pixel * _BOX_FACTORS[num_corners - 2]) -def get_rms(ifft_truth, ifft_pred): - rms_4_truth = compute_rms(ifft_truth, get_boxsize(4)) - rms_boxes = select_box(rms_4_truth, 1e-6) - rms_3_truth = compute_rms(ifft_truth, get_boxsize(3)) - select_box(rms_3_truth) - rms_2_truth = compute_rms(ifft_truth, get_boxsize(2)) - select_box(rms_2_truth) +def select_box(rms: NDArray, sensitivity: float = 1e-6) -> NDArray: + """ + Select valid corner boxes based on RMS threshold. - rms_4_pred = compute_rms(ifft_pred, get_boxsize(4)) - rms_3_pred = compute_rms(ifft_pred, get_boxsize(3)) - rms_2_pred = compute_rms(ifft_pred, get_boxsize(2)) + Parameters + ---------- + rms : :func:`~numpy.ndarray`, shape (4, B) + RMS values for each corner. + sensitivity : float, optional + Threshold below which corners are considered valid. + Default: 1e-6. - rms_3_pred[rms_3_truth == 0] = 0 - rms_2_pred[rms_2_truth == 0] = 0 + Returns + ------- + :func:`numpy.ndarray`, shape (B,) + Number of valid corners per sample. + """ + valid_corners = rms <= sensitivity + return valid_corners.sum(axis=0) - rms_truth = np.zeros(len(rms_boxes)) - rms_truth[rms_boxes == 4] = ( - np.sqrt(rms_4_truth[0:4, rms_boxes == 4] ** 2).sum(axis=0) / 4 - ) - rms_truth[rms_boxes == 3] = ( - np.sqrt(rms_3_truth[0:4, rms_boxes == 3] ** 2).sum(axis=0) / 3 - ) - rms_truth[rms_boxes == 2] = ( - np.sqrt(rms_2_truth[0:4, rms_boxes == 2] ** 2).sum(axis=0) / 2 - ) - rms_pred = np.zeros(len(rms_boxes)) - rms_pred[rms_boxes == 4] = ( - np.sqrt(rms_4_pred[0:4, rms_boxes == 4] ** 2).sum(axis=0) / 4 - ) - rms_pred[rms_boxes == 3] = ( - np.sqrt(rms_3_pred[0:4, rms_boxes == 3] ** 2).sum(axis=0) / 3 - ) - rms_pred[rms_boxes == 2] = ( - np.sqrt(rms_2_pred[0:4, rms_boxes == 2] ** 2).sum(axis=0) / 2 +def compute_rms(batch: ArrayLike, size: int) -> NDArray: + """ + Compute RMS in all four corner boxes. + + Parameters + ---------- + batch : :func:`~numpy.ndarray` + Batch of images, shape (B, H, W). + size : int + Corner box size in pixels. + + Returns + ------- + :func:`numpy.ndarray` + RMS values for each corner, shape (4, B). + """ + corners = np.stack( + [ + batch[:, :size, :size], # top left + batch[:, :size, -size:], # top right + batch[:, -size:, :size], # bottom left + batch[:, -size:, -size:], # bottom right + ] ) - corners = np.ones((rms_4_truth.shape[-1], 4)) - corners[rms_4_truth.swapaxes(1, 0) == 0] = 0 + + return np.sqrt((corners.reshape(4, len(batch), size * size) ** 2).mean(axis=2)) + + +def get_rms( + ifft_truth: NDArray, + ifft_pred: NDArray, + sensitivity: float = 1e-6, +) -> tuple[NDArray, NDArray, NDArray, NDArray]: + """ + Compute RMS values for ground truth and prediction. + + Parameters + ---------- + ifft_truth : :func:`numpy.ndarray`, shape (B, H, W) + Ground truth images. + ifft_pred : :func:`numpy.ndarray`, shape (B, H, W) + Predicted images. + sensitivity : float, optional + Threshold below which corners are considered valid. + Default: 1e-6. + + Returns + ------- + rms_truth : :func:`~numpy.ndarray`, shape (B,) + Averaged RMS for ground truth. + rms_pred : :func:`~numpy.ndarray`, shape (B,) + Averaged RMS for predictions. + rms_boxes : :func:`~numpy.ndarray`, shape (B,) + Number of valid corners per sample. + corners : :func:`~numpy.ndarray`, shape (B, 4) + Corner validity mask. + """ + _rms_truth_boxes = {} + _rms_pred_boxes = {} + + for num_corners in [4, 3, 2]: + size = get_boxsize(num_corners) + _rms_truth_boxes[num_corners] = compute_rms(ifft_truth, size) + _rms_pred_boxes[num_corners] = compute_rms(ifft_pred, size) + + rms_boxes = select_box(_rms_truth_boxes[4], sensitivity=sensitivity) + current_batch_size = len(ifft_pred) + + corners = (_rms_truth_boxes[4] <= sensitivity).T.astype(np.float64) + + for num_corners in [3, 2]: + invalid_mask = _rms_truth_boxes[num_corners] > sensitivity + _rms_pred_boxes[4][invalid_mask] = 0 + + rms_truth = np.zeros(current_batch_size) + rms_pred = np.zeros(current_batch_size) + + for num_corners in [4, 3, 2]: + mask = rms_boxes == num_corners + + if not mask.any(): + continue + + rms_truth[mask] = ( + np.abs(_rms_truth_boxes[num_corners][:, mask]).sum(axis=0) / num_corners + ) + rms_pred[mask] = ( + np.abs(_rms_pred_boxes[num_corners][:, mask]).sum(axis=0) / num_corners + ) + return rms_truth, rms_pred, rms_boxes, corners def calc_dr(ifft_truth, ifft_pred): + """ + Calculate dynamic range for ground truth and predicted images. + + The dynamic range is the peak value divided by RMS + noise in corner (off-)regions (i.e., where no signal is expected). + + Parameters + ---------- + ifft_truth : :func:`~numpy.ndarray` + Ground truth inverse FFT images (image space), shape (B, H, W). + ifft_pred : :func:`~numpy.ndarray` + Predicted inverse FFT images (image space), shape (B, H, W). + + Returns + ------- + dr_truth : :func:`~numpy.ndarray` + Dynamic range for truth. + dr_pred : :func:`~numpy.ndarray` + Dynamic range for predictions. + rms_boxes : np. ndarray + Number of valid corners per sample. + corners : :func:`~numpy.ndarray` + Corner validity mask. + """ rms_truth, rms_pred, rms_boxes, corners = get_rms(ifft_truth, ifft_pred) - peak_vals_truth = ifft_truth.reshape(-1, ifft_truth.shape[-1] ** 2).max(axis=1) - peak_vals_pred = ifft_pred.reshape(-1, ifft_pred.shape[-1] ** 2).max(axis=1) - dr_truth = peak_vals_truth[rms_truth != 0] / rms_truth[rms_truth != 0] - dr_pred = peak_vals_pred[rms_pred != 0] / rms_pred[rms_pred != 0] + + peak_truth = ifft_truth.reshape(len(ifft_truth), -1).max(axis=1) + peak_pred = ifft_pred.reshape(len(ifft_pred), -1).max(axis=1) + + valid_truth = rms_truth != 0 + valid_pred = rms_pred != 0 + dr_truth = peak_truth[valid_truth] / rms_truth[valid_truth] + dr_pred = peak_pred[valid_pred] / rms_pred[valid_pred] + return dr_truth, dr_pred, rms_boxes, corners diff --git a/src/radionets/evaluation/feature.py b/src/radionets/evaluation/feature.py new file mode 100644 index 00000000..a29729cc --- /dev/null +++ b/src/radionets/evaluation/feature.py @@ -0,0 +1,124 @@ +"""Feature detection submodule.""" + +from __future__ import annotations + +from math import sqrt +from typing import TYPE_CHECKING + +import numpy as np +import torch +from skimage.feature import blob_log + +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + +def calc_blobs( + ifft_pred: ArrayLike, + ifft_truth: ArrayLike, +) -> tuple[NDArray, NDArray]: + """Detect blobs using Laplacian of Gaussian in prediction + and truth images. + + Parameters + ---------- + ifft_pred : :class:`~torch.Tensor` or :func:`numpy.ndarray` + Predicted image (inverse FFT result), shape (N, 3). + ifft_truth : :class:`~torch.Tensor` or :func:`numpy.ndarray` + Ground truth image (inverse FFT result), shape (N, 3). + + Returns + ------- + blobs_log_pred : :func:`~numpy.ndarray` + Detected blobs in prediction, shape (N, 3) with columns [y, x, radius]. + blobs_log_truth : :func:`~numpy.ndarray` + Detected blobs in ground truth, shape (N, 3) with columns [y, x, radius]. + """ + if isinstance(ifft_pred, torch.Tensor): + ifft_pred = ifft_pred.detach().cpu().numpy() + + if isinstance(ifft_truth, torch.Tensor): + ifft_truth = ifft_truth.detach().cpu().numpy() + + threshold = ifft_truth.max() * 0.1 + kwargs = { + "min_sigma": 1, + "max_sigma": 10, + "num_sigma": 100, + "threshold": threshold, + "overlap": 0.9, + } + + blobs_log_pred = blob_log(ifft_pred, **kwargs) + blobs_log_truth = blob_log(ifft_truth, **kwargs) + + # Compute radii in the 3rd column. + blobs_log_pred[:, 2] = blobs_log_pred[:, 2] * sqrt(2) + blobs_log_truth[:, 2] = blobs_log_truth[:, 2] * sqrt(2) + + return blobs_log_pred, blobs_log_truth + + +def crop_first_component( + pred: ArrayLike, + truth: ArrayLike, + blob_truth: list | tuple, +) -> tuple[NDArray, NDArray]: + """Return cropped images around the first component of the true image. + + Parameters + ---------- + pred : :func:`~numpy.ndarray` + Predicted source image. + truth : :func:`~numpy.ndarray` + True source image. + blob_truth : list or tuple + Coordinates (y, x, r) for the first component. + + Returns + ------- + flux_pred : :func:`~numpy.ndarray` + Cropped prediction image. + flux_truth : :func:`~numpy.ndarray` + Cropped truth image. + """ + y, x, r = blob_truth + x_coord, y_coord = _corners(y, x, r) + + flux_truth = truth[x_coord[0] : x_coord[1], y_coord[0] : y_coord[1]] + flux_pred = pred[x_coord[0] : x_coord[1], y_coord[0] : y_coord[1]] + + return flux_pred, flux_truth + + +def _corners( + x: int | float, + y: int | float, + r: int | float, +) -> tuple[list[int], list[int]]: + """Generate coordinate ranges for cropping the first component. + + Parameters + ---------- + x : int or float + X coordinate of the component center. + y : int or float + Y coordinate of the component center. + r : int or float + Radius of the first component. + + Returns + ------- + x_coord : list of int + Start and end indices for x-axis cropping. + y_coord : list of int + Start and end indices for y-axis cropping. + """ + r = int(np.round(r)) + x = int(x) + y = int(y) + + x_coord = [x - r, x + r + 1] + y_coord = [y - r, y + r + 1] + + return x_coord, y_coord diff --git a/src/radionets/evaluation/jet_angle.py b/src/radionets/evaluation/jet_angle.py index 270db78b..a7c4fb2b 100644 --- a/src/radionets/evaluation/jet_angle.py +++ b/src/radionets/evaluation/jet_angle.py @@ -3,109 +3,146 @@ import torch -def bmul(vec, mat, axis=0): +def bmul(vec: torch.Tensor, mat: torch.Tensor, axis: int = 0) -> torch.Tensor: """Expand vector for batchwise matrix multiplication. Parameters ---------- - vec : 2dtensor - vector for multiplication - mat : 3dtensor - matrix for multiplication + vec : :class:`~torch.Tensor`, shape (B, N) + Vector for multiplication. + mat : :class:`~torch.Tensor`, shape (B, N, M) + Matrix for multiplication. axis : int, optional - batch axis, by default 0 + Batch axis. Default: ``0`` Returns ------- - 3dtensor - Product of matrix multiplication. (bs, n, m) + :class:`~torch.Tensor`, shape (B, N, M) + Product of matrix multiplication. """ mat = mat.transpose(axis, -1) return (mat * vec.expand_as(mat)).transpose(axis, -1) -def pca(image): - """Compute the major components of an image. The Image is treated as a - distribution. +def _im2array_value( + image: torch.tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transforms the image to an array of pixel coordinates and + its intensities. Parameters ---------- - image : Image or 2DArray (N, M) - Image to be used as distribution + image: :class:`~torch.Tensor`, shape (B, H, W) + Batch of images to be transformed. Returns ------- - cog_x : + x_coords : :class:`~torch.Tensor`, shape (B, H * W) + Contains the x-position of every pixel in the image + y_coords : :class:`~torch.Tensor`, shape (B, H * W) + Contains the y-position of every pixel in the image + value : :class:`~torch.Tensor`, shape (B, H * W) + Contains the intensity value corresponding to every x-y-pair + """ + # NOTE: This assumes quadratic images + batch_size, img_size, _ = image.shape + device = image.device + + a = torch.arange(img_size, device=device) + grid_x, grid_y = torch.meshgrid(a, a, indexing="xy") + + x_coords = grid_x.ravel().unsqueeze(0).expand(batch_size, -1) + y_coords = grid_y.ravel().unsqueeze(0).expand(batch_size, -1) + value = image.reshape(-1, img_size**2) + + return x_coords, y_coords, value + + +def pca( + image: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the major components of an image. The image is treated + as a 2D distribution. + + Parameters + ---------- + image : :class:`~torch.Tensor`, shape (B, H, W) + Images to be used as distribution + + Returns + ------- + cog_x : :class:`~torch.Tensor`, shape (B, 1) X-position of the distributions center of gravity - cog_y : + cog_y : :class:`~torch.Tensor`, shape (B, 1) Y-position of the distributions center of gravity - psi : - Angle between first mjor component and x-axis + psi : :class:`~torch.Tensor`, shape (B,) + Angle between first major component and x-axis """ - torch.set_printoptions(precision=16) + pix_x, pix_y, image = _im2array_value(image) - pix_x, pix_y, image = im_to_array_value(image) - - cog_x = (torch.sum(pix_x * image, axis=1) / torch.sum(image, axis=1)).unsqueeze(-1) - cog_y = (torch.sum(pix_y * image, axis=1) / torch.sum(image, axis=1)).unsqueeze(-1) + image_sum = image.sum(dim=1, keepdim=True) + cog_x = (pix_x * image).sum(dim=1, keepdim=True) / image_sum + cog_y = (pix_y * image).sum(dim=1, keepdim=True) / image_sum delta_x = pix_x - cog_x delta_y = pix_y - cog_y - inp = torch.cat([delta_x.unsqueeze(1), delta_y.unsqueeze(1)], dim=1) + inp = torch.stack([delta_x, delta_y], dim=1) cov_w = bmul( (cog_x - 1 * torch.sum(image * image, axis=1).unsqueeze(-1) / cog_x).squeeze(1), (torch.matmul(image.unsqueeze(1) * inp, inp.transpose(1, 2))), ) - eig_vals_torch, eig_vecs_torch = torch.linalg.eigh(cov_w, UPLO="U") - + _, eig_vecs_torch = torch.linalg.eigh(cov_w, UPLO="U") psi_torch = torch.atan(eig_vecs_torch[:, 1, 1] / eig_vecs_torch[:, 0, 1]) return cog_x, cog_y, psi_torch -def calc_jet_angle(image): - """Caluclate the jet angle from an image created with gaussian sources. This - is achieved by a PCA. +def calc_jet_angle( + image: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate the jet angle from an image consisting of + (approx.) gaussian sources using a PCA. Parameters ---------- - image : ndarray - input image + image : :class:`~torch.Tensor`, shape (B, H, W) + Input images Returns ------- - float - slope of the line - float - intercept of the line - float - angle between the horizontal axis and the jet axis + m : :class:`~torch.Tensor`, shape (B,) + Slope of the line + n : :class:`~torch.Tensor`, shape (B,) + Intercept of the line + alpha : :class:`~torch.Tensor`, shape (B,) + Angle between the horizontal axis and the jet axis """ if not isinstance(image, torch.Tensor): - image = torch.tensor(image) + image = torch.as_tensor(image) + image = image.clone() - img_size = image.shape[-1] - # ignore negative pixels, which can appear in predictions - image[image < 0] = 0 - if len(image.shape) == 2: + # ignore negative pixels that can appear in predictions + image = image.clamp(min=0) + + if image.ndim == 2: image = image.unsqueeze(0) - batch_size = image.shape[0] + batch_size, img_size, _ = image.shape - # only use brightest pixel - max_val = torch.tensor([(i.max() * 0.4) for i in image]) - max_arr = (torch.ones(img_size, img_size, batch_size) * max_val).permute(2, 0, 1) - image[image < max_arr] = 0 + # only use pixels above 40% of peak flux + max_vals = image.view(1, -1).max(dim=1).values + threshold = (0.4 * max_vals).view(batch_size, 1, 1) + image = torch.where(image >= threshold, image, torch.zeros_like(image)) _, _, alpha_pca = pca(image) # Search for sources with two maxima maxima = [] - for i in range(image.shape[0]): - a = torch.where(image[i] == image[i].max()) + for img in image: + a = torch.where(img == img.max()) if len(a[0]) > 1: # if two maxima are found, interpolate to the middle in x and y direction mid_x = (a[0][1] - a[0][0]) // 2 + a[0][0] @@ -118,37 +155,8 @@ def calc_jet_angle(image): x_mid = vals[:, 0] y_mid = vals[:, 1] - m = torch.tan(pi / 2 - alpha_pca) + m = torch.tan(torch.tensor(pi / 2, device=image.device) - alpha_pca) n = y_mid - m * x_mid - alpha = (alpha_pca) * 180 / pi - return m, n, alpha - - -def im_to_array_value(image): - """Transforms the image to an array of pixel coordinates and the containt - intensity - - Parameters - ---------- - image: Image or 2DArray (N, M) - Image to be transformed + alpha = torch.rad2deg(alpha_pca) - Returns - ------- - x_coords : array_like - Contains the x-pixel-position of every pixel in the image - y_coords: array_like - Contains the y-pixel-position of every pixel in the image - value: array_like - Contains the image-value corresponding to every x-y-pair - - """ - num = image.shape[0] - pix = image.shape[-1] - - a = torch.arange(0, pix, 1) - grid_x, grid_y = torch.meshgrid(a, a, indexing="xy") - x_coords = torch.cat(num * [grid_x.flatten().unsqueeze(0)]) - y_coords = torch.cat(num * [grid_y.flatten().unsqueeze(0)]) - value = image.reshape(-1, pix**2) - return x_coords, y_coords, value + return m, n, alpha diff --git a/src/radionets/evaluation/jets.py b/src/radionets/evaluation/jets.py index 61510513..baae6ad7 100644 --- a/src/radionets/evaluation/jets.py +++ b/src/radionets/evaluation/jets.py @@ -1,18 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np from astropy.modeling import fitting, models -from radionets.core.logging import setup_logger -from radionets.plotting.visualization import plot_fitgaussian +from radionets.core.logging import _setup_logger + +if TYPE_CHECKING: + from torch import Tensor -LOGGER = setup_logger(namespace=__name__) +LOGGER = _setup_logger(namespace=__name__) -def fitgaussian_crop(data, amp_scale=0.97, crop_size=0.1): + +def fitgaussian_crop( + image: Tensor | np.ndarray, + amp_scale: float = 0.97, + crop_size: float = 0.1, +): """Fitting a gaussian around the maximum Parameters ---------- - data : 2d array + image : 2d array Image amp_scale : float Reduces the fitted amplitude, encounters (partially) the problem @@ -21,96 +32,95 @@ def fitgaussian_crop(data, amp_scale=0.97, crop_size=0.1): crop_size : float proportionate size of the image after cropping - Returns ------- - result_lmf : astropy model - Fitted astropy model object + :class:`~astropy.modeling.models.Gaussian2D` + Fitted astropy Gaussian model. """ - size = data.shape[-1] - data[data < 0] = 0 + if isinstance(image, Tensor): + image = image.detach().cpu().numpy() + + size = image.shape[-1] + image = image.clip(min=0) crop_dist = int((size * crop_size) // 2) - maximum = np.unravel_index(data.argmax(), data.shape) - crop_xmin = crop_xmax = crop_ymin = crop_ymax = crop_dist - if maximum[0] < crop_dist: - crop_xmin = maximum[0] - if maximum[1] < crop_dist: - crop_ymin = maximum[1] - if size - maximum[0] < crop_dist: - crop_xmax = size - maximum[0] - if size - maximum[1] < crop_dist: - crop_ymax = size - maximum[1] - data_crop = data[ + maximum = np.unravel_index(image.argmax(), image.shape) + + crop_xmin = min(crop_dist, maximum[0]) + crop_xmax = min(crop_dist, size - maximum[0]) + crop_ymin = min(crop_dist, maximum[1]) + crop_ymax = min(crop_dist, size - maximum[1]) + + image_crop = image[ maximum[0] - crop_xmin : maximum[0] + crop_xmax, maximum[1] - crop_ymin : maximum[1] + crop_ymax, ] - M = models.Gaussian2D() + gaussian = models.Gaussian2D() lmf = fitting.LevMarLSQFitter() - xx, yy = np.indices([data_crop.shape[0], data_crop.shape[1]]) - result_lmf = lmf(M, xx, yy, data_crop) + xx, yy = np.indices(image_crop.shape) + result_lmf = lmf(gaussian, xx, yy, image_crop) + # the parameters can't be adjusted directly, need help-array params = result_lmf.parameters params[0] *= amp_scale params[1] += maximum[0] - crop_xmin params[2] += maximum[1] - crop_ymin result_lmf.parameters = params + return result_lmf -def fitgaussian_iterativ( - data, i=0, visualize=False, path=None, save=False, plot_format="pdf" +def fitgaussian_iterative( + image: Tensor | np.ndarray, + threshold: float = 0.05, + max_iter: int = 10, ): """Fitting a gaussian iteratively around the maxima. - Fit -> Substract -> Fit -> Substract ... until stopping criteria + Fit -> Subtract -> Fit -> Subtract ... until stopping criteria Parameters ---------- - data : 2d array - Image - i : int - Index of input image - visualize : bool - If the gauss should be plotted or not - path : string - Path to where the image is saved - save : bool - If the image is saved in path or not - plot_format : str - Format of the saved filed (png, pdf, ...) + image : :class:`~torch.Tensor` or :func:`~numpy.ndarray` + Input image. + threshold : float, optional + The threshold at which the iterations are stopped. + Default: 0.05 + max_iter : int, optional + Maximum iterations. Default: 10 Returns ------- - result_lmf : astropy model - Fitted astropy model object + params_list : list of :class:`~astropy.modeling.models.Gaussian2D` + List of fitted astropy model object(s). + + fits_list : list + List of :func:`~numpy.ndarray` with fits. """ - if visualize and path is None: - LOGGER.warning("Visualize is True, but no path is given.") - if not visualize and path is not None: - LOGGER.warning("Visualize is False, but a path is given.") + if isinstance(image, Tensor): + image = image.detach().cpu().numpy() params_list = [] - fit_list = [] - j = 0 - max_iterations = 10 - data_backup = data - - while data.max() > 0.05 and j < max_iterations: - result_lmf = fitgaussian_crop(data) - xx, yy = np.indices([data.shape[-1], data.shape[-1]]) + fits_list = [] + + for _ in range(max_iter): + if image.max() <= threshold: + break + + result_lmf = fitgaussian_crop(image) + xx, yy = np.indices([image.shape[-1], image.shape[-1]]) fit = result_lmf(xx, yy) + params = result_lmf.parameters params[1], params[2] = params[2], params[1] params[3], params[4] = params[4], params[3] result_lmf.parameters = params + # save, if gauss is not too narrow (e.g. one large pixel isn't meaningful here) - if not np.array(params[3:5] < data.shape[-1] / 40).any(): + if not np.any(params[3:5] < image.shape[-1] / 40): params_list.append(result_lmf) - fit_list.append(fit) - data -= fit - j += 1 - if visualize: - plot_fitgaussian(data_backup, fit_list, params_list, i, path, save, plot_format) + fits_list.append(fit) + + image -= fit - return params_list + return params_list, fits_list diff --git a/src/radionets/evaluation/pointsources.py b/src/radionets/evaluation/pointsources.py deleted file mode 100644 index ce20c703..00000000 --- a/src/radionets/evaluation/pointsources.py +++ /dev/null @@ -1,123 +0,0 @@ -import numpy as np - - -def get_min_max(element, index): - min_val = element[index, :][element[4, :] == 1.0].min() - max_val = element[index, :][element[4, :] == 1.0].max() - arg_max = np.argmax(element[index, :][element[4, :] == 1.0]) - arg_min = np.argmin(element[index, :][element[4, :] == 1.0]) - - return min_val, max_val, arg_min, arg_max - - -def get_length_extended(element): - """Identify the first and the last gaussian component in the extended source and - compute the distance. Last, add the extension of the source to the length. - - Parameters - ---------- - element : ndarray - array which contains all sources in the given image - - Returns - ------- - ndarray - length of extended source - """ - - x_min, x_max, arg_min_x, arg_max_x = get_min_max(element, 0) - y_min, y_max, arg_min_y, arg_max_y = get_min_max(element, 1) - - sig_x_max = element[2, :][element[4, :] == 1.0][arg_max_x] - sig_x_min = element[2, :][element[4, :] == 1.0][arg_min_x] - sig_y_max = element[3, :][element[4, :] == 1.0][arg_max_y] - sig_y_min = element[3, :][element[4, :] == 1.0][arg_min_y] - - # multiply with the factor for full width-half maximum - extend_max = np.sqrt(sig_x_max**2 + sig_y_max**2) * 2.35 - extend_min = np.sqrt(sig_x_min**2 + sig_y_min**2) * 2.35 - # compute amount of vector - laenge_x = (x_max - x_min) ** 2 - laenge_y = (y_max - y_min) ** 2 - laenge_extend = np.sqrt(laenge_x + laenge_y) + extend_max + extend_min - - return laenge_extend - - -def get_length_point(element): - """Compute linear extend of pointsources. - - Parameters - ---------- - element : ndarray - array which contains all sources in the given image - - Returns - ------- - ndarray - lengths of the pointsources in the image - """ - - laenge_y_point = element[3, :][element[4, :] == 0.0] * 2 - laenge_x_point = element[2, :][element[4, :] == 0.0] * 2 - laenge_point = np.max([laenge_x_point, laenge_y_point], axis=0) - - return laenge_point - - -def flux_comparison(pred, truth, source_list): - fluxes_pred = [] - fluxes_truth = [] - sigs_x = [] - sigs_y = [] - laenge = np.array([]) - for i, element in enumerate(source_list): - mean_pred = np.array([]) - mean_truth = np.array([]) - for blob in element.T: - y, x, sig_x, sig_y, mask = blob - - x_low = int(np.floor(x - sig_x)) - if x_low < 0: - x_low = 0 - - x_high = int(np.ceil(x + sig_x + 1)) - if x_high > 63: - x_high = 63 - - y_low = int(np.floor(y - sig_y)) - if y_low < 0: - y_low = 0 - - y_high = int(np.ceil(y + sig_y + 1)) - if y_high > 63: - y_high = 63 - - flux_truth = truth[i, int(x_low) : int(x_high), int(y_low) : int(y_high)] - flux_pred = pred[i, int(x_low) : int(x_high), int(y_low) : int(y_high)] - - mean_pred = np.append(mean_pred, flux_pred.mean()) - mean_truth = np.append(mean_truth, flux_truth.mean()) - sigs_x.append(sig_x) - sigs_y.append(sig_y) - - # sum over extended fluxes for truth and pred - c = mean_pred[element[4, :] == 1.0].sum() - mean_pred = np.append(mean_pred[element[4, :] == 0.0], c) - c = mean_truth[element[4, :] == 1.0].sum() - mean_truth = np.append(mean_truth[element[4, :] == 0.0], c) - - fluxes_pred.append(mean_pred) - fluxes_truth.append(mean_truth) - - # append lengths for point and extended sources - laenge_extend = get_length_extended(element) - laenge_point = get_length_point(element) - laenge = np.append(laenge, laenge_point) - laenge = np.append(laenge, laenge_extend) - - return ( - np.array(fluxes_pred, dtype="object"), - np.array(fluxes_truth, dtype="object"), - laenge, - ) diff --git a/src/radionets/evaluation/scripts/__init__.py b/src/radionets/evaluation/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/radionets/evaluation/scripts/start_evaluation.py b/src/radionets/evaluation/scripts/start_evaluation.py deleted file mode 100644 index 21dc5dae..00000000 --- a/src/radionets/evaluation/scripts/start_evaluation.py +++ /dev/null @@ -1,164 +0,0 @@ -import click -import numpy as np -import toml -from rich.pretty import pretty_repr - -from radionets.core.callbacks import PredictionImageGradient -from radionets.core.logging import setup_logger -from radionets.evaluation.train_inspection import ( - create_contour_plots, - create_inspection_plots, - create_predictions, - create_source_plots, - create_uncertainty_plots, - evaluate_area, - evaluate_area_sampled, - evaluate_dynamic_range, - evaluate_gan_sources, - evaluate_intensity, - evaluate_intensity_sampled, - evaluate_mean_diff, - evaluate_ms_ssim, - evaluate_ms_ssim_sampled, - evaluate_point, - evaluate_unc, - evaluate_viewing_angle, - save_sampled, -) -from radionets.evaluation.utils import check_outpath, check_samp_file, read_config - -LOGGER = setup_logger(namespace=__name__, tracebacks_suppress=[click]) - - -@click.command() -@click.argument("configuration_path", type=click.Path(exists=True, dir_okay=False)) -def main(configuration_path): - """ - Start evaluation of trained deep learning model. - - Parameters - ---------- - configuration_path : str - Path to the configuration toml file - """ - conf = toml.load(configuration_path) - eval_conf = read_config(conf) - - LOGGER.info("Evaluation config:") - LOGGER.info(pretty_repr(eval_conf)) - - if eval_conf["sample_unc"]: - LOGGER.info("Sampling test data set.") - save_sampled(eval_conf) - - for entry in conf["inspection"]: - if ( - conf["inspection"][entry] is not False - and isinstance(conf["inspection"][entry], bool) - and entry != "random" - ) and ( - not check_outpath(eval_conf["model_path"]) or conf["inspection"]["random"] - ): - create_predictions(eval_conf) - break - - if eval_conf["unc"]: - evaluate_unc(eval_conf) - create_uncertainty_plots( - eval_conf, num_images=eval_conf["num_images"], rand=eval_conf["random"] - ) - LOGGER.info(f"Created {eval_conf['num_images']} uncertainty images.") - - if eval_conf["vis_pred"]: - create_inspection_plots( - eval_conf, num_images=eval_conf["num_images"], rand=eval_conf["random"] - ) - - LOGGER.info(f"Created {eval_conf['num_images']} test predictions.") - - if eval_conf["vis_ms_ssim"]: - LOGGER.info("Visualization of ms ssim is enabled for source plots.") - - if eval_conf["vis_dr"]: - LOGGER.info(f"Created {eval_conf['num_images']} dynamic range plots.") - - if eval_conf["vis_source"]: - create_source_plots( - eval_conf, num_images=eval_conf["num_images"], rand=eval_conf["random"] - ) - - LOGGER.info(f"Created {eval_conf['num_images']} source predictions.") - - if eval_conf["plot_contour"]: - create_contour_plots( - eval_conf, num_images=eval_conf["num_images"], rand=eval_conf["random"] - ) - - LOGGER.info(f"Created {eval_conf['num_images']} contour plots.") - - if eval_conf["viewing_angle"]: - LOGGER.info("Start evaluation of viewing angles.") - evaluate_viewing_angle(eval_conf) - - if eval_conf["dynamic_range"]: - LOGGER.info("Start evaluation of dynamic ranges.") - evaluate_dynamic_range(eval_conf) - - if eval_conf["ms_ssim"]: - LOGGER.info("Start evaluation of ms ssim.") - samp_file = check_samp_file(eval_conf) - if samp_file: - evaluate_ms_ssim_sampled(eval_conf) - else: - evaluate_ms_ssim(eval_conf) - - if eval_conf["intensity"]: - LOGGER.info("Start evaluation of intensity.") - samp_file = check_samp_file(eval_conf) - if samp_file: - evaluate_intensity_sampled(eval_conf) - else: - evaluate_intensity(eval_conf) - - if eval_conf["mean_diff"]: - LOGGER.info("Start evaluation of mean difference.") - evaluate_mean_diff(eval_conf) - - if eval_conf["area"]: - LOGGER.info("Start evaluation of the area.") - samp_file = check_samp_file(eval_conf) - if samp_file: - evaluate_area_sampled(eval_conf) - else: - evaluate_area(eval_conf) - - if eval_conf["point"]: - LOGGER.info("Start evaluation of point sources.") - evaluate_point(eval_conf) - - if eval_conf["predict_grad"]: - output = PredictionImageGradient( - test_data=eval_conf["data_path"], - model=eval_conf["model_path"], - amp_phase=eval_conf["amp_phase"], - arch_name=eval_conf["arch_name"], - ) - output = output.save_output_pred() - grads_x, grads_y = output - - # specify names of saved gradients in x and y - np.savetxt("grads_x.csv", grads_x, delimiter=",") - np.savetxt("grads_y.csv", grads_y, delimiter=",") - - # # save image (no gradients) - np.savetxt("test_img.csv", output, delimiter=",") - - # # save x and y grads for fourier amplitude and phase - np.savetxt("grads_x_amp.csv", grads_x[0][0].cpu().numpy(), delimiter=",") - np.savetxt("grads_x_phase.csv", grads_x[0][1].cpu().numpy(), delimiter=",") - np.savetxt("grads_y_amp.csv", grads_y[0][0].cpu().numpy(), delimiter=",") - np.savetxt("grads_y_phase.csv", grads_y[0][1].cpu().numpy(), delimiter=",") - - if eval_conf["gan"]: - LOGGER.info("Start evaluation of GAN sources.") - evaluate_gan_sources(eval_conf) diff --git a/src/radionets/evaluation/utils.py b/src/radionets/evaluation/utils.py index e0111322..5bf61e85 100644 --- a/src/radionets/evaluation/utils.py +++ b/src/radionets/evaluation/utils.py @@ -1,844 +1,90 @@ -from pathlib import Path - -import h5py import numpy as np import torch import torch.nn.functional as F -from numba import set_num_threads, vectorize -from torch.utils.data import DataLoader - -from radionets import architecture -from radionets.core.data import load_data -from radionets.core.model import load_pre_model - - -def source_list_collate(batch): - """Collate function for the DataLoader with source list - - Parameters - ---------- - batch : tuple - input and target images alongside with the corresponding source_list - - Returns - ------- - tuple - stacked images and list for source_list values - """ - - x = [item[0] for item in batch] - y = [item[1] for item in batch] - z = [item[2][0] for item in batch] - return torch.stack(x), torch.stack(y), z - - -def create_databunch(data_path, fourier, source_list, batch_size): - """Create a dataloader object, which feeds the data batch-wise - - Parameters - ---------- - data_path : str - path to the data - fourier : bool - true, if data in Fourier space is used - source_list : bool - true, if source_list data is used - batch_size : int - number of images for one batch - - Returns - ------- - DataLoader - dataloader object - """ - # Load data sets - test_ds = load_data(data_path, mode="test", fourier=fourier) - - # Create databunch with defined batchsize and check for source_list - if source_list: - data = DataLoader( - test_ds, batch_size=batch_size, shuffle=True, collate_fn=source_list_collate - ) - else: - data = DataLoader(test_ds, batch_size=batch_size, shuffle=False) - return data - - -def create_sampled_databunch(data_path, batch_size): - """Create a dataloader object, which feeds the data batch-wise - - Parameters - ---------- - data_path : str - path to the data - fourier : bool - true, if data in Fourier space is used - source_list : bool - true, if source_list data is used - batch_size : int - number of images for one batch - Returns - ------- - DataLoader - dataloader object - """ - # Load data sets - test_ds = sampled_dataset(data_path) - - data = DataLoader(test_ds, batch_size=batch_size, shuffle=True) - return data - - -def read_config(config): - """Parse the toml config file - - Parameters - ---------- - config : dict - dict which contains the configurations loaded with toml.load - - Returns - ------- - dict - dict containing all configurations with unique keywords - """ - eval_conf = {} - eval_conf["data_path"] = config["paths"]["data_path"] - eval_conf["model_path"] = config["paths"]["model_path"] - eval_conf["model_path_2"] = config["paths"]["model_path_2"] - - eval_conf["quiet"] = config["mode"]["quiet"] - eval_conf["format"] = config["general"]["output_format"] - eval_conf["fourier"] = config["general"]["fourier"] - eval_conf["amp_phase"] = config["general"]["amp_phase"] - eval_conf["arch_name"] = config["general"]["arch_name"] - eval_conf["source_list"] = config["general"]["source_list"] - eval_conf["arch_name_2"] = config["general"]["arch_name_2"] - eval_conf["diff"] = config["general"]["diff"] - - eval_conf["vis_pred"] = config["inspection"]["visualize_prediction"] - eval_conf["vis_source"] = config["inspection"]["visualize_source_reconstruction"] - eval_conf["sample_unc"] = config["inspection"]["sample_uncertainty"] - eval_conf["unc"] = config["inspection"]["visualize_uncertainty"] - eval_conf["plot_contour"] = config["inspection"]["visualize_contour"] - eval_conf["vis_dr"] = config["inspection"]["visualize_dynamic_range"] - eval_conf["vis_ms_ssim"] = config["inspection"]["visualize_ms_ssim"] - eval_conf["num_images"] = config["inspection"]["num_images"] - eval_conf["random"] = config["inspection"]["random"] - - eval_conf["viewing_angle"] = config["eval"]["evaluate_viewing_angle"] - eval_conf["dynamic_range"] = config["eval"]["evaluate_dynamic_range"] - eval_conf["ms_ssim"] = config["eval"]["evaluate_ms_ssim"] - eval_conf["intensity"] = config["eval"]["evaluate_intensity"] - eval_conf["mean_diff"] = config["eval"]["evaluate_mean_diff"] - eval_conf["area"] = config["eval"]["evaluate_area"] - eval_conf["batch_size"] = config["eval"]["batch_size"] - eval_conf["point"] = config["eval"]["evaluate_point"] - eval_conf["predict_grad"] = config["eval"]["predict_grad"] - eval_conf["gan"] = config["eval"]["evaluate_gan"] - eval_conf["save_vals"] = config["eval"]["save_vals"] - eval_conf["save_path"] = config["eval"]["save_path"] - return eval_conf - - -def reshape_2d(array): - """Reshape 1d arrays into 2d ones. - - Parameters - ---------- - array: 1d array - input array +def get_ifft(image, amp_phase=False, scale=False, uncertainty=False): + """Get inverse FFT of provided image data. Returns ------- - array: 2d array - reshaped array + torch.tensor + Inverse FFT of provided image data. """ - shape = [int(np.sqrt(array.shape[-1]))] * 2 - return array.reshape(-1, *shape) - - -def make_axes_nice(fig, ax, im, title, phase=False, phase_diff=False, unc=False): - """Create nice colorbars with bigger label size for every axis in a subplot. - Also use ticks for the phase. - - Parameters - ---------- - fig : figure object - current figure - ax : axis object - current axis - im : ndarray - plotted image - title : str - title of subplot - """ - from mpl_toolkits.axes_grid1 import make_axes_locatable - - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.05) - ax.set_title(title) - - if phase: - cbar = fig.colorbar( - im, - cax=cax, - orientation="vertical", - ticks=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi], - ) - cbar.set_label("Phase / rad") - elif phase_diff: - cbar = fig.colorbar( - im, - cax=cax, - orientation="vertical", - ticks=[-2 * np.pi, -np.pi, 0, np.pi, 2 * np.pi], - ) - cbar.set_label("Phase / rad") - elif unc: - cbar = fig.colorbar(im, cax=cax, orientation="vertical") - cbar.set_label(r"$\sigma$ / $\mathrm{Jy \cdot px^{-1}}$") - else: - cbar = fig.colorbar(im, cax=cax, orientation="vertical") - cbar.set_label(r"$\mathrm{Flux \ density / Jy \cdot px^{-1}}$") - - if phase: - # set ticks for colorbar - cbar.ax.set_yticklabels([r"$-\pi$", r"$-\pi/2$", r"$0$", r"$\pi/2$", r"$\pi$"]) - elif phase_diff: - # set ticks for colorbar - cbar.ax.set_yticklabels([r"$-2\pi$", r"$-\pi$", r"$0$", r"$\pi$", r"$2\pi$"]) - - -def check_vmin_vmax(inp): - """Check wether the absolute of the maxmimum or the minimum is bigger. - If the minimum is bigger, return value with minus. Otherwise return - maximum. - - Parameters - ---------- - inp : float - input image - Returns - ------- - float - negative minimal or maximal value - """ - a = -inp.min() if np.abs(inp.min()) > np.abs(inp.max()) else inp.max() - return a - - -def load_pretrained_model(arch_name, model_path, img_size=63): - """Load model architecture and pretrained weigths. - - Parameters - ---------- - arch_name : str - name of the architecture - model_path : str - path to pretrained model - - Returns - ------- - arch : architecture object - architecture with pretrained weigths - """ - if ( - "filter_deep" in arch_name - or "resnet" in arch_name - or "Uncertainty" in arch_name - ): - arch = getattr(architecture, arch_name)(img_size) - else: - arch = getattr(architecture, arch_name)() - norm_dict = load_pre_model(arch, model_path, visualize=True) - return arch, norm_dict - - -def get_images(test_ds, num_images, rand=False, indices=None): - """Get n random test and truth images or mean, standard deviation and - true images from an already sampled dataset. - - Parameters - ---------- - test_ds : H5DataSet - data set with test images - num_images : int - number of test images - rand : bool - true if images should be drawn random - indices : list - list of indices to be used - - Returns - ------- - img_test : n 2d arrays - test images - img_true : n 2d arrays - truth images - """ - if hasattr(test_ds, "tar_fourier"): - indices = torch.arange(num_images) - if rand: - indices = torch.randint(0, len(test_ds), size=(num_images,)) - - # remove dublicate indices - while len(torch.unique(indices)) < len(indices): - new_indices = torch.randint( - 0, len(test_ds), size=(num_images - len(torch.unique(indices)),) - ) - indices = torch.cat((torch.unique(indices), new_indices)) - - # sort after getting indices - indices, _ = torch.sort(indices) - - img_test = test_ds[indices][0] - img_true = test_ds[indices][1] - return img_test, img_true, indices - else: - mean = test_ds[indices][0] - std = test_ds[indices][1] - img_true = test_ds[indices][2] - return mean, std, img_true - - -def eval_model(img, model): - """Put model into eval mode and evaluate test images. - - Parameters - ---------- - img : str - test image - model : architecture object - architecture with pretrained weigths - - Returns - ------- - pred : n 1d arrays - predicted images - """ - if len(img.shape) == (3): - img = img.unsqueeze(0) - model.eval() - if torch.cuda.is_available(): - model.cuda() - with torch.no_grad(): - if torch.cuda.is_available(): - pred = model(img.float().cuda())["pred"] - else: - pred = model(img.float())["pred"] - return pred.cpu() - - -def get_ifft(array, amp_phase=False, scale=False): - """Compute the inverse Fourier transformation + if isinstance(image, np.ndarray): + image = torch.from_numpy(image) - Parameters - ---------- - array : ndarray - array with shape (2, img_size, img_size) with optional batch size - amp_phase : bool, optional - true, if splitting in amplitude and phase was used, by default True + if len(image.shape) == 3: + image = image.unsqueeze(0) - Returns - ------- - ndarray - image(s) in image space - """ - if len(array.shape) == 3: - array = array.unsqueeze(0) if hasattr(array, "numpy") else array[np.newaxis, :] if amp_phase: - amp = 10 ** (10 * array[:, 0] - 10) - 1e-10 if scale else array[:, 0] + amp = 10 ** (10 * image[:, 0] - 10) - 1e-10 if scale else image[:, 0] + + index = 2 if uncertainty else 1 + a = amp * torch.cos(image[:, index]) + b = amp * torch.sin(image[:, index]) - a = amp * np.cos(array[:, 1]) - b = amp * np.sin(array[:, 1]) compl = a + b * 1j else: - compl = array[:, 0] + array[:, 1] * 1j + compl = image[:, 0] + image[:, 1] * 1j + if compl.shape[0] == 1: compl = compl.squeeze(0) - return np.abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl)))) - -def save_pred(path, img): - """Write test data and predictions to h5 file.""" - with h5py.File(path, "w") as hf: - for key, value in img.items(): - hf.create_dataset(key, data=value) - hf.close() + return torch.abs(torch.fft.ifftshift(torch.fft.ifft2(torch.fft.fftshift(compl)))) -def read_pred(path): - """Read data saved with save_pred from h5 file.""" - images = {} - with h5py.File(path, "r") as hf: - for key in hf: - images[key] = np.array(hf[key]) - hf.close() - return images +def apply_symmetry(image, uncertainty: bool = False) -> torch.tensor: + """Applies symmetry operations on an array. - -def check_outpath(model_path): - """Checks if there is already a predictions file in the evaluation folder - - Parameters - ---------- - model_path : str - path to the model - - Returns - ------- - bool - true, if the file exists - """ - name_model = Path(model_path).stem - model_path = Path(model_path).parent / "evaluation" / f"predictions_{name_model}.h5" - path = Path(model_path) - exists = path.exists() - return exists - - -def symmetry(image, key): - """Symmetry function to complete the images. + This follows Figure 5.3 in http://dx.doi.org/10.17877/DE290R-24834 Parameters ---------- - image : torch.Tensor - (stack of) half images + image : array_like + Input array of half images. + uncertainty : bool, optional + Whether image data contains uncertainty data. + Default: False Returns ------- - torch.Tensor - quadratic images after utilizing symmetry + symmetrical : torch.tensor + Torch tensor containing the symmetrical image. """ if isinstance(image, np.ndarray): - image = torch.tensor(image) - if len(image.shape) == 3: - image = image.view(1, image.shape[0], image.shape[1], image.shape[2]) - half_image = image.shape[-1] // 2 - upper_half = image[:, :, :half_image, :].clone() - a = torch.rot90(upper_half, 2, dims=[-2, -1]) - - image[:, 0, half_image + 1 :, 1:] = a[:, 0, :-1, :-1] - image[:, 0, half_image + 1 :, 0] = a[:, 0, :-1, -1] - - if key == "unc": - image[:, 1, half_image + 1 :, 1:] = a[:, 1, :-1, :-1] - image[:, 1, half_image + 1 :, 0] = a[:, 1, :-1, -1] - else: - image[:, 1, half_image + 1 :, 1:] = -a[:, 1, :-1, :-1] - image[:, 1, half_image + 1 :, 0] = -a[:, 1, :-1, -1] - - return image - - -def apply_symmetry(img_dict): - """Pads and applies symmetry to half images. - Takes a dict as input. - - Parameters - ---------- - img_dict : dict - input dict which contains the half images - - Returns - ------- - dict - input dict with quadratic images - """ - for key in img_dict: - if key != "indices": - if isinstance(img_dict[key], np.ndarray): - img_dict[key] = torch.tensor(img_dict[key]) - half_image = img_dict[key].shape[-1] // 2 - output = F.pad( - input=img_dict[key], - pad=(0, 0, 0, half_image - 5), - mode="constant", - value=0, - ) - output = symmetry(output, key) - img_dict[key] = output - - return img_dict - - -@vectorize(["float64(float64, float64, float64, float64)"], target="cpu") -def tn_numba_vec_cpu(mu, sig, a, b): - rv = np.random.normal(loc=mu, scale=sig) - cond = rv > a and rv < b - while not cond: - rv = np.random.normal(loc=mu, scale=sig) - cond = rv > a and rv < b - - return rv - - -@vectorize(["float64(float64, float64, float64, float64)"], target="parallel") -def tn_numba_vec_parallel(mu, sig, a, b): - rv = np.random.normal(loc=mu, scale=sig) - cond = rv > a and rv < b - while not cond: - rv = np.random.normal(loc=mu, scale=sig) - cond = rv > a and rv < b - - return rv - - -def trunc_rvs(mu, sig, num_samples, mode, target="cpu", nthreads=1): - if mode == "amp": - a = 0 - b = np.inf - elif mode == "phase": - a = -np.pi - b = np.pi - elif mode == "real" or mode == "imag": - a = -np.inf - b = np.inf - else: - raise ValueError("Unsupported mode, use either ``phase`` or ``amp``.") - mu = np.tile(mu, (num_samples, 1, 1, 1)) - sig = np.tile(sig, (num_samples, 1, 1, 1)) - - if target == "cpu": - if nthreads > 1: - raise ValueError( - f"Target is ``cpu`` but nthreads is {nthreads}, " - "use target=``parallel`` instead." - ) - res = tn_numba_vec_cpu(mu, sig, a, b) - elif target == "parallel": - if nthreads == 1: - raise ValueError( - "Target is ``parallel`` but nthreaads is 1, use target=``cpu`` instead." - ) - set_num_threads(int(nthreads)) - res = tn_numba_vec_parallel(mu, sig, a, b) - else: - raise ValueError("Unsupported target, use cpu or parallel.") - - return res.swapaxes(0, 1) - - -def sample_images(mean, std, num_samples, conf): - """Samples for every pixel in Fourier space from a - truncated Gaussian distribution based on the output - of the network. - - Parameters - ---------- - mean : torch.tensor - mean values of the pixels with shape (number of images, number of samples, - image size // 2 + 1, image_size) - std : torch.tensor - uncertainty values of the pixels with shape (number of images, - number of samples, image size // 2 + 1, image_size) - num_samples : int - number of samples in Fourier space - - Returns - ------- - dict - resulting mean and standard deviation - """ - mean_amp, mean_phase = mean[:, 0], mean[:, 1] - std_amp, std_phase = std[:, 0], std[:, 1] - num_img = mean_amp.shape[0] - - mode = ["amp", "phase"] if conf["amp_phase"] else ["real", "imag"] - - # amplitude - sampled_gauss_amp = trunc_rvs( - mu=mean_amp, - sig=std_amp, - mode=mode[0], - num_samples=num_samples, - ).reshape(num_img * num_samples, mean_amp.shape[-2], mean_amp.shape[-1]) - - # phase - sampled_gauss_phase = trunc_rvs( - mu=mean_phase, - sig=std_phase, - mode=mode[1], - num_samples=num_samples, - ).reshape(num_img * num_samples, mean_phase.shape[-2], mean_phase.shape[-1]) - - # masks - if conf["amp_phase"]: - mask_invalid_amp = sampled_gauss_amp <= (0 - 1e-4) - mask_invalid_phase = (sampled_gauss_phase <= (-np.pi - 1e-4)) | ( - sampled_gauss_phase >= (np.pi + 1e-4) - ) - - assert mask_invalid_amp.sum() == 0 - assert mask_invalid_phase.sum() == 0 - - sampled_gauss = np.stack([sampled_gauss_amp, sampled_gauss_phase], axis=1) - - # pad resulting images and utilize symmetry - sampled_gauss = F.pad( - input=torch.tensor(sampled_gauss), - pad=(0, 0, 0, mean_amp.shape[-2] - 2), - mode="constant", - value=0, - ) - sampled_gauss_symmetry = symmetry(sampled_gauss, None) - - fft_sampled_symmetry = get_ifft( - sampled_gauss_symmetry, amp_phase=conf["amp_phase"], scale=False - ).reshape(num_img, num_samples, mean_amp.shape[-1], mean_amp.shape[-1]) - - results = { - "mean": fft_sampled_symmetry.mean(axis=1), - "std": fft_sampled_symmetry.std(axis=1), - } - return results + image = torch.from_numpy(image) + if image.ndim == 3: + image = image.unsqueeze(0) -def mergeDictionary(dict_1, dict_2): - dict_3 = {**dict_1, **dict_2} - for key, value in dict_3.items(): - if key in dict_1 and key in dict_2: - dict_3[key] = np.append(dict_1[key], value) - return dict_3 + _, _, H, W = image.shape + # Assume images are square; get target height from full width + # NOTE: This may have to be changed should we allow different + # aspect ratios in the future + half_width = W // 2 -class sampled_dataset: - def __init__(self, bundle_path): - """ - Save the bundle paths and the number of bundles in one file. - """ - if bundle_path == []: - raise ValueError("No bundles found! Please check the names of your files.") - self.bundle_path = bundle_path + # Calculate the overlap from difference of half_image and + # input height H, so we do not need to pass it anymore + overlap = H - half_width - def __len__(self): - """Returns the total number of pictures in this dataset""" - bundle = h5py.File(self.bundle_path, "r") - data = bundle["mean"] - return data.shape[0] + pad_bottom = half_width - overlap + full_image = F.pad(image, pad=(0, 0, 0, pad_bottom), mode="constant", value=0) - def __getitem__(self, i): - mean = self.open_image("mean", i) - std = self.open_image("std", i) - true = self.open_image("true", i) - return mean, std, true + upper_half = image[..., :half_width, :] - def open_image(self, var, i): - bundle = h5py.File(self.bundle_path, "r") - data = bundle[var] - data = data[i] - return data + # flip along image axes W and H to rotate image by 180 deg + rotated = upper_half.flip(-2, -1) + # Shift columns to the right by 1 to account for central pixel + # and drop last row + lower_half = torch.roll(rotated, shifts=1, dims=-1) + lower_half = lower_half[..., :-1, :] -def apply_normalization(img_test, norm_dict): - """Applies one of currently two normalization - methods if the training was normalized + if not uncertainty: + lower_half[:, 1, ...] *= -1 - Parameters - ---------- - img_test : torch.Tensor - input image - norm_dict : dictionary - either empty (no normalization) or containing the factors - - Returns - ------- - img_test : torch.Tensor - normalized image - norm_dict : dictionary - updated dictionary - """ - # normalize using mean and std for whole dataset - if norm_dict and "mean_real" in norm_dict: - img_test[:, 0][img_test[:, 0] != 0] = ( - img_test[:, 0][img_test[:, 0] != 0] - norm_dict["mean_real"] - ) / norm_dict["std_real"] - - img_test[:, 1][img_test[:, 1] != 0] = ( - img_test[:, 1][img_test[:, 1] != 0] - norm_dict["mean_imag"] - ) / norm_dict["std_imag"] - - # scale with the maximum value of each image - elif norm_dict and "max_scaling" in norm_dict: - max_factors_real = torch.amax(img_test[:, 0], dim=(-2, -1), keepdim=True) - max_factors_imag = torch.amax( - torch.abs(img_test[:, 1]), dim=(-2, -1), keepdim=True - ) - img_test[:, 0] *= 1 / torch.amax(img_test[:, 0], dim=(-2, -1), keepdim=True) - img_test[:, 1] *= 1 / torch.amax( - torch.abs(img_test[:, 1]), dim=(-2, -1), keepdim=True - ) - norm_dict["max_factors_real"] = max_factors_real - norm_dict["max_factors_imag"] = max_factors_imag - - # normalize each image to mean=0 and std=1 - elif norm_dict and "all" in norm_dict: - means = ( - img_test.mean(axis=-1) - .mean(axis=-1) - .reshape(img_test.shape[0], img_test.shape[1], 1, 1) - ) - stds = ( - img_test.std(axis=-1) - .std(axis=-1) - .reshape(img_test.shape[0], img_test.shape[1], 1, 1) - ) - img_test = (img_test - means) / stds - norm_dict["means"] = means - norm_dict["stds"] = stds - - return img_test, norm_dict - - -def rescale_normalization(pred, norm_dict): - """Rescale the prediction after normalized training - - Parameters - ---------- - pred : torch.Tensor - predicted image - norm_dict : dictionary - either empty (no normalization) or containing the factors - - Returns - ------- - pred : torch.Tensor - recaled predicted image - """ - if norm_dict and "mean_real" in norm_dict: - pred[:, 0] = pred[:, 0] * norm_dict["std_real"] + norm_dict["mean_real"] - if pred.shape[1] == 4: - pred[:, 2] = pred[:, 2] * norm_dict["std_imag"] + norm_dict["mean_imag"] - else: - pred[:, 1] = pred[:, 1] * norm_dict["std_imag"] + norm_dict["mean_imag"] - - elif norm_dict and "max_scaling" in norm_dict: - pred[:, 0] *= norm_dict["max_factors_real"] - pred[:, 1] *= norm_dict["max_factors_imag"] - - elif norm_dict and "all" in norm_dict: - pred[:, 0] = pred[:, 0] * norm_dict["stds"][:, 0] + norm_dict["means"][:, 0] - if pred.shape[1] == 4: - pred[:, 2] = pred[:, 2] * norm_dict["stds"][:, 1] + norm_dict["means"][:, 1] - else: - pred[:, 1] = pred[:, 1] * norm_dict["stds"][:, 1] + norm_dict["means"][:, 1] - - return pred - - -def preprocessing(conf): - """Makes the necessary preprocessing for the evaluation - methods analyzing the whole test dataset. - - Parameters - ---------- - conf : dictionary - config file containing the settings - - Returns - ------- - model : architecture - model initialized with save file - model_2 : architecture - model initialized with save file - loader : torch.Dataloader - feeds the data batch-wise - norm_dict : dictionary - dict containing the normalization factors - out_path : Path object - path to the evaluation folder - """ - # create DataLoader - loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] - ) - model_path = conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) - - img_size = loader.dataset[0][0][0].shape[-1] - model, norm_dict = load_pretrained_model( - conf["arch_name"], conf["model_path"], img_size - ) - - # Loads second model if the two channels were trainined separately - model_2 = None - if conf["model_path_2"] != "none": - model_2, norm_dict = load_pretrained_model( - conf["arch_name_2"], conf["model_path_2"], img_size - ) - - return model, model_2, loader, norm_dict, out_path - - -def process_prediction(conf, img_test, img_true, norm_dict, model, model_2): - """Applies the normalization, gets and rescales a - prediction and performs the inverse Fourier transformation. - - Parameters - ---------- - conf : dictionary - config files containing the settings - img_test : torch.Tensor - input file for the network - img_true : torch.tensor - true image - norm_dict : dictionary - dict containing the normalization factors - model : architecture - model initialized with save file - model_2 : - model initialized with save file - - Returns - ------- - ifft_pred : ndarray - predicted source in image space - ifft_truth : ndarray - true source in image space - """ - img_test, norm_dict = apply_normalization(img_test, norm_dict) - pred = eval_model(img_test, model) - pred = rescale_normalization(pred, norm_dict) - if model_2 is not None: - pred_2 = eval_model(img_test, model_2) - pred_2 = rescale_normalization(pred_2, norm_dict) - pred = torch.cat((pred, pred_2), dim=1) - - # apply symmetry - if pred.shape[-2] < pred.shape[-1]: - img_dict = {"truth": img_true, "pred": pred} - img_dict = apply_symmetry(img_dict) - img_true = img_dict["truth"] - pred = img_dict["pred"] - - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) - - return ifft_pred, ifft_truth - - -def check_samp_file(eval_conf): - """Checks if a file with sampled images - is located in the evaluation folder - - Parameters - ---------- - eval_conf : dict - contains the evaluation parameters - - Returns - ------- - bool - true if file exists, otherwise false - """ - model_path = eval_conf["model_path"] - out_path = Path(model_path).parent / "evaluation" - out_path.mkdir(parents=True, exist_ok=True) + full_image[..., half_width + 1 :, :] = lower_half - name_model = Path(model_path).stem - data_path = out_path / f"sampled_imgs_{name_model}.h5" - return data_path.is_file() + return full_image diff --git a/src/radionets/io/__init__.py b/src/radionets/io/__init__.py new file mode 100644 index 00000000..c15bd6ee --- /dev/null +++ b/src/radionets/io/__init__.py @@ -0,0 +1,4 @@ +from .data import H5DataModule, WebDatasetModule +from .train_config import TrainConfig + +__all__ = ["TrainConfig", "H5DataModule", "WebDatasetModule"] diff --git a/src/radionets/io/data.py b/src/radionets/io/data.py new file mode 100644 index 00000000..da902cab --- /dev/null +++ b/src/radionets/io/data.py @@ -0,0 +1,539 @@ +from collections.abc import Callable +from pathlib import Path + +import h5py +import numpy as np +import pyarrow.parquet as pq +import torch +import webdataset as wds +from lightning import LightningDataModule +from natsort import natsorted +from torch.utils.data import DataLoader, Dataset + + +class H5DataSet(Dataset): + def __init__(self, data_dir, tar_fourier, mode="train"): + """ + Save the bundle paths and the number of bundles in one file. + """ + if not isinstance(data_dir, Path): + data_dir = Path(data_dir) + + bundle_paths = data_dir.glob(f"samp_{mode}_*.h5") + self.bundles = natsorted(bundle_paths) + self.num_img = len(self.open_bundle(self.bundles[0], "x")) + self.tar_fourier = tar_fourier + + if not self.bundles: + raise ValueError("No bundles found! Please check the names of your files.") + + def __call__(self): + return print("This is the H5DataSet class.") + + def __len__(self): + """ + Returns the total number of pictures in this dataset + """ + return len(self.bundles) * self.num_img + + def __getitem__(self, i): + x = self.open_image("x", i) + y = self.open_image("y", i) + return x, y + + def open_bundle(self, bundle_path, var): + bundle = h5py.File(bundle_path, "r") + data = bundle[var] + return data + + def open_image(self, var, i): + if isinstance(i, int): + i = torch.tensor([i]) + + elif isinstance(i, np.ndarray): + i = torch.tensor(i) + + indices, _ = torch.sort(i) + bundle = torch.div(indices, self.num_img, rounding_mode="floor") + image = indices - bundle * self.num_img + bundle_unique = torch.unique(bundle) + + bundle_paths = [ + h5py.File(self.bundles[bundle], "r") for bundle in bundle_unique + ] + bundle_paths_str = list(map(str, bundle_paths)) + + data = torch.from_numpy( + np.array( + [ + bund[var][img] + for bund, bund_str in zip(bundle_paths, bundle_paths_str) + for img in image[ + bundle == bundle_unique[bundle_paths_str.index(bund_str)] + ] + ] + ) + ) + + if self.tar_fourier is False and data.shape[1] == 2: + raise ValueError( + "Two channeled data is used despite Fourier being False.\ + Set Fourier to True!" + ) + + if data.shape[0] == 1: + data = data.squeeze(0) + return data.float() + + +class H5DataModule(LightningDataModule): + """ + PyTorch Lightning DataModule for handling visibility + data from HDF5 files. + + This DataModule manages the loading and preparation + of the visibility datasets for training, validation, + testing, and prediction stages of radionets. + + Parameters + ---------- + data_dir : str or :class:`pathlib.Path` + Directory path containing the HDF5 data files. + batch_size : int, optional + Number of samples per batch. + Default: ``32`` + fourier : bool, optional + Whether to use Fourier space targets. + Default: ``False`` + num_workers : int, optional + Number of worker processes for data loading. + Default ``10`` + + Attributes + ---------- + fourier : bool + Flag indicating whether Fourier space inputs/targets + are used. + data_dir : str or Path + Directory path to the data. + batch_size : int + Batch size for data loaders. + num_workers : int + Number of worker processes. + vis_train : H5DataSet + Training dataset used in the Trainer.fit stage. + vis_val : H5DataSet + Validation dataset used in the Trainer.validate + stage. + vis_test : H5DataSet + Test dataset used in the Trainer.test stage. + vis_predict : H5DataSet + Prediction dataset used for inference in the + Trainer.predict stage. + """ + + def __init__( + self, + data_dir: str | Path, + *, + batch_size: int = 32, + fourier: bool = False, + num_workers: int = 10, + **kwargs, + ): + super().__init__() + self.fourier = fourier + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + + self.train_length = None + self.valid_length = None + self.test_length = None + self.predict_length = None + + self.save_hyperparameters() + self.setup("fit") + + def setup(self, stage: str): + """ + Set up datasets for the specified stage. + + Creates H5DataSet instances for the appropriate data + splits based on the stage of the workflow + (fit, test, or predict). + + Parameters + ---------- + stage : str + The stage of the workflow. Must be one of: + + - 'fit': Prepares training and validation datasets + - 'test': Prepares test dataset + - 'predict': Prepares prediction dataset + + Raises + ------ + ValueError + If the provided stage is not one of 'fit', 'test', + or 'predict'. + + Notes + ----- + This method is called automatically by PyTorch Lightning + before any one of the training, validation, testing, or + prediction loop begins. + """ + match stage: + case "fit": + self.vis_train = H5DataSet( + self.data_dir, + tar_fourier=self.fourier, + mode="train", + ) + self.vis_val = H5DataSet( + self.data_dir, + tar_fourier=self.fourier, + mode="valid", + ) + self.train_length = len(self.vis_train) + self.valid_length = len(self.vis_val) + + case "test": + self.vis_test = H5DataSet( + self.data_dir, + tar_fourier=self.fourier, + mode="test", + ) + self.test_length = len(self.vis_test) + case "predict": + self.vis_predict = H5DataSet( + self.data_dir, + tar_fourier=self.fourier, + mode="test", + ) + # NOTE: For now, this will look for test files, + # but in the future this stage should be used for + # inference only + self.predict_length = len(self.vis_predict) + case _: + raise ValueError( + f"Stage {stage} is not available in {self.__class__.__name__}" + ) + + def train_dataloader(self): + """ + Create and return the training DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the training dataset with + configured batch size and number of workers. + """ + return DataLoader( + self.vis_train, + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + """ + Create and return the validation DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the validation dataset + with configured batch size and number of workers. + """ + return DataLoader( + self.vis_val, + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + """ + Create and return the test DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the test dataset with + configured batch size and number of workers. + """ + return DataLoader( + self.vis_test, + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + def predict_dataloader(self): + """ + Create and return the prediction DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the prediction dataset + with configured batch size and number of workers. + """ + return DataLoader( + self.vis_predict, + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + + +def identity(x): + """Identity function for no-op transformations.""" + return x + + +class WebDatasetModule(LightningDataModule): + """ + PyTorch Lightning DataModule for handling visibility + data from WebDataset files. + + This DataModule manages the loading and preparation + of the visibility datasets for training, validation, + testing, and prediction stages of radionets. + + Parameters + ---------- + data_dir : str or Path + Directory containing WebDataset tar files. + Expected structure: + - train-{000000..NNNNN}.tar + - valid-{000000..NNNNN}.tar + - test-{000000..NNNNN}.tar + epochs : int + Number of epochs. Used for WebDataset.with_epoch(). + batch_size : int, optional + Number of samples per batch. Default: 32 + fourier : bool, optional + Whether inputs/targets are in Fourier space. Default: False + num_workers : int, optional + Number of worker processes. Default: 10 + prefetch_factor : int, optional + Number of batches to prefetch per worker. Default: 2 + persistent_workers : bool, optional + Keep workers alive between epochs. Default: True + transform : Callable, optional + Transform applied to inputs. Default: None + target_transform : Callable, optional + Transform applied to targets. Default: None + shuffle_buffer : int, optional + Size of shuffle buffer for training. Default: None + + Notes + ----- + WebDataset files should be created with the following structure: + - __key__: unique identifier + - *.input.npy: input visibility data as numpy array in binary file format + - *.target.npy: target image data as numpy array in binary file format + """ + + def __init__( + self, + data_dir: str | Path, + *, + batch_size: int = 32, + fourier: bool = False, + num_workers: int = 10, + prefetch_factor: int = 2, + persistent_workers: bool = True, + transform: Callable | None = None, + target_transform: Callable | None = None, + shuffle_buffer: int | None = None, + compressed: bool = False, + **kwargs, + ): + super().__init__() + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.fourier = fourier + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.persistent_workers = persistent_workers + self.transform = transform or identity + self.target_transform = target_transform or identity + self.shuffle_buffer = shuffle_buffer + self.compressed = compressed + + self._get_dataset_lengths() + + self.save_hyperparameters(ignore=["transform", "target_transform"]) + + def _get_dataset_lengths(self): + train_parquet = list(self.data_dir.glob("train-*.parquet"))[0] + valid_parquet = list(self.data_dir.glob("valid-*.parquet"))[0] + test_parquet = list(self.data_dir.glob("test-*.parquet"))[0] + + self.train_length = ( + pq.read_table(train_parquet) + .to_pandas()["total_samples_in_dataset"] + .values[0] + ) + self.valid_length = ( + pq.read_table(valid_parquet) + .to_pandas()["total_samples_in_dataset"] + .values[0] + ) + self.test_length = ( + pq.read_table(test_parquet) + .to_pandas()["total_samples_in_dataset"] + .values[0] + ) + + def _create_dataset(self, mode: str, shuffle: bool = True): + """ + Create a WebDataset pipeline for the specified mode. + + Parameters + ---------- + mode : str + One of 'train', 'valid', or 'test'. + shuffle : bool, optional + Whether to shuffle the data. Default: False + + Returns + ------- + wds.WebDataset + Configured WebDataset pipeline. + """ + suffix = "tar.gz" if self.compressed else "tar" + + urls = sorted(map(str, self.data_dir.glob(f"{mode}-*.{suffix}"))) + + if not urls: + raise ValueError( + f"No WebDataset shards found for mode '{mode}' in {self.data_dir}. " + f"Expected pattern: {mode}-{{000000..NNNNN}}.{suffix}" + ) + if shuffle: + shuffle = self.batch_size + + if not self.shuffle_buffer: + self.shuffle_buffer = 10 * self.batch_size + + dataset = ( + wds.WebDataset(urls, shardshuffle=True, nodesplitter=wds.split_by_node) + .decode() + .to_tuple("input.npy", "target.npy") + .map_tuple( + lambda x: torch.from_numpy(x).float(), + lambda y: torch.from_numpy(y).float(), + ) + .map_tuple(self.transform, self.target_transform) + .batched(self.batch_size) + ) + + return dataset + + def setup(self, stage: str): + """ + Set up datasets for the specified stage. + + Parameters + ---------- + stage : str + One of 'fit', 'test', or 'predict'. + """ + match stage: + case "fit": + self.train_dataset = self._create_dataset("train", shuffle=True) + self.val_dataset = self._create_dataset("valid", shuffle=False) + case "test": + self.test_dataset = self._create_dataset("test", shuffle=False) + case "predict": + self.predict_dataset = self._create_dataset("test", shuffle=False) + + predict_parquet = list(self.data_dir.glob("predict-*.parquet"))[0] + self.predict_length = ( + pq.read_table(predict_parquet) + .to_pandas()["total_samples_in_dataset"] + .values[0] + ) + case _: + raise ValueError( + f"Stage '{stage}' is not available in {self.__class__.__name__}" + ) + + def _create_dataloader(self, dataset): + """ + Create a WebLoader with shuffling and batching. + + Parameters + ---------- + dataset : wds.WebDataset + WebDataset to wrap. + + Returns + ------- + wds.WebLoader + WebLoader DataLoader. + """ + return ( + wds.WebLoader( + dataset, + num_workers=self.num_workers, + batch_size=None, + prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None, + persistent_workers=self.persistent_workers + if self.num_workers > 0 + else False, + pin_memory=True, + ) + .unbatched() + .shuffle(self.shuffle_buffer) + .batched(self.batch_size) + ) + + def train_dataloader(self): + """Create training DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the training dataset + with configured batch size and number of workers. + """ + loader = self._create_dataloader(self.train_dataset) + return loader.with_epoch(self.train_length // (self.batch_size * 1)) + + def val_dataloader(self): + """Create validation DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the validation dataset + with configured batch size and number of workers. + """ + return self._create_dataloader(self.val_dataset).with_epoch( + self.test_length // (self.batch_size * 1) + ) + + def test_dataloader(self): + """Create test DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the test dataset + with configured batch size and number of workers. + """ + return self._create_dataloader(self.test_dataset) + + def predict_dataloader(self): + """Create prediction DataLoader. + + Returns + ------- + :class:`torch.utils.data.DataLoader` + PyTorch DataLoader for the prediction dataset + with configured batch size and number of workers. + """ + return self._create_dataloader(self.predict_dataset) diff --git a/src/radionets/io/train_config.py b/src/radionets/io/train_config.py new file mode 100644 index 00000000..6cd22a60 --- /dev/null +++ b/src/radionets/io/train_config.py @@ -0,0 +1,336 @@ +import inspect +import os +import tomllib +from collections.abc import Callable +from pathlib import Path +from typing import Literal + +import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from radionets.architecture import archs + +from . import data +from .training import ( + BatchSizeFinderCallbackConfig, + CodeCarbonEmissionTrackerConfig, + CometLoggerConfig, + CSVLoggerConfig, + DeepSpeedConfig, + EarlyStoppingCallbackConfig, + LearningRateMonitorCallbackConfig, + LossConfig, + LRSchedulerConfig, + MLFlowLoggerConfig, + ModelCheckpointCallbackConfig, + OptimizerConfig, + TimerCallbackConfig, +) + + +class PathsConfig(BaseModel): + """File paths configuration.""" + + data_path: Path = Path("./example_data/") + model_path: Path = Path("./build/example_model/") + checkpoint: Path | None | Literal[False] = None + + @field_validator("data_path", "model_path", "checkpoint") + @classmethod + def expand_path(cls, v: Path) -> Path: + """Expand and resolve paths.""" + + if v in {None, False}: + v = None + else: + v.expanduser().resolve() + + return v + + +class ModelConfig(BaseModel): + arch_name: str | Callable = archs.SRResNet18 + fourier: bool = True + amp_phase: bool = True + normalize: bool = False + + @field_validator("arch_name") + @classmethod + def load_arch_instance(cls, arch: str): + avail_archs = {} + + for member in inspect.getmembers(archs): + if inspect.isclass(member[1]): + avail_archs[member[0]] = member[1] + + try: + arch = avail_archs[arch] + except KeyError as e: + raise ValueError( + f"Unknown architecture: TrainConfig got {arch} but expected " + f"one of {avail_archs.keys()}!" + ) from e + + return arch + + +class TrainingConfig(BaseModel): + """Hyperparameters configuration.""" + + num_epochs: int = Field(default=50, gt=0) + batch_size: int = Field(default=100, gt=0) + loss: LossConfig = LossConfig() + optimizer: OptimizerConfig = OptimizerConfig() + lr_scheduling: bool | LRSchedulerConfig = False + + @field_validator("loss", mode="after") + @classmethod + def validate_loss(cls, v): + if isinstance(v, dict): + return LossConfig(**v) + + return v + + @field_validator("optimizer", mode="after") + @classmethod + def validate_optimizer(cls, v): + if isinstance(v, dict): + return OptimizerConfig(**v) + + return v + + @field_validator("lr_scheduling", mode="after") + @classmethod + def validate_lr_scheduler(cls, v: bool | LRSchedulerConfig): + if isinstance(v, str): + return v + elif isinstance(v, dict): + return LRSchedulerConfig(**v) + elif v is True: + return LRSchedulerConfig() + + return v + + +class DeviceConfig(BaseModel): + """Device configuration settings.""" + + accelerator: str = "auto" + num_devices: str | list | int = "auto" + precision: str | int = "32-true" + deepspeed: bool | str | DeepSpeedConfig = False + strategy: str = "auto" + + @model_validator(mode="after") + def check_device_count(self) -> None: + if self.accelerator in ["gpu", "tpu", "hpu"] and not torch.cuda.is_available(): + raise ValueError( + f"'accelerator' is set to {self.accelerator} in the " + "configuration but CUDA is not available. Please " + "ensure CUDA is installed or set accelerator to 'cpu'." + ) + + if ( + self.accelerator in ["gpu", "tpu", "hpu"] + and isinstance(self.num_devices, int) > torch.cuda.device_count() + ): + raise ValueError( + f"'num_devices' exceeds the number of available {self.accelerator}s " + f"({self.num_devices} > {torch.cuda.device_count})" + ) + + return self + + @field_validator("deepspeed", mode="after") + @classmethod + def validate_deepspeed(cls, v: bool | str | DeepSpeedConfig): + if isinstance(v, str): + return v + elif isinstance(v, dict): + return DeepSpeedConfig(**v) + elif v is True: + return DeepSpeedConfig() + + return v + + +class DataLoaderConfig(BaseModel): + """DataLoader configuration.""" + + module: str | Callable = data.H5DataModule + num_workers: int = Field(default=10, gt=0) + compressed: bool = False + + model_config = ConfigDict(extra="allow") + + @field_validator("module") + @classmethod + def load_data_module_instance(cls, name: str): + if isinstance(name, type): + return name + + avail_data_modules = {} + + for member in inspect.getmembers(data): + if inspect.isclass(member[1]): + avail_data_modules[member[0]] = member[1] + + try: + if name.lower() in ["h5", "hdf5"]: + data_module = avail_data_modules["H5DataModule"] + elif name.lower() in ["wds", "webdataset"]: + data_module = avail_data_modules["WebDatasetModule"] + else: + data_module = avail_data_modules[name] + except KeyError as e: + raise ValueError( + f"Unknown optimizer: TrainConfig got {name} but expected " + f"one of {set(avail_data_modules)}!" + ) from e + + return data_module + + +class CallbacksConfig(BaseModel): + "Callbacks configuration." + + model_checkpoint: bool | ModelCheckpointCallbackConfig = False + batch_size_finder: bool | BatchSizeFinderCallbackConfig = False + early_stopping: bool | EarlyStoppingCallbackConfig = False + lr_monitor: bool | LearningRateMonitorCallbackConfig = False + timer: bool | TimerCallbackConfig = False + device_stats_monitor: bool = False + + @field_validator("model_checkpoint", mode="after") + @classmethod + def validate_model_checkpoint(cls, v): + if isinstance(v, dict): + return ModelCheckpointCallbackConfig(**v) + elif v is True: + return ModelCheckpointCallbackConfig() # Return defaults + + return v + + @field_validator("batch_size_finder", mode="after") + @classmethod + def validate_batch_size_finder(cls, v): + if isinstance(v, dict): + return BatchSizeFinderCallbackConfig(**v) + elif v is True: + return BatchSizeFinderCallbackConfig() # Return defaults + + return v + + @field_validator("early_stopping", mode="after") + @classmethod + def validate_early_stopping(cls, v): + if isinstance(v, dict): + return EarlyStoppingCallbackConfig(**v) + elif v is True: + return EarlyStoppingCallbackConfig() # Return defaults + + return v + + @field_validator("lr_monitor", mode="after") + @classmethod + def validate_lr_monitor(cls, v): + if isinstance(v, dict): + return LearningRateMonitorCallbackConfig(**v) + elif v is True: + return LearningRateMonitorCallbackConfig() # Return defaults + + return v + + @field_validator("timer", mode="after") + @classmethod + def validate_timer(cls, v): + if isinstance(v, dict): + return TimerCallbackConfig(**v) + elif v is True: + return TimerCallbackConfig() # Return defaults + + return v + + +class LoggingConfig(BaseModel): + """Logging and experiment tracking configuration.""" + + project_name: str = "Radionets" + plot_n_epochs: int = Field(default=10, gt=0) + scale: bool = True + default_logger: CSVLoggerConfig = CSVLoggerConfig() + comet_ml: bool | CometLoggerConfig = False + mlflow: bool | MLFlowLoggerConfig = False + codecarbon: bool | CodeCarbonEmissionTrackerConfig = False + + @field_validator("default_logger", mode="after") + @classmethod + def validate_default_logger(cls, v): + if isinstance(v, dict): + return CSVLoggerConfig(**v) + + return v + + @field_validator("comet_ml", mode="after") + @classmethod + def validate_comet_ml(cls, v): + if isinstance(v, dict): + return CometLoggerConfig(**v) + elif v is True: + return CometLoggerConfig() # Return defaults + + return v + + @field_validator("mlflow", mode="after") + @classmethod + def validate_mlflow(cls, v): + if isinstance(v, dict): + return MLFlowLoggerConfig(**v) + elif v is True: + return MLFlowLoggerConfig() # Return defaults + + return v + + @field_validator("codecarbon", mode="after") + @classmethod + def validate_codecarbon(cls, v: bool | CodeCarbonEmissionTrackerConfig): + if isinstance(v, dict): + return CodeCarbonEmissionTrackerConfig( + **v, project_name=cls.logging.project_name + ) + elif v is True: + return CodeCarbonEmissionTrackerConfig( + project_name=cls.logging.project_name + ) + + # NOTE: CometML automatically logs with codecarbon + # if codecarbon is installed. This should ensure + # that codecarbon is only used when set in the config + os.environ["COMET_AUTO_LOG_CO2"] = "false" + + return v + + +class TrainConfig(BaseModel): + """Main training configuration.""" + + title: str = "Train configuration" + paths: PathsConfig = Field(default_factory=PathsConfig) + model: ModelConfig = Field(default_factory=ModelConfig) + training: TrainingConfig = Field(default_factory=TrainingConfig) + devices: DeviceConfig = Field(default_factory=DeviceConfig) + dataloader: DataLoaderConfig = Field(default_factory=DataLoaderConfig) + callbacks: CallbacksConfig = Field(default_factory=CallbacksConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) + + @classmethod + def from_toml(cls, path: str | Path) -> "TrainConfig": + """Load configuration from a TOML file.""" + with open(path, "rb") as f: + data = tomllib.load(f) + + return cls(**data) + + def to_dict(self) -> dict: + """Export configuration as a dictionary.""" + return self.model_dump() diff --git a/src/radionets/io/training/__init__.py b/src/radionets/io/training/__init__.py new file mode 100644 index 00000000..c07a9ab1 --- /dev/null +++ b/src/radionets/io/training/__init__.py @@ -0,0 +1,35 @@ +from ._accelerators import DeepSpeedConfig +from ._callbacks import ( + BatchSizeFinderCallbackConfig, + EarlyStoppingCallbackConfig, + LearningRateMonitorCallbackConfig, + ModelCheckpointCallbackConfig, + TimerCallbackConfig, +) +from ._logging import ( + CodeCarbonEmissionTrackerConfig, + CometLoggerConfig, + CSVLoggerConfig, + MLFlowLoggerConfig, +) +from ._training import ( + LossConfig, + LRSchedulerConfig, + OptimizerConfig, +) + +__all__ = [ + "BatchSizeFinderCallbackConfig", + "CSVLoggerConfig", + "CodeCarbonEmissionTrackerConfig", + "CometLoggerConfig", + "DeepSpeedConfig", + "EarlyStoppingCallbackConfig", + "LRSchedulerConfig", + "LearningRateMonitorCallbackConfig", + "LossConfig", + "MLFlowLoggerConfig", + "ModelCheckpointCallbackConfig", + "OptimizerConfig", + "TimerCallbackConfig", +] diff --git a/src/radionets/io/training/_accelerators.py b/src/radionets/io/training/_accelerators.py new file mode 100644 index 00000000..6fb07f52 --- /dev/null +++ b/src/radionets/io/training/_accelerators.py @@ -0,0 +1,52 @@ +from pathlib import Path + +from pydantic import BaseModel + +__all__ = [ + "DeepSpeedConfig", +] + + +class DeepSpeedConfig(BaseModel): + """Lightning DeepSpeedStrategy configuration""" + + zero_optimization: bool = True + stage: int = 2 + remote_device: str | None = None + offload_optimizer: bool = False + offload_parameters: bool = False + offload_params_device: str = "cpu" + nvme_path: str = "/local_nvme" + params_buffer_count: int = 5 + params_buffer_size: int = 100_000_000 + max_in_cpu: int = 1_000_000_000 + offload_optimizer_device: str = "cpu" + optimizer_buffer_count: int = 4 + block_size: int = 1048576 + queue_depth: int = 8 + single_submit: bool = False + overlap_events: bool = True + thread_count: int = 1 + pin_memory: bool = False + sub_group_size: int = 1_000_000_000_000 + contiguous_gradients: bool = True + overlap_comm: bool = True + allgather_partitions: bool = True + reduce_scatter: bool = True + allgather_bucket_size: int = 200_000_000 + reduce_bucket_size: int = 200_000_000 + zero_allow_untested_optimizer: bool = True + logging_batch_size_per_gpu: str | int = "auto" + config: Path | dict | None = None + logging_level: int = 30 + loss_scale: float = 0 + initial_scale_power: int = 16 + loss_scale_window: int = 1000 + hysteresis: int = 2 + min_loss_scale: int = 1 + partition_activations: bool = False + cpu_checkpointing: bool = False + contiguous_memory_optimization: bool = False + synchronize_checkpoint_boundary: bool = False + load_full_weights: bool = False + exclude_frozen_parameters: bool = False diff --git a/src/radionets/io/training/_callbacks.py b/src/radionets/io/training/_callbacks.py new file mode 100644 index 00000000..57423a8a --- /dev/null +++ b/src/radionets/io/training/_callbacks.py @@ -0,0 +1,71 @@ +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel + +__all__ = [ + "BatchSizeFinderCallbackConfig", + "EarlyStoppingCallbackConfig", + "LearningRateMonitorCallbackConfig", + "ModelCheckpointCallbackConfig", + "TimerCallbackConfig", +] + + +class ModelCheckpointCallbackConfig(BaseModel): + """Lightning ModelCheckpoint callback config""" + + dirpath: str | Path | None = None + filename: str | Path | None = None + monitor: str | None = None + verbose: bool = False + save_last: bool | Literal["link"] | None = None + save_top_k: int = 1 + save_on_exception: bool = False + save_weights_only: bool = False + mode: str = "min" + auto_insert_metric_name: bool = True + every_n_train_steps: int | None = None + every_n_epochs: int | None = None + save_on_train_epoch_end: bool | None = None + enable_version_counter: bool = True + + +class BatchSizeFinderCallbackConfig(BaseModel): + """Lightning BatchSizeFinder callback config""" + + mode: Literal["power", "binsearch"] = "power" + steps_per_trial: int = 3 + max_trials: int = 25 + + +class EarlyStoppingCallbackConfig(BaseModel): + """Lightning EarlyStopping callback config""" + + monitor: str = "val_loss" + min_delta: float = 0.0 + patience: int = 3 + verbose: bool = False + mode: str = "min" + strict: bool = True + check_finite: bool = True + stopping_threshold: float | None = None + divergence_threshold: float | None = None + check_on_train_epoch_end: bool | None = None + log_rank_zero_only: bool = False + + +class LearningRateMonitorCallbackConfig(BaseModel): + """Lightning LearningRateMonitor callback config""" + + logging_interval: str | None = "epoch" + log_momentum: bool = False + log_weight_decay: bool = False + + +class TimerCallbackConfig(BaseModel): + """Lightning Timer callback config""" + + duration: str | None = "14:00:00:00" + interval: Literal["epoch", "step"] = "epoch" + verbose: bool = True diff --git a/src/radionets/io/training/_logging.py b/src/radionets/io/training/_logging.py new file mode 100644 index 00000000..e175cb06 --- /dev/null +++ b/src/radionets/io/training/_logging.py @@ -0,0 +1,78 @@ +import os +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, SecretStr, field_validator + +__all__ = [ + "CSVLoggerConfig", + "CometLoggerConfig", + "MLFlowLoggerConfig", + "CodeCarbonEmissionTrackerConfig", +] + + +class CSVLoggerConfig(BaseModel): + """Lightning CSVLogger logging config""" + + name: str | None = "lightning_logs" + version: int | str | None = None + prefix: str = "" + flush_logs_every_n_steps: int = 100 + + +class CometLoggerConfig(BaseModel): + """Lightning CometLogger logging config""" + + api_key: SecretStr = SecretStr(os.getenv("COMET_API_KEY")) + workspace: str | None = None + experiment_key: str | None = None + mode: Literal["get_or_create", "get", "create"] | None = None + online: bool | None = None + prefix: str | None = None + + @field_validator("api_key", mode="before") + @classmethod + def validate_api_key(cls, key: SecretStr | None) -> SecretStr | None: + key = SecretStr(key) if key else None + + return key + + +class MLFlowLoggerConfig(BaseModel): + """Lightning MLFlowLogger logging config""" + + run_name: str | None = None + tracking_uri: str | None = "http://127.0.0.1:5000" + tags: dict | None = None + log_model: Literal[True, False, "all"] = False + prefix: str = "" + artifact_location: str | None = None + run_id: str | None = None + synchronous: bool | None = None + + +class CodeCarbonEmissionTrackerConfig(BaseModel): + """Codecarbon emission tracker configuration""" + + log_level: str | int = "error" + country_iso_code: str = "DEU" + output_dir: str | None = None + + @field_validator("output_dir", mode="after") + @classmethod + def expand_path(cls, v: str | Path) -> str: + """Expand and resolve paths.""" + + if not isinstance(v, Path): + v = Path(v) + + if v in {None, False}: + v = Path(os.getcwd()) + else: + v.expanduser().resolve() + + if not v.exists(): + v.mkdir(parents=True) + + return str(v) diff --git a/src/radionets/io/training/_training.py b/src/radionets/io/training/_training.py new file mode 100644 index 00000000..59512e53 --- /dev/null +++ b/src/radionets/io/training/_training.py @@ -0,0 +1,98 @@ +import inspect +from collections.abc import Callable + +import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from radionets.architecture import loss + +__all__ = ["LossConfig", "OptimizerConfig", "LRSchedulerConfig"] + + +class LossConfig(BaseModel): + loss_func: str | Callable = torch.nn.MSELoss + + model_config = ConfigDict(extra="allow") + + @field_validator("loss_func") + @classmethod + def load_loss_func_instance(cls, loss_func: str): + if isinstance(loss_func, type): + return loss_func + + avail_loss_funcs = {} + + for member in inspect.getmembers(torch.nn): + if inspect.isclass(member[1]): + avail_loss_funcs[member[0]] = member[1] + + for member in inspect.getmembers(loss): + if inspect.isclass(member[1]): + avail_loss_funcs[member[0]] = member[1] + + try: + loss_func = avail_loss_funcs[loss_func] + except KeyError as e: + raise ValueError( + f"Unknown optimizer: TrainConfig got {loss_func} but expected " + f"one of {set(avail_loss_funcs)}!" + ) from e + + return loss_func + + +class OptimizerConfig(BaseModel): + optimizer: str | Callable = torch.optim.AdamW + lr: float = Field(default=1e-3, gt=0.0) + + model_config = ConfigDict(extra="allow") + + @field_validator("optimizer") + @classmethod + def load_optimizer_instance(cls, optimizer: str): + avail_optimizers = {} + + for member in inspect.getmembers(torch.optim): + if inspect.isclass(member[1]): + avail_optimizers[member[0]] = member[1] + + try: + optimizer = avail_optimizers[optimizer] + except KeyError as e: + raise ValueError( + f"Unknown optimizer: TrainConfig got {optimizer} but expected " + f"one of {set(avail_optimizers)}!" + ) from e + + return optimizer + + +class LRSchedulerConfig(BaseModel): + """Learning rate scheduling configuration.""" + + scheduler: str | Callable = torch.optim.lr_scheduler.OneCycleLR + monitor: str = "train_loss" + interval: str = "step" + frequency: int = 1 + strict: bool = True + + model_config = ConfigDict(extra="allow") + + @field_validator("scheduler") + @classmethod + def load_optimizer_instance(cls, scheduler: str): + avail_schedulers = {} + + for member in inspect.getmembers(torch.optim.lr_scheduler): + if inspect.isclass(member[1]): + avail_schedulers[member[0]] = member[1] + + try: + scheduler = avail_schedulers[scheduler] + except KeyError as e: + raise ValueError( + f"Unknown optimizer: TrainConfig got {scheduler} but expected " + f"one of {set(avail_schedulers)}!" + ) from e + + return scheduler diff --git a/src/radionets/plotting/__init__.py b/src/radionets/plotting/__init__.py index afccf9f3..e69de29b 100644 --- a/src/radionets/plotting/__init__.py +++ b/src/radionets/plotting/__init__.py @@ -1,9 +0,0 @@ -from .hist import Hist -from .inspection import plot_loss, plot_lr, plot_lr_loss - -__all__ = [ - "Hist", - "plot_loss", - "plot_lr", - "plot_lr_loss", -] diff --git a/src/radionets/plotting/hist.py b/src/radionets/plotting/hist.py deleted file mode 100644 index 9b31809c..00000000 --- a/src/radionets/plotting/hist.py +++ /dev/null @@ -1,411 +0,0 @@ -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -from matplotlib.patches import Rectangle - -__all__ = ["Hist"] - - -class Hist: - def __init__( - self, - outpath, - plot_format: str = "png", - hist_kwargs: dict | None = None, - save_kwargs: dict | None = None, - ): - self.outpath = outpath - self.plot_format = plot_format - - if not Path(self.outpath).exists(): - Path(self.outpath).mkdir(parents=True, exist_ok=True) - - self.hist_kwargs = hist_kwargs - if not self.hist_kwargs: - self.hist_kwargs = dict( - color="darkorange", - linewidth=3, - histtype="step", - alpha=0.75, - ) - - self.save_kwargs = save_kwargs - if not self.save_kwargs: - self.save_kwargs = dict( - bbox_inches="tight", - pad_inches=0.01, - dpi=150, - ) - - def _preproc_vals( - self, vals: torch.Tensor | np.ndarray - ) -> tuple[np.ndarray, float, float]: - if torch.is_tensor(vals): - vals = vals.numpy() - - mean = vals.mean() - # NOTE: passing the mean to std() prevents its recalculation - std = vals.std(ddof=1, mean=mean) - - return vals, mean, std - - def _add_mean_std_text(self, ax: plt.axes, mean: float, std: float): - ax.text( - 0.1, - 0.8, - f"Mean: {mean:.2f}\nStd: {std:.2f}", - horizontalalignment="left", - verticalalignment="center", - transform=ax.transAxes, - bbox=dict( - boxstyle="round", - facecolor="white", - edgecolor="lightgray", - alpha=0.8, - ), - ) - - def _get_rect_patch(self) -> Rectangle | Rectangle: - kwargs = dict( - width=1, - height=1, - fc="w", - fill=False, - edgecolor="#1f77b4", - linewidth=2, - ) - - rect_1 = Rectangle((0, 0), **kwargs) - rect_2 = Rectangle((0, 0), **kwargs) - return rect_1, rect_2 - - def area( - self, - vals: torch.tensor, - bins: int = 30, - return_fig: bool = False, - ): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(1, figsize=(6, 4)) - - ax.hist( - vals, - bins=bins, - **self.hist_kwargs, - ) - ax.axvline(1, color="red", linestyle="dashed") - ax.set(xlabel="Ratio of areas", ylabel="Number of sources") - - self._add_mean_std_text(ax, mean, std) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/hist_area.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def dynamic_ranges( - self, - dr_truth: torch.tensor, - dr_pred: torch.tensor, - return_fig: bool = False, - ): - fig, ax = plt.subplots(2, 1, figsize=(6, 12), layout="constrained") - ax[0].hist(dr_truth, 51, **self.hist_kwargs) - ax[0].set( - title="True Images", - xlabel="Dynamic range", - ylabel="Number of sources", - ) - - ax[1].hist(dr_pred, 25, **self.hist_kwargs) - ax[1].set( - title="Predictions", - xlabel="Dynamic range", - ylabel="Number of sources", - ) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/dynamic_ranges.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def gan_sources( - self, - ratio, - num_zero, - above_zero, - below_zero, - num_images, - ): - bins = np.arange(0, ratio.max() + 0.1, 0.1) - - fig, ax = plt.subplots(1, layout="constrained") - ax.hist( - ratio, - bins=bins, - histtype="step", - label=f"mean: {ratio.mean():.2f}, max: {ratio.max():.2f}", - ) - ax.set( - xlabel=r"Maximum difference to maximum true flux ratio", - ylabel=r"Number of sources", - ) - ax.legend(loc="best") - - outpath = str(self.outpath) + f"/ratio.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - plt.close(fig) - - fig, ax = plt.subplots(1, layout="constrained") - bins = np.arange(0, 102, 2) - num_zero = num_zero.reshape(4, num_images) - - for i, label in enumerate(["1e-4", "1e-3", "1e-2", "1e-1"]): - ax.hist(num_zero[i], bins=bins, histtype="step", label=label) - - ax.set( - xlabel=r"Proportion of pixels close to 0 / %", - ylabel=r"Number of sources", - ) - ax.legend(loc="upper center") - - outpath = str(self.outpath) + f"/num_zeros.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - plt.close(fig) - - fig, ax = plt.subplots(1, layout="constrained") - bins = np.arange(0, 102, 2) - ax.hist( - above_zero, - bins=bins, - histtype="step", - label=f"Above, mean: {above_zero.mean():.2f}%, max: {above_zero.max():.2f}%", # noqa: E501 - ) - ax.hist( - below_zero, - bins=bins, - histtype="step", - label=f"Below, mean: {below_zero.mean():.2f}%, max: {below_zero.max():.2f}%", # noqa: E501 - ) - ax.set( - xlabel=r"Proportion of pixels below or above 0%", - ylabel=r"Number of sources", - ) - ax.legend(loc="upper center") - - outpath = str(self.outpath) + f"/above_below.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def jet_angles(self, vals: torch.tensor, return_fig: bool = False): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - ax[0].hist(vals, 51, **self.hist_kwargs) - ax[0].set( - xlabel="Offset / deg", - ylabel="Number of sources", - ) - - extra_1, extra_2 = self._get_rect_patch() - ax[0].legend([extra_1, extra_2], (f"Mean: {mean:.2f}", f"Std: {std:.2f}")) - - ax[1].hist(vals[(vals > -10) & (vals < 10)], 25, **self.hist_kwargs) - ax[1].set( - xticks=[-10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10], - xlabel="Offset / deg", - ylabel="Number of sources", - ) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/jet_offsets.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def jet_gaussian_distance(self, dist: torch.tensor, return_fig: bool = False): - """ - Plotting the distances between predicted and true component of several images. - Parameters - ---------- - dist: 2d array - array of shape (n, 2), where n is the number of distances - """ - - ran = [0, 50] - - fig, ax = plt.subplots(1, layout="constrained") - - for i in range(10): - ax.hist( - dist[dist[:, 0] == i][:, 1], - bins=20, - range=ran, - alpha=0.7, - label=f"Component {i}", - ) - - ax.set(xlabel="Distance", ylabel="Counts") - ax.legend() - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/hist_jet_gaussian_distance.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def mean_diff(self, vals: torch.tensor, return_fig: bool = False): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(1, figsize=(6, 4)) - - ax.hist(vals, 51, **self.hist_kwargs) - ax.set( - xlabel="Mean flux deviation / %", - ylabel="Number of sources", - ) - - extra_1, extra_2 = self._get_rect_patch() - ax.legend([extra_1, extra_2], (f"Mean: {mean:.2f}", f"Std: {std:.2f}")) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/mean_diff.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def ms_ssim( - self, - vals: torch.tensor, - bins: int = 30, - return_fig: bool = False, - ): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(1, figsize=(6, 4), layout="constrained") - ax.hist(vals, bins=bins, **self.hist_kwargs) - ax.set( - xlabel="ms ssim", - ylabel="Number of sources", - ) - - self._add_mean_std_text(ax, mean, std) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/ms_ssim.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def peak_intensity( - self, - vals: torch.tensor, - bins: int = 30, - return_fig: bool = False, - ): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(1, figsize=(6, 4), layout="constrained") - - ax.hist(vals, bins=bins, **self.hist_kwargs) - ax.axvline(1, color="red", linestyle="dashed") - ax.set( - xlabel="Ratio of peak flux densities", - ylabel="Number of sources", - ) - - self._add_mean_std_text(ax, mean, std) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/intensity_peak.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def point(self, vals: torch.tensor, mask: torch.tensor, return_fig: bool = False): - binwidth = 5 - min_all = vals.min() - bins = np.arange(min_all, 100 + binwidth, binwidth) - - mean_point = np.mean(vals[mask]) - std_point = np.std(vals[mask], ddof=1) - mean_extent = np.mean(vals[~mask]) - std_extent = np.std(vals[~mask], ddof=1) - - fig, ax = plt.subplots(1, figsize=(6, 4), layout="constrained") - - ax.hist(vals[mask], bins=bins, **self.hist_kwargs) - ax.hist(vals[~mask], bins=bins, **self.hist_kwargs) - - ax.axvline(0, linestyle="dotted", color="red") - ax.set( - xlabel="Mean specific intensity deviation", - ylabel="Number of sources", - ) - - extra_1, extra_2 = self._get_rect_patch() - ax.legend( - [extra_1, extra_2], - [ - rf"Point: $({mean_point:.2f}\pm{std_point:.2f})\,\%$", - rf"Extended: $({mean_extent:.2f}\pm{std_extent:.2f})\,\%$", - ], - ) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/hist_point.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def sum_intensity( - self, - vals: torch.tensor, - bins: int = 30, - return_fig: bool = False, - ): - vals, mean, std = self._preproc_vals(vals) - - fig, ax = plt.subplots(1, figsize=(6, 4), layout="constrained") - - ax.hist(vals, bins=bins, **self.hist_kwargs) - ax.axvline(1, color="red", linestyle="dashed") - ax.set( - xlabel="Ratio of integrated flux densities", - ylabel="Number of sources", - ) - - self._add_mean_std_text(ax, mean, std) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/intensity_sum.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) - - def unc(self, vals: torch.tensor, return_fig: bool = False): - vals, mean, std = self._preproc_vals(vals) - - bins = np.arange(0, 105, 5) - - fig, ax = plt.subplots(1, figsize=(6, 4), layout="constrained") - - ax.hist(vals, bins=bins, **self.hist_kwargs) - ax.set( - xlabel="Percentage of matching pixels", - ylabel="Number of sources", - ) - - self._add_mean_std_text(ax, mean, std) - - if return_fig: - return fig, ax - - outpath = str(self.outpath) + f"/hist_unc.{self.plot_format}" - fig.savefig(outpath, **self.save_kwargs) diff --git a/src/radionets/plotting/inspection.py b/src/radionets/plotting/inspection.py deleted file mode 100644 index 268c21f4..00000000 --- a/src/radionets/plotting/inspection.py +++ /dev/null @@ -1,98 +0,0 @@ -from pathlib import Path - -import matplotlib as mpl -import matplotlib.pyplot as plt - -from radionets.core.logging import setup_logger - -LOGGER = setup_logger(namespace=__name__) - - -def plot_loss(learn, model_path: str | Path, output_format: str = "png") -> None: - """ - Plot train and valid loss of model. - - Parameters - ---------- - learn : learner-object - learner containing data and model - model_path : str - path to trained model - """ - if isinstance(model_path, str): - model_path = Path(model_path) - - save_path = model_path.with_suffix("") - LOGGER.info(f"Plotting Loss for: {model_path.stem}") - - logscale = learn.avg_loss.plot_loss() - title = str(model_path.stem).replace("_", " ") - plt.title(rf"{title}") - - if logscale: - plt.yscale("log") - - plt.savefig( - f"{save_path}_loss.{output_format}", bbox_inches="tight", pad_inches=0.01 - ) - plt.clf() - - mpl.rcParams.update(mpl.rcParamsDefault) - - -def plot_lr(learn, model_path: str | Path, output_format: str = "png") -> None: - """ - Plot learning rate of model. - - Parameters - ---------- - learn : learner-object - learner containing data and model - model_path : str or Path - path to trained model - output_format : - """ - if isinstance(model_path, str): - model_path = Path(model_path) - - save_path = model_path.with_suffix("") - LOGGER.info(f"Plotting Learning rate for: {model_path.stem}") - - learn.avg_loss.plot_lrs() - - plt.savefig(f"{save_path}_lr.{output_format}", bbox_inches="tight", pad_inches=0.01) - plt.clf() - - mpl.rcParams.update(mpl.rcParamsDefault) - - -def plot_lr_loss( - learn, arch_name: str, out_path: str | Path, skip_last, output_format="png" -): - """ - Plot loss of learning rate finder. - - Parameters - ---------- - learn : learner-object - learner containing data and model - arch_path : str - name of the architecture - out_path : str - path to save loss plot - skip_last : int - skip n last points - """ - if isinstance(out_path, str): - out_path = Path(out_path) - - LOGGER.info(f"Plotting Lr vs Loss for architecture: {arch_name}") - - learn.recorder.plot_lr_find() - out_path.mkdir(parents=True, exist_ok=True) - - plt.savefig( - out_path / f"lr_loss.{output_format}", bbox_inches="tight", pad_inches=0.01 - ) - - mpl.rcParams.update(mpl.rcParamsDefault) diff --git a/src/radionets/plotting/utils.py b/src/radionets/plotting/utils.py new file mode 100644 index 00000000..5f04c44f --- /dev/null +++ b/src/radionets/plotting/utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from mpl_toolkits.axes_grid1 import make_axes_locatable + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + from matplotlib.image import AxesImage + from numpy.typing import ArrayLike + + +def set_cbar( + fig: Figure, + ax: Axes, + image: AxesImage, + title: str, + phase: bool = False, + unc: bool = False, +) -> None: + """Create nice colorbars with bigger label size + for every axis in a subplot. Also use ticks for the phase. + + Parameters + ---------- + fig : :class:`~matplotlib.figure.Figure` + Current figure object. + ax : :class:`~matplotlib.axes.Axes` + Current axis object. + image : :class:`~matplotlib.image.AxesImage` + Plotted image. + title : str + Title of subplot. + phase : bool, optional + If ``True``, sets colorbar to units of π. Default: False + unc : bool, optional + If ``True``, sets colorbar label to uncertainty. + """ + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + ax.set_title(title) + + if phase: + cbar = fig.colorbar(image, cax=cax, orientation="vertical", label="Phase / rad") + cbar.set_ticks( + ticks=[-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi], + labels=[r"$-\pi$", r"$-\pi/2$", r"$0$", r"$\pi/2$", r"$\pi$"], + ) + elif unc: + cbar = fig.colorbar( + image, + cax=cax, + orientation="vertical", + label=r"$\sigma$ / $\mathrm{Jy \cdot px^{-1}}$", + ) + else: + cbar = fig.colorbar( + image, + cax=cax, + orientation="vertical", + label=r"$\mathrm{Flux \ density / Jy \cdot px^{-1}}$", + ) + + +def get_vmin_vmax(image: ArrayLike): + """Check whether the absolute of the maximum or the minimum is bigger. + If the minimum is bigger, return value with negative sign. Otherwise return + maximum. + + Parameters + ---------- + image : array_like + Input image. + Returns + ------- + float + Negative minimum value or maximum value otherwise. + """ + a = -image.min() if np.abs(image.min()) > np.abs(image.max()) else image.max() + return a diff --git a/src/radionets/plotting/visualization.py b/src/radionets/plotting/visualization.py deleted file mode 100644 index 1874c265..00000000 --- a/src/radionets/plotting/visualization.py +++ /dev/null @@ -1,873 +0,0 @@ -from math import pi -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -from matplotlib.colors import LogNorm, PowerNorm -from mpl_toolkits.axes_grid1 import make_axes_locatable -from pytorch_msssim import ms_ssim -from tqdm import tqdm - -from radionets.evaluation.contour import compute_area_ratio -from radionets.evaluation.dynamic_range import calc_dr, get_boxsize -from radionets.evaluation.utils import check_vmin_vmax, make_axes_nice, reshape_2d - - -def plot_target(h5_dataset, log=False): - index = np.random.randint(len(h5_dataset) - 1) - - plt.figure(figsize=(5.78, 3.57)) - - target = reshape_2d(h5_dataset[index][1]).squeeze(0) - if log: - plt.imshow(target, norm=LogNorm()) - else: - plt.imshow(target) - - plt.xlabel("Pixels") - plt.ylabel("Pixels") - plt.colorbar(label="Intensity / a.u.") - - -def plot_inp_tar(h5_dataset, fourier=False, amp_phase=False): - index = np.random.randint(len(h5_dataset) - 1) - - if fourier is False: - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14.45, 3.57)) - - inp1 = h5_dataset[index][0][0] - lim1 = check_vmin_vmax(inp1) - im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1) - make_axes_nice(fig, ax1, im1, "Input: real part") - - inp2 = h5_dataset[index][0][1] - lim2 = check_vmin_vmax(inp2) - im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2) - make_axes_nice(fig, ax2, im2, "Input: imaginary part") - - tar = reshape_2d(h5_dataset[index][1]).squeeze(0) - im3 = ax3.imshow(tar, cmap="inferno") - make_axes_nice(fig, ax3, im3, "Target: source image") - - if fourier is True: - if amp_phase is False: - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92)) - - inp1 = h5_dataset[index][0][0] - lim1 = check_vmin_vmax(inp1) - im1 = ax1.imshow(inp1, cmap="RdBu", vmin=-lim1, vmax=lim1) - make_axes_nice(fig, ax1, im1, "Input: real part") - - inp2 = h5_dataset[index][0][1] - lim2 = check_vmin_vmax(inp2) - im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-lim2, vmax=lim2) - make_axes_nice(fig, ax2, im2, "Input: imaginary part") - - tar1 = h5_dataset[index][1][0] - lim_t1 = check_vmin_vmax(tar1) - im3 = ax3.imshow(tar1, cmap="RdBu", vmin=-lim_t1, vmax=lim_t1) - make_axes_nice(fig, ax3, im3, "Target: real part") - - tar2 = h5_dataset[index][1][1] - lim_t2 = check_vmin_vmax(tar2) - im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-lim_t2, vmax=lim_t2) - make_axes_nice(fig, ax4, im4, "Target: imaginary part") - - if amp_phase is True: - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14.45, 8.92)) - - inp1 = h5_dataset[index][0][0] - im1 = ax1.imshow(inp1, cmap="inferno") - make_axes_nice(fig, ax1, im1, "Input: amplitude") - - inp2 = h5_dataset[index][0][1] - lim2 = check_vmin_vmax(inp2) - im2 = ax2.imshow(inp2, cmap="RdBu", vmin=-pi, vmax=pi) - make_axes_nice(fig, ax2, im2, "Input: phase") - - tar1 = h5_dataset[index][1][0] - im3 = ax3.imshow(tar1, cmap="inferno") - make_axes_nice(fig, ax3, im3, "Target: amplitude") - - tar2 = h5_dataset[index][1][1] - im4 = ax4.imshow(tar2, cmap="RdBu", vmin=-pi, vmax=pi) - make_axes_nice(fig, ax4, im4, "Target: phase") - - -def visualize_with_fourier( - i: int, - img_input: torch.tensor, - img_pred: torch.tensor, - img_truth: torch.tensor, - amp_phase: bool, - out_path: Path, - plot_format: str = "png", - return_fig: bool = False, - kwargs: list[dict] | None = None, -): - """Visualizes how the target variables are displayed in fourier space. - - Parameters - ---------- - i : int - Current index given form the loop - img_input : :func:`torch.tensor` - Current input image as a :func:`~numpy.array` or - :func:`~torch.tensor` with shape [M, N] - img_pred : :func:`torch.tensor` - Current prediction image as a :func:`~numpy.array` or - :func:`~torch.tensor` with shape [M, N] - img_truth : :func:`torch.tensor` - Current true image as a :func:`~numpy.array` or - :func:`~torch.tensor` with shape [M, N] - amp_phase : bool - Whether the image contains real/imaginary information - or amplitude/phase information. - out_path : str which contains the output path - Output path of the figure. Skipped if ``return_fig`` is - set to ``True``. - plot_format : str, optional - Output file format. Default: png - return_fig : bool, optional - Whether to return the :func:`~matplotlib.pyplot.figure` object - instead of saving the figure to a file. Default: ``False`` - **kwargs : list[dict] or None, optional - Additional list of dictionaries with keyword arguments - for each subplot. Default: ``None`` - - Returns - ------- - fig : :func:`~matplotlib.pyplot.figure` - Figure object if ``return_fig`` is set to ``True``. - """ - # reshaping and splitting in real and imaginary part if necessary - inp_real, inp_imag = img_input[0], img_input[1] - real_pred, imag_pred = img_pred[0], img_pred[1] - real_truth, imag_truth = img_truth[0], img_truth[1] - - if not kwargs: - kwargs = [{}] * 8 - - a = check_vmin_vmax(inp_imag) - if amp_phase: - __defaults = dict( - cmap=["inferno"] * 3 + ["radionets.PuOr"] * 5, - vmin=[None, None, None, None, -a, -np.pi, -np.pi, None], - vmax=[None, None, None, None, a, np.pi, np.pi, None], - name=["Amplitude"] * 4 + ["Phase"] * 4, - ) - else: - __defaults = dict( - cmap=["radionets.PuOr"] * 8, - vmin=[None] * 8, - vmax=[None] * 8, - name=["Real"] * 4 + ["Imaginary"] * 4, - ) - - for i, kwarg in enumerate(kwargs): - if "cmap" not in kwarg: - kwarg["cmap"] = __defaults["cmap"][i] - if "vmin" not in kwarg: - kwarg["vmin"] = __defaults["vmin"][i] - if "vmax" not in kwarg: - kwarg["vmax"] = __defaults["vmax"][i] - - fig, ax = plt.subplots(2, 4, figsize=(16, 10), sharex=True, sharey=True) - ax = ax.ravel() - - im1 = ax[0].imshow(inp_real, **kwargs[0]) - make_axes_nice(fig, ax[0], im1, f"{__defaults['name'][0]} Input") - - im2 = ax[1].imshow(real_pred, **kwargs[1]) - make_axes_nice(fig, ax[1], im2, f"{__defaults['name'][1]} Prediction") - - im3 = ax[2].imshow(real_truth, **kwargs[2]) - make_axes_nice(fig, ax[2], im3, f"{__defaults['name'][2]} Truth") - - im4 = ax[3].imshow(real_truth - real_pred, **kwargs[3]) - make_axes_nice(fig, ax[3], im4, f"{__defaults['name'][3]} Difference") - - im5 = ax[4].imshow(inp_imag, **kwargs[4]) - make_axes_nice( - fig, - ax[4], - im5, - f"{__defaults['name'][4]} Input", - phase=bool(amp_phase), - ) - - im6 = ax[5].imshow(imag_pred, **kwargs[5]) - make_axes_nice( - fig, - ax[5], - im6, - f"{__defaults['name'][5]} Prediction", - phase=bool(amp_phase), - ) - - im7 = ax[6].imshow(imag_truth, **kwargs[6]) - make_axes_nice( - fig, - ax[6], - im7, - f"{__defaults['name'][6]} Truth", - phase=bool(amp_phase), - ) - - im8 = ax[7].imshow(imag_truth - imag_pred, **kwargs[7]) - make_axes_nice(fig, ax[7], im8, f"{__defaults['name'][7]} Difference") - - ax[0].set_ylabel("Pixels") - ax[4].set_ylabel("Pixels") - - for axs in ax[4:]: - axs.set_xlabel("Pixels") - - if return_fig: - return fig, ax - - outpath = str(out_path) + f"/prediction_{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) - - -def visualize_with_fourier_diff( - i, - img_pred, - img_truth, - amp_phase, - out_path, - plot_format="png", -): - """ - Visualizing, if the target variables are displayed in fourier space. - - Parameters - ---------- - i : int - Current index given form the loop - img_input : array_like - Current input image as a numpy array in shape (2*img_size^2) - img_pred : array_like - Current prediction image as a numpy array with shape (2*img_size^2) - img_truth: array_like - Current true image as a numpy array with shape (2*img_size^2) - out_path: str - Which contains the output path - """ - # reshaping and splitting in real and imaginary part if necessary - real_pred, imag_pred = img_pred[0], img_pred[1] - real_truth, imag_truth = img_truth[0], img_truth[1] - - # plotting - # plt.style.use('./paper_large_3_2.rc') - fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots( - 2, 3, figsize=(16, 10), sharex=True, sharey=True - ) - - if amp_phase: - im1 = ax1.imshow(real_pred, cmap="inferno") - make_axes_nice(fig, ax1, im1, r"Amplitude Prediction") - - im2 = ax2.imshow(real_truth, cmap="inferno") - make_axes_nice(fig, ax2, im2, r"Amplitude Truth") - - a = check_vmin_vmax(real_pred - real_truth) - im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) - make_axes_nice(fig, ax3, im3, r"Amplitude Difference") - - a = check_vmin_vmax(imag_truth) - im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi) - make_axes_nice(fig, ax4, im4, r"Phase Prediction", phase=True) - - a = check_vmin_vmax(imag_truth) - im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr", vmin=-np.pi, vmax=np.pi) - make_axes_nice(fig, ax5, im5, r"Phase Truth", phase=True) - - a = check_vmin_vmax(imag_pred - imag_truth) - im6 = ax6.imshow( - imag_pred - imag_truth, - cmap="radionets.PuOr", - vmin=-2 * np.pi, - vmax=2 * np.pi, - ) - make_axes_nice(fig, ax6, im6, r"Phase Difference", phase_diff=True) - - else: - im1 = ax1.imshow(real_pred, cmap="inferno") - make_axes_nice(fig, ax1, im1, r"Real Prediction") - - im2 = ax2.imshow(real_truth, cmap="inferno") - make_axes_nice(fig, ax2, im2, "Real Truth") - - a = check_vmin_vmax(real_pred - real_truth) - im3 = ax3.imshow(real_pred - real_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) - make_axes_nice(fig, ax3, im3, r"Real Difference") - - im4 = ax4.imshow(imag_pred, cmap="radionets.PuOr") - make_axes_nice(fig, ax4, im4, r"Imaginary Prediction") - - im5 = ax5.imshow(imag_truth, cmap="radionets.PuOr") - make_axes_nice(fig, ax5, im5, r"Imaginary Truth") - - im6 = ax6.imshow(imag_pred - imag_truth, cmap="radionets.PuOr") - make_axes_nice(fig, ax6, im6, r"Imaginary Difference") - - ax1.set_ylabel(r"Pixels") - ax4.set_ylabel(r"Pixels") - ax4.set_xlabel(r"Pixels") - ax5.set_xlabel(r"Pixels") - ax6.set_xlabel(r"Pixels") - - outpath = str(out_path) + f"/prediction_{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - plt.close("all") - - -def visualize_source_reconstruction( - ifft_pred, - ifft_truth, - out_path, - i, - dr=False, - msssim=False, - plot_format="png", -): - # plt.style.use("./paper_large_3.rc") - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 10), sharey=True) - - # Plot prediction - im1 = ax1.imshow(ifft_pred, vmax=ifft_truth.max(), cmap="inferno") - - # Plot truth - im2 = ax2.imshow(ifft_truth, cmap="inferno") - - a = check_vmin_vmax(ifft_pred - ifft_truth) - im3 = ax3.imshow(ifft_pred - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) - - make_axes_nice(fig, ax1, im1, r"FFT Prediction") - make_axes_nice(fig, ax2, im2, r"FFT Truth") - make_axes_nice(fig, ax3, im3, r"FFT Diff") - - ax1.set_ylabel(r"Pixels") - ax1.set_xlabel(r"Pixels") - ax2.set_xlabel(r"Pixels") - ax3.set_xlabel(r"Pixels") - - if dr: - dr_truth, dr_pred, num_boxes, corners = calc_dr( - ifft_truth[None, ...], ifft_pred[None, ...] - ) - ax1.plot([], [], " ", label=f"DR: {int(dr_pred[0])}") - ax2.plot([], [], " ", label=f"DR: {int(dr_truth[0])}") - - plot_box(ax1, num_boxes, corners[0]) - plot_box(ax2, num_boxes, corners[0]) - - if msssim: - val = ms_ssim( - torch.tensor(ifft_pred).unsqueeze(0).unsqueeze(0), - torch.tensor(ifft_truth).unsqueeze(0).unsqueeze(0), - data_range=1, - win_size=7, - size_average=False, - ) - val = val.numpy()[0] - ax1.plot([], [], " ", label=f"MS-SSIM: {val:.2f}") - ax1.legend(loc="best") - - outpath = str(out_path) + f"/fft_pred_{i}.{plot_format}" - - plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - plt.close("all") - return np.abs(ifft_pred), np.abs(ifft_truth) - - -def visualize_uncertainty( - i, img_pred, img_truth, img_unc, amp_phase, out_path, plot_format="png" -): - pred_amp, pred_phase = img_pred[0], img_pred[1] - true_amp, true_phase = img_truth[0], img_truth[1] - unc_amp, unc_phase = img_unc[0], img_unc[1] - - # amplitude - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( - 2, 2, sharey=True, sharex=True, figsize=(12, 10) - ) - - im1 = ax1.imshow(true_amp) - - im2 = ax2.imshow(pred_amp) - - im3 = ax3.imshow(unc_amp) - - a = check_vmin_vmax(true_amp - pred_amp) - im4 = ax4.imshow(true_amp - pred_amp, cmap="radionets.PuOr", vmin=-a, vmax=a) - - make_axes_nice(fig, ax1, im1, r"Simulation") - make_axes_nice(fig, ax2, im2, r"Predicted $\mu$") - make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True) - make_axes_nice(fig, ax4, im4, r"Difference") - - ax1.set_ylabel(r"pixels") - ax3.set_ylabel(r"pixels") - ax3.set_xlabel(r"pixels") - ax4.set_xlabel(r"pixels") - - outpath = str(out_path) + f"/unc_amp{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - - # phase - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( - 2, 2, sharey=True, sharex=True, figsize=(12, 10) - ) - - im1 = ax1.imshow(true_phase, cmap="radionets.PuOr") - - im2 = ax2.imshow(pred_phase, cmap="radionets.PuOr") - - im3 = ax3.imshow(unc_phase) - - a = check_vmin_vmax(true_phase - pred_phase) - im4 = ax4.imshow(true_phase - pred_phase, cmap="radionets.PuOr", vmin=-a, vmax=a) - - make_axes_nice(fig, ax1, im1, r"Simulation") - make_axes_nice(fig, ax2, im2, r"Predicted $\mu$") - make_axes_nice(fig, ax3, im3, r"Predicted $\sigma^2$", unc=True) - make_axes_nice(fig, ax4, im4, r"Difference") - - ax1.set_ylabel(r"pixels") - ax3.set_ylabel(r"pixels") - ax3.set_xlabel(r"pixels") - ax4.set_xlabel(r"pixels") - - outpath = str(out_path) + f"/unc_phase{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - plt.close("all") - - -def visualize_sampled_unc(i, mean, std, ifft_truth, out_path, plot_format): - # plt.style.use('../paper_large_3.rc') - fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( - 2, 2, figsize=(12, 10), sharey=True, sharex=True - ) - - im1 = ax1.imshow(ifft_truth) - im2 = ax2.imshow(mean) - im3 = ax3.imshow(std) - a = check_vmin_vmax(mean - ifft_truth) - im4 = ax4.imshow(mean - ifft_truth, cmap="radionets.PuOr", vmin=-a, vmax=a) - - ax1.text( - 90, - 110, - "Simulation", - ha="center", - size=9, - bbox=dict( - boxstyle="round", - fc="w", - ec="gray", - alpha=0.75, - ), - ) - ax2.text( - 90, - 110, - "Prediction", - ha="center", - size=9, - bbox=dict( - boxstyle="round", - fc="w", - ec="gray", - alpha=0.75, - ), - ) - ax3.text( - 90, - 110, - "Uncertainty", - ha="center", - size=9, - bbox=dict( - boxstyle="round", - fc="w", - ec="gray", - alpha=0.75, - ), - ) - ax4.text( - 90, - 110, - "Difference", - ha="center", - size=9, - bbox=dict( - boxstyle="round", - fc="w", - ec="gray", - alpha=0.75, - ), - ) - - make_axes_nice(fig, ax1, im1, r"Simulation") - make_axes_nice(fig, ax2, im2, r"Prediction") - make_axes_nice(fig, ax3, im3, r"Uncertainty", unc=True) - make_axes_nice(fig, ax4, im4, r"Difference") - - ax1.set_ylabel(r"pixels") - ax3.set_xlabel(r"pixels") - ax3.set_ylabel(r"pixels") - ax4.set_xlabel(r"pixels") - outpath = str(out_path) + f"/unc_samp{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - plt.close("all") - - -def plot_contour( - ifft_pred, - ifft_truth, - out_path, - i, - plot_format="png", - norm_scale: float = 0.4, - labels: list | None = None, - colors: list | None = None, - levels: list | None = None, -): - if not labels: - labels = ["5%", "10%", "30%", "50%", "80%"] - if not colors: - colors = ["#454CC7", "#1984DE", "#50B3D7", "#ABD9DC", "#FFFFFF"] - - if not levels: - levels = [ - ifft_truth.max() * 0.05, - ifft_truth.max() * 0.1, - ifft_truth.max() * 0.3, - ifft_truth.max() * 0.5, - ifft_truth.max() * 0.8, - ] - - fig, ax = plt.subplots(1, 2, figsize=(10, 8), sharey=True) - - im1 = ax[0].imshow( - ifft_pred, - cmap="inferno", - norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()), - ) - CS1 = ax[0].contour(ifft_pred, levels=levels, colors=colors) - make_axes_nice(fig, ax[0], im1, "Prediction") - - im2 = ax[1].imshow( - ifft_truth, - cmap="inferno", - norm=PowerNorm(norm_scale, vmin=ifft_truth.min(), vmax=ifft_truth.max()), - ) - CS2 = ax[1].contour(ifft_truth, levels=levels, colors=colors) - diff = np.round(compute_area_ratio(CS1, CS2), 2) - make_axes_nice(fig, ax[1], im2, f"Truth, ratio: {diff}") - outpath = str(out_path) + f"/contour_{diff}_{i}.{plot_format}" - - cl1, _ = CS1.legend_elements() - cl2, _ = CS2.legend_elements() - - # plotting legend - ax[0].legend(cl1, labels, loc="best") - ax[1].legend(cl2, labels, loc="best") - - ax[0].set_ylabel(r"Pixels") - ax[0].set_xlabel(r"Pixels") - ax[1].set_xlabel(r"Pixels") - - plt.savefig(outpath, bbox_inches="tight", pad_inches=0.05) - plt.close("all") - - -def plot_box(ax, num_boxes, corners): - size = get_boxsize(num_boxes) - img_size = 64 - if corners[2]: - ax.axvspan( - xmin=0, - xmax=size, - ymin=(img_size - size) / img_size, - ymax=0.99, - color="red", - fill=False, - ) - if corners[3]: - ax.axvspan( - xmin=img_size - size, - xmax=img_size - 1, - ymin=(img_size - size) / img_size, - ymax=0.99, - color="red", - fill=False, - ) - if corners[0]: - ax.axvspan( - xmin=0, - xmax=size, - ymin=0.01, - ymax=(size) / img_size, - color="red", - fill=False, - ) - if corners[1]: - ax.axvspan( - xmin=img_size - size, - xmax=img_size - 1, - ymin=0.01, - ymax=(size) / img_size, - color="red", - fill=False, - ) - - -def plot_length_point(length, vals, mask, out_path, plot_format="png"): - fig, (ax1) = plt.subplots(1, figsize=(6, 4)) - ax1.plot( - length[mask], - vals[mask], - ".", - markersize=1, - color="darkorange", - label="Point sources", - ) - ax1.plot( - length[~mask], - vals[~mask], - ".", - markersize=1, - color="#1f77b4", - label="Extended sources", - ) - ax1.set_ylabel("Mean specific intensity deviation") - ax1.set_xlabel("Linear extent / pixels") - plt.grid() - plt.legend(loc="best", markerscale=10) - - outpath = str(out_path) + "/extend_point.png" - plt.savefig(outpath, bbox_inches="tight", pad_inches=0.01, dpi=150) - - -def plot_jet_results(inp, pred, truth, path, save=False, plot_format="pdf"): - """ - Plot input images, prediction, true and diff image of the overall prediction. - (Not component wise) - - Parameters - ---------- - inp : n 4d arrays with 1 channel - input images - pred : n 4d arrays with multiple channels - predicted images - truth : n 4d arrays with multiple channels - true images - """ - if truth.shape[1] > 2: - truth = torch.sum(truth[:, 0:-1], axis=1) - pred = torch.sum(pred[:, 0:-1], axis=1) - elif truth.shape[1] == 2: - truth = truth[:, 0:-1].squeeze() - pred = pred[:, 0:-1].squeeze() - else: - truth = truth.squeeze() - pred = pred.squeeze() - - for i in tqdm(range(len(inp))): - fig, ax = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(4, 7)) - - im1 = ax[0].imshow(inp[i, 0], cmap=plt.cm.inferno) - ax[0].set_xlabel(r"Pixels") - ax[0].set_ylabel(r"Pixels") - divider = make_axes_locatable(ax[0]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im1, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - diff = pred[i] - truth[i] - im2 = ax[1].imshow(diff, cmap=plt.cm.inferno) - ax[1].set_xlabel(r"Pixels") - ax[1].set_ylabel(r"Pixels") - divider = make_axes_locatable(ax[1]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im2, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - if save: - Path(path).mkdir(parents=True, exist_ok=True) - outpath = str(path) + f"/prediction_{i}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) - plt.close() - - -def plot_jet_components_results(inp, pred, truth, path, save=False, plot_format="pdf"): - """ - Plot input images, prediction and true image. - - Parameters - ---------- - inp : n 4d arrays with 1 channel - input images - pred : n 4d arrays with multiple channels - predicted images - truth : n 4d arrays with multiple channels - true images - """ - X, Y = np.meshgrid(np.arange(inp.shape[-1]), np.arange(inp.shape[-1])) - for i in tqdm(range(len(inp))): - c = truth.shape[1] - 1 # -1 because last one is the background - for j in range(c): - truth_max = torch.max(truth[i, j]) - fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(8, 7)) - if truth_max != 0: - pred_max = torch.max(pred[i, j]) - axs[0, 0].contour( - X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" - ) - axs[0, 1].contour( - X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" - ) - axs[1, 0].contour( - X, Y, truth[i, j], levels=[truth_max * 0.32], colors="white" - ) - axs[1, 0].contour( - X, - Y, - pred[i, j], - levels=[pred_max * 0.32], - colors="cyan", - linestyles="dashed", - ) - - im1 = axs[0, 0].imshow(inp[i, 0], cmap=plt.cm.inferno) - axs[0, 0].set_xlabel(r"Pixels") - axs[0, 0].set_ylabel(r"Pixels") - divider = make_axes_locatable(axs[0, 0]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im1, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - im2 = axs[0, 1].imshow(truth[i, j], cmap=plt.cm.inferno) - axs[0, 1].set_xlabel(r"Pixels") - axs[0, 1].set_ylabel(r"Pixels") - divider = make_axes_locatable(axs[0, 1]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im2, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - im1 = axs[1, 0].imshow(pred[i, j], cmap=plt.cm.inferno) - axs[1, 0].set_xlabel(r"Pixels") - axs[1, 0].set_ylabel(r"Pixels") - divider = make_axes_locatable(axs[1, 0]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im1, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - im4 = axs[1, 1].imshow(pred[i, j] - truth[i, j], cmap=plt.cm.inferno) - divider = make_axes_locatable(axs[1, 1]) - axs[1, 1].set_xlabel(r"Pixels") - axs[1, 1].set_ylabel(r"Pixels") - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im4, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - if save: - Path(path).mkdir(parents=True, exist_ok=True) - outpath = str(path) + f"/prediction_{i}_comp_{j}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) - plt.close() - - -def plot_fitgaussian( - data, fit_list, params_list, iteration, path, save=False, plot_format="pdf" -): - """ - Plotting the sky image with the fitted gaussian distributian and the related - parameters. - - Parameters - ---------- - data : 2d array - skymap, usually the prediction of the NN - fit : 2d array - gaussian fit around the maxima - params : list - parameters related to the gaussian: height, x, y, width_x, width_y, theta - """ - fig, axs = plt.subplots( - 1, - len(params_list), - sharex=True, - sharey=True, - figsize=(4 * len(params_list), 3.5), - ) - for i, (fit, params) in enumerate(zip(fit_list, params_list)): - im = axs[i].imshow(data, cmap=plt.cm.inferno) - axs[i].set_xlabel(r"Pixels") - axs[i].set_ylabel(r"Pixels") - divider = make_axes_locatable(axs[i]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(im, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - axs[i].contour(fit, cmap=plt.cm.gray_r) - data -= fit - (height, x, y, width_x, width_y, theta) = params.parameters - plt.text( - 0.95, - 0.02, - f""" - height : {height:.2f} - x : {x:.1f} - y : {y:.1f} - width_x : {width_x:.1f} - width_y : {width_y:.1f} - theta : {theta:.2f}""", - fontsize=8, - horizontalalignment="right", - c="w", - verticalalignment="bottom", - transform=axs[i].transAxes, - ) - - if save: - Path(path).mkdir(parents=True, exist_ok=True) - outpath = str(path) + f"/eval_iterativ_gaussian_{iteration}.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) - plt.close() - - -def plot_data(x, path, rows=1, cols=1, save=False, plot_format="pdf"): - """ - Plotting image of the dataset - - Parameters - ---------- - x : array - array of shape (n, 1, size, size), n must be at least rows * cols - rows : int - number of rows in the plot - cols : int - number of cols in the plot - """ - fig, ax = plt.subplots( - rows, cols, sharex=True, sharey=True, figsize=(4 * cols, 3.5 * rows) - ) - for i in range(rows): - for j in range(cols): - img = ax[i, j].imshow(x[i * cols + j, 0], cmap=plt.cm.inferno) - ax[i, j].set_xlabel(r"Pixels") - ax[i, j].set_ylabel(r"Pixels") - divider = make_axes_locatable(ax[i, j]) - cax = divider.append_axes("right", size="5%", pad=0.05) - cbar = fig.colorbar(img, cax=cax, orientation="vertical") - cbar.set_label(r"Specific Intensity / a.u.") - - if save: - Path(path).mkdir(parents=True, exist_ok=True) - outpath = str(path) + f"/simulation_examples.{plot_format}" - fig.savefig(outpath, bbox_inches="tight", pad_inches=0.01) - plt.close() diff --git a/src/radionets/simulations/README.md b/src/radionets/simulations/README.md deleted file mode 100644 index 39cb3fb4..00000000 --- a/src/radionets/simulations/README.md +++ /dev/null @@ -1,49 +0,0 @@ -## Tools to simulate radio interferometric observations and Gaussian sources - - -### gaussian_simulations - -Functions to simulate radio galaxies consisting of Gaussian components, by adding 2d Gaussians to a two dimensional grid. Simulations are implemented for all pixel sizes of squared images. - -Varied parameters: -* Number of components -* Jet rotation -* Flux: peak, logarithmic decrease for jet components -* Extension of components -* One- and two-sided jets - -ToDo: -* Different power law indices for jet components flux decrease -* Lorentz factor for counter jet -* Flexible (more random) distance between components -* More varied extension of jet components (simulate FRI and FRII) -* Scale sources to image size - -### uv_simulations - -A antenna and a source class are used to simulate radio interferometric observations. Both classes hold information about the coordinates of the antennas/sources. It is possible to create masks to simulate (u, v)-sampling and apply these mask to -simulated (u, v)-spaces. In this way, toy monte carlo datasets are created. - - -### uv_plots - -Functions to visualize simulated sources and (u, v)-coverages. It is possible to create gifs of array layouts and -(u, v)-space sampling during an simulated observation. These functions are mainly used to create images for -presentations. - -### Examples: visualize_sampling -Creates visualization plots, which can be found in the examples directory. Run `make examples`. - -Explanations: -* ***gaussian_source.pdf***: A source consisting of Gaussian components. Can be one or two sided. - The flux is decresing logarithmically towards the outer components. -* ***fft_gaussian_source.pdf***: The Fourier transformation of the Gaussian source. Low frequencies - are located at the center, higher frequencies at the outer parts. -* ***uv_coverage.pdf***: Simulated (uv)-coverage for a radio interferometric observation. The used antenna - positions correspond to the layout of the [VLBA](https://science.nrao.edu/facilities/vlba/introduction-to-the-VLBA). -* ***baselines.gif/uv_coverage.gif***: These gifs visualize the sampling during a radio interferometric observation. -* ***mask.pdf***: [2d histogram](https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram2d.html) - of the (uv)-coverage. Masks are used to sample the Fourier space of Gaussian sources. By - doing so, incomplete measurements of the sources are simulated. -* ***sampled_frequs.pdf***: Visualizes the sampled frequencies of the source's Fourier space. -* ***recons_source.pdf***: The inverse Fourier transformation of the incomplete sample. diff --git a/src/radionets/simulations/__init__.py b/src/radionets/simulations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/radionets/simulations/gaussians.py b/src/radionets/simulations/gaussians.py deleted file mode 100644 index e3a83fcb..00000000 --- a/src/radionets/simulations/gaussians.py +++ /dev/null @@ -1,411 +0,0 @@ -import numpy as np -from scipy import ndimage -from scipy.ndimage import gaussian_filter -from tqdm import tqdm - -from radionets.core.data import save_fft_pair -from radionets.simulations.utils import add_noise, add_white_noise, adjust_outpath - - -def simulate_gaussian_sources( - data_path, - option, - num_bundles, - bundle_size, - img_size, - num_comp_ext, - noise, - noise_level, - white_noise, - mean_real, - std_real, - mean_imag, - std_imag, - source_list, -): - for _ in tqdm(range(num_bundles)): - grid = create_grid(img_size, bundle_size) - list_sources = None - - if num_comp_ext is not None: - bundle = create_ext_gauss_bundle(grid) - - if source_list: - print("Not implemented warning!") - - images = bundle.copy() - - if noise: - images = add_noise(images, noise_level) - - bundle_fft = np.array( - [np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(img))) for img in images] - ) - if white_noise: - bundle_fft = add_white_noise( - bundle_fft, mean_real, std_real, mean_imag, std_imag - ) - path = adjust_outpath(data_path, "/fft_" + option) - save_fft_pair(path, bundle_fft, bundle, list_sources) - - -def create_grid(pixel, bundle_size): - """ - Creates a square 2d grid. - - Parameters - ---------- - pixel : int - number of pixel in x and y - - Returns - ------- - grid : ndarray - 2d grid with 1e-10 pixels, X meshgrid, Y meshgrid - """ - x = np.linspace(0, pixel - 1, num=pixel) - y = np.linspace(0, pixel - 1, num=pixel) - X, Y = np.meshgrid(x, y) - grid = np.array([np.zeros(X.shape) + 1e-10, X, Y]) - grid = np.repeat( - grid[None, :, :, :], - bundle_size, - axis=0, - ) - return grid - - -# draw random parameters for extended gaussian sources - - -def gauss_paramters(): - """ - Generate a random set of Gaussian parameters. - - Returns - ------- - comps : int - Number of components - amp : float - Amplitude of the core component - x : array - x positions of components - y : array - y positions of components - sig_x : - standard deviation in x - sig_y : - standard deviation in y - rot : int - rotation in degree - sides : int - 0 for one-sided and 1 for two-sided jets - """ - # random number of components between 4 and 9 - comps = np.random.randint(4, 7) # decrease for smaller images - - # start amplitude between 10 and 1e-3 - amp_start = (np.random.randint(0, 100) * np.random.random()) / 10 - # if start amp is 0, draw a new number - while amp_start == 0: - amp_start = (np.random.randint(0, 100) * np.random.random()) / 10 - # logarithmic decrease to outer components - amp = np.array([amp_start / np.exp(i) for i in range(comps)]) - - # linear distance bestween the components - x = np.arange(0, comps) * 5 - y = np.zeros(comps) - - # extension of components - # random start value between 1 - 0.375 and 1 - 0 - # linear distance between components - # distances scaled by factor between 0.25 and 0.5 - # randomnized for each sigma - off1 = (np.random.random() + 0.5) / 4 - off2 = (np.random.random() + 0.5) / 4 - fac1 = (np.random.random() + 1) / 4 - fac2 = (np.random.random() + 1) / 4 - sig_x = (np.arange(1, comps + 1) - off1) * fac1 - sig_y = (np.arange(1, comps + 1) - off2) * fac2 - - # jet rotation - rot = np.random.randint(0, 360) - # jet one- or two-sided - sides = np.random.randint(0, 2) - - return comps, amp, x, y, sig_x, sig_y, rot, sides - - -def create_rot_mat(alpha): - """ - Create 2d rotation matrix for given alpha - - Parameters - ---------- - alpha : float - rotation angle in rad - - Returns - ------- - rot_mat : 2darray - 2d rotation matrix - """ - rot_mat = np.array( - [[np.cos(alpha), -np.sin(alpha)], [np.sin(alpha), np.cos(alpha)]] - ) - return rot_mat - - -def gaussian_component(x, y, flux, x_fwhm, y_fwhm, rot, center=None): - """ - Adds a gaussian component to a 2d grid. - - Parameters - ---------- - x : 2darray - x coordinates of 2d meshgrid - y : 2darray - y coordinates of 2d meshgrid - flux : float - peak amplitude of component - x_fwhm : float - full-width-half-maximum in x direction (sigma_x) - y_fwhm : float - full-width-half-maximum in y direction (sigma_y) - rot : int - rotation of component in degree - center : 2darray - enter of component - - Returns - ------- - gauss : 2darray - 2d grid with gaussian component - """ - if center is None: - x_0 = y_0 = len(x) // 2 - else: - rot_mat = create_rot_mat(np.deg2rad(rot)) - x_0, y_0 = ((center - len(x) // 2) @ rot_mat) + len(x) // 2 - gauss = flux * np.exp( - -((x_0 - x) ** 2 / (2 * (x_fwhm) ** 2) + (y_0 - y) ** 2 / (2 * (y_fwhm) ** 2)) - ) - return gauss - - -def add_gaussian(grid, amp, x, y, sig_x, sig_y, rot): - """ - Takes a grid and adds n Gaussian component relative to the center. - - Parameters - ---------- - grid : 2darray - 2d grid - amp : float - amplitude of gaussian component - x : float - x position, will be calculated rel. to center - y : float - y position, will be calculated rel. to center - sig_x : float - standard deviation in x - sig_y : float - standard deviation in y - rot : int - rotation in degree - - Returns - ------- - gaussian : 2darray - grid with gaussian component - """ - cent = np.array([len(grid[0]) // 2 + x, len(grid[0]) // 2 + y]) - X = grid[1] - Y = grid[2] - gaussian = grid[0] - gaussian += gaussian_component( - X, - Y, - amp, - sig_x, - sig_y, - rot, - center=cent, - ) - - return gaussian - - -def create_gaussian_source( - grid, comps, amp, x, y, sig_x, sig_y, rot, sides=0, blur=True -): - """ - Combines Gaussian components on a 2d grid to create a Gaussian source - - takes grid - side: one-sided or two-sided - core dominated or lobe dominated - number of components - angle of the jet - - Parameters - ---------- - grid : ndarray - 2dgrid + X and Y meshgrid - comps : int - number of components - amp : 1darray - amplitudes of components - x : 1darray - x positions of components - y : 1darray - y positions of components - sig_x : 1darray - standard deviations of components in x - sig_y : 1darray - standard deviations of components in y - rot : int - rotation of the jet in degree - sides : int - 0 one-sided, 1 two-sided jet - blur : bool - use Gaussian filter to blur image - - Returns - ------- - source : 2darray - 2d grid containing Gaussian source - - Comments - -------- - components should not have too big gaps between each other - """ - if sides == 1: - comps += comps - 1 - amp = np.append(amp, amp[1:]) - x = np.append(x, -x[1:]) - y = np.append(y, -y[1:]) - sig_x = np.append(sig_x, sig_x[1:]) - sig_y = np.append(sig_y, sig_y[1:]) - - for i in range(comps): - source = add_gaussian( - grid=grid, - amp=amp[i], - x=x[i], - y=y[i], - sig_x=sig_x[i], - sig_y=sig_y[i], - rot=rot, - ) - if blur is True: - source = gaussian_filter(source, sigma=1.5) - return source - - -def gaussian_source(grid): - """ - Creates random Gaussian source parameters and returns an image - of a Gaussian source. - - Parameters - ---------- - grid : nd array - array holding 2d grid and axis for one image - - Returns - ------- - s : 2darray - Image containing a simulated Gaussian source. - """ - # grid = create_grid(img_size) - comps, amp, x, y, sig_x, sig_y, rot, sides = gauss_paramters() - s = create_gaussian_source( - grid, comps, amp, x, y, sig_x, sig_y, rot, sides, blur=True - ) - return s - - -def create_ext_gauss_bundle(grid): - """ - Creates a bundle of Gaussian sources. - - Parameters - ---------- - grid : nd array - array holding 2d grid and axis for whole bundle - - Returns - ------- - bundle ndarray - bundle of Gaussian sources - """ - bundle = np.array([gaussian_source(g) for g in grid]) - return bundle - - -# pointlike gaussians - - -def create_gauss(img, N, sources, spherical, source_list): - # img = [img] - mx = np.random.randint(1, 63, size=(N, sources)) - my = np.random.randint(1, 63, size=(N, sources)) - amp = ( - np.random.randint(0.001, 100, size=(N)) * 1 / 10 * np.random.randint(5, 10) - ) / 1e2 - - if spherical: - sx = np.random.randint(3, 8, size=(N, sources)) - sy = sx - else: - sx = np.random.randint(1, 15, size=(N, sources)) - sy = np.random.randint(1, 15, size=(N, sources)) - theta = np.random.randint(0, 360, size=(N, sources)) - - s = np.zeros((N, sources, 1)) # changed from 5 - for i in range(N): - for j in range(sources): - g = gauss(mx[i, j], my[i, j], sx[i, j], sy[i, j], amp[i]) - # s[i,j] = np.array([mx[i,j],my[i,j],sx[i,j],sy[i,j],amp[i]]) - s[i, j] = np.array([mx[i, j]]) - if spherical: - img[i] += g - else: - # rotation around center of the source - padX = [g.shape[0] - mx[i, j], mx[i, j]] - padY = [g.shape[1] - my[i, j], my[i, j]] - imgP = np.pad(g, [padY, padX], "constant") - imgR = ndimage.rotate(imgP, theta[i, j], reshape=False) - imgC = imgR[padY[0] : -padY[1], padX[0] : -padX[1]] - img[i] += imgC - if source_list: - return img, s - else: - return img - - -# pointsources - - -def gauss_pointsources(img, N, sources, source_list): - mx = np.random.randint(0, 63, size=(N, sources)) - my = np.random.randint(0, 63, size=(N, sources)) - amp = np.random.randint(1, 10, size=(N)) - sigma = 0.005 - s = np.zeros((N, sources, 3)) # changed from 5 - for i in range(N): - for j in range(sources): - g = gauss(mx[i, j], my[i, j], sigma, sigma, amp[i]) - s[i, j] = np.array([mx[i, j], my[i, j], amp[i]]) - img[i] += g - print(s.shape) - if source_list: - return img, s - return np.array(img) - - -def gauss(mx, my, sx, sy, amp=0.01): - x = np.arange(63)[None].astype(np.float) - y = x.T - return amp * np.exp(-((y - my) ** 2) / sy).dot(np.exp(-((x - mx) ** 2) / sx)) diff --git a/src/radionets/simulations/layouts/__init__.py b/src/radionets/simulations/layouts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/radionets/simulations/layouts/layouts.py b/src/radionets/simulations/layouts/layouts.py deleted file mode 100644 index 0bbe8ed5..00000000 --- a/src/radionets/simulations/layouts/layouts.py +++ /dev/null @@ -1,11 +0,0 @@ -from pathlib import Path - -import numpy as np - -file_dir = Path(__file__).parent.resolve() - - -def vlba(): - x, y, z, _, _ = np.genfromtxt(file_dir / "vlba.txt", unpack=True) - ant_pos = np.array([x, y, z]) - return ant_pos diff --git a/src/radionets/simulations/layouts/vlba.txt b/src/radionets/simulations/layouts/vlba.txt deleted file mode 100644 index ed9a9986..00000000 --- a/src/radionets/simulations/layouts/vlba.txt +++ /dev/null @@ -1,11 +0,0 @@ -# x y z dish_dia station --2112065.261576 -3705356.502894 4726813.649034 25 BR --1324009.374466 -5332181.952906 3231962.377936 25 FD -1446374.806874 -4447939.68308 4322306.19968 25 HN --1995678.891541 -5037317.693848 3357328.002045 25 KP --1449752.637707 -4975298.573645 3709123.828301 25 LA --5464075.238656 -2495247.871441 2148297.485741 25 MK --130872.556729 -4762317.087045 4226850.993404 25 NL --2409150.471188 -4478573.093114 3838617.326057 25 OV --1640953.992891 -5014816.024485 3575411.772132 25 PT -2607848.664542 -5488069.500452 1932739.778597 25 SC diff --git a/src/radionets/simulations/point_sources.py b/src/radionets/simulations/point_sources.py deleted file mode 100644 index 64eccdbb..00000000 --- a/src/radionets/simulations/point_sources.py +++ /dev/null @@ -1,218 +0,0 @@ -import h5py -import numpy as np -from tqdm import tqdm - -from radionets.simulations.gaussians import create_grid, create_rot_mat - - -def gaussian_component(x, y, flux, x_fwhm, y_fwhm, rot, center=None): - """ - Adds a gaussian component to a 2d grid. - - Parameters - ---------- - x : 2darray - x coordinates of 2d meshgrid - y : 2darray - y coordinates of 2d meshgrid - flux : float - peak amplitude of component - x_fwhm : float - full-width-half-maximum in x direction (sigma_x) - y_fwhm : float - full-width-half-maximum in y direction (sigma_y) - rot : int - rotation of component in degree - center : 2darray - enter of component - - Returns - ------- - gauss : 2darray - 2d grid with gaussian component - """ - if center is None: - x_0 = y_0 = len(x) // 2 - else: - rot_mat = create_rot_mat(np.deg2rad(rot)) - x_0, y_0 = ((center - len(x) // 2) @ rot_mat) + len(x) // 2 - gauss = flux * np.exp( - -((x_0 - x) ** 2 / (2 * (x_fwhm) ** 2) + (y_0 - y) ** 2 / (2 * (y_fwhm) ** 2)) - ) - params = np.array([x_0, y_0, x_fwhm, y_fwhm]) - return gauss, params - - -def gauss_parameters(): - # random number of components between 4 and 9 - comps = np.random.randint(4, 7) # decrease for smaller images - - rng = np.random.default_rng() - # start amplitude between 1 and 10 - amp_start = rng.uniform(5, 10) - # logarithmic decrease to outer components - amp = np.array([amp_start / (np.exp(i * 0.6)) for i in range(comps)]) - - # linear distance bestween the components - x = np.arange(0, comps) * 5 - y = np.zeros(comps) - - sig_start = rng.uniform(1, 1.2) - fac = rng.uniform(0.25, 0.5) - sig = (np.arange(0, comps) * fac) + sig_start - - # jet rotation - rot = np.random.randint(0, 360) - # jet one- or two-sided - sides = np.random.randint(0, 2) - - return comps, amp, x, y, sig, sig, rot, sides - - -def add_gaussian(grid, amp, x, y, sig_x, sig_y, rot): - cent = np.array([len(grid[0]) // 2 + x, len(grid[0]) // 2 + y]) - X = grid[1] - Y = grid[2] - gaussian = grid[0] - comp, params = gaussian_component( - X, - Y, - amp, - sig_x, - sig_y, - rot, - center=cent, - ) - gaussian += comp - - return gaussian, params - - -def create_gaussian_source(grid, comps, amp, x, y, sig_x, sig_y, rot, sides=0): - if sides == 1: - comps += comps - 1 - amp = np.append(amp, amp[1:]) - x = np.append(x, -x[1:]) - y = np.append(y, -y[1:]) - sig_x = np.append(sig_x, sig_x[1:]) - sig_y = np.append(sig_y, sig_y[1:]) - - params = np.array([]) - for i in range(comps): - source, param = add_gaussian( - grid=grid, - amp=amp[i], - x=x[i], - y=y[i], - sig_x=sig_x[i], - sig_y=sig_y[i], - rot=rot, - ) - params = np.append(params, param) - return source, params - - -def gaussian_source(grid): - comps, amp, x, y, sig_x, sig_y, rot, sides = gauss_parameters() - s, params = create_gaussian_source(grid, comps, amp, x, y, sig_x, sig_y, rot, sides) - return s, params - - -def gauss(img_size, mx, my, sx, sy, amp=0.01): - x = np.arange(img_size)[None].astype(np.float) - y = x.T - return amp * np.exp(-((y - my) ** 2) / sy).dot(np.exp(-((x - mx) ** 2) / sx)) - - -def create_gauss(img, num_sources, source_list, img_size=63): - mx = np.random.randint(1, img_size, size=(num_sources)) - my = np.random.randint(1, img_size, size=(num_sources)) - rng = np.random.default_rng() - amp = rng.uniform(1, 10, num_sources) - sx = np.random.randint( - round(1 / 8 * (img_size**2) / 720), - 1 / 2 * (img_size**2) / 360, - size=(num_sources), - ) - sy = sx - idx = [] - for n in range(num_sources): - if img[mx[n], my[n]] <= 5e-10: - g = gauss(img_size, mx[n], my[n], sx[n], sy[n], amp[n]) - img += g - else: - idx.append(n) - mx = np.delete(mx, idx) - my = np.delete(my, idx) - sx = np.delete(sx, idx) - sy = np.delete(sy, idx) - # assert np.isnan(img).any() == False - if source_list: - return img, [mx, my], [sx, sy] - else: - return img - - -def create_point_source_img( - img_size, bundle_size, num_bundles, path, option, extended=False -): - for num_bundle in tqdm(range(num_bundles)): - with h5py.File(path + "/fft_" + option + str(num_bundle) + ".h5", "w") as hf: - for num_img in range(bundle_size): - grid = create_grid(img_size, 1) - num_point_sources = np.random.randint(2, 5) - - if extended: - gs, params_extended = gaussian_source(grid[0]) - x_off = np.random.randint(1, 20) - y_off = np.random.randint(1, 20) - params_extended[0::4] += x_off - params_extended[1::4] += y_off - gs = np.pad(gs, ((y_off, 0), (x_off, 0)), constant_values=(1e-10))[ - :-y_off, :-x_off - ] - tag_ext = np.ones(len(params_extended) // 4) - - g, p_point, s_point = create_gauss( - gs, num_point_sources, True, img_size - ) - - tag_point = np.zeros(len(p_point[0])) - - comps = np.array( - [ - np.concatenate([p_point[0], params_extended[0::4]]), - np.concatenate([p_point[1], params_extended[1::4]]), - np.concatenate([s_point[0], params_extended[2::4]]), - np.concatenate([s_point[1], params_extended[3::4]]), - np.concatenate([tag_point, tag_ext]), - ] - ) - - # crop image size - mask = ( - (comps[0] >= 0) - & (comps[0] <= img_size - 1) - & (comps[1] >= 0) - & (comps[1] <= img_size - 1) - ) - list_x = comps[0][mask] - list_y = comps[1][mask] - list_sx = comps[2][mask] - list_sy = comps[3][mask] - list_tag = comps[4][mask] - assert ( - list_x.shape - == list_y.shape - == list_sx.shape - == list_sy.shape - == list_tag.shape - ) - - source_list = np.array([list_x, list_y, list_sx, list_sy, list_tag]) - g_fft = np.array(np.fft.fftshift(np.fft.fft2(g.copy()))) - hf.create_dataset("x" + str(num_img), data=g_fft) - hf.create_dataset("y" + str(num_img), data=g) - hf.create_dataset("z" + str(num_img), data=source_list) - - hf.close() diff --git a/src/radionets/simulations/sampling.py b/src/radionets/simulations/sampling.py deleted file mode 100644 index f98ec92c..00000000 --- a/src/radionets/simulations/sampling.py +++ /dev/null @@ -1,96 +0,0 @@ -import os - -import numpy as np -from numpy import savez_compressed -from tqdm import tqdm - -from radionets.core.data import ( - open_bundle_pack, - open_fft_bundle, - save_fft_pair, -) -from radionets.simulations.utils import ( - get_fft_bundle_paths, - interpol, - prepare_fft_images, -) -from radionets.simulations.uv_simulations import sample_freqs - - -def sample_frequencies( - data_path, - amp_phase, - real_imag, - fourier, - compressed, - interpolation, - specific_mask, - antenna_config, - lon=None, - lat=None, - steps=None, - multi_channel=False, - bandwidths=4, - source_type="point_sources", -): - for mode in ["train", "valid", "test"]: - print(f"\n Sampling {mode} data set.\n") - - bundle_paths = get_fft_bundle_paths(data_path, "fft", mode) - - if bundle_paths.size == 0: - print(f"\n No {mode} data set fft images available.\n") - - for path in tqdm(bundle_paths): - if source_type != "point_sources": - fft, truth = open_fft_bundle(path) - source_list = None - else: - fft, truth, source_list = open_bundle_pack(path) - - size = fft.shape[-1] - - fft_scaled = prepare_fft_images(fft.copy(), amp_phase, real_imag) - truth_fft = np.array( - [np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(img))) for img in truth] - ) - fft_scaled_truth = prepare_fft_images(truth_fft, amp_phase, real_imag) - - if specific_mask is True: - fft_samp = sample_freqs( - fft_scaled.copy(), - antenna_config, - size, - lon, - lat, - steps, - plot=False, - test=False, - multi_channel=multi_channel, - bandwidths=bandwidths, - ) - else: - fft_samp = sample_freqs( - fft_scaled.copy(), - antenna_config, - num_steps=steps, - size=size, - specific_mask=False, - multi_channel=multi_channel, - bandwidths=bandwidths, - ) - - if interpolation: - for i in range(len(fft_samp[:, 0, 0, 0])): - fft_samp[i] = interpol(fft_samp[i]) - - out = data_path + "/samp_" + path.name.split("_")[-1] - - if fourier: - if compressed: - savez_compressed(out, x=fft_samp, y=fft_scaled) - os.remove(path) - else: - save_fft_pair(out, fft_samp, fft_scaled_truth, source_list) - else: - save_fft_pair(out, fft_samp, truth) diff --git a/src/radionets/simulations/scripts/__init__.py b/src/radionets/simulations/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/radionets/simulations/scripts/simulate_images.py b/src/radionets/simulations/scripts/simulate_images.py deleted file mode 100644 index ffd629bf..00000000 --- a/src/radionets/simulations/scripts/simulate_images.py +++ /dev/null @@ -1,47 +0,0 @@ -import click -import toml - -from radionets.simulations.simulate import create_fft_images, sample_fft_images -from radionets.simulations.utils import check_outpath, read_config - - -@click.command() -@click.argument("configuration_path", type=click.Path(exists=True, dir_okay=False)) -def main(configuration_path): - """ - Generate monte carlo simulation data sets to train and test neural networks for - reconstruction of radio interferometric data. - - Parameter - --------- - configuration_path : str - Path to the config toml file - """ - config = toml.load(configuration_path) - - # check out path and look for existing files - data_path = config["paths"]["data_path"] - sim_fft, sim_sampled = check_outpath( - data_path, - data_format=config["paths"]["data_format"], - quiet=config["mode"]["quiet"], - ) - - # declare source options - sim_conf = read_config(config) - - click.echo("\n Simulation config:") - print(sim_conf, "\n") - - # start simulations - if sim_fft is True: - click.echo("Starting simulation of fft_files!") - create_fft_images(sim_conf) - - if sim_sampled is True: - click.echo("Start sampling fft_files!") - sample_fft_images(sim_conf) - - -if __name__ == "__main__": - main() diff --git a/src/radionets/simulations/simulate.py b/src/radionets/simulations/simulate.py deleted file mode 100644 index abd2ea74..00000000 --- a/src/radionets/simulations/simulate.py +++ /dev/null @@ -1,80 +0,0 @@ -from pathlib import Path - -import click - -from radionets.simulations.gaussians import simulate_gaussian_sources -from radionets.simulations.point_sources import create_point_source_img -from radionets.simulations.sampling import sample_frequencies - - -def create_fft_images(sim_conf): - """ - Create fft source images and save them to h5 files. - - Parameters - ---------- - sim_conf : dict - dict holding simulation parameters - """ - if sim_conf["type"] == "gaussians": - for opt in ["train", "valid", "test"]: - simulate_gaussian_sources( - data_path=sim_conf["data_path"], - option=opt, - num_bundles=sim_conf["bundles_" + str(opt)], - bundle_size=sim_conf["bundle_size"], - img_size=sim_conf["img_size"], - num_comp_ext=sim_conf["num_components"], - noise=sim_conf["noise"], - noise_level=sim_conf["noise_level"], - source_list=sim_conf["source_list"], - white_noise=sim_conf["white_noise"], - mean_real=sim_conf["mean_real"], - std_real=sim_conf["std_real"], - mean_imag=sim_conf["mean_imag"], - std_imag=sim_conf["std_imag"], - ) - - if sim_conf["type"] == "point_sources": - for opt in ["train", "valid", "test"]: - create_point_source_img( - img_size=sim_conf["img_size"], - bundle_size=sim_conf["bundle_size"], - num_bundles=sim_conf["bundles_" + str(opt)], - path=sim_conf["data_path"], - option=opt, - extended=sim_conf["add_extended"], - ) - - -def sample_fft_images(sim_conf): - """ - check for fft files - keep fft_files? - """ - sample_frequencies( - data_path=sim_conf["data_path"], - amp_phase=sim_conf["amp_phase"], - real_imag=sim_conf["real_imag"], - specific_mask=sim_conf["specific_mask"], - antenna_config=sim_conf["antenna_config"], - lon=sim_conf["lon"], - lat=sim_conf["lat"], - steps=sim_conf["steps"], - fourier=sim_conf["fourier"], - compressed=sim_conf["compressed"], - interpolation=sim_conf["interpolation"], - multi_channel=sim_conf["multi_channel"], - bandwidths=sim_conf["bandwidths"], - source_type=sim_conf["type"], - ) - if sim_conf["keep_fft_files"] is not True: # noqa: SIM102 - if click.confirm("Do you really want to delete the fft_files?", abort=False): - fft = { - p - for p in Path(sim_conf["data_path"]).rglob( - "*fft*." + str(sim_conf["data_format"]) - ) - if p.is_file() - } - [p.unlink() for p in fft] diff --git a/src/radionets/simulations/utils.py b/src/radionets/simulations/utils.py deleted file mode 100644 index c251a76d..00000000 --- a/src/radionets/simulations/utils.py +++ /dev/null @@ -1,269 +0,0 @@ -import os -import re -import sys -from pathlib import Path - -import click -import numpy as np -from scipy import interpolate - -from radionets.core.data import get_bundles -from radionets.core.utils import ( - split_amp_phase, - split_real_imag, -) - - -def check_outpath(outpath, data_format, quiet=False): - """ - Check if outpath exists. Check for existing fft_files and sampled-files. - Ask to overwrite or reuse existing files. - - Parameters - ---------- - path : str - path to out directory - - Returns - ------- - sim_fft : bool - flag to enable/disable fft routine - sim_sampled : bool - flag to enable/disable sampling routine - """ - path = Path(outpath) - exists = path.exists() - if exists is True: - fft = {p for p in path.rglob("*fft*." + str(data_format)) if p.is_file()} - samp = {p for p in path.rglob("*samp*." + str(data_format)) if p.is_file()} - if fft: - click.echo("Found existing fft_files!") - if quiet or click.confirm( - "Do you really want to overwrite the files?", abort=False - ): - click.echo("Overwriting old fft_files!") - [p.unlink() for p in fft] - [p.unlink() for p in samp] - sim_fft = True - sim_sampled = True - return sim_fft, sim_sampled - else: - click.echo("Using old fft_files!") - sim_fft = False - else: - sim_fft = True - if samp: - click.echo("Found existing samp_files!") - if quiet or click.confirm( - "Do you really want to overwrite the files?", abort=False - ): - click.echo("Overwriting old samp_files!") - [p.unlink() for p in samp] - sim_sampled = True - else: - click.echo("No new images sampled!") - sim_sampled = False - sys.exit() - else: - sim_sampled = True - else: - Path(path).mkdir(parents=True, exist_ok=False) - sim_fft = True - sim_sampled = True - return sim_fft, sim_sampled - - -def read_config(config): - sim_conf = {} - sim_conf["data_path"] = config["paths"]["data_path"] - sim_conf["data_format"] = config["paths"]["data_format"] - if config["gaussians"]["simulate"]: - click.echo("Create fft_images from gaussian data set! \n") - - sim_conf["type"] = "gaussians" - sim_conf["num_components"] = config["gaussians"]["num_components"] - click.echo("Adding extended gaussian sources.") - - if config["point_sources"]["simulate"]: - click.echo("Create fft_images from point source data set! \n") - - sim_conf["type"] = "point_sources" - sim_conf["add_extended"] = config["point_sources"]["add_extended"] - click.echo("Adding point sources.") - - sim_conf["bundles_train"] = config["image_options"]["bundles_train"] - sim_conf["bundles_valid"] = config["image_options"]["bundles_valid"] - sim_conf["bundles_test"] = config["image_options"]["bundles_test"] - sim_conf["bundle_size"] = config["image_options"]["bundle_size"] - sim_conf["img_size"] = config["image_options"]["img_size"] - sim_conf["noise"] = config["image_options"]["noise"] - sim_conf["noise_level"] = config["image_options"]["noise_level"] - sim_conf["white_noise"] = config["image_options"]["white_noise"] - sim_conf["mean_real"] = config["image_options"]["mean_real"] - sim_conf["std_real"] = config["image_options"]["std_real"] - sim_conf["mean_imag"] = config["image_options"]["mean_imag"] - sim_conf["std_imag"] = config["image_options"]["std_imag"] - - sim_conf["amp_phase"] = config["sampling_options"]["amp_phase"] - sim_conf["real_imag"] = config["sampling_options"]["real_imag"] - sim_conf["source_list"] = config["sampling_options"]["source_list"] - sim_conf["antenna_config"] = config["sampling_options"]["antenna_config"] - sim_conf["specific_mask"] = config["sampling_options"]["specific_mask"] - sim_conf["lon"] = config["sampling_options"]["lon"] - sim_conf["lat"] = config["sampling_options"]["lat"] - sim_conf["steps"] = config["sampling_options"]["steps"] - sim_conf["fourier"] = config["sampling_options"]["fourier"] - sim_conf["compressed"] = config["sampling_options"]["compressed"] - sim_conf["keep_fft_files"] = config["sampling_options"]["keep_fft_files"] - sim_conf["interpolation"] = config["sampling_options"]["interpolation"] - sim_conf["multi_channel"] = config["sampling_options"]["multi_channel"] - sim_conf["bandwidths"] = config["sampling_options"]["bandwidths"] - return sim_conf - - -def adjust_outpath(path, option, form="h5"): - """ - Add number to out path when filename already exists. - - Parameters - ---------- - path : str - path to save directory - option : str - additional keyword to add to path - - Returns - ------- - out : str - adjusted path - """ - counter = 0 - filename = str(path) + (option + "{}." + form) - while os.path.isfile(filename.format(counter)): - counter += 1 - out = filename.format(counter) - return out - - -def get_fft_bundle_paths(data_path, ftype, mode): - bundles = get_bundles(data_path) - bundle_paths = np.sort( - [path for path in bundles if re.findall(f"{ftype}_{mode}", path.name)] - ) - return bundle_paths - - -def prepare_fft_images(fft_images, amp_phase, real_imag): - if amp_phase: - amp, phase = split_amp_phase(fft_images) - amp = (np.log10(amp + 1e-10) / 10) + 1 - - # Test new masking for 511 Pixel pictures - if amp.shape[1] == 511: - mask = amp > 0.1 - phase[~mask] = 0 - fft_scaled = np.stack((amp, phase), axis=1) - else: - real, imag = split_real_imag(fft_images) - fft_scaled = np.stack((real, imag), axis=1) - return fft_scaled - - -def get_noise(image, scale, mean=0, std=1): - """ - Calculate random noise values for all image pixels. - - Parameters - ---------- - image : 2darray - 2d image - scale : float - scaling factor to increase noise - mean : float - mean of noise values - std : float - standard deviation of noise values - - Returns - ------- - out : ndarray - array with noise values in image shape - """ - return np.random.normal(mean, std, size=image.shape) * scale - - -def add_noise(bundle, noise_level): - """ - Used for adding noise and plotting the original and noised picture, - if asked. Using 0.05 * max(image) as scaling factor. - - Parameters - ---------- - bundle : path - path to hdf5 bundle file - noise_level : int - noise level in percent - - Returns - ------- - bundle_noised hdf5_file - bundle with noised images - """ - bundle_noised = np.array( - [img + get_noise(img, (img.max() * noise_level / 100)) for img in bundle] - ) - return bundle_noised - - -def interpol(img): - """Interpolates fft sampled amplitude and phase data. - Parameters - ---------- - img : array - array with shape 2,width,heigth - input image array with amplitude and phase on axis 0 - Returns - ------- - array - array with shape 2,width,heigth - interpolated image array with amplitude and phase on axis 0 - """ - grid_x, grid_y = np.mgrid[0 : len(img[0, 0]) : 1, 0 : len(img[0, 0]) : 1] - - idx_amp = np.nonzero(img[0]) - amp = interpolate.griddata( - (idx_amp[0], idx_amp[1]), img[0][idx_amp], (grid_x, grid_y), method="nearest" - ) - - img[1][img[1] < 0] = 0 - idx_phase = np.nonzero(img[1]) - phase = interpolate.griddata( - (idx_phase[0], idx_phase[1]), - img[1][idx_phase], - (grid_x, grid_y), - method="nearest", - ) - - mask = np.ones((len(img[0, 0]), len(img[0, 0]))) - mask[1::2, 1::2] = 0 - mask[::2, ::2] = 0 - for i in range(len(img[0, 0])): - mask[i, len(img[0, 0]) - 1 - i :] = 1 - mask[i, len(img[0, 0]) - 1 - i :] - - phase_fl = -np.flip(phase, [0, 1]) - phase = phase * mask + phase_fl * (1 - mask) - - return np.array([amp, phase]) - - -def add_white_noise(images, mean_real=25, std_real=1.25, mean_imag=7, std_imag=0.35): - img_size = images.shape[2] - noise_real = np.random.normal( - mean_real, std_real, size=(images.shape[0], img_size, img_size) - ) - noise_imag = np.random.normal( - mean_imag, std_imag, size=(images.shape[0], img_size, img_size) - ) - images.real += noise_real - images.imag += noise_imag - return images diff --git a/src/radionets/simulations/uv_plots.py b/src/radionets/simulations/uv_plots.py deleted file mode 100644 index d937d2ee..00000000 --- a/src/radionets/simulations/uv_plots.py +++ /dev/null @@ -1,298 +0,0 @@ -import cartopy.crs as ccrs -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.animation import FuncAnimation, PillowWriter -from matplotlib.colors import LogNorm - -from radionets.simulations.uv_simulations import get_uv_coverage - -# make nice Latex friendly plots -# mpl.use("pgf") -# mpl.rcParams.update( -# { -# "font.size": 12, -# "font.family": "sans-serif", -# "text.usetex": True, -# "pgf.rcfonts": False, -# "pgf.texsystem": "lualatex", -# } -# ) - - -def plot_uv_coverage(u, v): - """ - Visualize (uv)-coverage - - Parameters - ---------- - u : 1darray - array of u coordinates - v : 1darray - array of v coordinates - """ - plt.plot(u, v, marker="o", linestyle="none", markersize=2, color="#1f77b4") - plt.xlabel(r"u / $\lambda$", fontsize=20) - plt.ylabel(r"v / $\lambda$", fontsize=20) - plt.tight_layout() - - -def plot_baselines(antenna): - """ - Visualize baselines of an antenna layout - - Parameters - ---------- - antenna : antenna class object - class object with antenna positions and baselines between telescopes - """ - x_base, y_base = antenna.get_baselines() - plt.plot( - x_base, - y_base, - linestyle="--", - color="#2ca02c", - zorder=0, - label="Baselines", - alpha=0.35, - ) - plt.tight_layout() - - -def plot_antenna_distribution(source_lon, source_lat, source, antenna, baselines=False): - """ - Visualize antenna distribution seen from a specific source position - - Parameters - ---------- - source_lon : float - longitude of the source - source_lat : float - latitude of the source - source : source class object - class object containing source position - antenna : antenna class object - class object with antenna positions and baselines between telescopes - baselines : bool - enable baseline plotting - """ - x, y, z = source.to_ecef(val=[source_lon, source_lat]) # only use source ? - x_enu_ant, y_enu_ant = antenna.to_enu(x, y, z) - - ax = plt.axes(projection=ccrs.Orthographic(source_lon, source_lat)) - ax.set_global() - ax.coastlines() - - plt.plot( - x_enu_ant, - y_enu_ant, - marker="o", - markersize=6, - color="#1f77b4", - linestyle="none", - label="Antenna positions", - ) - plt.plot( - x, - y, - marker="*", - linestyle="none", - color="#ff7f0e", - markersize=15, - transform=ccrs.Geodetic(), - zorder=10, - label="Projected source", - ) - - if baselines is True: - plot_baselines(antenna) - - plt.legend(fontsize=16, markerscale=1.5) - plt.tight_layout() - - -def animate_baselines(source, antenna, filename, fps=5): - """ - Create gif to animate change of baselines during an observation - - Parameters - ---------- - source : source class object - class object containing source position - antenna : antenna class object - class object with antenna positions and baselines between telescopes - filename : str - name of the created gif - fps : int - frames per seconds of the gif - """ - s_lon = source.lon_prop - s_lat = source.lat_prop - - fig = plt.figure(figsize=(6, 6), dpi=100) - - def init(): - pass - - def update(frame): - lon = s_lon[frame] - lat = s_lat[frame] - plot_antenna_distribution(lon, lat, source, antenna, baselines=True) - - ani = FuncAnimation( - fig, update, frames=len(s_lon), init_func=init, interval=1000 / fps - ) - - ani.save(str(filename) + ".gif", writer=PillowWriter(fps=fps)) - - -def animate_uv_coverage(source, antenna, filename, fps=5): - """ - Create gif to animate improvement of (uv)-coverage during an observation - - Parameters - ---------- - source : source class object - class object containing source position - antenna : antenna class object - class object with antenna positions and baselines between telescopes - filename : str - name of the created gif - fps : int - frames per seconds of the gif - """ - u, v, steps = get_uv_coverage(source, antenna, iterate=True) - - fig = plt.figure(figsize=(6, 6), dpi=100) - - def init(): - pass - - def update(frame): - plot_uv_coverage(u[frame], v[frame]) - plt.ylim(-5e8, 5e8) - plt.xlim(-5e8, 5e8) - - ani = FuncAnimation( - fig, update, frames=steps, init_func=init, interval=0.001, repeat=False - ) - - ani.save(str(filename) + ".gif", dpi=80, writer=PillowWriter(fps=fps)) - - -def plot_source(img, ft=False, log=False, ft2=False): - """ - Visualize a radio source - - Parameters - ---------- - img : 2darray - values of Gaussian source - ft : bool - if True, the Fourier transformation (frequency space) of the image is plotted - """ - # plt.rcParams.update({"font.size": 18}) - fig = plt.figure(figsize=(8, 6)) - ax = fig.add_subplot(111) - if ft is False: - img = np.abs(img) - ax.set_xlabel("l", fontsize=20) - ax.set_ylabel("m", fontsize=20) - if log is True: - s = ax.imshow(img, cmap="inferno", norm=LogNorm(vmin=1e-8, vmax=img.max())) - else: - s = ax.imshow(img, cmap="inferno") - cbar = fig.colorbar(s, label="Intensity / a.u.") - cbar.set_label("Intensity / a.u.", size=20) - cbar.ax.tick_params(labelsize=20) - else: - img = np.abs(FT2(img)) if ft2 else np.abs(FT(img)) - - ax.set_xlabel("u", fontsize=20) - ax.set_ylabel("v", fontsize=20) - if log is True: - s = ax.imshow(img, cmap="inferno", norm=LogNorm()) - else: - s = ax.imshow(img, cmap="inferno") - cbar = fig.colorbar(s, label="Intensity / a.u.") - cbar.set_label("Intensity / a.u.", size=20) - cbar.ax.tick_params(labelsize=20) - - ax.set_yticklabels([]) - ax.set_xticklabels([]) - ax.xaxis.set_ticks_position("none") - ax.yaxis.set_ticks_position("none") - plt.tight_layout() - - -def plot_mask(fig, mask): - from mpl_toolkits.axes_grid1 import make_axes_locatable - - ax = fig.add_subplot(111) - s = plt.imshow(mask.astype(int), cmap="inferno") - divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="5%", pad=0.3) - cbar = plt.colorbar(s, cax=cax) - cbar.ax.tick_params(labelsize=20) - - ax.set_yticklabels([]) - ax.set_xticklabels([]) - ax.xaxis.set_ticks_position("none") - ax.yaxis.set_ticks_position("none") - ax.set_xlabel("u", fontsize=20) - ax.set_ylabel("v", fontsize=20) - plt.tight_layout() - - -def FT(img): - """ - Computes the 2d Fourier trafo of an image - - Parameters - ---------- - img : 2darray - values of Gaussian source - - Returns - ------- - 2darray - Fourier transform of input array - """ - return np.fft.fftshift(np.fft.fft2(img)) - - -def FT2(img): - """ - Computes the 2d Fourier trafo of an image - - Parameters - ---------- - img : 2darray - values of Gaussian source - - Returns - ------- - 2darray - Fourier transform of input array - """ - return np.fft.ifft2(np.fft.ifftshift(img)) - - -def apply_mask(img, mask): - """ - Applies a boolean mask to a 2d image - - Parameters - ---------- - img : 2darray - values of Gaussian source - mask : bool - mask for sampling frequencies - - Returns - ------- - 2darray - array with sampled frequencies - """ - img = img.copy() - img[~mask.astype(bool)] = 0 - return img diff --git a/src/radionets/simulations/uv_simulations.py b/src/radionets/simulations/uv_simulations.py deleted file mode 100644 index f6faf75a..00000000 --- a/src/radionets/simulations/uv_simulations.py +++ /dev/null @@ -1,453 +0,0 @@ -import astropy.coordinates as ac -import numpy as np - -import radionets.simulations.layouts.layouts as layouts - - -class Source: - """ - Source class that holds longitude and latitude information. - Can be converted to geocentric coordinates. Position of source - can be propagated to simulate an ongoing observation. - """ - - def __init__(self, lon, lat): - """Initializes the source class. - - Paramters - --------- - lon : float - longitude of source - lat : float - latitude of source - """ - self.lon = lon - self.lat = lat - - def to_ecef(self, val=None, prop=False): - """ - Converts from geodetic to geocentric coordinates - - Parameters - ---------- - val : list with [lon, lat] - A specific geodetic position - prop : bool - use True on lists of propagated source positions, default is False - - Returns - ------- - x, y, z : 1darrays - Positions in geocentric coordinates - """ - if prop is True: - quant = ac.EarthLocation(self.lon_prop, self.lat_prop).to_geocentric() - else: - quant = ac.EarthLocation([self.lon], [self.lat]).to_geocentric() - if val is not None: - quant = ac.EarthLocation([val[0]], [val[1]]).to_geocentric() - x = quant[0].value - y = quant[1].value - z = quant[2].value - return x, y, z - - def propagate(self, num_steps=None, multi_pointing=False): - """ - Propagates a source position with random parameters - - Parameters - ---------- - num_steps : int - number of propagation steps - multi_pointing : bool - when True observation blocks are simulated, default is False - - Returns - ------- - lon : 1darray - array with propagated lons - lat : 1darray - array with propagated lats - """ - if num_steps is None: - num_steps = np.random.randint(30, 60) - lon_start = self.lon - lon_stop = lon_start - num_steps - lon_step = 0.5 - lon = np.arange(lon_stop, lon_start, lon_step) - lon = lon[::-1] - - lat_start = self.lat - direction = np.sign(np.random.randint(0, 1) - 0.5) - lat_stop = round((lat_start + direction * num_steps / 50) + 0.005, 3) - lat_step = 0.01 - if lat_start > lat_stop: - lat = np.arange(lat_stop, lat_start, lat_step) - lat = lat[::-1] - else: - lat = np.arange(lat_start, lat_stop, lat_step) - - if len(lon) != len(lat): - raise ValueError("Length of lon and lat are different!") - - if multi_pointing is True: - lon = self.mod_delete(lon, 5, 10) - lat = self.mod_delete(lat, 5, 10) - - self.lon_prop = lon - self.lat_prop = lat - return lon, lat - - def mod_delete(self, a, n, m): - """ - Deletes all m steps n values in a - - Parameters - ---------- - a : 1darray - array with coordinates - n : int - number of deleted points - m : int - range between two deletions - - Returns - ------- - a : 1darray - array with reduced coordinate points - """ - return a[np.mod(np.arange(a.size), n + m) < n] - - -class Antenna: - """ - Antenna class that holds information about the geocentric coordinates of the - radio telescopes. Can be converted to geodetic. All baselines between the - the telescopes can be computed. Antenna positions can be shifted into a ENU frame - of a specific observation, for which the (u, v)-coverage can be computed. - """ - - def __init__(self, X, Y, Z): - """Initializes the class. - - Parameters - ---------- - X, Y, Z : array - X, Y, Z coordinates of antennas - """ - self.all = np.array(list(zip(X, Y, Z))) - self.len = len(self.all) - self.baselines = self.len * (self.len - 1) - self.X = X - self.Y = Y - self.Z = Z - - def to_geodetic(self, x_ref, y_ref, z_ref, enu=False): - """ - Converts geocentric coordinates to geodetic. - - Parameters - ---------- - x_ref, y_ref, z_ref : float - x, y, z reference positon - enu : bool - when True: - """ - import astropy.units as u - - quant = ac.EarthLocation(x_ref, y_ref, z_ref, u.meter).to_geodetic() - if enu is True: - return quant.lon.deg, quant.lat.deg - else: - self.lon = quant.lon.deg - self.lat = quant.lat.deg - - def get_baselines(self): - """ - Calculates baselines between antenna pairs - - Returns - ------- - x_base, y_base : 1darrays - x, y values of the baselines - """ - x_base = [] - y_base = [] - for i in range(self.len): - ref = np.ones((self.len, 2)) * ([self.x_enu[i], self.y_enu[i]]) - pairs = np.array([self.x_enu, self.y_enu]) - baselines = np.array(list(zip(ref, pairs.T))).ravel() - x = baselines[0::2] - y = baselines[1::2] - x_base = np.append(x_base, x) - y_base = np.append(y_base, y) - - drops = np.asarray( - [ - ((i * 2 + np.array([1, 2])) - 1) + (i * self.len * 2) - for i in range(self.len) - ] - ) - coords = np.delete(np.stack([x_base, y_base], axis=1), drops.ravel(), axis=0).T - x_base = coords[0] - y_base = coords[1] - return x_base, y_base - - def to_enu(self, x_ref, y_ref, z_ref): - """ - Converts from geodetic to geocentric coordinates projected onto 2d plane - - Parameters - ---------- - x_ref, y_ref, z_ref : 1darrays - x, y, z reference coordinates - """ - lon_ref, lat_ref = self.to_geodetic(x_ref, y_ref, z_ref, enu=True) - if isinstance(x_ref, int): - x_ref, y_ref, z_ref = [x_ref], [y_ref], [z_ref] - lon_ref, lat_ref = [lon_ref], [lat_ref] - ref = np.array(list(zip(x_ref, y_ref, z_ref))) - - def rot(lon, lat): - """ - Calculates roytation matrix - """ - lon = np.deg2rad(lon) - lat = np.deg2rad(lat) - return np.array( - [ - [-np.sin(lon), np.cos(lon), 0], - [ - -np.sin(lat) * np.cos(lon), - -np.sin(lat) * np.sin(lon), - np.cos(lat), - ], - [np.cos(lat) * np.cos(lon), np.cos(lat) * np.sin(lon), np.sin(lat)], - ] - ) - - enu = np.array( - [ - rot(lon_ref[j], lat_ref[j]) @ (self.all[i] - ref[j]) - for i in range(self.len) - for j in range(len(lon_ref)) - ] - ) - self.ant_enu = enu - self.x_enu = enu.ravel()[0::3] - self.y_enu = enu.ravel()[1::3] - self.z_enu = enu.ravel()[2::3] - return self.x_enu, self.y_enu - - def get_uv(self): - """ - Calculates (u, v)-coordinates - - Returns - ------- - u, v : 1d arrays - u, v coordinates - steps : int - number of observation steps - """ - u = [] - v = [] - steps = int(len(self.x_enu) / self.len) - for j in range(steps): - for i in range(self.len): - x = self.x_enu[j::steps] - y = self.y_enu[j::steps] - x_ref = x[i] * np.ones(self.len) - y_ref = y[i] * np.ones(self.len) - x_base = x - x_ref - y_base = y - y_ref - x_base = x_base[x_base != 0] / 0.02 - y_base = -y_base[y_base != 0] / 0.02 - u = np.append(u, x_base) - v = np.append(v, y_base) - - if len(u) != len(v): - raise ValueError("Length of u and v are different!") - return u, v, steps - - -def get_uv_coverage(source, antenna, multi_channel=False, bandwidths=4, iterate=False): - """ - Converts source position and antenna positions into an (u, v)-coverage. - - Parameters - ---------- - source : source class object - source class containing source positions - antenna : antenna clas object - antenna class containing antenna positions - iterate : bool - use True while creating (u, v)-coverage gif - - Returns - ------- - u : 1darray - u coordinates - v : 1darray - v coordinates - steps : 1darray - number of observation steps - """ - antenna.to_enu(*source.to_ecef(prop=True)) - u, v, steps = antenna.get_uv() - - if multi_channel: - u = np.repeat(u[None], bandwidths, axis=0) - v = np.repeat(v[None], bandwidths, axis=0) - scales = np.arange(bandwidths, dtype=float) - scales *= 0.02 - scales += 1 - u *= scales[:, None] - v *= scales[:, None] - else: - u = u[None] - v = v[None] - - if iterate is True: - num_base = antenna.baselines - u.resize((steps, num_base)) - v.resize((steps, num_base)) - - return u, v, steps - - -def create_mask(u, v, size=64): - """Create 2d mask from a given (uv)-coverage - - u : array of u coordinates - v : array of v coordinates - size : number of bins - """ - uv_hist, _, _ = np.histogram2d(u.ravel(), v.ravel(), bins=size) - # exclude center - limit = 2 if size % 2 == 0 else 3 - - ex_l = size // 2 - 2 - ex_h = size // 2 + limit - uv_hist[ex_l:ex_h, ex_l:ex_h] = 0 - mask = uv_hist > 0 - return np.rot90(mask) - - -def test_mask(bundle_size, num_channel, img_size): - """ - Test mask for filter tests - """ - mask = np.ones((bundle_size, num_channel, img_size, img_size)) - mask[:, :, 19, 30] = 0 - mask[:, :, 23, 23] = 0 - mask[:, :, 30, 19] = 0 - mask[:, :, 43, 32] = 0 - mask[:, :, 39, 39] = 0 - mask[:, :, 32, 43] = 0 - mask[:, :, 33:35, 33:35] = 0 - mask[:, :, 28:30, 28:30] = 0 - return mask - - -def sample_freqs( - img, - ant_config, - size=64, - lon=None, - lat=None, - num_steps=None, - plot=False, - test=False, - specific_mask=True, - multi_channel=False, - bandwidths=4, -): - """ - Sample specific frequencies in 2d Fourier space. Using antenna and source class to - simulate a radio interferometric observation. - - Parameters - ---------- - img : 2darray - 2d Fourier space - ant_config : str - name of antenna config - size : int - pixel size of input image, default 64x64 pixel - lon : float - start lon of source, if None: random start value between -90 and -70 is used - lat : float - start lat of source, if None a random start value between 30 and 80 is used - num_steps : int - number of observation steps - plot : bool - if True: returns sampled Fourier spectrum and sampling mask - test_mask : bool - if True: use same test mask for every image - - Returns - ------- - img : 2darray - sampled Fourier Spectrum - """ - - def get_mask( - lon, - lat, - num_steps, - ant, - size, - multi_channel=multi_channel, - bandwidths=bandwidths, - ): - s = Source(lon, lat) - s.propagate(num_steps=num_steps, multi_pointing=False) - u, v, _ = get_uv_coverage( - s, ant, multi_channel=multi_channel, iterate=False, bandwidths=bandwidths - ) - single_mask = create_mask(u, v, size) - return single_mask - - bundle_size = img.shape[0] - num_channel = img.shape[1] - img_size = img.shape[2] - if test: - mask = test_mask(bundle_size, num_channel, img_size) - else: - layout = getattr(layouts, ant_config) - ant = Antenna(*layout()) - if specific_mask is True: - s = Source(lon, lat) - s.propagate(num_steps=num_steps, multi_pointing=False) - u, v, _ = get_uv_coverage( - s, - ant, - multi_channel=multi_channel, - iterate=False, - bandwidths=bandwidths, - ) - single_mask = create_mask(u, v, size) - mask = np.repeat( - np.repeat(single_mask[None, None, :, :], num_channel, axis=1), - bundle_size, - axis=0, - ) - else: - mask = np.array([None, None, None]) - num_steps = np.random.randint(40, 60, size=(bundle_size,)) - lon = np.random.randint(-90, -70, size=(bundle_size,)) - lat = np.random.randint(30, 80, size=(bundle_size,)) - mask_woc = np.asarray( - [ - get_mask(lon[i], lat[i], num_steps[i], ant, size) - for i in range(bundle_size) - ] - ) - mask = np.repeat(mask_woc[:, None, :, :], num_channel, axis=1) - img = img.copy() - img[~mask.astype(bool)] = 0 - if plot is True: - return img, mask - else: - return img diff --git a/src/radionets/simulations/visualize_simulations.py b/src/radionets/simulations/visualize_simulations.py deleted file mode 100644 index 10574458..00000000 --- a/src/radionets/simulations/visualize_simulations.py +++ /dev/null @@ -1,369 +0,0 @@ -from pathlib import Path - -import cartopy.crs as ccrs -import cartopy.io.img_tiles as cimgt -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import cm -from matplotlib.colors import ListedColormap, LogNorm - -import radionets.simulations.layouts.layouts as layouts -from radionets.evaluation.utils import make_axes_nice -from radionets.simulations.uv_simulations import Antenna, Source - - -def create_OrBu(): - top = cm.get_cmap("Blues_r", 128) - bottom = cm.get_cmap("Oranges", 128) - white = np.array([256 / 256, 256 / 256, 256 / 256, 1]) - newcolors = np.vstack((top(np.linspace(0, 1, 128)), bottom(np.linspace(0, 1, 128)))) - newcolors[128, :] = white - newcmp = ListedColormap(newcolors, name="OrangeBlue") - return newcmp - - -OrBu = create_OrBu() - - -def create_path(path): - p = Path(path).parent - p.mkdir(parents=True, exist_ok=True) - - -def vlba_basic(center_lon=-110, center_lat=27.75): - layout = layouts.vlba - ant = Antenna(*layout()) - ant.to_geodetic(ant.X, ant.Y, ant.Z) - - s = Source(center_lon, center_lat) - s.to_ecef(prop=False) - - ant_lon = ant.lon - ant_lat = ant.lat - - ant.to_enu(*s.to_ecef(prop=False)) - base_lon, base_lat = ant.get_baselines() - return ant_lon, ant_lat, base_lon, base_lat - - -def plot_vlba( - out_path, ant_lon, ant_lat, base_lon, base_lat, center_lon=-110, center_lat=27.75 -): - extent = [-155, -65, 10, 45.5] - central_lon = np.mean(extent[:2]) - central_lat = np.mean(extent[2:]) - - stamen_terrain = cimgt.Stamen("terrain-background") - - plt.figure(figsize=(5.78 * 2, 3.57)) - ax = plt.axes(projection=ccrs.Orthographic(central_lon, central_lat)) - ax.set_extent(extent) - - ax.plot( - ant_lon, - ant_lat, - marker=".", - color="black", - linestyle="none", - markersize=6, - zorder=10, - transform=ccrs.Geodetic(), - label="Antenna positions", - ) - ax.plot( - base_lon, - base_lat, - zorder=5, - linestyle="-", - linewidth=0.5, - alpha=0.7, - color="#d62728", - label="Baselines", - ) - - ax.add_image(stamen_terrain, 4) - - leg = plt.legend(markerscale=1.5, fontsize=7, loc=2) - for legobj in leg.legendHandles: - legobj.set_linewidth(1.5) - - plt.savefig(out_path, dpi=100, bbox_inches="tight", pad_inches=0.05) - - -def create_vlba_overview(out_path): - ant_lon, ant_lat, base_lon, base_lat = vlba_basic() - create_path(out_path) - plot_vlba(out_path, ant_lon, ant_lat, base_lon, base_lat) - - -def plot_source(img, log=False, out_path=None): - fig = plt.figure() - ax = fig.add_subplot(111) - img = np.abs(img) - ax.set_xlabel("l") - ax.set_ylabel("m") - if log is True: - s = ax.imshow(img, cmap="inferno", norm=LogNorm(vmin=1e-8, vmax=img.max())) - else: - s = ax.imshow(img, cmap="inferno") - make_axes_nice(fig, ax, s, "") - - ax.set_yticklabels([]) - ax.set_xticklabels([]) - ax.xaxis.set_ticks_position("none") - ax.yaxis.set_ticks_position("none") - plt.tight_layout(pad=0) - - if out_path is not None: - plt.savefig(out_path, bbox_inches="tight") - - -def plot_comparison(img1, img2, log=False, out_path=None): - fig, (ax1, ax2) = plt.subplots( - 2, - 1, - ) - if log is True: - im1 = ax1.imshow(img1, cmap="inferno", norm=LogNorm(vmin=1e-8, vmax=img1.max())) - im2 = ax2.imshow(img2, cmap="inferno", norm=LogNorm(vmin=1e-8, vmax=img2.max())) - else: - im1 = ax1.imshow(img1, cmap="inferno") - im2 = ax2.imshow(img2, cmap="inferno") - - make_axes_nice(fig, ax1, im1, "") - ax1.set_xlabel("l") - ax1.set_ylabel("m") - - make_axes_nice(fig, ax2, im2, "") - ax2.set_xlabel("l") - ax2.set_ylabel("m") - - plt.tight_layout(pad=0) - - if out_path is not None: - plt.savefig(out_path, dpi=600, bbox_inches="tight") - - -def plot(img, ax, phase=False, grey=False): - if grey: - if phase: - im = ax.imshow( - np.abs(img), cmap="binary_r", vmin=-np.pi, vmax=np.pi, alpha=1 - ) - else: - im = ax.imshow( - img, - cmap="binary_r", - alpha=1, - # vmin=1e-8, - norm=LogNorm(), - ) - else: - if phase: - im = ax.imshow(img, cmap=OrBu, vmin=-np.pi, vmax=np.pi) - else: - im = ax.imshow(img, cmap="inferno", norm=LogNorm()) - - ax.set_yticklabels([]) - ax.set_xticklabels([]) - ax.xaxis.set_ticks_position("none") - ax.yaxis.set_ticks_position("none") - return im - - -def plot_spectrum(img, out_path=None): - fig, (ax1, ax2) = plt.subplots( - 1, - 2, - ) - amp = np.abs(img) - phase = np.angle(img) - - im1 = plot(amp, ax1) - make_axes_nice(fig, ax1, im1, "") - ax1.set_xlabel("u") - ax1.set_ylabel("v") - - im2 = plot(phase, ax2, phase=True) - make_axes_nice(fig, ax2, im2, "", phase=True) - ax2.set_xlabel("u") - ax2.set_ylabel("v") - - plt.tight_layout(pad=0) - - if out_path is not None: - plt.savefig(out_path, dpi=600, bbox_inches="tight") - - -def plot_spectrum_grey(img1, img2, out_path=None): - fig, (ax1, ax2) = plt.subplots( - 2, - 1, - ) - amp1 = np.abs(img1) - phase1 = np.angle(img1) - - amp2 = np.ma.masked_where(np.abs(img2) == 0, np.abs(img2)) - phase2 = np.ma.masked_where(np.angle(img2) == 0, np.angle(img2)) - - plot(amp1, ax1, grey=True) - im1 = plot(amp2, ax1) - make_axes_nice(fig, ax1, im1, "") - ax1.set_xlabel("u") - ax1.set_ylabel("v") - - plot(phase1, ax2, phase=True, grey=True) - im2 = plot(phase2, ax2, phase=True) - make_axes_nice(fig, ax2, im2, "", phase=True) - ax2.set_xlabel("u") - ax2.set_ylabel("v") - - fig.tight_layout(pad=0) - plt.subplots_adjust(hspace=0.22) - - if out_path is not None: - plt.savefig(out_path, dpi=600, bbox_inches="tight") - - -def ft(img): - return np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(img))) - - -def plot_uv_coverage(u, v, ax): - ax.plot( - u, v, marker="o", linestyle="none", markersize=2, color="#1f77b4", label="Data" - ) - ax.set_xlabel(r"u / $\lambda$") - ax.set_ylabel(r"v / $\lambda$") - - -def plot_vlba_uv(u, v, out_path=None): - fig, ax = plt.subplots(1, figsize=(3.57, 3.57), dpi=100) - plot_uv_coverage(u, v, ax) - ax.set_ylim(-5e8, 5e8) - ax.set_xlim(-5e8, 5e8) - plt.tick_params(axis="both") - - ax.axis("equal") - # plt.legend() - fig.tight_layout() - - if out_path is not None: - plt.savefig(out_path, dpi=100, bbox_inches="tight", pad_inches=0.05) - - -def plot_baselines(antenna): - x_base, y_base = antenna.get_baselines() - plt.plot( - x_base, - y_base, - linestyle="-", - color="#2ca02c", - zorder=0, - label="Baselines", - alpha=0.35, - ) - - -def plot_antenna_distribution( - source_lon, - source_lat, - source, - antenna, - baselines=False, - end=False, - lon_start=None, - lat_start=None, - out_path=None, -): - x, y, z = source.to_ecef(val=[source_lon, source_lat]) # only use source ? - x_enu_ant, y_enu_ant = antenna.to_enu(x, y, z) - - plt.figure(figsize=(5.78, 3.57), dpi=100) - ax = plt.axes(projection=ccrs.Orthographic(source_lon, source_lat)) - ax.set_global() - ax.coastlines() - - plt.plot( - x_enu_ant, - y_enu_ant, - marker="o", - markersize=3, - color="#1f77b4", - linestyle="none", - label="Antenna positions", - ) - plt.plot( - x, - y, - marker="*", - linestyle="none", - color="#ff7f0e", - markersize=10, - transform=ccrs.Geodetic(), - zorder=10, - label="Projected source", - ) - - if baselines: - plot_baselines(antenna) # projected baselines - - if end: - x_start, y_start, _ = source.to_ecef(val=[lon_start, lat_start]) - - ax.plot( - np.array([x, x_start]), - np.array([y, y_start]), - marker=".", - linestyle="--", - color="#d62728", - linewidth=1, - transform=ccrs.Geodetic(), - zorder=10, - label="Source path", - ) - ax.plot( - x_start, - y_start, - marker=".", - color="green", - zorder=10, - label="hi", - transform=ccrs.Geodetic(), - ) - - plt.legend( - fontsize=9, markerscale=1.5, bbox_to_anchor=(0.95, 1), loc=2, borderaxespad=0.0 - ) - - if out_path is not None: - plt.savefig(out_path, dpi=100, bbox_inches="tight", pad_inches=0.05) - - -def plot_mask(mask): - fig = plt.figure() - ax = fig.add_subplot(111) - - im = ax.imshow(mask, cmap="inferno") - values = np.unique(mask.ravel()) - colors = [im.cmap(im.norm(value)) for value in values] - names = ["Unsampled", "Sampled"] - patches = [ - mpatches.Patch(color=colors[i], label=f"{names[i]}") for i in range(len(values)) - ] - plt.legend(handles=patches, loc=0, framealpha=1) - - ax.set_yticklabels([]) - ax.set_xticklabels([]) - ax.xaxis.set_ticks_position("none") - ax.yaxis.set_ticks_position("none") - ax.set_xlabel("u") - ax.set_ylabel("v") - plt.tight_layout() - - -def apply_mask(img, mask): - img = img.copy() - img[~mask.astype(bool)] = 0 - return img diff --git a/src/radionets/tools/cli.py b/src/radionets/tools/cli.py new file mode 100644 index 00000000..aabbc009 --- /dev/null +++ b/src/radionets/tools/cli.py @@ -0,0 +1,53 @@ +import rich_click as click + +from radionets import __version__ + +from .model_cli import main as model +from .quickstart import main as quickstart + +click.rich_click.COMMAND_GROUPS = { + "radionets": [ + { + "name": "Model Operations", + "commands": ["train", "test", "inference", "predict"], + }, + { + "name": "Setup", + "commands": ["quickstart"], + }, + ] +} + + +@click.group( + help=f""" + This is the [dark_orange]Radionets[/] + [cornflower_blue]v{__version__}[/] main CLI tool. + """ +) +def main(): + pass + + +def create_mode_command(mode, cmd_alias=None): + """Factory function to create mode-specific commands""" + if cmd_alias is None: + cmd_alias = mode + + @click.command(name=cmd_alias, help=f"Run radionets in {mode} mode") + @click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) + @click.pass_context + def command(ctx, config_path): + ctx.invoke(model, config_path=config_path, mode=mode) + + return command + + +main.add_command(quickstart, name="quickstart") +main.add_command(create_mode_command("train")) +main.add_command(create_mode_command("test")) +main.add_command(create_mode_command("predict")) +main.add_command(create_mode_command("predict", "inference")) # NOTE: Subject to change + +if __name__ == "__main__": + main() diff --git a/src/radionets/tools/model_cli.py b/src/radionets/tools/model_cli.py new file mode 100644 index 00000000..86bebefc --- /dev/null +++ b/src/radionets/tools/model_cli.py @@ -0,0 +1,119 @@ +from pathlib import Path + +import lightning as L +import rich_click as click +from rich import print + +from radionets.core.callbacks import Callbacks +from radionets.core.logging import Loggers +from radionets.io import TrainConfig +from radionets.training import TrainModule +from radionets.utils._paths import _validate_pre_model_path +from radionets.utils.carbon_tracking import CarbonTracker + + +@click.command() +@click.argument( + "mode", + type=click.Choice(["train", "test", "predict"], case_sensitive=False), + default="train", +) +@click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) +@click.option("--premodel", "-p", type=click.Path(exists=True, dir_okay=False)) +def main(config_path, mode="train", premodel=None): + """Starts the radionets training process with + options specified in configuration file. + + Parameters + ---------- + configuration_path : str + Path to the configuration toml file. + mode : str, optional + Operation mode, can be one of {'train', 'test', 'predict'}. + Default: 'train' + """ + if not isinstance(config_path, Path): + config_path = Path(config_path) + + train_config = TrainConfig.from_toml(config_path) + + print(train_config) + + if mode == "test": + train_config.paths.model_path /= "testing" + + if mode == "predict": + train_config.paths.model_path /= "inference" + + if premodel: + # if the premodel cli option is used, overwrite config path + train_config.paths.checkpoint = premodel + + data_module = train_config.dataloader.module( + data_dir=train_config.paths.data_path, + batch_size=train_config.training.batch_size, + fourier=train_config.model.fourier, + **train_config.dataloader.model_dump(), + ) + + # Dumping train_config to a dict here results in nicer + # formatting in hparams.yaml files generated by lightning + train_module = TrainModule(train_config.model_dump(), data_module.train_length) + + loggers = Loggers.get_loggers(train_config) + callbacks = Callbacks.get_callbacks(train_config) + + # with rich_training_layout(train_config, callbacks) as layout_callbacks: + trainer = L.Trainer( + limit_train_batches=data_module.train_length // train_config.training.batch_size + if data_module.train_length + else train_config.training.batch_size, + limit_val_batches=data_module.valid_length // train_config.training.batch_size + if data_module.valid_length + else train_config.training.batch_size, + limit_test_batches=data_module.test_length // train_config.training.batch_size + if data_module.test_length + else train_config.training.batch_size, + max_epochs=train_config.training.num_epochs, + callbacks=callbacks, + logger=loggers, + log_every_n_steps=train_config.training.batch_size, + devices=train_config.devices.num_devices, + accelerator=train_config.devices.accelerator, + precision=train_config.devices.precision, + strategy=train_config.devices.deepspeed + if train_config.devices.deepspeed + else train_config.devices.strategy, + ) + + trainer.radionets_task = mode.lower() + + if mode.lower() == "train": + # let mlflow callback stop the tracker + stop_inside_scope = train_config.logging.mlflow + + with CarbonTracker( + train_config=train_config, stop_inside_scope=stop_inside_scope + ) as tracker: + trainer.carbontracker = tracker + trainer.fit(model=train_module, datamodule=data_module) + + elif mode.lower() == "test": + _validate_pre_model_path(train_config) + + train_module = TrainModule.load_from_checkpoint(train_config.paths.pre_model) + preds = trainer.test(model=train_module, datamodule=data_module) + + return preds + + elif mode.lower() == "predict": + _validate_pre_model_path(train_config) + + train_module = TrainModule.load_from_checkpoint(train_config.paths.pre_model) + preds = trainer.predict(model=train_module, datamodule=data_module) + + return preds + + +if __name__ == "__main__": + main() diff --git a/src/radionets/tools/quickstart.py b/src/radionets/tools/quickstart.py index e0a903a2..dbf55be6 100644 --- a/src/radionets/tools/quickstart.py +++ b/src/radionets/tools/quickstart.py @@ -6,7 +6,7 @@ from rich.pretty import pretty_repr from radionets import __version__ -from radionets.core.logging import setup_logger +from radionets.core.logging import _setup_logger @click.command() @@ -14,6 +14,14 @@ "config_path", type=click.Path(dir_okay=True), ) +@click.option( + "-m", + "--mode", + type=click.Choice(["train", "eval"]), + default="train", + help="""What config file to create at config_path. + Valid are {train, eval}. Default: train""", +) @click.option( "-y", "--yes", @@ -22,28 +30,38 @@ is_flag=True, help="Overwrite file if it already exists.", ) -def quickstart( +def main( config_path: str | Path, + mode: str = "train", overwrite: bool = False, ) -> None: - """Quickstart CLI tool for pyvisgen. Creates - a copy of the default simulation configuration + """Quickstart CLI tool for radionets. Creates + a copy of the default train or eval configuration file at the specified path. Parameters ---------- config_path : str or Path Path to write the config to. + mode : str, optional + Determines the type of config. Either 'train' + or 'eval' are valid. Default: 'train' + overwrite : bool, optional + If ``True``, overwrites the config file if it already + exists. Default: ``False`` Notes ----- If a directory is given, this tool will create - a file called 'radionets_default_train_config.toml' + a file called 'radionets_default_{train,eval}_config.toml' inside that directory. """ - log = setup_logger(namespace=__name__, tracebacks_suppress=[click]) + if mode not in ["train", "eval"]: + raise ValueError("Unknown mode: Expected one of {train, eval}.") - msg = f"This is the pyvisgen [blue]v{__version__}[/] quickstart tool" + log = _setup_logger(namespace=__name__, tracebacks_suppress=[click]) + + msg = f"This is the radionets [blue]v{__version__}[/] quickstart tool" log.info(msg, extra={"markup": True, "highlighter": None}) log.info((len(msg) - len("[blue][/]")) * "=") @@ -51,18 +69,24 @@ def quickstart( config_path = Path(config_path) root = sysconfig.get_path("data", sysconfig.get_default_scheme()) - default_config_path = Path( - root + "/share/configs/radionets_default_train_config.toml" - ) + if mode == "train": + default_config_path = Path( + root + "/share/configs/radionets_default_train_config.toml" + ) + else: + default_config_path = Path( + root + "/share/configs/radionets_default_eval_config.toml" + ) + + log.info(f"Loading default radionets {mode} configuration...") with open(default_config_path) as f: default_config = toml.load(f) - log.info("Loading default pyvisgen configuration:") log.info(pretty_repr(default_config)) if config_path.is_dir(): - config_path /= "radionets_default_train_config.toml" + config_path /= f"radionets_default_{mode}_config.toml" # write_file is used below; the following if statement acts as # a switch, toggling write_file to False if the user does not @@ -86,4 +110,4 @@ def quickstart( if __name__ == "__main__": - quickstart() + main() diff --git a/src/radionets/training/__init__.py b/src/radionets/training/__init__.py index e69de29b..28ae880f 100644 --- a/src/radionets/training/__init__.py +++ b/src/radionets/training/__init__.py @@ -0,0 +1,3 @@ +from .trainers import TrainModule + +__all__ = ["TrainModule"] diff --git a/src/radionets/training/scripts/__init__.py b/src/radionets/training/scripts/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/radionets/training/scripts/start_training.py b/src/radionets/training/scripts/start_training.py deleted file mode 100644 index fdabf766..00000000 --- a/src/radionets/training/scripts/start_training.py +++ /dev/null @@ -1,178 +0,0 @@ -import sys -from pathlib import Path - -import click -import toml -from rich.pretty import pretty_repr - -from radionets.core.learner import define_learner -from radionets.core.logging import setup_logger -from radionets.core.model import load_pre_model -from radionets.evaluation.train_inspection import after_training_plots -from radionets.plotting.inspection import plot_loss, plot_lr, plot_lr_loss -from radionets.training.utils import ( - check_outpath, - create_databunch, - define_arch, - end_training, - get_normalisation_factors, - pop_interrupt, - read_config, -) - -LOGGER = setup_logger(tracebacks_suppress=[click]) - - -@click.command() -@click.argument("configuration_path", type=click.Path(exists=True, dir_okay=False)) -@click.option( - "--mode", - type=click.Choice( - ["train", "lr_find", "plot_loss", "fine_tune"], case_sensitive=False - ), - default="train", -) -def main(configuration_path, mode): - """ - Start DNN training with options specified in configuration file. - - Parameters - ---------- - configuration_path : str - Path to the configuration toml file - - Notes - ----- - train : start training of deep learning model (default option) - lr_find : execute learning rate finder - plot_loss : plot losscurve of existing model - """ - config = toml.load(configuration_path) - train_conf = read_config(config) - - LOGGER.info("Train config:") - LOGGER.info(pretty_repr(train_conf)) - - # create databunch - data = create_databunch( - data_path=train_conf["data_path"], - fourier=train_conf["fourier"], - batch_size=train_conf["batch_size"], - ) - - # get image size - train_conf["image_size"] = data.train_ds[0][0][0].shape[1] - - # define architecture - arch = define_arch( - arch_name=train_conf["arch_name"], img_size=train_conf["image_size"] - ) - - if mode == "train": - if train_conf["normalize"] == "mean": - train_conf["norm_factors"] = get_normalisation_factors(data) - # check out path and look for existing model files - check_outpath(train_conf["model_path"], train_conf) - - LOGGER.info("Start training of the model.") - - # define_learner - learn = define_learner(data, arch, train_conf) - - # load pretrained model - if train_conf["pre_model"] != "none": - learn.create_opt() - load_pre_model(learn, train_conf["pre_model"]) - - # Train the model, except interrupt - # train_conf["comet_ml"] = True - try: - if train_conf["comet_ml"]: - learn.comet.experiment.log_parameters(train_conf) - with learn.comet.experiment.train(): - learn.fit(train_conf["num_epochs"]) - else: - learn.fit(train_conf["num_epochs"]) - except KeyboardInterrupt: - pop_interrupt(learn, train_conf) - - end_training(learn, train_conf) - - if train_conf["inspection"]: - after_training_plots(train_conf, rand=True) - - if mode == "fine_tune": - LOGGER.info("Start fine tuning of the model.") - - # define_learner - learn = define_learner( - data, - arch, - train_conf, - ) - - # load pretrained model - if train_conf["pre_model"] == "none": - LOGGER.warning("Need a pre-trained modle for fine tuning!") - return - - learn.create_opt() - load_pre_model(learn, train_conf["pre_model"]) - - # Train the model, except interrupt - try: - learn.fine_tune(train_conf["num_epochs"]) - except KeyboardInterrupt: - pop_interrupt(learn, train_conf) - - end_training(learn, train_conf) - if train_conf["inspection"]: - after_training_plots(train_conf, rand=True) - - if mode == "lr_find": - LOGGER.info("Start lr_find.") - if train_conf["normalize"] == "mean": - train_conf["norm_factors"] = get_normalisation_factors(data) - - # define_learner - learn = define_learner(data, arch, train_conf, lr_find=True) - - # load pretrained model - if train_conf["pre_model"] != "none": - learn.create_opt() - load_pre_model(learn, train_conf["pre_model"]) - - learn.lr_find() - - # save loss plot - plot_lr_loss( - learn, - train_conf["arch_name"], - Path(train_conf["model_path"]).parent, - skip_last=5, - output_format=train_conf["format"], - ) - - if mode == "plot_loss": - LOGGER.info("Start plotting loss.") - - # define_learner - learn = define_learner(data, arch, train_conf, plot_loss=True) - # load pretrained model - if Path(train_conf["model_path"]).exists: - load_pre_model(learn, train_conf["model_path"], plot_loss=True) - else: - LOGGER.warning("Selected model does not exist.") - LOGGER.info("Exiting.") - sys.exit() - - plot_lr( - learn, Path(train_conf["model_path"]), output_format=train_conf["format"] - ) - plot_loss( - learn, Path(train_conf["model_path"]), output_format=train_conf["format"] - ) - - -if __name__ == "__main__": - main() diff --git a/src/radionets/training/trainers.py b/src/radionets/training/trainers.py new file mode 100644 index 00000000..a904d476 --- /dev/null +++ b/src/radionets/training/trainers.py @@ -0,0 +1,106 @@ +from inspect import signature + +from lightning import LightningModule + + +class TrainModule(LightningModule): + def __init__(self, train_config: dict, train_length: int = None): + super().__init__() + self.save_hyperparameters() + + self.train_config = train_config + self.model = train_config["model"]["arch_name"]() + self.optimizer = train_config["training"]["optimizer"]["optimizer"] + self.loss_fn = train_config["training"]["loss"]["loss_func"]() + self.train_length = train_length + self.num_epochs = train_config["training"]["num_epochs"] + self.batch_size = train_config["training"]["batch_size"] + + def forward(self, inputs): + return self.model(inputs) + + def training_step(self, batch, batch_idinputs): + inputs, targets = self._extract_inputs_targets(batch) + + logits = self(inputs)["pred"] + loss = self.loss_fn(logits, targets) + self.log("train_loss", loss, prog_bar=True, sync_dist=True) + + return loss + + def validation_step(self, batch, batch_idx): + inputs, targets = self._extract_inputs_targets(batch) + + logits = self(inputs)["pred"] + loss = self.loss_fn(logits, targets) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + + def _extract_inputs_targets(self, batch): + if isinstance(batch, dict): + inputs = batch["inputs"] + targets = batch.get("target", None) + elif isinstance(batch, list | tuple): + if len(batch) >= 2 and hasattr(batch[1], "__array__"): + inputs, targets = batch[0], batch[1] + else: + inputs, targets = batch[0], None + else: + inputs, targets = batch, None + + return inputs, targets + + def test_step(self, batch, batch_idx): + inputs, targets = self._extract_inputs_targets(batch) + preds = self(inputs)["pred"] + + if targets is not None: + loss = self.loss_fn(preds, targets) + self.log("test_loss", loss, prog_bar=True, sync_dist=True) + + return preds, targets + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + inputs, _ = self._extract_inputs_targets(batch) + preds = self(inputs)["pred"] + + return preds + + def configure_optimizers(self): + optimizer = self.optimizer( + self.parameters(), + lr=self.train_config["training"]["optimizer"]["lr"], + ) + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + # optimizer, mode="min", factor=0.1, patience=10, threshold=1e-5 + # ) + + sched_config = self.train_config["training"]["lr_scheduling"] + if sched_config: + sched_sig_set = set(signature(sched_config["scheduler"]).parameters.keys()) + + if "epochs" not in sched_config: + sched_config["epochs"] = self.num_epochs + + if "steps_per_epoch" not in sched_config: + sched_config["steps_per_epoch"] = ( + int(self.train_length) // self.batch_size + ) + + sched_config_set = set(sched_config) + sched_keys = sched_config_set.intersection(sched_sig_set) + sched_kwargs = dict(zip(sched_keys, map(sched_config.get, sched_keys))) + + scheduler = sched_config["scheduler"](optimizer, **sched_kwargs) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": sched_config["monitor"], + "interval": sched_config["interval"], + "frequency": sched_config["frequency"], + "strict": sched_config["strict"], + }, + } + + return {"optimizer": optimizer} diff --git a/src/radionets/training/utils.py b/src/radionets/training/utils.py deleted file mode 100644 index c065287b..00000000 --- a/src/radionets/training/utils.py +++ /dev/null @@ -1,152 +0,0 @@ -import sys -from pathlib import Path - -import click -import torch -from tqdm import tqdm - -from radionets import architecture -from radionets.core.data import DataBunch, get_dls, load_data -from radionets.core.logging import setup_logger -from radionets.core.model import save_model -from radionets.evaluation.train_inspection import create_inspection_plots -from radionets.plotting.inspection import plot_loss - -LOGGER = setup_logger(namespace=__name__) - - -def create_databunch(data_path, fourier, batch_size): - # Load data sets - train_ds = load_data(data_path, "train", fourier=fourier) - valid_ds = load_data(data_path, "valid", fourier=fourier) - - # Create databunch with defined batchsize - data = DataBunch(*get_dls(train_ds, valid_ds, batch_size)) - return data - - -def read_config(config): - train_conf = {} - train_conf["data_path"] = config["paths"]["data_path"] - train_conf["model_path"] = config["paths"]["model_path"] - train_conf["pre_model"] = config["paths"]["pre_model"] - - train_conf["quiet"] = config["mode"]["quiet"] - train_conf["gpu"] = config["mode"]["gpu"] - - train_conf["comet_ml"] = config["logging"]["comet_ml"] - train_conf["plot_n_epochs"] = config["logging"]["plot_n_epochs"] - train_conf["project_name"] = config["logging"]["project_name"] - train_conf["scale"] = config["logging"]["scale"] - - train_conf["batch_size"] = config["hypers"]["batch_size"] - train_conf["lr"] = config["hypers"]["lr"] - - train_conf["fourier"] = config["general"]["fourier"] - train_conf["amp_phase"] = config["general"]["amp_phase"] - train_conf["normalize"] = config["general"]["normalize"] - train_conf["arch_name"] = config["general"]["arch_name"] - train_conf["loss_func"] = config["general"]["loss_func"] - train_conf["num_epochs"] = config["general"]["num_epochs"] - train_conf["inspection"] = config["general"]["inspection"] - train_conf["separate"] = False - train_conf["format"] = config["general"]["output_format"] - train_conf["switch_loss"] = config["general"]["switch_loss"] - train_conf["when_switch"] = config["general"]["when_switch"] - - train_conf["param_scheduling"] = config["param_scheduling"]["use"] - train_conf["lr_start"] = config["param_scheduling"]["lr_start"] - train_conf["lr_max"] = config["param_scheduling"]["lr_max"] - train_conf["lr_stop"] = config["param_scheduling"]["lr_stop"] - train_conf["lr_ratio"] = config["param_scheduling"]["lr_ratio"] - - train_conf["source_list"] = config["general"]["source_list"] - - return train_conf - - -def check_outpath(model_path, train_conf): - path = Path(model_path) - - exists = path.exists() - if exists: - if train_conf["quiet"]: - LOGGER.info("Overwriting existing model file!") - path.unlink() - else: - if click.confirm( - "Do you really want to overwrite existing model file?", abort=True - ): - LOGGER.info("Overwriting existing model file!") - path.unlink() - - -def define_arch(arch_name, img_size): - if ( - "filter_deep" in arch_name - or "resnet" in arch_name - or "Uncertainty" in arch_name - ): - arch = getattr(architecture, arch_name)(img_size) - else: - arch = getattr(architecture, arch_name)() - - return arch - - -def pop_interrupt(learn, train_conf): - if click.confirm("KeyboardInterrupt, do you want to save the model?", abort=False): - model_path = train_conf["model_path"] - # save model - LOGGER.info(f"Saving the model after epoch {learn.epoch}") - save_model(learn, model_path) - - # plot loss - plot_loss(learn, model_path) - - # Plot input, prediction and true image if asked - if train_conf["inspection"]: - create_inspection_plots(learn, train_conf) - else: - LOGGER.info(f"Stopping after epoch {learn.epoch}") - - sys.exit(1) - - -def end_training(learn, train_conf): - # Save model - save_model(learn, Path(train_conf["model_path"])) - - # Plot loss - plot_loss(learn, Path(train_conf["model_path"])) - - -def get_normalisation_factors(data): - mean_real = [] - mean_imag = [] - std_real = [] - std_imag = [] - - for inp, _ in tqdm(data.train_ds): - mean_batch_imag = inp[1].mean() - mean_batch_real = inp[0].mean() - std_batch_imag = inp[1].std() - std_batch_real = inp[0].std() - mean_real.append(mean_batch_real) - mean_imag.append(mean_batch_imag) - std_real.append(std_batch_real) - std_imag.append(std_batch_imag) - - mean_real = torch.tensor(mean_real).mean() - mean_imag = torch.tensor(mean_imag).mean() - std_real = torch.tensor(std_real).std() - std_imag = torch.tensor(std_imag).std() - - norm_factors = { - "mean_real": mean_real, - "mean_imag": mean_imag, - "std_real": std_real, - "std_imag": std_imag, - } - - return norm_factors diff --git a/src/radionets/utils/_paths.py b/src/radionets/utils/_paths.py new file mode 100644 index 00000000..3b0e34ec --- /dev/null +++ b/src/radionets/utils/_paths.py @@ -0,0 +1,13 @@ +def _validate_pre_model_path(train_config): + if not train_config.paths.pre_model: + raise ValueError( + f"'pre_model' path is {train_config.paths.pre_model} " + "even though testing mode was started. Please make sure " + "you provide a valid path to a model checkpoint file (.ckpt) " + "in your configuration." + ) + if not train_config.paths.pre_model.is_file(): + raise ValueError( + f"'pre_model' path is {train_config.paths.pre_model}, " + "but not a valid path to a model checkpoint file (.ckpt)." + ) diff --git a/src/radionets/utils/carbon_tracking.py b/src/radionets/utils/carbon_tracking.py new file mode 100644 index 00000000..495603ad --- /dev/null +++ b/src/radionets/utils/carbon_tracking.py @@ -0,0 +1,42 @@ +from typing import Self + +try: + from codecarbon import OfflineEmissionsTracker + + _CODECARBON_AVAILABLE = True +except ImportError: + _CODECARBON_AVAILABLE = False + + +__all__ = ["CarbonTracker"] + + +class DummyTracker: + def start(self): + pass + + def stop(self): + pass + + +class CarbonTracker: + def __init__(self, train_config, stop_inside_scope=True, *args, **kwargs): + self.train_config = train_config + self.use = _CODECARBON_AVAILABLE and train_config.logging.codecarbon + self.stop = stop_inside_scope + + def __enter__(self) -> Self: + if self.use: + self.tracker = OfflineEmissionsTracker( + **self.train_config.logging.codecarbon.model_dump() + ) + self.tracker.start() + else: + self.tracker = DummyTracker() + + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if self.stop: + self.tracker.stop() + return None diff --git a/tests/test_architecture_layers.py b/tests/test_architecture_layers.py index 0358b06e..e70dd1ae 100644 --- a/tests/test_architecture_layers.py +++ b/tests/test_architecture_layers.py @@ -1,6 +1,6 @@ +import pytest import torch import torch.nn as nn -import pytest from radionets.architecture.layers import LocallyConnected2d @@ -18,7 +18,7 @@ def test_initialization_without_bias(self): output_size=(8, 8), kernel_size=3, stride=1, - bias=False + bias=False, ) # Check weight shape @@ -40,7 +40,7 @@ def test_initialization_with_bias(self): output_size=(8, 8), kernel_size=3, stride=1, - bias=True + bias=True, ) # Check bias shape @@ -56,7 +56,7 @@ def test_forward_pass_basic(self): output_size=(8, 8), kernel_size=3, stride=1, - bias=False + bias=False, ) # Input that should produce 8x8 output with kernel_size=3, stride=1 @@ -76,7 +76,7 @@ def test_forward_pass_with_bias(self): output_size=(8, 8), kernel_size=3, stride=1, - bias=True + bias=True, ) input_tensor = torch.randn(2, 3, 10, 10) @@ -93,7 +93,7 @@ def test_different_stride(self): output_size=(4, 4), kernel_size=3, stride=2, - bias=False + bias=False, ) # Input size: 10x10 -> (10-3+1)/2 = 4x4 @@ -111,7 +111,7 @@ def test_single_channel_input_output(self): output_size=(6, 6), kernel_size=2, stride=1, - bias=False + bias=False, ) # Input size: 7x7 -> (7-2+1)/1 = 6x6 @@ -129,7 +129,7 @@ def test_gradient_flow(self): output_size=(5, 5), kernel_size=2, stride=1, - bias=True + bias=True, ) input_tensor = torch.randn(1, 2, 6, 6, requires_grad=True) @@ -157,7 +157,7 @@ def test_batch_processing(self): output_size=(3, 3), kernel_size=3, stride=1, - bias=False + bias=False, ) # Test different batch sizes @@ -175,7 +175,7 @@ def test_deterministic_output(self): output_size=(3, 3), kernel_size=2, stride=1, - bias=True + bias=True, ) input_tensor = torch.randn(1, 2, 4, 4) @@ -195,7 +195,7 @@ def test_weight_initialization_range(self): output_size=(8, 8), kernel_size=3, stride=1, - bias=True + bias=True, ) # Check that weights are not all zeros or all the same @@ -217,7 +217,7 @@ def test_large_kernel_size(self): output_size=(1, 1), kernel_size=5, stride=1, - bias=False + bias=False, ) # Input size: 5x5 -> (5-5+1)/1 = 1x1 @@ -231,13 +231,18 @@ def test_large_kernel_size(self): expected_weight_shape = (1, 2, 1, 1, 1, 25) assert layer.weight.shape == expected_weight_shape - @pytest.mark.parametrize("in_channels,out_channels,kernel_size,stride", [ - (1, 1, 2, 1), - (3, 8, 3, 1), - (16, 32, 2, 2), - (8, 16, 4, 2), - ]) - def test_parametrized_configurations(self, in_channels, out_channels, kernel_size, stride): + @pytest.mark.parametrize( + "in_channels,out_channels,kernel_size,stride", + [ + (1, 1, 2, 1), + (3, 8, 3, 1), + (16, 32, 2, 2), + (8, 16, 4, 2), + ], + ) + def test_parametrized_configurations( + self, in_channels, out_channels, kernel_size, stride + ): """Test various parameter combinations.""" # Calculate expected output size for a 10x10 input output_h = (10 - kernel_size) // stride + 1 @@ -250,7 +255,7 @@ def test_parametrized_configurations(self, in_channels, out_channels, kernel_siz output_size=output_size, kernel_size=kernel_size, stride=stride, - bias=True + bias=True, ) input_tensor = torch.randn(2, in_channels, 10, 10) @@ -271,7 +276,7 @@ def test_minimum_input_size(): output_size=(1, 1), kernel_size=3, stride=1, - bias=False + bias=False, ) # Minimum input size for 3x3 kernel to produce 1x1 output @@ -281,7 +286,6 @@ def test_minimum_input_size(): assert output.shape == (1, 1, 1, 1) -import numpy as np from radionets.architecture.layers import ComplexConv2d @@ -385,11 +389,7 @@ def test_forward_basic(self): def test_forward_output_calculation(self): """Test that forward pass calculates complex convolution correctly.""" conv = ComplexConv2d( - in_channels=2, - out_channels=2, - kernel_size=3, - stride=1, - bias=False + in_channels=2, out_channels=2, kernel_size=3, stride=1, bias=False ) # Create simple input @@ -445,11 +445,7 @@ def test_forward_different_input_sizes(self): def test_chunk_operation(self): """Test the chunk operation in forward method.""" conv = ComplexConv2d( - in_channels=4, - out_channels=8, - kernel_size=3, - stride=1, - bias=True + in_channels=4, out_channels=8, kernel_size=3, stride=1, bias=True ) # Create input @@ -466,11 +462,7 @@ def test_chunk_operation(self): def test_gradient_flow(self): """Test that gradients flow properly through the network.""" conv = ComplexConv2d( - in_channels=2, - out_channels=4, - kernel_size=3, - stride=1, - bias=True + in_channels=2, out_channels=4, kernel_size=3, stride=1, bias=True ) # Create input that requires grad @@ -497,33 +489,25 @@ def test_gradient_flow(self): def test_device_compatibility(self): """Test that the module works on different devices.""" conv = ComplexConv2d( - in_channels=2, - out_channels=16, - kernel_size=3, - stride=1, - bias=True + in_channels=2, out_channels=16, kernel_size=3, stride=1, bias=True ) # Test on CPU x_cpu = torch.randn(1, 2, 16, 16, dtype=torch.float32) output_cpu = conv.forward(x_cpu) - assert output_cpu.device.type == 'cpu' + assert output_cpu.device.type == "cpu" # Test on GPU if available if torch.cuda.is_available(): conv_gpu = conv.cuda() x_gpu = x_cpu.cuda() output_gpu = conv_gpu.forward(x_gpu) - assert output_gpu.device.type == 'cuda' + assert output_gpu.device.type == "cuda" def test_module_inheritance(self): """Test that ComplexConv2d properly inherits from nn.Module.""" conv = ComplexConv2d( - in_channels=2, - out_channels=2, - kernel_size=3, - stride=1, - bias=True + in_channels=2, out_channels=2, kernel_size=3, stride=1, bias=True ) assert isinstance(conv, nn.Module) @@ -531,7 +515,7 @@ def test_module_inheritance(self): # Test that it can be added to a sequential model model = nn.Sequential( conv, - nn.ReLU() # Note: ReLU won't work with complex numbers in practice + nn.ReLU(), # Note: ReLU won't work with complex numbers in practice ) assert len(list(model.parameters())) > 0 @@ -544,7 +528,7 @@ def test_parameter_count(self): out_channels=out_channels, kernel_size=kernel_size, stride=1, - bias=True + bias=True, ) # Count parameters @@ -552,7 +536,9 @@ def test_parameter_count(self): # Expected: 2 conv layers, each with weight and bias and # half input and output channels - expected_weight_params = 2 * out_channels // 2 * in_channels // 2 * kernel_size * kernel_size + expected_weight_params = ( + 2 * out_channels // 2 * in_channels // 2 * kernel_size * kernel_size + ) expected_bias_params = 2 * out_channels // 2 expected_total = expected_weight_params + expected_bias_params @@ -565,11 +551,7 @@ class TestComplexConv2dEdgeCases: def test_zero_input(self): """Test with zero input.""" conv = ComplexConv2d( - in_channels=2, - out_channels=2, - kernel_size=3, - stride=1, - bias=False + in_channels=2, out_channels=2, kernel_size=3, stride=1, bias=False ) x_zero = torch.zeros(1, 2, 8, 8, dtype=torch.float32) @@ -582,11 +564,7 @@ def test_zero_input(self): def test_single_pixel_input(self): """Test with single pixel input.""" conv = ComplexConv2d( - in_channels=2, - out_channels=2, - kernel_size=1, - stride=1, - bias=True + in_channels=2, out_channels=2, kernel_size=1, stride=1, bias=True ) x_single = torch.randn(1, 2, 1, 1, dtype=torch.float32) @@ -597,11 +575,7 @@ def test_single_pixel_input(self): def test_large_kernel_size(self): """Test with kernel size larger than input.""" conv = ComplexConv2d( - in_channels=2, - out_channels=2, - kernel_size=5, - stride=1, - bias=True + in_channels=2, out_channels=2, kernel_size=5, stride=1, bias=True ) # Input smaller than kernel @@ -628,10 +602,10 @@ def test_init_basic(self): assert norm.affine == True # Check learnable parameters exist - assert hasattr(norm, 'weight_real') - assert hasattr(norm, 'weight_imag') - assert hasattr(norm, 'bias_real') - assert hasattr(norm, 'bias_imag') + assert hasattr(norm, "weight_real") + assert hasattr(norm, "weight_imag") + assert hasattr(norm, "bias_real") + assert hasattr(norm, "bias_imag") # Check parameter shapes assert norm.weight_real.shape == (32,) @@ -654,10 +628,10 @@ def test_init_no_affine(self): assert norm.affine == False # Check that affine parameters don't exist - assert not hasattr(norm, 'weight_real') - assert not hasattr(norm, 'weight_imag') - assert not hasattr(norm, 'bias_real') - assert not hasattr(norm, 'bias_imag') + assert not hasattr(norm, "weight_real") + assert not hasattr(norm, "weight_imag") + assert not hasattr(norm, "bias_real") + assert not hasattr(norm, "bias_imag") def test_init_different_parameters(self): """Test initialization with different parameter combinations.""" @@ -755,9 +729,9 @@ def test_forward_different_input_sizes(self): # Test different input sizes input_sizes = [ - (1, 64, 1, 1), # Single pixel - (1, 64, 8, 8), # Small image - (4, 64, 32, 32), # Medium batch and image + (1, 64, 1, 1), # Single pixel + (1, 64, 8, 8), # Small image + (4, 64, 32, 32), # Medium batch and image (2, 64, 128, 256), # Large image ] @@ -773,8 +747,12 @@ def test_forward_different_input_sizes(self): real_means = real_out.mean(dim=[2, 3]) imag_means = imag_out.mean(dim=[2, 3]) - assert torch.allclose(real_means, torch.zeros_like(real_means), atol=1e-4) - assert torch.allclose(imag_means, torch.zeros_like(imag_means), atol=1e-4) + assert torch.allclose( + real_means, torch.zeros_like(real_means), atol=1e-4 + ) + assert torch.allclose( + imag_means, torch.zeros_like(imag_means), atol=1e-4 + ) def test_chunk_operation(self): """Test the chunk operation in forward method.""" @@ -821,7 +799,9 @@ def test_gradient_flow(self): # Check that gradients are non-zero (indicating proper flow) assert not torch.allclose(x.grad, torch.zeros_like(x.grad)) - assert not torch.allclose(norm.weight_real.grad, torch.zeros_like(norm.weight_real.grad)) + assert not torch.allclose( + norm.weight_real.grad, torch.zeros_like(norm.weight_real.grad) + ) def test_device_compatibility(self): """Test that the module works on different devices.""" @@ -830,14 +810,14 @@ def test_device_compatibility(self): # Test on CPU x_cpu = torch.randn(1, 16, 8, 8) output_cpu = norm.forward(x_cpu) - assert output_cpu.device.type == 'cpu' + assert output_cpu.device.type == "cpu" # Test on GPU if available if torch.cuda.is_available(): norm_gpu = norm.cuda() x_gpu = x_cpu.cuda() output_gpu = norm_gpu.forward(x_gpu) - assert output_gpu.device.type == 'cuda' + assert output_gpu.device.type == "cuda" # Results should be similar (allowing for minor numerical differences) assert torch.allclose(output_cpu, output_gpu.cpu(), atol=1e-5) @@ -849,10 +829,7 @@ def test_module_inheritance(self): assert isinstance(norm, nn.Module) # Test that it can be added to a sequential model - model = nn.Sequential( - norm, - nn.ReLU() - ) + model = nn.Sequential(norm, nn.ReLU()) # Test parameter counting params = list(norm.parameters()) @@ -963,8 +940,8 @@ def test_init_basic(self): assert prelu.num_parameters == 1 # Check learnable parameters exist - assert hasattr(prelu, 'weight_real') - assert hasattr(prelu, 'weight_imag') + assert hasattr(prelu, "weight_real") + assert hasattr(prelu, "weight_imag") assert isinstance(prelu.weight_real, nn.Parameter) assert isinstance(prelu.weight_imag, nn.Parameter) @@ -1023,10 +1000,12 @@ def test_forward_positive_values_unchanged(self): prelu = ComplexPReLU(num_parameters=1, init=0.2) # Create input with known positive and negative values - x = torch.tensor([ - [[[2.0, -1.0], [3.0, -2.0]], [[2.0, -1.0], [3.0, -2.0]]], - [[[1.5, -0.5], [-1.0, 4.0]], [[2.0, -1.0], [3.0, -2.0]]] - ]) + x = torch.tensor( + [ + [[[2.0, -1.0], [3.0, -2.0]], [[2.0, -1.0], [3.0, -2.0]]], + [[[1.5, -0.5], [-1.0, 4.0]], [[2.0, -1.0], [3.0, -2.0]]], + ] + ) print(x.shape) output = prelu.forward(x) real_out, imag_out = output.chunk(2, dim=1) @@ -1042,10 +1021,12 @@ def test_forward_negative_values_scaled(self): prelu = ComplexPReLU(num_parameters=1, init=init_val) # Create input with known negative values - x = torch.tensor([ - [[[-2.0, -1.0], [-3.0, -0.5]], [[-2.0, -1.0], [-3.0, -0.5]]], - [[[-1.5, -2.5], [-1.0, -4.0]], [[-2.0, -1.0], [-3.0, -0.5]]] - ]) + x = torch.tensor( + [ + [[[-2.0, -1.0], [-3.0, -0.5]], [[-2.0, -1.0], [-3.0, -0.5]]], + [[[-1.5, -2.5], [-1.0, -4.0]], [[-2.0, -1.0], [-3.0, -0.5]]], + ] + ) output = prelu.forward(x) real_out, imag_out = output.chunk(2, dim=1) @@ -1064,10 +1045,12 @@ def test_forward_mixed_values(self): prelu = ComplexPReLU(num_parameters=1, init=init_val) # Create input with mixed values - x = torch.tensor([ - [[[2.0, -1.0], [-3.0, 4.0]], [[2.0, -1.0], [-3.0, 4.0]]], - [[[-1.5, 2.5], [1.0, -4.0]], [[-1.5, 2.5], [1.0, -4.0]]] - ]) + x = torch.tensor( + [ + [[[2.0, -1.0], [-3.0, 4.0]], [[2.0, -1.0], [-3.0, 4.0]]], + [[[-1.5, 2.5], [1.0, -4.0]], [[-1.5, 2.5], [1.0, -4.0]]], + ] + ) output = prelu.forward(x) real_out, imag_out = output.chunk(2, dim=1) @@ -1098,11 +1081,11 @@ def test_forward_per_channel_parameters(self): # Check that each channel is scaled by its respective parameter for c in range(num_channels // 2): - expected_real_c = real_in[:, c:c+1] * prelu.weight_real[c] - expected_imag_c = imag_in[:, c:c+1] * prelu.weight_imag[c] + expected_real_c = real_in[:, c : c + 1] * prelu.weight_real[c] + expected_imag_c = imag_in[:, c : c + 1] * prelu.weight_imag[c] - assert torch.allclose(real_out[:, c:c+1], expected_real_c, atol=1e-6) - assert torch.allclose(imag_out[:, c:c+1], expected_imag_c, atol=1e-6) + assert torch.allclose(real_out[:, c : c + 1], expected_real_c, atol=1e-6) + assert torch.allclose(imag_out[:, c : c + 1], expected_imag_c, atol=1e-6) def test_forward_different_input_sizes(self): """Test forward pass with different input sizes.""" @@ -1110,10 +1093,10 @@ def test_forward_different_input_sizes(self): # Test different input sizes input_sizes = [ - (1, 2, 1, 1), # Single pixel, 1 complex channel - (1, 8, 4, 4), # Small image, 4 complex channels - (4, 16, 8, 8), # Medium batch and image - (2, 32, 16, 32), # Large image, 16 complex channels + (1, 2, 1, 1), # Single pixel, 1 complex channel + (1, 8, 4, 4), # Small image, 4 complex channels + (4, 16, 8, 8), # Medium batch and image + (2, 32, 16, 32), # Large image, 16 complex channels ] for batch_size, channels, height, width in input_sizes: @@ -1174,8 +1157,12 @@ def test_gradient_flow(self): assert prelu.weight_imag.grad is not None # Check that gradients are reasonable (non-zero for learnable parameters) - assert not torch.allclose(prelu.weight_real.grad, torch.zeros_like(prelu.weight_real.grad)) - assert not torch.allclose(prelu.weight_imag.grad, torch.zeros_like(prelu.weight_imag.grad)) + assert not torch.allclose( + prelu.weight_real.grad, torch.zeros_like(prelu.weight_real.grad) + ) + assert not torch.allclose( + prelu.weight_imag.grad, torch.zeros_like(prelu.weight_imag.grad) + ) def test_gradient_flow_per_channel(self): """Test gradient flow with per-channel parameters.""" @@ -1189,7 +1176,7 @@ def test_gradient_flow_per_channel(self): output = prelu.forward(x) # Create loss that depends on all parameters - loss = (output ** 2).sum() + loss = (output**2).sum() # Backward pass loss.backward() @@ -1207,14 +1194,14 @@ def test_device_compatibility(self): # Test on CPU x_cpu = torch.randn(1, 8, 4, 4) output_cpu = prelu.forward(x_cpu) - assert output_cpu.device.type == 'cpu' + assert output_cpu.device.type == "cpu" # Test on GPU if available if torch.cuda.is_available(): prelu_gpu = prelu.cuda() x_gpu = x_cpu.cuda() output_gpu = prelu_gpu.forward(x_gpu) - assert output_gpu.device.type == 'cuda' + assert output_gpu.device.type == "cuda" # Results should be similar (allowing for minor numerical differences) assert torch.allclose(output_cpu, output_gpu.cpu(), atol=1e-6) @@ -1226,10 +1213,7 @@ def test_module_inheritance(self): assert isinstance(prelu, nn.Module) # Test that it can be added to a sequential model - model = nn.Sequential( - prelu, - nn.Flatten() - ) + model = nn.Sequential(prelu, nn.Flatten()) # Test parameter counting params = list(prelu.parameters()) @@ -1271,15 +1255,11 @@ def test_activation_properties(self): neg_imag_mask = imag_in < 0 if neg_real_mask.any(): assert torch.allclose( - real_out[neg_real_mask], - real_in[neg_real_mask] * 0.2, - atol=1e-6 + real_out[neg_real_mask], real_in[neg_real_mask] * 0.2, atol=1e-6 ) if neg_imag_mask.any(): assert torch.allclose( - imag_out[neg_imag_mask], - imag_in[neg_imag_mask] * 0.2, - atol=1e-6 + imag_out[neg_imag_mask], imag_in[neg_imag_mask] * 0.2, atol=1e-6 ) @@ -1291,19 +1271,17 @@ def test_extreme_values(self): prelu = ComplexPReLU(num_parameters=1, init=0.1) # Test with very large values - x_large = torch.tensor([ - [[[1e6, -1e6]], [[1e6, -1e6]]], - [[[1e5, -1e5]], [[1e5, -1e5]]] - ]) + x_large = torch.tensor( + [[[[1e6, -1e6]], [[1e6, -1e6]]], [[[1e5, -1e5]], [[1e5, -1e5]]]] + ) output_large = prelu.forward(x_large) assert torch.isfinite(output_large).all() # Test with very small values - x_small = torch.tensor([ - [[[1e-6, -1e-6]], [[1e-6, -1e-6]]], - [[[1e-7, -1e-7]], [[1e-7, -1e-7]]] - ]) + x_small = torch.tensor( + [[[[1e-6, -1e-6]], [[1e-6, -1e-6]]], [[[1e-7, -1e-7]], [[1e-7, -1e-7]]]] + ) output_small = prelu.forward(x_small) assert torch.isfinite(output_small).all() @@ -1313,20 +1291,28 @@ def test_boundary_values(self): prelu = ComplexPReLU(num_parameters=1, init=0.25) # Input with exact zeros - x = torch.tensor([ - [[[0.0, -1.0], [1.0, 0.0]], [[0.0, -1.0], [1.0, 0.0]]], - [[[0.0, 1.0], [-1.0, 0.0]], [[0.0, 1.0], [-1.0, 0.0]]] - ]) + x = torch.tensor( + [ + [[[0.0, -1.0], [1.0, 0.0]], [[0.0, -1.0], [1.0, 0.0]]], + [[[0.0, 1.0], [-1.0, 0.0]], [[0.0, 1.0], [-1.0, 0.0]]], + ] + ) output = prelu.forward(x) real_out, imag_out = output.chunk(2, dim=1) # Zeros should remain zeros - zero_positions_real = (x[:, :1] == 0.0) - zero_positions_imag = (x[:, 1:] == 0.0) + zero_positions_real = x[:, :1] == 0.0 + zero_positions_imag = x[:, 1:] == 0.0 - assert torch.equal(real_out[zero_positions_real], torch.zeros_like(real_out[zero_positions_real])) - assert torch.equal(imag_out[zero_positions_imag], torch.zeros_like(imag_out[zero_positions_imag])) + assert torch.equal( + real_out[zero_positions_real], + torch.zeros_like(real_out[zero_positions_real]), + ) + assert torch.equal( + imag_out[zero_positions_imag], + torch.zeros_like(imag_out[zero_positions_imag]), + ) def test_single_pixel_input(self): """Test with single pixel input.""" diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 3fda838e..b52e1494 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -2,6 +2,7 @@ import pytest from scipy.stats import truncnorm + def truncnorm_moments(mu, sig, a, b): a, b = (a - mu) / sig, (b - mu) / sig sampled_gauss = truncnorm(a, b, loc=mu, scale=sig) diff --git a/tests/test_simulation.py b/tests/test_simulation.py deleted file mode 100644 index 6372dca8..00000000 --- a/tests/test_simulation.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from click.testing import CliRunner - - -@pytest.mark.order("first") -def test_simulation(): - from radionets.simulations.scripts.simulate_images import main - - runner = CliRunner() - result = runner.invoke(main, "tests/simulate.toml") - assert result.exit_code == 0