Skip to content

Commit 2ecf5fc

Browse files
committed
Lint and review comments
1 parent b06360d commit 2ecf5fc

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

server/common/oursrc/accountadm/mbash.in

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22

33
from __future__ import (absolute_import, division, print_function)
44

5-
import os
6-
import sys
75
import getpass
6+
import os
87
import subprocess
8+
import sys
9+
910
import ldap
1011
import ldap.filter
1112

13+
BASE_DN = 'dc=scripts,dc=mit,dc=edu'
14+
1215
def get_pool(username):
16+
"""
17+
Check what pool(s) a locker is on.
18+
19+
Returns: (default vhost pool IP, [(pool name, vhost name)] if multiple pools)
20+
"""
1321
ldap_uri = ldap.get_option(ldap.OPT_URI)
1422

1523
ll = ldap.initialize(ldap_uri)
1624

1725
users = ll.search_s(
18-
'dc=scripts,dc=mit,dc=edu',
26+
BASE_DN,
1927
ldap.SCOPE_SUBTREE,
2028
ldap.filter.filter_format('(&(objectClass=posixAccount)(uid=%s))', [username]),
2129
[],
@@ -26,31 +34,40 @@ def get_pool(username):
2634

2735
pool_ips = set()
2836
vhost_pools = {}
29-
for dn, attrs in ll.search_s(
30-
'dc=scripts,dc=mit,dc=edu',
31-
ldap.SCOPE_SUBTREE,
32-
ldap.filter.filter_format('(&(objectClass=scriptsVhost)(scriptsVhostAccount=%s))', [user_dn]),
33-
['scriptsVhostName', 'scriptsVhostPoolIPv4'],
37+
for _, attrs in ll.search_s(
38+
BASE_DN,
39+
ldap.SCOPE_SUBTREE,
40+
ldap.filter.filter_format(
41+
'(&(objectClass=scriptsVhost)(scriptsVhostAccount=%s))',
42+
[user_dn]),
43+
['scriptsVhostName', 'scriptsVhostPoolIPv4'],
3444
):
3545
vhost_pools[attrs['scriptsVhostName'][0]] = attrs['scriptsVhostPoolIPv4'][0]
3646
pool_ips.add(attrs['scriptsVhostPoolIPv4'][0])
3747

3848
pool_names = {}
3949
for dn, attrs in ll.search_s(
40-
'dc=scripts,dc=mit,dc=edu',
50+
BASE_DN,
4151
ldap.SCOPE_SUBTREE,
42-
'(&(objectClass=scriptsVhostPool)(|'+''.join(ldap.filter.filter_format('(scriptsVhostPoolIPv4=%s)', [ip]) for ip in pool_ips)+'))',
52+
'(&(objectClass=scriptsVhostPool)(|'+''.join(
53+
ldap.filter.filter_format('(scriptsVhostPoolIPv4=%s)', [ip])
54+
for ip in pool_ips
55+
)+'))',
4356
['cn', 'scriptsVhostPoolIPv4'],
4457
):
4558
pool_names[attrs['scriptsVhostPoolIPv4'][0]] = attrs['cn'][0]
4659

4760
main_pool = vhost_pools.get(username + '.scripts.mit.edu')
4861
other_pools = None
4962
if len(pool_ips) > 1:
50-
other_pools = sorted((pool_names.get(pool, pool), vhost) for vhost, pool in vhost_pools.items())
63+
other_pools = sorted(
64+
(pool_names.get(pool, pool), vhost)
65+
for vhost, pool in vhost_pools.items()
66+
)
5167
return main_pool, other_pools
5268

5369
def should_forward():
70+
"""Check if we were invoked by ssh on a vip that requires forwarding."""
5471
ssh_connection = os.environ.get('SSH_CONNECTION')
5572
if not ssh_connection:
5673
return False
@@ -64,9 +81,15 @@ def should_forward():
6481
return False
6582

6683
def has_pool(ip):
84+
"""Check if the current machine is binding a vip."""
6785
return len(subprocess.check_output(['/sbin/ip', 'addr', 'show', 'to', ip])) > 0
6886

6987
def maybe_forward():
88+
"""
89+
Forward the invocation if appropriate.
90+
91+
exec's when forwarding, so returning means we should run locally.
92+
"""
7093
if not should_forward():
7194
return
7295
command = None
@@ -79,10 +102,11 @@ def maybe_forward():
79102
main_pool, other_pools = get_pool(user)
80103
forward = main_pool and not has_pool(main_pool)
81104
if forward:
82-
# TODO: Check if we're already on the right server.
83105
print("Forwarding to the server for %s.scripts.mit.edu." % (user,), file=sys.stderr)
84106
if other_pools:
85-
print("Your account has virtual hosts on multiple server pools; to connect to a server for a particular host, connect to a specific server:", file=sys.stderr)
107+
print("Your account has virtual hosts on multiple server pools; "
108+
"to connect to a server for a particular host, "
109+
"connect to a specific server:", file=sys.stderr)
86110
print(file=sys.stderr)
87111
for name, vhost in other_pools:
88112
print("%s - ssh %s" % (vhost, name), file=sys.stderr)
@@ -93,6 +117,7 @@ def maybe_forward():
93117
args.append(command)
94118
os.execv('/usr/bin/ssh', args)
95119

96-
maybe_forward()
120+
if __name__ == '__main__':
121+
maybe_forward()
97122

98-
os.execv("@bash_path@", ["bash", "--rcfile", "/usr/local/etc/mbashrc"] + sys.argv[1:])
123+
os.execv("@bash_path@", ["bash", "--rcfile", "/usr/local/etc/mbashrc"] + sys.argv[1:])

0 commit comments

Comments
 (0)