Skip to content

adds role mgmt for the sdk #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions refinery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def get_project_details(self) -> Dict[str, str]:
Dict[str, str]: dictionary containing the above information
"""
url = settings.get_project_url(self.project_id)
api_response = api_calls.get_request(url, self.session_token)
api_response = api_calls.get_request(
url,
self.session_token,
self.project_id,
)
return api_response

def get_primary_keys(self) -> List[str]:
Expand Down Expand Up @@ -107,7 +111,11 @@ def get_lookup_list(self, list_id: str) -> Dict[str, str]:
Dict[str, str]: Containing the specified lookup list of your project.
"""
url = settings.get_lookup_list_url(self.project_id, list_id)
api_response = api_calls.get_request(url, self.session_token)
api_response = api_calls.get_request(
url,
self.session_token,
self.project_id,
)
return api_response

def get_lookup_lists(self) -> List[Dict[str, str]]:
Expand Down Expand Up @@ -140,7 +148,7 @@ def get_record_export(
"""
url = settings.get_export_url(self.project_id)
api_response = api_calls.get_request(
url, self.session_token, **{"num_samples": num_samples}
url, self.session_token, self.project_id, **{"num_samples": num_samples}
)
df = pd.DataFrame(api_response)

Expand Down Expand Up @@ -212,6 +220,7 @@ def post_associations(
"source_type": source_type,
},
self.session_token,
self.project_id,
)
return api_response

Expand All @@ -234,13 +243,15 @@ def post_records(self, records: List[Dict[str, Any]]):
"is_last": False,
},
self.session_token,
self.project_id,
)
batch_responses.append(api_response)
time.sleep(0.5) # wait half a second to avoid server overload
api_calls.post_request(
url,
{"request_uuid": request_uuid, "records": [], "is_last": True},
self.session_token,
self.project_id,
)
return batch_responses

Expand Down Expand Up @@ -281,6 +292,7 @@ def post_file_import(
config_api_response = api_calls.get_request(
config_url,
self.session_token,
self.project_id,
)
endpoint = config_api_response.get("KERN_S3_ENDPOINT")

Expand All @@ -294,6 +306,7 @@ def post_file_import(
"import_file_options": import_file_options,
},
self.session_token,
self.project_id,
)
credentials = credentials_api_response["Credentials"]
access_key = credentials["AccessKeyId"]
Expand Down Expand Up @@ -362,6 +375,8 @@ def __monitor_task(self, upload_task_id: str) -> None:

def __get_task(self, upload_task_id: str) -> Dict[str, Any]:
api_response = api_calls.get_request(
settings.get_task(self.project_id, upload_task_id), self.session_token
settings.get_task(self.project_id, upload_task_id),
self.session_token,
self.project_id,
)
return api_response
13 changes: 8 additions & 5 deletions refinery/api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
version = "noversion"


def post_request(url: str, body: Dict[str, Any], session_token: str) -> str:
def post_request(
url: str, body: Dict[str, Any], session_token: str, project_id: str
) -> str:
headers = _build_headers(session_token)
response = requests.post(url=url, json=body, headers=headers)
return _handle_response(response)
return _handle_response(response, project_id)


def get_request(url: str, session_token: str, **query_params) -> str:
def get_request(url: str, session_token: str, project_id: str, **query_params) -> str:
headers = _build_headers(session_token)
response = requests.get(url=url, headers=headers, params=query_params)
return _handle_response(response)
return _handle_response(response, project_id)


def _build_headers(session_token: str) -> Dict[str, str]:
Expand All @@ -33,7 +35,7 @@ def _build_headers(session_token: str) -> Dict[str, str]:
}


def _handle_response(response: requests.Response) -> str:
def _handle_response(response: requests.Response, project_id: str) -> str:
status_code = response.status_code
if status_code == 200:
json_data = response.json()
Expand All @@ -53,5 +55,6 @@ def _handle_response(response: requests.Response) -> str:
status_code=status_code,
error_code=error_code,
error_message=error_message,
project_id=project_id,
)
raise exception
31 changes: 24 additions & 7 deletions refinery/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,27 @@ class PrimaryKeyError(LocalError):

# https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses
class APIError(Exception):
def __init__(self, message: Optional[str] = None):
def __init__(self, project_id: str, message: Optional[str] = None):
self.project_id = project_id
if message is None:
message = "Please check the SDK documentation at https://github.com/code-kern-ai/refinery-python"
super().__init__(message)


# 401 Unauthorized
class UnauthorizedError(APIError):
pass
def __init__(self, project_id: str, message: Optional[str] = None):
super().__init__(project_id, message)


# 403 Forbidden
class ForbiddenError(APIError):
def __init__(
self,
project_id: str,
message="You can't access the project with your current role.",
):
super().__init__(project_id, message)


# 404 Not Found
Expand All @@ -33,15 +45,17 @@ class NotFoundError(APIError):


class UnknownProjectError(APIError):
def __init__(self, project_id: str):
def __init__(self, project_id: str, message: Optional[str] = None):
super().__init__(
message=f"Could not find project '{project_id}'. Please check your input."
project_id,
f"Could not find project '{project_id}'. Please check your input.",
)


# 500 Server Error
class InternalServerError(APIError):
pass
def __init__(self, project_id: str, message: Optional[str] = None):
super().__init__(project_id, message)


class FileImportError(Exception):
Expand All @@ -51,11 +65,13 @@ class FileImportError(Exception):
# mirror this from the rest api class ErrorCodes
class ErrorCodes:
UNRECOGNIZED_USER = "UNRECOGNIZED_USER" # not actively used in SDK
FORBIDDEN_USER = "FORBIDDEN_USER"
PROJECT_NOT_FOUND = "PROJECT_NOT_FOUND"


RESPONSE_CODES_API_EXCEPTION_MAP = {
401: {"*": UnauthorizedError},
403: {"*": ForbiddenError},
404: {"*": NotFoundError, ErrorCodes.PROJECT_NOT_FOUND: UnknownProjectError},
500: {"*": InternalServerError},
}
Expand All @@ -65,10 +81,11 @@ def get_api_exception_class(
status_code: int,
error_code: Optional[str] = None,
error_message: Optional[str] = None,
project_id: Optional[str] = None,
) -> APIError:
exception_or_dict = RESPONSE_CODES_API_EXCEPTION_MAP.get(status_code, APIError)
if isinstance(exception_or_dict, dict):
exception_class = exception_or_dict.get(error_code, exception_or_dict["*"])
return exception_class(project_id=project_id)
else:
exception_class = exception_or_dict
return exception_class(error_message)
return exception_or_dict(project_id=project_id, message=error_message)