-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathagent.py
More file actions
85 lines (63 loc) · 2.16 KB
/
agent.py
File metadata and controls
85 lines (63 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import asyncio
import argparse
import shlex
from enum import Enum
from typing import List
import aiohttp
class JobStatus(Enum):
STARTED = "started"
COMPLETED = "completed"
FAILED = "failed"
async def run(command: List[str]) -> JobStatus:
process = await asyncio.create_subprocess_exec(*command)
await process.wait()
if process.returncode == 0:
return JobStatus.COMPLETED
else:
return JobStatus.FAILED
async def send_status(url: str, status: JobStatus) -> None:
payload = {"status": status.value}
async with aiohttp.ClientSession() as session:
await session.put(url, json=payload)
async def heartbeat(url: str, period: float) -> None:
async with aiohttp.ClientSession() as session:
while True:
await session.put(url)
await asyncio.sleep(period)
async def main(
command: List[str], tracking_server_url: str, heartbeat_period: float
) -> None:
status_url = f"{tracking_server_url}/status"
heartbeat_url = f"{tracking_server_url}/heartbeat"
await send_status(status_url, JobStatus.STARTED)
# Use asyncio.create_task to start running the heartbeat coroutine
# immediately
heartbeat_future = asyncio.create_task(
heartbeat(heartbeat_url, heartbeat_period)
)
# Run command
final_status = await run(command)
await send_status(status_url, final_status)
# Cancel heartbeat future (as it is on an infinite loop) and wait for it to
# finish
heartbeat_future.cancel()
try:
await heartbeat_future
except asyncio.CancelledError:
pass
def cli():
parser = argparse.ArgumentParser()
parser.add_argument("command", help="The job command to run")
parser.add_argument("tracking_url", help="The URL of the tracking server")
parser.add_argument(
"--heartbeat-period",
type=float,
default=1.,
help="The period on which to send heartbeats to the tracking server "
"(in seconds)",
)
args = parser.parse_args()
command_parts = shlex.split(args.command)
asyncio.run(main(command_parts, args.tracking_url, args.heartbeat_period))
if __name__ == "__main__":
cli()