Skip to content

Commit 953d12a

Browse files
alts: Forward-fix of ALTS queuing of handshake requests. (#6906)
* alts: Forward-fix of ALTS queuing of handshake requests.
1 parent 6ce73bf commit 953d12a

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

credentials/alts/alts_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import (
4444
)
4545

4646
const (
47-
defaultTestLongTimeout = 10 * time.Second
47+
defaultTestLongTimeout = 60 * time.Second
4848
defaultTestShortTimeout = 10 * time.Millisecond
4949
)
5050

@@ -392,17 +392,23 @@ func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress stri
392392
ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
393393
defer cancel()
394394
c := testgrpc.NewTestServiceClient(conn)
395+
success := false
395396
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
396397
_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
397398
if err == nil {
399+
success = true
398400
break
399401
}
400-
if code := status.Code(err); code == codes.Unavailable {
401-
// The server is not ready yet. Try again.
402+
if code := status.Code(err); code == codes.Unavailable || code == codes.DeadlineExceeded {
403+
// The server is not ready yet or there were too many concurrent handshakes.
404+
// Try again.
402405
continue
403406
}
404407
t.Fatalf("c.UnaryCall() failed: %v", err)
405408
}
409+
if !success {
410+
t.Fatalf("c.UnaryCall() timed out after %v", defaultTestShortTimeout)
411+
}
406412
}
407413

408414
func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ var (
6161
// control number of concurrent created (but not closed) handshakes.
6262
clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
6363
serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
64-
// errDropped occurs when maxPendingHandshakes is reached.
65-
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
6664
// errOutOfBound occurs when the handshake service returns a consumed
6765
// bytes value larger than the buffer that was passed to it originally.
6866
errOutOfBound = errors.New("handshaker service consumed bytes value is out-of-bound")
@@ -156,8 +154,8 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn,
156154
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
157155
// done, ClientHandshake returns a secure connection.
158156
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
159-
if !clientHandshakes.TryAcquire(1) {
160-
return nil, nil, errDropped
157+
if err := clientHandshakes.Acquire(ctx, 1); err != nil {
158+
return nil, nil, err
161159
}
162160
defer clientHandshakes.Release(1)
163161

@@ -209,8 +207,8 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
209207
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
210208
// done, ServerHandshake returns a secure connection.
211209
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
212-
if !serverHandshakes.TryAcquire(1) {
213-
return nil, nil, errDropped
210+
if err := serverHandshakes.Acquire(ctx, 1); err != nil {
211+
return nil, nil, err
214212
}
215213
defer serverHandshakes.Release(1)
216214

credentials/alts/internal/handshaker/handshaker_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ func (s) TestClientHandshake(t *testing.T) {
193193
}()
194194
}
195195

196-
// Ensure all errors are expected.
196+
// Ensure that there are no errors.
197197
for i := 0; i < testCase.numberOfHandshakes; i++ {
198-
if err := <-errc; err != nil && err != errDropped {
199-
t.Errorf("ClientHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
198+
if err := <-errc; err != nil {
199+
t.Errorf("ClientHandshake() = _, %v, want _, <nil>", err)
200200
}
201201
}
202202

@@ -250,10 +250,10 @@ func (s) TestServerHandshake(t *testing.T) {
250250
}()
251251
}
252252

253-
// Ensure all errors are expected.
253+
// Ensure that there are no errors.
254254
for i := 0; i < testCase.numberOfHandshakes; i++ {
255-
if err := <-errc; err != nil && err != errDropped {
256-
t.Errorf("ServerHandshake() = _, %v, want _, <nil> or %v", err, errDropped)
255+
if err := <-errc; err != nil {
256+
t.Errorf("ServerHandshake() = _, %v, want _, <nil>", err)
257257
}
258258
}
259259

0 commit comments

Comments
 (0)