diff --git a/diffsync/__init__.py b/diffsync/__init__.py index b82b015b..498233e9 100644 --- a/diffsync/__init__.py +++ b/diffsync/__init__.py @@ -23,7 +23,7 @@ from .diff import Diff from .enum import DiffSyncModelFlags, DiffSyncFlags, DiffSyncStatus -from .exceptions import ObjectAlreadyExists, ObjectStoreWrongType, ObjectNotFound +from .exceptions import DiffClassMismatch, ObjectAlreadyExists, ObjectStoreWrongType, ObjectNotFound from .helpers import DiffSyncDiffer, DiffSyncSyncer @@ -461,7 +461,8 @@ def sync_from( diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[Text, int, int], None]] = None, - ): + diff: Optional[Diff] = None, + ): # pylint: disable=too-many-arguments: """Synchronize data from the given source DiffSync object into the current DiffSync object. Args: @@ -470,8 +471,17 @@ def sync_from( flags (DiffSyncFlags): Flags influencing the behavior of this sync. callback (function): Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff and subsequent sync proceed. + diff (Diff): An existing diff to be used rather than generating a completely new diff. """ - diff = self.diff_from(source, diff_class=diff_class, flags=flags, callback=callback) + if diff_class and diff: + if not isinstance(diff, diff_class): + raise DiffClassMismatch( + f"The provided diff's class ({diff.__class__.__name__}) does not match the diff_class: {diff_class.__name__}", + ) + + # Generate the diff if an existing diff was not provided + if not diff: + diff = self.diff_from(source, diff_class=diff_class, flags=flags, callback=callback) syncer = DiffSyncSyncer(diff=diff, src_diffsync=source, dst_diffsync=self, flags=flags, callback=callback) result = syncer.perform_sync() if result: @@ -483,7 +493,8 @@ def sync_to( diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[Text, int, int], None]] = None, - ): + diff: Optional[Diff] = None, + ): # pylint: disable=too-many-arguments """Synchronize data from the current DiffSync object into the given target DiffSync object. Args: @@ -492,8 +503,9 @@ def sync_to( flags (DiffSyncFlags): Flags influencing the behavior of this sync. callback (function): Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff and subsequent sync proceed. + diff (Diff): An existing diff that will be used when determining what needs to be synced. """ - target.sync_from(self, diff_class=diff_class, flags=flags, callback=callback) + target.sync_from(self, diff_class=diff_class, flags=flags, callback=callback, diff=diff) def sync_complete( self, diff --git a/diffsync/exceptions.py b/diffsync/exceptions.py index 809954e0..b604c74e 100644 --- a/diffsync/exceptions.py +++ b/diffsync/exceptions.py @@ -51,3 +51,11 @@ class ObjectNotFound(ObjectStoreException): class ObjectStoreWrongType(ObjectStoreException): """Exception raised when trying to store a DiffSyncModel of the wrong type.""" + + +class DiffException(Exception): + """Base class for various failures related to Diff operations.""" + + +class DiffClassMismatch(DiffException): + """Exception raised when a diff object is not the same as the expected diff_class.""" diff --git a/examples/01-multiple-data-sources/README.md b/examples/01-multiple-data-sources/README.md index e833dc18..cbb0e0ed 100644 --- a/examples/01-multiple-data-sources/README.md +++ b/examples/01-multiple-data-sources/README.md @@ -57,6 +57,8 @@ Synchronize A and B (update B with the contents of A): ```python a.sync_to(b) print(a.diff_to(b).str()) +# Alternatively you can pass in the diff object from above to prevent another diff calculation +# a.sync_to(b, diff=diff_a_b) ``` Now A and B will show no differences: diff --git a/examples/01-multiple-data-sources/main.py b/examples/01-multiple-data-sources/main.py index fb5cff99..0665059f 100755 --- a/examples/01-multiple-data-sources/main.py +++ b/examples/01-multiple-data-sources/main.py @@ -69,7 +69,7 @@ def main(): pprint.pprint(diff_a_b.dict(), width=120) print("Syncing changes from Backend A to Backend B...") - backend_a.sync_to(backend_b) + backend_a.sync_to(backend_b, diff=diff_a_b) print("Getting updated diffs from Backend A to Backend B...") print(backend_a.diff_to(backend_b).str()) diff --git a/examples/03-remote-system/main.py b/examples/03-remote-system/main.py index 66dc6ffd..f00317eb 100644 --- a/examples/03-remote-system/main.py +++ b/examples/03-remote-system/main.py @@ -40,8 +40,10 @@ def main(): print(diff.str()) if args.sync: + if not args.diff: + diff = None print("Updating the list of countries in Nautobot ...") - nautobot.sync_from(local, flags=flags, diff_class=AlphabeticalOrderDiff) + nautobot.sync_from(local, flags=flags, diff_class=AlphabeticalOrderDiff, diff=diff) if __name__ == "__main__": diff --git a/examples/04-get-update-instantiate/backends.py b/examples/04-get-update-instantiate/backends.py index a79e5031..bbc06e43 100644 --- a/examples/04-get-update-instantiate/backends.py +++ b/examples/04-get-update-instantiate/backends.py @@ -15,7 +15,7 @@ limitations under the License. """ -from models import Site, Device, Interface +from models import Site, Device, Interface # pylint: disable=no-name-in-module from diffsync import DiffSync BACKEND_DATA_A = [ diff --git a/tests/unit/test_diffsync.py b/tests/unit/test_diffsync.py index 4d32f98d..47c06c92 100644 --- a/tests/unit/test_diffsync.py +++ b/tests/unit/test_diffsync.py @@ -20,7 +20,7 @@ import pytest from diffsync import DiffSync, DiffSyncModel, DiffSyncFlags, DiffSyncModelFlags -from diffsync.exceptions import ObjectAlreadyExists, ObjectNotFound, ObjectCrudException +from diffsync.exceptions import DiffClassMismatch, ObjectAlreadyExists, ObjectNotFound, ObjectCrudException from .conftest import Site, Device, Interface, TrackedDiff, BackendA, PersonA @@ -468,6 +468,57 @@ def callback(stage, current, total): assert last_value == {"current": expected, "total": expected} +def test_diffsync_sync_to_w_different_diff_class_raises(backend_a, backend_b): + diff = backend_b.diff_to(backend_a) + with pytest.raises(DiffClassMismatch) as failure: + backend_b.sync_to(backend_a, diff_class=TrackedDiff, diff=diff) + assert failure.value.args[0] == "The provided diff's class (Diff) does not match the diff_class: TrackedDiff" + + +def test_diffsync_sync_to_w_diff_no_mocks(backend_a, backend_b): + diff = backend_b.diff_to(backend_a) + assert diff.has_diffs() + # Perform full sync + backend_b.sync_to(backend_a, diff=diff) + # Assert there are no diffs after synchronization + post_diff = backend_b.diff_to(backend_a) + assert not post_diff.has_diffs() + + +def test_diffsync_sync_to_w_diff(backend_a, backend_b): + diff = backend_b.diff_to(backend_a) + assert diff.has_diffs() + # Mock diff_from to make sure it's not called when passing in an existing diff + backend_b.diff_from = mock.Mock() + backend_b.diff_to = mock.Mock() + backend_a.diff_from = mock.Mock() + backend_a.diff_to = mock.Mock() + # Perform full sync + backend_b.sync_to(backend_a, diff=diff) + # Assert none of the diff methods have been called + assert not backend_b.diff_from.called + assert not backend_b.diff_to.called + assert not backend_a.diff_from.called + assert not backend_a.diff_to.called + + +def test_diffsync_sync_from_w_diff(backend_a, backend_b): + diff = backend_a.diff_from(backend_b) + assert diff.has_diffs() + # Mock diff_from to make sure it's not called when passing in an existing diff + backend_a.diff_from = mock.Mock() + backend_a.diff_to = mock.Mock() + backend_b.diff_from = mock.Mock() + backend_b.diff_to = mock.Mock() + # Perform full sync + backend_a.sync_from(backend_b, diff=diff) + # Assert none of the diff methods have been called + assert not backend_a.diff_from.called + assert not backend_a.diff_to.called + assert not backend_b.diff_from.called + assert not backend_b.diff_to.called + + def test_diffsync_sync_from(backend_a, backend_b): backend_a.sync_complete = mock.Mock() backend_b.sync_complete = mock.Mock() @@ -542,7 +593,6 @@ def check_successful_sync_log_sanity(log, src, dst, flags): def check_sync_logs_against_diff(diffsync, diff, log, errors_permitted=False): """Given a Diff, make sure the captured structlogs correctly correspond to its contents/actions.""" for element in diff.get_children(): - print(element) # This is kinda gross, but needed since a DiffElement stores a shortname and keys, not a unique_id uid = getattr(diffsync, element.type).create_unique_id(**element.keys)