Last active
April 5, 2021 20:28
-
-
Save emsearcy/cba3295d1a06d4c432ab4f6173b65e4f to your computer and use it in GitHub Desktop.
go-ldap LDAP connection management
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| type ldapConnResult struct { | |
| Conn *ldap.Conn | |
| Error error | |
| } | |
| type userLdapPool struct { | |
| // maintain two "pool" connections managed through channels | |
| // one for read/write operations and another for authentication (re-binding) | |
| rwConn chan ldapConnResult | |
| rwConnReset chan bool | |
| authConn chan ldapConnResult | |
| authConnReset chan bool | |
| } | |
| var ( | |
| maxConnAttempts uint8 = 2 | |
| attrEscapeChars = regexp.MustCompile("[,\\#+<>;\"=]") | |
| ) | |
| var userDB userGroupReaderWriter // userLdapPool implements userGroupReaderWriter, I have other structs for cache-based mocks | |
| // | |
| // ... | |
| // | |
| func (db *userLdapPool) New() { | |
| db.rwConn = make(chan ldapConnResult) | |
| db.rwConnReset = make(chan bool, 1) | |
| go ldapFactoryWorker(db.rwConnReset, db.rwConn) | |
| db.authConn = make(chan ldapConnResult) | |
| db.authConnReset = make(chan bool, 1) | |
| go ldapFactoryWorker(db.authConnReset, db.authConn) | |
| } | |
| func (db *userLdapPool) Add(ctx context.Context, attributes *userAttributes) (created bool, err error) { | |
| // ... | |
| // Get shared LDAP connection. | |
| l, err := getLDAPConn(ctx, db.rwConn, db.rwConnReset) | |
| if err != nil { | |
| return false, err | |
| } | |
| var n uint8 | |
| for n = 1; ; n++ { | |
| _, seg := xray.BeginSubsegment(ctx, "LDAP add") | |
| err = l.Add(addRequest) | |
| xrayFinishSegment(ctx, seg) | |
| if !ldap.IsErrorWithCode(err, ldap.ErrorNetwork) || n > maxConnAttempts { | |
| // Break out of loop for successes, non-network errors, or max | |
| // network-error retries. | |
| break | |
| } | |
| logrus.WithFields(logrus.Fields{ | |
| "operation": "add", | |
| "attempt": n, | |
| "error": err, | |
| }).Info("ldap network error, retrying") | |
| l, err = reconnectLDAPConn(ctx, db.rwConn, db.rwConnReset) | |
| if err != nil { | |
| return false, err | |
| } | |
| } | |
| if ldap.IsErrorWithCode(err, ldap.LDAPResultEntryAlreadyExists) { | |
| // user already exists | |
| return false, nil | |
| } | |
| if err != nil { | |
| // includes network errors when n > maxConnAttempts | |
| return false, err | |
| } | |
| return true, nil | |
| } | |
| func (db *userLdapPool) VerifyPass(ctx context.Context, username string, password string) (result bool, err error) { | |
| // Get shared LDAP connection. | |
| l, err := getLDAPConn(ctx, db.authConn, db.authConnReset) | |
| if err != nil { | |
| return false, err | |
| } | |
| dn := fmt.Sprintf("uid=%s,%s", ldapEscapeAttrValue(username), cfg.LDAPUserSearchBase) | |
| bindRequest := ldap.NewSimpleBindRequest(dn, password, nil) | |
| var n uint8 | |
| for n = 1; ; n++ { | |
| _, seg := xray.BeginSubsegment(ctx, "LDAP bind") | |
| _, err = l.SimpleBind(bindRequest) | |
| xrayFinishSegment(ctx, seg) | |
| if !ldap.IsErrorWithCode(err, ldap.ErrorNetwork) || n > maxConnAttempts { | |
| // Break out of loop for successes, non-network errors, or max | |
| // network-error retries. | |
| break | |
| } | |
| logrus.WithFields(logrus.Fields{ | |
| "operation": "bind", | |
| "attempt": n, | |
| "error": err, | |
| }).Info("ldap network error, retrying") | |
| l, err = reconnectLDAPConn(db.authConn, db.authConnReset) | |
| if l == nil { | |
| return false, err | |
| } | |
| } | |
| if ldap.IsErrorWithCode(err, ldap.LDAPResultInvalidCredentials) { | |
| // invalid credentials | |
| return false, nil | |
| } | |
| if err != nil { | |
| // includes network errors when n > maxConnAttempts | |
| return false, err | |
| } | |
| return true, nil | |
| } | |
| func ldapFactoryWorker(reset <-chan bool, c chan<- ldapConnResult) { | |
| var err error | |
| var l *ldap.Conn | |
| l, err = ldapConnect() | |
| if err != nil { | |
| logrus.WithField("err", err).Warn("initial ldap connection failed") | |
| } | |
| for { | |
| select { | |
| case _ = <-reset: | |
| if l != nil { | |
| l.Close() | |
| } | |
| // reconnect | |
| logrus.Infoln("ldap reconnecting") | |
| l, err = ldapConnect() | |
| case c <- ldapConnResult{l, err}: | |
| } | |
| } | |
| } | |
| // getLDAPConn is a small helper to break out the result and error from our | |
| // ldapResultConn channel. If the returned connection is nil, it attempts a | |
| // single reconnect. | |
| func getLDAPConn(ctx context.Context, c <-chan ldapConnResult, reset chan<- bool) (*ldap.Conn, error) { | |
| r := <-c | |
| if r.Conn != nil { | |
| return r.Conn, r.Error | |
| } | |
| return reconnectLDAPConn(ctx, c, reset) | |
| } | |
| // reconnectLDAPConn resets the connection and attempts to return a new | |
| // connection (breaking out the result and error from the ldapResultConn | |
| // channel). | |
| func reconnectLDAPConn(ctx context.Context, c <-chan ldapConnResult, reset chan<- bool) (*ldap.Conn, error) { | |
| _, seg := xray.BeginSubsegment(ctx, "LDAP connect") | |
| select { | |
| // Send non-blocking connection reset request. | |
| case reset <- true: | |
| logrus.Debugln("attempting connection") | |
| default: | |
| logrus.Debugln("already connecting") | |
| } | |
| // Fetch client (blocks on pending reconnection). | |
| result := <-c | |
| xrayFinishSegment(ctx, seg) | |
| return result.Conn, result.Error | |
| } | |
| func ldapConnect() (l *ldap.Conn, err error) { | |
| if cfg.LDAPStartTLS { | |
| // connect plain first (usually to port 389) | |
| l, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", cfg.LDAPHost, cfg.LDAPPort)) | |
| if err != nil { | |
| return | |
| } | |
| // reconnect with TLS | |
| err = l.StartTLS(&tls.Config{ServerName: cfg.LDAPHost}) | |
| if err != nil { | |
| l.Close() | |
| return | |
| } | |
| } else { | |
| // the other secure alternative is wrapped SSL (usually 636) | |
| l, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", cfg.LDAPHost, cfg.LDAPPort), &tls.Config{ServerName: cfg.LDAPHost}) | |
| if err != nil { | |
| return | |
| } | |
| } | |
| // now bind with our user | |
| err = l.Bind(cfg.LDAPBindDN, cfg.LDAPBindPass) | |
| if err != nil { | |
| l.Close() | |
| return nil, err | |
| } | |
| return | |
| } | |
| func ldapEscapeAttrValue(in string) string { | |
| out := attrEscapeChars.ReplaceAllString(in, "\\$0") | |
| if strings.HasPrefix(out, " ") { | |
| out = "\\" + out | |
| } | |
| if strings.HasSuffix(out, " ") { | |
| out = out[:len(out)-1] + "\\ " | |
| } | |
| return out | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment