|
19 | 19 | import org.springframework.web.server.ServerWebExchange;
|
20 | 20 | import org.springframework.web.server.WebFilter;
|
21 | 21 | import org.springframework.web.server.WebFilterChain;
|
| 22 | +import reactor.core.publisher.Flux; |
22 | 23 | import reactor.core.publisher.Mono;
|
23 | 24 | import reactor.core.scheduler.Schedulers;
|
24 | 25 |
|
25 | 26 | import java.time.Instant;
|
| 27 | +import java.util.Collection; |
| 28 | +import java.util.List; |
26 | 29 | import java.util.Objects;
|
27 | 30 | import java.util.Optional;
|
28 | 31 |
|
@@ -63,95 +66,97 @@ public Mono<Void> filter(@Nonnull ServerWebExchange exchange, WebFilterChain cha
|
63 | 66 | .map(user -> {
|
64 | 67 |
|
65 | 68 | Connection activeConnection = null;
|
66 |
| - String orgId = null; |
| 69 | + List<String> orgIds = List.of(); |
67 | 70 |
|
68 | 71 | Optional<Connection> activeConnectionOptional = user.getConnections()
|
69 | 72 | .stream()
|
70 | 73 | .filter(connection -> connection.getAuthId().equals(user.getActiveAuthId()))
|
71 | 74 | .findFirst();
|
72 | 75 |
|
73 | 76 | if(!activeConnectionOptional.isPresent()) {
|
74 |
| - return Triple.of(user, activeConnection, orgId); |
| 77 | + return Triple.of(user, activeConnection, orgIds); |
75 | 78 | }
|
76 | 79 |
|
77 | 80 | activeConnection = activeConnectionOptional.get();
|
78 | 81 |
|
79 | 82 | if(!activeConnection.getAuthId().equals(DEFAULT_AUTH_CONFIG.getId())) {
|
80 | 83 | if(activeConnection.getAuthConnectionAuthToken().getExpireAt() == 0) {
|
81 |
| - return Triple.of(user, activeConnection, orgId); |
| 84 | + return Triple.of(user, activeConnection, orgIds); |
82 | 85 | }
|
83 | 86 | boolean isAccessTokenExpired = (activeConnection.getAuthConnectionAuthToken().getExpireAt()*1000) < Instant.now().toEpochMilli();
|
84 | 87 | if(isAccessTokenExpired) {
|
85 | 88 |
|
86 |
| - Optional<String> orgIdOptional = activeConnection.getOrgIds().stream().findFirst(); |
87 |
| - if(!orgIdOptional.isPresent()) { |
88 |
| - return Triple.of(user, activeConnection, orgId); |
| 89 | + List<String> activeOrgIds = activeConnection.getOrgIds().stream().toList(); |
| 90 | + if(!activeOrgIds.isEmpty()) { |
| 91 | + return Triple.of(user, activeConnection, activeOrgIds); |
89 | 92 | }
|
90 |
| - orgId = orgIdOptional.get(); |
| 93 | + orgIds = activeOrgIds; |
91 | 94 | }
|
92 | 95 | }
|
93 | 96 |
|
94 |
| - return Triple.of(user, activeConnection, orgId); |
| 97 | + return Triple.of(user, activeConnection, orgIds); |
95 | 98 |
|
96 | 99 | }).flatMap(this::refreshOauthToken)
|
97 | 100 | .flatMap(user -> chain.filter(exchange).contextWrite(withAuthentication(toAuthentication(user)))
|
98 | 101 | .then(service.extendValidity(cookieToken))
|
99 | 102 | );
|
100 | 103 | }
|
101 | 104 |
|
102 |
| - private Mono<User> refreshOauthToken(Triple<User, Connection, String> triple) { |
| 105 | + private Mono<User> refreshOauthToken(Triple<User, Connection, List<String>> triple) { |
103 | 106 |
|
104 | 107 | User user = triple.getLeft();
|
105 | 108 | Connection connection = triple.getMiddle();
|
106 |
| - String orgId = triple.getRight(); |
| 109 | + Collection<String> orgIds = triple.getRight(); |
107 | 110 |
|
108 |
| - if (connection == null || orgId == null) { |
| 111 | + if (connection == null || orgIds == null || orgIds.isEmpty()) { |
109 | 112 | return Mono.just(user);
|
110 | 113 | }
|
111 | 114 |
|
112 |
| - OAuth2RequestContext oAuth2RequestContext = new OAuth2RequestContext(triple.getRight(), null, null); |
| 115 | + return Flux.fromIterable(orgIds).flatMap(orgId -> { |
| 116 | + OAuth2RequestContext oAuth2RequestContext = new OAuth2RequestContext(orgId, null, null); |
113 | 117 |
|
114 |
| - log.info("Refreshing token for user: [ name: {}, id: {} ], orgId: {}, activeConnection: [ authId: {}, name: {}, orgIds: ({})]", |
115 |
| - user.getName(), user.getId(), |
116 |
| - orgId, |
117 |
| - connection.getAuthId(), connection.getName(), StringUtils.join(connection.getOrgIds(), ", ")); |
| 118 | + log.info("Refreshing token for user: [ name: {}, id: {} ], orgIds: {}, activeConnection: [ authId: {}, name: {}, orgIds: ({})]", |
| 119 | + user.getName(), user.getId(), |
| 120 | + orgIds, |
| 121 | + connection.getAuthId(), connection.getName(), StringUtils.join(connection.getOrgIds(), ", ")); |
118 | 122 |
|
119 |
| - return authenticationService |
120 |
| - .findAuthConfigByAuthId(orgId, connection.getAuthId()) |
121 |
| - .switchIfEmpty(Mono.empty()) |
122 |
| - .flatMap(findAuthConfig -> { |
| 123 | + return authenticationService |
| 124 | + .findAllAuthConfigs(orgId, true) |
| 125 | + .filter(findAuthConfig -> findAuthConfig.authConfig().getId().equals(connection.getAuthId())) |
| 126 | + .switchIfEmpty(Mono.empty()) |
| 127 | + .flatMap(findAuthConfig -> { |
123 | 128 |
|
124 |
| - Mono<AuthRequest> authRequestMono = Mono.empty(); |
| 129 | + Mono<AuthRequest> authRequestMono = Mono.empty(); |
125 | 130 |
|
126 |
| - if(findAuthConfig == null) { |
127 |
| - return authRequestMono; |
128 |
| - } |
129 |
| - oAuth2RequestContext.setAuthConfig(findAuthConfig.authConfig()); |
130 |
| - |
131 |
| - return authRequestFactory.build(oAuth2RequestContext); |
132 |
| - }) |
133 |
| - .publishOn(Schedulers.boundedElastic()).flatMap(authRequest -> { |
134 |
| - if(authRequest == null) { |
135 |
| - return Mono.just(user); |
136 |
| - } |
137 |
| - try { |
138 |
| - if (StringUtils.isEmpty(connection.getAuthConnectionAuthToken().getRefreshToken())) { |
139 |
| - log.error("Refresh token is empty"); |
140 |
| - throw new Exception("Refresh token is empty"); |
| 131 | + if (findAuthConfig == null) { |
| 132 | + return authRequestMono; |
141 | 133 | }
|
142 |
| - AuthUser authUser = authRequest.refresh(connection.getAuthConnectionAuthToken().getRefreshToken()).block(); |
143 |
| - authUser.setAuthContext(oAuth2RequestContext); |
144 |
| - authenticationApiService.updateConnection(authUser, user); |
145 |
| - return userService.update(user.getId(), user); |
146 |
| - } catch (Exception e) { |
147 |
| - log.error("Failed to refresh access token. Removing user sessions/tokens."); |
148 |
| - connection.getTokens().forEach(token -> { |
149 |
| - service.removeUserSession(token).block(); |
150 |
| - }); |
151 |
| - } |
152 |
| - return Mono.just(user); |
153 |
| - }); |
| 134 | + oAuth2RequestContext.setAuthConfig(findAuthConfig.authConfig()); |
154 | 135 |
|
| 136 | + return authRequestFactory.build(oAuth2RequestContext); |
| 137 | + }) |
| 138 | + .publishOn(Schedulers.boundedElastic()).flatMap(authRequest -> { |
| 139 | + if (authRequest == null) { |
| 140 | + return Mono.just(user); |
| 141 | + } |
| 142 | + try { |
| 143 | + if (StringUtils.isEmpty(connection.getAuthConnectionAuthToken().getRefreshToken())) { |
| 144 | + log.error("Refresh token is empty"); |
| 145 | + throw new Exception("Refresh token is empty"); |
| 146 | + } |
| 147 | + AuthUser authUser = authRequest.refresh(connection.getAuthConnectionAuthToken().getRefreshToken()).block(); |
| 148 | + authUser.setAuthContext(oAuth2RequestContext); |
| 149 | + authenticationApiService.updateConnection(authUser, user); |
| 150 | + return userService.update(user.getId(), user); |
| 151 | + } catch (Exception e) { |
| 152 | + log.error("Failed to refresh access token. Removing user sessions/tokens."); |
| 153 | + connection.getTokens().forEach(token -> { |
| 154 | + service.removeUserSession(token).block(); |
| 155 | + }); |
| 156 | + } |
| 157 | + return Mono.just(user); |
| 158 | + }); |
| 159 | + }).next(); |
155 | 160 | }
|
156 | 161 |
|
157 | 162 | }
|
0 commit comments