Skip to content

Commit 58a0a76

Browse files
committed
adding the filter for nodes and relationships
unit tests support for includes regardless of path difference
1 parent c6248f6 commit 58a0a76

File tree

7 files changed

+163
-50
lines changed

7 files changed

+163
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "redisgraph-bulk-loader"
3-
version = "0.11.0"
3+
version = "1.0.0"
44
description = "RedisGraph Bulk Import Tool"
55
authors = ["Redis Inc <oss@redis.com>"]
66
license = "BSD-3-Clause"

redisgraph_bulk_loader/bulk_insert.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,27 @@
44
import click
55
import redis
66

7-
from .config import Config
8-
from .label import Label
9-
from .query_buffer import QueryBuffer
10-
from .relation_type import RelationType
7+
try:
8+
from .config import Config
9+
from .label import Label
10+
from .query_buffer import QueryBuffer
11+
from .relation_type import RelationType
12+
except:
13+
from config import Config
14+
from label import Label
15+
from query_buffer import QueryBuffer
16+
from relation_type import RelationType
1117

12-
13-
def parse_schemas(cls, query_buf, path_to_csv, csv_tuples, config):
18+
def parse_schemas(cls, query_buf, path_to_csv, csv_tuples, config, label_column):
1419
schemas = [None] * (len(path_to_csv) + len(csv_tuples))
1520
for idx, in_csv in enumerate(path_to_csv):
1621
# Build entity descriptor from input CSV
17-
schemas[idx] = cls(query_buf, in_csv, None, config)
22+
schemas[idx] = cls(query_buf, in_csv, None, config, label_column)
1823

1924
offset = len(path_to_csv)
2025
for idx, csv_tuple in enumerate(csv_tuples):
2126
# Build entity descriptor from input CSV
22-
schemas[idx + offset] = cls(query_buf, csv_tuple[1], csv_tuple[0], config)
27+
schemas[idx + offset] = cls(query_buf, csv_tuple[1], csv_tuple[0], config, label_column)
2328
return schemas
2429

2530

