From 58a0a76c42c48732e53918d1e3e48998873cf769 Mon Sep 17 00:00:00 2001 From: "Chayim I. Kirshen" Date: Thu, 21 Jul 2022 16:18:32 +0300 Subject: [PATCH] adding the filter for nodes and relationships unit tests support for includes regardless of path difference --- pyproject.toml | 2 +- redisgraph_bulk_loader/bulk_insert.py | 33 ++++--- redisgraph_bulk_loader/config.py | 5 +- redisgraph_bulk_loader/entity_file.py | 29 +++++- redisgraph_bulk_loader/label.py | 16 ++-- redisgraph_bulk_loader/relation_type.py | 15 +++- test/test_bulk_loader.py | 113 +++++++++++++++++++----- 7 files changed, 163 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b598872..f1474ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "redisgraph-bulk-loader" -version = "0.11.0" +version = "1.0.0" description = "RedisGraph Bulk Import Tool" authors = ["Redis Inc "] license = "BSD-3-Clause" diff --git a/redisgraph_bulk_loader/bulk_insert.py b/redisgraph_bulk_loader/bulk_insert.py index 8935710..0176f7e 100644 --- a/redisgraph_bulk_loader/bulk_insert.py +++ b/redisgraph_bulk_loader/bulk_insert.py @@ -4,22 +4,27 @@ import click import redis -from .config import Config -from .label import Label -from .query_buffer import QueryBuffer -from .relation_type import RelationType +try: + from .config import Config + from .label import Label + from .query_buffer import QueryBuffer + from .relation_type import RelationType +except: + from config import Config + from label import Label + from query_buffer import QueryBuffer + from relation_type import RelationType - -def parse_schemas(cls, query_buf, path_to_csv, csv_tuples, config): +def parse_schemas(cls, query_buf, path_to_csv, csv_tuples, config, label_column): schemas = [None] * (len(path_to_csv) + len(csv_tuples)) for idx, in_csv in enumerate(path_to_csv): # Build entity descriptor from input CSV - schemas[idx] = cls(query_buf, in_csv, None, config) + schemas[idx] = cls(query_buf, in_csv, None, config, label_column) offset = len(path_to_csv) for idx, csv_tuple in enumerate(csv_tuples): # Build entity descriptor from input CSV - schemas[idx + offset] = cls(query_buf, csv_tuple[1], csv_tuple[0], config) + schemas[idx + offset] = cls(query_buf, csv_tuple[1], csv_tuple[0], config, label_column) return schemas @@ -54,6 +59,7 @@ def process_entities(entities): "--redis-url", "-u", default="redis://127.0.0.1:6379", help="Redis connection url" ) @click.option("--nodes", "-n", multiple=True, help="Path to node csv file") +@click.option("--node-label-column", "-L", default=None, nargs=2, help="Import based on having ") @click.option( "--nodes-with-label", "-N", @@ -62,6 +68,7 @@ def process_entities(entities): help="Label string followed by path to node csv file", ) @click.option("--relations", "-r", multiple=True, help="Path to relation csv file") +@click.option("--relation-type-column", "-T", default=None, nargs=2, help="Import based on having ") @click.option( "--relations-with-type", "-R", @@ -144,8 +151,10 @@ def bulk_insert( graph, redis_url, nodes, + node_label_column, nodes_with_label, relations, + relation_type_column, relations_with_type, separator, enforce_schema, @@ -160,9 +169,7 @@ def bulk_insert( index, full_text_index, ): - if sys.version_info.major < 3 or sys.version_info.minor < 6: - raise Exception("Python >= 3.6 is required for the RedisGraph bulk loader.") - + if not (any(nodes) or any(nodes_with_label)): raise Exception("At least one node file must be specified.") @@ -216,9 +223,9 @@ def bulk_insert( query_buf = QueryBuffer(graph, client, config) # Read the header rows of each input CSV and save its schema. - labels = parse_schemas(Label, query_buf, nodes, nodes_with_label, config) + labels = parse_schemas(Label, query_buf, nodes, nodes_with_label, config, node_label_column) reltypes = parse_schemas( - RelationType, query_buf, relations, relations_with_type, config + RelationType, query_buf, relations, relations_with_type, config, relation_type_column, ) process_entities(labels) diff --git a/redisgraph_bulk_loader/config.py b/redisgraph_bulk_loader/config.py index 7b3c1a6..ca18a84 100644 --- a/redisgraph_bulk_loader/config.py +++ b/redisgraph_bulk_loader/config.py @@ -1,4 +1,7 @@ -from .exceptions import SchemaError +try: + from .exceptions import SchemaError +except: + from exceptions import SchemaError class Config: diff --git a/redisgraph_bulk_loader/entity_file.py b/redisgraph_bulk_loader/entity_file.py index 8f9450d..bc04e49 100644 --- a/redisgraph_bulk_loader/entity_file.py +++ b/redisgraph_bulk_loader/entity_file.py @@ -7,7 +7,10 @@ import sys from enum import Enum -from .exceptions import CSVError, SchemaError +try: + from .exceptions import CSVError, SchemaError +except: + from exceptions import CSVError, SchemaError csv.field_size_limit(sys.maxsize) # Don't limit the size of user input fields. @@ -179,7 +182,7 @@ def inferred_prop_to_binary(prop_val): class EntityFile(object): """Superclass for Label and RelationType classes""" - def __init__(self, filename, label, config): + def __init__(self, filename, label, config, filter_column=None): # The configurations for this run. self.config = config @@ -204,10 +207,30 @@ def __init__(self, filename, label, config): self.packed_header = b"" self.binary_entities = [] self.binary_size = 0 # size of binary token - + self.convert_header() # Extract data from header row. self.count_entities() # Count number of entities/row in file. + + if filter_column is None: + self.__FILTER_ID__ = -1 + self.__FILTER_VALUE__ = None + else: + try: + self.__FILTER_ID__ = self.column_names.index(filter_column[0]) + self.__FILTER_VALUE__ = filter_column[1] + except ValueError: # it doesn't have to apply in the multiple file case + self.__FILTER_ID__ = -1 + self.__FILTER_VALUE__ = None + next(self.reader) # Skip the header row. + + @property + def filter_value(self): + return self.__FILTER_VALUE__ + + @property + def filter_column_id(self): + return self.__FILTER_ID__ # Count number of rows in file. def count_entities(self): diff --git a/redisgraph_bulk_loader/label.py b/redisgraph_bulk_loader/label.py index f597c03..4588397 100644 --- a/redisgraph_bulk_loader/label.py +++ b/redisgraph_bulk_loader/label.py @@ -3,17 +3,21 @@ import click -from .entity_file import EntityFile, Type -from .exceptions import SchemaError +try: + from .entity_file import EntityFile, Type + from .exceptions import SchemaError +except: + from entity_file import EntityFile, Type + from exceptions import SchemaError class Label(EntityFile): """Handler class for processing Label CSV files.""" - def __init__(self, query_buffer, infile, label_str, config): + def __init__(self, query_buffer, infile, label_str, config, filter_column=None): self.id_namespace = None self.query_buffer = query_buffer - super(Label, self).__init__(infile, label_str, config) + super(Label, self).__init__(infile, label_str, config, filter_column) def process_schemaless_header(self, header): # The first column is the ID. @@ -70,6 +74,8 @@ def process_entities(self): ) as reader: for row in reader: self.validate_row(row) + if self.filter_value is not None and row[self.filter_column_id] != self.filter_value: + continue # Update the node identifier dictionary if necessary if self.config.store_node_identifiers: @@ -107,5 +113,5 @@ def process_entities(self): self.binary_size += row_binary_len self.binary_entities.append(row_binary) self.query_buffer.labels.append(self.to_binary()) - self.infile.close() + self.infile.close() print("%d nodes created with label '%s'" % (entities_created, self.entity_str)) diff --git a/redisgraph_bulk_loader/relation_type.py b/redisgraph_bulk_loader/relation_type.py index 63b10c9..8effd7e 100644 --- a/redisgraph_bulk_loader/relation_type.py +++ b/redisgraph_bulk_loader/relation_type.py @@ -3,14 +3,18 @@ import click -from .entity_file import EntityFile, Type -from .exceptions import CSVError, SchemaError +try: + from .entity_file import EntityFile, Type + from .exceptions import CSVError, SchemaError +except: + from entity_file import EntityFile, Type + from exceptions import CSVError, SchemaError # Handler class for processing relation csv files. class RelationType(EntityFile): - def __init__(self, query_buffer, infile, type_str, config): - super(RelationType, self).__init__(infile, type_str, config) + def __init__(self, query_buffer, infile, type_str, config, filter_column=None): + super(RelationType, self).__init__(infile, type_str, config, filter_column=None) self.query_buffer = query_buffer def process_schemaless_header(self, header): @@ -63,6 +67,9 @@ def process_entities(self): ) as reader: for row in reader: self.validate_row(row) + if self.filter_value is not None and row[self.filter_column_id] != self.filter_value: + continue + try: start_id = row[self.start_id] if self.start_namespace: diff --git a/test/test_bulk_loader.py b/test/test_bulk_loader.py index eacd928..4e57792 100644 --- a/test/test_bulk_loader.py +++ b/test/test_bulk_loader.py @@ -29,6 +29,13 @@ def row_count(in_csv): class TestBulkLoader: + csv_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), + "..", "example")) + person_file = os.path.join(csv_path, "Person.csv") + country_file = os.path.join(csv_path ,"Country.csv") + knows_file = os.path.join(csv_path, "KNOWS.csv") + visited_file = os.path.join(csv_path, "VISITED.csv") + redis_con = redis.Redis(decode_responses=True) @classmethod @@ -62,29 +69,23 @@ def test_social_graph(self): graphname = "social" runner = CliRunner() - csv_path = os.path.dirname(os.path.abspath(__file__)) + "/../example/" - person_file = csv_path + "Person.csv" - country_file = csv_path + "Country.csv" - knows_file = csv_path + "KNOWS.csv" - visited_file = csv_path + "VISITED.csv" - # Set the globals for node edge counts, as they will be reused. - person_count = str(row_count(person_file)) - country_count = str(row_count(country_file)) - knows_count = str(row_count(knows_file)) - visited_count = str(row_count(visited_file)) + person_count = str(row_count(self.person_file)) + country_count = str(row_count(self.country_file)) + knows_count = str(row_count(self.knows_file)) + visited_count = str(row_count(self.visited_file)) res = runner.invoke( bulk_insert, [ "--nodes", - person_file, + self.person_file, "--nodes", - country_file, + self.country_file, "--relations", - knows_file, + self.knows_file, "--relations", - visited_file, + self.visited_file, graphname, ], ) @@ -290,20 +291,86 @@ def test_reused_identifier(self): # The script should succeed and create 3 nodes assert res.exit_code == 0 assert "3 nodes created" in res.output + + def test_filtered_nodes(self): + """Create a nodeset using a filtered set""" + self.redis_con.flushall() + graphname = "filtered_set" + runner = CliRunner() + + res = runner.invoke( + bulk_insert, + [ + "--nodes", + self.person_file, + "-L", + "status", + "single", + graphname, + ], + catch_exceptions=False + ) + assert res.exit_code == 0 + assert "4 nodes created" in res.output + + # and now multiple files at once + self.redis_con.flushall() + graphname = "filtered_set" + runner = CliRunner() + + res = runner.invoke( + bulk_insert, + [ + "-n", + self.person_file, + "-n", + self.country_file, + "-L", + "status", + "single", + graphname, + ], + catch_exceptions=False + ) + assert res.exit_code == 0 + assert "13 nodes created" in res.output + + def test_filtered_relations(self): + """Create a filtered relation set""" + self.redis_con.flushall() + graphname = "filtered_set" + runner = CliRunner() + + res = runner.invoke( + bulk_insert, + [ + "--nodes", + self.person_file, + "--nodes", + self.country_file, + "--relations", + self.knows_file, + "--relations", + self.visited_file, + "-T", + "purpose", + "pleasure", + graphname, + ], + catch_exceptions=False + ) + assert res.exit_code == 0 + assert "48 relations created" in res.output def test_batched_build(self): """ Create a graph using many batches. Reuses the inputs of test01_social_graph """ + self.test_social_graph() graphname = "batched_graph" runner = CliRunner() - csv_path = os.path.dirname(os.path.abspath(__file__)) + "/../example/" - person_file = csv_path + "Person.csv" - country_file = csv_path + "Country.csv" - knows_file = csv_path + "KNOWS.csv" - visited_file = csv_path + "VISITED.csv" csv_path = ( os.path.dirname(os.path.abspath(__file__)) + "/../../demo/bulk_insert/resources/" @@ -313,13 +380,13 @@ def test_batched_build(self): bulk_insert, [ "--nodes", - person_file, + self.person_file, "--nodes", - country_file, + self.country_file, "--relations", - knows_file, + self.knows_file, "--relations", - visited_file, + self.visited_file, "--max-token-count", 1, graphname,