Skip to content

Commit f65074f

Browse files
committed
Raise SessionStatusError on get(_or_create) endpoint
Instead of manually handling specific state cases
1 parent d03b03b commit f65074f

File tree

9 files changed

+146
-114
lines changed

9 files changed

+146
-114
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Improve error message when a query is interrupted by a signal (SIGINT or SIGTERM).
1717
* Improve error message if session is expired.
1818
* Improve robustness of Arrow client against connection errors such as `FlightUnavailableError` and `FlightTimedOutError`.
19+
* Return dedicated error class `SessionStatusError` if a session failed or expired.
1920

2021

2122
## Other changes

graphdatascience/session/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .algorithm_category import AlgorithmCategory
2+
from .aura_api import AuraApiError, SessionStatusError
23
from .cloud_location import CloudLocation
34
from .dbms_connection_info import DbmsConnectionInfo
45
from .gds_sessions import AuraAPICredentials, GdsSessions
@@ -14,4 +15,6 @@
1415
"SessionMemory",
1516
"SessionMemoryValue",
1617
"AlgorithmCategory",
18+
"SessionStatusError",
19+
"AuraApiError",
1720
]

graphdatascience/session/aura_api.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
InstanceDetails,
2424
InstanceSpecificDetails,
2525
SessionDetails,
26+
SessionDetailsWithErrors,
27+
SessionErrorData,
2628
TenantDetails,
2729
WaitResult,
2830
)
@@ -32,12 +34,27 @@
3234

3335

3436
class AuraApiError(Exception):
37+
"""
38+
Raised when an API call to the AuraAPI fails (after retries).
39+
"""
40+
3541
def __init__(self, message: str, status_code: int):
3642
super().__init__(self, message)
3743
self.status_code = status_code
3844
self.message = message
3945

4046

47+
class SessionStatusError(Exception):
48+
"""
49+
Raised when a session is in a non-healthy state. Such as after a session failed or got expired.
50+
"""
51+
52+
def __init__(self, errors: list[SessionErrorData]):
53+
message = f"Session is in an unhealthy state. Details: {[str(e) for e in errors]}"
54+
55+
super().__init__(self, message)
56+
57+
4158
class AuraApi:
4259
API_VERSION = "v1beta5"
4360

