Skip to content

Commit f4a50d9

Browse files
jhoetterJWittmeyer
and
JWittmeyer
authored
adds role mgmt for the sdk (#24)
Co-authored-by: JWittmeyer <jens.wittmeyer@onetask.ai>
1 parent f2019aa commit f4a50d9

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

refinery/__init__.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ def get_project_details(self) -> Dict[str, str]:
7979
Dict[str, str]: dictionary containing the above information
8080
"""
8181
url = settings.get_project_url(self.project_id)
82-
api_response = api_calls.get_request(url, self.session_token)
82+
api_response = api_calls.get_request(
83+
url,
84+
self.session_token,
85+
self.project_id,
86+
)
8387
return api_response
8488

8589
def get_primary_keys(self) -> List[str]:
@@ -107,7 +111,11 @@ def get_lookup_list(self, list_id: str) -> Dict[str, str]:
107111
Dict[str, str]: Containing the specified lookup list of your project.
108112
"""
109113
url = settings.get_lookup_list_url(self.project_id, list_id)
110-
api_response = api_calls.get_request(url, self.session_token)
114+
api_response = api_calls.get_request(
115+
url,
116+
self.session_token,
117+
self.project_id,
118+
)
111119
return api_response
112120

113121
def get_lookup_lists(self) -> List[Dict[str, str]]:
@@ -140,7 +148,7 @@ def get_record_export(
140148
"""
141149
url = settings.get_export_url(self.project_id)
142150
api_response = api_calls.get_request(
143-
url, self.session_token, **{"num_samples": num_samples}
151+
url, self.session_token, self.project_id, **{"num_samples": num_samples}
144152
)
145153
df = pd.DataFrame(api_response)
146154

@@ -212,6 +220,7 @@ def post_associations(
212220
"source_type": source_type,
213221
},
214222
self.session_token,
223+
self.project_id,
215224
)
216225
return api_response
217226

@@ -234,13 +243,15 @@ def post_records(self, records: List[Dict[str, Any]]):
234243
"is_last": False,
235244
},
236245
self.session_token,
246+
self.project_id,
237247
)
238248
batch_responses.append(api_response)
239249
time.sleep(0.5) # wait half a second to avoid server overload
240250
api_calls.post_request(
241251
url,
242252
{"request_uuid": request_uuid, "records": [], "is_last": True},
243253
self.session_token,
254+
self.project_id,
244255
)
245256
return batch_responses
246257

@@ -281,6 +292,7 @@ def post_file_import(
281292
config_api_response = api_calls.get_request(
282293
config_url,
283294
self.session_token,
295+
self.project_id,
284296
)
285297
endpoint = config_api_response.get("KERN_S3_ENDPOINT")
286298

@@ -294,6 +306,7 @@ def post_file_import(
294306
"import_file_options": import_file_options,
295307
},
296308
self.session_token,
309+
self.project_id,
297310
)
298311
credentials = credentials_api_response["Credentials"]
299312
access_key = credentials["AccessKeyId"]
@@ -362,6 +375,8 @@ def __monitor_task(self, upload_task_id: str) -> None:
362375

363376
def __get_task(self, upload_task_id: str) -> Dict[str, Any]:
364377
api_response = api_calls.get_request(
365-
settings.get_task(self.project_id, upload_task_id), self.session_token
378+
settings.get_task(self.project_id, upload_task_id),
379+
self.session_token,
380+
self.project_id,
366381
)
367382
return api_response

refinery/api_calls.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
version = "noversion"
1313

1414

15-
def post_request(url: str, body: Dict[str, Any], session_token: str) -> str:
15+
def post_request(
16+
url: str, body: Dict[str, Any], session_token: str, project_id: str
17+
) -> str:
1618
headers = _build_headers(session_token)
1719
response = requests.post(url=url, json=body, headers=headers)
18-
return _handle_response(response)
20+
return _handle_response(response, project_id)
1921

2022

21-
def get_request(url: str, session_token: str, **query_params) -> str:
23+
def get_request(url: str, session_token: str, project_id: str, **query_params) -> str:
2224
headers = _build_headers(session_token)
2325
response = requests.get(url=url, headers=headers, params=query_params)
24-
return _handle_response(response)
26+
return _handle_response(response, project_id)
2527

2628

2729
def _build_headers(session_token: str) -> Dict[str, str]:
@@ -33,7 +35,7 @@ def _build_headers(session_token: str) -> Dict[str, str]:
3335
}
3436

3537

36-
def _handle_response(response: requests.Response) -> str:
38+
def _handle_response(response: requests.Response, project_id: str) -> str:
3739
status_code = response.status_code
3840
if status_code == 200:
3941
json_data = response.json()
@@ -53,5 +55,6 @@ def _handle_response(response: requests.Response) -> str:
5355
status_code=status_code,
5456
error_code=error_code,
5557
error_message=error_message,
58+
project_id=project_id,
5659
)
5760
raise exception

refinery/exceptions.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,27 @@ class PrimaryKeyError(LocalError):
1616

1717
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses
1818
class APIError(Exception):
19-
def __init__(self, message: Optional[str] = None):
19+
def __init__(self, project_id: str, message: Optional[str] = None):
20+
self.project_id = project_id
2021
if message is None:
2122
message = "Please check the SDK documentation at https://github.com/code-kern-ai/refinery-python"
2223
super().__init__(message)
2324

2425

2526
# 401 Unauthorized
2627
class UnauthorizedError(APIError):
27-
pass
28+
def __init__(self, project_id: str, message: Optional[str] = None):
29+
super().__init__(project_id, message)
30+
31+
32+
# 403 Forbidden
33+
class ForbiddenError(APIError):
34+
def __init__(
35+
self,
36+
project_id: str,
37+
message="You can't access the project with your current role.",
38+
):
39+
super().__init__(project_id, message)
2840

2941

3042
# 404 Not Found
@@ -33,15 +45,17 @@ class NotFoundError(APIError):
3345

3446

3547
class UnknownProjectError(APIError):
36-
def __init__(self, project_id: str):
48+
def __init__(self, project_id: str, message: Optional[str] = None):
3749
super().__init__(
38-
message=f"Could not find project '{project_id}'. Please check your input."
50+
project_id,
51+
f"Could not find project '{project_id}'. Please check your input.",
3952
)
4053

4154

4255
# 500 Server Error
4356
class InternalServerError(APIError):
44-
pass
57+
def __init__(self, project_id: str, message: Optional[str] = None):
58+
super().__init__(project_id, message)
4559

4660

4761
class FileImportError(Exception):
@@ -51,11 +65,13 @@ class FileImportError(Exception):
5165
# mirror this from the rest api class ErrorCodes
5266
class ErrorCodes:
5367
UNRECOGNIZED_USER = "UNRECOGNIZED_USER" # not actively used in SDK
68+
FORBIDDEN_USER = "FORBIDDEN_USER"
5469
PROJECT_NOT_FOUND = "PROJECT_NOT_FOUND"
5570

5671

5772
RESPONSE_CODES_API_EXCEPTION_MAP = {
5873
401: {"*": UnauthorizedError},
74+
403: {"*": ForbiddenError},
5975
404: {"*": NotFoundError, ErrorCodes.PROJECT_NOT_FOUND: UnknownProjectError},
6076
500: {"*": InternalServerError},
6177
}
@@ -65,10 +81,11 @@ def get_api_exception_class(
6581
status_code: int,
6682
error_code: Optional[str] = None,
6783
error_message: Optional[str] = None,
84+
project_id: Optional[str] = None,
6885
) -> APIError:
6986
exception_or_dict = RESPONSE_CODES_API_EXCEPTION_MAP.get(status_code, APIError)
7087
if isinstance(exception_or_dict, dict):
7188
exception_class = exception_or_dict.get(error_code, exception_or_dict["*"])
89+
return exception_class(project_id=project_id)
7290
else:
73-
exception_class = exception_or_dict
74-
return exception_class(error_message)
91+
return exception_or_dict(project_id=project_id, message=error_message)

0 commit comments

Comments
 (0)