Skip to content

Commit 5b4b2c2

Browse files
committed
Establish connections to multiple LDAP servers and spread requests across them
1 parent 4c8dca2 commit 5b4b2c2

File tree

3 files changed

+117
-32
lines changed

3 files changed

+117
-32
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package ldap
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"sync"
7+
"time"
8+
9+
ldap "gopkg.in/ldap.v3"
10+
)
11+
12+
type conn struct {
13+
server, baseDn string
14+
// mu protects conn during reconnect cycles
15+
// TODO: The ldap package supports multiple in-flight queries;
16+
// by using a Mutex we are only going to issue one at a
17+
// time. We should figure out how to do retry/reconnect
18+
// behavior with parallel queries.
19+
mu sync.Mutex
20+
conn *ldap.Conn
21+
}
22+
23+
func (c *conn) reconnect() {
24+
c.mu.Lock()
25+
defer c.mu.Unlock()
26+
if c.conn != nil {
27+
c.conn.Close()
28+
}
29+
var err error
30+
for {
31+
log.Printf("connecting to %s", c.server)
32+
c.conn, err = ldap.Dial("tcp", c.server)
33+
if err == nil {
34+
return
35+
}
36+
log.Printf("connecting to %s: %v", c.server, err)
37+
time.Sleep(100 * time.Millisecond)
38+
}
39+
}
40+
41+
func (c *conn) resolvePool(hostname string) (string, error) {
42+
c.mu.Lock()
43+
defer c.mu.Unlock()
44+
45+
escapedHostname := ldap.EscapeFilter(hostname)
46+
req := ldap.NewSearchRequest(
47+
c.baseDn,
48+
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
49+
fmt.Sprintf("(|(scriptsVhostName=%s)(scriptsVhostAlias=%s))", escapedHostname, escapedHostname),
50+
[]string{"scriptsVhostPoolIPv4"},
51+
nil,
52+
)
53+
sr, err := c.conn.Search(req)
54+
if err != nil {
55+
return "", err
56+
}
57+
for _, entry := range sr.Entries {
58+
return entry.GetAttributeValue("scriptsVhostPoolIPv4"), nil
59+
}
60+
// Not found is not an error
61+
return "", nil
62+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package ldap
2+
3+
import "log"
4+
5+
type Pool struct {
6+
retries int
7+
// connCh holds open connections to servers.
8+
connCh chan *conn
9+
}
10+
11+
func NewPool(servers []string, baseDn string, retries int) *Pool {
12+
p := &Pool{
13+
retries: retries,
14+
connCh: make(chan *conn, len(servers)),
15+
}
16+
for _, s := range servers {
17+
c := &conn{
18+
server: s,
19+
baseDn: baseDn,
20+
}
21+
go p.reconnect(c)
22+
}
23+
return p
24+
}
25+
26+
func (p *Pool) reconnect(c *conn) {
27+
c.reconnect()
28+
p.connCh <- c
29+
}
30+
31+
func (p *Pool) ResolvePool(hostname string) (string, error) {
32+
var ip string
33+
var err error
34+
for i := 0; i < p.retries; i++ {
35+
c := <-p.connCh
36+
ip, err = c.resolvePool(hostname)
37+
if err == nil {
38+
p.connCh <- c
39+
return ip, err
40+
}
41+
log.Printf("resolving %q on %s: %v", hostname, c.server, err)
42+
go p.reconnect(c)
43+
}
44+
return ip, err
45+
}

server/common/oursrc/scripts-proxy/main.go

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,39 +8,41 @@ import (
88
"net"
99
"strings"
1010

11-
ldap "gopkg.in/ldap.v3"
11+
"github.com/mit-scripts/scripts/server/common/oursrc/scripts-proxy/ldap"
1212
"inet.af/tcpproxy"
1313
)
1414

1515
var (
1616
httpAddrs = flag.String("http_addrs", ":80", "comma-separated addresses to listen for HTTP traffic on")
1717
sniAddrs = flag.String("sni_addrs", ":443,:444", "comma-separated addresses to listen for SNI traffic on")
18-
ldapServer = flag.String("ldap_server", "scripts-ldap.mit.edu:389", "LDAP server to query")
18+
ldapServers = flag.String("ldap_servers", "scripts-ldap.mit.edu:389", "comma-spearated LDAP servers to query")
1919
defaultHost = flag.String("default_host", "scripts.mit.edu", "default host to route traffic to if SNI/Host header cannot be parsed or cannot be found in LDAP")
2020
baseDn = flag.String("base_dn", "ou=VirtualHosts,dc=scripts,dc=mit,dc=edu", "base DN to query for hosts")
2121
localRange = flag.String("local_range", "18.4.86.0/24", "IP block for client IP spoofing")
2222
)
2323

24+
const ldapRetries = 3
25+
2426
func always(context.Context, string) bool {
2527
return true
2628
}
2729

2830
type ldapTarget struct {
2931
localPoolRange *net.IPNet
30-
ldap *ldap.Conn
32+
ldap *ldap.Pool
3133
}
3234

3335
func (l *ldapTarget) HandleConn(netConn net.Conn) {
3436
var pool string
3537
var err error
3638
if conn, ok := netConn.(*tcpproxy.Conn); ok {
37-
pool, err = l.resolvePool(conn.HostName)
39+
pool, err = l.ldap.ResolvePool(conn.HostName)
3840
if err != nil {
3941
log.Printf("resolving %q: %v", conn.HostName, err)
4042
}
4143
}
4244
if pool == "" {
43-
pool, err = l.resolvePool(*defaultHost)
45+
pool, err = l.ldap.ResolvePool(*defaultHost)
4446
if err != nil {
4547
log.Printf("resolving default pool: %v", err)
4648
}
@@ -72,44 +74,20 @@ func (l *ldapTarget) HandleConn(netConn net.Conn) {
7274
dp.HandleConn(netConn)
7375
}
7476

75-
func (l *ldapTarget) resolvePool(hostname string) (string, error) {
76-
escapedHostname := ldap.EscapeFilter(hostname)
77-
req := ldap.NewSearchRequest(
78-
*baseDn,
79-
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
80-
fmt.Sprintf("(|(scriptsVhostName=%s)(scriptsVhostAlias=%s))", escapedHostname, escapedHostname),
81-
[]string{"scriptsVhostPoolIPv4"},
82-
nil,
83-
)
84-
sr, err := l.ldap.Search(req)
85-
if err != nil {
86-
return "", err
87-
}
88-
for _, entry := range sr.Entries {
89-
return entry.GetAttributeValue("scriptsVhostPoolIPv4"), nil
90-
}
91-
// Not found is not an error
92-
return "", nil
93-
}
94-
9577
func main() {
9678
flag.Parse()
9779

98-
l, err := ldap.Dial("tcp", *ldapServer)
99-
if err != nil {
100-
log.Fatal(err)
101-
}
102-
defer l.Close()
103-
10480
_, ipnet, err := net.ParseCIDR(*localRange)
10581
if err != nil {
10682
log.Fatal(err)
10783
}
10884

85+
ldapPool := ldap.NewPool(strings.Split(*ldapServers, ","), *baseDn, ldapRetries)
86+
10987
var p tcpproxy.Proxy
11088
t := &ldapTarget{
11189
localPoolRange: ipnet,
112-
ldap: l,
90+
ldap: ldapPool,
11391
}
11492
for _, addr := range strings.Split(*httpAddrs, ",") {
11593
p.AddHTTPHostMatchRoute(addr, always, t)

0 commit comments

Comments
 (0)