Skip to content

Commit 22cadbf

Browse files
committed
minor adaptions
1 parent a273a85 commit 22cadbf

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

refinery/callbacks/__init__.py

Whitespace-only changes.

refinery/callbacks/inference.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pandas as pd
2+
from refinery import exceptions
3+
4+
5+
class ModelCallback:
6+
def __init__(
7+
self, client, inference_fn, preprocessing_fn=None, postprocessing_fn=None
8+
):
9+
self.client = client
10+
self.inference_fn = inference_fn
11+
self.preprocessing_fn = preprocessing_fn
12+
self.postprocessing_fn = postprocessing_fn
13+
14+
self.primary_keys = client.get_primary_keys()
15+
16+
@staticmethod
17+
def __batch(documents):
18+
BATCH_SIZE = 32
19+
length = len(documents)
20+
for idx in range(0, length, BATCH_SIZE):
21+
yield documents[idx : min(idx + BATCH_SIZE, length)]
22+
23+
def run(self, inputs, indices):
24+
indices_df = pd.DataFrame(indices)
25+
if not all([key in indices_df.columns for key in self.primary_keys]):
26+
raise exceptions.PrimaryKeyError("Errorneous primary keys given for index.")
27+
28+
index_generator = ModelCallback.__batch(indices)
29+
for batched_inputs in ModelCallback.__batch(inputs):
30+
batched_indices = next(index_generator)
31+
32+
if self.preprocessing_fn is not None:
33+
batched_inputs = self.preprocessing_fn(batched_inputs)
34+
35+
batched_outputs = self.inference_fn(batched_inputs)
36+
37+
if self.postprocessing_fn is not None:
38+
batched_outputs = self.postprocessing_fn(batched_outputs)
39+
40+
yield {"index": batched_indices, "associations": batched_outputs}

refinery/exceptions.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# -*- coding: utf-8 -*-
22
from typing import Optional
33

4+
45
class LocalError(Exception):
56
pass
67

8+
79
class UnknownItemError(LocalError):
810
pass
911

12+
13+
class PrimaryKeyError(LocalError):
14+
pass
15+
16+
1017
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses
1118
class APIError(Exception):
1219
def __init__(self, message: Optional[str] = None):
@@ -24,9 +31,12 @@ class UnauthorizedError(APIError):
2431
class NotFoundError(APIError):
2532
pass
2633

34+
2735
class UnknownProjectError(APIError):
2836
def __init__(self, project_id: str):
29-
super().__init__(message=f"Could not find project '{project_id}'. Please check your input.")
37+
super().__init__(
38+
message=f"Could not find project '{project_id}'. Please check your input."
39+
)
3040

3141

3242
# 500 Server Error
@@ -37,9 +47,10 @@ class InternalServerError(APIError):
3747
class FileImportError(Exception):
3848
pass
3949

50+
4051
# mirror this from the rest api class ErrorCodes
4152
class ErrorCodes:
42-
UNRECOGNIZED_USER = "UNRECOGNIZED_USER" # not actively used in SDK
53+
UNRECOGNIZED_USER = "UNRECOGNIZED_USER" # not actively used in SDK
4354
PROJECT_NOT_FOUND = "PROJECT_NOT_FOUND"
4455

4556

0 commit comments

Comments
 (0)