2
2
3
3
from __future__ import (absolute_import , division , print_function )
4
4
5
- import os
6
- import sys
7
5
import getpass
6
+ import os
8
7
import subprocess
8
+ import sys
9
+
9
10
import ldap
10
11
import ldap .filter
11
12
13
+ BASE_DN = 'dc=scripts,dc=mit,dc=edu'
14
+
12
15
def get_pool (username ):
16
+ """
17
+ Check what pool(s) a locker is on.
18
+
19
+ Returns: (default vhost pool IP, [(pool name, vhost name)] if multiple pools)
20
+ """
13
21
ldap_uri = ldap .get_option (ldap .OPT_URI )
14
22
15
23
ll = ldap .initialize (ldap_uri )
16
24
17
25
users = ll .search_s (
18
- 'dc=scripts,dc=mit,dc=edu' ,
26
+ BASE_DN ,
19
27
ldap .SCOPE_SUBTREE ,
20
28
ldap .filter .filter_format ('(&(objectClass=posixAccount)(uid=%s))' , [username ]),
21
29
[],
@@ -26,31 +34,40 @@ def get_pool(username):
26
34
27
35
pool_ips = set ()
28
36
vhost_pools = {}
29
- for dn , attrs in ll .search_s (
30
- 'dc=scripts,dc=mit,dc=edu' ,
31
- ldap .SCOPE_SUBTREE ,
32
- ldap .filter .filter_format ('(&(objectClass=scriptsVhost)(scriptsVhostAccount=%s))' , [user_dn ]),
33
- ['scriptsVhostName' , 'scriptsVhostPoolIPv4' ],
37
+ for _ , attrs in ll .search_s (
38
+ BASE_DN ,
39
+ ldap .SCOPE_SUBTREE ,
40
+ ldap .filter .filter_format (
41
+ '(&(objectClass=scriptsVhost)(scriptsVhostAccount=%s))' ,
42
+ [user_dn ]),
43
+ ['scriptsVhostName' , 'scriptsVhostPoolIPv4' ],
34
44
):
35
45
vhost_pools [attrs ['scriptsVhostName' ][0 ]] = attrs ['scriptsVhostPoolIPv4' ][0 ]
36
46
pool_ips .add (attrs ['scriptsVhostPoolIPv4' ][0 ])
37
47
38
48
pool_names = {}
39
49
for dn , attrs in ll .search_s (
40
- 'dc=scripts,dc=mit,dc=edu' ,
50
+ BASE_DN ,
41
51
ldap .SCOPE_SUBTREE ,
42
- '(&(objectClass=scriptsVhostPool)(|' + '' .join (ldap .filter .filter_format ('(scriptsVhostPoolIPv4=%s)' , [ip ]) for ip in pool_ips )+ '))' ,
52
+ '(&(objectClass=scriptsVhostPool)(|' + '' .join (
53
+ ldap .filter .filter_format ('(scriptsVhostPoolIPv4=%s)' , [ip ])
54
+ for ip in pool_ips
55
+ )+ '))' ,
43
56
['cn' , 'scriptsVhostPoolIPv4' ],
44
57
):
45
58
pool_names [attrs ['scriptsVhostPoolIPv4' ][0 ]] = attrs ['cn' ][0 ]
46
59
47
60
main_pool = vhost_pools .get (username + '.scripts.mit.edu' )
48
61
other_pools = None
49
62
if len (pool_ips ) > 1 :
50
- other_pools = sorted ((pool_names .get (pool , pool ), vhost ) for vhost , pool in vhost_pools .items ())
63
+ other_pools = sorted (
64
+ (pool_names .get (pool , pool ), vhost )
65
+ for vhost , pool in vhost_pools .items ()
66
+ )
51
67
return main_pool , other_pools
52
68
53
69
def should_forward ():
70
+ """Check if we were invoked by ssh on a vip that requires forwarding."""
54
71
ssh_connection = os .environ .get ('SSH_CONNECTION' )
55
72
if not ssh_connection :
56
73
return False
@@ -64,9 +81,15 @@ def should_forward():
64
81
return False
65
82
66
83
def has_pool (ip ):
84
+ """Check if the current machine is binding a vip."""
67
85
return len (subprocess .check_output (['/sbin/ip' , 'addr' , 'show' , 'to' , ip ])) > 0
68
86
69
87
def maybe_forward ():
88
+ """
89
+ Forward the invocation if appropriate.
90
+
91
+ exec's when forwarding, so returning means we should run locally.
92
+ """
70
93
if not should_forward ():
71
94
return
72
95
command = None
@@ -79,10 +102,11 @@ def maybe_forward():
79
102
main_pool , other_pools = get_pool (user )
80
103
forward = main_pool and not has_pool (main_pool )
81
104
if forward :
82
- # TODO: Check if we're already on the right server.
83
105
print ("Forwarding to the server for %s.scripts.mit.edu." % (user ,), file = sys .stderr )
84
106
if other_pools :
85
- print ("Your account has virtual hosts on multiple server pools; to connect to a server for a particular host, connect to a specific server:" , file = sys .stderr )
107
+ print ("Your account has virtual hosts on multiple server pools; "
108
+ "to connect to a server for a particular host, "
109
+ "connect to a specific server:" , file = sys .stderr )
86
110
print (file = sys .stderr )
87
111
for name , vhost in other_pools :
88
112
print ("%s - ssh %s" % (vhost , name ), file = sys .stderr )
@@ -93,6 +117,7 @@ def maybe_forward():
93
117
args .append (command )
94
118
os .execv ('/usr/bin/ssh' , args )
95
119
96
- maybe_forward ()
120
+ if __name__ == '__main__' :
121
+ maybe_forward ()
97
122
98
- os .execv ("@bash_path@" , ["bash" , "--rcfile" , "/usr/local/etc/mbashrc" ] + sys .argv [1 :])
123
+ os .execv ("@bash_path@" , ["bash" , "--rcfile" , "/usr/local/etc/mbashrc" ] + sys .argv [1 :])
0 commit comments