Skip to content

Commit 6bd88c2

Browse files
committed
Add support for OAR Scheduler (merged with the new structure 1.0a1)
1 parent 55c30c1 commit 6bd88c2

File tree

4 files changed

+276
-0
lines changed

4 files changed

+276
-0
lines changed

pydra/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def pytest_generate_tests(metafunc):
3434
available_workers.append("dask")
3535
if bool(shutil.which("sbatch")):
3636
available_workers.append("slurm")
37+
if bool(shutil.which("oarsub")):
38+
available_workers.append("oar")
3739
else:
3840
available_workers = [only_worker]
3941
# Set the available workers as a parameter to the

pydra/engine/tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
not (bool(shutil.which("qsub")) and bool(shutil.which("qacct"))),
3838
reason="sge not available",
3939
)
40+
need_oar = pytest.mark.skipif(
41+
not (bool(shutil.which("oarsub")) and bool(shutil.which("oarstat"))),
42+
reason="oar not available",
43+
)
4044

4145

4246
def num_python_cache_roots(cache_path: Path) -> int:

pydra/workers/oar.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import asyncio
2+
import os
3+
import sys
4+
import json
5+
import re
6+
import typing as ty
7+
from tempfile import gettempdir
8+
from pathlib import Path
9+
from shutil import copyfile
10+
import logging
11+
import attrs
12+
from pydra.engine.job import Job, save
13+
from pydra.workers import base
14+
15+
16+
logger = logging.getLogger("pydra.worker")
17+
18+
if ty.TYPE_CHECKING:
19+
from pydra.engine.result import Result
20+
21+
22+
@attrs.define
23+
class OarWorker(base.Worker):
24+
"""A worker to execute tasks on OAR systems."""
25+
26+
_cmd = "oarsub"
27+
28+
poll_delay: int = attrs.field(default=1, converter=base.ensure_non_negative)
29+
oarsub_args: str = ""
30+
error: dict[str, ty.Any] = attrs.field(factory=dict)
31+
32+
def __getstate__(self) -> dict[str, ty.Any]:
33+
"""Return state for pickling."""
34+
state = super().__getstate__()
35+
del state["error"]
36+
return state
37+
38+
def __setstate__(self, state: dict[str, ty.Any]):
39+
"""Set state for unpickling."""
40+
state["error"] = {}
41+
super().__setstate__(state)
42+
43+
def _prepare_runscripts(self, job, interpreter="/bin/sh", rerun=False):
44+
if isinstance(job, Job):
45+
cache_root = job.cache_root
46+
ind = None
47+
uid = job.uid
48+
else:
49+
assert isinstance(job, tuple), f"Expecting a job or a tuple, not {job!r}"
50+
assert len(job) == 2, f"Expecting a tuple of length 2, not {job!r}"
51+
ind = job[0]
52+
cache_root = job[-1].cache_root
53+
uid = f"{job[-1].uid}_{ind}"
54+
55+
script_dir = cache_root / f"{self.plugin_name()}_scripts" / uid
56+
script_dir.mkdir(parents=True, exist_ok=True)
57+
if ind is None:
58+
if not (script_dir / "_job.pklz").exists():
59+
save(script_dir, job=job)
60+
else:
61+
copyfile(job[1], script_dir / "_job.pklz")
62+
63+
job_pkl = script_dir / "_job.pklz"
64+
if not job_pkl.exists() or not job_pkl.stat().st_size:
65+
raise Exception("Missing or empty job!")
66+
67+
batchscript = script_dir / f"batchscript_{uid}.sh"
68+
python_string = (
69+
f"""'from pydra.engine.job import load_and_run; """
70+
f"""load_and_run("{job_pkl}", rerun={rerun}) '"""
71+
)
72+
bcmd = "\n".join(
73+
(
74+
f"#!{interpreter}",
75+
f"{sys.executable} -c " + python_string,
76+
)
77+
)
78+
with batchscript.open("wt") as fp:
79+
fp.writelines(bcmd)
80+
os.chmod(batchscript, 0o544)
81+
return script_dir, batchscript
82+
83+
async def run(self, job: "Job[base.TaskType]", rerun: bool = False) -> "Result":
84+
"""Worker submission API."""
85+
script_dir, batch_script = self._prepare_runscripts(job, rerun=rerun)
86+
if (script_dir / script_dir.parts[1]) == gettempdir():
87+
logger.warning("Temporary directories may not be shared across computers")
88+
script_dir = job.cache_root / f"{self.plugin_name()}_scripts" / job.uid
89+
sargs = self.oarsub_args.split()
90+
jobname = re.search(r"(?<=-n )\S+|(?<=--name=)\S+", self.oarsub_args)
91+
if not jobname:
92+
jobname = ".".join((job.name, job.uid))
93+
sargs.append(f"--name={jobname}")
94+
output = re.search(r"(?<=-O )\S+|(?<=--stdout=)\S+", self.oarsub_args)
95+
if not output:
96+
output_file = str(script_dir / "oar-%jobid%.out")
97+
sargs.append(f"--stdout={output_file}")
98+
error = re.search(r"(?<=-E )\S+|(?<=--stderr=)\S+", self.oarsub_args)
99+
if not error:
100+
error_file = str(script_dir / "oar-%jobid%.err")
101+
sargs.append(f"--stderr={error_file}")
102+
else:
103+
error_file = None
104+
sargs.append(str(batch_script))
105+
# TO CONSIDER: add random sleep to avoid overloading calls
106+
rc, stdout, stderr = await base.read_and_display_async(
107+
self._cmd, *sargs, hide_display=True
108+
)
109+
jobid = re.search(r"OAR_JOB_ID=(\d+)", stdout)
110+
if rc:
111+
raise RuntimeError(f"Error returned from oarsub: {stderr}")
112+
elif not jobid:
113+
raise RuntimeError("Could not extract job ID")
114+
jobid = jobid.group(1)
115+
if error_file:
116+
error_file = error_file.replace("%jobid%", jobid)
117+
self.error[jobid] = error_file.replace("%jobid%", jobid)
118+
# intermittent polling
119+
while True:
120+
# 4 possibilities
121+
# False: job is still pending/working
122+
# Terminated: job is complete
123+
# Error + idempotent: job has been stopped and resubmited with another jobid
124+
# Error: Job failure
125+
done = await self._poll_job(jobid)
126+
if not done:
127+
await asyncio.sleep(self.poll_delay)
128+
elif done == "Terminated":
129+
return True
130+
elif done == "Error" and "idempotent" in self.oarsub_args:
131+
jobid = await self._handle_resubmission(jobid, job)
132+
continue
133+
else:
134+
error_file = self.error[jobid]
135+
if not Path(error_file).exists():
136+
logger.debug(
137+
f"No error file for job {jobid}. Checking if job was resubmitted by OAR..."
138+
)
139+
jobid = await self._handle_resubmission(jobid, job)
140+
if jobid:
141+
continue
142+
for _ in range(5):
143+
if Path(error_file).exists():
144+
break
145+
await asyncio.sleep(1)
146+
else:
147+
raise RuntimeError(
148+
f"OAR error file not found: {error_file}, and no resubmission detected."
149+
)
150+
error_line = Path(error_file).read_text().split("\n")[-2]
151+
if "Exception" in error_line:
152+
error_message = error_line.replace("Exception: ", "")
153+
elif "Error" in error_line:
154+
error_message = error_line.replace("Error: ", "")
155+
else:
156+
error_message = "Job failed (unknown reason - TODO)"
157+
raise Exception(error_message)
158+
return True
159+
160+
async def _poll_job(self, jobid):
161+
cmd = ("oarstat", "-J", "-s", "-j", jobid)
162+
logger.debug(f"Polling job {jobid}")
163+
_, stdout, _ = await base.read_and_display_async(*cmd, hide_display=True)
164+
if not stdout:
165+
raise RuntimeError("Job information not found")
166+
status = json.loads(stdout)[jobid]
167+
if status in ["Waiting", "Launching", "Running", "Finishing"]:
168+
return False
169+
return status
170+
171+
async def _handle_resubmission(self, jobid, job):
172+
logger.debug(f"Job {jobid} has been stopped. Looking for its resubmission...")
173+
# loading info about task with a specific uid
174+
info_file = job.cache_root / f"{job.uid}_info.json"
175+
if info_file.exists():
176+
checksum = json.loads(info_file.read_text())["checksum"]
177+
lock_file = job.cache_root / f"{checksum}.lock"
178+
if lock_file.exists():
179+
lock_file.unlink()
180+
cmd_re = ("oarstat", "-J", "--sql", f"resubmit_job_id='{jobid}'")
181+
_, stdout, _ = await base.read_and_display_async(*cmd_re, hide_display=True)
182+
if stdout:
183+
return next(iter(json.loads(stdout).keys()), None)
184+
else:
185+
return None
186+
187+
188+
# Alias so it can be referred to as oar.Worker
189+
Worker = OarWorker

