// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" ber "github.com/go-asn1-ber/asn1-ber" ) // AddMessage is an add request message type AddMessage struct { baseMessage // DN identifies the entry being added DN string // Attributes list the attributes of the new entry Attributes []Attribute // Controls hold optional controls to send with the request Controls []Control } // Attribute represents an LDAP attribute within AddMessage type Attribute struct { // Type is the name of the LDAP attribute Type string // Vals are the LDAP attribute values Vals []string } func (a *Attribute) encode() *ber.Packet { seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type")) set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") for _, value := range a.Vals { set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) } seq.AppendChild(set) return seq } func decodeAttribute(berPacket *ber.Packet) (*Attribute, error) { const op = "gldap.decodeAttribute" const ( childType = 0 childVals = 1 childControls = 2 ) if berPacket == nil { return nil, fmt.Errorf("%s: missing ber packet: %w", op, ErrInvalidParameter) } var decodedAttribute Attribute seq := &packet{ Packet: berPacket, } if err := seq.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence)); err != nil { return nil, fmt.Errorf("%s: missing/invalid attributes ber packet: %w", op, ErrInvalidParameter) } if err := seq.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childType)); err != nil { return nil, fmt.Errorf("%s: missing/invalid attributes type: %w", op, ErrInvalidParameter) } decodedAttribute.Type = seq.Children[childType].Data.String() if err := seq.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSet), withAssertChild(childVals)); err != nil { return nil, fmt.Errorf("%s: missing/invalid attributes values: %w", op, ErrInvalidParameter) } valuesPacket := &packet{ Packet: seq.Children[childVals], } decodedAttribute.Vals = make([]string, 0, len(valuesPacket.Children)) for idx := range valuesPacket.Children { if err := valuesPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(idx)); err != nil { return nil, fmt.Errorf("%s: invalid attribute values packet: %w", op, err) } decodedAttribute.Vals = append(decodedAttribute.Vals, valuesPacket.Children[idx].Data.String()) } return &decodedAttribute, nil }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "bufio" "context" "errors" "fmt" "io" "net" "strings" "sync" "time" ber "github.com/go-asn1-ber/asn1-ber" "github.com/hashicorp/go-hclog" ) // conn is a connection to an ldap client type conn struct { mu sync.Mutex // mutex for the conn connID int netConn net.Conn logger hclog.Logger router *Mux shutdownCtx context.Context requestsWg sync.WaitGroup reader *bufio.Reader writer *bufio.Writer writerMu sync.Mutex // shared lock across all ResponseWriter's to prevent write data races } // newConn will create a new Conn from an accepted net.Conn which will be used // to serve requests to an ldap client. func newConn(shutdownCtx context.Context, connID int, netConn net.Conn, logger hclog.Logger, router *Mux) (*conn, error) { const op = "gldap.NewConn" if shutdownCtx == nil { return nil, fmt.Errorf("%s: missing shutdown context: %w", op, ErrInvalidParameter) } if connID == 0 { return nil, fmt.Errorf("%s: missing connection id: %w", op, ErrInvalidParameter) } if netConn == nil { return nil, fmt.Errorf("%s: missing connection: %w", op, ErrInvalidParameter) } if logger == nil { return nil, fmt.Errorf("%s: missing logger: %w", op, ErrInvalidParameter) } if router == nil { return nil, fmt.Errorf("%s: missing router: %w", op, ErrInvalidParameter) } c := &conn{ connID: connID, netConn: netConn, shutdownCtx: shutdownCtx, logger: logger, router: router, } if err := c.initConn(netConn); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return c, nil } // serveRequests until the connection is closed or the shutdownCtx is cancelled // as the server stops func (c *conn) serveRequests() error { const op = "gldap.serveRequests" requestID := 0 for { requestID++ w, err := newResponseWriter(c.writer, &c.writerMu, c.logger, c.connID, requestID) if err != nil { return fmt.Errorf("%s: %w", op, err) } select { case <-c.shutdownCtx.Done(): c.logger.Debug("received shutdown cancellation", "op", op, "conn", c.connID, "requestID", w.requestID) // build a request by hand, since this is not a normal situation // where we've read a request... and we need to make this check // before blocking on reading the next request. req := &Request{ ID: w.requestID, conn: c, message: &ExtendedOperationMessage{baseMessage: baseMessage{id: 0}}, routeOp: routeOperation(ExtendedOperationDisconnection), extendedName: ExtendedOperationDisconnection, } resp := req.NewResponse(WithResponseCode(ResultUnwillingToPerform), WithDiagnosticMessage("server stopping")) if err := w.Write(resp); err != nil { return fmt.Errorf("%s: %w", op, err) } if err := c.netConn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { return fmt.Errorf("%s: %w", op, err) } return nil default: // need a default to fall through to rest of loop... } r, err := c.readRequest(w.requestID) if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || strings.Contains(err.Error(), "unexpected EOF") { return nil // connection is closed } return fmt.Errorf("%s: error reading request: %w", op, err) } switch { // TODO: rate limit in-flight requests per conn and send a // BusyResponse when the limit is reached. This limit per conn // should be configurable case r.routeOp == unbindRouteOperation: // support an optional unbind route if c.router.unbindRoute != nil { c.router.unbindRoute.handler()(w, r) } // stop serving requests when UnbindRequest is received return nil // If it's a StartTLS request, then we can't dispatch it concurrently, // since the conn needs to complete it's TLS negotiation before handling // any other requests. // see: https://datatracker.ietf.org/doc/html/rfc4511#section-4.14.1 case r.extendedName == ExtendedOperationStartTLS: c.router.serve(w, r) default: c.requestsWg.Add(1) go func() { defer func() { c.logger.Debug("requestsWg done", "op", op, "conn", c.connID, "requestID", w.requestID) c.requestsWg.Done() }() c.router.serve(w, r) }() } } } func (c *conn) readRequest(requestID int) (*Request, error) { const op = "gldap.(Conn).readRequest" p, err := c.readPacket(requestID) if err != nil { return nil, fmt.Errorf("%s: error reading packet for %d/%d: %w", op, c.connID, requestID, err) } r, err := newRequest(requestID, c, p) if err != nil { return nil, fmt.Errorf("%s: unable to create new in-memory request for %d/%d: %w", op, c.connID, requestID, err) } return r, nil } func (c *conn) readPacket(requestID int) (*packet, error) { const op = "gldap.readPacket" // read a request berPacket, err := func() (*ber.Packet, error) { c.mu.Lock() defer c.mu.Unlock() berPacket, err := ber.ReadPacket(c.reader) switch { case err != nil && strings.Contains(err.Error(), "invalid character for IA5String at pos 2"): return nil, fmt.Errorf("%s: error reading ber packet for %d/%d (possible attempt to use TLS with a non-TLS server): %w", op, c.connID, requestID, err) case err != nil: return nil, fmt.Errorf("%s: error reading ber packet for %d/%d: %w", op, c.connID, requestID, err) } return berPacket, nil }() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } p := &packet{Packet: berPacket} if c.logger.IsDebug() { c.logger.Debug("packet read", "op", op, "conn", c.connID, "requestID", requestID) p.Log(c.logger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false) } // Simple header is first... let's make sure it's an ldap packet with 2 // children containing: // [0] is a message ID // [1] is a request header if err := p.basicValidation(); err != nil { return nil, fmt.Errorf("%s: failed validation: %w", op, err) } return p, nil } func (c *conn) initConn(netConn net.Conn) error { const op = "gldap.(Conn).initConn" if netConn == nil { return fmt.Errorf("%s: missing net conn: %w", op, ErrInvalidParameter) } c.mu.Lock() defer c.mu.Unlock() c.netConn = netConn c.reader = bufio.NewReader(c.netConn) c.writer = bufio.NewWriter(c.netConn) return nil } func (c *conn) close() error { const op = "gldap.(Conn).close" c.requestsWg.Wait() if err := c.netConn.Close(); err != nil { return fmt.Errorf("%s: error closing conn: %w", op, err) } return nil }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" "strconv" ber "github.com/go-asn1-ber/asn1-ber" ) const ( // ControlTypePaging - https://www.ietf.org/rfc/rfc2696.txt ControlTypePaging = "1.2.840.113556.1.4.319" // ControlTypeBeheraPasswordPolicy - https://tools.ietf.org/html/draft-behera-ldap-password-policy-10 ControlTypeBeheraPasswordPolicy = "1.3.6.1.4.1.42.2.27.8.5.1" // ControlTypeVChuPasswordMustChange - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4" // ControlTypeVChuPasswordWarning - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5" // ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296 ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2" // ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532 ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3" // ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528" // ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417" // ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309" ) // ControlTypeMap maps controls to text descriptions var ControlTypeMap = map[string]string{ ControlTypePaging: "Paging", ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft", ControlTypeManageDsaIT: "Manage DSA IT", ControlTypeMicrosoftNotification: "Change Notification - Microsoft", ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft", ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft", } // Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10) const ( BeheraPasswordExpired = 0 BeheraAccountLocked = 1 BeheraChangeAfterReset = 2 BeheraPasswordModNotAllowed = 3 BeheraMustSupplyOldPassword = 4 BeheraInsufficientPasswordQuality = 5 BeheraPasswordTooShort = 6 BeheraPasswordTooYoung = 7 BeheraPasswordInHistory = 8 ) // BeheraPasswordPolicyErrorMap contains human readable descriptions of Behera Password Policy error codes var BeheraPasswordPolicyErrorMap = map[int8]string{ BeheraPasswordExpired: "Password expired", BeheraAccountLocked: "Account locked", BeheraChangeAfterReset: "Password must be changed", BeheraPasswordModNotAllowed: "Policy prevents password modification", BeheraMustSupplyOldPassword: "Policy requires old password in order to change password", BeheraInsufficientPasswordQuality: "Password fails quality checks", BeheraPasswordTooShort: "Password is too short for policy", BeheraPasswordTooYoung: "Password has been changed too recently", BeheraPasswordInHistory: "New password is in list of old passwords", } // Control defines a common interface for all ldap controls type Control interface { // GetControlType returns the OID GetControlType() string // Encode returns the ber packet representation Encode() *ber.Packet // String returns a human-readable description String() string } func encodeControls(controls []Control) *ber.Packet { packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls") for _, control := range controls { packet.AppendChild(control.Encode()) } return packet } func decodeControl(packet *ber.Packet) (Control, error) { const op = "gldap.decodeControl" var ( ControlType = "" Criticality = false value *ber.Packet ) if packet == nil { return nil, fmt.Errorf("%s: packet is nil: %w", op, ErrInvalidParameter) } switch len(packet.Children) { case 0: // at least one child is required for a control type return nil, fmt.Errorf("%s: at least one child is required for control type", op) case 1: // just type, no critically or value packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" ControlType = packet.Children[0].Value.(string) case 2: packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" ControlType = packet.Children[0].Value.(string) // Children[1] could be criticality or value (both are optional) // duck-type on whether this is a boolean if _, ok := packet.Children[1].Value.(bool); ok { packet.Children[1].Description = "Criticality" Criticality = packet.Children[1].Value.(bool) } else { packet.Children[1].Description = "Control Value" value = packet.Children[1] } case 3: packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" ControlType = packet.Children[0].Value.(string) packet.Children[1].Description = "Criticality" Criticality = packet.Children[1].Value.(bool) packet.Children[2].Description = "Control Value" value = packet.Children[2] default: // more than 3 children is invalid return nil, fmt.Errorf("%s: more than 3 children is invalid for controls", op) } switch ControlType { case ControlTypeManageDsaIT: return NewControlManageDsaIT(WithCriticality(Criticality)) case ControlTypePaging: if value == nil { return new(ControlPaging), nil } value.Description += " (Paging)" c := new(ControlPaging) if value.Value != nil { valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) if err != nil { return nil, fmt.Errorf("%s, failed to decode data bytes: %w", op, err) } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) } if len(value.Children) < 1 { return nil, fmt.Errorf("%s: paging control value must have a least 1 child: %w", op, ErrInvalidParameter) } value = value.Children[0] value.Description = "Search Control Value" value.Children[0].Description = "Paging Size" value.Children[1].Description = "Cookie" c.PagingSize = uint32(value.Children[0].Value.(int64)) c.Cookie = value.Children[1].Data.Bytes() value.Children[1].Value = c.Cookie return c, nil case ControlTypeBeheraPasswordPolicy: if value == nil { c, err := NewControlBeheraPasswordPolicy() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return c, nil } value.Description += " (Password Policy - Behera)" c, err := NewControlBeheraPasswordPolicy() if err != nil { return nil, fmt.Errorf("%s: failed to create behera password control", op) } if value.Value != nil { valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) if err != nil { return nil, fmt.Errorf("%s: failed to decode data bytes: %w", op, err) } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) } if len(value.Children) == 0 { return nil, fmt.Errorf("%s: behera control value must have a least 1 child: %w", op, ErrInvalidParameter) } sequence := value.Children[0] for _, child := range sequence.Children { if child.Tag == 0 { // Warning warningPacket := child.Children[0] val, err := ber.ParseInt64(warningPacket.Data.Bytes()) if err != nil { return nil, fmt.Errorf("%s: failed to decode data bytes: %w", op, err) } if warningPacket.Tag == 0 { // timeBeforeExpiration c.expire = val warningPacket.Value = c.expire } else if warningPacket.Tag == 1 { // graceAuthNsRemaining c.grace = val warningPacket.Value = c.grace } } else if child.Tag == 1 { // Error bs := child.Data.Bytes() if len(bs) != 1 || bs[0] > 8 { return nil, fmt.Errorf("%s: failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value", op) } val := int8(bs[0]) c.error = val child.Value = c.error c.errorString = BeheraPasswordPolicyErrorMap[c.error] } } return c, nil case ControlTypeVChuPasswordMustChange: c := &ControlVChuPasswordMustChange{MustChange: true} return c, nil case ControlTypeVChuPasswordWarning: if value == nil { return &ControlVChuPasswordWarning{Expire: -1}, nil } c := &ControlVChuPasswordWarning{Expire: -1} expireStr := ber.DecodeString(value.Data.Bytes()) expire, err := strconv.ParseInt(expireStr, 10, 64) if err != nil { return nil, fmt.Errorf("%s: failed to parse value as int: %w", op, err) } c.Expire = expire value.Value = c.Expire return c, nil case ControlTypeMicrosoftNotification: return NewControlMicrosoftNotification() case ControlTypeMicrosoftShowDeleted: return NewControlMicrosoftShowDeleted() case ControlTypeMicrosoftServerLinkTTL: return NewControlMicrosoftServerLinkTTL() default: c := new(ControlString) c.ControlType = ControlType c.Criticality = Criticality if value != nil { c.ControlValue = value.Value.(string) } return c, nil } } // ControlString implements the Control interface for simple controls type ControlString struct { ControlType string Criticality bool ControlValue string } // GetControlType returns the OID func (c *ControlString) GetControlType() string { return c.ControlType } // Encode returns the ber packet representation func (c *ControlString) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")")) if c.Criticality { packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) } if c.ControlValue != "" { packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(c.ControlValue), "Control Value")) } return packet } // String returns a human-readable description func (c *ControlString) String() string { return fmt.Sprintf("Control Type: %s (%q) Criticality: %t Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue) } // NewControlString returns a generic control. Options supported: // WithCriticality and WithControlValue func NewControlString(controlType string, opt ...Option) (*ControlString, error) { const op = "gldap.NewControlString" if controlType == "" { return nil, fmt.Errorf("%s: missing control type: %w", op, ErrInvalidParameter) } opts := getControlOpts(opt...) return &ControlString{ ControlType: controlType, Criticality: opts.withCriticality, ControlValue: opts.withControlValue, }, nil } // ControlManageDsaIT implements the control described in https://tools.ietf.org/html/rfc3296 type ControlManageDsaIT struct { // Criticality indicates if this control is required Criticality bool } // Encode returns the ber packet representation func (c *ControlManageDsaIT) Encode() *ber.Packet { // FIXME packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeManageDsaIT, "Control Type ("+ControlTypeMap[ControlTypeManageDsaIT]+")")) if c.Criticality { packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality")) } return packet } // GetControlType returns the OID func (c *ControlManageDsaIT) GetControlType() string { return ControlTypeManageDsaIT } // String returns a human-readable description func (c *ControlManageDsaIT) String() string { return fmt.Sprintf( "Control Type: %s (%q) Criticality: %t", ControlTypeMap[ControlTypeManageDsaIT], ControlTypeManageDsaIT, c.Criticality) } // NewControlManageDsaIT returns a ControlManageDsaIT control. Supported // options: WithCriticality func NewControlManageDsaIT(opt ...Option) (*ControlManageDsaIT, error) { opts := getControlOpts(opt...) return &ControlManageDsaIT{Criticality: opts.withCriticality}, nil } // ControlMicrosoftNotification implements the control described in https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx type ControlMicrosoftNotification struct{} // GetControlType returns the OID func (c *ControlMicrosoftNotification) GetControlType() string { return ControlTypeMicrosoftNotification } // Encode returns the ber packet representation func (c *ControlMicrosoftNotification) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftNotification, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftNotification]+")")) return packet } // String returns a human-readable description func (c *ControlMicrosoftNotification) String() string { return fmt.Sprintf( "Control Type: %s (%q)", ControlTypeMap[ControlTypeMicrosoftNotification], ControlTypeMicrosoftNotification) } // NewControlMicrosoftNotification returns a ControlMicrosoftNotification // control. No options are currently supported. func NewControlMicrosoftNotification(_ ...Option) (*ControlMicrosoftNotification, error) { return &ControlMicrosoftNotification{}, nil } // ControlMicrosoftServerLinkTTL implements the control described in https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN type ControlMicrosoftServerLinkTTL struct{} // GetControlType returns the OID func (c *ControlMicrosoftServerLinkTTL) GetControlType() string { return ControlTypeMicrosoftServerLinkTTL } // Encode returns the ber packet representation func (c *ControlMicrosoftServerLinkTTL) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftServerLinkTTL, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftServerLinkTTL]+")")) return packet } // String returns a human-readable description func (c *ControlMicrosoftServerLinkTTL) String() string { return fmt.Sprintf( "Control Type: %s (%q)", ControlTypeMap[ControlTypeMicrosoftServerLinkTTL], ControlTypeMicrosoftServerLinkTTL) } // NewControlMicrosoftServerLinkTTL returns a ControlMicrosoftServerLinkTTL // control. No options are currently supported. func NewControlMicrosoftServerLinkTTL(_ ...Option) (*ControlMicrosoftServerLinkTTL, error) { return &ControlMicrosoftServerLinkTTL{}, nil } // ControlMicrosoftShowDeleted implements the control described in https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx type ControlMicrosoftShowDeleted struct{} // GetControlType returns the OID func (c *ControlMicrosoftShowDeleted) GetControlType() string { return ControlTypeMicrosoftShowDeleted } // Encode returns the ber packet representation func (c *ControlMicrosoftShowDeleted) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftShowDeleted, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftShowDeleted]+")")) return packet } // String returns a human-readable description func (c *ControlMicrosoftShowDeleted) String() string { return fmt.Sprintf( "Control Type: %s (%q)", ControlTypeMap[ControlTypeMicrosoftShowDeleted], ControlTypeMicrosoftShowDeleted) } // NewControlMicrosoftShowDeleted returns a ControlMicrosoftShowDeleted control. // No options are currently supported. func NewControlMicrosoftShowDeleted(_ ...Option) (*ControlMicrosoftShowDeleted, error) { return &ControlMicrosoftShowDeleted{}, nil } // ControlBeheraPasswordPolicy implements the control described in https://tools.ietf.org/html/draft-behera-ldap-password-policy-10 type ControlBeheraPasswordPolicy struct { // expire contains the number of seconds before a password will expire expire int64 // grace indicates the remaining number of times a user will be allowed to authenticate with an expired password grace int64 // error indicates the error code error int8 // errorString is a human readable error errorString string } // Grace returns the remaining number of times a user will be allowed to // authenticate with an expired password. A value of -1 indicates it hasn't been // set. func (c *ControlBeheraPasswordPolicy) Grace() int { return int(c.grace) } // Expire contains the number of seconds before a password will expire. A value // of -1 indicates it hasn't been set. func (c *ControlBeheraPasswordPolicy) Expire() int { return int(c.expire) } // ErrorCode is the error code and a human readable string. A value of -1 and // empty string indicates it hasn't been set. func (c *ControlBeheraPasswordPolicy) ErrorCode() (int, string) { return int(c.error), c.errorString } // GetControlType returns the OID func (c *ControlBeheraPasswordPolicy) GetControlType() string { return ControlTypeBeheraPasswordPolicy } // Encode returns the ber packet representation func (c *ControlBeheraPasswordPolicy) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeBeheraPasswordPolicy, "Control Type ("+ControlTypeMap[ControlTypeBeheraPasswordPolicy]+")")) switch { case c.grace >= 0: // control value packet for GraceAuthNsRemaining valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "") sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "") // it's a warning. so it's the end of a context (ber.TagEOC) contextPacket := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0x00, nil, "") // "0x01" tag indicates an grace logins contextPacket.AppendChild(ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x01, c.grace, "")) sequencePacket.AppendChild(contextPacket) valuePacket.AppendChild(sequencePacket) packet.AppendChild(valuePacket) return packet // I believe you can only have either Grace or Expire for a response.... not both. case c.expire >= 0: // control value packet for timeBeforeExpiration valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "") sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "") // it's a warning. so it's the end of a context (ber.TagEOC) contextPacket := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0x00, nil, "") // "0x00" tag indicates an expires in contextPacket.AppendChild(ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x00, c.expire, "")) sequencePacket.AppendChild(contextPacket) valuePacket.AppendChild(sequencePacket) packet.AppendChild(valuePacket) return packet // I believe you can only have either Grace or Expire for a response.... not both. case c.error >= 0: valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "") sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "") contextPacket := ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x01, c.error, "") sequencePacket.AppendChild(contextPacket) valuePacket.AppendChild(sequencePacket) packet.AppendChild(valuePacket) } return packet } // String returns a human-readable description func (c *ControlBeheraPasswordPolicy) String() string { return fmt.Sprintf( "Control Type: %s (%q) Criticality: %t Expire: %d Grace: %d Error: %d, ErrorString: %s", ControlTypeMap[ControlTypeBeheraPasswordPolicy], ControlTypeBeheraPasswordPolicy, false, c.expire, c.grace, c.error, c.errorString) } // NewControlBeheraPasswordPolicy returns a ControlBeheraPasswordPolicy. // Options supported: WithExpire, WithGrace, WithErrorCode func NewControlBeheraPasswordPolicy(opt ...Option) (*ControlBeheraPasswordPolicy, error) { const op = "NewControlBeheraPolicy" opts := getControlOpts(opt...) switch { case opts.withGrace != -1 && opts.withExpire != -1: return nil, fmt.Errorf("%s: behera policies cannot have both grace and expire set: %w", op, ErrInvalidParameter) case opts.withGrace != -1 && opts.withErrorCode != -1: return nil, fmt.Errorf("%s: behera policies cannot have both grace and error codes set: %w", op, ErrInvalidParameter) case opts.withExpire != -1 && opts.withErrorCode != -1: return nil, fmt.Errorf("%s: behera polices cannot have both expire and error codes set: %w", op, ErrInvalidParameter) case opts.withErrorCode > 8: return nil, fmt.Errorf("%s: %d is not a valid behera policy error code (must be between 0-8: %w", op, opts.withErrorCode, ErrInvalidParameter) } c := &ControlBeheraPasswordPolicy{ expire: int64(opts.withExpire), grace: int64(opts.withGrace), error: int8(opts.withErrorCode), } if opts.withErrorCode != -1 { c.errorString = BeheraPasswordPolicyErrorMap[int8(opts.withErrorCode)] } return c, nil } // ControlVChuPasswordMustChange implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 type ControlVChuPasswordMustChange struct { // MustChange indicates if the password is required to be changed MustChange bool } // GetControlType returns the OID func (c *ControlVChuPasswordMustChange) GetControlType() string { return ControlTypeVChuPasswordMustChange } // Encode returns the ber packet representation func (c *ControlVChuPasswordMustChange) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") // I believe, just the control type child is require... not criticality or // value is require... packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeVChuPasswordMustChange, "Control Type ("+ControlTypeMap[ControlTypeVChuPasswordMustChange]+")")) return packet } // String returns a human-readable description func (c *ControlVChuPasswordMustChange) String() string { return fmt.Sprintf( "Control Type: %s (%q) Criticality: %t MustChange: %v", ControlTypeMap[ControlTypeVChuPasswordMustChange], ControlTypeVChuPasswordMustChange, false, c.MustChange) } // ControlVChuPasswordWarning implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00 type ControlVChuPasswordWarning struct { // Expire indicates the time in seconds until the password expires Expire int64 } // GetControlType returns the OID func (c *ControlVChuPasswordWarning) GetControlType() string { return ControlTypeVChuPasswordWarning } // Encode returns the ber packet representation func (c *ControlVChuPasswordWarning) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeVChuPasswordWarning, "Control Type ("+ControlTypeMap[ControlTypeVChuPasswordWarning]+")")) // I believe, it's a string in the spec expStr := strconv.FormatInt(c.Expire, 10) packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, expStr, "Control Value")) return packet } // String returns a human-readable description func (c *ControlVChuPasswordWarning) String() string { return fmt.Sprintf( "Control Type: %s (%q) Criticality: %t Expire: %d", ControlTypeMap[ControlTypeVChuPasswordWarning], ControlTypeVChuPasswordWarning, false, c.Expire) } // ControlPaging implements the paging control described in https://www.ietf.org/rfc/rfc2696.txt type ControlPaging struct { // PagingSize indicates the page size PagingSize uint32 // Cookie is an opaque value returned by the server to track a paging cursor Cookie []byte } // GetControlType returns the OID func (c *ControlPaging) GetControlType() string { return ControlTypePaging } // Encode returns the ber packet representation func (c *ControlPaging) Encode() *ber.Packet { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")")) p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)") seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value") seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.PagingSize), "Paging Size")) cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie") cookie.Value = c.Cookie cookie.Data.Write(c.Cookie) seq.AppendChild(cookie) p2.AppendChild(seq) packet.AppendChild(p2) return packet } // String returns a human-readable description func (c *ControlPaging) String() string { return fmt.Sprintf( "Control Type: %s (%q) Criticality: %t PagingSize: %d Cookie: %q", ControlTypeMap[ControlTypePaging], ControlTypePaging, false, c.PagingSize, c.Cookie) } // SetCookie stores the given cookie in the paging control func (c *ControlPaging) SetCookie(cookie []byte) { c.Cookie = cookie } // NewControlPaging returns a paging control func NewControlPaging(pagingSize uint32, _ ...Option) (*ControlPaging, error) { return &ControlPaging{PagingSize: pagingSize}, nil } func addControlDescriptions(packet *ber.Packet) error { const op = "gldap.addControlDescriptions" if packet == nil { return fmt.Errorf("%s: missing packet: %w", op, ErrInvalidParameter) } packet.Description = "Controls" for _, child := range packet.Children { var value *ber.Packet controlType := "" child.Description = "Control" switch len(child.Children) { case 0: // at least one child is required for control type return fmt.Errorf("at least one child is required for a control type") case 1: // just type, no criticality or value controlType = child.Children[0].Value.(string) child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" case 2: controlType = child.Children[0].Value.(string) child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" // Children[1] could be criticality or value (both are optional) // duck-type on whether this is a boolean if _, ok := child.Children[1].Value.(bool); ok { child.Children[1].Description = "Criticality" } else { child.Children[1].Description = "Control Value" value = child.Children[1] } case 3: // criticality and value present controlType = child.Children[0].Value.(string) child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")" child.Children[1].Description = "Criticality" child.Children[2].Description = "Control Value" value = child.Children[2] default: // more than 3 children is invalid return fmt.Errorf("more than 3 children for control packet found") } if value == nil { continue } switch controlType { case ControlTypePaging: value.Description += " (Paging)" if value.Value != nil { valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) if err != nil { return fmt.Errorf("failed to decode data bytes: %s", err) } value.Data.Truncate(0) value.Value = nil valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes() value.AppendChild(valueChildren) } value.Children[0].Description = "Real Search Control Value" value.Children[0].Children[0].Description = "Paging Size" value.Children[0].Children[1].Description = "Cookie" case ControlTypeBeheraPasswordPolicy: value.Description += " (Password Policy - Behera Draft)" if value.Value != nil { valueChildren, err := ber.DecodePacketErr(value.Data.Bytes()) if err != nil { return fmt.Errorf("failed to decode data bytes: %s", err) } value.Data.Truncate(0) value.Value = nil value.AppendChild(valueChildren) } sequence := value.Children[0] for _, child := range sequence.Children { if child.Tag == 0 { // Warning warningPacket := child.Children[0] val, err := ber.ParseInt64(warningPacket.Data.Bytes()) if err != nil { return fmt.Errorf("failed to decode data bytes: %s", err) } if warningPacket.Tag == 0 { // timeBeforeExpiration value.Description += " (TimeBeforeExpiration)" warningPacket.Value = val } else if warningPacket.Tag == 1 { // graceAuthNsRemaining value.Description += " (GraceAuthNsRemaining)" warningPacket.Value = val } } else if child.Tag == 1 { // Error bs := child.Data.Bytes() if len(bs) != 1 || bs[0] > 8 { return fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value") } val := int8(bs[0]) child.Description = "Error" child.Value = val } } } } return nil }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap type controlOptions struct { withGrace int withExpire int withErrorCode int withCriticality bool withControlValue string // test options withTestType string withTestToString string } func controlDefaults() controlOptions { return controlOptions{ withGrace: -1, withExpire: -1, withErrorCode: -1, } } func getControlOpts(opt ...Option) controlOptions { opts := controlDefaults() applyOpts(&opts, opt...) return opts } // WithGraceAuthNsRemaining specifies the number of grace authentication // remaining. func WithGraceAuthNsRemaining(remaining uint) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withGrace = int(remaining) } } } // WithSecondsBeforeExpiration specifies the number of seconds before a password // will expire func WithSecondsBeforeExpiration(seconds uint) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withExpire = int(seconds) } } } // WithErrorCode specifies the error code func WithErrorCode(code uint) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withErrorCode = int(code) } } } // WithCriticality specifies the criticality func WithCriticality(criticality bool) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withCriticality = criticality } } } // WithControlValue specifies the control value func WithControlValue(value string) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withControlValue = value } } } func withTestType(s string) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withTestType = s } } } func withTestToString(s string) Option { return func(o interface{}) { if o, ok := o.(*controlOptions); ok { o.withTestToString = s } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" "os" "sort" "strings" ber "github.com/go-asn1-ber/asn1-ber" ) // Entry represents an ldap entry type Entry struct { // DN is the distinguished name of the entry DN string // Attributes are the returned attributes for the entry Attributes []*EntryAttribute } // GetAttributeValues returns the values for the named attribute, or an empty list func (e *Entry) GetAttributeValues(attribute string) []string { for _, attr := range e.Attributes { if attr.Name == attribute { return attr.Values } } return []string{} } // NewEntry returns an Entry object with the specified distinguished name and attribute key-value pairs. // The map of attributes is accessed in alphabetical order of the keys in order to ensure that, for the // same input map of attributes, the output entry will contain the same order of attributes func NewEntry(dn string, attributes map[string][]string) *Entry { var attributeNames []string for attributeName := range attributes { attributeNames = append(attributeNames, attributeName) } sort.Strings(attributeNames) var encodedAttributes []*EntryAttribute for _, attributeName := range attributeNames { encodedAttributes = append(encodedAttributes, NewEntryAttribute(attributeName, attributes[attributeName])) } return &Entry{ DN: dn, Attributes: encodedAttributes, } } // PrettyPrint outputs a human-readable description indenting. Supported // options: WithWriter func (e *Entry) PrettyPrint(indent int, opt ...Option) { opts := getGeneralOpts(opt...) if opts.withWriter == nil { opts.withWriter = os.Stdout } fmt.Fprintf(opts.withWriter, "%sDN: %s\n", strings.Repeat(" ", indent), e.DN) for _, attr := range e.Attributes { attr.PrettyPrint(indent+2, opt...) } } // PrettyPrint outputs a human-readable description with indenting. Supported // options: WithWriter func (e *EntryAttribute) PrettyPrint(indent int, opt ...Option) { opts := getGeneralOpts(opt...) if opts.withWriter == nil { opts.withWriter = os.Stdout } fmt.Fprintf(opts.withWriter, "%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values) } // EntryAttribute holds a single attribute type EntryAttribute struct { // Name is the name of the attribute Name string // Values contain the string values of the attribute Values []string // ByteValues contain the raw values of the attribute ByteValues [][]byte } // NewEntryAttribute returns a new EntryAttribute with the desired key-value pair func NewEntryAttribute(name string, values []string) *EntryAttribute { var bytes [][]byte for _, value := range values { bytes = append(bytes, []byte(value)) } return &EntryAttribute{ Name: name, Values: values, ByteValues: bytes, } } // AddValue to an existing EntryAttribute func (e *EntryAttribute) AddValue(value ...string) { for _, v := range value { e.ByteValues = append(e.ByteValues, []byte(v)) e.Values = append(e.Values, v) } } func (e *EntryAttribute) encode() *ber.Packet { seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, e.Name, "Type")) set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") for _, value := range e.Values { set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) } seq.AppendChild(set) return seq }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" ) // Scope represents the scope of a search (see: https://ldap.com/the-ldap-search-operation/) type Scope int64 const ( // BaseObject (often referred to as “base”): Indicates that only the entry // specified as the search base should be considered. None of its // subordinates will be considered. BaseObject Scope = 0 // SingleLevel (often referred to as “one”): Indicates that only the // immediate children of the entry specified as the search base should be // considered. The base entry itself should not be considered, nor any // descendants of the immediate children of the base entry. SingleLevel Scope = 1 // WholeSubtree (often referred to as “sub”): Indicates that the entry // specified as the search base, and all of its subordinates to any depth, // should be considered. Note that in the special case that the search base // DN is the null DN, the root DSE should not be considered in a // wholeSubtree search. WholeSubtree Scope = 2 ) // AuthChoice defines the authentication choice for bind message type AuthChoice string // SimpleAuthChoice specifies a simple user/password authentication choice for // the bind message const SimpleAuthChoice AuthChoice = "simple" type requestType string const ( unknownRequestType requestType = "" bindRequestType requestType = "bind" searchRequestType requestType = "search" extendedRequestType requestType = "extended" modifyRequestType requestType = "modify" addRequestType requestType = "add" deleteRequestType requestType = "delete" unbindRequestType requestType = "unbind" ) // Message defines a common interface for all messages type Message interface { // GetID returns the message ID GetID() int64 } // baseMessage defines a common base type for all messages (typically embedded) type baseMessage struct { id int64 } // GetID() returns the message ID func (m baseMessage) GetID() int64 { return m.id } // SearchMessage is a search request message type SearchMessage struct { baseMessage // BaseDN for the request BaseDN string // Scope of the request Scope Scope // DerefAliases for the request DerefAliases int // TimeLimit is the max time in seconds to spend processing TimeLimit int64 // SizeLimit is the max number of results to return SizeLimit int64 // TypesOnly is true if the client only expects type info TypesOnly bool // Filter for the request Filter string // Attributes requested Attributes []string // Controls requested Controls []Control } // SimpleBindMessage is a simple bind request message type SimpleBindMessage struct { baseMessage // AuthChoice for the request (SimpleAuthChoice) AuthChoice AuthChoice // UserName for the bind request UserName string // Password for the bind request Password Password // Controls are optional controls for the bind request Controls []Control } // ExtendedOperationMessage is an extended operation request message type ExtendedOperationMessage struct { baseMessage // Name of the extended operation Name ExtendedOperationName // Value of the extended operation Value string } // DeleteMessage is an delete request message type DeleteMessage struct { baseMessage // DN identifies the entry being added DN string // Controls hold optional controls to send with the request Controls []Control } // UnbindMessage is an unbind request message type UnbindMessage struct { baseMessage } // newMessage will create a new message from the packet. func newMessage(p *packet) (Message, error) { const op = "gldap.NewMessage" reqType, err := p.requestType() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } msgID, err := p.requestMessageID() if err != nil { return nil, fmt.Errorf("%s: unable to get message id: %w", op, err) } switch reqType { case unbindRequestType: return &UnbindMessage{ baseMessage: baseMessage{ id: msgID, }, }, nil case bindRequestType: u, pass, controls, err := p.simpleBindParameters() if err != nil { return nil, fmt.Errorf("%s: invalid bind message: %w", op, err) } return &SimpleBindMessage{ baseMessage: baseMessage{ id: msgID, }, UserName: u, Password: pass, AuthChoice: SimpleAuthChoice, Controls: controls, }, nil case searchRequestType: parameters, err := p.searchParmeters() if err != nil { return nil, fmt.Errorf("%s: invalid search message: %w", op, err) } return &SearchMessage{ baseMessage: baseMessage{ id: msgID, }, BaseDN: parameters.baseDN, Scope: Scope(parameters.scope), DerefAliases: int(parameters.derefAliases), SizeLimit: parameters.sizeLimit, TimeLimit: parameters.timeLimit, TypesOnly: parameters.typesOnly, Filter: parameters.filter, Attributes: parameters.attributes, Controls: parameters.controls, }, nil case extendedRequestType: opName, err := p.extendedOperationName() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return &ExtendedOperationMessage{ baseMessage: baseMessage{ id: msgID, }, Name: opName, }, nil case modifyRequestType: parameters, err := p.modifyParameters() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return &ModifyMessage{ baseMessage: baseMessage{ id: msgID, }, DN: parameters.dn, Changes: parameters.changes, Controls: parameters.controls, }, nil case addRequestType: parameters, err := p.addParameters() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return &AddMessage{ baseMessage: baseMessage{ id: msgID, }, DN: parameters.dn, Attributes: parameters.attributes, Controls: parameters.controls, }, nil case deleteRequestType: dn, controls, err := p.deleteParameters() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } return &DeleteMessage{ baseMessage: baseMessage{ id: msgID, }, DN: dn, Controls: controls, }, nil default: return &ExtendedOperationMessage{ baseMessage: baseMessage{ id: msgID, }, Name: ExtendedOperationUnknown, }, nil } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ber "github.com/go-asn1-ber/asn1-ber" type messageOptions struct { withMinChildren *int withLenChildren *int withAssertChild *int withTag *ber.Tag } func messageDefaults() messageOptions { return messageOptions{} } func getMessageOpts(opt ...Option) messageOptions { opts := messageDefaults() applyOpts(&opts, opt...) return opts } func withMinChildren(min int) Option { return func(o interface{}) { if o, ok := o.(*messageOptions); ok { o.withMinChildren = &min } } } // we'll see if we start using this again in the near future, // but for now ignore the warning // //nolint:unused func withLenChildren(len int) Option { return func(o interface{}) { if o, ok := o.(*messageOptions); ok { o.withLenChildren = &len } } } func withAssertChild(idx int) Option { return func(o interface{}) { if o, ok := o.(*messageOptions); ok { o.withAssertChild = &idx } } } func withTag(t ber.Tag) Option { return func(o interface{}) { if o, ok := o.(*messageOptions); ok { o.withTag = &t } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ber "github.com/go-asn1-ber/asn1-ber" // Change operation choices const ( AddAttribute = 0 DeleteAttribute = 1 ReplaceAttribute = 2 IncrementAttribute = 3 // (https://tools.ietf.org/html/rfc4525) ) // ModifyMessage as defined in https://tools.ietf.org/html/rfc4511 type ModifyMessage struct { baseMessage DN string Changes []Change Controls []Control } // Change for a ModifyMessage as defined in https://tools.ietf.org/html/rfc4511 type Change struct { // Operation is the type of change to be made Operation int64 // Modification is the attribute to be modified Modification PartialAttribute } // PartialAttribute for a ModifyMessage as defined in https://tools.ietf.org/html/rfc4511 type PartialAttribute struct { // Type is the type of the partial attribute Type string // Vals are the values of the partial attribute Vals []string } func (c *Change) encode() *ber.Packet { change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change") change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(c.Operation), "Operation")) change.AppendChild(c.Modification.encode()) return change } func (p *PartialAttribute) encode() *ber.Packet { seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute") seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.Type, "Type")) set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue") for _, value := range p.Vals { set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals")) } seq.AppendChild(set) return seq }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" "sync" ) // Mux is an ldap request multiplexer. It matches the inbound request against a // list of registered route handlers. Routes are matched in the order they're // added and only one route is called per request. type Mux struct { mu sync.Mutex routes []route defaultRoute route unbindRoute route } // NewMux creates a new multiplexer. func NewMux(opt ...Option) (*Mux, error) { return &Mux{ routes: []route{}, }, nil } // Bind will register a handler for bind requests. // Options supported: WithLabel func (m *Mux) Bind(bindFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Bind" if bindFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &simpleBindRoute{ baseRoute: &baseRoute{ h: bindFn, routeOp: bindRouteOperation, label: opts.withLabel, }, authChoice: SimpleAuthChoice, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // Unbind will register a handler for unbind requests and override the default // unbind handler. Registering an unbind handler is optional and regardless of // whether or not an unbind route is defined the server will stop serving // requests for a connection after an unbind request is received. Options // supported: WithLabel func (m *Mux) Unbind(bindFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Unbind" if bindFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &unbindRoute{ baseRoute: &baseRoute{ h: bindFn, routeOp: bindRouteOperation, label: opts.withLabel, }, } m.mu.Lock() defer m.mu.Unlock() m.unbindRoute = r return nil } // Search will register a handler for search requests. // Options supported: WithLabel, WithBaseDN, WithScope func (m *Mux) Search(searchFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Search" if searchFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &searchRoute{ baseRoute: &baseRoute{ h: searchFn, routeOp: searchRouteOperation, label: opts.withLabel, }, basedn: opts.withBaseDN, filter: opts.withFilter, scope: opts.withScope, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // ExtendedOperation will register a handler for extended operation requests. // Options supported: WithLabel func (m *Mux) ExtendedOperation(operationFn HandlerFunc, exName ExtendedOperationName, opt ...Option) error { const op = "gldap.(Mux).Search" if operationFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &extendedRoute{ baseRoute: &baseRoute{ h: operationFn, routeOp: extendedRouteOperation, label: opts.withLabel, }, extendedName: exName, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // Modify will register a handler for modify operation requests. // Options supported: WithLabel func (m *Mux) Modify(modifyFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Modify" if modifyFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &modifyRoute{ baseRoute: &baseRoute{ h: modifyFn, routeOp: modifyRouteOperation, label: opts.withLabel, }, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // Add will register a handler for add operation requests. // Options supported: WithLabel func (m *Mux) Add(addFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Add" if addFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &addRoute{ baseRoute: &baseRoute{ h: addFn, routeOp: addRouteOperation, label: opts.withLabel, }, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // Delete will register a handler for delete operation requests. // Options supported: WithLabel func (m *Mux) Delete(modifyFn HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Delete" if modifyFn == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } opts := getRouteOpts(opt...) r := &deleteRoute{ baseRoute: &baseRoute{ h: modifyFn, routeOp: deleteRouteOperation, label: opts.withLabel, }, } m.mu.Lock() defer m.mu.Unlock() m.routes = append(m.routes, r) return nil } // DefaultRoute will register a default handler requests which have no other // registered handler. func (m *Mux) DefaultRoute(noRouteFN HandlerFunc, opt ...Option) error { const op = "gldap.(Mux).Bind" if noRouteFN == nil { return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter) } r := &baseRoute{ h: noRouteFN, routeOp: bindRouteOperation, } m.mu.Lock() defer m.mu.Unlock() m.defaultRoute = r return nil } // serveRequests will find a matching route to serve the request func (m *Mux) serve(w *ResponseWriter, req *Request) { const op = "gldap.(Mux).serve" defer func() { w.logger.Debug("finished serving request", "op", op, "connID", w.connID, "requestID", w.requestID) }() if w == nil { // this should be unreachable, and if it is then we'll just panic panic(fmt.Errorf("%s: %d/%d missing response writer: %w", op, w.connID, w.requestID, ErrInternal).Error()) } if req == nil { w.logger.Error("missing request", "op", op, "connID", w.connID, "requestID", w.requestID) return } // find the first matching route to dispatch the request to and then return for _, r := range m.routes { if !r.match(req) { continue } h := r.handler() if h == nil { w.logger.Error("route is missing handler", "op", op, "connID", w.connID, "requestID", w.requestID, "route", r.op) return } // the handler intentionally doesn't return errors, since we want the // handler to response to the connection's client with errors. h(w, req) return } if m.defaultRoute != nil { h := m.defaultRoute.handler() h(w, req) return } w.logger.Error("no matching handler found for request and returning internal error", "op", op, "connID", w.connID, "requestID", w.requestID, "routeOp", req.routeOp) resp := req.NewResponse(WithResponseCode(ResultUnwillingToPerform), WithDiagnosticMessage("No matching handler found")) _ = w.Write(resp) }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "io" "reflect" ) // Option defines a common functional options type which can be used in a // variadic parameter pattern. type Option func(interface{}) // applyOpts takes a pointer to the options struct as a set of default options // and applies the slice of opts as overrides. func applyOpts(opts interface{}, opt ...Option) { for _, o := range opt { if o == nil { // ignore any nil Options continue } o(opts) } } type generalOptions struct { withWriter io.Writer } func generalDefaults() generalOptions { return generalOptions{} } func getGeneralOpts(opt ...Option) generalOptions { opts := generalDefaults() applyOpts(&opts, opt...) return opts } // WithWriter allows you to specify an optional writer. func WithWriter(w io.Writer) Option { return func(o interface{}) { if o, ok := o.(*generalOptions); ok { if !isNil(w) { o.withWriter = w } } } } func isNil(i interface{}) bool { if i == nil { return true } switch reflect.TypeOf(i).Kind() { case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: return reflect.ValueOf(i).IsNil() } return false }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "fmt" "io" ber "github.com/go-asn1-ber/asn1-ber" "github.com/go-ldap/ldap/v3" "github.com/hashicorp/go-hclog" ) type packet struct { *ber.Packet validated bool } func (p *packet) basicValidation() error { const ( op = "gldap.(packet).basicValidation" // messageID packet + Request packet childMinChildren = 2 ) if p.validated { return nil } // Simple header is first... let's make sure it's an ldap packet with 2 // children containing: // [0] is a message ID // [1] is a request header if err := p.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withMinChildren(childMinChildren)); err != nil { return fmt.Errorf("%s: invalid ldap packet 0: %w", op, ErrInvalidParameter) } p.validated = true return nil } func (p *packet) requestMessageID() (int64, error) { const ( op = "gldap.(packet).requestMessageID" childMessageID = 0 ) if err := p.basicValidation(); err != nil { return 0, fmt.Errorf("%s: %w", op, err) } msgIDPacket := &packet{Packet: p.Children[childMessageID]} // assert it's capable of holding the message ID if err := msgIDPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger)); err != nil { return 0, fmt.Errorf("%s: missing/invalid packet: %w", op, err) } id, ok := msgIDPacket.Value.(int64) if !ok { return 0, fmt.Errorf("%s: expected int64 message ID and got %t: %w", op, msgIDPacket.Value, ErrInvalidParameter) } return id, nil } // returns nil, nil if there's no control packet func (p *packet) controlPacket() (*packet, error) { const ( op = "gldap.(packet).controlPacket" childControl = 2 ) if len(p.Children) <= 2 { // no control packet return nil, nil } controlPacket := &packet{Packet: p.Children[childControl]} if err := controlPacket.assert(ber.ClassContext, ber.TypeConstructed); err != nil { return nil, fmt.Errorf("%s: invalid control packet: %w", op, ErrInvalidParameter) } return controlPacket, nil } func (p *packet) requestPacket() (*packet, error) { const ( op = "gldap.(packet).requestPacket" childApplicationRequest = 1 childVersionNumber = 0 // first child of the app request packet ) if err := p.basicValidation(); err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if err := p.assertApplicationRequest(); err != nil { return nil, fmt.Errorf("%s: missing request child packet: %w", op, err) } requestPacket := &packet{Packet: p.Children[childApplicationRequest]} switch requestPacket.Packet.Tag { case ApplicationBindRequest: // assert it's ldap v3 if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childVersionNumber)); err != nil { return nil, fmt.Errorf("%s: missing/invalid packet: %w", op, err) } ldapVersion, ok := requestPacket.Packet.Children[childVersionNumber].Value.(int64) if !ok { return nil, fmt.Errorf("%s: %v is not the expected int64 type: %w", op, requestPacket.Packet.Children[childVersionNumber].Value, ErrInvalidParameter) } if ldapVersion != 3 { return nil, fmt.Errorf("%s: incorrect ldap version, expected 3 but got %v", op, requestPacket.Value.(int64)) } default: // nothing to do or see here, move along please... :) } return &packet{Packet: p.Children[childApplicationRequest]}, nil } func (p *packet) requestType() (requestType, error) { const op = "gldap.(Packet).requestType" requestPacket, err := p.requestPacket() if err != nil { return unknownRequestType, fmt.Errorf("%s: %w", op, err) } switch requestPacket.Tag { case ApplicationBindRequest: return bindRequestType, nil case ApplicationSearchRequest: return searchRequestType, nil case ApplicationExtendedRequest: return extendedRequestType, nil case ApplicationModifyRequest: return modifyRequestType, nil case ApplicationAddRequest: return addRequestType, nil case ApplicationDelRequest: return deleteRequestType, nil case ApplicationUnbindRequest: return unbindRequestType, nil default: return unknownRequestType, fmt.Errorf("%s: unhandled request type %d: %w", op, requestPacket.Tag, ErrInternal) } } type modifyParameters struct { dn string changes []Change controls []Control } // return the DN, changes, and controls func (p *packet) modifyParameters() (*modifyParameters, error) { const ( op = "gldap.(packet).modifyParameters" childDN = 0 childChanges = 1 childOperation = 0 childModification = 1 childModificationType = 0 childModificationValues = 1 childControls = 2 ) requestPacket, err := p.requestPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if requestPacket.Packet.Tag != ApplicationModifyRequest { return nil, fmt.Errorf("%s: not an modify request, expected tag %d and got %d: %w", op, ApplicationModifyRequest, requestPacket.Tag, ErrInvalidParameter) } var parameters modifyParameters if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childDN)); err != nil { return nil, fmt.Errorf("%s: modify dn packet: %w", op, ErrInvalidParameter) } parameters.dn = requestPacket.Children[childDN].Data.String() // assert changes packet if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childChanges)); err != nil { return nil, fmt.Errorf("%s: modify changes packet: %w", op, ErrInvalidParameter) } changesPacket := requestPacket.Children[childChanges] parameters.changes = make([]Change, 0, len(changesPacket.Children)) for _, c := range changesPacket.Children { changePacket := packet{Packet: c} // assert this is a "Change" packet if err := changePacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence)); err != nil { return nil, fmt.Errorf("%s: modify changes child packet: %w", op, ErrInvalidParameter) } // assert the change operation child if err := changePacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childOperation)); err != nil { return nil, fmt.Errorf("%s: modify changes child operation packet: %w", op, ErrInvalidParameter) } var ok bool var chg Change if chg.Operation, ok = changePacket.Children[childOperation].Value.(int64); !ok { return nil, fmt.Errorf("%s: change operation is not an int64: %t", op, changePacket.Children[childOperation].Value) } // assert the change modification child if err := changePacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childModification)); err != nil { return nil, fmt.Errorf("%s: change modification child packet: %w", op, ErrInvalidParameter) } // get the modification type modificationPacket := packet{Packet: changePacket.Children[childModification]} if err := modificationPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childModificationType)); err != nil { return nil, fmt.Errorf("%s: modification type packet: %w", op, ErrInvalidParameter) } chg.Modification.Type = modificationPacket.Children[childModificationType].Data.String() // get the modification values if len(modificationPacket.Children) < childModificationValues+1 { return nil, fmt.Errorf("%s: missing modification values packet: %w", op, ErrInvalidParameter) } chg.Modification.Vals = make([]string, 0, len(modificationPacket.Children)-1) for _, value := range modificationPacket.Children[1:] { chg.Modification.Vals = append(chg.Modification.Vals, value.Data.String()) } parameters.changes = append(parameters.changes, chg) } controlPacket, err := p.controlPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if controlPacket != nil { parameters.controls = make([]Control, 0, len(controlPacket.Children)) for _, c := range controlPacket.Children { ctrl, err := decodeControl(c) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } parameters.controls = append(parameters.controls, ctrl) } } return ¶meters, nil } func (p *packet) extendedOperationName() (ExtendedOperationName, error) { const ( op = "gldap.(Packet).simpleBindParameters" childExtendedOperationName = 0 ) requestPacket, err := p.requestPacket() if err != nil { return "", fmt.Errorf("%s: %w", op, err) } if requestPacket.Packet.Tag != ApplicationExtendedRequest { return "", fmt.Errorf("%s: not an extended operation request, expected tag %d and got %d: %w", op, ApplicationExtendedRequest, requestPacket.Tag, ErrInvalidParameter) } if err := requestPacket.assert(ber.ClassContext, ber.TypePrimitive, withTag(0), withAssertChild(childExtendedOperationName)); err != nil { return "", fmt.Errorf("%s: missing/invalid username packet: %w", op, ErrInvalidParameter) } n := requestPacket.Children[childExtendedOperationName].Data.String() return ExtendedOperationName(n), nil } // Password is a simple bind request password type Password string func (p *packet) simpleBindParameters() (string, Password, []Control, error) { const ( op = "gldap.(Packet).simpleBindParameters" childBindUserName = 1 childBindPassword = 2 ) requestPacket, err := p.requestPacket() if err != nil { return "", "", nil, fmt.Errorf("%s: %w", op, err) } if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childBindUserName)); err != nil { return "", "", nil, fmt.Errorf("%s: missing/invalid username packet: %w", op, ErrInvalidParameter) } userName := requestPacket.Children[childBindUserName].Data.String() // check if there's even an password packet in the request if len(requestPacket.Children) > 3 { return userName, "", nil, nil } if err := requestPacket.assert(ber.ClassContext, ber.TypePrimitive, withTag(0), withAssertChild(childBindPassword)); err != nil { return "", "", nil, fmt.Errorf("%s: missing/invalid password packet: %w", op, ErrInvalidParameter) } password := requestPacket.Children[childBindPassword].Data.String() var controls []Control controlPacket, err := p.controlPacket() if err != nil { return "", "", nil, fmt.Errorf("%s: %w", op, err) } if controlPacket != nil { controls = make([]Control, 0, len(controlPacket.Children)) for _, c := range controlPacket.Children { ctrl, err := decodeControl(c) if err != nil { return "", "", nil, fmt.Errorf("%s: %w", op, err) } controls = append(controls, ctrl) } } return userName, Password(password), controls, nil } type addParameters struct { dn string attributes []Attribute controls []Control } // addParameters decodes the add request parameters from the packet func (p *packet) addParameters() (*addParameters, error) { const op = "gldap.(Packet).addParameters" const ( childDN = 0 childAttributes = 1 childControls = 2 ) var add addParameters requestPacket, err := p.requestPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } // validate that it's a search request if requestPacket.Packet.Tag != ApplicationAddRequest { return nil, fmt.Errorf("%s: not an add request, expected tag %d and got %d: %w", op, ApplicationAddRequest, requestPacket.Tag, ErrInvalidParameter) } // DN child if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childDN)); err != nil { return nil, fmt.Errorf("%s: missing/invalid DN: %w", op, ErrInvalidParameter) } add.dn = requestPacket.Children[childDN].Data.String() if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childAttributes)); err != nil { return nil, fmt.Errorf("%s: missing/invalid attributes: %w", op, ErrInvalidParameter) } attributesPackets := packet{ Packet: requestPacket.Children[childAttributes], } for _, attribute := range attributesPackets.Children { attr, err := decodeAttribute(attribute) if err != nil { return nil, fmt.Errorf("%s: failed to decode attribute packet: %w", op, err) } add.attributes = append(add.attributes, *attr) } controlPacket, err := p.controlPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if controlPacket != nil { add.controls = make([]Control, 0, len(controlPacket.Children)) for _, c := range controlPacket.Children { ctrl, err := decodeControl(c) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } add.controls = append(add.controls, ctrl) } } return &add, nil } type searchParameters struct { baseDN string scope int64 derefAliases int64 sizeLimit int64 timeLimit int64 typesOnly bool filter string attributes []string controls []Control } func (p *packet) searchParmeters() (*searchParameters, error) { const op = "gldap.(Packet).searchParmeters" const ( childBaseDN = 0 childScope = 1 childDerefAliases = 2 childSizeLimit = 3 childTimeLimit = 4 childTypesOnly = 5 childFilter = 6 childAttributes = 7 ) var ok bool var searchFor searchParameters requestPacket, err := p.requestPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } // validate that it's a search request if requestPacket.Packet.Tag != ApplicationSearchRequest { return nil, fmt.Errorf("%s: not an search request, expected tag %d and got %d: %w", op, ApplicationSearchRequest, requestPacket.Tag, ErrInvalidParameter) } // baseDN child if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childBaseDN)); err != nil { return nil, fmt.Errorf("%s: missing/invalid baseDN: %w", op, ErrInvalidParameter) } searchFor.baseDN = requestPacket.Children[childBaseDN].Data.String() // scope child if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childScope)); err != nil { return nil, fmt.Errorf("%s: missing/invalid scope: %w", op, ErrInvalidParameter) } if searchFor.scope, ok = requestPacket.Children[childScope].Value.(int64); !ok { return nil, fmt.Errorf("%s: scope is not an int64", op) } // deref aliases if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childDerefAliases)); err != nil { return nil, fmt.Errorf("%s: missing/invalid deref aliases: %w", op, ErrInvalidParameter) } if searchFor.derefAliases, ok = requestPacket.Children[childDerefAliases].Value.(int64); !ok { return nil, fmt.Errorf("%s: deref aliases is not an int64", op) } // size limit if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childSizeLimit)); err != nil { return nil, fmt.Errorf("%s: missing/invalid size limit: %w", op, ErrInvalidParameter) } if searchFor.sizeLimit, ok = requestPacket.Children[childSizeLimit].Value.(int64); !ok { return nil, fmt.Errorf("%s: size limit is not an int64", op) } // time limit if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childTimeLimit)); err != nil { return nil, fmt.Errorf("%s: missing/invalid time limit: %w", op, ErrInvalidParameter) } if searchFor.timeLimit, ok = requestPacket.Children[childTimeLimit].Value.(int64); !ok { return nil, fmt.Errorf("%s: time limit is not an int64", op) } // types only if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagBoolean), withAssertChild(childTypesOnly)); err != nil { return nil, fmt.Errorf("%s: missing/invalid types only: %w", op, ErrInvalidParameter) } if searchFor.typesOnly, ok = requestPacket.Children[childTypesOnly].Value.(bool); !ok { return nil, fmt.Errorf("%s: types only is not a bool", op) } if len(requestPacket.Children) < childFilter+1 { return nil, fmt.Errorf("%s: missing filter: %w", op, ErrInvalidParameter) } filter, err := ldap.DecompileFilter(requestPacket.Children[childFilter]) if err != nil { return nil, fmt.Errorf("%s: unable to decompile filter: %w", op, err) } searchFor.filter = filter // check for attributes packet if len(requestPacket.Children) < childAttributes+1 { return &searchFor, nil // there's none, so just return } if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childAttributes)); err != nil { return nil, fmt.Errorf("%s: invalid attributes: %w", op, err) } attributesPacket := packet{ Packet: requestPacket.Children[childAttributes], } searchFor.attributes = make([]string, 0, len(attributesPacket.Children)) for idx, attribute := range attributesPacket.Children { if err := attributesPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(idx)); err != nil { return nil, fmt.Errorf("%s: invalid attribute child packet: %w", op, err) } searchFor.attributes = append(searchFor.attributes, attribute.Data.String()) } controlPacket, err := p.controlPacket() if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } if controlPacket != nil { searchFor.controls = make([]Control, 0, len(controlPacket.Children)) for _, c := range controlPacket.Children { ctrl, err := decodeControl(c) if err != nil { return nil, fmt.Errorf("%s: %w", op, err) } searchFor.controls = append(searchFor.controls, ctrl) } } return &searchFor, nil } func (p *packet) assert(cl ber.Class, ty ber.Type, opt ...Option) error { const op = "gldap.assert" opts := getMessageOpts(opt...) if opts.withLenChildren != nil { if len(p.Children) != *opts.withLenChildren { return fmt.Errorf("%s: not the correct number of children packets, expected %d but got %d", op, *opts.withLenChildren, len(p.Children)) } } if opts.withMinChildren != nil { if len(p.Children) < *opts.withMinChildren { return fmt.Errorf("%s: not enough children packets, expected %d but got %d", op, *opts.withMinChildren, len(p.Children)) } } chkPacket := p.Packet if opts.withAssertChild != nil { if len(p.Children) < *opts.withAssertChild+1 { return fmt.Errorf("%s: missing asserted child %d, but there are only %d", op, *opts.withAssertChild, len(p.Children)) } chkPacket = p.Packet.Children[*opts.withAssertChild] } if chkPacket.ClassType != cl { return fmt.Errorf("%s: incorrect class, expected %v but got %v", op, cl, chkPacket.ClassType) } if chkPacket.TagType != ty { return fmt.Errorf("%s: incorrect type, expected %v but got %v", op, ty, chkPacket.TagType) } if opts.withTag != nil && chkPacket.Tag != *opts.withTag { return fmt.Errorf("%s: incorrect tag, expected %v but got %v", op, *opts.withTag, chkPacket.Tag) } return nil } func (p *packet) assertApplicationRequest() error { const ( op = "gldap.(packet).assertApplicationRequest" childApplicationRequest = 1 ) if len(p.Children) < childApplicationRequest+1 { return fmt.Errorf("%s: missing asserted application request child, but there are only %d", op, len(p.Children)) } chkPacket := p.Packet.Children[childApplicationRequest] if chkPacket.ClassType != ber.ClassApplication { return fmt.Errorf("%s: incorrect class, expected %v (ber.ClassApplication) but got %v", op, ber.ClassApplication, chkPacket.ClassType) } switch chkPacket.TagType { case ber.TypePrimitive: if chkPacket.Tag != ApplicationDelRequest && chkPacket.Tag != ApplicationUnbindRequest { return fmt.Errorf("%s: incorrect type, primitive %q must be a delete request %q or an unbind request %q, but got %q", op, ber.TypePrimitive, ApplicationDelRequest, ApplicationUnbindRequest, chkPacket.Tag) } case ber.TypeConstructed: default: return fmt.Errorf("%s: incorrect type, expected ber.TypeConstructed %q but got %v", op, ber.TypeConstructed, chkPacket.TagType) } return nil } func (p *packet) debug() { testLogger := hclog.New(&hclog.LoggerOptions{ Name: "debug-logger", Level: hclog.Debug, }) p.Log(testLogger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false) } // Log will pretty print log a packet func (p *packet) Log(out io.Writer, indent int, printBytes bool) { indentStr := "" for len(indentStr) != indent { indentStr += " " } classStr := ber.ClassMap[p.ClassType] tagtypeStr := ber.TypeMap[p.TagType] tagStr := fmt.Sprintf("0x%02X", p.Tag) if p.ClassType == ber.ClassUniversal { tagStr = tagMap[p.Tag] } value := fmt.Sprint(p.Value) description := "" if p.Description != "" { description = p.Description + ": " } fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagtypeStr, tagStr, p.Data.Len(), value) if printBytes { ber.PrintBytes(out, p.Bytes(), indentStr) } for _, child := range p.Children { childPacket := packet{Packet: child} childPacket.Log(out, indent+1, printBytes) } } func (p *packet) deleteParameters() (string, []Control, error) { const op = "gldap.(packet).deleteDN" requestPacket, err := p.requestPacket() if err != nil { return "", nil, fmt.Errorf("%s: %w", op, err) } if requestPacket.Packet.Tag != ApplicationDelRequest { return "", nil, fmt.Errorf("%s: not a delete request, expected tag %d and got %d: %w", op, ApplicationDelRequest, requestPacket.Tag, ErrInvalidParameter) } dn := requestPacket.Data.String() controlPacket, err := p.controlPacket() if err != nil { return "", nil, fmt.Errorf("%s: %w", op, err) } var controls []Control if controlPacket != nil { controls = make([]Control, 0, len(controlPacket.Children)) for _, c := range controlPacket.Children { ctrl, err := decodeControl(c) if err != nil { return "", nil, fmt.Errorf("%s: %w", op, err) } controls = append(controls, ctrl) } } return dn, controls, nil } var tagMap = map[ber.Tag]string{ ber.TagEOC: "EOC (End-of-Content)", ber.TagBoolean: "Boolean", ber.TagInteger: "Integer", ber.TagBitString: "Bit String", ber.TagOctetString: "Octet String", ber.TagNULL: "NULL", ber.TagObjectIdentifier: "Object Identifier", ber.TagObjectDescriptor: "Object Descriptor", ber.TagExternal: "External", ber.TagRealFloat: "Real (float)", ber.TagEnumerated: "Enumerated", ber.TagEmbeddedPDV: "Embedded PDV", ber.TagUTF8String: "UTF8 String", ber.TagRelativeOID: "Relative-OID", ber.TagSequence: "Sequence and Sequence of", ber.TagSet: "Set and Set OF", ber.TagNumericString: "Numeric String", ber.TagPrintableString: "Printable String", ber.TagT61String: "T61 String", ber.TagVideotexString: "Videotex String", ber.TagIA5String: "IA5 String", ber.TagUTCTime: "UTC Time", ber.TagGeneralizedTime: "Generalized Time", ber.TagGraphicString: "Graphic String", ber.TagVisibleString: "Visible String", ber.TagGeneralString: "General String", ber.TagUniversalString: "Universal String", ber.TagCharacterString: "Character String", ber.TagBMPString: "BMP String", }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "crypto/tls" "errors" "fmt" ber "github.com/go-asn1-ber/asn1-ber" ) // ExtendedOperationName is an extended operation request/response name type ExtendedOperationName string // Extended operation response/request names const ( ExtendedOperationDisconnection ExtendedOperationName = "1.3.6.1.4.1.1466.2003" ExtendedOperationCancel ExtendedOperationName = "1.3.6.1.1.8" ExtendedOperationStartTLS ExtendedOperationName = "1.3.6.1.4.1.1466.20037" ExtendedOperationWhoAmI ExtendedOperationName = "1.3.6.1.4.1.4203.1.11.3" ExtendedOperationGetConnectionID ExtendedOperationName = "1.3.6.1.4.1.26027.1.6.2" ExtendedOperationPasswordModify ExtendedOperationName = "1.3.6.1.4.1.4203.1.11.1" ExtendedOperationUnknown ExtendedOperationName = "Unknown" ) // Request represents an ldap request type Request struct { // ID is the request number for a specific connection. Every connection has // its own request counter which starts at 1. ID int // conn is needed this for cancellation among other things. conn *conn message Message routeOp routeOperation extendedName ExtendedOperationName } func newRequest(id int, c *conn, p *packet) (*Request, error) { const op = "gldap.newRequest" if c == nil { return nil, fmt.Errorf("%s: missing connection: %w", op, ErrInvalidParameter) } if p == nil { return nil, fmt.Errorf("%s: missing packet: %w", op, ErrInvalidParameter) } m, err := newMessage(p) if err != nil { return nil, fmt.Errorf("%s: unable to build message for request %d: %w", op, id, err) } var extendedName ExtendedOperationName var routeOp routeOperation switch v := m.(type) { case *SimpleBindMessage: routeOp = bindRouteOperation case *SearchMessage: routeOp = searchRouteOperation case *ExtendedOperationMessage: routeOp = extendedRouteOperation extendedName = v.Name case *ModifyMessage: routeOp = modifyRouteOperation case *AddMessage: routeOp = addRouteOperation case *DeleteMessage: routeOp = deleteRouteOperation case *UnbindMessage: routeOp = unbindRouteOperation default: // this should be unreachable, since newMessage defaults to returning an // *ExtendedOperationMessage return nil, fmt.Errorf("%s: %v is an unsupported route operation: %w", op, v, ErrInternal) } r := &Request{ ID: id, conn: c, message: m, routeOp: routeOp, extendedName: extendedName, } return r, nil } // ConnectionID returns the request's connection ID which enables you to know // "who" (i.e. which connection) made a request. Using the connection ID you // can do things like ensure a connection performing a search operation has // successfully authenticated (a.k.a. performed a successful bind operation). func (r *Request) ConnectionID() int { return r.conn.connID } // NewModifyResponse creates a modify response // Supported options: WithResponseCode, WithDiagnosticMessage, WithMatchedDN func (r *Request) NewModifyResponse(opt ...Option) *ModifyResponse { opts := getResponseOpts(opt...) return &ModifyResponse{ GeneralResponse: r.NewResponse( WithApplicationCode(ApplicationModifyResponse), WithResponseCode(*opts.withResponseCode), WithDiagnosticMessage(opts.withDiagnosticMessage), WithMatchedDN(opts.withMatchedDN), ), } } // StartTLS will start a TLS connection using the Message's existing connection func (r *Request) StartTLS(tlsconfig *tls.Config) error { const op = "gldap.(Message).StartTLS" if tlsconfig == nil { return fmt.Errorf("%s: missing tls configuration: %w", op, ErrInvalidParameter) } tlsConn := tls.Server(r.conn.netConn, tlsconfig) if err := tlsConn.Handshake(); err != nil { return fmt.Errorf("%s: handshake error: %w", op, err) } if err := r.conn.initConn(tlsConn); err != nil { return fmt.Errorf("%s: %w", op, err) } return nil } // NewResponse creates a general response (not necessarily to any specific // request because you can set WithApplicationCode). // Supported options: WithResponseCode, WithApplicationCode, // WithDiagnosticMessage, WithMatchedDN func (r *Request) NewResponse(opt ...Option) *GeneralResponse { const op = "gldap.NewResponse" // nolint:unused opts := getResponseOpts(opt...) if opts.withResponseCode == nil { opts.withResponseCode = intPtr(ResultUnwillingToPerform) } if opts.withApplicationCode == nil { opts.withApplicationCode = intPtr(ApplicationExtendedResponse) } return &GeneralResponse{ baseResponse: &baseResponse{ messageID: r.message.GetID(), code: int16(*opts.withResponseCode), diagMessage: opts.withDiagnosticMessage, matchedDN: opts.withMatchedDN, }, applicationCode: *opts.withApplicationCode, } } // NewExtendedResponse creates a new extended response. // Supported options: WithResponseCode func (r *Request) NewExtendedResponse(opt ...Option) *ExtendedResponse { const op = "gldap.NewExtendedResponse" // nolint:unused opts := getResponseOpts(opt...) resp := &ExtendedResponse{ baseResponse: &baseResponse{ messageID: r.message.GetID(), }, } if opts.withResponseCode != nil { resp.code = int16(*opts.withResponseCode) } return resp } // NewBindResponse creates a new bind response. // Supported options: WithResponseCode func (r *Request) NewBindResponse(opt ...Option) *BindResponse { const op = "gldap.NewBindResponse" // nolint:unused opts := getResponseOpts(opt...) resp := &BindResponse{ baseResponse: &baseResponse{ messageID: r.message.GetID(), }, } if opts.withResponseCode != nil { resp.code = int16(*opts.withResponseCode) } return resp } // GetSimpleBindMessage retrieves the SimpleBindMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetSimpleBindMessage() (*SimpleBindMessage, error) { const op = "gldap.(Request).GetSimpleBindMessage" s, ok := r.message.(*SimpleBindMessage) if !ok { return nil, fmt.Errorf("%s: %T not a simple bind request: %w", op, r.message, ErrInvalidParameter) } return s, nil } // NewSearchDoneResponse creates a new search done response. If there are no // results found, then set the response code by adding the option // WithResponseCode(ResultNoSuchObject) // // Supported options: WithResponseCode func (r *Request) NewSearchDoneResponse(opt ...Option) *SearchResponseDone { const op = "gldap.(Request).NewSearchDoneResponse" // nolint:unused opts := getResponseOpts(opt...) resp := &SearchResponseDone{ baseResponse: &baseResponse{ messageID: r.message.GetID(), }, } if opts.withResponseCode != nil { resp.code = int16(*opts.withResponseCode) } return resp } // GetSearchMessage retrieves the SearchMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetSearchMessage() (*SearchMessage, error) { const op = "gldap.(Request).GetSearchMessage" m, ok := r.message.(*SearchMessage) if !ok { return nil, fmt.Errorf("%s: %T not a search request: %w", op, r.message, ErrInvalidParameter) } return m, nil } // NewSearchResponseEntry is a search response entry. // Supported options: WithAttributes func (r *Request) NewSearchResponseEntry(entryDN string, opt ...Option) *SearchResponseEntry { opts := getResponseOpts(opt...) newAttrs := make([]*EntryAttribute, 0, len(opts.withAttributes)) for name, values := range opts.withAttributes { newAttrs = append(newAttrs, NewEntryAttribute(name, values)) } return &SearchResponseEntry{ baseResponse: &baseResponse{ messageID: r.message.GetID(), }, entry: Entry{ DN: entryDN, Attributes: newAttrs, }, } } // GetModifyMessage retrieves the ModifyMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetModifyMessage() (*ModifyMessage, error) { const op = "gldap.(Request).GetModifyMessage" m, ok := r.message.(*ModifyMessage) if !ok { return nil, fmt.Errorf("%s: %T not a modify request: %w", op, r.message, ErrInvalidParameter) } return m, nil } // GetAddMessage retrieves the AddMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetAddMessage() (*AddMessage, error) { const op = "gldap.(Request).GetAddMessage" m, ok := r.message.(*AddMessage) if !ok { return nil, fmt.Errorf("%s: %T not a add request: %w", op, r.message, ErrInvalidParameter) } return m, nil } // GetDeleteMessage retrieves the DeleteMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetDeleteMessage() (*DeleteMessage, error) { const op = "gldap.(Request).GetDeleteMessage" m, ok := r.message.(*DeleteMessage) if !ok { return nil, fmt.Errorf("%s: %T not a delete request: %w", op, r.message, ErrInvalidParameter) } return m, nil } // GetUnbindMessage retrieves the UnbindMessage from the request, which // allows you handle the request based on the message attributes. func (r *Request) GetUnbindMessage() (*UnbindMessage, error) { const op = "gldap.(Request).GetUnbindMessage" m, ok := r.message.(*UnbindMessage) if !ok { return nil, fmt.Errorf("%s: %T not an unbind request: %w", op, r.message, ErrInvalidParameter) } return m, nil } // ConvertString will convert an ASN1 BER Octet string into a "native" go // string. Support ber string encoding types: OctetString, GeneralString and // all other types will return an error. func ConvertString(octetString ...string) ([]string, error) { const ( op = "gldap.ConvertOctetString" berTagIdx = 0 startOfDataIdx = 1 ) converted := make([]string, 0, len(octetString)) for _, s := range octetString { data := []byte(s) switch { case ber.Tag(data[berTagIdx]) == ber.TagOctetString, ber.Tag(data[berTagIdx]) == ber.TagGeneralString: _, strDataLen, err := readLength(data[startOfDataIdx:]) if err != nil { return nil, err } converted = append(converted, string(data[(startOfDataIdx+strDataLen):])) default: return nil, fmt.Errorf("%s: unsupported ber encoding type %s: %w", op, string(data[berTagIdx]), ErrInvalidParameter) } } return converted, nil } // readLength(...) // jimlambrt: 2/2023 // copied directly from github.com/go-asn1-ber/asn1-ber@v1.5.4/length.go // it has an MIT license: https://github.com/go-asn1-ber/asn1-ber/blob/master/LICENSE func readLength(bytes []byte) (length int, read int, err error) { // length byte b := bytes[0] read++ switch { case b == 0xFF: // Invalid 0xFF (x.600, 8.1.3.5.c) return 0, read, errors.New("invalid length byte 0xff") case b == ber.LengthLongFormBitmask: // Indefinite form, we have to decode packets until we encounter an EOC packet (x.600, 8.1.3.6) length = ber.LengthIndefinite case b&ber.LengthLongFormBitmask == 0: // Short definite form, extract the length from the bottom 7 bits (x.600, 8.1.3.4) length = int(b) & ber.LengthValueBitmask case b&ber.LengthLongFormBitmask != 0: // Long definite form, extract the number of length bytes to follow from the bottom 7 bits (x.600, 8.1.3.5.b) lengthBytes := int(b) & ber.LengthValueBitmask // Protect against overflow // TODO: support big int length? if lengthBytes > 8 { return 0, read, errors.New("long-form length overflow") } // Accumulate into a 64-bit variable var length64 int64 for i := 0; i < lengthBytes; i++ { b = bytes[read] read++ // x.600, 8.1.3.5 length64 <<= 8 length64 |= int64(b) } // Cast to a platform-specific integer length = int(length64) // Ensure we didn't overflow if int64(length) != length64 { return 0, read, errors.New("long-form length overflow") } default: return 0, read, errors.New("invalid length byte") } return length, read, nil } func intPtr(i int) *int { return &i }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "bufio" "fmt" "sync" ber "github.com/go-asn1-ber/asn1-ber" "github.com/hashicorp/go-hclog" ) // ResponseWriter is an ldap request response writer which is used by a // HanderFunc to write responses to client requests. type ResponseWriter struct { writerMu *sync.Mutex // a shared lock across all requests to prevent data races when writing writer *bufio.Writer logger hclog.Logger connID int requestID int } func newResponseWriter(w *bufio.Writer, lock *sync.Mutex, logger hclog.Logger, connID, requestID int) (*ResponseWriter, error) { const op = "gldap.NewResponseWriter" if w == nil { return nil, fmt.Errorf("%s: missing writer: %w", op, ErrInvalidParameter) } if lock == nil { return nil, fmt.Errorf("%s: missing writer lock: %w", op, ErrInvalidParameter) } if logger == nil { return nil, fmt.Errorf("%s: missing logger: %w", op, ErrInvalidParameter) } if connID == 0 { return nil, fmt.Errorf("%s: missing conn ID: %w", op, ErrInvalidParameter) } if requestID == 0 { return nil, fmt.Errorf("%s: missing request ID: %w", op, ErrInvalidParameter) } return &ResponseWriter{ writerMu: lock, writer: w, logger: logger, connID: connID, requestID: requestID, }, nil } // Write will write the response to the client func (rw *ResponseWriter) Write(r Response) error { const op = "gldap.(ResponseWriter).Write" if r == nil { return fmt.Errorf("%s: missing response: %w", op, ErrInvalidParameter) } p := r.packet() if rw.logger.IsDebug() { rw.logger.Debug("response write", "op", op, "conn", rw.connID, "requestID", rw.requestID) p.Log(rw.logger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false) } rw.writerMu.Lock() defer rw.writerMu.Unlock() if _, err := rw.writer.Write(r.packet().Bytes()); err != nil { return fmt.Errorf("%s: unable to write response: %w", op, err) } if err := rw.writer.Flush(); err != nil { return fmt.Errorf("%s: unable to flush write: %w", op, err) } rw.logger.Debug("finished writing", "op", op, "conn", rw.connID, "requestID", rw.requestID) return nil } func beginResponse(messageID int64) *ber.Packet { const op = "gldap.beginResponse" // nolint:unused p := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") p.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID")) return p } func addOptionalResponseChildren(bindResponse *ber.Packet, opt ...Option) { const op = "gldap.addOptionalResponseChildren" // nolint:unused opts := getResponseOpts(opt...) bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, opts.withMatchedDN, "matchedDN")) bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, opts.withDiagnosticMessage, "diagnosticMessage")) } // Response represents a response to an ldap request type Response interface { packet() *packet } type baseResponse struct { messageID int64 code int16 diagMessage string matchedDN string } // SetResultCode the result code for a response. func (l *baseResponse) SetResultCode(code int) { l.code = int16(code) } // SetDiagnosticMessage sets the optional diagnostic message for a response. func (l *baseResponse) SetDiagnosticMessage(msg string) { l.diagMessage = msg } // SetMatchedDN sets the optional matched DN for a response. func (l *baseResponse) SetMatchedDN(dn string) { l.matchedDN = dn } // ExtendedResponse represents a response to an extended operation request type ExtendedResponse struct { *baseResponse name ExtendedOperationName } // SetResponseName will set the response name for the extended operation response. func (r *ExtendedResponse) SetResponseName(n ExtendedOperationName) { r.name = n } func (r *ExtendedResponse) packet() *packet { replyPacket := beginResponse(r.messageID) // a new packet for the bind response resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(ApplicationExtendedResponse), nil, ApplicationCodeMap[ApplicationExtendedResponse]) // append the result code to the bind response packet resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)])) // Add optional diagnostic message and matched DN addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN)) replyPacket.AppendChild(resultPacket) return &packet{Packet: replyPacket} } // BindResponse represents the response to a bind request type BindResponse struct { *baseResponse controls []Control } // SetControls for bind response func (r *BindResponse) SetControls(controls ...Control) { r.controls = controls } func (r *BindResponse) packet() *packet { replyPacket := beginResponse(r.messageID) // a new packet for the bind response resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(ApplicationBindResponse), nil, ApplicationCodeMap[ApplicationBindResponse]) // append the result code to the bind response packet resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)])) // Add optional diagnostic message and matched DN addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN)) replyPacket.AppendChild(resultPacket) if len(r.controls) > 0 { replyPacket.AppendChild(encodeControls(r.controls)) } return &packet{Packet: replyPacket} } // GeneralResponse represents a general response (non-specific to a request). type GeneralResponse struct { *baseResponse applicationCode int } func (r *GeneralResponse) packet() *packet { const op = "gldap.(GeneralResponse).packet" // nolint:unused replyPacket := beginResponse(r.messageID) // a new packet for the bind response resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(r.applicationCode), nil, ApplicationCodeMap[uint8(r.applicationCode)]) // append the result code to the bind response packet resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)])) // Add optional diagnostic message and matched DN addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN)) replyPacket.AppendChild(resultPacket) return &packet{Packet: replyPacket} } // SearchResponseDone represents that handling a search requests is done. type SearchResponseDone struct { *baseResponse controls []Control } // SetControls for the search response func (r *SearchResponseDone) SetControls(controls ...Control) { r.controls = controls } func (r *SearchResponseDone) packet() *packet { const op = "gldap.(SearchDoneResponse).packet" // nolint:unused replyPacket := beginResponse(r.messageID) resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, ApplicationCodeMap[ApplicationSearchResultDone]) resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)])) // Add optional diagnostic message and matched DN addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN)) replyPacket.AppendChild(resultPacket) if len(r.controls) > 0 { replyPacket.AppendChild(encodeControls(r.controls)) } return &packet{Packet: replyPacket} } // SearchResponseEntry is an ldap entry that's part of search response. type SearchResponseEntry struct { *baseResponse entry Entry } // AddAttribute will an attributes to the response entry func (r *SearchResponseEntry) AddAttribute(name string, values []string) { r.entry.Attributes = append(r.entry.Attributes, NewEntryAttribute(name, values)) } func (r *SearchResponseEntry) packet() *packet { const op = "gldap.(SearchEntryResponse).packet" // nolint:unused replyPacket := beginResponse(r.messageID) resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, ApplicationCodeMap[ApplicationSearchResultEntry]) resultPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, r.entry.DN, "DN")) attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") for _, a := range r.entry.Attributes { attributesPacket.AppendChild(a.encode()) } resultPacket.AppendChild(attributesPacket) replyPacket.AppendChild(resultPacket) return &packet{Packet: replyPacket} } // ModifyResponse is a response to a modify request. type ModifyResponse struct { *GeneralResponse }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap type responseOptions struct { withDiagnosticMessage string withMatchedDN string withResponseCode *int withApplicationCode *int withAttributes map[string][]string } func responseDefaults() responseOptions { return responseOptions{ withMatchedDN: "Unused", withDiagnosticMessage: "Unused", } } func getResponseOpts(opt ...Option) responseOptions { opts := responseDefaults() applyOpts(&opts, opt...) return opts } // WithDiagnosticMessage provides an optional diagnostic message for the // response. func WithDiagnosticMessage(msg string) Option { return func(o interface{}) { if o, ok := o.(*responseOptions); ok { o.withDiagnosticMessage = msg } } } // WithMatchedDN provides an optional match DN for the response. func WithMatchedDN(dn string) Option { return func(o interface{}) { if o, ok := o.(*responseOptions); ok { o.withMatchedDN = dn } } } // WithResponseCode specifies the ldap response code. For a list of valid codes // see: // https://github.com/go-ldap/ldap/blob/13008e4c5260d08625b65eb1f172ae909152b751/v3/error.go#L11 func WithResponseCode(code int) Option { return func(o interface{}) { if o, ok := o.(*responseOptions); ok { o.withResponseCode = &code } } } // WithApplicationCode specifies the ldap application code. For a list of valid codes // for a list of supported application codes see: // https://github.com/jimlambrt/gldap/blob/8f171b8eb659c76019719382c4daf519dd1281e6/codes.go#L159 func WithApplicationCode(applicationCode int) Option { return func(o interface{}) { if o, ok := o.(*responseOptions); ok { o.withApplicationCode = &applicationCode } } } // WithAttributes specifies optional attributes for a response entry func WithAttributes(attributes map[string][]string) Option { return func(o interface{}) { if o, ok := o.(*responseOptions); ok { o.withAttributes = attributes } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "strings" ) // routeOperation represents the ldap operation for a route. type routeOperation string const ( // undefinedRouteOperation is an undefined operation. undefinedRouteOperation routeOperation = "" // nolint:unused // bindRouteOperation is a route supporting the bind operation bindRouteOperation routeOperation = "bind" // searchRouteOperation is a route supporting the search operation searchRouteOperation routeOperation = "search" // extendedRouteOperation is a route supporting an extended operation extendedRouteOperation routeOperation = "extendedOperation" // modifyRouteOperation is a route supporting the modify operation modifyRouteOperation routeOperation = "modify" // addRouteOperation is a route supporting the add operation addRouteOperation routeOperation = "add" // deleteRouteOperation is a route supporting the delete operation deleteRouteOperation routeOperation = "delete" // unbindRouteOperation is a route supporting the unbind operation unbindRouteOperation routeOperation = "unbind" // defaultRouteOperation is a default route which is used when there are no routes // defined for a particular operation defaultRouteOperation routeOperation = "noRoute" // nolint:unused ) // HandlerFunc defines a function for handling an LDAP request. type HandlerFunc func(*ResponseWriter, *Request) type route interface { match(req *Request) bool handler() HandlerFunc op() routeOperation } type baseRoute struct { h HandlerFunc routeOp routeOperation label string } func (r *baseRoute) handler() HandlerFunc { return r.h } func (r *baseRoute) op() routeOperation { return r.routeOp } func (r *baseRoute) match(req *Request) bool { return false } type searchRoute struct { *baseRoute basedn string filter string scope Scope } type simpleBindRoute struct { *baseRoute authChoice AuthChoice } type unbindRoute struct { *baseRoute } type extendedRoute struct { *baseRoute extendedName ExtendedOperationName } type modifyRoute struct { *baseRoute } type addRoute struct { *baseRoute } type deleteRoute struct { *baseRoute } func (r *deleteRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } if _, ok := req.message.(*DeleteMessage); !ok { return false } return true } func (r *addRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } if _, ok := req.message.(*AddMessage); !ok { return false } return true } func (r *modifyRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } if _, ok := req.message.(*ModifyMessage); !ok { return false } return true } func (r *simpleBindRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } if m, ok := req.message.(*SimpleBindMessage); ok { if r.authChoice != "" && r.authChoice == m.AuthChoice { return true } } return false } func (r *extendedRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } if r.extendedName != req.extendedName { return false } _, ok := req.message.(*ExtendedOperationMessage) return ok } func (r *searchRoute) match(req *Request) bool { if req == nil { return false } if r.op() != req.routeOp { return false } searchMsg, ok := req.message.(*SearchMessage) if !ok { return false } if r.basedn != "" && !strings.EqualFold(searchMsg.BaseDN, r.basedn) { return false } if r.filter != "" && !strings.EqualFold(searchMsg.Filter, r.filter) { return false } if r.scope != 0 && searchMsg.Scope != r.scope { return false } // if it didn't get eliminated by earlier request criteria, then it's a // match. return true }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap type routeOptions struct { withLabel string withBaseDN string withFilter string withScope Scope } func routeDefaults() routeOptions { return routeOptions{} } func getRouteOpts(opt ...Option) routeOptions { opts := routeDefaults() applyOpts(&opts, opt...) return opts } // WithLabel specifies an optional label for the route func WithLabel(l string) Option { return func(o interface{}) { if o, ok := o.(*routeOptions); ok { o.withLabel = l } } } // WithBaseDN specifies an optional base DN to associate with a Search route func WithBaseDN(dn string) Option { return func(o interface{}) { if o, ok := o.(*routeOptions); ok { o.withBaseDN = dn } } } // WithFilter specifies an optional filter to associate with a Search route func WithFilter(filter string) Option { return func(o interface{}) { if o, ok := o.(*routeOptions); ok { o.withFilter = filter } } } // WithScope specifies and optional scope to associate with a Search route func WithScope(s Scope) Option { return func(o interface{}) { if o, ok := o.(*routeOptions); ok { o.withScope = s } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "context" "crypto/tls" "fmt" "net" "strings" "sync" "time" "github.com/hashicorp/go-hclog" ) // Server is an ldap server that you can add a mux (multiplexer) router to and // then run it to accept and process requests. type Server struct { mu sync.RWMutex logger hclog.Logger connWg sync.WaitGroup listener net.Listener listenerReady bool router *Mux tlsConfig *tls.Config readTimeout time.Duration writeTimeout time.Duration onCloseHandler OnCloseHandler disablePanicRecovery bool shutdownCancel context.CancelFunc shutdownCtx context.Context } // NewServer creates a new ldap server // // Options supported: // - WithLogger allows you pass a logger with whatever hclog.Level you wish including hclog.Off to turn off all logging // - WithReadTimeout will set a read time out per connection // - WithWriteTimeout will set a write time out per connection // - WithOnClose will define a callback the server will call every time a connection is closed func NewServer(opt ...Option) (*Server, error) { cancelCtx, cancel := context.WithCancel(context.Background()) opts := getConfigOpts(opt...) if opts.withLogger == nil { opts.withLogger = hclog.New(&hclog.LoggerOptions{ Name: "Server-logger", Level: hclog.Error, }) } return &Server{ router: &Mux{}, // TODO: a better default router logger: opts.withLogger, shutdownCancel: cancel, shutdownCtx: cancelCtx, writeTimeout: opts.withWriteTimeout, readTimeout: opts.withReadTimeout, disablePanicRecovery: opts.withDisablePanicRecovery, onCloseHandler: opts.withOnClose, }, nil } // Index of rightmost occurrence of b in s. func last(s string, b byte) int { i := len(s) for i--; i >= 0; i-- { if s[i] == b { break } } return i } // validateAddr will not only validate the addr, but if it's an ipv6 literal without // proper brackets, it will add them. func validateAddr(addr string) (string, error) { const op = "gldap.parseAddr" lastColon := last(addr, ':') if lastColon < 0 { return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter) } rawHost := addr[0:lastColon] rawPort := addr[lastColon+1:] if len(rawPort) == 0 { return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter) } if len(rawHost) == 0 { return fmt.Sprintf(":%s", rawPort), nil } // ipv6 literal with proper brackets if rawHost[0] == '[' { // Expect the first ']' just before the last ':'. end := strings.IndexByte(rawHost, ']') if end < 0 { return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter) } trimedIp := strings.Trim(rawHost, "[]") if net.ParseIP(trimedIp) == nil { return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter) } // ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good return fmt.Sprintf("%s:%s", rawHost, rawPort), nil } // see if we're dealing with a hostname hostnames, _ := net.LookupHost(rawHost) if len(hostnames) > 0 { if rawHost == "::1" { // special case for localhost return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil } return fmt.Sprintf("%s:%s", rawHost, rawPort), nil } lastColon = last(rawHost, ':') if lastColon >= 0 { // ipv6 literal without proper brackets ipv6Literal := fmt.Sprintf("[%s]", rawHost) if net.ParseIP(ipv6Literal) == nil { return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter) } return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil } // ipv4 if net.ParseIP(rawHost) == nil { return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter) } return fmt.Sprintf("%s:%s", rawHost, rawPort), nil } // Run will run the server which will listen and serve requests. // // Options supported: WithTLSConfig func (s *Server) Run(addr string, opt ...Option) error { const op = "gldap.(Server).Run" opts := getConfigOpts(opt...) var err error addr, err = validateAddr(addr) if err != nil { return fmt.Errorf("%s: %w", op, err) } s.mu.Lock() s.listener, err = net.Listen("tcp", addr) s.listenerReady = true s.mu.Unlock() if err != nil { return fmt.Errorf("%s: unable to listen to addr %s: %w", op, addr, err) } if opts.withTLSConfig != nil { s.logger.Debug("setting up TLS listener", "op", op) s.tlsConfig = opts.withTLSConfig s.mu.Lock() s.listener = tls.NewListener(s.listener, s.tlsConfig) s.mu.Unlock() } s.logger.Info("listening", "op", op, "addr", s.listener.Addr()) connID := 0 for { connID++ select { case <-s.shutdownCtx.Done(): return nil default: // need a default to fall through to rest of loop... } c, err := s.listener.Accept() if err != nil { if strings.Contains(err.Error(), "use of closed network connection") { s.logger.Debug("accept on closed conn") return nil } return fmt.Errorf("%s: error accepting conn: %w", op, err) } s.logger.Debug("new connection accepted", "op", op, "conn", connID) conn, err := newConn(s.shutdownCtx, connID, c, s.logger, s.router) if err != nil { return fmt.Errorf("%s: unable to create in-memory conn: %w", op, err) } localConnID := connID s.connWg.Add(1) go func() { defer func() { s.logger.Debug("connWg done", "op", op, "conn", localConnID) s.connWg.Done() err := conn.close() if err != nil { s.logger.Error("error closing conn", "op", op, "conn", localConnID, "conn/req", "err", err) // we are intentionally not returning here; since we still // need to call the onCloseHandler if it's not nil } if s.onCloseHandler != nil { s.onCloseHandler(localConnID) } }() if !s.disablePanicRecovery { // catch and report panics - we don't want it to crash the server if // handling a single conn causes a panic defer func() { if r := recover(); r != nil { s.logger.Error("Caught panic while serving request", "op", op, "conn", localConnID, "conn/req", fmt.Sprintf("%+v: %+v", c, r)) } }() } if s.readTimeout != 0 { if err := c.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil { s.logger.Error("unable to set read deadline", "op", op, "err", err.Error()) return } } if s.writeTimeout != 0 { if err := c.SetWriteDeadline(time.Now().Add(s.writeTimeout)); err != nil { s.logger.Error("unable to set write deadline", "op", op, "err", err.Error()) return } } if err := conn.serveRequests(); err != nil { s.logger.Error("error handling conn", "op", op, "conn", localConnID, "err", err.Error()) } }() } } // Ready will return true when the server is ready to accept connection func (s *Server) Ready() bool { s.mu.RLock() defer s.mu.RUnlock() return s.listenerReady } // Stop a running ldap server func (s *Server) Stop() error { const op = "gldap.(Server).Stop" s.mu.RLock() defer s.mu.RUnlock() s.logger.Debug("shutting down") if s.listener == nil && s.shutdownCancel == nil { s.logger.Debug("nothing to do for shutdown") return nil } if s.listener != nil { s.logger.Debug("closing listener") if err := s.listener.Close(); err != nil { switch { case !strings.Contains(err.Error(), "use of closed network connection"): return fmt.Errorf("%s: %w", op, err) default: s.logger.Debug("listener already closed") } } } if s.shutdownCancel != nil { s.logger.Debug("shutdown cancel func") s.shutdownCancel() } s.logger.Debug("waiting on connections to close") s.connWg.Wait() s.logger.Debug("stopped") return nil } // Router sets the mux (multiplexer) router for matching inbound requests // to handlers. func (s *Server) Router(r *Mux) error { const op = "gldap.(Server).HandleRoutes" if r == nil { return fmt.Errorf("%s: missing router: %w", op, ErrInvalidParameter) } s.mu.Lock() defer s.mu.Unlock() s.router = r return nil }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "crypto/tls" "time" "github.com/hashicorp/go-hclog" ) type configOptions struct { withTLSConfig *tls.Config withLogger hclog.Logger withReadTimeout time.Duration withWriteTimeout time.Duration withDisablePanicRecovery bool withOnClose OnCloseHandler } func configDefaults() configOptions { return configOptions{} } // getConfigOpts gets the defaults and applies the opt overrides passed // in. func getConfigOpts(opt ...Option) configOptions { opts := configDefaults() applyOpts(&opts, opt...) return opts } // WithLogger provides the optional logger. func WithLogger(l hclog.Logger) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { o.withLogger = l } } } // WithTLSConfig provides an optional tls.Config func WithTLSConfig(tc *tls.Config) Option { return func(o interface{}) { switch v := o.(type) { case *configOptions: v.withTLSConfig = tc } } } // WithReadTimeout will set a read time out per connection func WithReadTimeout(d time.Duration) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { o.withReadTimeout = d } } } // WithWriteTimeout will set a write timeout per connection func WithWriteTimeout(d time.Duration) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { o.withWriteTimeout = d } } } // WithDisablePanicRecovery will disable recovery from panics which occur when // handling a request. This is helpful for debugging since you'll get the // panic's callstack. func WithDisablePanicRecovery() Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { o.withDisablePanicRecovery = true } } } // OnCloseHandler defines a function for a "on close" callback handler. See: // NewServer(...) and WithOnClose(...) option for more information type OnCloseHandler func(connectionID int) // WithOnClose defines a OnCloseHandler that the server will use as a callback // every time a connection to the server is closed. This allows callers to // clean up resources for closed connections (using their ID to determine which // one to clean up) func WithOnClose(handler OnCloseHandler) Option { return func(o interface{}) { if o, ok := o.(*configOptions); ok { o.withOnClose = handler } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "bytes" "encoding/binary" "fmt" ) // SIDBytes creates a SID from the provided revision and identifierAuthority func SIDBytes(revision uint8, identifierAuthority uint16) ([]byte, error) { const op = "gldap.SidBytes" var identifierAuthorityParts [3]uint16 identifierAuthorityParts[2] = identifierAuthority subAuthorityCount := uint8(0) var writer bytes.Buffer if err := binary.Write(&writer, binary.LittleEndian, uint8(revision)); err != nil { return nil, fmt.Errorf("%s: unable to write revision: %w", op, err) } if err := binary.Write(&writer, binary.LittleEndian, subAuthorityCount); err != nil { return nil, fmt.Errorf("%s: unable to write subauthority count: %w", op, err) } if err := binary.Write(&writer, binary.BigEndian, identifierAuthorityParts); err != nil { return nil, fmt.Errorf("%s: unable to write authority parts: %w", op, err) } return writer.Bytes(), nil } // SIDBytesToString will convert SID bytes to a string func SIDBytesToString(b []byte) (string, error) { const op = "gldap.sidBytesToString" reader := bytes.NewReader(b) var revision, subAuthorityCount uint8 var identifierAuthorityParts [3]uint16 if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil { return "", fmt.Errorf("%s: SID %#v convert failed reading Revision: %w", op, b, err) } if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil { return "", fmt.Errorf("%s: SID %#v convert failed reading SubAuthorityCount: %w", op, b, err) } if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil { return "", fmt.Errorf("%s: SID %#v convert failed reading IdentifierAuthority: %w", op, b, err) } identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2]) subAuthority := make([]uint32, subAuthorityCount) if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil { return "", fmt.Errorf("%s: SID %#v convert failed reading SubAuthority: %w", op, b, err) } result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority) for _, subAuthorityPart := range subAuthority { result += fmt.Sprintf("-%d", subAuthorityPart) } return result, nil }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package testdirectory import ( "bytes" "crypto/tls" "crypto/x509" "encoding/pem" "fmt" "regexp" "strings" "sync" "time" "github.com/cenkalti/backoff" "github.com/go-ldap/ldap/v3" "github.com/hashicorp/go-hclog" "github.com/jimlambrt/gldap" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" ) const ( // DefaultUserAttr is the "username" attribute of the entry's DN and is // typically either the cn in ActiveDirectory or uid in openLDAP (default: // cn) DefaultUserAttr = "cn" // DefaultGroupAttr for the ClientConfig.GroupAttr DefaultGroupAttr = "cn" // DefaultUserDN defines a default base distinguished name to use when // searching for users for the Directory DefaultUserDN = "ou=people,dc=example,dc=org" // DefaultGroupDN defines a default base distinguished name to use when // searching for groups for the Directory DefaultGroupDN = "ou=groups,dc=example,dc=org" ) // Directory is a local ldap directory that supports test ldap capabilities // which makes writing tests much easier. // // It's important to remember that the Directory is stateful (see any of its // receiver functions that begin with Set*) // // Once you started a Directory with Start(...), the following // test ldap operations are supported: // // - Bind // - StartTLS // - Search // - Modify // - Add // // Making requests to the Directory is facilitated by: // - Directory.Conn() returns a *ldap.Conn connected to the Directory (honors WithMTLS options from start) // - Directory.Cert() returns the pem-encoded CA certificate used by the directory. // - Directory.Port() returns the port the directory is listening on. // - Directory.ClientCert() returns a client cert for mtls // - Directory.ClientKey() returns a client private key for mtls type Directory struct { t TestingT s *gldap.Server logger hclog.Logger port int host string useTLS bool client *tls.Config server *tls.Config mu sync.Mutex users []*gldap.Entry groups []*gldap.Entry tokenGroups map[string][]*gldap.Entry // string == SID allowAnonymousBind bool controls []gldap.Control // userDN is the base distinguished name to use when searching for users userDN string // groupDN is the base distinguished name to use when searching for groups groupDN string } // Start creates and starts a running Directory ldap server. // Support options: WithPort, WithMTLS, WithNoTLS, WithDefaults, // WithLogger. // // The Directory will be shutdown when the test and all its // subtests are compted via a registered function with t.Cleanup(...) func Start(t TestingT, opt ...Option) *Directory { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } require := require.New(t) opts := getOpts(t, opt...) if opts.withPort == 0 { opts.withPort = FreePort(t) } d := &Directory{ t: t, logger: opts.withLogger, users: opts.withDefaults.Users, groups: opts.withDefaults.Groups, port: opts.withPort, host: opts.withHost, userDN: opts.withDefaults.UserDN, groupDN: opts.withDefaults.GroupDN, allowAnonymousBind: opts.withDefaults.AllowAnonymousBind, } var err error var srvOpts []gldap.Option if opts.withLogger != nil { srvOpts = append(srvOpts, gldap.WithLogger(opts.withLogger)) } if opts.withDisablePanicRecovery { srvOpts = append(srvOpts, gldap.WithDisablePanicRecovery()) } d.s, err = gldap.NewServer(srvOpts...) require.NoError(err) d.logger.Debug("base search DNs", "users", d.userDN, "groups", d.groupDN) mux, err := gldap.NewMux() require.NoError(err) require.NoError(mux.DefaultRoute(d.handleNotFound(t))) require.NoError(mux.Bind(d.handleBind(t))) require.NoError(mux.ExtendedOperation(d.handleStartTLS(t), gldap.ExtendedOperationStartTLS)) require.NoError(mux.Search(d.handleSearchUsers(t), gldap.WithBaseDN(d.userDN), gldap.WithLabel("Search - Users"))) require.NoError(mux.Search(d.handleSearchGroups(t), gldap.WithBaseDN(d.groupDN), gldap.WithLabel("Search - Groups"))) require.NoError(mux.Search(d.handleSearchGeneric(t), gldap.WithLabel("Search - Generic"))) require.NoError(mux.Modify(d.handleModify(t), gldap.WithLabel("Modify"))) require.NoError(mux.Add(d.handleAdd(t), gldap.WithLabel("Add"))) require.NoError(mux.Delete(d.handleDelete(t), gldap.WithLabel("Delete"))) require.NoError(d.s.Router(mux)) serverTLSConfig, clientTLSConfig := GetTLSConfig(t, opt...) d.client = clientTLSConfig d.server = serverTLSConfig var connOpts []gldap.Option if !opts.withNoTLS { d.useTLS = true connOpts = append(connOpts, gldap.WithTLSConfig(d.server)) if opts.withMTLS { d.logger.Debug("using mTLS") } else { d.logger.Debug("using TLS") } } else { d.logger.Debug("not using TLS") } go func() { err := d.s.Run(fmt.Sprintf("%s:%d", opts.withHost, opts.withPort), connOpts...) if err != nil { d.logger.Error("Error during shutdown", "op", "testdirectory.Start", "err", err.Error()) } }() if v, ok := interface{}(t).(CleanupT); ok { v.Cleanup(func() { _ = d.s.Stop() }) } // need a bit of a pause to get the service up and running, otherwise we'll // get a connection error because the service isn't listening yet. for { time.Sleep(100 * time.Nanosecond) if d.s.Ready() { break } } return d } // Stop will stop the Directory if it wasn't started with a *testing.T // if it was started with *testing.T then Stop() is ignored. func (d *Directory) Stop() { const op = "testdirectory.(Directory).Stop" if _, ok := interface{}(d.t).(CleanupT); !ok { err := d.s.Stop() if err != nil { d.logger.Error("error stopping directory: %s", "op", op, "err", err) return } } } // handleBind is ONLY supporting simple authentication (no SASL here!) func (d *Directory) handleBind(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleBind" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) resp := r.NewBindResponse(gldap.WithResponseCode(gldap.ResultInvalidCredentials)) defer func() { _ = w.Write(resp) }() m, err := r.GetSimpleBindMessage() if err != nil { d.logger.Error("not a simple bind message", "op", op, "err", err) return } if m.AuthChoice != gldap.SimpleAuthChoice { // if it's not a simple auth request, then the bind failed... return } if m.Password == "" && d.allowAnonymousBind { resp.SetResultCode(gldap.ResultSuccess) return } for _, u := range d.users { d.logger.Debug("user", "u.DN", u.DN, "m.UserName", m.UserName) if u.DN == m.UserName { d.logger.Debug("found bind user", "op", op, "DN", u.DN) values := u.GetAttributeValues("password") if len(values) > 0 && string(m.Password) == values[0] { resp.SetResultCode(gldap.ResultSuccess) if d.controls != nil { d.mu.Lock() defer d.mu.Unlock() resp.SetControls(d.controls...) } return } } } // bind failed... return //nolint:gosimple // (ignore redundant return) } } func (d *Directory) handleNotFound(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleNotFound" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) resp := r.NewResponse(gldap.WithDiagnosticMessage("intentionally not handled")) _ = w.Write(resp) return //nolint:gosimple // (ignore redundant return) } } func (d *Directory) handleStartTLS(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleStartTLS" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewExtendedResponse(gldap.WithResponseCode(gldap.ResultSuccess)) res.SetResponseName(gldap.ExtendedOperationStartTLS) err := w.Write(res) if err != nil { d.logger.Error("error writing response: %s", "op", op, "err", err) return } if err := r.StartTLS(d.server); err != nil { d.logger.Error("StartTLS Handshake error", "op", op, "err", err) res.SetDiagnosticMessage(fmt.Sprintf("StartTLS Handshake error : \"%s\"", err.Error())) res.SetResultCode(gldap.ResultOperationsError) err := w.Write(res) if err != nil { d.logger.Error("error writing response: %s", "op", op, "err", err) return } return } d.logger.Debug("StartTLS OK", "op", op) } } func (d *Directory) handleSearchGeneric(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleSearchGeneric" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing response: %s", "op", op, "err", err) return } }() m, err := r.GetSearchMessage() if err != nil { d.logger.Error("not a search message: %s", "op", op, "err", err) return } d.logSearchRequest(m) filter := m.Filter // if our search base is the base userDN, we're searching for a single // user, so adjust the filter to match user's entries if strings.Contains(string(m.BaseDN), d.userDN) { filter = fmt.Sprintf("(%s)", m.BaseDN) d.logger.Debug("new filter", "op", op, "value", filter) for _, a := range m.Attributes { d.logger.Debug("attr", "op", op, "value", a) if a == "tokenGroups" { d.logger.Debug("asking for groups", "op", op) } } } var foundEntries int // if our search base is a SID, then we're searching for tokenGroups if len(d.tokenGroups) > 0 && strings.HasPrefix(string(m.BaseDN), "<SID=") { sid := string(m.BaseDN) sid = strings.TrimPrefix(sid, "<SID=") sid = strings.TrimSuffix(sid, ">") for _, g := range d.tokenGroups[sid] { d.logger.Debug("found tokenGroup", "op", op, "group DN", g.DN) result := r.NewSearchResponseEntry(g.DN) for _, attr := range g.Attributes { result.AddAttribute(attr.Name, attr.Values) } foundEntries += 1 err = w.Write(result) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } } d.logger.Debug("found entries", "op", op, "count", foundEntries) res.SetResultCode(gldap.ResultSuccess) return } d.logger.Debug("filter", "op", op, "value", filter) var entries []*gldap.Entry for _, e := range d.users { if ok, _ := match(filter, e.DN); !ok { continue } entries = append(entries, e) foundEntries += 1 } for _, e := range d.groups { if ok, _ := match(filter, e.DN); !ok { continue } switch { case slices.Contains(entries, e): continue default: entries = append(entries, e) foundEntries += 1 } } if foundEntries > 0 { d.logger.Debug("found entries", "op", op, "count", foundEntries) for _, e := range entries { result := r.NewSearchResponseEntry(e.DN) for _, attr := range e.Attributes { result.AddAttribute(attr.Name, attr.Values) } foundEntries += 1 err := w.Write(result) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } } if d.controls != nil { d.mu.Lock() defer d.mu.Unlock() res.SetControls(d.controls...) } res.SetResultCode(gldap.ResultSuccess) } } } func (d *Directory) handleSearchGroups(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleSearchGroups" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } }() m, err := r.GetSearchMessage() if err != nil { d.logger.Error("not a search message: %s", "op", op, "err", err) return } d.logSearchRequest(m) _, entries := d.findMembers(m.Filter) foundEntries := len(entries) for _, e := range d.groups { if ok, _ := match(m.Filter, e.DN); !ok { continue } switch { case slices.Contains(entries, e): continue default: entries = append(entries, e) } foundEntries += 1 } if foundEntries > 0 { for _, e := range entries { result := r.NewSearchResponseEntry(e.DN) for _, attr := range e.Attributes { result.AddAttribute(attr.Name, attr.Values) } foundEntries += 1 err = w.Write(result) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } } d.logger.Debug("found entries", "op", op, "count", foundEntries) if d.controls != nil { d.mu.Lock() defer d.mu.Unlock() res.SetControls(d.controls...) } res.SetResultCode(gldap.ResultSuccess) } } } func (d *Directory) handleSearchUsers(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleSearchUsers" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } }() m, err := r.GetSearchMessage() if err != nil { d.logger.Error("not a search message: %s", "op", op, "err", err) return } d.logSearchRequest(m) var foundEntries int _, _, entries := find(d.t, m.Filter, d.users) if len(entries) == 0 { return } for _, e := range entries { result := r.NewSearchResponseEntry(e.DN) for _, attr := range e.Attributes { result.AddAttribute(attr.Name, attr.Values) } foundEntries += 1 err := w.Write(result) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } } if foundEntries > 0 { d.logger.Debug("found entries", "op", op, "count", foundEntries) if d.controls != nil { d.mu.Lock() defer d.mu.Unlock() res.SetControls(d.controls...) fmt.Println(d.controls) } res.SetResultCode(gldap.ResultSuccess) } } } func (d *Directory) handleModify(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleModify" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewModifyResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } }() m, err := r.GetModifyMessage() if err != nil { d.logger.Error("not a modify message: %s", "op", op, "err", err) return } d.logger.Info("modify request", "dn", m.DN) var entries []*gldap.Entry _, _, entries = find(d.t, fmt.Sprintf("(%s)", m.DN), d.users) if len(entries) == 0 { _, _, entries = find(d.t, m.DN, d.groups) } if len(entries) == 0 { return } if len(entries) > 1 { res.SetResultCode(gldap.ResultInappropriateMatching) res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(entries))) return } d.mu.Lock() defer d.mu.Unlock() e := entries[0] if entries[0].Attributes == nil { e.Attributes = []*gldap.EntryAttribute{} } res.SetMatchedDN(entries[0].DN) for _, chg := range m.Changes { // find specific attr var foundAttr *gldap.EntryAttribute var foundAt int for i, a := range e.Attributes { if a.Name == chg.Modification.Type { foundAttr = a foundAt = i } } // then apply operation switch chg.Operation { case gldap.AddAttribute: if foundAttr != nil { foundAttr.AddValue(chg.Modification.Vals...) } else { e.Attributes = append(e.Attributes, gldap.NewEntryAttribute(chg.Modification.Type, chg.Modification.Vals)) } case gldap.DeleteAttribute: if foundAttr != nil { // slice out the deleted attribute copy(e.Attributes[foundAt:], e.Attributes[foundAt+1:]) e.Attributes = e.Attributes[:len(e.Attributes)-1] } case gldap.ReplaceAttribute: if foundAttr != nil { // we're updating what the ptr points at, so disable lint of // unused var //nolint:staticcheck foundAttr = gldap.NewEntryAttribute(chg.Modification.Type, chg.Modification.Vals) } } } res.SetResultCode(gldap.ResultSuccess) } } func (d *Directory) handleAdd(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleAdd" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewResponse(gldap.WithApplicationCode(gldap.ApplicationAddResponse), gldap.WithResponseCode(gldap.ResultOperationsError)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing result: %s", "op", op, "err", err) return } }() m, err := r.GetAddMessage() if err != nil { d.logger.Error("not an add message: %s", "op", op, "err", err) return } d.logger.Info("add request", "dn", m.DN) if found, _, _ := find(d.t, fmt.Sprintf("(%s)", m.DN), d.users); found { res.SetResultCode(gldap.ResultEntryAlreadyExists) res.SetDiagnosticMessage(fmt.Sprintf("entry exists for DN: %s", m.DN)) return } attrs := map[string][]string{} for _, a := range m.Attributes { attrs[a.Type] = a.Vals } newEntry := gldap.NewEntry(m.DN, attrs) d.mu.Lock() defer d.mu.Unlock() d.users = append(d.users, newEntry) res.SetResultCode(gldap.ResultSuccess) } } func (d *Directory) handleDelete(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) { const op = "testdirectory.(Directory).handleDelete" if v, ok := interface{}(t).(HelperT); ok { v.Helper() } return func(w *gldap.ResponseWriter, r *gldap.Request) { d.logger.Debug(op) res := r.NewResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject), gldap.WithApplicationCode(gldap.ApplicationDelResponse)) defer func() { err := w.Write(res) if err != nil { d.logger.Error("error writing response: %s", "op", op, "err", err) return } }() m, err := r.GetDeleteMessage() if err != nil { d.logger.Error("not a delete message: %s", "op", op, "err", err) return } d.logger.Info("delete request", "dn", m.DN) _, foundAt, _ := find(d.t, fmt.Sprintf("(%s)", m.DN), d.users) if len(foundAt) > 0 { if len(foundAt) > 1 { res.SetResultCode(gldap.ResultInappropriateMatching) res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(foundAt))) return } d.mu.Lock() defer d.mu.Unlock() d.users = append(d.users[:foundAt[0]], d.users[foundAt[0]+1:]...) res.SetResultCode(gldap.ResultSuccess) return } _, foundAt, _ = find(d.t, fmt.Sprintf("(%s)", m.DN), d.groups) if len(foundAt) > 0 { if len(foundAt) > 1 { res.SetResultCode(gldap.ResultInappropriateMatching) res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(foundAt))) return } d.mu.Lock() defer d.mu.Unlock() d.groups = append(d.groups[:foundAt[0]], d.groups[foundAt[0]+1:]...) res.SetResultCode(gldap.ResultSuccess) return } return //nolint:gosimple // (ignore redundant return) } } func (d *Directory) findMembers(filter string, opt ...Option) (bool, []*gldap.Entry) { opts := getOpts(d.t, opt...) var matches []*gldap.Entry for _, e := range d.groups { members := e.GetAttributeValues("member") for _, m := range members { if ok, _ := match(filter, "member="+m); ok { matches = append(matches, e) if opts.withFirst { return true, matches } } } } if len(matches) > 0 { return true, matches } return false, nil } func find(t TestingT, filter string, entries []*gldap.Entry, opt ...Option) (bool, []int, []*gldap.Entry) { opts := getOpts(t, opt...) var matches []*gldap.Entry var matchIndexes []int for idx, e := range entries { if ok, _ := match(filter, e.DN); ok { matches = append(matches, e) matchIndexes = append(matchIndexes, idx) if opts.withFirst { return true, []int{idx}, matches } } } if len(matches) > 0 { return true, matchIndexes, matches } return false, nil, nil } func match(filter string, attr string) (bool, error) { // TODO: make this actually do something more reasonable with the search // request filter re := regexp.MustCompile(`\((.*?)\)`) submatchall := re.FindAllString(filter, -1) for _, element := range submatchall { element = strings.ReplaceAll(element, "*", "") element = strings.Trim(element, "|(") element = strings.Trim(element, "(") element = strings.Trim(element, ")") element = strings.TrimSpace(element) if strings.Contains(attr, element) { return true, nil } } return false, nil } // Conn returns an *ldap.Conn that's connected (using whatever tls.Config is // appropriate for the directory) and ready send requests to the directory. func (d *Directory) Conn() *ldap.Conn { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } require := require.New(d.t) var conn *ldap.Conn retryAttempt := 5 retryErrFn := func(e error) error { if retryAttempt > 0 { fmt.Println(retryAttempt) retryAttempt-- return backoff.Permanent(e) } return backoff.Permanent(e) } err := backoff.Retry(func() error { var connErr error if d.useTLS { if conn, connErr = ldap.DialURL(fmt.Sprintf("ldaps://%s:%d", d.Host(), d.Port()), ldap.DialWithTLSConfig(d.client)); connErr != nil { return retryErrFn(connErr) } return nil } if conn, connErr = ldap.DialURL(fmt.Sprintf("ldap://%s:%d", d.Host(), d.Port())); connErr != nil { return retryErrFn(connErr) } return nil }, backoff.NewConstantBackOff(1*time.Second)) require.NoError(err) return conn } // Cert returns the pem-encoded certificate used by the Directory. func (d *Directory) Cert() string { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } require := require.New(d.t) require.NotNil(d.server) require.Len(d.server.Certificates, 1) cert := d.server.Certificates[0] require.NotNil(cert) require.Len(cert.Certificate, 1) var buf bytes.Buffer err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]}) require.NoError(err) return buf.String() } // Port returns the port the directory is listening on func (d *Directory) Port() int { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } return d.port } // Host returns the host the directory is listening on func (d *Directory) Host() string { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } return d.host } // ClientCert returns the pem-encoded certificate which can be used by a client // for mTLS. func (d *Directory) ClientCert() string { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } require := require.New(d.t) require.NotNil(d.client) require.Len(d.client.Certificates, 1) cert := d.client.Certificates[0] require.NotNil(cert) require.Len(cert.Certificate, 1) var buf bytes.Buffer err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]}) require.NoError(err) return buf.String() } // ClientKey returns the pem-encoded private key which can be used by a client // for mTLS. func (d *Directory) ClientKey() string { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } require := require.New(d.t) require.NotNil(d.client) require.Len(d.client.Certificates, 1) privBytes, err := x509.MarshalPKCS8PrivateKey(d.client.Certificates[0].PrivateKey) require.NoError(err) pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) require.NotNil(pemKey) return string(pemKey) } // Controls returns all the current bind controls for the Directory func (d *Directory) Controls() []gldap.Control { return d.controls } // SetControls sets the bind controls. func (d *Directory) SetControls(controls ...gldap.Control) { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } d.mu.Lock() defer d.mu.Unlock() d.controls = controls } // Users returns all the current user entries in the Directory func (d *Directory) Users() []*gldap.Entry { return d.users } // SetUsers sets the user entries. func (d *Directory) SetUsers(users ...*gldap.Entry) { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } d.mu.Lock() defer d.mu.Unlock() d.users = users } // Groups returns all the current group entries in the Directory func (d *Directory) Groups() []*gldap.Entry { return d.groups } // SetGroups sets the group entries. func (d *Directory) SetGroups(groups ...*gldap.Entry) { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } d.mu.Lock() defer d.mu.Unlock() d.groups = groups } // SetTokenGroups will set the tokenGroup entries. func (d *Directory) SetTokenGroups(tokenGroups map[string][]*gldap.Entry) { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } d.mu.Lock() defer d.mu.Unlock() d.tokenGroups = tokenGroups } // TokenGroups will return the tokenGroup entries func (d *Directory) TokenGroups() map[string][]*gldap.Entry { return d.tokenGroups } // AllowAnonymousBind returns the allow anon bind setting func (d *Directory) AllowAnonymousBind() bool { return d.allowAnonymousBind } // SetAllowAnonymousBind enables/disables anon binds func (d *Directory) SetAllowAnonymousBind(enabled bool) { if v, ok := interface{}(d.t).(HelperT); ok { v.Helper() } d.mu.Lock() defer d.mu.Unlock() d.allowAnonymousBind = enabled } func (d *Directory) logSearchRequest(m *gldap.SearchMessage) { d.logger.Info("search request", "baseDN", m.BaseDN, "scope", m.Scope, "filter", m.Filter, "attributes", m.Attributes, ) }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package testdirectory import ( "strings" "github.com/hashicorp/go-hclog" "github.com/jimlambrt/gldap" ) // Option defines a common functional options type which can be used in a // variadic parameter pattern. type Option func(interface{}) // getOpts gets the defaults and applies the opt overrides passed in func getOpts(t TestingT, opt ...Option) options { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } opts := defaults(t) applyOpts(&opts, opt...) return opts } // applyOpts takes a pointer to the options struct as a set of default options // and applies the slice of opts as overrides. func applyOpts(opts interface{}, opt ...Option) { for _, o := range opt { if o == nil { // ignore any nil Options continue } o(opts) } } // options are the set of available options for test functions type options struct { withPort int withHost string withLogger hclog.Logger withNoTLS bool withMTLS bool withDisablePanicRecovery bool withDefaults *Defaults withMembersOf []string withTokenGroupSIDs [][]byte withFirst bool } func defaults(t TestingT) options { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } debugLogger := hclog.New(&hclog.LoggerOptions{ Name: "testdirectory-default-logger", Level: hclog.Error, }) return options{ withLogger: debugLogger, withHost: "localhost", withDefaults: &Defaults{ UserAttr: DefaultUserAttr, GroupAttr: DefaultGroupAttr, UserDN: DefaultUserDN, GroupDN: DefaultGroupDN, }, } } // Defaults define a type for composing all the defaults for Directory.Start(...) type Defaults struct { UserAttr string GroupAttr string // Users configures the user entries which are empty by default Users []*gldap.Entry // Groups configures the group entries which are empty by default Groups []*gldap.Entry // TokenGroups configures the tokenGroup entries which are empty be default TokenGroups map[string][]*gldap.Entry // UserDN is the base distinguished name to use when searching for users // which is "ou=people,dc=example,dc=org" by default UserDN string // GroupDN is the base distinguished name to use when searching for groups // which is "ou=groups,dc=example,dc=org" by default GroupDN string // AllowAnonymousBind determines if anon binds are allowed AllowAnonymousBind bool // UPNDomain is the userPrincipalName domain, which enables a // userPrincipalDomain login with [username]@UPNDomain (optional) UPNDomain string } // WithDefaults provides an option to provide a set of defaults to // Directory.Start(...) which make it much more composable. func WithDefaults(t TestingT, defaults *Defaults) Option { return func(o interface{}) { if o, ok := o.(*options); ok { if defaults != nil { if defaults.AllowAnonymousBind { o.withDefaults.AllowAnonymousBind = true } if defaults.Users != nil { o.withDefaults.Users = defaults.Users } if defaults.Groups != nil { o.withDefaults.Groups = defaults.Groups } if defaults.UserDN != "" { o.withDefaults.UserDN = defaults.UserDN } if defaults.GroupDN != "" { o.withDefaults.GroupDN = defaults.GroupDN } if len(defaults.TokenGroups) > 0 { o.withDefaults.TokenGroups = defaults.TokenGroups } if defaults.UserAttr != "" { o.withDefaults.UserAttr = defaults.UserAttr } if defaults.GroupAttr != "" { o.withDefaults.GroupAttr = defaults.GroupAttr } if defaults.UPNDomain != "" { o.withDefaults.UPNDomain = defaults.UPNDomain } } } } } // WithMTLS provides the option to use mTLS for the directory. func WithMTLS(t TestingT) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withMTLS = true } } } // WithNoTLS provides the option to not use TLS for the directory. func WithNoTLS(t TestingT) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withNoTLS = true } } } // WithLogger provides the optional logger for the directory. func WithLogger(t TestingT, l hclog.Logger) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withLogger = l } } } // WithPort provides an optional port for the directory. 0 causes a // started server with a random port. Any other value returns a started server // on that port. func WithPort(t TestingT, port int) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withPort = port } } } // WithHost provides an optional hostname for the directory func WithHost(t TestingT, host string) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withHost = strings.TrimSpace(host) } } } // withFirst provides the option to only find the first match. func withFirst(t TestingT) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withFirst = true } } } // WithMembersOf specifies optional memberOf attributes for user // entries func WithMembersOf(t TestingT, membersOf ...string) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withMembersOf = membersOf } } } // WithTokenGroups specifies optional test tokenGroups SID attributes for user // entries func WithTokenGroups(t TestingT, tokenGroupSID ...[]byte) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withTokenGroupSIDs = tokenGroupSID } } } func WithDisablePanicRecovery(t TestingT, disable bool) Option { return func(o interface{}) { if o, ok := o.(*options); ok { o.withDisablePanicRecovery = disable } } }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package testdirectory import ( "bytes" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "time" "github.com/jimlambrt/gldap" "github.com/stretchr/testify/require" ) // NewMemberOf creates memberOf attributes which can be assigned to user // entries. Supported Options: WithDefaults func NewMemberOf(t TestingT, groupNames []string, opt ...Option) []string { opts := getOpts(t, opt...) DNs := make([]string, 0, len(groupNames)) for _, n := range groupNames { DNs = append(DNs, fmt.Sprintf("%s=%s,%s", opts.withDefaults.GroupAttr, n, opts.withDefaults.GroupDN)) } return DNs } // NewUsers creates user entries. Options supported: WithDefaults, WithMembersOf func NewUsers(t TestingT, userNames []string, opt ...Option) []*gldap.Entry { opts := getOpts(t, opt...) entries := make([]*gldap.Entry, 0, len(userNames)) for _, n := range userNames { entryAttrs := map[string][]string{ "name": {n}, "email": {fmt.Sprintf("%s@example.com", n)}, "password": {"password"}, } if len(opts.withMembersOf) > 0 { entryAttrs["memberOf"] = opts.withMembersOf } if len(opts.withTokenGroupSIDs) > 0 { groups := make([]string, 0, len(opts.withTokenGroupSIDs)) for _, s := range opts.withTokenGroupSIDs { groups = append(groups, string(s)) } entryAttrs["tokenGroups"] = groups } var DN string switch { case opts.withDefaults.UPNDomain != "": DN = fmt.Sprintf("userPrincipalName=%s@%s,%s", n, opts.withDefaults.UPNDomain, opts.withDefaults.UserDN) default: DN = fmt.Sprintf("%s=%s,%s", opts.withDefaults.UserAttr, n, opts.withDefaults.UserDN) } entries = append(entries, gldap.NewEntry( DN, entryAttrs, ), ) } return entries } // NewGroup creates a group entry. Options supported: WithDefaults func NewGroup(t TestingT, groupName string, memberNames []string, opt ...Option) *gldap.Entry { opts := getOpts(t, opt...) members := make([]string, 0, len(memberNames)) for _, n := range memberNames { var DN string switch { case opts.withDefaults.UPNDomain != "": DN = fmt.Sprintf("userPrincipalName=%s@%s,%s", n, opts.withDefaults.UPNDomain, opts.withDefaults.UserDN) default: DN = fmt.Sprintf("%s=%s,%s", opts.withDefaults.UserAttr, n, opts.withDefaults.UserDN) } members = append(members, DN) } return gldap.NewEntry( fmt.Sprintf("%s=%s,%s", opts.withDefaults.GroupAttr, groupName, opts.withDefaults.GroupDN), map[string][]string{ "member": members, }) } // FreePort just returns an available free localhost port func FreePort(t TestingT) int { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } require := require.New(t) addr, err := net.ResolveTCPAddr("tcp", "localhost:0") require.NoError(err) l, err := net.ListenTCP("tcp", addr) require.NoError(err) defer l.Close() return l.Addr().(*net.TCPAddr).Port } // supports WithMTLS func GetTLSConfig(t TestingT, opt ...Option) (s *tls.Config, c *tls.Config) { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } require := require.New(t) certSubject := pkix.Name{ Organization: []string{"Acme, INC."}, Country: []string{"US"}, Province: []string{""}, Locality: []string{"New York"}, StreetAddress: []string{"Empire State Building"}, PostalCode: []string{"10118"}, } // set up our CA certificate ca := &x509.Certificate{ SerialNumber: genSerialNumber(t), Subject: certSubject, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1, 0, 0), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } caPriv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) require.NoError(err) caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPriv.PublicKey, caPriv) require.NoError(err) caPEM := new(bytes.Buffer) err = pem.Encode(caPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: caBytes, }) require.NoError(err) privBytes, err := x509.MarshalPKCS8PrivateKey(caPriv) require.NoError(err) caPrivKeyPEM := new(bytes.Buffer) err = pem.Encode(caPrivKeyPEM, &pem.Block{ Type: "PRIVATE KEY", Bytes: privBytes, }) require.NoError(err) opts := getOpts(t, opt...) var ipAddrs []net.IP if hostIp := net.ParseIP(opts.withHost); hostIp != nil { ipAddrs = append(ipAddrs, hostIp) } cert := &x509.Certificate{ SerialNumber: genSerialNumber(t), Subject: certSubject, IPAddresses: ipAddrs, DNSNames: []string{opts.withHost}, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1, 0, 0), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, } serverCert := genCert(t, ca, caPriv, cert) certpool := x509.NewCertPool() certpool.AppendCertsFromPEM(caPEM.Bytes()) serverTLSConf := &tls.Config{ Certificates: []tls.Certificate{serverCert}, } clientTLSConf := &tls.Config{ RootCAs: certpool, } if opts.withMTLS { // setup mTLS for certs from the ca serverTLSConf.ClientCAs = certpool serverTLSConf.ClientAuth = tls.RequireAndVerifyClientCert cert := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: certSubject, EmailAddresses: []string{"mtls.client@example.com"}, NotBefore: time.Now(), NotAfter: time.Now().AddDate(1, 0, 0), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, } clientCert := genCert(t, ca, caPriv, cert) clientTLSConf.Certificates = []tls.Certificate{clientCert} } return serverTLSConf, clientTLSConf } func genCert(t TestingT, ca *x509.Certificate, caPriv interface{}, certTemplate *x509.Certificate) tls.Certificate { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } require := require.New(t) certPrivKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) require.NoError(err) certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, ca, &certPrivKey.PublicKey, caPriv) require.NoError(err) certPEM := new(bytes.Buffer) err = pem.Encode(certPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: certBytes, }) require.NoError(err) privBytes, err := x509.MarshalPKCS8PrivateKey(certPrivKey) require.NoError(err) certPrivKeyPEM := new(bytes.Buffer) err = pem.Encode(certPrivKeyPEM, &pem.Block{ Type: "PRIVATE KEY", Bytes: privBytes, }) require.NoError(err) newCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes()) require.NoError(err) return newCert } func genSerialNumber(t TestingT) *big.Int { if v, ok := interface{}(t).(HelperT); ok { v.Helper() } require := require.New(t) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) require.NoError(err) return serialNumber }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package testdirectory import ( "errors" "github.com/hashicorp/go-hclog" ) // TestingT defines a very slim interface required by a Directory and any // test functions it uses. type TestingT interface { Errorf(format string, args ...interface{}) FailNow() Log(...interface{}) } // CleanupT defines an single function interface for a testing.Cleanup(func()). type CleanupT interface{ Cleanup(func()) } // HelperT defines a single function interface for a testing.Helper() type HelperT interface{ Helper() } // InfofT defines a single function interface for a Info(format string, args ...interface{}) type InfofT interface { Infof(format string, args ...interface{}) } // Logger defines a logger that will implement the TestingT interface so // it can be used with Directory.Start(...) as its t TestingT parameter. type Logger struct { Logger hclog.Logger } // NewLogger makes a new TestingLogger func NewLogger(logger hclog.Logger) (*Logger, error) { if logger == nil { return nil, errors.New("missing logger") } return &Logger{ Logger: logger, }, nil } // Errorf will output the error to the log func (l *Logger) Errorf(format string, args ...interface{}) { l.Logger.Error(format, args...) } // Infof will output the info to the log func (l *Logger) Infof(format string, args ...interface{}) { l.Logger.Info(format, args...) } // FailNow will panic func (l *Logger) FailNow() { panic("testing.T failed, see logs for output (if any)") } func (l *Logger) Log(i ...interface{}) { l.Logger.StandardLogger(&hclog.StandardLoggerOptions{}).Println(i...) }
// Copyright (c) Jim Lambert // SPDX-License-Identifier: MIT package gldap import ( "net" "os" "strings" "sync" "testing" ber "github.com/go-asn1-ber/asn1-ber" "github.com/go-ldap/ldap/v3" "github.com/stretchr/testify/require" ) type testOptions struct { // test options withDescription string } func testDefaults() testOptions { return testOptions{} } func getTestOpts(opt ...Option) testOptions { opts := testDefaults() applyOpts(&opts, opt...) return opts } // WithDescription allows you to specify an optional description. func WithDescription(desc string) Option { return func(o interface{}) { if o, ok := o.(*testOptions); ok { o.withDescription = desc } } } func freePort(t *testing.T) int { t.Helper() require := require.New(t) addr, err := net.ResolveTCPAddr("tcp", "localhost:0") require.NoError(err) l, err := net.ListenTCP("tcp", addr) require.NoError(err) defer l.Close() return l.Addr().(*net.TCPAddr).Port } func testStartTLSRequestPacket(t *testing.T, messageID int) *packet { t.Helper() envelope := testRequestEnvelope(t, int(messageID)) request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) envelope.AppendChild(request) return &packet{ Packet: envelope, } } func testSearchRequestPacket(t *testing.T, s SearchMessage) *packet { t.Helper() require := require.New(t) envelope := testRequestEnvelope(t, int(s.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN")) pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.Scope), "Scope")) pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.DerefAliases), "Deref Aliases")) pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.SizeLimit), "Size Limit")) pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.TimeLimit), "Time Limit")) pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only")) // compile and encode filter filterPacket, err := ldap.CompileFilter(s.Filter) require.NoError(err) pkt.AppendChild(filterPacket) attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") for _, attribute := range s.Attributes { attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute")) } pkt.AppendChild(attributesPacket) envelope.AppendChild(pkt) if len(s.Controls) > 0 { envelope.AppendChild(encodeControls(s.Controls)) } return &packet{ Packet: envelope, } } func testSimpleBindRequestPacket(t *testing.T, m SimpleBindMessage) *packet { t.Helper() envelope := testRequestEnvelope(t, int(m.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request") pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(3), "Version")) pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.UserName, "User Name")) pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, string(m.Password), "Password")) envelope.AppendChild(pkt) if len(m.Controls) > 0 { envelope.AppendChild(encodeControls(m.Controls)) } return &packet{ Packet: envelope, } } func testUnbindRequestPacket(t *testing.T, m UnbindMessage) *packet { t.Helper() envelope := testRequestEnvelope(t, int(m.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationUnbindRequest, nil, "Unbind Request") envelope.AppendChild(pkt) return &packet{ Packet: envelope, } } func testModifyRequestPacket(t *testing.T, m ModifyMessage) *packet { t.Helper() envelope := testRequestEnvelope(t, int(m.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes") for _, change := range m.Changes { changes.AppendChild(change.encode()) } pkt.AppendChild(changes) envelope.AppendChild(pkt) if len(m.Controls) > 0 { envelope.AppendChild(encodeControls(m.Controls)) } return &packet{ Packet: envelope, } } func testDeleteRequestPacket(t *testing.T, m DeleteMessage) *packet { t.Helper() envelope := testRequestEnvelope(t, int(m.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationDelRequest, nil, "Delete Request") pkt.Data.Write([]byte(m.DN)) envelope.AppendChild(pkt) if len(m.Controls) > 0 { envelope.AppendChild(encodeControls(m.Controls)) } return &packet{ Packet: envelope, } } func testAddRequestPacket(t *testing.T, m AddMessage) *packet { t.Helper() envelope := testRequestEnvelope(t, int(m.GetID())) pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request") pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN")) attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") for _, attr := range m.Attributes { attributes.AppendChild(attr.encode()) } pkt.AppendChild(attributes) envelope.AppendChild(pkt) if len(m.Controls) > 0 { envelope.AppendChild(encodeControls(m.Controls)) } return &packet{ Packet: envelope, } } func testRequestEnvelope(t *testing.T, messageID int) *ber.Packet { p := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") p.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(messageID), "MessageID")) return p } func testControlString(t *testing.T, controlType string, opt ...Option) *ControlString { t.Helper() require := require.New(t) c, err := NewControlString(controlType, opt...) require.NoError(err) return c } func testControlManageDsaIT(t *testing.T, opt ...Option) *ControlManageDsaIT { t.Helper() require := require.New(t) c, err := NewControlManageDsaIT(opt...) require.NoError(err) return c } func testControlMicrosoftNotification(t *testing.T, opt ...Option) *ControlMicrosoftNotification { t.Helper() require := require.New(t) c, err := NewControlMicrosoftNotification(opt...) require.NoError(err) return c } func testControlMicrosoftServerLinkTTL(t *testing.T, opt ...Option) *ControlMicrosoftServerLinkTTL { t.Helper() require := require.New(t) c, err := NewControlMicrosoftServerLinkTTL(opt...) require.NoError(err) return c } func testControlMicrosoftShowDeleted(t *testing.T, opt ...Option) *ControlMicrosoftShowDeleted { t.Helper() require := require.New(t) c, err := NewControlMicrosoftShowDeleted(opt...) require.NoError(err) return c } func testControlPaging(t *testing.T, pagingSize uint32, opt ...Option) *ControlPaging { t.Helper() require := require.New(t) c, err := NewControlPaging(uint32(pagingSize), opt...) require.NoError(err) return c } // TestWithDebug specifies that the test should be run under "debug" mode func TestWithDebug(t *testing.T) bool { t.Helper() return strings.ToLower(os.Getenv("DEBUG")) == "true" } func TestEncodeString(t *testing.T, tag ber.Tag, s string, opt ...Option) string { t.Helper() opts := getTestOpts(opt...) pkt := ber.NewString(ber.ClassUniversal, ber.TypePrimitive, tag, s, opts.withDescription) dec, err := ber.DecodePacketErr(pkt.Bytes()) require.NoError(t, err) return string(dec.Bytes()) } type safeBuf struct { buf *strings.Builder mu *sync.Mutex } func testSafeBuf(t *testing.T) *safeBuf { t.Helper() return &safeBuf{ mu: &sync.Mutex{}, buf: &strings.Builder{}, } } func (w *safeBuf) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() return w.buf.Write(p) } func (w *safeBuf) String() string { w.mu.Lock() defer w.mu.Unlock() return w.buf.String() }