diff --git a/Justfile b/Justfile index f8da47479..30a8b8b09 100644 --- a/Justfile +++ b/Justfile @@ -1,7 +1,7 @@ # Setup install: - uv sync --no-cache --frozen + uv sync --group dev --group docs --no-cache --frozen # Packaging @@ -16,8 +16,8 @@ publish: # Testing test: - export TEST_TOKEN=$(cat ~/.latch/token) &&\ - pytest -s tests + export TEST_TOKEN=$(cat ~/.latch/token) + pytest -s # Docs diff --git a/pyproject.toml b/pyproject.toml index a0dcbd72e..3fe4f7d8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,15 +12,15 @@ include = ["src/**/*.py", "src/**/py.typed", "src/latch_cli/services/init/*"] [project] name = "latch" -version = "2.67.5" +version = "2.67.5.a4" description = "The Latch SDK" authors = [{ name = "Kenny Workman", email = "kenny@latch.bio" }] maintainers = [{ name = "Ayush Kamat", email = "ayush@latch.bio" }] readme = "README.md" license = { file = "LICENSE" } - requires-python = ">=3.9" + dependencies = [ "kubernetes>=24.2.0", "pyjwt>=0.2.0", @@ -31,7 +31,7 @@ dependencies = [ "scp>=0.14.0", "boto3>=1.26.0", "tqdm>=4.63.0", - "lytekit==0.15.29", + "lytekit==0.15.30", "lytekitplugins-pods==0.7.4", "typing-extensions>=4.12.0", "apscheduler>=3.10.0", @@ -72,7 +72,11 @@ classifiers = [ [project.optional-dependencies] pandas = ["pandas>=2.0.0"] -snakemake = ["snakemake>=7.18.0,<7.30.2", "pulp>=2.0,<2.8"] +snakemake = [ + "snakemake", + "snakemake-storage-plugin-latch==0.1.11", + "snakemake-executor-plugin-latch==0.1.9", +] [project.scripts] latch = "latch_cli.main:main" @@ -96,11 +100,10 @@ docs = [ ] [tool.ruff] +line-length = 100 target-version = "py39" [tool.ruff.lint] -preview = true - pydocstyle = { convention = "google" } extend-select = [ "F", diff --git a/src/latch/registry/upstream_types/values.py b/src/latch/registry/upstream_types/values.py index e1b79027a..f236cdf6b 100644 --- a/src/latch/registry/upstream_types/values.py +++ b/src/latch/registry/upstream_types/values.py @@ -3,6 +3,8 @@ from typing_extensions import Self, TypeAlias +from latch.utils import Singleton + class InvalidValue(TypedDict): rawValue: str @@ -62,9 +64,7 @@ class PrimitiveUnresolvedBlobValueValid(TypedDict): valid: Literal[True] -PrimitiveUnresolvedBlobValue: TypeAlias = Union[ - PrimitiveUnresolvedBlobValueValid, InvalidValue -] +PrimitiveUnresolvedBlobValue: TypeAlias = Union[PrimitiveUnresolvedBlobValueValid, InvalidValue] class LinkValue(TypedDict): @@ -108,23 +108,11 @@ class UnionValue(TypedDict): DBValue: TypeAlias = Union[PrimitiveValue, ArrayValue, UnionValue] -@dataclass(frozen=True) -class EmptyCell: +class EmptyCell(Singleton): """Empty Registry :class:`Record` value. Singleton. - The constructor returns a referentially identical instance each call. That is, - `EmptyCell() is EmptyCell()` - Used to distinguish explicit `None` values from missing values. """ - _singleton: ClassVar[Optional["EmptyCell"]] = None - - def __new__(cls) -> Self: - if cls._singleton is None: - cls._singleton = super().__new__(cls) - - return cls._singleton - Value: TypeAlias = Union[DBValue, EmptyCell] diff --git a/src/latch/resources/tasks.py b/src/latch/resources/tasks.py index db2bba2d7..48d1e334d 100644 --- a/src/latch/resources/tasks.py +++ b/src/latch/resources/tasks.py @@ -52,18 +52,8 @@ def get_v100_x1_pod() -> Pod: primary_container = V1Container(name="primary") resources = V1ResourceRequirements( - requests={ - "cpu": "7", - "memory": "48Gi", - "nvidia.com/gpu": 1, - "ephemeral-storage": "4500Gi", - }, - limits={ - "cpu": "7", - "memory": "48Gi", - "nvidia.com/gpu": 1, - "ephemeral-storage": "5000Gi", - }, + requests={"cpu": "7", "memory": "48Gi", "nvidia.com/gpu": 1, "ephemeral-storage": "4500Gi"}, + limits={"cpu": "7", "memory": "48Gi", "nvidia.com/gpu": 1, "ephemeral-storage": "5000Gi"}, ) primary_container.resources = resources @@ -94,12 +84,7 @@ def get_v100_x4_pod() -> Pod: "nvidia.com/gpu": 4, "ephemeral-storage": "4500Gi", }, - limits={ - "cpu": "30", - "memory": "230Gi", - "nvidia.com/gpu": 4, - "ephemeral-storage": "5000Gi", - }, + limits={"cpu": "30", "memory": "230Gi", "nvidia.com/gpu": 4, "ephemeral-storage": "5000Gi"}, ) primary_container.resources = resources @@ -135,12 +120,7 @@ def get_v100_x8_pod() -> Pod: "nvidia.com/gpu": 8, "ephemeral-storage": "4500Gi", }, - limits={ - "cpu": "62", - "memory": "400Gi", - "nvidia.com/gpu": 8, - "ephemeral-storage": "5000Gi", - }, + limits={"cpu": "62", "memory": "400Gi", "nvidia.com/gpu": 8, "ephemeral-storage": "5000Gi"}, ) primary_container.resources = resources @@ -205,21 +185,14 @@ def _get_small_gpu_pod() -> Pod: "nvidia.com/gpu": "1", "ephemeral-storage": "1500Gi", }, - limits={ - "cpu": "7", - "memory": "30Gi", - "nvidia.com/gpu": "1", - "ephemeral-storage": "1500Gi", - }, + limits={"cpu": "7", "memory": "30Gi", "nvidia.com/gpu": "1", "ephemeral-storage": "1500Gi"}, ) primary_container.resources = resources return Pod( pod_spec=V1PodSpec( containers=[primary_container], - tolerations=[ - V1Toleration(effect="NoSchedule", key="ng", value="gpu-small") - ], + tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="gpu-small")], ), primary_container_name="primary", ) @@ -244,9 +217,7 @@ def _get_large_pod() -> Pod: pod_spec=V1PodSpec( runtime_class_name="sysbox-runc", containers=[primary_container], - tolerations=[ - V1Toleration(effect="NoSchedule", key="ng", value="cpu-96-spot") - ], + tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="cpu-96-spot")], ), primary_container_name="primary", ) @@ -271,9 +242,7 @@ def _get_medium_pod() -> Pod: pod_spec=V1PodSpec( runtime_class_name="sysbox-runc", containers=[primary_container], - tolerations=[ - V1Toleration(effect="NoSchedule", key="ng", value="cpu-32-spot") - ], + tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="cpu-32-spot")], ), primary_container_name="primary", ) @@ -295,10 +264,7 @@ def _get_small_pod() -> Pod: "private:uidmapping=0:1048576:65536;gidmapping=0:1048576:65536" ) }, - pod_spec=V1PodSpec( - runtime_class_name="sysbox-runc", - containers=[primary_container], - ), + pod_spec=V1PodSpec(runtime_class_name="sysbox-runc", containers=[primary_container]), primary_container_name="primary", ) @@ -466,8 +432,7 @@ def custom_memory_optimized_task(cpu: int, memory: int): ) elif memory > 485: raise ValueError( - f"custom memory optimized task requires too much RAM: {memory} GiB (max 485" - " GiB)" + f"custom memory optimized task requires too much RAM: {memory} GiB (max 485 GiB)" ) primary_container = V1Container(name="primary") @@ -485,9 +450,7 @@ def custom_memory_optimized_task(cpu: int, memory: int): pod_spec=V1PodSpec( runtime_class_name="sysbox-runc", containers=[primary_container], - tolerations=[ - V1Toleration(effect="NoSchedule", key="ng", value="mem-512-spot") - ], + tolerations=[V1Toleration(effect="NoSchedule", key="ng", value="mem-512-spot")], ), primary_container_name="primary", ) @@ -517,11 +480,7 @@ class _NGConfig: max_storage_gb_ish = int(max_storage_gib * Units.GiB / Units.GB) -def _custom_task_config( - cpu: int, - memory: int, - storage_gib: int, -) -> Pod: +def _custom_task_config(cpu: int, memory: int, storage_gib: int) -> Pod: target_ng = None for ng in taint_data: if ( @@ -547,11 +506,7 @@ def _custom_task_config( "memory": f"{memory}Gi", "ephemeral-storage": f"{storage_gib}Gi", }, - limits={ - "cpu": str(cpu), - "memory": f"{memory}Gi", - "ephemeral-storage": f"{storage_gib}Gi", - }, + limits={"cpu": str(cpu), "memory": f"{memory}Gi", "ephemeral-storage": f"{storage_gib}Gi"}, ) primary_container.resources = resources return Pod( @@ -564,9 +519,7 @@ def _custom_task_config( runtime_class_name="sysbox-runc", containers=[primary_container], tolerations=[ - V1Toleration( - effect="NoSchedule", key="ng", value=target_ng.toleration_value - ) + V1Toleration(effect="NoSchedule", key="ng", value=target_ng.toleration_value) ], ), primary_container_name="primary", @@ -591,18 +544,12 @@ def custom_task( """ if callable(cpu) or callable(memory) or callable(storage_gib): task_config = DynamicTaskConfig( - cpu=cpu, - memory=memory, - storage=storage_gib, - pod_config=_get_small_pod(), + cpu=cpu, memory=memory, storage=storage_gib, pod_config=_get_small_pod() ) return functools.partial(task, task_config=task_config, timeout=timeout) return functools.partial( - task, - task_config=_custom_task_config(cpu, memory, storage_gib), - timeout=timeout, - **kwargs, + task, task_config=_custom_task_config(cpu, memory, storage_gib), timeout=timeout, **kwargs ) @@ -610,12 +557,9 @@ def lustre_setup_task(): primary_container = V1Container( name="primary", resources=V1ResourceRequirements( - requests={"cpu": "500m", "memory": "500Mi"}, - limits={"cpu": "500m", "memory": "500Mi"}, + requests={"cpu": "500m", "memory": "500Mi"}, limits={"cpu": "500m", "memory": "500Mi"} ), - volume_mounts=[ - V1VolumeMount(mount_path="/nf-workdir", name="nextflow-workdir") - ], + volume_mounts=[V1VolumeMount(mount_path="/nf-workdir", name="nextflow-workdir")], ) task_config = Pod( @@ -659,6 +603,30 @@ def nextflow_runtime_task(cpu: int, memory: int, storage_gib: int = 50): return functools.partial(task, task_config=task_config) +def snakemake_runtime_task(*, cpu: int, memory: int, storage_gib: int = 50): + task_config = _custom_task_config(cpu, memory, storage_gib) + + task_config.pod_spec.automount_service_account_token = True + + assert len(task_config.pod_spec.containers) == 1 + task_config.pod_spec.containers[0].volume_mounts = [ + V1VolumeMount(mount_path="/snakemake-workdir", name="snakemake-workdir") + ] + + task_config.pod_spec.volumes = [ + V1Volume( + name="snakemake-workdir", + persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( + # this value will be injected by flytepropeller + # ayush: this is also used by snakemake bc why not + claim_name="nextflow-pvc-placeholder" + ), + ) + ] + + return functools.partial(task, task_config=task_config) + + def _get_l40s_pod(instance_type: str, cpu: int, memory_gib: int, gpus: int) -> Pod: """Helper function to create L40s GPU pod configurations.""" primary_container = V1Container(name="primary") @@ -685,66 +653,50 @@ def _get_l40s_pod(instance_type: str, cpu: int, memory_gib: int, gpus: int) -> P return Pod( pod_spec=V1PodSpec( containers=[primary_container], - tolerations=[ - V1Toleration( - effect="NoSchedule", - key="ng", - value=instance_type - ) - ], + tolerations=[V1Toleration(effect="NoSchedule", key="ng", value=instance_type)], ), primary_container_name="primary", - annotations={ - "cluster-autoscaler.kubernetes.io/safe-to-evict": "false", - }, + annotations={"cluster-autoscaler.kubernetes.io/safe-to-evict": "false"}, ) g6e_xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-xlarge", cpu=4, memory_gib=32, gpus=1) + task, task_config=_get_l40s_pod("g6e-xlarge", cpu=4, memory_gib=32, gpus=1) ) """4 vCPUs, 32 GiB RAM, 1 L40s GPU""" g6e_2xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-2xlarge", cpu=8, memory_gib=64, gpus=1) + task, task_config=_get_l40s_pod("g6e-2xlarge", cpu=8, memory_gib=64, gpus=1) ) """8 vCPUs, 64 GiB RAM, 1 L40s GPU""" g6e_4xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-4xlarge", cpu=16, memory_gib=128, gpus=1) + task, task_config=_get_l40s_pod("g6e-4xlarge", cpu=16, memory_gib=128, gpus=1) ) """16 vCPUs, 128 GiB RAM, 1 L40s GPU""" g6e_8xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-8xlarge", cpu=32, memory_gib=256, gpus=1) + task, task_config=_get_l40s_pod("g6e-8xlarge", cpu=32, memory_gib=256, gpus=1) ) """32 vCPUs, 256 GiB RAM, 1 L40s GPU""" g6e_12xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-12xlarge", cpu=48, memory_gib=384, gpus=4) + task, task_config=_get_l40s_pod("g6e-12xlarge", cpu=48, memory_gib=384, gpus=4) ) """48 vCPUs, 384 GiB RAM, 4 L40s GPUs""" g6e_16xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-16xlarge", cpu=64, memory_gib=512, gpus=1) + task, task_config=_get_l40s_pod("g6e-16xlarge", cpu=64, memory_gib=512, gpus=1) ) """64 vCPUs, 512 GiB RAM, 1 L40s GPUs""" g6e_24xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-24xlarge", cpu=96, memory_gib=768, gpus=4) + task, task_config=_get_l40s_pod("g6e-24xlarge", cpu=96, memory_gib=768, gpus=4) ) """96 vCPUs, 768 GiB RAM, 4 L40s GPUs""" g6e_48xlarge_task = functools.partial( - task, - task_config=_get_l40s_pod("g6e-48xlarge", cpu=192, memory_gib=1536, gpus=8) + task, task_config=_get_l40s_pod("g6e-48xlarge", cpu=192, memory_gib=1536, gpus=8) ) """192 vCPUs, 1536 GiB RAM, 8 L40s GPUs""" diff --git a/src/latch/resources/workflow.py b/src/latch/resources/workflow.py index e0f741f17..0812514d1 100644 --- a/src/latch/resources/workflow.py +++ b/src/latch/resources/workflow.py @@ -11,12 +11,7 @@ from flytekit.core.interface import transform_function_to_interface from flytekit.core.workflow import PythonFunctionWorkflow -from latch.types.metadata import ( - LatchAuthor, - LatchMetadata, - LatchParameter, - NextflowMetadata, -) +from latch.types.metadata import LatchAuthor, LatchMetadata, LatchParameter, NextflowMetadata from latch_cli.utils import best_effort_display_name @@ -44,9 +39,7 @@ def _inject_metadata(f: Callable, metadata: LatchMetadata) -> None: # this weird Union thing is to ensure backwards compatibility, # so that when users call @workflow without any arguments or # parentheses, the workflow still serializes as expected -def workflow( - metadata: Union[LatchMetadata, Callable], -) -> Union[PythonFunctionWorkflow, Callable]: +def workflow(metadata: Union[LatchMetadata, Callable]) -> Union[PythonFunctionWorkflow, Callable]: if isinstance(metadata, Callable): f = metadata if f.__doc__ is None or "__metadata__:" not in f.__doc__: @@ -107,9 +100,7 @@ def decorator(f: Callable): raise click.exceptions.Exit(1) arg_origin = get_origin(args[0]) - valid = is_dataclass(args[0]) or ( - arg_origin is not None and is_dataclass(arg_origin) - ) + valid = is_dataclass(args[0]) or (arg_origin is not None and is_dataclass(arg_origin)) if not valid: click.secho( f"parameter marked as samplesheet is not valid: {name} " @@ -148,9 +139,13 @@ def decorator(f: Callable): return decorator -def nextflow_workflow( - metadata: NextflowMetadata, -) -> Callable[[Callable], PythonFunctionWorkflow]: +def nextflow_workflow(metadata: NextflowMetadata) -> Callable[[Callable], PythonFunctionWorkflow]: + metadata._non_standard["unpack_records"] = True + + return workflow(metadata) + + +def snakemake_workflow(metadata: NextflowMetadata) -> Callable[[Callable], PythonFunctionWorkflow]: metadata._non_standard["unpack_records"] = True return workflow(metadata) diff --git a/src/latch/types/directory.py b/src/latch/types/directory.py index 916d9287e..bb3f2a6f9 100644 --- a/src/latch/types/directory.py +++ b/src/latch/types/directory.py @@ -12,10 +12,7 @@ from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types.directory.types import ( - FlyteDirectory, - FlyteDirToMultipartBlobTransformer, -) +from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer from typing_extensions import Annotated from latch.ldata.path import LPath @@ -96,10 +93,7 @@ def task(dir: LatchFile): """ def __init__( - self, - path: Union[str, PathLike], - remote_path: Optional[PathLike] = None, - **kwargs, + self, path: Union[str, PathLike], remote_path: Optional[PathLike] = None, **kwargs ): if path is None: raise ValueError("Unable to instantiate LatchDir with None") @@ -253,10 +247,7 @@ def __repr__(self): if self.remote_path is None: return f"LatchDir({repr(format_path(self.local_path))})" - return ( - f"LatchDir({repr(self.path)}," - f" remote_path={repr(format_path(self.remote_path))})" - ) + return f"LatchDir({repr(self.path)}, remote_path={repr(format_path(self.remote_path))})" def __str__(self): if self.remote_path is None: @@ -294,9 +285,7 @@ def to_literal( if remote_directory is None: remote_directory = ctx.file_access.get_random_remote_directory() - put_res = ctx.file_access.put_data( - python_val.path, remote_directory, is_multipart=True - ) + put_res = ctx.file_access.put_data(python_val.path, remote_directory, is_multipart=True) if put_res is None: put_res = {} @@ -305,8 +294,7 @@ def to_literal( blob=Blob( metadata=BlobMetadata( type=BlobType( - format="", - dimensionality=BlobType.BlobDimensionality.MULTIPART, + format="", dimensionality=BlobType.BlobDimensionality.MULTIPART ) ), uri=python_val.remote_path, @@ -316,24 +304,17 @@ def to_literal( ) def to_python_value( - self, - ctx: FlyteContext, - lv: Literal, - expected_python_type: Union[type[LatchDir], PathLike], + self, ctx: FlyteContext, lv: Literal, expected_python_type: Union[type[LatchDir], PathLike] ) -> FlyteDirectory: uri = lv.scalar.blob.uri if expected_python_type is PathLike: - raise TypeError( - "Casting from Pathlike to LatchDir is currently not supported." - ) + raise TypeError("Casting from Pathlike to LatchDir is currently not supported.") while get_origin(expected_python_type) == Annotated: expected_python_type = get_args(expected_python_type)[0] if not issubclass(expected_python_type, LatchDir): - raise TypeError( - f"Neither os.PathLike nor LatchDir specified {expected_python_type}" - ) + raise TypeError(f"Neither os.PathLike nor LatchDir specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. diff --git a/src/latch/types/file.py b/src/latch/types/file.py index 1930bb933..33c6c4115 100644 --- a/src/latch/types/file.py +++ b/src/latch/types/file.py @@ -160,10 +160,7 @@ def __repr__(self): if self.remote_path is None: return f"LatchFile({repr(format_path(self.local_path))})" - return ( - f"LatchFile({repr(self.path)}," - f" remote_path={repr(format_path(self.remote_path))})" - ) + return f"LatchFile({repr(self.path)}, remote_path={repr(format_path(self.remote_path))})" def __str__(self): if self.remote_path is None: @@ -200,9 +197,7 @@ def to_literal( if remote_path is None: remote_path = ctx.file_access.get_random_remote_path() - put_res = ctx.file_access.put_data( - python_val.path, remote_path, is_multipart=False - ) + put_res = ctx.file_access.put_data(python_val.path, remote_path, is_multipart=False) if put_res is None: put_res = {} @@ -210,9 +205,7 @@ def to_literal( scalar=Scalar( blob=Blob( metadata=BlobMetadata( - type=BlobType( - format="", dimensionality=BlobType.BlobDimensionality.SINGLE - ) + type=BlobType(format="", dimensionality=BlobType.BlobDimensionality.SINGLE) ), uri=python_val.remote_path, ) @@ -221,21 +214,14 @@ def to_literal( ) def to_python_value( - self, - ctx: FlyteContext, - lv: Literal, - expected_python_type: Union[type[LatchFile], PathLike], + self, ctx: FlyteContext, lv: Literal, expected_python_type: Union[type[LatchFile], PathLike] ) -> LatchFile: uri = lv.scalar.blob.uri if expected_python_type is PathLike: - raise TypeError( - "Casting from Pathlike to LatchFile is currently not supported." - ) + raise TypeError("Casting from Pathlike to LatchFile is currently not supported.") if not issubclass(expected_python_type, LatchFile): - raise TypeError( - f"Neither os.PathLike nor LatchFile specified {expected_python_type}" - ) + raise TypeError(f"Neither os.PathLike nor LatchFile specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. diff --git a/src/latch/types/metadata/__init__.py b/src/latch/types/metadata/__init__.py new file mode 100644 index 000000000..3dc713798 --- /dev/null +++ b/src/latch/types/metadata/__init__.py @@ -0,0 +1,66 @@ +# for backwards compatibility so as not to break existing imports + +from .flows import FlowBase, Fork, ForkBranch, Params, Section, Spoiler, Text, Title +from .latch import ( + LatchAppearance, + LatchAppearanceEnum, + LatchAppearanceType, + LatchAuthor, + LatchMetadata, + LatchParameter, + LatchRule, + Multiselect, + MultiselectOption, +) +from .nextflow import ( + NextflowMetadata, + NextflowParameter, + NextflowRuntimeResources, + _nextflow_metadata, + _samplesheet_constructor, + _samplesheet_repr, +) +from .snakemake import ( + DockerMetadata, + EnvironmentConfig, + FileMetadata, + SnakemakeFileMetadata, + SnakemakeFileParameter, + SnakemakeMetadata, + SnakemakeParameter, + _snakemake_metadata, +) + +__all__ = [ + "FlowBase", + "Fork", + "ForkBranch", + "Params", + "Section", + "Spoiler", + "Text", + "Title", + "LatchAppearance", + "LatchAppearanceEnum", + "LatchAppearanceType", + "LatchAuthor", + "LatchMetadata", + "LatchParameter", + "LatchRule", + "Multiselect", + "MultiselectOption", + "NextflowMetadata", + "NextflowParameter", + "NextflowRuntimeResources", + "_nextflow_metadata", + "_samplesheet_constructor", + "_samplesheet_repr", + "DockerMetadata", + "EnvironmentConfig", + "FileMetadata", + "SnakemakeFileMetadata", + "SnakemakeFileParameter", + "SnakemakeMetadata", + "SnakemakeParameter", + "_snakemake_metadata", +] diff --git a/src/latch/types/metadata/flows.py b/src/latch/types/metadata/flows.py new file mode 100644 index 000000000..3d0a21d8e --- /dev/null +++ b/src/latch/types/metadata/flows.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class FlowBase: + """Parent class for all flow elements + + Available flow elements: + + * :class:`~latch.types.metadata.Params` + + * :class:`~latch.types.metadata.Text` + + * :class:`~latch.types.metadata.Title` + + * :class:`~latch.types.metadata.Section` + + * :class:`~latch.types.metadata.Spoiler` + + * :class:`~latch.types.metadata.Fork` + """ + + +@dataclass(frozen=True, init=False) +class Section(FlowBase): + """Flow element that displays a child flow in a card with a given title + + Example: + + .. image:: ../assets/flow-example/flow_example_1.png + :alt: Example of a user interface for a workflow with a custom flow + + .. image:: ../assets/flow-example/flow_example_spoiler.png + :alt: Example of a spoiler flow element + + + The `LatchMetadata` for the example above can be defined as follows: + + .. code-block:: python + + from latch.types import LatchMetadata, LatchParameter + from latch.types.metadata import FlowBase, Section, Text, Params, Fork, Spoiler + from latch import workflow + + flow = [ + Section( + "Samples", + Text( + "Sample provided has to include an identifier for the sample (Sample name)" + " and one or two files corresponding to the reads (single-end or paired-end, respectively)" + ), + Fork( + "sample_fork", + "Choose read type", + paired_end=ForkBranch("Paired-end", Params("paired_end")), + single_end=ForkBranch("Single-end", Params("single_end")), + ), + ), + Section( + "Quality threshold", + Text( + "Select the quality value in which a base is qualified." + "Quality value refers to a Phred quality score" + ), + Params("quality_threshold"), + ), + Spoiler( + "Output directory", + Text("Name of the output directory to send results to."), + Params("output_directory"), + ), + ] + + metadata = LatchMetadata( + display_name="fastp - Flow Tutorial", + author=LatchAuthor( + name="LatchBio", + ), + parameters={ + "sample_fork": LatchParameter(), + "paired_end": LatchParameter( + display_name="Paired-end reads", + description="FASTQ files", + batch_table_column=True, + ), + "single_end": LatchParameter( + display_name="Single-end reads", + description="FASTQ files", + batch_table_column=True, + ), + "output_directory": LatchParameter( + display_name="Output directory", + ), + }, + flow=flow, + ) + + @workflow(metadata) + def fastp( + sample_fork: str, + paired_end: PairedEnd, + single_end: Optional[SingleEnd] = None, + output_directory: str = "fastp_results", + ) -> LatchDir: + ... + """ + + section: str + """Title of the section""" + flow: list[FlowBase] + """Flow displayed in the section card""" + + def __init__(self, section: str, *flow: FlowBase): + object.__setattr__(self, "section", section) + object.__setattr__(self, "flow", list(flow)) + + +@dataclass(frozen=True) +class Text(FlowBase): + """Flow element that displays a markdown string""" + + text: str + """Markdown body text""" + + +@dataclass(frozen=True) +class Title(FlowBase): + """Flow element that displays a markdown title""" + + title: str + """Markdown title text""" + + +@dataclass(frozen=True, init=False) +class Params(FlowBase): + """Flow element that displays parameter widgets""" + + params: list[str] + """ + Names of parameters whose widgets will be displayed. + Order is preserved. Duplicates are allowed + """ + + def __init__(self, *args: str): + object.__setattr__(self, "params", list(args)) + + +@dataclass(frozen=True, init=False) +class Spoiler(FlowBase): + """Flow element that displays a collapsible card with a given title""" + + spoiler: str + """Title of the spoiler""" + flow: list[FlowBase] + """Flow displayed in the spoiler card""" + + def __init__(self, spoiler: str, *flow: FlowBase): + object.__setattr__(self, "spoiler", spoiler) + object.__setattr__(self, "flow", list(flow)) + + +@dataclass(frozen=True, init=False) +class ForkBranch: + """Definition of a :class:`~latch.types.metadata.Fork` branch""" + + display_name: str + """String displayed in the fork's multibutton""" + flow: list[FlowBase] + """Child flow displayed in the fork card when the branch is active""" + + def __init__(self, display_name: str, *flow: FlowBase): + object.__setattr__(self, "display_name", display_name) + object.__setattr__(self, "flow", list(flow)) + + +@dataclass(frozen=True, init=False) +class Fork(FlowBase): + """Flow element that displays a set of mutually exclusive alternatives + + Displays a title, followed by a horizontal multibutton for selecting a branch, + then a card for the active branch + """ + + fork: str + """Name of a `str`-typed parameter to store the active branch's key""" + display_name: str + """Title shown above the fork selector""" + flows: dict[str, ForkBranch] + """ + Mapping between branch keys to branch definitions. + Order determines the order of options in the multibutton + """ + + def __init__(self, fork: str, display_name: str, **flows: ForkBranch): + object.__setattr__(self, "fork", fork) + object.__setattr__(self, "display_name", display_name) + object.__setattr__(self, "flows", flows) diff --git a/src/latch/types/metadata/latch.py b/src/latch/types/metadata/latch.py new file mode 100644 index 000000000..b4c45f064 --- /dev/null +++ b/src/latch/types/metadata/latch.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from textwrap import indent +from typing import TYPE_CHECKING, Any, Union + +import yaml +from typing_extensions import TypeAlias + +if TYPE_CHECKING: + from .flows import FlowBase + + +@dataclass +class LatchRule: + """Class describing a rule that a parameter input must follow""" + + regex: str + """A string regular expression which inputs must match""" + message: str + """The message to render when an input does not match the regex""" + + @property + def dict(self): + return asdict(self) + + def __post_init__(self): + try: + re.compile(self.regex) + except re.error as e: + raise ValueError(f"Malformed regex {self.regex}: {e.msg}") from e + + +class LatchAppearanceEnum(Enum): + line = "line" + paragraph = "paragraph" + + +@dataclass(frozen=True) +class MultiselectOption: + name: str + value: object + + +@dataclass(frozen=True) +class Multiselect: + options: list[MultiselectOption] = field(default_factory=list) + allow_custom: bool = False + + +# backwards compatibility +LatchAppearanceType = LatchAppearanceEnum + +LatchAppearance: TypeAlias = Union[LatchAppearanceEnum, Multiselect] + + +@dataclass +class LatchAuthor: + """Class describing metadata about the workflow author""" + + name: str | None = None + """The name of the author""" + email: str | None = None + """The email of the author""" + github: str | None = None + """A link to the github profile of the author""" + + +@dataclass +class LatchParameter: + """Class for organizing parameter metadata""" + + display_name: str | None = None + """The name used to display the parameter on Latch Console""" + description: str | None = None + """The description of the parameter's role in the workflow""" + hidden: bool = False + """Whether or not the parameter should be hidden by default""" + section_title: str | None = None + """Whether this parameter should start a new section""" + placeholder: str | None = None + """ + What should be rendered as a placeholder in the input box + of the parameter before any value is inputed. + """ + comment: str | None = None + """Any comment on the parameter itself""" + output: bool = False + """ + Whether or not this parameter is an output (used to disable + path validation before launching a workflow) + """ + batch_table_column: bool = False + """ + Whether this parameter should be given a column in the batch + table at the top of the workflow inputs + """ + allow_dir: bool = True + """ + Whether or not this parameter should accept directories in UI + """ + allow_file: bool = True + """ + Whether or not this parameter should accept files in UI. + """ + appearance_type: LatchAppearance = LatchAppearanceEnum.line + """ + Whether the parameter should be rendered as a line or paragraph + (must be exactly one of either LatchAppearanceType.line or + LatchAppearanceType.paragraph) + """ + rules: list[LatchRule] = field(default_factory=list) + """ + A list of LatchRule objects that inputs to this parameter must follow + """ + detail: str | None = None + samplesheet: bool | None = None + """ + Use samplesheet input UI. Allows importing from Latch Registry. + Parameter type must be a list of dataclasses + """ + allowed_tables: list[int] | None = None + """ + If using the samplesheet component, specify a set of Registry Tables (by ID) to allow selection from. + If not provided, all Tables are allowed. + + Only has an effect if `samplesheet=True`. + """ + _custom_ingestion: str | None = None + + def __str__(self): + metadata_yaml = yaml.safe_dump(self.dict, sort_keys=False) + if self.description is not None: + return f"{self.description}\n{metadata_yaml}" + return metadata_yaml + + @property + def dict(self): + parameter_dict: dict[str, Any] = {"display_name": self.display_name} + + if self.output: + parameter_dict["output"] = True + if self.batch_table_column: + parameter_dict["batch_table_column"] = True + if self.samplesheet: + parameter_dict["samplesheet"] = True + if self.allowed_tables is not None: + parameter_dict["allowed_tables"] = [str(x) for x in self.allowed_tables] + + temp_dict: dict[str, Any] = {"hidden": self.hidden} + if self.section_title is not None: + temp_dict["section_title"] = self.section_title + if self._custom_ingestion is not None: + temp_dict["custom_ingestion"] = self._custom_ingestion + + parameter_dict["_tmp"] = temp_dict + + appearance_dict: dict[str, Any] + if isinstance(self.appearance_type, LatchAppearanceEnum): + appearance_dict = {"type": self.appearance_type.value} + elif isinstance(self.appearance_type, Multiselect): + appearance_dict = {"multiselect": asdict(self.appearance_type)} + else: + appearance_dict = {} + + if self.placeholder is not None: + appearance_dict["placeholder"] = self.placeholder + if self.comment is not None: + appearance_dict["comment"] = self.comment + if self.detail is not None: + appearance_dict["detail"] = self.detail + + appearance_dict["file_type"] = ( + "ANY" + if self.allow_file and self.allow_dir + else "FILE" + if self.allow_file + else "DIR" + if self.allow_dir + else "NONE" + ) + + parameter_dict["appearance"] = appearance_dict + + if len(self.rules) > 0: + parameter_dict["rules"] = [rule.dict for rule in self.rules] + + return {"__metadata__": parameter_dict} + + +@dataclass +class LatchMetadata: + """Class for organizing workflow metadata + + Example: + + .. code-block:: python + + from latch.types import LatchMetadata, LatchAuthor, LatchRule, LatchAppearanceType + + metadata = LatchMetadata( + parameters={ + "read1": LatchParameter( + display_name="Read 1", + description="Paired-end read 1 file to be assembled.", + hidden=True, + section_title="Sample Reads", + placeholder="Select a file", + comment="This is a comment", + output=False, + appearance_type=LatchAppearanceType.paragraph, + rules=[ + LatchRule( + regex="(.fasta|.fa|.faa|.fas)$", + message="Only .fasta, .fa, .fas, or .faa extensions are valid" + ) + ], + batch_table_column=True, # Show this parameter in batched mode. + # The below parameters will be displayed on the side bar of the workflow + documentation="https://github.com/author/my_workflow/README.md", + author=LatchAuthor( + name="Workflow Author", + email="licensing@company.com", + github="https://github.com/author", + ), + repository="https://github.com/author/my_workflow", + license="MIT", + # If the workflow is public, display it under the defined categories on Latch to be more easily discovered by users + tags=["NGS", "MAG"], + ), + ) + + @workflow(metadata) + def wf(read1: LatchFile): + ... + + """ + + display_name: str + """The human-readable name of the workflow""" + author: LatchAuthor + """ A `LatchAuthor` object that describes the author of the workflow""" + documentation: str | None = None + """A link to documentation for the workflow itself""" + repository: str | None = None + """A link to the repository where the code for the workflow is hosted""" + license: str = "MIT" + """A SPDX identifier""" + parameters: dict[str, LatchParameter] = field(default_factory=dict) + """A dictionary mapping parameter names (strings) to `LatchParameter` objects""" + wiki_url: str | None = None + video_tutorial: str | None = None + tags: list[str] = field(default_factory=list) + flow: list[FlowBase] = field(default_factory=list) + + no_standard_bulk_execution: bool = False + """ + Disable the standard CSV-based bulk execution. Intended for workflows that + support an alternative way of processing bulk data e.g. using a samplesheet + parameter + """ + _non_standard: dict[str, object] = field(default_factory=dict) + + @property + def dict(self): + metadata_dict = asdict(self) + # remove parameters since that will be handled by each parameters' dict() method + del metadata_dict["parameters"] + metadata_dict["license"] = {"id": self.license} + + # flows override all other rendering, so disable them entirely if not provided + if len(self.flow) == 0: + del metadata_dict["flow"] + + for key in self._non_standard: + metadata_dict[key] = self._non_standard[key] + + return {"__metadata__": metadata_dict} + + def __str__(self): + def _parameter_str(t: tuple[str, LatchParameter]): + parameter_name, parameter_meta = t + return f"{parameter_name}:\n" + indent( + str(parameter_meta), " ", lambda _: True + ) + + metadata_yaml = yaml.safe_dump(self.dict, sort_keys=False) + parameter_yaml = "".join(map(_parameter_str, self.parameters.items())) + return ( + metadata_yaml + "Args:\n" + indent(parameter_yaml, " ", lambda _: True) + ).strip("\n ") diff --git a/src/latch/types/metadata/nextflow.py b/src/latch/types/metadata/nextflow.py new file mode 100644 index 000000000..7e1cf4b4a --- /dev/null +++ b/src/latch/types/metadata/nextflow.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import csv +import functools +from dataclasses import dataclass, field, fields, is_dataclass +from enum import Enum +from pathlib import Path +from textwrap import dedent +from typing import Any, Callable, Generic, Literal, get_args, get_origin + +import click + +from latch_cli.utils import identifier_suffix_from_str + +from ..directory import LatchDir, LatchOutputDir +from ..file import LatchFile +from .latch import LatchMetadata, LatchParameter +from .utils import DC, P + + +@dataclass +class NextflowParameter(Generic[P], LatchParameter): + type: type[P] | None = None + """ + The python type of the parameter. + """ + default: P | None = None + """ + Default value of the parameter + """ + + samplesheet_type: Literal["csv", "tsv", None] = None + """ + The type of samplesheet to construct from the input parameter. + + Only used if the provided parameter is a samplesheet (samplesheet=True) + """ + samplesheet_constructor: Callable[[P], Path] | None = None + """ + A custom samplesheet constructor. + + Should return the path of the constructed samplesheet. If samplesheet_type is also specified, this takes precedence. + Only used if the provided parameter is a samplesheet (samplesheet=True) + """ + results_paths: list[Path] | None = None + """ + Output sub-paths that will be exposed in the UI under the "Results" tab on the workflow execution page. + + Only valid where the `type` attribute is a LatchDir + """ + + def __post_init__(self): + if self.results_paths is not None and self.type not in { + LatchDir, + LatchOutputDir, + }: + click.secho( + "`results_paths` attribute can only be defined for parameters" + " of type `LatchDir`.", + fg="red", + ) + raise click.exceptions.Exit(1) + + if not self.samplesheet or self.samplesheet_constructor is not None: + return + + t = self.type + if get_origin(t) is not list or not is_dataclass(get_args(t)[0]): + click.secho("Samplesheets must be a list of dataclasses.", fg="red") + raise click.exceptions.Exit(1) + + if self.samplesheet_type is not None: + delim = "," if self.samplesheet_type == "csv" else "\t" + self.samplesheet_constructor = functools.partial( + _samplesheet_constructor, t=get_args(self.type)[0], delim=delim + ) + return + + click.secho( + dedent("""\ + A Samplesheet constructor is required for a samplesheet parameter. Please either provide a value for + `samplesheet_type` or provide a custom callable to the `samplesheet_constructor` argument. + """), + fg="red", + ) + raise click.exceptions.Exit(1) + + +def _samplesheet_repr(v: Any) -> str: + if v is None: + return "" + if isinstance(v, (LatchFile, LatchDir)): + return str(v.remote_path) + if isinstance(v, Enum): + return v.value + + return str(v) + + +def _samplesheet_constructor(samples: list[DC], t: DC, delim: str = ",") -> Path: + samplesheet = Path("samplesheet.csv") + + with samplesheet.open("w") as f: + writer = csv.DictWriter(f, [f.name for f in fields(t)], delimiter=delim) + writer.writeheader() + + for sample in samples: + row_data = { + f.name: _samplesheet_repr(getattr(sample, f.name)) + for f in fields(sample) + } + writer.writerow(row_data) + + return samplesheet + + +@dataclass(frozen=True) +class NextflowRuntimeResources: + """Resources for Nextflow runtime tasks""" + + cpus: int | None = 4 + """ + Number of CPUs required for the task + """ + memory: int | None = 8 + """ + Memory required for the task in GiB + """ + storage_gib: int | None = 100 + """ + Storage required for the task in GiB + """ + storage_expiration_hours: int = 7 * 24 + """ + Number of hours after execution failure that workdir should be retained in EFS. + Warning: Increasing this number will increase your Nextflow Storage costs. + """ + + +@dataclass +class NextflowMetadata(LatchMetadata): + name: str | None = None + """ + Name of the workflow + """ + parameters: dict[str, NextflowParameter[Any]] = field(default_factory=dict) + """ + A dictionary mapping parameter names (strings) to `NextflowParameter` objects + """ + runtime_resources: NextflowRuntimeResources = field( + default_factory=NextflowRuntimeResources + ) + """ + Resources (cpu/memory/storage) for Nextflow runtime task + """ + execution_profiles: list[str] = field(default_factory=list) + """ + Execution config profiles to expose to users in the Latch console + """ + log_dir: LatchDir | None = None + """ + Directory to dump Nextflow logs + """ + upload_command_logs: bool = False + """ + Upload .command.* logs to Latch Data after each task execution + """ + about_page_path: Path | None = None + """ + Path to a markdown file containing information about the pipeline - rendered in the About page. + """ + + def validate(self): + if self.about_page_path is not None and not isinstance( + self.about_page_path, Path + ): # type: ignore + click.secho( + f"`about_page_path` parameter ({self.about_page_path}) must be a" + " Path object.", + fg="red", + ) + + @property + def dict(self): + d = super().dict + del d["__metadata__"]["about_page_path"] + return d + + def __post_init__(self): + self.validate() + + if self.name is None: + if self.display_name is None: + click.secho( + "Name or display_name must be provided in metadata", fg="red" + ) + self.name = f"nf_{identifier_suffix_from_str(self.display_name.lower())}" + else: + self.name = identifier_suffix_from_str(self.name) + + global _nextflow_metadata + _nextflow_metadata = self + + +_nextflow_metadata: NextflowMetadata | None = None diff --git a/src/latch/types/metadata/snakemake.py b/src/latch/types/metadata/snakemake.py new file mode 100644 index 000000000..298a3c339 --- /dev/null +++ b/src/latch/types/metadata/snakemake.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, Generic, Literal, Union + +import click +from typing_extensions import TypeAlias + +from latch_cli.snakemake.config.utils import validate_snakemake_type +from latch_cli.utils import identifier_suffix_from_str + +from ..directory import LatchDir +from ..file import LatchFile +from .latch import LatchMetadata, LatchParameter +from .utils import P, ParameterType + + +@dataclass +class SnakemakeParameter(LatchParameter, Generic[P]): + type: type[P] | None = None + """ + The python type of the parameter. + """ + default: P | None = None + """ + Optional default value for this parameter + """ + + samplesheet_type: Literal["csv", "tsv", None] = None + """ + The type of samplesheet to construct from the input parameter. + + Only used if the provided parameter is a samplesheet (samplesheet=True) + """ + samplesheet_constructor: Callable[[P], Path] | None = None + """ + A custom samplesheet constructor. + + Should return the path of the constructed samplesheet. If samplesheet_type is also specified, this takes precedence. + Only used if the provided parameter is a samplesheet (samplesheet=True) + """ + + def __post_init__(self): + if self.type is None: + click.secho("All SnakemakeParameter objects must specify a type.", fg="red") + raise click.exceptions.Exit(1) + + +@dataclass +class SnakemakeFileParameter(SnakemakeParameter[Union[LatchFile, LatchDir]]): + """Deprecated: use `file_metadata` keyword in `SnakemakeMetadata` instead""" + + type: type[LatchFile | LatchDir] | None = None + """ + The python type of the parameter. + """ + path: Path | None = None + """ + The path where the file passed to this parameter will be copied. + """ + config: bool = False + """ + Whether or not the file path is exposed in the Snakemake config + """ + download: bool = False + """ + Whether or not the file is downloaded in the JIT step + """ + + +@dataclass +class SnakemakeFileMetadata: + path: Path + """ + The local path where the file passed to this parameter will be copied + """ + config: bool = False + """ + If `True`, expose the file in the Snakemake config + """ + download: bool = False + """ + If `True`, download the file in the JIT step + """ + + +@dataclass +class DockerMetadata: + """Class describing credentials for private docker repositories""" + + username: str + """ + The account username for the private repository + """ + secret_name: str + """ + The name of the Latch Secret that contains the password for the private repository + """ + + +@dataclass +class EnvironmentConfig: + """Class describing environment for spawning Snakemake tasks""" + + use_conda: bool = False + """ + Use Snakemake `conda` directive to spawn tasks in conda environments + """ + use_container: bool = False + """ + Use Snakemake `container` directive to spawn tasks in Docker containers + """ + container_args: list[str] = field(default_factory=list) + """ + Additional arguments to use when running Docker containers + """ + + +FileMetadata: TypeAlias = dict[str, Union[SnakemakeFileMetadata, "FileMetadata"]] + + +@dataclass +class SnakemakeMetadata(LatchMetadata): + """Class for organizing Snakemake workflow metadata""" + + output_dir: LatchDir | None = None + """ + Directory for snakemake workflow outputs + """ + name: str | None = None + """ + Name of the workflow + """ + docker_metadata: DockerMetadata | None = None + """ + Credentials configuration for private docker repositories + """ + env_config: EnvironmentConfig = field(default_factory=EnvironmentConfig) + """ + Environment configuration for spawning Snakemake tasks + """ + parameters: dict[str, SnakemakeParameter[ParameterType]] = field(default_factory=dict) + """ + A dictionary mapping parameter names (strings) to `SnakemakeParameter` objects + """ + file_metadata: FileMetadata = field(default_factory=dict) + """ + A dictionary mapping parameter names to `SnakemakeFileMetadata` objects + """ + cores: int = 4 + """ + Number of cores to use for Snakemake tasks (equivalent of Snakemake's `--cores` flag) + """ + about_page_content: Path | None = None + """ + Path to a markdown file containing information about the pipeline - rendered in the About page. + """ + + def validate(self): + if self.about_page_content is not None: + if not isinstance(self.about_page_content, Path): + click.secho( + f"`about_page_content` parameter ({self.about_page_content}) must" + " be a Path object.", + fg="red", + ) + raise click.exceptions.Exit(1) + + for name, param in self.parameters.items(): + if param.default is None: + continue + try: + validate_snakemake_type(name, param.type, param.default) + except ValueError as e: + click.secho(e, fg="red") + raise click.exceptions.Exit(1) from e + + def __post_init__(self): + self.validate() + + if self.name is None: + self.name = f"snakemake_{identifier_suffix_from_str(self.display_name.lower())}" + + global _snakemake_metadata + _snakemake_metadata = self + + @property + def dict(self): + d = super().dict + # ayush: Paths aren't JSON serializable but ribosome doesn't need it anyway so we can just delete it + del d["__metadata__"]["about_page_content"] + return d + + +_snakemake_metadata: SnakemakeMetadata | None = None diff --git a/src/latch/types/metadata/snakemake_v2.py b/src/latch/types/metadata/snakemake_v2.py new file mode 100644 index 000000000..44f53f59b --- /dev/null +++ b/src/latch/types/metadata/snakemake_v2.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import click + +from latch_cli.utils import identifier_suffix_from_str + +from .latch import LatchMetadata +from .snakemake import SnakemakeParameter # noqa: TCH001 + + +@dataclass(frozen=True) +class SnakemakeRuntimeResources: + """Resources for Snakemake runtime tasks""" + + cpus: int = 1 + """ + Number of CPUs required for the task + """ + memory: int = 2 + """ + Memory required for the task in GiB + """ + storage_gib: int = 50 + """ + Storage required for the task in GiB + """ + + +@dataclass +class SnakemakeV2Metadata(LatchMetadata): + parameters: dict[str, SnakemakeParameter[Any]] = field(default_factory=dict) + """ + A dictionary mapping parameter names (strings) to `SnakemakeParameter` objects + """ + about_page_path: Path | None = None + """ + Path to a markdown file containing information about the pipeline - rendered in the About page. + """ + runtime_resources: SnakemakeRuntimeResources = field(default_factory=SnakemakeRuntimeResources) + + def validate(self): + if self.about_page_path is not None and not isinstance(self.about_page_path, Path): + click.secho( + f"SnakemakeV2Metadata.about_page_path ({self.about_page_path}) must be a" + " `Path` object.", + fg="red", + ) + raise click.exceptions.Exit(1) + + def __post_init__(self): + self.validate() + + self.name = identifier_suffix_from_str(f"snakemake_v2_{self.display_name}".lower()) + + global _snakemake_v2_metadata + _snakemake_v2_metadata = self + + @property + def dict(self): + d = super().dict + del d["__metadata__"]["about_page_path"] + return d + + +_snakemake_v2_metadata: SnakemakeV2Metadata | None = None diff --git a/src/latch/types/metadata/utils.py b/src/latch/types/metadata/utils.py new file mode 100644 index 000000000..f2bcb1374 --- /dev/null +++ b/src/latch/types/metadata/utils.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from collections.abc import Collection +from enum import Enum +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Protocol, TypeVar, Union + +from typing_extensions import TypeAlias + +from ..directory import LatchDir +from ..file import LatchFile + +if TYPE_CHECKING: + from dataclasses import Field + + +# https://stackoverflow.com/questions/54668000/type-hint-for-an-instance-of-a-non-specific-dataclass +class _IsDataclass(Protocol): + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] + + +DC = TypeVar("DC", bound=_IsDataclass) + +ParameterType: TypeAlias = Union[ + None, + int, + float, + str, + bool, + LatchFile, + LatchDir, + Enum, + _IsDataclass, + Collection["ParameterType"], +] + +P = TypeVar("P", bound=ParameterType) diff --git a/src/latch/utils.py b/src/latch/utils.py index 23d5bb334..a6656261e 100644 --- a/src/latch/utils.py +++ b/src/latch/utils.py @@ -1,9 +1,12 @@ import itertools import os -from typing import Dict, TypedDict +from dataclasses import dataclass +from typing import ClassVar, Dict, Optional, TypedDict import gql import jwt +from typing_extensions import Self + from latch_sdk_config.user import user_config from latch_sdk_gql.execute import execute @@ -112,9 +115,7 @@ def get_workspaces() -> Dict[str, WSInfo]: owned_org_teams = [x["teamInfosByOrgId"]["nodes"] for x in res["orgInfos"]["nodes"]] owned_org_teams = list(itertools.chain(*owned_org_teams)) - member_org_teams = [ - x["org"]["teamInfosByOrgId"]["nodes"] for x in res["orgMembers"]["nodes"] - ] + member_org_teams = [x["org"]["teamInfosByOrgId"]["nodes"] for x in res["orgMembers"]["nodes"]] member_org_teams = list(itertools.chain(*member_org_teams)) default_account = ( @@ -130,11 +131,7 @@ def get_workspaces() -> Dict[str, WSInfo]: ) for x in owned_teams + member_teams - + ( - [res["teamInfoByAccountId"]] - if res["teamInfoByAccountId"] is not None - else [] - ) + + ([res["teamInfoByAccountId"]] if res["teamInfoByAccountId"] is not None else []) + owned_org_teams + member_org_teams } @@ -167,7 +164,7 @@ def current_workspace() -> str: } } } - """), + """) )["accountInfoCurrent"] ws = res["id"] @@ -180,3 +177,20 @@ def current_workspace() -> str: class NotFoundError(ValueError): ... + + +@dataclass(frozen=True) +class Singleton: + """Base class for singleton objects. + + The constructor returns a referentially identical instance each call. That is, + `Singleton() is Singleton()` + """ + + _singleton: ClassVar[Optional[Self]] = None + + def __new__(cls) -> Self: + if cls._singleton is None: + cls._singleton = super().__new__(cls) + + return cls._singleton diff --git a/src/latch_cli/centromere/ast_parsing.py b/src/latch_cli/centromere/ast_parsing.py index 1df006cec..db9c746da 100644 --- a/src/latch_cli/centromere/ast_parsing.py +++ b/src/latch_cli/centromere/ast_parsing.py @@ -41,7 +41,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): # noqa: N802 # 3. save fully qualified name for tasks (need to parse based on import graph) for decorator in node.decorator_list: if isinstance(decorator, ast.Name): - if decorator.id in {"workflow", "nextflow_workflow"}: + if decorator.id in {"workflow", "nextflow_workflow", "snakemake_workflow"}: self.flyte_objects.append(FlyteObject("workflow", fqn)) elif decorator.id in task_decorators: self.flyte_objects.append(FlyteObject("task", fqn)) @@ -53,10 +53,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef): # noqa: N802 if func.id not in task_decorators and func.id not in { "workflow", "nextflow_workflow", + "snakemake_workflow", }: continue - if func.id in {"workflow", "nextflow_workflow"}: + if func.id in {"workflow", "nextflow_workflow", "snakemake_workflow"}: self.flyte_objects.append(FlyteObject("workflow", fqn)) continue @@ -105,9 +106,7 @@ def get_flyte_objects(module: Path) -> list[FlyteObject]: if file.suffix != ".py": continue - module_name = str(file.with_suffix("").relative_to(module.parent)).replace( - os.sep, "." - ) + module_name = str(file.with_suffix("").relative_to(module.parent)).replace(os.sep, ".") v = Visitor(file, module_name) @@ -115,9 +114,7 @@ def get_flyte_objects(module: Path) -> list[FlyteObject]: parsed = ast.parse(file.read_text(), filename=file) except SyntaxError as e: traceback.print_exc() - click.secho( - "\nRegistration failed due to a syntax error (see above)", fg="red" - ) + click.secho("\nRegistration failed due to a syntax error (see above)", fg="red") raise click.exceptions.Exit(1) from e v.visit(parsed) diff --git a/src/latch_cli/centromere/ctx.py b/src/latch_cli/centromere/ctx.py index da07ab1d7..3ce1e295c 100644 --- a/src/latch_cli/centromere/ctx.py +++ b/src/latch_cli/centromere/ctx.py @@ -157,8 +157,7 @@ def __init__( pass except Exception as e: click.secho( - "WARN: Exception occurred while getting git hash from" - f" {self.pkg_root}: {e}", + f"WARN: Exception occurred while getting git hash from {self.pkg_root}: {e}", fg="yellow", ) @@ -334,10 +333,12 @@ def __init__( import subprocess if system == "Linux": - res = subprocess.run([ - "xdg-open", - new_meta, - ]).returncode + res = subprocess.run( + [ + "xdg-open", + new_meta, + ] + ).returncode elif system == "Darwin": res = subprocess.run(["open", new_meta]).returncode elif system == "Windows": @@ -457,6 +458,7 @@ def _patched_create_paramiko_client(self, base_url): else: self.dkr_client = _construct_dkr_client() + self.remote_conn_info = None except (Exception, KeyboardInterrupt) as e: self.cleanup() raise e diff --git a/src/latch_cli/docker_utils/__init__.py b/src/latch_cli/docker_utils/__init__.py index e82462e6d..1d05edb25 100644 --- a/src/latch_cli/docker_utils/__init__.py +++ b/src/latch_cli/docker_utils/__init__.py @@ -52,11 +52,6 @@ class DockerfileBuilder: direnv: Optional[Path] = None def get_prologue(self): - if self.wf_type == WorkflowType.snakemake: - library_name = '"latch[snakemake]"' - else: - library_name = "latch" - self.commands.append( DockerCmdBlock( comment="Prologue", @@ -83,21 +78,25 @@ def get_prologue(self): "env LANG='en_US.UTF-8'", "", "arg DEBIAN_FRONTEND=noninteractive", - "", - "# Latch SDK", - "# DO NOT REMOVE", - f"run pip install {library_name}=={self.config.latch_version}", - "run mkdir /opt/latch", ], order=DockerCmdBlockOrder.precopy, ) ) def get_epilogue(self): + if self.wf_type == WorkflowType.snakemake: + library_name = '"latch[snakemake]"' + else: + library_name = "latch" + self.commands.append( DockerCmdBlock( comment="Epilogue", commands=[ + "", + "# Latch SDK", + "# DO NOT REMOVE", + f"run pip install {library_name}=={self.config.latch_version}", "", "# Latch workflow registration metadata", "# DO NOT CHANGE", @@ -294,10 +293,7 @@ def infer_env_commands(self): return click.echo( - " ".join([ - click.style(f"{self.direnv.name}:", bold=True), - "Environment variable setup", - ]) + " ".join([click.style(f"{self.direnv.name}:", bold=True), "Environment variable setup"]) ) envs: list[str] = [] for line in self.direnv.read_text().splitlines(): @@ -327,15 +323,6 @@ def infer_dependencies(self): def get_copy_file_commands(self): cmd = ["copy . /root/"] - if self.wf_type == WorkflowType.snakemake: - cmd.extend([ - "", - "# Latch snakemake workflow entrypoint", - "# DO NOT CHANGE", - "", - "copy .latch/snakemake_jit_entrypoint.py /root/snakemake_jit_entrypoint.py", - ]) - self.commands.append( DockerCmdBlock( comment="Copy workflow data (use .dockerignore to skip files)", @@ -351,20 +338,13 @@ def generate(self, *, dest: Optional[Path] = None, overwrite: bool = False): if ( dest.exists() and not overwrite - and not ( - click.confirm(f"Dockerfile already exists at `{dest}`. Overwrite?") - ) + and not (click.confirm(f"Dockerfile already exists at `{dest}`. Overwrite?")) ): return click.secho("Generating Dockerfile", bold=True) - click.echo( - " ".join([ - click.style("Base image:", fg="bright_blue"), - self.config.base_image, - ]) - ) + click.echo(" ".join([click.style("Base image:", fg="bright_blue"), self.config.base_image])) click.echo( " ".join([ click.style("Latch SDK version:", fg="bright_blue"), @@ -398,15 +378,10 @@ def generate(self, *, dest: Optional[Path] = None, overwrite: bool = False): click.secho(f"Successfully generated dockerfile `{dest}`", fg="green") -def generate_dockerignore( - dest: Path, *, wf_type: WorkflowType, overwrite: bool = False -) -> None: +def generate_dockerignore(dest: Path, *, wf_type: WorkflowType, overwrite: bool = False) -> None: if dest.exists(): if dest.is_dir(): - click.secho( - f".dockerignore already exists at `{dest}` and is a directory.", - fg="red", - ) + click.secho(f".dockerignore already exists at `{dest}` and is a directory.", fg="red") raise click.exceptions.Exit(1) if not overwrite and not click.confirm( @@ -424,9 +399,7 @@ def generate_dockerignore( click.secho(f"Successfully generated .dockerignore `{dest}`", fg="green") -def get_default_dockerfile( - pkg_root: Path, *, wf_type: WorkflowType, overwrite: bool = False -): +def get_default_dockerfile(pkg_root: Path, *, wf_type: WorkflowType, overwrite: bool = False): default_dockerfile = pkg_root / "Dockerfile" config = get_or_create_workflow_config( diff --git a/src/latch_cli/main.py b/src/latch_cli/main.py index 69a75dc42..6ecd4f830 100644 --- a/src/latch_cli/main.py +++ b/src/latch_cli/main.py @@ -17,6 +17,7 @@ from latch.ldata._transfer.progress import Progress as _Progress # noqa: PLC2701 from latch.utils import current_workspace from latch_cli.click_utils import EnumChoice +from latch_cli.docker_utils import DockerfileBuilder from latch_cli.exceptions.handler import CrashHandler from latch_cli.services.cp.autocomplete import complete as cp_complete from latch_cli.services.cp.autocomplete import remote_complete @@ -30,7 +31,7 @@ get_local_package_version, hash_directory, ) -from latch_cli.workflow_config import BaseImageOptions +from latch_cli.workflow_config import BaseImageOptions, get_or_create_workflow_config from latch_sdk_gql.execute import execute as gql_execute latch_cli.click_utils.patch() @@ -95,10 +96,7 @@ def main(): @main.command("login") @click.option( - "--connection", - type=str, - default=None, - help="Specific AuthO connection name e.g. for SSO.", + "--connection", type=str, default=None, help="Specific AuthO connection name e.g. for SSO." ) def login(connection: Optional[str]): """Manually login to Latch.""" @@ -176,9 +174,7 @@ def init( @main.command("dockerfile") -@click.argument( - "pkg_root", type=click.Path(exists=True, file_okay=False, path_type=Path) -) +@click.argument("pkg_root", type=click.Path(exists=True, file_okay=False, path_type=Path)) @click.option( "-s", "--snakemake", @@ -431,23 +427,14 @@ def generate_metadata( raise click.exceptions.Exit(1) generate_metadata( - config_file, - metadata_root, - skip_confirmation=yes, - infer_files=not no_infer_files, - generate_defaults=not no_defaults, + config_file, metadata_root, skip_confirmation=yes, generate_defaults=not no_defaults ) @main.command("develop") @click.argument("pkg_root", nargs=1, type=click.Path(exists=True, path_type=Path)) @click.option( - "--yes", - "-y", - is_flag=True, - default=False, - type=bool, - help="Skip the confirmation dialog.", + "--yes", "-y", is_flag=True, default=False, type=bool, help="Skip the confirmation dialog." ) @click.option( "--wf-version", @@ -522,9 +509,7 @@ def local_development( @main.command("exec") -@click.option( - "--execution-id", "-e", type=str, help="Optional execution ID to inspect." -) +@click.option("--execution-id", "-e", type=str, help="Optional execution ID to inspect.") @click.option("--egn-id", "-g", type=str, help="Optional task execution ID to inspect.") @click.option( "--container-index", @@ -533,9 +518,7 @@ def local_development( help="Optional container index to inspect (only used for Map Tasks)", ) @requires_login -def execute( - execution_id: Optional[str], egn_id: Optional[str], container_index: Optional[int] -): +def execute(execution_id: Optional[str], egn_id: Optional[str], container_index: Optional[int]): """Drops the user into an interactive shell from within a task.""" from latch_cli.services.k8s.execute import exec as _exec @@ -552,8 +535,7 @@ def execute( default=False, type=bool, help=( - "Whether to automatically bump the version of the workflow each time register" - " is called." + "Whether to automatically bump the version of the workflow each time register is called." ), ) @click.option( @@ -574,12 +556,7 @@ def execute( ), ) @click.option( - "-y", - "--yes", - is_flag=True, - default=False, - type=bool, - help="Skip the confirmation dialog.", + "-y", "--yes", is_flag=True, default=False, type=bool, help="Skip the confirmation dialog." ) @click.option( "--open", @@ -620,10 +597,7 @@ def execute( is_flag=True, default=False, type=bool, - help=( - "Whether or not to cache snakemake tasks. Ignored if --snakefile is not" - " provided." - ), + help=("Whether or not to cache snakemake tasks. Ignored if --snakefile is not provided."), ) @click.option( "--nf-script", @@ -738,9 +712,7 @@ def register( @main.command("launch") @click.argument("params_file", nargs=1, type=click.Path(exists=True)) @click.option( - "--version", - default=None, - help="The version of the workflow to launch. Defaults to latest.", + "--version", default=None, help="The version of the workflow to launch. Defaults to latest." ) @requires_login def launch(params_file: Path, version: Union[str, None] = None): @@ -773,16 +745,13 @@ def launch(params_file: Path, version: Union[str, None] = None): version = "latest" click.secho( - f"Successfully launched workflow named {wf_name} with version {version}.", - fg="green", + f"Successfully launched workflow named {wf_name} with version {version}.", fg="green" ) @main.command("get-params") @click.argument("wf_name", nargs=1) -@click.option( - "--version", default=None, help="The version of the workflow. Defaults to latest." -) +@click.option("--version", default=None, help="The version of the workflow. Defaults to latest.") @requires_login def get_params(wf_name: Union[str, None], version: Union[str, None] = None): """[DEPRECATED] Generate a python parameter map for a workflow. @@ -818,9 +787,7 @@ def get_params(wf_name: Union[str, None], version: Union[str, None] = None): @main.command("get-wf") @click.option( - "--name", - default=None, - help="The name of the workflow to list. Will display all versions", + "--name", default=None, help="The name of the workflow to list. Will display all versions" ) @requires_login def get_wf(name: Union[str, None] = None): @@ -840,9 +807,7 @@ def get_wf(name: Union[str, None] = None): version_padding = max(version_padding, version_len) # TODO(ayush): make this much better - click.secho( - f"ID{id_padding * ' '}\tName{name_padding * ' '}\tVersion{version_padding * ' '}" - ) + click.secho(f"ID{id_padding * ' '}\tName{name_padding * ' '}\tVersion{version_padding * ' '}") for wf in wfs: click.secho( f"{wf[0]}{(id_padding - len(str(wf[0]))) * ' '}\t{wf[1]}{(name_padding - len(wf[1])) * ' '}\t{wf[2]}{(version_padding - len(wf[2])) * ' '}" @@ -905,13 +870,9 @@ def get_executions(): default=False, show_default=True, ) +@click.option("--cores", help="Manually specify number of cores to parallelize over", type=int) @click.option( - "--cores", help="Manually specify number of cores to parallelize over", type=int -) -@click.option( - "--chunk-size-mib", - help="Manually specify the upload chunk size in MiB. Must be >= 5", - type=int, + "--chunk-size-mib", help="Manually specify the upload chunk size in MiB. Must be >= 5", type=int ) @requires_login def cp( @@ -1001,12 +962,7 @@ def ls(paths: tuple[str], group_directories_first: bool): @main.command("rmr") @click.argument("remote_path", nargs=1, type=str) @click.option( - "-y", - "--yes", - is_flag=True, - default=False, - type=bool, - help="Skip the confirmation dialog.", + "-y", "--yes", is_flag=True, default=False, type=bool, help="Skip the confirmation dialog." ) @click.option( "--no-glob", @@ -1052,27 +1008,18 @@ def mkdir(remote_directory: str): @click.argument("srcs", nargs=-1) @click.argument("dst", nargs=1) @click.option( - "--delete", - help="Delete extraneous files from destination.", - is_flag=True, - default=False, + "--delete", help="Delete extraneous files from destination.", is_flag=True, default=False ) @click.option( "--ignore-unsyncable", - help=( - "Synchronize even if some source paths do not exist or refer to special files." - ), + help=("Synchronize even if some source paths do not exist or refer to special files."), is_flag=True, default=False, ) @click.option("--cores", help="Number of cores to use for parallel syncing.", type=int) @requires_login def sync( - srcs: list[str], - dst: str, - delete: bool, - ignore_unsyncable: bool, - cores: Optional[int] = None, + srcs: list[str], dst: str, delete: bool, ignore_unsyncable: bool, cores: Optional[int] = None ): """Update the contents of a remote directory with local data.""" from latch_cli.services.sync import sync @@ -1132,7 +1079,7 @@ def version(pkg_root: Path): ), ) @click.option("--yes", "-y", is_flag=True, help="Skip the confirmation dialog.") -def generate_entrypoint( +def nf_generate_entrypoint( pkg_root: Path, metadata_root: Optional[Path], nf_script: Path, @@ -1151,9 +1098,7 @@ def generate_entrypoint( output = output.with_suffix(".py") - if not yes and not click.confirm( - f"Will generate an entrypoint at {output}. Proceed?" - ): + if not yes and not click.confirm(f"Will generate an entrypoint at {output}. Proceed?"): raise click.exceptions.Abort output.parent.mkdir(exist_ok=True) @@ -1161,9 +1106,7 @@ def generate_entrypoint( if ( not yes and output.exists() - and not click.confirm( - f"Nextflow entrypoint already exists at `{output}`. Overwrite?" - ) + and not click.confirm(f"Nextflow entrypoint already exists at `{output}`. Overwrite?") ): return @@ -1178,9 +1121,7 @@ def generate_entrypoint( if metadata._nextflow_metadata is None: click.secho( dedent(f"""\ - Failed to generate Nextflow entrypoint. - Make sure the project root contains a `{meta}` - with a `NextflowMetadata` object defined. + Failed to generate Nextflow entrypoint. Make sure the project root contains a `{meta}` with a `NextflowMetadata` object defined. """), fg="red", ) @@ -1192,9 +1133,7 @@ def generate_entrypoint( @nextflow.command("attach") -@click.option( - "--execution-id", "-e", type=str, help="Optional execution ID to inspect." -) +@click.option("--execution-id", "-e", type=str, help="Optional execution ID to inspect.") @requires_login def attach(execution_id: Optional[str]): """Drops the user into an interactive shell to inspect the workdir of a nextflow execution.""" @@ -1205,9 +1144,7 @@ def attach(execution_id: Optional[str]): @nextflow.command("register") -@click.argument( - "pkg_root", type=click.Path(exists=True, file_okay=False, path_type=Path) -) +@click.argument("pkg_root", type=click.Path(exists=True, file_okay=False, path_type=Path)) @click.option("--yes", "-y", is_flag=True, help="Skip confirmation dialogs.") @click.option( "--no-ignore", @@ -1286,9 +1223,7 @@ def nf_register( sha = repo.head.commit.hexsha[:6] components.append(sha) click.echo(f"Tagging version with git commit {sha}.") - click.secho( - " Disable with --disable-git-version/-G", dim=True, italic=True - ) + click.secho(" Disable with --disable-git-version/-G", dim=True, italic=True) if repo.is_dirty(): components.append("wip") @@ -1386,11 +1321,106 @@ def nf_register( click.echo() register( - pkg_root, - config=RegisterConfig(workflow_name, version, Path(script_path), not no_ignore), + pkg_root, config=RegisterConfig(workflow_name, version, Path(script_path), not no_ignore) ) +@main.group() +def snakemake(): + """Manage snakemake-specific commands""" + + +# todo(ayush): allow providing destinations for +# - config path +# - dockerfile path +# - entrypoint output +@snakemake.command("generate-entrypoint") +@click.argument("pkg-root", nargs=1, type=click.Path(exists=True, path_type=Path)) +@click.option( + "--metadata-root", + type=click.Path(exists=True, path_type=Path, file_okay=False), + help="Path to a directory containing a python package defining a SnakemakeV2Metadata " + "object. If not provided, will default to searching the package root for a directory called " + "`latch_metadata`.", +) +@click.option( + "--snakefile", + required=False, + type=click.Path(exists=True, path_type=Path, dir_okay=False), + help="Path to the Snakefile to register. If not provided, will default to searching the package " + "root for a file named `Snakefile`.", +) +@click.option( + "--no-dockerfile", + "-D", + is_flag=True, + default=False, + type=bool, + help="Disable automatically generating a Dockerfile.", +) +def sm_generate_entrypoint( + pkg_root: Path, metadata_root: Optional[Path], snakefile: Optional[Path], no_dockerfile: bool +): + """Generate a `wf/entrypoint.py` file from a Snakemake workflow""" + + from latch_cli.services.register.utils import import_module_by_path + from latch_cli.snakemake.v2.workflow import get_entrypoint_content + + dest = pkg_root / "wf" / "entrypoint.py" + dest.parent.mkdir(exist_ok=True) + + if dest.exists() and not click.confirm( + f"Workflow entrypoint already exists at `{dest}`. Overwrite?" + ): + return + + if metadata_root is None: + metadata_root = pkg_root / "latch_metadata" + + metadata_path = metadata_root / "__init__.py" + if metadata_path.exists(): + click.echo(f"Using metadata file {click.style(metadata_path, italic=True)}") + import_module_by_path(metadata_path) + else: + click.secho( + f"Unable to find file `{metadata_path}` with a `SnakemakeV2Metadata` object " + "defined. If you have a custom metadata root please provide a path " + "to it using the `--metadata-root` option", + fg="red", + ) + raise click.exceptions.Exit(1) + + import latch.types.metadata.snakemake_v2 as metadata + + if metadata._snakemake_v2_metadata is None: + click.secho( + "Failed to generate entrypoint. Make sure the python package at path " + f"`{metadata_path}` defines a `SnakemakeV2Metadata` object.", + fg="red", + ) + raise click.exceptions.Exit(1) + + if snakefile is None: + snakefile = pkg_root / "Snakefile" + + if not snakefile.exists(): + click.secho( + f"Unable to find a Snakefile at `{snakefile}`. If your Snakefile is " + "in a different location please provide an explicit path to it " + "using the `--snakefile` option." + ) + raise click.exceptions.Exit(1) + + if not no_dockerfile: + config = get_or_create_workflow_config( + pkg_root / ".latch/config", base_image_type=BaseImageOptions.default + ) + DockerfileBuilder(pkg_root, config, wf_type=WorkflowType.snakemake).generate() + + dest.write_text(get_entrypoint_content(pkg_root, metadata_path, snakefile)) + click.secho(f"Successfully generated entrypoint file `{dest}`", fg="green") + + """ POD COMMANDS """ @@ -1424,8 +1454,7 @@ def stop_pod(pod_id: Optional[int] = None): err_str = f"Error reading Pod ID from `{id_path}`" click.secho( - f"{err_str} -- please provide a Pod ID as a command line argument.", - fg="red", + f"{err_str} -- please provide a Pod ID as a command line argument.", fg="red" ) return diff --git a/src/latch_cli/nextflow/config.py b/src/latch_cli/nextflow/config.py index 2c07fd821..204e1a24d 100644 --- a/src/latch_cli/nextflow/config.py +++ b/src/latch_cli/nextflow/config.py @@ -16,19 +16,11 @@ from latch.types.directory import LatchDir from latch.types.file import LatchFile from latch.types.samplesheet_item import SamplesheetItem -from latch_cli.snakemake.config.utils import get_preamble -from latch_cli.utils import best_effort_display_name, identifier_from_str +from ..snakemake.config.utils import get_preamble +from ..utils import best_effort_display_name, best_effort_title_case, identifier_from_str from .parse_schema import NfType, parse_schema -underscores = re.compile(r"_+") -spaces = re.compile(r"\s+") - - -def best_effort_title_case(s: str) -> str: - return identifier_from_str(spaces.sub("", underscores.sub(" ", s).title())) - - T = TypeVar("T") @@ -118,14 +110,10 @@ def get_python_type_inner( else: defaults.append((field_name, field_type, field_obj)) - return make_dataclass( - f"{best_effort_title_case(param_name)}Type", no_defaults + defaults - ) + return make_dataclass(f"{best_effort_title_case(param_name)}Type", no_defaults + defaults) if typ["type"] == "samplesheet": - dc = get_python_type( - param_name, {**typ, "type": "object", "properties": typ["schema"]} - ) + dc = get_python_type(param_name, {**typ, "type": "object", "properties": typ["schema"]}) return list[SamplesheetItem[dc]] assert typ["type"] == "enum", f"unsupported type {typ['typ']!r}" @@ -157,9 +145,7 @@ def get_python_type( return Optional[inner] -def generate_flow( - raw_schema_content: dict[str, object], parsed: dict[str, NfType] -) -> str: +def generate_flow(raw_schema_content: dict[str, object], parsed: dict[str, NfType]) -> str: if "$defs" not in raw_schema_content: return "generated_flow = None" @@ -220,9 +206,7 @@ def generate_flow( return f"generated_flow = [{', '.join(flow_elements)}]" -def generate_metadata( - schema_path: Path, metadata_root: Path, *, skip_confirmation: bool = False -): +def generate_metadata(schema_path: Path, metadata_root: Path, *, skip_confirmation: bool = False): raw_schema_content: dict[str, object] = json.loads(schema_path.read_text()) display_name: Optional[str] = raw_schema_content.get("title") diff --git a/src/latch_cli/snakemake/config/parser.py b/src/latch_cli/snakemake/config/parser.py index 35a8df546..b238b5f15 100644 --- a/src/latch_cli/snakemake/config/parser.py +++ b/src/latch_cli/snakemake/config/parser.py @@ -1,16 +1,21 @@ -from dataclasses import fields, is_dataclass +from dataclasses import Field, field, fields, is_dataclass, make_dataclass from pathlib import Path -from typing import Dict, List, Tuple, Type, TypeVar, get_args, get_origin +from typing import Annotated, TypeVar, Union, get_args, get_origin import click +import google.protobuf.json_format as gpjson import yaml -from typing_extensions import Annotated +from flytekit.core.annotation import FlyteAnnotation +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine from latch.types.directory import LatchDir from latch.types.file import LatchFile +from latch.utils import Singleton from latch_cli.snakemake.utils import reindent from latch_cli.utils import best_effort_display_name, identifier_from_str +from ...utils import best_effort_title_case, exit from .utils import ( JSONValue, get_preamble, @@ -24,29 +29,20 @@ T = TypeVar("T") -def parse_config( - config_path: Path, - *, - infer_files: bool = False, -) -> Dict[str, Tuple[Type[T], T]]: +class NoValue(Singleton): ... + + +def parse_config(config_path: Path) -> dict[str, tuple[type[T], Union[T, NoValue]]]: if not config_path.exists(): - click.secho( - f"No config file found at {config_path}.", - fg="red", - ) - raise click.exceptions.Exit(1) + raise exit(f"No config file found at {config_path}.") if config_path.is_dir(): - click.secho( - f"Path {config_path} points to a directory.", - fg="red", - ) - raise click.exceptions.Exit(1) + raise exit(f"Path {config_path} points to a directory.") try: res: JSONValue = yaml.safe_load(config_path.read_text()) except yaml.YAMLError as e: - click.secho( + raise exit( reindent( f""" Error loading config from {config_path}: @@ -54,86 +50,39 @@ def parse_config( {e} """, 0, - ), - fg="red", - ) - raise click.exceptions.Exit(1) from e + ) + ) from e - if not isinstance(res, dict): - # ayush: this case doesn't matter bc a non-dict .yaml file isn't valid snakemake - return {"snakemake_parameter": (parse_type(res, infer_files=infer_files), res)} + assert isinstance(res, dict) - parsed: Dict[str, Type] = {} + parsed: dict[str, tuple[type[T], T]] = {} for k, v in res.items(): try: - typ = parse_type(v, k, infer_files=infer_files) + typ = parse_type(v, k) except ValueError as e: - click.secho( - f"WARNING: Skipping parameter {k}. Failed to parse type: {e}.", - fg="yellow", - ) + click.secho(f"WARNING: Skipping parameter {k}. Failed to parse type: {e}.", fg="yellow") continue - val, default = parse_value(typ, v) - parsed[k] = (typ, (val, default)) - - return parsed - - -def file_metadata_str(typ: Type, value: JSONValue, level: int = 0) -> str: - if get_origin(typ) is Annotated: - args = get_args(typ) - assert len(args) > 0 - return file_metadata_str(args[0], value, level) + default = NoValue() + try: + default = parse_value(typ, v) + except AssertionError as e: + click.secho(f"WARNING: Unable to parse default for parameter {k}: {e}.", fg="yellow") - if is_primitive_type(typ): - return "" + parsed[k] = (typ, default) - if typ in {LatchFile, LatchDir}: - return reindent( - f"""\ - SnakemakeFileMetadata( - path={repr(value)}, - config=True, - ),\n""", - level, - ) + return parsed - metadata: List[str] = [] - if is_list_type(typ): - template = """ - [ - __metadata__],\n""" - - args = get_args(typ) - assert len(args) > 0 - for val in value: - metadata_str = file_metadata_str(get_args(typ)[0], val, level + 1) - if metadata_str == "": - continue - metadata.append(metadata_str) - else: - template = """ - { - __metadata__},\n""" - - assert is_dataclass(typ) - for field in fields(typ): - metadata_str = file_metadata_str( - field.type, getattr(value, field.name), level - ) - if metadata_str == "": - continue - metadata_str = f"{repr(identifier_from_str(field.name))}: {metadata_str}" - metadata.append(reindent(metadata_str, level + 1)) - if len(metadata) == 0: - return "" +# doing bare lambda: variable_name doesn't work because we call the lambda to get its return value +# and print it so if its something of the form lambda: variable_name, the call will always result +# in the latest value of variable_name, as opposed to the value of variable_name at the time the +# lambda was created +def get_lambda(value: object): + def inner(): + return value - return reindent( - template, - level, - ).replace("__metadata__", "".join(metadata), level + 1) + return inner # todo(ayush): print informative stuff here ala register @@ -143,40 +92,42 @@ def generate_metadata( *, skip_confirmation: bool = False, generate_defaults: bool = False, - infer_files: bool = False, ): - parsed = parse_config(config_path, infer_files=infer_files) + parsed = parse_config(config_path) - preambles: List[str] = [] - params: List[str] = [] - file_metadata: List[str] = [] + no_defaults: list[tuple[str, type, Field[object]]] = [] + defaults: list[tuple[str, type, Field[object]]] = [] - for k, (typ, (val, default)) in parsed.items(): - preambles.append(get_preamble(typ)) + ctx = FlyteContextManager.current_context() - param_str = reindent( - f"""\ - {repr(identifier_from_str(k))}: SnakemakeParameter( - display_name={repr(best_effort_display_name(k))}, - type={type_repr(typ)}, - __default__),""", - 0, - ) + for k, (typ, default) in parsed.items(): + name = identifier_from_str(k) - default_str = "" - if generate_defaults and default is not None: - default_str = f" default={repr(default)},\n" + annotations: dict[str, object] = { + "display_name": best_effort_display_name(k), + "output": name == "outdir", + } + annotated_typ = Annotated[typ, FlyteAnnotation(annotations)] - param_str = param_str.replace("__default__", default_str) + if not generate_defaults or default is NoValue(): + no_defaults.append((name, annotated_typ, field())) + continue - param_str = reindent(param_str, 1) - params.append(param_str) + annotations["default"] = gpjson.MessageToDict( + TypeEngine.to_literal(ctx, default, typ, TypeEngine.to_literal_type(typ)).to_flyte_idl() + ) - metadata_str = file_metadata_str(typ, val) - if metadata_str == "": + if isinstance(default, (list, dict, LatchFile, LatchDir)): + defaults.append((name, annotated_typ, field(default_factory=get_lambda(default)))) continue - metadata_str = f"{repr(identifier_from_str(k))}: {metadata_str}" - file_metadata.append(reindent(metadata_str, 1)) + + if is_dataclass(default): + defaults.append((name, annotated_typ, field(default_factory=default))) + continue + + defaults.append((name, annotated_typ, field(default=default))) + + generated_args_type = make_dataclass("SnakemakeArgsType", no_defaults + defaults) if metadata_root.is_file(): if not click.confirm(f"A file exists at `{metadata_root}`. Delete it?"): @@ -187,47 +138,28 @@ def generate_metadata( metadata_root.mkdir(exist_ok=True) metadata_path = metadata_root / Path("__init__.py") - old_metadata_path = Path("latch_metadata.py") - - if old_metadata_path.exists() and not metadata_path.exists(): - if click.confirm( - "Found legacy `latch_metadata.py` file in current directory. This is" - " deprecated and will be ignored in future releases. Move to" - f" `{metadata_path}`? (This will not change file contents)" - ): - old_metadata_path.rename(metadata_path) - elif old_metadata_path.exists() and metadata_path.exists(): - click.secho( - "Warning: Found both `latch_metadata.py` and" - f" `{metadata_path}` in current directory." - " `latch_metadata.py` will be ignored.", - fg="yellow", - ) if not metadata_path.exists(): metadata_path.write_text( reindent( r""" - from latch.types.metadata import SnakemakeMetadata, LatchAuthor, EnvironmentConfig - from latch.types.directory import LatchDir + from latch.types.metadata import LatchAuthor + from latch.types.metadata.snakemake_v2 import SnakemakeV2Metadata, SnakemakeParameter - from .parameters import generated_parameters, file_metadata + from .generated import SnakemakeArgsType - SnakemakeMetadata( - output_dir=LatchDir("latch:///your_output_directory"), + class WorkflowArgsType(SnakemakeArgsType): + # add custom parameters here + ... + + SnakemakeV2Metadata( display_name="Your Workflow Name", author=LatchAuthor( name="Your Name", ), - env_config=EnvironmentConfig( - use_conda=False, - use_container=False, - ), - cores=4, - # Add more parameters - parameters=generated_parameters, - file_metadata=file_metadata, - + parameters={ + "args": SnakemakeParameter(type=WorkflowArgsType) + }, ) """, 0, @@ -235,7 +167,7 @@ def generate_metadata( ) click.secho(f"Generated `{metadata_path}`.", fg="green") - params_path = metadata_root / Path("parameters.py") + params_path = metadata_root / Path("generated.py") if ( params_path.exists() and not skip_confirmation @@ -245,35 +177,30 @@ def generate_metadata( params_path.write_text( reindent( - r""" - from dataclasses import dataclass + rf""" + # This file is auto-generated, PLEASE DO NOT EDIT DIRECTLY! To update, run + # + # $ latch generate-metadata --snakemake {config_path} + # + # Add any custom logic or parameters in `latch_metadata/__init__.py`. + import typing - import typing_extensions + from dataclasses import dataclass, field + from enum import Enum + import typing_extensions from flytekit.core.annotation import FlyteAnnotation - from latch.types.metadata import SnakemakeParameter, SnakemakeFileParameter, SnakemakeFileMetadata - from latch.types.file import LatchFile + from latch.ldata.path import LPath from latch.types.directory import LatchDir + from latch.types.file import LatchFile + from latch.types.metadata import Params, Section, Spoiler, Text + from latch.types.samplesheet_item import SamplesheetItem __preambles__ - # Import these into your `__init__.py` file: - # - # from .parameters import generated_parameters, file_metadata - - generated_parameters = { - __params__ - } - - file_metadata = { - __file_metadata__} - """, 0, - ) - .replace("__preambles__", "".join(preambles)) - .replace("__params__", "\n".join(params)) - .replace("__file_metadata__", "".join(file_metadata)) + ).replace("__preambles__", get_preamble(generated_args_type)) ) click.secho(f"Generated `{params_path}`.", fg="green") diff --git a/src/latch_cli/snakemake/config/utils.py b/src/latch_cli/snakemake/config/utils.py index 9577eb35a..c18262691 100644 --- a/src/latch_cli/snakemake/config/utils.py +++ b/src/latch_cli/snakemake/config/utils.py @@ -1,29 +1,27 @@ -from dataclasses import MISSING, Field, field, fields, is_dataclass, make_dataclass +from __future__ import annotations + +import re +import sys +from dataclasses import MISSING, Field, fields, is_dataclass, make_dataclass from enum import Enum -from types import MappingProxyType -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Type, - Union, - get_args, - get_origin, -) +from typing import Annotated, Any, Callable, TypeVar, Union, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation -from typing_extensions import Annotated, TypeAlias, TypeGuard +from typing_extensions import TypeAlias, TypeGuard from latch.ldata.path import LPath from latch.types.directory import LatchDir from latch.types.file import LatchFile from latch.types.samplesheet_item import SamplesheetItem -from latch_cli.utils import identifier_from_str +from latch_cli.utils import best_effort_title_case, identifier_from_str + +JSONValue: TypeAlias = Union[int, str, bool, float, None, list["JSONValue"], "JSONDict"] +JSONDict: TypeAlias = dict[str, "JSONValue"] -JSONValue: TypeAlias = Union[int, str, bool, float, None, List["JSONValue"], "JSONDict"] -JSONDict: TypeAlias = Dict[str, "JSONValue"] +if sys.version_info >= (3, 10): + from types import UnionType +else: + UnionType = Union # ayush: yoinked from console valid_extensions = { @@ -104,52 +102,74 @@ } -def parse_type( - v: JSONValue, name: Optional[str] = None, *, infer_files: bool = False -) -> Type: +expr = re.compile( + r""" + ^( + (latch://.*) + | (s3://.*) + | ( + /? + ([^/]/)+ + [^/]* + ) + )$ + """, + re.VERBOSE, +) + + +def is_file_like(name: str, value: str) -> bool: + if name == "outdir": + return True + + if expr.match(value): + return True + + return any(value.endswith(x) for x in valid_extensions) + + +def parse_type(v: JSONValue, name: str) -> type: if v is None: return str - if infer_files and isinstance(v, str): - if any([v.endswith(ext) for ext in valid_extensions]): - return LatchFile - elif v.endswith("/"): + if isinstance(v, str) and is_file_like(name, v): + if v.endswith("/") or name == "outdir": return LatchDir + return LatchFile + if is_primitive_value(v): return type(v) if isinstance(v, list): - parsed_types = tuple(parse_type(x, name, infer_files=infer_files) for x in v) + parsed_types = tuple(parse_type(x, name) for x in v) if len(set(parsed_types)) != 1: raise ValueError( "Generic Lists are not supported - please" f" ensure that all elements in {name} are of the same type" ) + typ = parsed_types[0] - if typ in {LatchFile, LatchDir}: - return Annotated[List[typ], FlyteAnnotation({"size": len(v)})] - return List[typ] - assert isinstance(v, dict) + return list[typ] - if name is None: - name = "SnakemakeRecord" + assert isinstance(v, dict) - fields: Dict[str, Type] = {} + fields: dict[str, type] = {} for k, x in v.items(): - fields[identifier_from_str(k)] = parse_type( - x, f"{name}_{k}", infer_files=infer_files - ) + fields[identifier_from_str(k)] = parse_type(x, f"{name}_{k}") + + return make_dataclass(best_effort_title_case(f"{name}_type"), fields.items()) - return make_dataclass(identifier_from_str(name), fields.items()) +T = TypeVar("T") -# returns raw value and generated default -def parse_value(t: Type, v: JSONValue): + +def parse_value(t: type[T], v: JSONValue) -> T: if v is None: - return None, None + assert t is type(None) + return None if get_origin(t) is Annotated: args = get_args(t) @@ -157,12 +177,11 @@ def parse_value(t: Type, v: JSONValue): return parse_value(args[0], v) if t in {LatchFile, LatchDir}: - # ayush: autogenerated defaults don't make sense for files/dirs since their - # value in the config is their local path - return v, None + assert isinstance(v, str) + return t(v) if is_primitive_value(v): - return v, v + return v if isinstance(v, list): assert get_origin(t) is list @@ -171,29 +190,26 @@ def parse_value(t: Type, v: JSONValue): assert len(args) > 0 sub_type = args[0] - res = [parse_value(sub_type, x) for x in v] - return [x[0] for x in res], [x[1] for x in res] + return [parse_value(sub_type, x) for x in v] assert isinstance(v, dict), v assert is_dataclass(t), t - ret = {} defaults = {} fs = {identifier_from_str(f.name): f for f in fields(t)} for k, x in v.items(): sanitized = identifier_from_str(k) assert sanitized in fs, sanitized - val, default = parse_value(fs[sanitized].type, x) - ret[sanitized] = val + default = parse_value(fs[sanitized].type, x) defaults[sanitized] = default - return t(**ret), t(**defaults) + return t(**defaults) def is_primitive_type( - typ: Type, -) -> TypeGuard[Union[Type[None], Type[str], Type[bool], Type[int], Type[float]]]: + typ: type, +) -> TypeGuard[Union[type[None], type[str], type[bool], type[int], type[float]]]: return typ in {type(None), str, bool, int, float} @@ -201,11 +217,14 @@ def is_primitive_value(val: object) -> TypeGuard[Union[None, str, bool, int, flo return is_primitive_type(type(val)) -def is_list_type(typ: Type) -> TypeGuard[Type[List]]: +def is_list_type(typ: type) -> TypeGuard[type[list[object]]]: return get_origin(typ) is list -def type_repr(t: Type, *, add_namespace: bool = False) -> str: +def type_repr(t: type[Any] | str, *, add_namespace: bool = False) -> str: + if isinstance(t, str): + return type_repr(eval(t), add_namespace=add_namespace) + if is_primitive_type(t) or t in {LatchFile, LatchDir}: return t.__name__ @@ -227,9 +246,16 @@ def type_repr(t: Type, *, add_namespace: bool = False) -> str: return "typing.List" - if get_origin(t) is Union: + if get_origin(t) is dict: args = get_args(t) + if len(args) != 2: + return "typing.Dict" + + s = ", ".join([type_repr(x, add_namespace=add_namespace) for x in args]) + return f"typing.Dict[{s}]" + if get_origin(t) is Union: + args = get_args(t) if len(args) != 2 or args[1] is not type(None): raise ValueError("Union types other than Optional are not yet supported") @@ -241,7 +267,7 @@ def type_repr(t: Type, *, add_namespace: bool = False) -> str: if isinstance(args[1], FlyteAnnotation): return ( f"typing_extensions.Annotated[{type_repr(args[0], add_namespace=add_namespace)}," - f" FlyteAnnotation({repr(args[1].data)})]" + f" FlyteAnnotation({args[1].data!r})]" ) return type_repr(args[0], add_namespace=add_namespace) @@ -273,14 +299,12 @@ def field_repr(f: Field[object]) -> str: suffix = "" if len(args) > 0: - suffix = ( - f" = field({', '.join(f'{k}={value_repr(v)}' for k, v in args.items())})" - ) + suffix = f" = field({', '.join(f'{k}={value_repr(v)}' for k, v in args.items())})" return f"{f.name}: {type_repr(f.type)}{suffix}" -def dataclass_repr(typ: Type) -> str: +def dataclass_repr(typ: type) -> str: assert is_dataclass(typ) lines = ["@dataclass", f"class {typ.__name__}:"] @@ -290,39 +314,56 @@ def dataclass_repr(typ: Type) -> str: return "\n".join(lines) + "\n\n\n" -def enum_repr(typ: Type) -> str: +def enum_repr(typ: type) -> str: assert issubclass(typ, Enum), typ lines = [f"class {typ.__name__}(Enum):"] for name, val in typ._member_map_.items(): - lines.append(f" {name} = {repr(val.value)}") + lines.append(f" {name} = {val.value!r}") return "\n".join(lines) + "\n\n\n" -def get_preamble(typ: Type) -> str: +def get_preamble(typ: type[Any] | str, *, defined_names: set[str] | None = None) -> str: + # ayush: some dataclass fields have strings as their types so attempt to eval them here + if isinstance(typ, str): + try: + typ = eval(typ) + except Exception: + return "" + + assert not isinstance(typ, str) + + if defined_names is None: + defined_names = set() + if get_origin(typ) is Annotated: args = get_args(typ) assert len(args) > 0 - return get_preamble(args[0]) + return get_preamble(args[0], defined_names=defined_names) if is_primitive_type(typ) or typ in {LatchFile, LatchDir, LPath}: return "" - if get_origin(typ) in {Union, list, SamplesheetItem}: - return "".join([get_preamble(t) for t in get_args(typ)]) + if get_origin(typ) in {Union, UnionType, list, dict, SamplesheetItem}: + return "".join([get_preamble(t, defined_names=defined_names) for t in get_args(typ)]) + + if typ.__name__ in defined_names: + return "" + + defined_names.add(typ.__name__) if issubclass(typ, Enum): return enum_repr(typ) assert is_dataclass(typ), typ - preamble = "".join([get_preamble(f.type) for f in fields(typ)]) + preamble = "".join([get_preamble(f.type, defined_names=defined_names) for f in fields(typ)]) return "".join([preamble, dataclass_repr(typ)]) -def validate_snakemake_type(name: str, t: Type, param: Any) -> None: +def validate_snakemake_type(name: str, t: type, param: Any) -> None: if t is type(None) and param is not None: raise ValueError("parameter of type `NoneType` must be None") @@ -356,8 +397,7 @@ def validate_snakemake_type(name: str, t: Type, param: Any) -> None: args = get_args(t) if len(args) == 0: raise ValueError( - "Generic Lists are not supported - please specify a subtype," - " e.g. List[LatchFile]" + "Generic Lists are not supported - please specify a subtype, e.g. List[LatchFile]" ) list_typ = args[0] for i, val in enumerate(param): @@ -366,6 +406,6 @@ def validate_snakemake_type(name: str, t: Type, param: Any) -> None: else: assert is_dataclass(t) for field in fields(t): - validate_snakemake_type( - f"{name}.{field.name}", field.type, getattr(param, field.name) - ) + validate_snakemake_type(f"{name}.{field.name}", field.type, getattr(param, field.name)) + for i, val in enumerate(param): + validate_snakemake_type(f"{name}[{i}]", list_typ, val) diff --git a/src/latch_cli/snakemake/v2/__init__.py b/src/latch_cli/snakemake/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/latch_cli/snakemake/v2/utils.py b/src/latch_cli/snakemake/v2/utils.py new file mode 100644 index 000000000..0e9431f6e --- /dev/null +++ b/src/latch_cli/snakemake/v2/utils.py @@ -0,0 +1,31 @@ +from dataclasses import fields, is_dataclass +from enum import Enum +from urllib.parse import urlparse + +from latch.types.directory import LatchDir +from latch.types.file import LatchFile + + +def get_config_val(val: object): + if isinstance(val, list): + return [get_config_val(x) for x in val] + if isinstance(val, dict): + return {k: get_config_val(v) for k, v in val.items()} + if isinstance(val, (LatchFile, LatchDir)): + if val.remote_path is None: + return str(val.path) + + parsed = urlparse(val.remote_path) + domain = parsed.netloc + if domain == "": + domain = "inferred" + + return f"/ldata/{domain}{parsed.path}" + if isinstance(val, (int, float, bool, type(None))): + return val + if is_dataclass(val): + return {f.name: get_config_val(getattr(val, f.name)) for f in fields(val)} + if isinstance(val, Enum): + return val.value + + return str(val) diff --git a/src/latch_cli/snakemake/v2/workflow.py b/src/latch_cli/snakemake/v2/workflow.py new file mode 100644 index 000000000..0594bcf53 --- /dev/null +++ b/src/latch_cli/snakemake/v2/workflow.py @@ -0,0 +1,134 @@ +from pathlib import Path + +import latch.types.metadata.snakemake_v2 as snakemake + +_template = """\ +import json +import os +import shutil +import subprocess +import sys +import typing +import typing_extensions +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +import requests + +from latch.resources.tasks import custom_task, snakemake_runtime_task +from latch.resources.workflow import snakemake_workflow +from latch.types.directory import LatchDir, LatchOutputDir +from latch.types.file import LatchFile +from latch_cli.snakemake.v2.utils import get_config_val +from latch_cli.services.register.utils import import_module_by_path + +latch_metadata = import_module_by_path(Path({metadata_path})) + +import latch.types.metadata.snakemake_v2 as smv2 + + +@custom_task(cpu=0.25, memory=0.5, storage_gib=1) +def initialize() -> str: + token = os.environ.get("FLYTE_INTERNAL_EXECUTION_ID") + if token is None: + raise RuntimeError("failed to get execution token") + + headers = {{"Authorization": f"Latch-Execution-Token {{token}}"}} + + print("Provisioning shared storage volume... ", end="") + resp = requests.post( + "http://nf-dispatcher-service.flyte.svc.cluster.local/provision-storage-ofs", + headers=headers, + json={{ + "storage_expiration_hours": 0, + "version": 2, + "snakemake": True, + }}, + ) + resp.raise_for_status() + print("Done.") + + return resp.json()["name"] + +@snakemake_runtime_task(cpu=1, memory=2, storage_gib=50) +def snakemake_runtime(pvc_name: str, args: latch_metadata.WorkflowArgsType): + print(f"Using shared filesystem: {{pvc_name}}") + + shared = Path("/snakemake-workdir") + snakefile = shared / {snakefile_path} + + config = get_config_val(args) + + config_path = (shared / "__latch.config.json").resolve() + config_path.write_text(json.dumps(config, indent=2)) + + ignore_list = [ + "latch", + ".latch", + ".git", + "nextflow", + ".nextflow", + ".snakemake", + "results", + "miniconda", + "anaconda3", + "mambaforge", + ] + + shutil.copytree( + Path("/root"), + shared, + ignore=lambda src, names: ignore_list, + ignore_dangling_symlinks=True, + dirs_exist_ok=True, + ) + + cmd = [ + "snakemake", + "--snakefile", + str(snakefile), + "--configfile", + str(config_path), + "--executor", + "latch", + "--default-storage-provider", + "latch", + "--jobs", + "1000", + ] + + print("Launching Snakemake Runtime") + print(" ".join(cmd), flush=True) + + failed = False + try: + subprocess.run(cmd, cwd=shared, check=True) + except subprocess.CalledProcessError: + failed = True + finally: + if not failed: + return + + sys.exit(1) + + +@snakemake_workflow(smv2._snakemake_v2_metadata) +def {workflow_name}(args: latch_metadata.WorkflowArgsType): + \"\"\" + Sample Description + \"\"\" + + snakemake_runtime(pvc_name=initialize(), args=args) +""" + + +def get_entrypoint_content(pkg_root: Path, metadata_path: Path, snakefile_path: Path) -> str: + metadata = snakemake._snakemake_v2_metadata + assert metadata is not None + + return _template.format( + metadata_path=repr(str(metadata_path.relative_to(pkg_root))), + snakefile_path=repr(str(snakefile_path.relative_to(pkg_root))), + workflow_name=metadata.name, + ) diff --git a/src/latch_cli/utils/__init__.py b/src/latch_cli/utils/__init__.py index 3548e6853..2137d70e7 100644 --- a/src/latch_cli/utils/__init__.py +++ b/src/latch_cli/utils/__init__.py @@ -96,8 +96,7 @@ def sub_from_jwt(token: str) -> str: sub = payload["sub"] except KeyError: raise ValueError( - "Provided token lacks a user sub in the data payload" - " and is not a valid token." + "Provided token lacks a user sub in the data payload and is not a valid token." ) return sub @@ -156,9 +155,7 @@ def human_readable_time(t_seconds: float) -> str: def hash_directory(dir_path: Path, *, silent: bool = False) -> str: # todo(maximsmol): store per-file hashes to show which files triggered a version change if not silent: - click.secho( - "Calculating workflow version based on file content hash", bold=True - ) + click.secho("Calculating workflow version based on file content hash", bold=True) click.secho(" Disable with --disable-auto-version/-d", italic=True, dim=True) m = hashlib.new("sha256") @@ -203,8 +200,7 @@ def hash_directory(dir_path: Path, *, silent: bool = False) -> str: if not stat.S_ISREG(p_stat.st_mode): if not silent: click.secho( - f"{p.relative_to(dir_path.resolve())} is not a regular file." - " Ignoring contents", + f"{p.relative_to(dir_path.resolve())} is not a regular file. Ignoring contents", fg="yellow", bold=True, ) @@ -225,9 +221,7 @@ def hash_directory(dir_path: Path, *, silent: bool = False) -> str: return m.hexdigest() -def generate_temporary_ssh_credentials( - ssh_key_path: Path, *, add_to_agent: bool = True -) -> str: +def generate_temporary_ssh_credentials(ssh_key_path: Path, *, add_to_agent: bool = True) -> str: # check if there is already a valid key at that path, and if so, use that # otherwise, if its not valid, remove it if ssh_key_path.exists(): @@ -242,14 +236,10 @@ def generate_temporary_ssh_credentials( raise # if both files are valid and their fingerprints match, use them instead of generating a new pair - click.secho( - f"Found existing key pair at {ssh_key_path}.", dim=True, italic=True - ) + click.secho(f"Found existing key pair at {ssh_key_path}.", dim=True, italic=True) except: click.secho( - f"Found malformed key-pair at {ssh_key_path}. Overwriting.", - dim=True, - italic=True, + f"Found malformed key-pair at {ssh_key_path}. Overwriting.", dim=True, italic=True ) ssh_key_path.unlink(missing_ok=True) @@ -363,13 +353,8 @@ def generate(self): self._public_key = generate_temporary_ssh_credentials(self._ssh_key_path) def cleanup(self): - if ( - self._ssh_key_path.exists() - and self._ssh_key_path.with_suffix(".pub").exists() - ): - subprocess.run( - ["ssh-add", "-d", self._ssh_key_path], check=True, capture_output=True - ) + if self._ssh_key_path.exists() and self._ssh_key_path.with_suffix(".pub").exists(): + subprocess.run(["ssh-add", "-d", self._ssh_key_path], check=True, capture_output=True) self._ssh_key_path.unlink(missing_ok=True) self._ssh_key_path.with_suffix(".pub").unlink(missing_ok=True) @@ -442,10 +427,7 @@ def check_exists_and_rename(old: Path, new: Path): return if new.is_file(): - print( - f"Warning: {old} is a directory but {new} is not. {new} will be" - " overwritten." - ) + print(f"Warning: {old} is a directory but {new} is not. {new} will be overwritten.") shutil.rmtree(new) os.renames(old, new) return @@ -455,7 +437,17 @@ def check_exists_and_rename(old: Path, new: Path): underscores = re.compile(r"_+") +spaces = re.compile(r"\s+") def best_effort_display_name(x: str) -> str: return underscores.sub(" ", x).title().strip() + + +def best_effort_title_case(s: str) -> str: + return identifier_from_str(spaces.sub("", underscores.sub(" ", s).title())) + + +def exit(msg: str, *, exit_code: int = 1) -> click.exceptions.Exit: + click.secho(msg, fg="red") + return click.exceptions.Exit(exit_code) diff --git a/src/latch_cli/utils/stateful_writer.py b/src/latch_cli/utils/stateful_writer.py new file mode 100644 index 000000000..13e7ea0bb --- /dev/null +++ b/src/latch_cli/utils/stateful_writer.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + + +class StatefulWriter: + def __init__(self, indent: int = 4): + self._indent = " " * indent + + self._buf = [] + self._cur = "" + + @contextmanager + def indent(self): + self._cur += self._indent + yield + self._cur = self._cur.removesuffix(self._indent) + + def clear(self): + self._buf = [] + self._cur = "" + + def write(self, s: str, *, nl: bool = True): + self._buf.append(self._indent) + self._buf.append(s) + + if nl: + self._buf.append("\n") + + def get(self): + return "".join(self._buf)