diff --git a/requirements.txt b/requirements.txt index a805aacd8..1a4b08061 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ openshift-client==1.0.18 rich==12.5.1 ray[default]==2.1.0 kubernetes==26.1.0 -codeflare-torchx==0.5.0.dev5 +codeflare-torchx==0.6.0.dev0 diff --git a/src/codeflare_sdk/job/jobs.py b/src/codeflare_sdk/job/jobs.py index d7581538d..d89269220 100644 --- a/src/codeflare_sdk/job/jobs.py +++ b/src/codeflare_sdk/job/jobs.py @@ -61,6 +61,7 @@ def __init__( max_retries: int = 0, mounts: Optional[List[str]] = None, rdzv_port: int = 29500, + rdzv_backend: str = None, scheduler_args: Optional[Dict[str, str]] = None, image: Optional[str] = None, ): @@ -81,6 +82,7 @@ def __init__( self.max_retries = max_retries self.mounts: List[str] = mounts if mounts is not None else [] self.rdzv_port = rdzv_port + self.rdzv_backend = rdzv_backend self.scheduler_args: Dict[str, str] = ( scheduler_args if scheduler_args is not None else dict() ) @@ -104,6 +106,9 @@ def _dry_run(self, cluster: "Cluster"): env=self.env, max_retries=self.max_retries, rdzv_port=self.rdzv_port, + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "static", mounts=self.mounts, ), scheduler=cluster.torchx_scheduler, @@ -142,6 +147,9 @@ def _dry_run_no_cluster(self): env=self.env, # should this still exist? max_retries=self.max_retries, rdzv_port=self.rdzv_port, # should this still exist? + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "c10d", mounts=self.mounts, image=self.image if self.image is not None