@@ -54,6 +59,7 @@ def process_entities(entities):
5459
"--redis-url", "-u", default="redis://127.0.0.1:6379", help="Redis connection url"
5560
)
5661
@click.option("--nodes", "-n", multiple=True, help="Path to node csv file")
62+
@click.option("--node-label-column", "-L", default=None, nargs=2, help="Import based on <column> having <value>")
5763
@click.option(
5864
"--nodes-with-label",
5965
"-N",
@@ -62,6 +68,7 @@ def process_entities(entities):
6268
help="Label string followed by path to node csv file",
6369
)
6470
@click.option("--relations", "-r", multiple=True, help="Path to relation csv file")
71+
@click.option("--relation-type-column", "-T", default=None, nargs=2, help="Import based on <column> having <value>")
6572
@click.option(
6673
"--relations-with-type",
6774
"-R",
@@ -144,8 +151,10 @@ def bulk_insert(
144151
graph,
145152
redis_url,
146153
nodes,
154+
node_label_column,
147155
nodes_with_label,
148156
relations,
157+
relation_type_column,
149158
relations_with_type,
150159
separator,
151160
enforce_schema,
@@ -160,9 +169,7 @@ def bulk_insert(
160169
index,
161170
full_text_index,
162171
):
163-
if sys.version_info.major < 3 or sys.version_info.minor < 6:
164-
raise Exception("Python >= 3.6 is required for the RedisGraph bulk loader.")
165-
172+
166173
if not (any(nodes) or any(nodes_with_label)):
167174
raise Exception("At least one node file must be specified.")
168175

@@ -216,9 +223,9 @@ def bulk_insert(
216223
query_buf = QueryBuffer(graph, client, config)
217224

218225
# Read the header rows of each input CSV and save its schema.
219-
labels = parse_schemas(Label, query_buf, nodes, nodes_with_label, config)
226+
labels = parse_schemas(Label, query_buf, nodes, nodes_with_label, config, node_label_column)
220227
reltypes = parse_schemas(
221-
RelationType, query_buf, relations, relations_with_type, config
228+
RelationType, query_buf, relations, relations_with_type, config, relation_type_column,
222229
)
223230

224231
process_entities(labels)

redisgraph_bulk_loader/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from .exceptions import SchemaError
1+
try:
2+
from .exceptions import SchemaError
3+
except:
4+
from exceptions import SchemaError
25

36

47
class Config:

redisgraph_bulk_loader/entity_file.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import sys
88
from enum import Enum
99

10-
from .exceptions import CSVError, SchemaError
10+
try:
11+
from .exceptions import CSVError, SchemaError
12+
except:
13+
from exceptions import CSVError, SchemaError
1114

1215
csv.field_size_limit(sys.maxsize) # Don't limit the size of user input fields.
1316

@@ -179,7 +182,7 @@ def inferred_prop_to_binary(prop_val):
179182
class EntityFile(object):
180183
"""Superclass for Label and RelationType classes"""
181184

182-
def __init__(self, filename, label, config):
185+
def __init__(self, filename, label, config, filter_column=None):
183186
# The configurations for this run.
184187
self.config = config
185188

@@ -204,10 +207,30 @@ def __init__(self, filename, label, config):
204207
self.packed_header = b""
205208
self.binary_entities = []
206209
self.binary_size = 0 # size of binary token
207-
210+
208211
self.convert_header() # Extract data from header row.
209212
self.count_entities() # Count number of entities/row in file.
213+
214+
if filter_column is None:
215+
self.__FILTER_ID__ = -1
216+
self.__FILTER_VALUE__ = None
217+
else:
218+
try:
219+
self.__FILTER_ID__ = self.column_names.index(filter_column[0])
220+
self.__FILTER_VALUE__ = filter_column[1]
221+
except ValueError: # it doesn't have to apply in the multiple file case
222+
self.__FILTER_ID__ = -1
223+
self.__FILTER_VALUE__ = None
224+
210225
next(self.reader) # Skip the header row.
226+
227+
@property
228+
def filter_value(self):
229+
return self.__FILTER_VALUE__
230+
231+
@property
232+
def filter_column_id(self):
233+
return self.__FILTER_ID__
211234

212235
# Count number of rows in file.
213236
def count_entities(self):

redisgraph_bulk_loader/label.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,21 @@
33

44
import click
55

6-
from .entity_file import EntityFile, Type
7-
from .exceptions import SchemaError
6+
try:
7+
from .entity_file import EntityFile, Type
8+
from .exceptions import SchemaError
9+
except:
10+
from entity_file import EntityFile, Type
11+
from exceptions import SchemaError
812

913

1014
class Label(EntityFile):
1115
"""Handler class for processing Label CSV files."""
1216

13-
def __init__(self, query_buffer, infile, label_str, config):
17+
def __init__(self, query_buffer, infile, label_str, config, filter_column=None):
1418
self.id_namespace = None
1519
self.query_buffer = query_buffer
16-
super(Label, self).__init__(infile, label_str, config)
20+
super(Label, self).__init__(infile, label_str, config, filter_column)
1721

1822
def process_schemaless_header(self, header):
1923
# The first column is the ID.
@@ -70,6 +74,8 @@ def process_entities(self):
7074
) as reader:
7175
for row in reader:
7276
self.validate_row(row)
77+
if self.filter_value is not None and row[self.filter_column_id] != self.filter_value:
78+
continue
7379

7480
# Update the node identifier dictionary if necessary
7581
if self.config.store_node_identifiers:
@@ -107,5 +113,5 @@ def process_entities(self):
107113
self.binary_size += row_binary_len
108114
self.binary_entities.append(row_binary)
109115
self.query_buffer.labels.append(self.to_binary())
110-
self.infile.close()
116+
self.infile.close()
111117
print("%d nodes created with label '%s'" % (entities_created, self.entity_str))

redisgraph_bulk_loader/relation_type.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33

44
import click
55

6-
from .entity_file import EntityFile, Type
7-
from .exceptions import CSVError, SchemaError
6+
try:
7+
from .entity_file import EntityFile, Type
8+
from .exceptions import CSVError, SchemaError
9+
except:
10+
from entity_file import EntityFile, Type
11+
from exceptions import CSVError, SchemaError
812

913

1014
# Handler class for processing relation csv files.
1115
class RelationType(EntityFile):
12-
def __init__(self, query_buffer, infile, type_str, config):
13-
super(RelationType, self).__init__(infile, type_str, config)
16+
def __init__(self, query_buffer, infile, type_str, config, filter_column=None):
17+
super(RelationType, self).__init__(infile, type_str, config, filter_column=None)
1418
self.query_buffer = query_buffer
1519

1620
def process_schemaless_header(self, header):
@@ -63,6 +67,9 @@ def process_entities(self):
6367
) as reader:
6468
for row in reader:
6569
self.validate_row(row)
70+
if self.filter_value is not None and row[self.filter_column_id] != self.filter_value:
71+
continue
72+
6673
try:
6774
start_id = row[self.start_id]
6875
if self.start_namespace:

test/test_bulk_loader.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def row_count(in_csv):
2929

3030
class TestBulkLoader:
3131

32+
csv_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)),
33+
"..", "example"))
34+
person_file = os.path.join(csv_path, "Person.csv")
35+
country_file = os.path.join(csv_path ,"Country.csv")
36+
knows_file = os.path.join(csv_path, "KNOWS.csv")
37+
visited_file = os.path.join(csv_path, "VISITED.csv")
38+
3239
redis_con = redis.Redis(decode_responses=True)
3340

3441
@classmethod
@@ -62,29 +69,23 @@ def test_social_graph(self):
6269
graphname = "social"
6370
runner = CliRunner()
6471

65-
csv_path = os.path.dirname(os.path.abspath(__file__)) + "/../example/"
66-
person_file = csv_path + "Person.csv"
67-
country_file = csv_path + "Country.csv"
68-
knows_file = csv_path + "KNOWS.csv"
69-
visited_file = csv_path + "VISITED.csv"
70-
7172
# Set the globals for node edge counts, as they will be reused.
72-
person_count = str(row_count(person_file))
73-
country_count = str(row_count(country_file))
74-
knows_count = str(row_count(knows_file))
75-
visited_count = str(row_count(visited_file))
73+
person_count = str(row_count(self.person_file))
74+
country_count = str(row_count(self.country_file))
75+
knows_count = str(row_count(self.knows_file))
76+
visited_count = str(row_count(self.visited_file))
7677

7778
res = runner.invoke(
7879
bulk_insert,
7980
[
8081
"--nodes",
81-
person_file,
82+
self.person_file,
8283
"--nodes",
83-
country_file,
84+
self.country_file,
8485
"--relations",
85-
knows_file,
86+
self.knows_file,
8687
"--relations",
87-
visited_file,
88+
self.visited_file,
8889
graphname,
8990
],
9091
)
@@ -290,20 +291,86 @@ def test_reused_identifier(self):
290291
# The script should succeed and create 3 nodes
291292
assert res.exit_code == 0
292293
assert "3 nodes created" in res.output
294+
295+
def test_filtered_nodes(self):
296+
"""Create a nodeset using a filtered set"""
297+
self.redis_con.flushall()
298+
graphname = "filtered_set"
299+
runner = CliRunner()
300+
301+
res = runner.invoke(
302+
bulk_insert,
303+
[
304+
"--nodes",
305+
self.person_file,
306+
"-L",
307+
"status",
308+
"single",
309+
graphname,
310+
],
311+
catch_exceptions=False
312+
)
313+
assert res.exit_code == 0
314+
assert "4 nodes created" in res.output
315+
316+
# and now multiple files at once
317+
self.redis_con.flushall()
318+
graphname = "filtered_set"
319+
runner = CliRunner()
320+
321+
res = runner.invoke(
322+
bulk_insert,
323+
[
324+
"-n",
325+
self.person_file,
326+
"-n",
327+
self.country_file,
328+
"-L",
329+
"status",
330+
"single",
331+
graphname,
332+
],
333+
catch_exceptions=False
334+
)
335+
assert res.exit_code == 0
336+
assert "13 nodes created" in res.output
337+
338+
def test_filtered_relations(self):
339+
"""Create a filtered relation set"""
340+
self.redis_con.flushall()
341+
graphname = "filtered_set"
342+
runner = CliRunner()
343+
344+
res = runner.invoke(
345+
bulk_insert,
346+
[
347+
"--nodes",
348+
self.person_file,
349+
"--nodes",
350+
self.country_file,
351+
"--relations",
352+
self.knows_file,
353+
"--relations",
354+
self.visited_file,
355+
"-T",
356+
"purpose",
357+
"pleasure",
358+
graphname,
359+
],
360+
catch_exceptions=False
361+
)
362+
assert res.exit_code == 0
363+
assert "48 relations created" in res.output
293364

294365
def test_batched_build(self):
295366
"""
296367
Create a graph using many batches.
297368
Reuses the inputs of test01_social_graph
298369
"""
370+
self.test_social_graph()
299371
graphname = "batched_graph"
300372
runner = CliRunner()
301373

302-
csv_path = os.path.dirname(os.path.abspath(__file__)) + "/../example/"
303-
person_file = csv_path + "Person.csv"
304-
country_file = csv_path + "Country.csv"
305-
knows_file = csv_path + "KNOWS.csv"
306-
visited_file = csv_path + "VISITED.csv"
307374
csv_path = (
308375
os.path.dirname(os.path.abspath(__file__))
309376
+ "/../../demo/bulk_insert/resources/"
@@ -313,13 +380,13 @@ def test_batched_build(self):
313380
bulk_insert,
314381
[
315382
"--nodes",
316-
person_file,
383+
self.person_file,
317384
"--nodes",
318-
country_file,
385+
self.country_file,
319386
"--relations",
320-
knows_file,
387+
self.knows_file,
321388
"--relations",
322-
visited_file,
389+
self.visited_file,
323390
"--max-token-count",
324391
1,
325392
graphname,

0 commit comments

Comments
 (0)