Skip to content

Commit 1eb6b61

Browse files
committed
Support accelerator directive for local executor
Signed-off-by: Ben Sherman <bentshermann@gmail.com>
1 parent 3fb8a58 commit 1eb6b61

File tree

8 files changed

+400
-5
lines changed

8 files changed

+400
-5
lines changed

docs/executor.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ The `local` executor is useful for developing and testing a pipeline script on y
227227

228228
Resource requests and other job characteristics can be controlled via the following process directives:
229229

230+
- {ref}`process-accelerator`
230231
- {ref}`process-cpus`
231232
- {ref}`process-memory`
232233
- {ref}`process-time`
@@ -241,6 +242,25 @@ The local executor supports two types of tasks:
241242
- Script tasks (processes with a `script` or `shell` block) - executed via a Bash wrapper
242243
- Native tasks (processes with an `exec` block) - executed directly in the JVM.
243244

245+
(local-accelerators)=
246+
247+
### Accelerators
248+
249+
:::{versionadded} 25.10.0
250+
:::
251+
252+
The local executor can use the `accelerator` directive to allocate accelerators, such as GPUs. To use accelerators, set the corresponding environment variable:
253+
254+
- `CUDA_VISIBLE_DEVICES` for [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables) applications
255+
256+
- `HIP_VISIBLE_DEVICES` for [HIP](https://rocm.docs.amd.com/projects/HIP/en/docs-develop/reference/env_variables.html) applications
257+
258+
- `ROCR_VISIBLE_DEVICES` for [AMD ROCm](https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html) applications
259+
260+
Set the environment variable to a comma-separated list of device IDs for Nextflow to access. Nextflow uses this environment variable to allocate accelerators for tasks that request them.
261+
262+
For example, to use all GPUs on a node with four NVIDIA GPUs, set `CUDA_VISIBLE_DEVICES` to `0,1,2,3`. If four tasks each request one GPU, they will be executed with `CUDA_VISIBLE_DEVICES` set to `0`, `1`, `2`, and `3`, respectively.
263+
244264
(lsf-executor)=
245265

246266
## LSF

docs/migrations/25-10.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ export NXF_PLUGINS_REGISTRY_URL="https://raw.githubusercontent.com/nextflow-io/p
9494
Plugin developers will not be able to submit PRs to the legacy plugin index once the plugin registry is generally available. Plugins should be updated to publish to the Nextflow plugin registry using the {ref}`Nextflow Gradle plugin <gradle-plugin-page>` instead. See {ref}`migrate-plugin-registry-page` for details.
9595
:::
9696

97+
<h3>GPU scheduling for local executor</h3>
98+
99+
The local executor can now schedule GPUs using the `accelerator` directive. This feature is useful when running Nextflow on a single machine with multiple GPUs.
100+
101+
See {ref}`local-accelerators` for details.
102+
97103
<h3>New syntax for workflow handlers</h3>
98104

99105
The workflow `onComplete` and `onError` handlers were previously defined by calling `workflow.onComplete` and `workflow.onError` in the pipeline script. You can now define handlers as `onComplete` and `onError` sections in an entry workflow:

modules/nextflow/src/main/groovy/nextflow/executor/local/LocalTaskHandler.groovy

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
8080

8181
private volatile TaskResult result
8282

83+
String acceleratorEnv
84+
85+
List<String> acceleratorIds
86+
8387
LocalTaskHandler(TaskRun task, LocalExecutor executor) {
8488
super(task)
8589
// create the task handler
@@ -142,11 +146,13 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
142146
final workDir = task.workDir.toFile()
143147
final logFile = new File(workDir, TaskRun.CMD_LOG)
144148

145-
return new ProcessBuilder()
149+
final pb = new ProcessBuilder()
146150
.redirectErrorStream(true)
147151
.redirectOutput(logFile)
148152
.directory(workDir)
149153
.command(cmd)
154+
prepareAccelerators(pb)
155+
return pb
150156
}
151157

152158
protected ProcessBuilder fusionProcessBuilder() {
@@ -162,10 +168,18 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
162168

163169
final logPath = Files.createTempFile('nf-task','.log')
164170

165-
return new ProcessBuilder()
171+
final pb = new ProcessBuilder()
166172
.redirectErrorStream(true)
167173
.redirectOutput(logPath.toFile())
168174
.command(List.of('sh','-c', cmd))
175+
prepareAccelerators(pb)
176+
return pb
177+
}
178+
179+
protected void prepareAccelerators(ProcessBuilder pb) {
180+
if( !acceleratorEnv )
181+
return
182+
pb.environment().put(acceleratorEnv, acceleratorIds.join(','))
169183
}
170184

171185
protected ProcessBuilder createLaunchProcessBuilder() {
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nextflow.processor
18+
19+
import groovy.transform.CompileStatic
20+
import nextflow.SysEnv
21+
import nextflow.util.TrackingSemaphore
22+
23+
/**
24+
* Specialized semaphore that keeps track of accelerators by
25+
* id. The id can be an integer or a UUID.
26+
*
27+
* @author Ben Sherman <bentshermann@gmail.com>
28+
*/
29+
@CompileStatic
30+
class AcceleratorTracker {
31+
32+
private static final List<String> DEVICE_ENV_NAMES = [
33+
'CUDA_VISIBLE_DEVICES',
34+
'HIP_VISIBLE_DEVICES',
35+
'ROCR_VISIBLE_DEVICES'
36+
]
37+
38+
static AcceleratorTracker create() {
39+
return create(SysEnv.get())
40+
}
41+
42+
static AcceleratorTracker create(Map<String,String> env) {
43+
return DEVICE_ENV_NAMES.stream()
44+
.filter(name -> env.containsKey(name))
45+
.map((name) -> {
46+
final ids = env.get(name).tokenize(',')
47+
return new AcceleratorTracker(name, ids)
48+
})
49+
.findFirst().orElse(new AcceleratorTracker())
50+
}
51+
52+
private final String name
53+
private final TrackingSemaphore semaphore
54+
55+
private AcceleratorTracker(String name, List<String> ids) {
56+
this.name = name
57+
this.semaphore = new TrackingSemaphore(ids)
58+
}
59+
60+
private AcceleratorTracker() {
61+
this(null, [])
62+
}
63+
64+
String name() {
65+
return name
66+
}
67+
68+
int total() {
69+
return semaphore.totalPermits()
70+
}
71+
72+
int available() {
73+
return semaphore.availablePermits()
74+
}
75+
76+
List<String> acquire(int permits) {
77+
return semaphore.acquire(permits)
78+
}
79+
80+
void release(List<String> ids) {
81+
semaphore.release(ids)
82+
}
83+
84+
}

modules/nextflow/src/main/groovy/nextflow/processor/LocalPollingMonitor.groovy

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import groovy.util.logging.Slf4j
2424
import nextflow.Session
2525
import nextflow.executor.ExecutorConfig
2626
import nextflow.exception.ProcessUnrecoverableException
27+
import nextflow.executor.local.LocalTaskHandler
2728
import nextflow.util.Duration
2829
import nextflow.util.MemoryUnit
2930

@@ -59,6 +60,11 @@ class LocalPollingMonitor extends TaskPollingMonitor {
5960
*/
6061
private final long maxMemory
6162

63+
/**
64+
* Tracks the total and available accelerators in the system
65+
*/
66+
private AcceleratorTracker acceleratorTracker
67+
6268
/**
6369
* Create the task polling monitor with the provided named parameters object.
6470
* <p>
@@ -76,6 +82,7 @@ class LocalPollingMonitor extends TaskPollingMonitor {
7682
super(params)
7783
this.availCpus = maxCpus = params.cpus as int
7884
this.availMemory = maxMemory = params.memory as long
85+
this.acceleratorTracker = AcceleratorTracker.create()
7986
assert availCpus>0, "Local avail `cpus` attribute cannot be zero"
8087
assert availMemory>0, "Local avail `memory` attribute cannot zero"
8188
}
@@ -154,6 +161,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
154161
handler.task.getConfig()?.getMemory()?.toBytes() ?: 1L
155162
}
156163

164+
/**
165+
* @param handler
166+
* A {@link TaskHandler} instance
167+
* @return
168+
* The number of accelerators requested to execute the specified task
169+
*/
170+
private static int accelerators(TaskHandler handler) {
171+
handler.task.getConfig()?.getAccelerator()?.getRequest() ?: 0
172+
}
173+
157174
/**
158175
* Determines if a task can be submitted for execution checking if the resources required
159176
* (cpus and memory) match the amount of avail resource
@@ -179,9 +196,13 @@ class LocalPollingMonitor extends TaskPollingMonitor {
179196
if( taskMemory>maxMemory)
180197
throw new ProcessUnrecoverableException("Process requirement exceeds available memory -- req: ${new MemoryUnit(taskMemory)}; avail: ${new MemoryUnit(maxMemory)}")
181198

182-
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory
199+
final taskAccelerators = accelerators(handler)
200+
if( taskAccelerators > acceleratorTracker.total() )
201+
throw new ProcessUnrecoverableException("Process requirement exceeds available accelerators -- req: $taskAccelerators; avail: ${acceleratorTracker.total()}")
202+
203+
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory && taskAccelerators <= acceleratorTracker.available()
183204
if( !result && log.isTraceEnabled( ) ) {
184-
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)}"
205+
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)} && taskAccelerators: $taskAccelerators <= availAccelerators: ${acceleratorTracker.available()}"
185206
}
186207
return result
187208
}
@@ -194,9 +215,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
194215
*/
195216
@Override
196217
protected void submit(TaskHandler handler) {
197-
super.submit(handler)
198218
availCpus -= cpus(handler)
199219
availMemory -= mem(handler)
220+
221+
final taskAccelerators = accelerators(handler)
222+
if( handler instanceof LocalTaskHandler && taskAccelerators > 0 ) {
223+
handler.acceleratorEnv = acceleratorTracker.name()
224+
handler.acceleratorIds = acceleratorTracker.acquire(taskAccelerators)
225+
}
226+
227+
super.submit(handler)
200228
}
201229

202230
/**
@@ -209,11 +237,14 @@ class LocalPollingMonitor extends TaskPollingMonitor {
209237
* {@code true} when the task is successfully removed from polling queue,
210238
* {@code false} otherwise
211239
*/
240+
@Override
212241
protected boolean remove(TaskHandler handler) {
213242
final result = super.remove(handler)
214243
if( result ) {
215244
availCpus += cpus(handler)
216245
availMemory += mem(handler)
246+
if( handler instanceof LocalTaskHandler )
247+
acceleratorTracker.release(handler.acceleratorIds ?: Collections.<String>emptyList())
217248
}
218249
return result
219250
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2013-2024, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nextflow.util
18+
19+
import java.util.concurrent.Semaphore
20+
21+
import groovy.transform.CompileStatic
22+
23+
/**
24+
* Specialized semaphore that keeps track of which permits
25+
* are being used.
26+
*
27+
* @author Ben Sherman <bentshermann@gmail.com>
28+
*/
29+
@CompileStatic
30+
class TrackingSemaphore {
31+
private final Semaphore semaphore
32+
private final Map<String,Boolean> availIds
33+
34+
TrackingSemaphore(List<String> ids) {
35+
semaphore = new Semaphore(ids.size())
36+
availIds = new HashMap<>(ids.size())
37+
for( final id : ids )
38+
availIds.put(id, true)
39+
}
40+
41+
int totalPermits() {
42+
return availIds.size()
43+
}
44+
45+
int availablePermits() {
46+
return semaphore.availablePermits()
47+
}
48+
49+
List<String> acquire(int permits) {
50+
semaphore.acquire(permits)
51+
final result = new ArrayList<String>(permits)
52+
for( final entry : availIds.entrySet() ) {
53+
if( entry.getValue() ) {
54+
entry.setValue(false)
55+
result.add(entry.getKey())
56+
}
57+
if( result.size() == permits )
58+
break
59+
}
60+
return result
61+
}
62+
63+
void release(List<String> ids) {
64+
semaphore.release(ids.size())
65+
for( final id : ids )
66+
availIds.put(id, true)
67+
}
68+
69+
}

0 commit comments

Comments
 (0)