Skip to content

Instantly share code, notes, and snippets.

@emsearcy
Last active April 5, 2021 20:28
Show Gist options
  • Select an option

  • Save emsearcy/cba3295d1a06d4c432ab4f6173b65e4f to your computer and use it in GitHub Desktop.

Select an option

Save emsearcy/cba3295d1a06d4c432ab4f6173b65e4f to your computer and use it in GitHub Desktop.
go-ldap LDAP connection management
type ldapConnResult struct {
Conn *ldap.Conn
Error error
}
type userLdapPool struct {
// maintain two "pool" connections managed through channels
// one for read/write operations and another for authentication (re-binding)
rwConn chan ldapConnResult
rwConnReset chan bool
authConn chan ldapConnResult
authConnReset chan bool
}
var (
maxConnAttempts uint8 = 2
attrEscapeChars = regexp.MustCompile("[,\\#+<>;\"=]")
)
var userDB userGroupReaderWriter // userLdapPool implements userGroupReaderWriter, I have other structs for cache-based mocks
//
// ...
//
func (db *userLdapPool) New() {
db.rwConn = make(chan ldapConnResult)
db.rwConnReset = make(chan bool, 1)
go ldapFactoryWorker(db.rwConnReset, db.rwConn)
db.authConn = make(chan ldapConnResult)
db.authConnReset = make(chan bool, 1)
go ldapFactoryWorker(db.authConnReset, db.authConn)
}
func (db *userLdapPool) Add(ctx context.Context, attributes *userAttributes) (created bool, err error) {
// ...
// Get shared LDAP connection.
l, err := getLDAPConn(ctx, db.rwConn, db.rwConnReset)
if err != nil {
return false, err
}
var n uint8
for n = 1; ; n++ {
_, seg := xray.BeginSubsegment(ctx, "LDAP add")
err = l.Add(addRequest)
xrayFinishSegment(ctx, seg)
if !ldap.IsErrorWithCode(err, ldap.ErrorNetwork) || n > maxConnAttempts {
// Break out of loop for successes, non-network errors, or max
// network-error retries.
break
}
logrus.WithFields(logrus.Fields{
"operation": "add",
"attempt": n,
"error": err,
}).Info("ldap network error, retrying")
l, err = reconnectLDAPConn(ctx, db.rwConn, db.rwConnReset)
if err != nil {
return false, err
}
}
if ldap.IsErrorWithCode(err, ldap.LDAPResultEntryAlreadyExists) {
// user already exists
return false, nil
}
if err != nil {
// includes network errors when n > maxConnAttempts
return false, err
}
return true, nil
}
func (db *userLdapPool) VerifyPass(ctx context.Context, username string, password string) (result bool, err error) {
// Get shared LDAP connection.
l, err := getLDAPConn(ctx, db.authConn, db.authConnReset)
if err != nil {
return false, err
}
dn := fmt.Sprintf("uid=%s,%s", ldapEscapeAttrValue(username), cfg.LDAPUserSearchBase)
bindRequest := ldap.NewSimpleBindRequest(dn, password, nil)
var n uint8
for n = 1; ; n++ {
_, seg := xray.BeginSubsegment(ctx, "LDAP bind")
_, err = l.SimpleBind(bindRequest)
xrayFinishSegment(ctx, seg)
if !ldap.IsErrorWithCode(err, ldap.ErrorNetwork) || n > maxConnAttempts {
// Break out of loop for successes, non-network errors, or max
// network-error retries.
break
}
logrus.WithFields(logrus.Fields{
"operation": "bind",
"attempt": n,
"error": err,
}).Info("ldap network error, retrying")
l, err = reconnectLDAPConn(db.authConn, db.authConnReset)
if l == nil {
return false, err
}
}
if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) {
// invalid credentials
return false, nil
}
if err != nil {
// includes network errors when n > maxConnAttempts
return false, err
}
return true, nil
}
func ldapFactoryWorker(reset <-chan bool, c chan<- ldapConnResult) {
var err error
var l *ldap.Conn
l, err = ldapConnect()
if err != nil {
logrus.WithField("err", err).Warn("initial ldap connection failed")
}
for {
select {
case _ = <-reset:
if l != nil {
l.Close()
}
// reconnect
logrus.Infoln("ldap reconnecting")
l, err = ldapConnect()
case c <- ldapConnResult{l, err}:
}
}
}
// getLDAPConn is a small helper to break out the result and error from our
// ldapResultConn channel. If the returned connection is nil, it attempts a
// single reconnect.
func getLDAPConn(ctx context.Context, c <-chan ldapConnResult, reset chan<- bool) (*ldap.Conn, error) {
r := <-c
if r.Conn != nil {
return r.Conn, r.Error
}
return reconnectLDAPConn(ctx, c, reset)
}
// reconnectLDAPConn resets the connection and attempts to return a new
// connection (breaking out the result and error from the ldapResultConn
// channel).
func reconnectLDAPConn(ctx context.Context, c <-chan ldapConnResult, reset chan<- bool) (*ldap.Conn, error) {
_, seg := xray.BeginSubsegment(ctx, "LDAP connect")
select {
// Send non-blocking connection reset request.
case reset <- true:
logrus.Debugln("attempting connection")
default:
logrus.Debugln("already connecting")
}
// Fetch client (blocks on pending reconnection).
result := <-c
xrayFinishSegment(ctx, seg)
return result.Conn, result.Error
}
func ldapConnect() (l *ldap.Conn, err error) {
if cfg.LDAPStartTLS {
// connect plain first (usually to port 389)
l, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", cfg.LDAPHost, cfg.LDAPPort))
if err != nil {
return
}
// reconnect with TLS
err = l.StartTLS(&tls.Config{ServerName: cfg.LDAPHost})
if err != nil {
l.Close()
return
}
} else {
// the other secure alternative is wrapped SSL (usually 636)
l, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", cfg.LDAPHost, cfg.LDAPPort), &tls.Config{ServerName: cfg.LDAPHost})
if err != nil {
return
}
}
// now bind with our user
err = l.Bind(cfg.LDAPBindDN, cfg.LDAPBindPass)
if err != nil {
l.Close()
return nil, err
}
return
}
func ldapEscapeAttrValue(in string) string {
out := attrEscapeChars.ReplaceAllString(in, "\\$0")
if strings.HasPrefix(out, " ") {
out = "\\" + out
}
if strings.HasSuffix(out, " ") {
out = out[:len(out)-1] + "\\ "
}
return out
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment