Skip to content

Added operation context to authentication #1298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ private void doAdvanceOrThrow(final Throwable attemptException,
final boolean onlyRuntimeExceptions) throws Throwable {
assertTrue(attempt() < attempts);
assertNotNull(attemptException);
if (attemptException instanceof MongoOperationTimeoutException) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional seems unrelated to the rest of the PR. What's the intention of including it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a test failing due to the wrapping of the exception that later occurs.

throw attemptException;
}
if (onlyRuntimeExceptions) {
assertTrue(isRuntime(attemptException));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ <T> T getNonNullMechanismProperty(final String key, @Nullable final T defaultVal

}

abstract void authenticate(InternalConnection connection, ConnectionDescription connectionDescription);
abstract void authenticate(InternalConnection connection, ConnectionDescription connectionDescription,
OperationContext operationContext);

abstract void authenticateAsync(InternalConnection connection, ConnectionDescription connectionDescription,
SingleResultCallback<Void> callback);
OperationContext operationContext, SingleResultCallback<Void> callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
import com.mongodb.MongoServerException;
import com.mongodb.ServerApi;
import com.mongodb.connection.ClusterConnectionMode;
import com.mongodb.internal.IgnorableRequestContext;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.session.SessionContext;
import com.mongodb.internal.validator.NoOpFieldNameValidator;
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;
Expand All @@ -47,26 +43,25 @@ public final class CommandHelper {
static final String LEGACY_HELLO_LOWER = LEGACY_HELLO.toLowerCase(Locale.ROOT);

static BsonDocument executeCommand(final String database, final BsonDocument command, final ClusterConnectionMode clusterConnectionMode,
@Nullable final ServerApi serverApi, final InternalConnection internalConnection) {
return sendAndReceive(database, command, clusterConnectionMode, serverApi, internalConnection);
@Nullable final ServerApi serverApi, final InternalConnection internalConnection, final OperationContext operationContext) {
return sendAndReceive(database, command, clusterConnectionMode, serverApi, internalConnection, operationContext);
}

static BsonDocument executeCommandWithoutCheckingForFailure(final String database, final BsonDocument command,
final ClusterConnectionMode clusterConnectionMode,
@Nullable final ServerApi serverApi,
final InternalConnection internalConnection) {
final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi,
final InternalConnection internalConnection, final OperationContext operationContext) {
try {
return sendAndReceive(database, command, clusterConnectionMode, serverApi, internalConnection);
return executeCommand(database, command, clusterConnectionMode, serverApi, internalConnection, operationContext);
} catch (MongoServerException e) {
return new BsonDocument();
}
}

static void executeCommandAsync(final String database, final BsonDocument command, final ClusterConnectionMode clusterConnectionMode,
@Nullable final ServerApi serverApi, final InternalConnection internalConnection,
final SingleResultCallback<BsonDocument> callback) {
@Nullable final ServerApi serverApi, final InternalConnection internalConnection, final OperationContext operationContext,
final SingleResultCallback<BsonDocument> callback) {
internalConnection.sendAndReceiveAsync(getCommandMessage(database, command, internalConnection, clusterConnectionMode, serverApi),
new BsonDocumentCodec(), createOperationContext(NoOpSessionContext.INSTANCE, serverApi),
new BsonDocumentCodec(), operationContext,
(result, t) -> {
if (t != null) {
callback.onResult(null, t);
Expand All @@ -90,19 +85,15 @@ static boolean isCommandOk(final BsonDocument response) {
}
}

static OperationContext createOperationContext(final SessionContext sessionContext, @Nullable final ServerApi serverApi) {
return new OperationContext(IgnorableRequestContext.INSTANCE, sessionContext,
new TimeoutContext(TimeoutSettings.DEFAULT), serverApi);
}

private static BsonDocument sendAndReceive(final String database, final BsonDocument command,
final ClusterConnectionMode clusterConnectionMode,
@Nullable final ServerApi serverApi,
final InternalConnection internalConnection) {
final InternalConnection internalConnection,
final OperationContext operationContext) {
return assertNotNull(
internalConnection.sendAndReceive(
getCommandMessage(database, command, internalConnection, clusterConnectionMode, serverApi),
new BsonDocumentCodec(), createOperationContext(NoOpSessionContext.INSTANCE, serverApi))
new BsonDocumentCodec(), operationContext)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ class DefaultAuthenticator extends Authenticator implements SpeculativeAuthentic
}

@Override
void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) {
void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription,
final OperationContext operationContext) {
if (serverIsLessThanVersionFourDotZero(connectionDescription)) {
new ScramShaAuthenticator(getMongoCredentialWithCache().withMechanism(SCRAM_SHA_1), getClusterConnectionMode(), getServerApi())
.authenticate(connection, connectionDescription);
.authenticate(connection, connectionDescription, operationContext);
} else {
try {
setDelegate(connectionDescription);
delegate.authenticate(connection, connectionDescription);
delegate.authenticate(connection, connectionDescription, operationContext);
} catch (Exception e) {
throw wrapException(e);
}
Expand All @@ -63,13 +64,13 @@ void authenticate(final InternalConnection connection, final ConnectionDescripti

@Override
void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription,
final SingleResultCallback<Void> callback) {
final OperationContext operationContext, final SingleResultCallback<Void> callback) {
if (serverIsLessThanVersionFourDotZero(connectionDescription)) {
new ScramShaAuthenticator(getMongoCredentialWithCache().withMechanism(SCRAM_SHA_1), getClusterConnectionMode(), getServerApi())
.authenticateAsync(connection, connectionDescription, callback);
.authenticateAsync(connection, connectionDescription, operationContext, callback);
} else {
setDelegate(connectionDescription);
delegate.authenticateAsync(connection, connectionDescription, callback);
delegate.authenticateAsync(connection, connectionDescription, operationContext, callback);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,10 @@ private void initialize() {

private void pingServer(final InternalConnection connection) {
long start = System.nanoTime();
OperationContext operationContext = operationContextFactory.create();
executeCommand("admin",
new BsonDocument(getHandshakeCommandName(connection.getInitialServerDescription()), new BsonInt32(1)),
clusterConnectionMode, serverApi, connection);
clusterConnectionMode, serverApi, connection, operationContext);
long elapsedTimeNanos = System.nanoTime() - start;
roundTripTimeSampler.addSample(elapsedTimeNanos);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@

interface InternalConnectionInitializer {

InternalConnectionInitializationDescription startHandshake(InternalConnection internalConnection);
InternalConnectionInitializationDescription startHandshake(InternalConnection internalConnection,
OperationContext operationContext);

InternalConnectionInitializationDescription finishHandshake(InternalConnection internalConnection,
InternalConnectionInitializationDescription description);
InternalConnectionInitializationDescription description,
OperationContext operationContext);

void startHandshakeAsync(InternalConnection internalConnection,
OperationContext operationContext,
SingleResultCallback<InternalConnectionInitializationDescription> callback);

void finishHandshakeAsync(InternalConnection internalConnection, InternalConnectionInitializationDescription description,
void finishHandshakeAsync(InternalConnection internalConnection,
InternalConnectionInitializationDescription description,
OperationContext operationContext,
SingleResultCallback<InternalConnectionInitializationDescription> callback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,10 @@ public void open(final OperationContext operationContext) {
try {
stream.open(operationContext);

InternalConnectionInitializationDescription initializationDescription = connectionInitializer.startHandshake(this);
InternalConnectionInitializationDescription initializationDescription = connectionInitializer.startHandshake(this, operationContext);
initAfterHandshakeStart(initializationDescription);

initializationDescription = connectionInitializer.finishHandshake(this, initializationDescription);
initializationDescription = connectionInitializer.finishHandshake(this, initializationDescription, operationContext);
initAfterHandshakeFinish(initializationDescription);
} catch (Throwable t) {
close();
Expand All @@ -226,7 +226,7 @@ public void openAsync(final OperationContext operationContext, final SingleResul

@Override
public void completed(@Nullable final Void aVoid) {
connectionInitializer.startHandshakeAsync(InternalStreamConnection.this,
connectionInitializer.startHandshakeAsync(InternalStreamConnection.this, operationContext,
(initialResult, initialException) -> {
if (initialException != null) {
close();
Expand All @@ -235,7 +235,7 @@ public void completed(@Nullable final Void aVoid) {
assertNotNull(initialResult);
initAfterHandshakeStart(initialResult);
connectionInitializer.finishHandshakeAsync(InternalStreamConnection.this,
initialResult, (completedResult, completedException) -> {
initialResult, operationContext, (completedResult, completedException) -> {
if (completedException != null) {
close();
callback.onResult(null, completedException);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,29 @@ public InternalStreamConnectionInitializer(final ClusterConnectionMode clusterCo
}

@Override
public InternalConnectionInitializationDescription startHandshake(final InternalConnection internalConnection) {
public InternalConnectionInitializationDescription startHandshake(final InternalConnection internalConnection,
final OperationContext operationContext) {
notNull("internalConnection", internalConnection);

return initializeConnectionDescription(internalConnection);
return initializeConnectionDescription(internalConnection, operationContext);
}

public InternalConnectionInitializationDescription finishHandshake(final InternalConnection internalConnection,
final InternalConnectionInitializationDescription description) {
final InternalConnectionInitializationDescription description,
final OperationContext operationContext) {
notNull("internalConnection", internalConnection);
notNull("description", description);

authenticate(internalConnection, description.getConnectionDescription());
return completeConnectionDescriptionInitialization(internalConnection, description);
authenticate(internalConnection, description.getConnectionDescription(), operationContext);
return completeConnectionDescriptionInitialization(internalConnection, description, operationContext);
}

@Override
public void startHandshakeAsync(final InternalConnection internalConnection,
public void startHandshakeAsync(final InternalConnection internalConnection, final OperationContext operationContext,
final SingleResultCallback<InternalConnectionInitializationDescription> callback) {
long startTime = System.nanoTime();
executeCommandAsync("admin", createHelloCommand(authenticator, internalConnection), clusterConnectionMode, serverApi,
internalConnection, (helloResult, t) -> {
internalConnection, operationContext, (helloResult, t) -> {
if (t != null) {
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
} else {
Expand All @@ -105,31 +107,35 @@ public void startHandshakeAsync(final InternalConnection internalConnection,
@Override
public void finishHandshakeAsync(final InternalConnection internalConnection,
final InternalConnectionInitializationDescription description,
final OperationContext operationContext,
final SingleResultCallback<InternalConnectionInitializationDescription> callback) {
if (authenticator == null || description.getConnectionDescription().getServerType()
== ServerType.REPLICA_SET_ARBITER) {
completeConnectionDescriptionInitializationAsync(internalConnection, description, callback);
completeConnectionDescriptionInitializationAsync(internalConnection, description, operationContext, callback);
} else {
authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(),
authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(), operationContext,
(result1, t1) -> {
if (t1 != null) {
callback.onResult(null, t1);
} else {
completeConnectionDescriptionInitializationAsync(internalConnection, description, callback);
completeConnectionDescriptionInitializationAsync(internalConnection, description, operationContext, callback);
}
});
}
}

private InternalConnectionInitializationDescription initializeConnectionDescription(final InternalConnection internalConnection) {
private InternalConnectionInitializationDescription initializeConnectionDescription(final InternalConnection internalConnection,
final OperationContext operationContext) {
BsonDocument helloResult;
BsonDocument helloCommandDocument = createHelloCommand(authenticator, internalConnection);

long start = System.nanoTime();
try {
helloResult = executeCommand("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection);
helloResult = executeCommand("admin", helloCommandDocument, clusterConnectionMode, serverApi, internalConnection, operationContext);
} catch (MongoException e) {
throw mapHelloException(e);
} finally {
operationContext.getTimeoutContext().resetMaintenanceTimeout();
}
setSpeculativeAuthenticateResponse(helloResult);
return createInitializationDescription(helloResult, internalConnection, start);
Expand Down Expand Up @@ -189,21 +195,23 @@ private BsonDocument createHelloCommand(final Authenticator authenticator, final

private InternalConnectionInitializationDescription completeConnectionDescriptionInitialization(
final InternalConnection internalConnection,
final InternalConnectionInitializationDescription description) {
final InternalConnectionInitializationDescription description,
final OperationContext operationContext) {

if (description.getConnectionDescription().getConnectionId().getServerValue() != null) {
return description;
}

return applyGetLastErrorResult(executeCommandWithoutCheckingForFailure("admin",
new BsonDocument("getlasterror", new BsonInt32(1)), clusterConnectionMode, serverApi,
internalConnection),
internalConnection, operationContext),
description);
}

private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription) {
private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription,
final OperationContext operationContext) {
if (authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER) {
authenticator.authenticate(internalConnection, connectionDescription);
authenticator.authenticate(internalConnection, connectionDescription, operationContext);
}
}

Expand All @@ -217,6 +225,7 @@ private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult)
private void completeConnectionDescriptionInitializationAsync(
final InternalConnection internalConnection,
final InternalConnectionInitializationDescription description,
final OperationContext operationContext,
final SingleResultCallback<InternalConnectionInitializationDescription> callback) {

if (description.getConnectionDescription().getConnectionId().getServerValue() != null) {
Expand All @@ -225,7 +234,7 @@ private void completeConnectionDescriptionInitializationAsync(
}

executeCommandAsync("admin", new BsonDocument("getlasterror", new BsonInt32(1)), clusterConnectionMode, serverApi,
internalConnection,
internalConnection, operationContext,
(result, t) -> {
if (t != null) {
callback.onResult(description, null);
Expand Down
Loading