Skip to content

Commit 9a3f2fe

Browse files
jhoetterJWittmeyer
andauthored
adds user role mgmt to proxy (#9)
Co-authored-by: JWittmeyer <91723236+JWittmeyer@users.noreply.github.com>
1 parent bb167dd commit 9a3f2fe

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

app.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class ErrorCodes:
1414
UNRECOGNIZED_USER = "UNRECOGNIZED_USER"
15+
FORBIDDEN_USER = "FORBIDDEN_USER"
1516
PROJECT_NOT_FOUND = "PROJECT_NOT_FOUND"
1617

1718

@@ -24,6 +25,23 @@ def get_user_id_from_request(request):
2425
return user_id
2526

2627

28+
def handle_response(response):
29+
if response.status_code == status.HTTP_200_OK:
30+
return responses.JSONResponse(
31+
status_code=status.HTTP_200_OK, content=response.json()
32+
)
33+
elif response.status_code == status.HTTP_404_NOT_FOUND:
34+
return responses.JSONResponse(
35+
status_code=status.HTTP_404_NOT_FOUND,
36+
content={"error_code": ErrorCodes.PROJECT_NOT_FOUND},
37+
)
38+
elif response.status_code == status.HTTP_403_FORBIDDEN:
39+
return responses.JSONResponse(
40+
status_code=status.HTTP_403_FORBIDDEN,
41+
content={"error_code": ErrorCodes.FORBIDDEN_USER},
42+
)
43+
44+
2745
@app.get("/project/{project_id}/export")
2846
def get_export(request: Request, project_id: str, num_samples: Optional[int] = None):
2947
try:
@@ -35,7 +53,7 @@ def get_export(request: Request, project_id: str, num_samples: Optional[int] = N
3553
)
3654
url = f"{BASE_URI}/project/{project_id}/export"
3755
resp = requests.get(url, params={"user_id": user_id, "num_samples": num_samples})
38-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
56+
return handle_response(resp)
3957

4058

4159
@app.get("/project/{project_id}/lookup_list/{lookup_list_id}")
@@ -49,7 +67,7 @@ def get_lookup_list(request: Request, project_id: str, lookup_list_id: str):
4967
)
5068
url = f"{BASE_URI}/project/{project_id}/knowledge_base/{lookup_list_id}"
5169
resp = requests.get(url, params={"user_id": user_id})
52-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
70+
return handle_response(resp)
5371

5472

5573
@app.post("/project/{project_id}/import_file")
@@ -72,7 +90,7 @@ async def post_import_file(request: Request, project_id: str):
7290
"file_import_options": request_body.get("file_import_options"),
7391
},
7492
)
75-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
93+
return handle_response(resp)
7694

7795

7896
@app.post("/project/{project_id}/import_json")
@@ -118,7 +136,7 @@ async def post_associations(request: Request, project_id: str):
118136
"source_type": request_body["source_type"],
119137
},
120138
)
121-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
139+
return handle_response(resp)
122140

123141

124142
@app.get("/project/{project_id}")
@@ -132,15 +150,7 @@ def get_details(request: Request, project_id: str):
132150
)
133151
url = f"{BASE_URI}/project/{project_id}"
134152
resp = requests.get(url, params={"user_id": user_id})
135-
if resp.status_code == 200:
136-
return responses.JSONResponse(
137-
status_code=status.HTTP_200_OK, content=resp.json()
138-
)
139-
else:
140-
return responses.JSONResponse(
141-
status_code=status.HTTP_404_NOT_FOUND,
142-
content={"error_code": ErrorCodes.PROJECT_NOT_FOUND},
143-
)
153+
return handle_response(resp)
144154

145155

146156
@app.get("/project/{project_id}/import/base_config")
@@ -154,7 +164,7 @@ def get_base_config(request: Request, project_id: str):
154164
)
155165
url = f"{CONFIG_URI}/base_config"
156166
resp = requests.get(url)
157-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
167+
return handle_response(resp)
158168

159169

160170
@app.get("/project/{project_id}/import/task/{task_id}")
@@ -168,4 +178,4 @@ def get_details(request: Request, project_id: str, task_id: str):
168178
)
169179
url = f"{BASE_URI}/project/{project_id}/import/task/{task_id}"
170180
resp = requests.get(url, params={"user_id": user_id})
171-
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=resp.json())
181+
return handle_response(resp)

0 commit comments

Comments
 (0)