Skip to content

Commit 509dbba

Browse files
committed
Add modin format; fix for faiss ivf_pq compatibility
1 parent 917cc32 commit 509dbba

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

envs/conda-env-sklearn.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ dependencies:
99
- catboost
1010
- lightgbm
1111
- faiss-cpu
12-
- intel::scikit-learn-intelex
13-
- intel::daal4py
12+
- scikit-learn-intelex
13+
- modin-all
1414
# sklbench dependencies
1515
- scikit-learn
1616
- pandas

sklbench/datasets/transformer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def convert_data(data, dformat: str, order: str, dtype: str):
4040
if data.ndim == 1:
4141
return pd.Series(data)
4242
return pd.DataFrame(data)
43+
elif dformat == "modin":
44+
import modin.pandas as modin_pd
45+
46+
if data.ndim == 1:
47+
return modin_pd.Series(data)
48+
return modin_pd.DataFrame(data)
4349
elif dformat == "cudf":
4450
import cudf
4551

sklbench/emulators/common/neighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_m_subvectors(self, percentile, d):
4242
"""Method to get `m_subvectors` closest to specific percentile and
4343
compatible with RAFT and FAISS"""
4444
raft_comp = np.arange(1, d // 16) * 16
45-
faiss_comp = np.arange(1, d)
45+
faiss_comp = np.array([1, 2, 3, 4, 8, 12, 16, 20, 24, 28, 32, 40, 48])
4646
faiss_comp = faiss_comp[d % faiss_comp == 0]
4747
intersection = np.intersect1d(raft_comp, faiss_comp)
4848
if len(intersection) == 0:

0 commit comments

Comments
 (0)