Skip to content

Commit f2019aa

Browse files
authored
enables upload of json (#23)
* enables upload of json * adds upload of dfs * adds batching to import uploads * handle json upload as file upload
1 parent 7f3d0a6 commit f2019aa

File tree

3 files changed

+64
-3
lines changed

3 files changed

+64
-3
lines changed

refinery/__init__.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3+
from uuid import uuid4
34
from black import Any
45
from wasabi import msg
56
import pandas as pd
@@ -214,6 +215,44 @@ def post_associations(
214215
)
215216
return api_response
216217

218+
def post_records(self, records: List[Dict[str, Any]]):
219+
"""Posts records to the server.
220+
221+
Args:
222+
records (List[Dict[str, str]]): List of records to post.
223+
"""
224+
request_uuid = str(uuid4())
225+
url = settings.get_import_json_url(self.project_id)
226+
227+
batch_responses = []
228+
for records_batch in util.batch(records, settings.BATCH_SIZE_DEFAULT):
229+
api_response = api_calls.post_request(
230+
url,
231+
{
232+
"request_uuid": request_uuid,
233+
"records": records_batch,
234+
"is_last": False,
235+
},
236+
self.session_token,
237+
)
238+
batch_responses.append(api_response)
239+
time.sleep(0.5) # wait half a second to avoid server overload
240+
api_calls.post_request(
241+
url,
242+
{"request_uuid": request_uuid, "records": [], "is_last": True},
243+
self.session_token,
244+
)
245+
return batch_responses
246+
247+
def post_df(self, df: pd.DataFrame):
248+
"""Posts a DataFrame to the server.
249+
250+
Args:
251+
df (pd.DataFrame): DataFrame to post.
252+
"""
253+
records = df.to_dict(orient="records")
254+
return self.post_records(records)
255+
217256
def post_file_import(
218257
self, path: str, import_file_options: Optional[str] = ""
219258
) -> bool:
@@ -246,7 +285,7 @@ def post_file_import(
246285
endpoint = config_api_response.get("KERN_S3_ENDPOINT")
247286

248287
# credentials
249-
credentials_url = settings.get_import_url(self.project_id)
288+
credentials_url = settings.get_import_file_url(self.project_id)
250289
credentials_api_response = api_calls.post_request(
251290
credentials_url,
252291
{

refinery/settings.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
BASE_URI: str
33
DEFAULT_URI: str = "https://app.kern.ai"
44

5+
BATCH_SIZE_DEFAULT: int = 1000
6+
57

68
def set_base_uri(uri: str):
79
global BASE_URI
@@ -40,8 +42,13 @@ def get_export_url(project_id: str) -> str:
4042
return f"{get_project_url(project_id)}/export"
4143

4244

43-
def get_import_url(project_id: str) -> str:
44-
return f"{get_project_url(project_id)}/import"
45+
def get_import_file_url(project_id: str) -> str:
46+
return f"{get_project_url(project_id)}/import_file"
47+
48+
49+
def get_import_json_url(project_id: str) -> str:
50+
return f"{get_project_url(project_id)}/import_json"
51+
4552

4653
def get_associations_url(project_id: str) -> str:
4754
return f"{get_project_url(project_id)}/associations"

refinery/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import boto3
22
from botocore.client import Config
3+
from typing import List, Dict, Any
34

45

56
def s3_upload(
@@ -29,3 +30,17 @@ def s3_upload(
2930
with open(file_path, "rb") as file:
3031
s3_object.put(Body=file)
3132
return True
33+
34+
35+
def batch(records: List[Dict[str, Any]], batch_size: int):
36+
"""Batches records into batches of size `batch_size`.
37+
38+
Args:
39+
records (List[Dict[str, Any]]): List of records to batch.
40+
batch_size (int): Size of the batches.
41+
42+
Yields:
43+
List[Dict[str, Any]]: Batches of records.
44+
"""
45+
for i in range(0, len(records), batch_size):
46+
yield records[i : i + batch_size]

0 commit comments

Comments
 (0)