pydra/workers/tests/test_worker.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
need_sge,
2323
need_slurm,
2424
need_singularity,
25+
need_oar,
2526
BasicWorkflow,
2627
BasicWorkflowWithThreadCount,
2728
BasicWorkflowWithThreadCountConcurrent,
@@ -602,6 +603,86 @@ def test_sge_no_limit_maxthreads(tmpdir):
602603
assert job_1_endtime > job_2_starttime
603604

604605

606+
@need_oar
607+
def test_oar_wf(tmpdir):
608+
wf = BasicWorkflow(x=1)
609+
# submit workflow and every task as oar job
610+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
611+
res = sub(wf)
612+
613+
outputs = res.outputs
614+
assert outputs.out == 5
615+
script_dir = tmpdir / "oar_scripts"
616+
assert script_dir.exists()
617+
# ensure each task was executed with oar
618+
assert len([sd for sd in script_dir.listdir() if sd.isdir()]) == 2
619+
620+
621+
@pytest.mark.skip(
622+
reason=(
623+
"There currently isn't a way to specify a worker to run a whole workflow within "
624+
"a single OAR job"
625+
)
626+
)
627+
@need_oar
628+
def test_oar_wf_cf(tmpdir):
629+
# submit entire workflow as single job executing with cf worker
630+
wf = BasicWorkflow(x=1)
631+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
632+
res = sub(wf)
633+
634+
outputs = res.outputs
635+
assert outputs.out == 5
636+
script_dir = tmpdir / "oar_scripts"
637+
assert script_dir.exists()
638+
# ensure only workflow was executed with oar
639+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
640+
assert len(sdirs) == 1
641+
# oar scripts should be in the dirs that are using uid in the name
642+
assert sdirs[0].basename == wf.uid
643+
644+
645+
@need_oar
646+
def test_oar_wf_state(tmpdir):
647+
wf = BasicWorkflow().split(x=[5, 6])
648+
with Submitter(worker="oar", cache_root=tmpdir) as sub:
649+
res = sub(wf)
650+
651+
outputs = res.outputs
652+
assert outputs.out == [9, 10]
653+
script_dir = tmpdir / "oar_scripts"
654+
assert script_dir.exists()
655+
sdirs = [sd for sd in script_dir.listdir() if sd.isdir()]
656+
assert len(sdirs) == 2 * len(wf.x)
657+
658+
659+
@need_oar
660+
def test_oar_args_1(tmpdir):
661+
"""testing sbatch_args provided to the submitter"""
662+
task = SleepAddOne(x=1)
663+
# submit workflow and every task as oar job
664+
with Submitter(worker="oar", cache_root=tmpdir, oarsub_args="-l nodes=2") as sub:
665+
res = sub(task)
666+
667+
assert res.outputs.out == 2
668+
script_dir = tmpdir / "oar_scripts"
669+
assert script_dir.exists()
670+
671+
672+
@need_oar
673+
def test_oar_args_2(tmpdir):
674+
"""testing oarsub_args provided to the submitter
675+
exception should be raised for invalid options
676+
"""
677+
task = SleepAddOne(x=1)
678+
# submit workflow and every task as oar job
679+
with pytest.raises(RuntimeError, match="Error returned from oarsub:"):
680+
with Submitter(
681+
worker="oar", cache_root=tmpdir, oarsub_args="-l nodes=2 --invalid"
682+
) as sub:
683+
sub(task)
684+
685+
605686
def test_hash_changes_in_task_inputs_file(tmp_path):
606687
@python.define
607688
def cache_dir_as_input(out_dir: Directory) -> Directory:

0 commit comments

Comments
 (0)