@@ -126,7 +143,9 @@ def get_or_create_session(
126143
self._check_resp(response)
127144

128145
raw_json: dict[str, Any] = response.json()
129-
return SessionDetails.from_json(raw_json["data"], raw_json.get("errors", []))
146+
self._check_errors(raw_json)
147+
148+
return SessionDetails.from_json(raw_json["data"])
130149

131150
def get_session(self, session_id: str) -> Optional[SessionDetails]:
132151
response = self._request_session.get(
@@ -139,9 +158,11 @@ def get_session(self, session_id: str) -> Optional[SessionDetails]:
139158
self._check_resp(response)
140159

141160
raw_json: dict[str, Any] = response.json()
142-
return SessionDetails.from_json(raw_json["data"], raw_json.get("errors", []))
161+
self._check_errors(raw_json)
162+
163+
return SessionDetails.from_json(raw_json["data"])
143164

144-
def list_sessions(self, dbid: Optional[str] = None) -> list[SessionDetails]:
165+
def list_sessions(self, dbid: Optional[str] = None) -> list[SessionDetailsWithErrors]:
145166
# these are query parameters (not passed in the body)
146167
params = {
147168
"tenantId": self._tenant_id,
@@ -161,7 +182,7 @@ def list_sessions(self, dbid: Optional[str] = None) -> list[SessionDetails]:
161182
for error in raw_json.get("errors", []):
162183
errors_per_session[error["id"]].append(error)
163184

164-
return [SessionDetails.from_json(s, errors_per_session[s["id"]]) for s in data]
185+
return [SessionDetailsWithErrors.from_json_with_error(s, errors_per_session[s["id"]]) for s in data] # noqa: F821
165186

166187
def wait_for_session_running(
167188
self,
@@ -177,14 +198,6 @@ def wait_for_session_running(
177198
return WaitResult.from_error(f"Session `{session_id}` not found -- please retry")
178199
elif session.status == "Ready":
179200
return WaitResult.from_connection_url(session.bolt_connection_url())
180-
elif session.status == "Failed":
181-
return WaitResult.from_error(
182-
f"Session `{session_id}` with name `{session.name}` failed due to: {session.errors}"
183-
)
184-
elif session.is_expired():
185-
return WaitResult.from_error(
186-
f"Session `{session_id}` with name `{session.name}` is expired. Expired due to: {session.errors}"
187-
)
188201
else:
189202
self._logger.debug(
190203
f"Session `{session_id}` is not yet running. "
@@ -324,6 +337,13 @@ def tenant_details(self) -> TenantDetails:
324337
self._tenant_details = TenantDetails.from_json(response.json()["data"])
325338
return self._tenant_details
326339

340+
def _check_errors(self, raw_json: dict[str, Any]) -> None:
341+
errors = raw_json.get("errors", [])
342+
typed_errors = [SessionErrorData.from_json(error) for error in errors] if errors else None
343+
344+
if typed_errors:
345+
raise SessionStatusError(typed_errors)
346+
327347
def _check_resp(self, resp: requests.Response) -> None:
328348
self._check_status_code(resp)
329349
self._check_endpoint_deprecation(resp)

graphdatascience/session/aura_api_responses.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,15 @@ class SessionDetails:
2727
user_id: str
2828
tenant_id: str
2929
cloud_location: Optional[CloudLocation] = None
30-
errors: Optional[list[SessionError]] = None
3130

3231
@classmethod
33-
def from_json(cls, data: dict[str, Any], errors: list[dict[str, Any]]) -> SessionDetails:
32+
def from_json(cls, data: dict[str, Any]) -> SessionDetails:
3433
id = data["id"]
3534
expiry_date = data.get("expiry_date")
3635
ttl: Any | None = data.get("ttl")
3736
instance_id = data.get("instance_id")
3837
cloud_location = CloudLocation(data["cloud_provider"], data["region"]) if data.get("cloud_provider") else None
3938

40-
session_errors = [SessionError.from_json(error) for error in errors] if errors else None
41-
4239
return cls(
4340
id=id,
4441
name=data["name"],
@@ -52,18 +49,45 @@ def from_json(cls, data: dict[str, Any], errors: list[dict[str, Any]]) -> Sessio
5249
tenant_id=data["tenant_id"],
5350
user_id=data["user_id"],
5451
cloud_location=cloud_location,
55-
errors=session_errors,
5652
)
5753

5854
def bolt_connection_url(self) -> str:
5955
return f"neo4j+s://{self.host}"
6056

61-
def is_expired(self) -> bool:
62-
return self.status == "Expired"
57+
58+
@dataclass(repr=True, frozen=True)
59+
class SessionDetailsWithErrors(SessionDetails):
60+
errors: Optional[list[SessionErrorData]] = None
61+
62+
@classmethod
63+
def from_json_with_error(cls, data: dict[str, Any], errors: list[dict[str, Any]]) -> SessionDetailsWithErrors:
64+
session_errors = [SessionErrorData.from_json(error) for error in errors] if errors else None
65+
66+
id = data["id"]
67+
expiry_date = data.get("expiry_date")
68+
ttl: Any | None = data.get("ttl")
69+
instance_id = data.get("instance_id")
70+
cloud_location = CloudLocation(data["cloud_provider"], data["region"]) if data.get("cloud_provider") else None
71+
72+
return cls(
73+
id=id,
74+
name=data["name"],
75+
instance_id=instance_id if instance_id else None,
76+
memory=SessionMemoryValue.fromApiResponse(data["memory"]),
77+
status=data["status"],
78+
host=data["host"],
79+
expiry_date=TimeParser.fromisoformat(expiry_date) if expiry_date else None,
80+
created_at=TimeParser.fromisoformat(data["created_at"]),
81+
ttl=Timedelta(ttl).to_pytimedelta() if ttl else None, # datetime has no support for parsing timedelta
82+
tenant_id=data["tenant_id"],
83+
user_id=data["user_id"],
84+
cloud_location=cloud_location,
85+
errors=session_errors,
86+
)
6387

6488

6589
@dataclass(repr=True, frozen=True)
66-
class SessionError:
90+
class SessionErrorData:
6791
"""
6892
Represents information about a session errors.
6993
Indicates that session is in `Failed` state.
@@ -77,12 +101,15 @@ class SessionError:
77101
reason: str
78102

79103
@classmethod
80-
def from_json(cls, json: dict[str, Any]) -> SessionError:
104+
def from_json(cls, json: dict[str, Any]) -> SessionErrorData:
81105
return cls(
82106
reason=json["reason"],
83107
message=json["message"],
84108
)
85109

110+
def __str__(self) -> str:
111+
return f"Reason: {self.reason}, Message: {self.message}"
112+
86113

87114
@dataclass(repr=True, frozen=True)
88115
class InstanceDetails:

graphdatascience/session/dedicated_sessions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str
116116
return False
117117

118118
def list(self, dbid: Optional[str] = None) -> list[SessionInfo]:
119-
sessions: list[SessionDetails] = self._aura_api.list_sessions(dbid)
120-
119+
sessions = self._aura_api.list_sessions(dbid)
121120
return [SessionInfo.from_session_details(i) for i in sessions]
122121

123122
def _find_existing_session(self, session_name: str) -> Optional[SessionDetails]:

graphdatascience/session/gds_sessions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str
127127

128128
def list(self) -> list[SessionInfo]:
129129
"""
130-
Retrieves the list of GDS sessions visible by the user asscociated by the given api-credentials.
130+
Retrieves the list of GDS sessions visible by the user associated by the given api-credentials.
131131
132132
Returns:
133133
A list of SessionInfo objects representing the GDS sessions.

graphdatascience/session/session_info.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import datetime, timedelta
55
from typing import Optional
66

7-
from graphdatascience.session.aura_api_responses import SessionDetails, SessionError
7+
from graphdatascience.session.aura_api_responses import SessionDetails, SessionDetailsWithErrors, SessionErrorData
88
from graphdatascience.session.cloud_location import CloudLocation
99
from graphdatascience.session.session_sizes import SessionMemoryValue
1010

@@ -38,10 +38,14 @@ class SessionInfo:
3838
user_id: str
3939
cloud_location: Optional[CloudLocation]
4040
ttl: Optional[timedelta] = None
41-
errors: Optional[list[SessionError]] = None
41+
errors: Optional[list[SessionErrorData]] = None
4242

4343
@classmethod
44-
def from_session_details(cls, details: SessionDetails) -> SessionInfo:
44+
def from_session_details(cls, details: SessionDetailsWithErrors | SessionDetails) -> SessionInfo:
45+
errors: Optional[list[SessionErrorData]] = None
46+
if isinstance(details, SessionDetailsWithErrors):
47+
errors = details.errors
48+
4549
return SessionInfo(
4650
id=details.id,
4751
name=details.name,
@@ -53,5 +57,5 @@ def from_session_details(cls, details: SessionDetails) -> SessionInfo:
5357
user_id=details.user_id,
5458
cloud_location=details.cloud_location,
5559
ttl=details.ttl,
56-
errors=details.errors,
60+
errors=errors,
5761
)

0 commit comments

Comments
 (0)