diff --git a/refinery/__init__.py b/refinery/__init__.py index 71025fa..f2a968a 100644 --- a/refinery/__init__.py +++ b/refinery/__init__.py @@ -79,7 +79,11 @@ def get_project_details(self) -> Dict[str, str]: Dict[str, str]: dictionary containing the above information """ url = settings.get_project_url(self.project_id) - api_response = api_calls.get_request(url, self.session_token) + api_response = api_calls.get_request( + url, + self.session_token, + self.project_id, + ) return api_response def get_primary_keys(self) -> List[str]: @@ -107,7 +111,11 @@ def get_lookup_list(self, list_id: str) -> Dict[str, str]: Dict[str, str]: Containing the specified lookup list of your project. """ url = settings.get_lookup_list_url(self.project_id, list_id) - api_response = api_calls.get_request(url, self.session_token) + api_response = api_calls.get_request( + url, + self.session_token, + self.project_id, + ) return api_response def get_lookup_lists(self) -> List[Dict[str, str]]: @@ -140,7 +148,7 @@ def get_record_export( """ url = settings.get_export_url(self.project_id) api_response = api_calls.get_request( - url, self.session_token, **{"num_samples": num_samples} + url, self.session_token, self.project_id, **{"num_samples": num_samples} ) df = pd.DataFrame(api_response) @@ -212,6 +220,7 @@ def post_associations( "source_type": source_type, }, self.session_token, + self.project_id, ) return api_response @@ -234,6 +243,7 @@ def post_records(self, records: List[Dict[str, Any]]): "is_last": False, }, self.session_token, + self.project_id, ) batch_responses.append(api_response) time.sleep(0.5) # wait half a second to avoid server overload @@ -241,6 +251,7 @@ def post_records(self, records: List[Dict[str, Any]]): url, {"request_uuid": request_uuid, "records": [], "is_last": True}, self.session_token, + self.project_id, ) return batch_responses @@ -281,6 +292,7 @@ def post_file_import( config_api_response = api_calls.get_request( config_url, self.session_token, + self.project_id, ) endpoint = config_api_response.get("KERN_S3_ENDPOINT") @@ -294,6 +306,7 @@ def post_file_import( "import_file_options": import_file_options, }, self.session_token, + self.project_id, ) credentials = credentials_api_response["Credentials"] access_key = credentials["AccessKeyId"] @@ -362,6 +375,8 @@ def __monitor_task(self, upload_task_id: str) -> None: def __get_task(self, upload_task_id: str) -> Dict[str, Any]: api_response = api_calls.get_request( - settings.get_task(self.project_id, upload_task_id), self.session_token + settings.get_task(self.project_id, upload_task_id), + self.session_token, + self.project_id, ) return api_response diff --git a/refinery/api_calls.py b/refinery/api_calls.py index 46246e7..68de015 100644 --- a/refinery/api_calls.py +++ b/refinery/api_calls.py @@ -12,16 +12,18 @@ version = "noversion" -def post_request(url: str, body: Dict[str, Any], session_token: str) -> str: +def post_request( + url: str, body: Dict[str, Any], session_token: str, project_id: str +) -> str: headers = _build_headers(session_token) response = requests.post(url=url, json=body, headers=headers) - return _handle_response(response) + return _handle_response(response, project_id) -def get_request(url: str, session_token: str, **query_params) -> str: +def get_request(url: str, session_token: str, project_id: str, **query_params) -> str: headers = _build_headers(session_token) response = requests.get(url=url, headers=headers, params=query_params) - return _handle_response(response) + return _handle_response(response, project_id) def _build_headers(session_token: str) -> Dict[str, str]: @@ -33,7 +35,7 @@ def _build_headers(session_token: str) -> Dict[str, str]: } -def _handle_response(response: requests.Response) -> str: +def _handle_response(response: requests.Response, project_id: str) -> str: status_code = response.status_code if status_code == 200: json_data = response.json() @@ -53,5 +55,6 @@ def _handle_response(response: requests.Response) -> str: status_code=status_code, error_code=error_code, error_message=error_message, + project_id=project_id, ) raise exception diff --git a/refinery/exceptions.py b/refinery/exceptions.py index fbbb77a..4cf64c6 100644 --- a/refinery/exceptions.py +++ b/refinery/exceptions.py @@ -16,7 +16,8 @@ class PrimaryKeyError(LocalError): # https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses class APIError(Exception): - def __init__(self, message: Optional[str] = None): + def __init__(self, project_id: str, message: Optional[str] = None): + self.project_id = project_id if message is None: message = "Please check the SDK documentation at https://github.com/code-kern-ai/refinery-python" super().__init__(message) @@ -24,7 +25,18 @@ def __init__(self, message: Optional[str] = None): # 401 Unauthorized class UnauthorizedError(APIError): - pass + def __init__(self, project_id: str, message: Optional[str] = None): + super().__init__(project_id, message) + + +# 403 Forbidden +class ForbiddenError(APIError): + def __init__( + self, + project_id: str, + message="You can't access the project with your current role.", + ): + super().__init__(project_id, message) # 404 Not Found @@ -33,15 +45,17 @@ class NotFoundError(APIError): class UnknownProjectError(APIError): - def __init__(self, project_id: str): + def __init__(self, project_id: str, message: Optional[str] = None): super().__init__( - message=f"Could not find project '{project_id}'. Please check your input." + project_id, + f"Could not find project '{project_id}'. Please check your input.", ) # 500 Server Error class InternalServerError(APIError): - pass + def __init__(self, project_id: str, message: Optional[str] = None): + super().__init__(project_id, message) class FileImportError(Exception): @@ -51,11 +65,13 @@ class FileImportError(Exception): # mirror this from the rest api class ErrorCodes class ErrorCodes: UNRECOGNIZED_USER = "UNRECOGNIZED_USER" # not actively used in SDK + FORBIDDEN_USER = "FORBIDDEN_USER" PROJECT_NOT_FOUND = "PROJECT_NOT_FOUND" RESPONSE_CODES_API_EXCEPTION_MAP = { 401: {"*": UnauthorizedError}, + 403: {"*": ForbiddenError}, 404: {"*": NotFoundError, ErrorCodes.PROJECT_NOT_FOUND: UnknownProjectError}, 500: {"*": InternalServerError}, } @@ -65,10 +81,11 @@ def get_api_exception_class( status_code: int, error_code: Optional[str] = None, error_message: Optional[str] = None, + project_id: Optional[str] = None, ) -> APIError: exception_or_dict = RESPONSE_CODES_API_EXCEPTION_MAP.get(status_code, APIError) if isinstance(exception_or_dict, dict): exception_class = exception_or_dict.get(error_code, exception_or_dict["*"]) + return exception_class(project_id=project_id) else: - exception_class = exception_or_dict - return exception_class(error_message) + return exception_or_dict(project_id=project_id, message=error_message)