Skip to content

Commit 37b21d3

Browse files
committed
Add modin support; fixes for ANN emulators
1 parent 509dbba commit 37b21d3

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

sklbench/datasets/transformer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17+
import os
1718
import numpy as np
1819
import pandas as pd
1920
from scipy.sparse import csr_matrix
@@ -40,7 +41,19 @@ def convert_data(data, dformat: str, order: str, dtype: str):
4041
if data.ndim == 1:
4142
return pd.Series(data)
4243
return pd.DataFrame(data)
43-
elif dformat == "modin":
44+
elif dformat.startswith("modin"):
45+
if dformat.endswith("ray"):
46+
os.environ["MODIN_ENGINE"] = "ray"
47+
elif dformat.endswith("dask"):
48+
os.environ["MODIN_ENGINE"] = "dask"
49+
elif dformat.endswith("unidist"):
50+
os.environ["MODIN_ENGINE"] = "unidist"
51+
os.environ["UNIDIST_BACKEND"] = "mpi"
52+
else:
53+
logger.info(
54+
"Modin engine is unknown or not specified. Default engine will be used."
55+
)
56+
4457
import modin.pandas as modin_pd
4558

4659
if data.ndim == 1:

sklbench/emulators/faiss/neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848

4949
def fit(self, X, y=None):
5050
d = X.shape[1]
51+
if isinstance(self.m_subvectors, float):
52+
self.m_subvectors = self.get_m_subvectors(self.m_subvectors, d)
5153
self._base_index = faiss.IndexFlatL2(d)
5254
if self.algorithm == "brute":
5355
self._index = self._base_index
@@ -56,8 +58,6 @@ def fit(self, X, y=None):
5658
self._base_index, d, self.n_lists, faiss.METRIC_L2
5759
)
5860
elif self.algorithm == "ivf_pq":
59-
if isinstance(self.m_subvectors, float):
60-
self.m_subvectors = self.get_m_subvectors(self.m_subvectors, d)
6161
self._index = faiss.IndexIVFPQ(
6262
self._base_index,
6363
d,

sklbench/emulators/raft/neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def __init__(
5050

5151
def fit(self, X, y=None):
5252
d = X.shape[1]
53+
if isinstance(self.m_subvectors, float):
54+
self.m_subvectors = self.get_m_subvectors(self.m_subvectors, d)
5355
if self.algorithm == "brute":
5456
self._X_fit = X
5557
elif self.algorithm == "ivf_flat":
5658
index_params = ivf_flat.IndexParams(n_lists=self.n_lists, metric=self.metric)
5759
self._index = ivf_flat.build(index_params, X, handle=self._handle)
5860
elif self.algorithm == "ivf_pq":
59-
if isinstance(self.m_subvectors, float):
60-
self.m_subvectors = self.get_m_subvectors(self.m_subvectors, d)
6161
index_params = ivf_pq.IndexParams(
6262
n_lists=self.n_lists,
6363
metric=self.metric,

sklbench/report/implementation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ def apply_rules_for_sheet(sheet, time_color_scale, metric_color_scale):
288288
]
289289
)
290290
is_time = any(
291-
[isinstance(cell.value, str) and "time[ms]" in cell.value for cell in column]
291+
[
292+
isinstance(cell.value, str)
293+
and ("time[ms]" in cell.value or "throughput[samples/ms]" in cell.value)
294+
for cell in column
295+
]
292296
)
293297
if is_rel_impr:
294298
cell_range = f"${column_idx}1:${column_idx}{len(column)}"

0 commit comments

Comments
 (0)