diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..54d397dc0a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +# syntax = docker/dockerfile:1.4.1 + +from alpine:3.22.0 + +run apk add \ + bash \ + curl \ + openjdk21-jre-headless + +run curl -sSL https://github.com/jqlang/jq/releases/download/jq-1.8.1/jq-linux-amd64 -o /bin/jq +run chmod +x /bin/jq + +copy ./.nextflow /root/.nextflow +copy ./nextflow /usr/bin/nextflow diff --git a/Justfile b/Justfile index 12953c0b00..2ad570420c 100644 --- a/Justfile +++ b/Justfile @@ -8,6 +8,17 @@ build-sync: CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -ldflags '-extldflags "-static"' -o custom_fsync.bin custom_fsync/sync.go chmod +x custom_fsync +image_name := "812206152185.dkr.ecr.us-west-2.amazonaws.com/forch-nf-runtime" + +@dbnp: + cp -rf ~/.nextflow ./ + rm -rf .nextflow/plugins/* + + docker build --platform linux/amd64 -t {{image_name}}:$( { // return the run command as result runCommand = result.toString() + log.warn(runCommand) + // use an explicit 'docker rm' command since the --rm flag may fail. See https://groups.google.com/d/msg/docker-user/0Ayim0wv2Ls/tDC-tlAK03YJ if( remove && name ) { removeCommand = 'docker rm ' + name diff --git a/modules/nextflow/src/main/groovy/nextflow/executor/BashWrapperBuilder.groovy b/modules/nextflow/src/main/groovy/nextflow/executor/BashWrapperBuilder.groovy index 196ad9cfa5..92c34659ac 100644 --- a/modules/nextflow/src/main/groovy/nextflow/executor/BashWrapperBuilder.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/executor/BashWrapperBuilder.groovy @@ -444,14 +444,8 @@ class BashWrapperBuilder { int attempt=0 while( true ) { try { - // note(taras): always sync to disk to ensure that the file is visible to other clients - try( - FileOutputStream fos = new FileOutputStream(path.toFile()); - BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(fos)) - ) { + try (BufferedWriter writer=Files.newBufferedWriter(path, CREATE,WRITE,TRUNCATE_EXISTING)) { writer.write(data) - writer.flush() - fos.getFD().sync() } return path } diff --git a/modules/nextflow/src/main/groovy/nextflow/executor/ExecutorFactory.groovy b/modules/nextflow/src/main/groovy/nextflow/executor/ExecutorFactory.groovy index 57a1535b6c..3bd3253422 100644 --- a/modules/nextflow/src/main/groovy/nextflow/executor/ExecutorFactory.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/executor/ExecutorFactory.groovy @@ -22,6 +22,7 @@ import groovy.transform.PackageScope import groovy.util.logging.Slf4j import nextflow.Session import nextflow.executor.local.LocalExecutor +import nextflow.forch.ForchExecutor import nextflow.k8s.K8sExecutor import nextflow.script.BodyDef import nextflow.script.ProcessConfig @@ -61,7 +62,8 @@ class ExecutorFactory { 'nqsii': NqsiiExecutor, 'moab': MoabExecutor, 'oar': OarExecutor, - 'hq': HyperQueueExecutor + 'hq': HyperQueueExecutor, + 'forch': ForchExecutor, ] @PackageScope Map> executorsMap diff --git a/modules/nextflow/src/main/groovy/nextflow/executor/LatchPathFactory.groovy b/modules/nextflow/src/main/groovy/nextflow/executor/LatchPathFactory.groovy index 0c5de932b1..001cf99dc1 100644 --- a/modules/nextflow/src/main/groovy/nextflow/executor/LatchPathFactory.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/executor/LatchPathFactory.groovy @@ -41,7 +41,7 @@ class LatchPathFactory extends FileSystemPathFactory { @Override protected String getBashLib(Path target) { - if (target.scheme != "latch") { + if (target == null || target.scheme != "latch") { return null } diff --git a/modules/nextflow/src/main/groovy/nextflow/forch/ForchExecutor.groovy b/modules/nextflow/src/main/groovy/nextflow/forch/ForchExecutor.groovy new file mode 100644 index 0000000000..af8612cd16 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/forch/ForchExecutor.groovy @@ -0,0 +1,64 @@ +package nextflow.forch + +import java.nio.file.Path + +import groovy.util.logging.Slf4j +import nextflow.executor.Executor +import nextflow.extension.FilesEx +import nextflow.processor.TaskHandler +import nextflow.processor.TaskMonitor +import nextflow.processor.TaskPollingMonitor +import nextflow.processor.TaskRun +import nextflow.util.DispatcherClient +import nextflow.util.ForchClient +import nextflow.util.Duration + +@Slf4j +class ForchExecutor extends Executor { + + Path remoteBinDir = null + private ForchClient forchClient + + @Override + protected TaskMonitor createTaskMonitor() { + return TaskPollingMonitor.create(session, name, 100, Duration.of("15s")) + } + + @Override + protected void register() { + // todo(ayush): decouple dispatcher and executor + this.dispatcherClient = new DispatcherClient() + this.forchClient = new ForchClient() + + this.session.addIgniter { + this.dispatcherClient.updateExecutionStatus("RUNNING") + } + + uploadBinDir() + } + + @Override + TaskHandler createTaskHandler(TaskRun task) { + return new ForchTaskHandler(task, remoteBinDir, session, this.forchClient, this.dispatcherClient) + } + + protected void uploadBinDir() { + if( session.binDir && !session.binDir.empty() ) { + def s3 = getTempDir() + log.info "Uploading local `bin` scripts folder to ${s3.toUriString()}/bin" + remoteBinDir = FilesEx.copyTo(session.binDir, s3) + } + } + + @Override + void shutdown() { + def status = session.success ? "SUCCEEDED" : ((session.aborted || session.cancelled) ? "ABORTED" : "FAILED") + this.dispatcherClient.updateExecutionStatus(status) + + String nfsServerTaskId = System.getenv("nfs_server_task_id") + if (nfsServerTaskId != null) + this.forchClient.abortTasks([Integer.parseInt(nfsServerTaskId)]) + + super.shutdown() + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/forch/ForchFileCopyStrategy.groovy b/modules/nextflow/src/main/groovy/nextflow/forch/ForchFileCopyStrategy.groovy new file mode 100644 index 0000000000..5a3d4ef764 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/forch/ForchFileCopyStrategy.groovy @@ -0,0 +1,104 @@ +package nextflow.forch + +import java.nio.file.Path + +import nextflow.executor.BashFunLib +import nextflow.executor.SimpleFileCopyStrategy +import nextflow.util.Escape + +class ForchFileCopyStrategy extends SimpleFileCopyStrategy { + + @Override + String getBeforeStartScript() { + def lib = new BashFunLib().coreLib() + + return lib + "\n\n" + """\ + nxf_s5cmd_upload() { + local name=\$1 + local s3path=\$2 + if [[ "\$name" == - ]]; then + echo 's5cmd --no-verify-ssl pipe "\$s3path"' + s5cmd --no-verify-ssl pipe "\$s3path" + elif [[ -d "\$name" ]]; then + s5cmd --no-verify-ssl cp "\$name" "\$s3path/" + else + s5cmd --no-verify-ssl cp "\$name" "\$s3path/\$name" + fi + } + + nxf_s5cmd_download() { + local source=\$1 + local target=\$2 + local file_name=\$(basename \$1) + local is_dir=\$(s5cmd --no-verify-ssl ls \$source | grep -F "DIR \${file_name}/" -c) + if [[ \$is_dir == 1 ]]; then + s5cmd --no-verify-ssl cp "\$source/*" "\$target" + else + s5cmd --no-verify-ssl cp "\$source" "\$target" + fi + } + + + """.stripIndent() + } + + @Override + String getStageInputFilesScript(Map inputFiles) { + def result = 'downloads=(true)\n' + result += super.getStageInputFilesScript(inputFiles) + '\n' + result += 'nxf_parallel "${downloads[@]}"\n' + return result + } + + @Override + protected String stageInCommand(String source, String target, String mode) { + return "downloads+=(\"nxf_s5cmd_download s3:/${Escape.path(source)} ${Escape.path(target)}\")" + } + + @Override + String getUnstageOutputFilesScript(List outputFiles, Path targetDir) { + final patterns = normalizeGlobStarPaths(outputFiles) + + if( !patterns ) + return null + + final escape = new ArrayList(outputFiles.size()) + for( String it : patterns ) + escape.add( Escape.path(it) ) + + return """\ + uploads=() + IFS=\$'\\n' + for name in \$(eval "ls -1d ${escape.join(' ')}" | sort | uniq); do + uploads+=("nxf_s5cmd_upload '\$name' s3:/${Escape.path(targetDir)}") + done + unset IFS + nxf_parallel "\${uploads[@]}" + """.stripIndent(true) + } + + @Override + String touchFile(Path file) { + return "echo start | s5cmd --no-verify-ssl pipe s3:/${Escape.path(file)}" + } + + @Override + String fileStr( Path path ) { + Escape.path(path.getFileName()) + } + + @Override + String copyFile( String name, Path target ) { + "s5cmd --no-verify-ssl cp ${Escape.path(name)} s3:/${Escape.path(target)}" + } + + @Override + String exitFile(Path file) { + return "| s5cmd --no-verify-ssl pipe s3:/${Escape.path(file)} || true" + } + + @Override + String pipeInputFile(Path file) { + return " < ${Escape.path(file.getFileName())}" + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskHandler.groovy b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskHandler.groovy new file mode 100644 index 0000000000..e212c19b03 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskHandler.groovy @@ -0,0 +1,140 @@ +package nextflow.forch + +import nextflow.util.DispatcherClient +import nextflow.util.ForchClient + +import java.nio.file.Path + +import groovy.util.logging.Slf4j + +import nextflow.Session + +import nextflow.processor.TaskHandler +import nextflow.processor.TaskRun +import nextflow.processor.TaskStatus +import nextflow.script.ProcessConfig +import nextflow.util.MemoryUnit + +@Slf4j +class ForchTaskHandler extends TaskHandler { + + ProcessConfig processConfig + Integer forchTaskId + Path remoteBinDir = null + private ForchClient forchClient + private DispatcherClient dispatcherClient + Session session + + ForchTaskHandler(TaskRun task, Path remoteBinDir, Session session, ForchClient forchClient, DispatcherClient dispatcherClient) { + super(task) + + this.processConfig = task.processor.config + this.remoteBinDir = remoteBinDir + this.forchClient = forchClient + this.dispatcherClient = dispatcherClient + + this.session = session + } + + private String getCurrentStatus() { + if (this.forchTaskId == null) return + + return this.forchClient.getTaskStatus(this.forchTaskId) + } + + @Override + boolean checkIfRunning() { + def running = this.currentStatus == 'RUNNING' + if (running) + status = TaskStatus.RUNNING + return running + } + + @Override + boolean checkIfCompleted() { + def cur = this.currentStatus + if (cur != "SUCCEEDED" && cur != "FAILED") return false + + // todo(ayush): single query + task.exitStatus = this.forchClient.getTaskExitCode(this.forchTaskId) + + // todo(ayush): logs, retries + task.stdout = "" + task.stderr = "" + status = TaskStatus.COMPLETED + return true + } + + @Override + void kill() { + forchClient.abortTasks([forchTaskId]) + } + + @Override + void prepareLauncher() { + new ForchTaskWrapperBuilder(this.task.toTaskBean()).build() + } + + @Override + void submit() { + int cpus = task.config.getCpus() + MemoryUnit memory = task.config.getMemory() ?: MemoryUnit.of("2GiB") + + final containerOpts = task.config.getContainerOptionsMap() + + MemoryUnit shm = null; + if (containerOpts != null && containerOpts.exists("shm-size")) { + shm = new MemoryUnit(containerOpts.getFirstValue("shm-size") as String) + } + + // todo(ayush): gpu support + // AcceleratorResource acc = task.config.getAccelerator() + + def serverIp = System.getenv("latch_internal_nfs_server_ip") + if (serverIp == null) + throw new RuntimeException("failed to get server ip") + + String cmd = """\ + mkdir --parents ${session.baseDir} + + chown -R root:root /usr/bin/mount 2>&1 > /dev/null + + until mount -t nfs4 [${serverIp}]:/ ${session.baseDir} 2>&1 > /dev/null + do + sleep 5 + done + + cat ${task.workDir}/${TaskRun.CMD_RUN} | bash 2>&1 + """.stripIndent().trim() + + if (remoteBinDir != null) { + cmd = """\ + mkdir -p /nextflow-bin + cp ${remoteBinDir}/* /nextflow-bin + chmod +x /nextflow-bin/* + export PATH=/nextflow-bin:\$PATH + """.stripIndent() + cmd + } + + List entrypoint = [ + "/bin/bash", + "-c", + cmd, + ] + + this.forchTaskId = this.forchClient.submitTask( + this.task.name, + this.task.container, + entrypoint, + cpus, + memory.bytes, + shm?.bytes ?: 0 + ) + + // todo(rahul): put this in a single transaction with submitTask + this.dispatcherClient.updateForchTaskId( + this.taskExecutionId, + this.forchTaskId + ) + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskMonitor.groovy b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskMonitor.groovy new file mode 100644 index 0000000000..35150dd9d9 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskMonitor.groovy @@ -0,0 +1,28 @@ +package nextflow.forch + +import groovy.util.logging.Slf4j +import nextflow.processor.TaskHandler +import nextflow.processor.TaskMonitor + +@Slf4j +class ForchTaskMonitor implements TaskMonitor { + @Override + void schedule(TaskHandler handler) { + + } + + @Override + boolean evict(TaskHandler handler) { + return false + } + + @Override + TaskMonitor start() { + return null + } + + @Override + void signal() { + // noop + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskWrapperBuilder.groovy b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskWrapperBuilder.groovy new file mode 100644 index 0000000000..64556c9b8a --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/forch/ForchTaskWrapperBuilder.groovy @@ -0,0 +1,30 @@ +package nextflow.forch + +import java.nio.file.Path + +import nextflow.executor.BashWrapperBuilder +import nextflow.executor.SimpleFileCopyStrategy +import nextflow.processor.TaskBean +import nextflow.processor.TaskRun + +class ForchTaskWrapperBuilder extends BashWrapperBuilder { + // entirely lifted from AWS Batch Wrapper + ForchTaskWrapperBuilder(TaskBean bean) { + super(bean, new SimpleFileCopyStrategy()) + // enable the copying of output file to the S3 work dir + if( scratch==null ) + scratch = true + + // include task script as an input to force its staging in the container work directory + bean.inputFiles[TaskRun.CMD_SCRIPT] = bean.workDir.resolve(TaskRun.CMD_SCRIPT) + // add the wrapper file when stats are enabled + // NOTE: this must match the logic that uses the run script in BashWrapperBuilder + if( isTraceRequired() ) { + bean.inputFiles[TaskRun.CMD_RUN] = bean.workDir.resolve(TaskRun.CMD_RUN) + } + // include task stdin file + if( bean.input != null ) { + bean.inputFiles[TaskRun.CMD_INFILE] = bean.workDir.resolve(TaskRun.CMD_INFILE) + } + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/k8s/model/PodSpecBuilder.groovy b/modules/nextflow/src/main/groovy/nextflow/k8s/model/PodSpecBuilder.groovy index a443e50470..615d8e2542 100644 --- a/modules/nextflow/src/main/groovy/nextflow/k8s/model/PodSpecBuilder.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/k8s/model/PodSpecBuilder.groovy @@ -671,7 +671,7 @@ class PodSpecBuilder { @PackageScope - void validateAccelerator(AcceleratorResource accelerator) { + static void validateAccelerator(AcceleratorResource accelerator) { // gpu-small: nvidia-t4 (1) // gpu-large: nvidia-a10g (1) // v100-x1: nvidia-v100 (1) @@ -687,15 +687,15 @@ class PodSpecBuilder { } throw new VerifyError("""\ -Invalid GPU configuration. Latch only allows the following combinations: - - accelerator 1, type: "nvidia-t4" - - accelerator 1, type: "nvidia-a10g" - - accelerator 1, type: "nvidia-v100" - - accelerator 4, type: "nvidia-v100" - - accelerator 8, type: "nvidia-v100" - -You provided ${accelerator.type}, ${accelerator.limit} - """) + Invalid GPU configuration. Latch only allows the following combinations: + - accelerator 1, type: "nvidia-t4" + - accelerator 1, type: "nvidia-a10g" + - accelerator 1, type: "nvidia-v100" + - accelerator 4, type: "nvidia-v100" + - accelerator 8, type: "nvidia-v100" + + You provided ${accelerator.type}, ${accelerator.limit} + """.stripIndent().trim()) } diff --git a/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy b/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy index 489601cc59..40d7e4dad8 100644 --- a/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy @@ -17,6 +17,8 @@ package nextflow.processor import nextflow.file.http.LatchPath import java.nio.file.StandardCopyOption + +import nextflow.k8s.K8sExecutor import nextflow.trace.TraceRecord import static nextflow.processor.ErrorStrategy.* diff --git a/modules/nextflow/src/main/groovy/nextflow/script/BindableDef.groovy b/modules/nextflow/src/main/groovy/nextflow/script/BindableDef.groovy index e9b9fc502c..b91da88f48 100644 --- a/modules/nextflow/src/main/groovy/nextflow/script/BindableDef.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/script/BindableDef.groovy @@ -40,8 +40,9 @@ abstract class BindableDef extends ComponentDef { // use this instance an workflow template, therefore clone it final String prefix = ExecutionStack.workflow()?.name final fqName = prefix ? prefix+SCOPE_SEP+name : name + log.debug("bindable fqName: $fqName") if( this instanceof ProcessDef && !invocations.add(fqName) ) { - log.debug "Bindable invocations=$invocations" + log.debug "Bindable invocations=$invocations, ${this.toString()}" final msg = "Process '$name' has been already used -- If you need to reuse the same component, include it with a different name or include it in a different workflow context" throw new DuplicateProcessInvocation(msg) } diff --git a/modules/nextflow/src/main/groovy/nextflow/util/DispatcherClient.groovy b/modules/nextflow/src/main/groovy/nextflow/util/DispatcherClient.groovy index 318c5f5460..11651d9063 100644 --- a/modules/nextflow/src/main/groovy/nextflow/util/DispatcherClient.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/util/DispatcherClient.groovy @@ -10,41 +10,101 @@ class DispatcherClient { private GQLClient client = new GQLClient() - int createProcessNode(String processName) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { - return 1 + public boolean debug = System.getenv("LATCH_NF_DEBUG") == "true" + + void updateExecutionStatus(String status) { + if (debug) { + return } - String executionToken = System.getenv("FLYTE_INTERNAL_EXECUTION_ID") - if (executionToken == null) - throw new RuntimeException("unable to get execution token") + String executionId = System.getenv("forch_execution_id") + if (executionId == null) { + throw new RuntimeException("failed to update execution status: execution id not found") + } - Map res = client.execute(""" - mutation CreateNode(\$executionToken: String!, \$name: String!) { - createNfProcessNodeByExecutionToken(input: {argExecutionToken: \$executionToken, argName: \$name}) { - nodeId + client.execute(""" + mutation UpdateExecutionStatus( + \$argExecutionId: BigInt! + \$argStatus: ExecutionStatus! + ) { + updateExecutionInfo( + input: { + id: \$argExecutionId, + patch: { + status: \$argStatus + } + } + ) { + clientMutationId } } """, [ - executionToken: executionToken, - name: processName, + argExecutionId: executionId, + argStatus: status, ] - )["createNfProcessNodeByExecutionToken"] as Map + ) + } - if (res == null) - throw new RuntimeException("failed to create remote process node for: processName=${processName}") + int createProcessNode(String processName) { + if (debug) { + return 1 + } - return (res.nodeId as String).toInteger() + String executionToken = System.getenv("FLYTE_INTERNAL_EXECUTION_ID") + if (executionToken != null) { + Map res = client.execute(""" + mutation CreateNode(\$executionToken: String!, \$name: String!) { + createNfProcessNodeByExecutionToken(input: {argExecutionToken: \$executionToken, argName: \$name}) { + nodeId + } + } + """, + [ + executionToken: executionToken, + name: processName, + ] + )["createNfProcessNodeByExecutionToken"] as Map + + if (res == null) + throw new RuntimeException("failed to create remote process node for: processName=${processName}") + + return (res.nodeId as String).toInteger() + } + + String executionId = System.getenv("forch_execution_id") + if (executionId != null) { + Map res = client.execute(""" + mutation CreateNode(\$executionId: BigInt!, \$name: String!) { + createNfProcessNode(input: {nfProcessNode: {executionId: \$executionId, name: \$name } }) { + nfProcessNode { + id + } + } + } + """, + [ + executionId: executionId, + name: processName, + ] + )["createNfProcessNode"] as Map + + if (res == null || res["nfProcessNode"] == null) + throw new RuntimeException("failed to create remote process node for: processName=${processName}") + + return (res["nfProcessNode"]["id"] as String).toInteger() + } + + throw new RuntimeException("failed to create process node: unable to get source execution") } void closeProcessNode(int nodeId, int numTasks) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return } client.execute(""" - mutation CreateTaskInfo(\$nodeId: BigInt!, \$numTasks: BigInt!) { + mutation UpdateTaskInfo(\$nodeId: BigInt!, \$numTasks: BigInt!) { updateNfProcessNode( input: { id: \$nodeId, @@ -65,7 +125,7 @@ class DispatcherClient { } void createProcessEdge(int from, int to) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return } @@ -91,7 +151,7 @@ class DispatcherClient { } int createProcessTask(int processNodeId, int index, String tag) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return 1 } @@ -150,10 +210,71 @@ class DispatcherClient { } int createTaskExecution(int taskId, int attemptIdx, String hash, String status = null) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return 1 } + String forchExecutionId = System.getenv("forch_execution_id") + if (forchExecutionId != null) { + try { + Map res = client.execute(""" + mutation CreateForchTaskExecutionInfo(\$taskId: BigInt!, \$attemptIdx: BigInt!, \$cached: Boolean!, \$hash: String) { + createNfForchTaskExecutionInfo( + input: { + nfForchTaskExecutionInfo: { + taskId: \$taskId, + attemptIdx: \$attemptIdx, + cached: \$cached, + hash: \$hash + } + } + ) { + nfForchTaskExecutionInfo { + id + } + } + } + """, + [ + taskId: taskId, + attemptIdx: attemptIdx, + cached: status == 'SKIPPED', + hash: hash, + ] + )["createNfForchTaskExecutionInfo"] as Map + + if (res == null) + throw new RuntimeException("failed to create remote task execution for: taskId=${taskId} attempt=${attemptIdx} hash=${hash}") + + return ((res.nfForchTaskExecutionInfo as Map).id as String).toInteger() + } catch (GQLQueryException e) { + + // note(rahul): the gql client uses the HTTP Retry Client. As a result, it may retry a request after + // successfully committing the row to the DB (for example, if the connection fails) + if (!e.message.contains("duplicate key value violates unique constraint")) { + throw e + } + } + + Map res = client.execute(""" + query GetNfForchTaskExecutionInfo(\$taskId: BigInt!, \$attemptIdx: BigInt!) { + nfForchTaskExecutionInfoByTaskIdAndAttemptIdx(attemptIdx: \$attemptIdx, taskId: \$taskId) { + id + } + } + """, + [ + taskId: taskId, + attemptIdx: attemptIdx, + ] + )["nfForchTaskExecutionInfoByTaskIdAndAttemptIdx"] as Map + + if (res == null) + throw new RuntimeException("failed to get forch task execution id for: taskId=${taskId} attemptIdx=${attemptIdx}") + + return (res.id as String).toInteger() + } + try { Map res = client.execute(""" mutation CreateTaskExecutionInfo(\$taskId: BigInt!, \$attemptIdx: BigInt!, \$hash: String, \$status: TaskExecutionStatus!) { @@ -218,6 +339,8 @@ class DispatcherClient { } void submitPod(int taskExecutionId, Map pod) { + if (debug) return + client.execute(""" mutation UpdateTaskExecution(\$taskExecutionId: BigInt!, \$podSpec: String!) { updateNfTaskExecutionInfo( @@ -241,7 +364,7 @@ class DispatcherClient { } void updateTaskStatus(int taskExecutionId, String status) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return } @@ -267,7 +390,7 @@ class DispatcherClient { } Map getTaskStatus(int taskExecutionId) { - if (System.getenv("LATCH_NF_DEBUG") == "true") { + if (debug) { return null } @@ -292,4 +415,30 @@ class DispatcherClient { return res } + + void updateForchTaskId(int taskExecutionId, int forchTaskId) { + if (debug) { + return + } + + client.execute(""" + mutation UpdateTaskExecution(\$taskExecutionId: BigInt!, \$forchTaskId: BigInt!) { + updateNfForchTaskExecutionInfo( + input: { + id: \$taskExecutionId, + patch: { + forchTaskId: \$forchTaskId + }, + } + ) { + clientMutationId + } + } + """, + [ + taskExecutionId: taskExecutionId, + forchTaskId: forchTaskId + ] + ) + } } diff --git a/modules/nextflow/src/main/groovy/nextflow/util/ForchClient.groovy b/modules/nextflow/src/main/groovy/nextflow/util/ForchClient.groovy new file mode 100644 index 0000000000..4f121d9c24 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/util/ForchClient.groovy @@ -0,0 +1,156 @@ +package nextflow.util + +import groovy.util.logging.Slf4j +import nextflow.file.http.GQLClient + +@Slf4j +class ForchClient { + private GQLClient client = new GQLClient(true) + + public int submitTask( + String displayName, + String image, + List entrypoint, + int cpus, + long memoryBytes, + long shmBytes // nullable + ) { + String resourceGroup = System.getenv("forch_resource_group_id") + if (resourceGroup == null) + throw new RuntimeException("unable to get resource group") + + String billingGroup = System.getenv("forch_billing_group_id") + if (billingGroup == null) + throw new RuntimeException("unable to get billing group") + + String nfsServerTaskId = System.getenv("nfs_server_task_id") + if (nfsServerTaskId == null) + throw new RuntimeException("unable to get NFS server task id") + + String region = System.getenv("host_region") ?: "us-west-2" + + Map res = client.execute(""" + mutation CreateForchTask( + \$displayName: String!, + \$containerImage: String!, + \$containerEntrypoint: [String]!, + \$cpus: Int!, + \$memoryBytes: BigInt!, + \$shmBytes: BigInt, + \$gpuType: String, + \$gpus: Int!, + \$groupId: BigInt!, + \$billedTo: BigInt!, + \$nfsServerTaskId: BigInt!, + \$targetRegion: String! + ) { + nfCreateForchTask( + input: { + argDisplayName: \$displayName, + argContainerImage: \$containerImage, + argContainerEntrypoint: \$containerEntrypoint, + argCpus: \$cpus, + argMemoryBytes: \$memoryBytes, + argShmBytes: \$shmBytes, + argGpuType: \$gpuType, + argGpus: \$gpus, + argGroupId: \$groupId, + argBilledTo: \$billedTo, + argNfsServerTaskId: \$nfsServerTaskId, + argTargetRegion: \$targetRegion + } + ) { + resTaskId + } + } + """, + [ + "displayName" : displayName, + "containerImage" : image, + "containerEntrypoint" : entrypoint, + "cpus" : cpus, + "memoryBytes" : memoryBytes, + "shmBytes": shmBytes == 0 ? null : shmBytes, + "gpuType" : null, + "gpus" : 0, + "groupId": resourceGroup.toInteger(), + "billedTo": billingGroup.toInteger(), + "nfsServerTaskId": nfsServerTaskId, + "targetRegion": region, + ] + )["nfCreateForchTask"] as Map + + if (res == null) + throw new RuntimeException("failed to create forch task") + + return (res.resTaskId as String).toInteger() + } + + String getTaskStatus(int forchTaskId) { + Map res = client.execute(""" + query GetTaskStatus(\$taskId: BigInt!) { + task(id: \$taskId) { + id + status + } + } + """, + [ + taskId: forchTaskId + ] + ) as Map + + if (res == null) + throw new RuntimeException("failed to get task status for ${forchTaskId}") + + return res["task"]["status"] + + } + + int getTaskExitCode(int forchTaskId) { + Map res = client.execute(""" + query GetTaskExitCode(\$taskId: BigInt!) { + taskEvents( + condition: {taskId: \$taskId}, + filter: {taskEventContainerExitedDatumByIdExists: true}, + orderBy: TIME_DESC, + first: 1 + ) { + nodes { + id + type + taskEventContainerExitedDatumById { + id + exitStatus + } + } + } + } + """, + [ + taskId: forchTaskId + ] + )["taskEvents"] as Map + + if (res == null) + throw new RuntimeException("failed to get exit code for ${forchTaskId}") + + List nodes = res["nodes"] as List + if (nodes == null || nodes.size() == 0) + return -1 + + return nodes[0]["taskEventContainerExitedDatumById"]["exitStatus"] as int + } + + void abortTasks(List taskIds) { + client.execute(""" + mutation AbortTask(\$argTaskIds: [BigInt!]!) { + nfStopForchTasks(input: { argTaskIds: \$argTaskIds }) { + clientMutationId + } + } + """, + ["argTaskIds": taskIds] + ) + } +} diff --git a/modules/nf-commons/src/main/nextflow/util/MemoryUnit.groovy b/modules/nf-commons/src/main/nextflow/util/MemoryUnit.groovy index 26495a81f4..dede90ba00 100644 --- a/modules/nf-commons/src/main/nextflow/util/MemoryUnit.groovy +++ b/modules/nf-commons/src/main/nextflow/util/MemoryUnit.groovy @@ -33,7 +33,7 @@ class MemoryUnit implements Comparable, Serializable, Cloneable { final static public MemoryUnit ZERO = new MemoryUnit(0) - final static private Pattern FORMAT = ~/([0-9\.]+)\s*(\S)?B?/ + final static private Pattern FORMAT = ~/([0-9\.]+)\s*(\S)?i?B?/ final static public List UNITS = [ "B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB" ] diff --git a/modules/nf-httpfs/src/main/nextflow/file/http/GQLClient.groovy b/modules/nf-httpfs/src/main/nextflow/file/http/GQLClient.groovy index 5ea8767073..c4c3286ba8 100644 --- a/modules/nf-httpfs/src/main/nextflow/file/http/GQLClient.groovy +++ b/modules/nf-httpfs/src/main/nextflow/file/http/GQLClient.groovy @@ -11,6 +11,7 @@ import groovy.json.JsonSlurper class GQLClient { private String endpoint private HttpRetryClient client + private boolean useForchAuth class GQLQueryException extends Exception { GQLQueryException(String msg) { @@ -18,7 +19,8 @@ class GQLClient { } } - GQLClient() { + GQLClient(boolean useForchAuth = false) { + this.useForchAuth = useForchAuth endpoint = "https://vacuole.latch.bio/graphql" String domain = System.getenv("LATCH_SDK_DOMAIN") @@ -43,7 +45,7 @@ class GQLClient { .uri(URI.create(this.endpoint)) .timeout(Duration.ofSeconds(90)) .header("Content-Type", "application/json") - .header("Authorization", LatchPathUtils.getAuthHeader()) + .header("Authorization", LatchPathUtils.getAuthHeader(useForchAuth)) HttpRequest req = requestBuilder.POST(HttpRequest.BodyPublishers.ofString(builder.toString())).build() HttpResponse response = this.client.send(req) diff --git a/modules/nf-httpfs/src/main/nextflow/file/http/LatchPath.groovy b/modules/nf-httpfs/src/main/nextflow/file/http/LatchPath.groovy index 1902185fd7..2fd570b146 100644 --- a/modules/nf-httpfs/src/main/nextflow/file/http/LatchPath.groovy +++ b/modules/nf-httpfs/src/main/nextflow/file/http/LatchPath.groovy @@ -326,11 +326,14 @@ class LatchPath extends XPath { JsonBuilder builder = new JsonBuilder() builder(["path": this.toUriString(), "part_count": numParts, "content_type": mimeType]) + def latchToken = System.getenv("latch_execution_token") + def authHeader = latchToken != null ? "Latch-Execution-Token $latchToken" : LatchPathUtils.getAuthHeader() + def request = HttpRequest.newBuilder() .uri(URI.create("${host}/ldata/start-upload")) .timeout(Duration.ofSeconds(90)) .header("Content-Type", "application/json") - .header("Authorization", LatchPathUtils.getAuthHeader()) + .header("Authorization", authHeader) .POST(HttpRequest.BodyPublishers.ofString(builder.toString())) .build() @@ -424,7 +427,7 @@ class LatchPath extends XPath { .uri(URI.create("${host}/ldata/end-upload")) .timeout(Duration.ofSeconds(90)) .header("Content-Type", "application/json") - .header("Authorization", LatchPathUtils.getAuthHeader()) + .header("Authorization", authHeader) .POST(HttpRequest.BodyPublishers.ofString(endUploadBody)) .build() @@ -435,11 +438,14 @@ class LatchPath extends XPath { JsonBuilder builder = new JsonBuilder() builder(["path": this.toUriString()]) + def latchToken = System.getenv("latch_execution_token") + def authHeader = latchToken != null ? "Latch-Execution-Token $latchToken" : LatchPathUtils.getAuthHeader() + def request = HttpRequest.newBuilder() .uri(URI.create("${host}/ldata/get-signed-url")) .timeout(Duration.ofSeconds(90)) .header("Content-Type", "application/json") - .header("Authorization", LatchPathUtils.getAuthHeader()) + .header("Authorization", authHeader) .POST(HttpRequest.BodyPublishers.ofString(builder.toString())) .build() diff --git a/modules/nf-httpfs/src/main/nextflow/file/http/LatchPathUtils.groovy b/modules/nf-httpfs/src/main/nextflow/file/http/LatchPathUtils.groovy index 3e89103a58..5c64efe7a8 100644 --- a/modules/nf-httpfs/src/main/nextflow/file/http/LatchPathUtils.groovy +++ b/modules/nf-httpfs/src/main/nextflow/file/http/LatchPathUtils.groovy @@ -7,15 +7,18 @@ class LatchPathUtils { static class UnauthenticatedException extends Exception {} - static String getAuthHeader() { - def flyteToken = System.getenv("FLYTE_INTERNAL_EXECUTION_ID") - if (flyteToken != null) - return "Latch-Execution-Token $flyteToken" - - String home = System.getProperty("user.home") - File tokenFile = new File("$home/.latch/token") - if (tokenFile.exists()) - return "Latch-SDK-Token ${tokenFile.text.strip()}" + static String getAuthHeader(boolean useForchAuth = false) { + if (useForchAuth) { + def forchToken = System.getenv("forch_auth_token") + if (forchToken != null) return "Forch-Auth-Token $forchToken" + } else { + def flyteToken = System.getenv("FLYTE_INTERNAL_EXECUTION_ID") + if (flyteToken != null) return "Latch-Execution-Token $flyteToken" + + String home = System.getProperty("user.home") + File tokenFile = new File("$home/.latch/token") + if (tokenFile.exists()) return "Latch-SDK-Token ${tokenFile.text.strip()}" + } throw new UnauthenticatedException() } @@ -75,7 +78,7 @@ class LatchPathUtils { defaultAccount } } - } + } """)["accountInfoCurrent"] as Map if (accInfo == null)