Skip to content

Commit 22ce12e

Browse files
committed
Add SVS NearestNeighbors emulator
1 parent 37b21d3 commit 22ce12e

File tree

5 files changed

+103
-6
lines changed

5 files changed

+103
-6
lines changed

configs/experiments/nearest_neighbors.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@
3535
"m_subvectors": 0.2
3636
}
3737
}
38+
},
39+
{
40+
"algorithm": {
41+
"library": "sklbench.emulators.svs",
42+
"device": "cpu",
43+
"estimator_params": {
44+
"algorithm": "vamana",
45+
"graph_max_degree": 128,
46+
"window_size": 256
47+
}
48+
}
3849
}
3950
],
4051
"nearest neighbors common parameters": {

sklbench/emulators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17-
__all__ = ["common", "faiss", "raft"]
17+
__all__ = ["common", "faiss", "raft", "svs"]

sklbench/emulators/common/neighbors.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ def get_params(self):
2727
"metric": self.metric,
2828
"metric_params": None,
2929
"p": 2 if "euclidean" in self.metric else None,
30-
"n_lists": self.n_lists,
31-
"n_probes": self.n_probes,
32-
"m_subvectors": self.m_subvectors,
33-
"n_bits": self.n_bits,
3430
}
35-
optional_keys = ["intermediate_graph_degree", "graph_degree"]
31+
optional_keys = [
32+
"n_lists",
33+
"n_probes",
34+
"m_subvectors",
35+
"n_bits",
36+
"intermediate_graph_degree",
37+
"graph_degree",
38+
]
3639
for optional_key in optional_keys:
3740
if hasattr(self, optional_key):
3841
result[optional_key] = getattr(self, optional_key)

sklbench/emulators/svs/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# ===============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ===============================================================================
16+
17+
from .neighbors import NearestNeighbors
18+
19+
20+
__all__ = ["NearestNeighbors"]

sklbench/emulators/svs/neighbors.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ===============================================================================
2+
# Copyright 2024 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ===============================================================================
16+
17+
import pysvs
18+
from psutil import cpu_count
19+
from ..common.neighbors import NearestNeighborsBase
20+
21+
22+
class NearestNeighbors(NearestNeighborsBase):
23+
"""
24+
Minimal class emulating `sklearn.neighbors.NearestNeighbors` estimator
25+
"""
26+
27+
def __init__(
28+
self,
29+
n_neighbors=5,
30+
algorithm="vamana",
31+
metric="euclidean",
32+
graph_max_degree=64,
33+
window_size=128,
34+
n_jobs=cpu_count(logical=False),
35+
):
36+
self.n_neighbors = n_neighbors
37+
self.algorithm = algorithm
38+
self.metric = metric
39+
self.graph_max_degree = graph_max_degree
40+
self.window_size = window_size
41+
self.n_jobs = n_jobs
42+
43+
def fit(self, X, y=None):
44+
build_params = pysvs.VamanaBuildParameters(
45+
graph_max_degree=self.graph_max_degree,
46+
window_size=self.window_size,
47+
num_threads=self.n_jobs,
48+
)
49+
self._index = pysvs.Vamana.build(
50+
build_params,
51+
X,
52+
pysvs.DistanceType.L2,
53+
num_threads=self.n_jobs,
54+
)
55+
return self
56+
57+
def kneighbors(self, X, n_neighbors=None, return_distance=True):
58+
k = self.n_neighbors if n_neighbors is None else n_neighbors
59+
indices, distances = self._index.search(X, k)
60+
if return_distance:
61+
return distances, indices
62+
else:
63+
return indices

0 commit comments

Comments
 (0)