// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// AddMessage is an add request message
type AddMessage struct {
baseMessage
// DN identifies the entry being added
DN string
// Attributes list the attributes of the new entry
Attributes []Attribute
// Controls hold optional controls to send with the request
Controls []Control
}
// Attribute represents an LDAP attribute within AddMessage
type Attribute struct {
// Type is the name of the LDAP attribute
Type string
// Vals are the LDAP attribute values
Vals []string
}
func (a *Attribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range a.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
}
seq.AppendChild(set)
return seq
}
func decodeAttribute(berPacket *ber.Packet) (*Attribute, error) {
const op = "gldap.decodeAttribute"
const (
childType = 0
childVals = 1
childControls = 2
)
if berPacket == nil {
return nil, fmt.Errorf("%s: missing ber packet: %w", op, ErrInvalidParameter)
}
var decodedAttribute Attribute
seq := &packet{
Packet: berPacket,
}
if err := seq.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid attributes ber packet: %w", op, ErrInvalidParameter)
}
if err := seq.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childType)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid attributes type: %w", op, ErrInvalidParameter)
}
decodedAttribute.Type = seq.Children[childType].Data.String()
if err := seq.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSet), withAssertChild(childVals)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid attributes values: %w", op, ErrInvalidParameter)
}
valuesPacket := &packet{
Packet: seq.Children[childVals],
}
decodedAttribute.Vals = make([]string, 0, len(valuesPacket.Children))
for idx := range valuesPacket.Children {
if err := valuesPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(idx)); err != nil {
return nil, fmt.Errorf("%s: invalid attribute values packet: %w", op, err)
}
decodedAttribute.Vals = append(decodedAttribute.Vals, valuesPacket.Children[idx].Data.String())
}
return &decodedAttribute, nil
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/hashicorp/go-hclog"
)
// conn is a connection to an ldap client
type conn struct {
mu sync.Mutex // mutex for the conn
connID int
netConn net.Conn
logger hclog.Logger
router *Mux
shutdownCtx context.Context
requestsWg sync.WaitGroup
reader *bufio.Reader
writer *bufio.Writer
writerMu sync.Mutex // shared lock across all ResponseWriter's to prevent write data races
}
// newConn will create a new Conn from an accepted net.Conn which will be used
// to serve requests to an ldap client.
func newConn(shutdownCtx context.Context, connID int, netConn net.Conn, logger hclog.Logger, router *Mux) (*conn, error) {
const op = "gldap.NewConn"
if shutdownCtx == nil {
return nil, fmt.Errorf("%s: missing shutdown context: %w", op, ErrInvalidParameter)
}
if connID == 0 {
return nil, fmt.Errorf("%s: missing connection id: %w", op, ErrInvalidParameter)
}
if netConn == nil {
return nil, fmt.Errorf("%s: missing connection: %w", op, ErrInvalidParameter)
}
if logger == nil {
return nil, fmt.Errorf("%s: missing logger: %w", op, ErrInvalidParameter)
}
if router == nil {
return nil, fmt.Errorf("%s: missing router: %w", op, ErrInvalidParameter)
}
c := &conn{
connID: connID,
netConn: netConn,
shutdownCtx: shutdownCtx,
logger: logger,
router: router,
}
if err := c.initConn(netConn); err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return c, nil
}
// serveRequests until the connection is closed or the shutdownCtx is cancelled
// as the server stops
func (c *conn) serveRequests() error {
const op = "gldap.serveRequests"
requestID := 0
for {
requestID++
w, err := newResponseWriter(c.writer, &c.writerMu, c.logger, c.connID, requestID)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
select {
case <-c.shutdownCtx.Done():
c.logger.Debug("received shutdown cancellation", "op", op, "conn", c.connID, "requestID", w.requestID)
// build a request by hand, since this is not a normal situation
// where we've read a request... and we need to make this check
// before blocking on reading the next request.
req := &Request{
ID: w.requestID,
conn: c,
message: &ExtendedOperationMessage{baseMessage: baseMessage{id: 0}},
routeOp: routeOperation(ExtendedOperationDisconnection),
extendedName: ExtendedOperationDisconnection,
}
resp := req.NewResponse(WithResponseCode(ResultUnwillingToPerform), WithDiagnosticMessage("server stopping"))
if err := w.Write(resp); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
if err := c.netConn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
default:
// need a default to fall through to rest of loop...
}
r, err := c.readRequest(w.requestID)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || strings.Contains(err.Error(), "unexpected EOF") {
return nil // connection is closed
}
return fmt.Errorf("%s: error reading request: %w", op, err)
}
switch {
// TODO: rate limit in-flight requests per conn and send a
// BusyResponse when the limit is reached. This limit per conn
// should be configurable
case r.routeOp == unbindRouteOperation:
// support an optional unbind route
if c.router.unbindRoute != nil {
c.router.unbindRoute.handler()(w, r)
}
// stop serving requests when UnbindRequest is received
return nil
// If it's a StartTLS request, then we can't dispatch it concurrently,
// since the conn needs to complete it's TLS negotiation before handling
// any other requests.
// see: https://datatracker.ietf.org/doc/html/rfc4511#section-4.14.1
case r.extendedName == ExtendedOperationStartTLS:
c.router.serve(w, r)
default:
c.requestsWg.Add(1)
go func() {
defer func() {
c.logger.Debug("requestsWg done", "op", op, "conn", c.connID, "requestID", w.requestID)
c.requestsWg.Done()
}()
c.router.serve(w, r)
}()
}
}
}
func (c *conn) readRequest(requestID int) (*Request, error) {
const op = "gldap.(Conn).readRequest"
p, err := c.readPacket(requestID)
if err != nil {
return nil, fmt.Errorf("%s: error reading packet for %d/%d: %w", op, c.connID, requestID, err)
}
r, err := newRequest(requestID, c, p)
if err != nil {
return nil, fmt.Errorf("%s: unable to create new in-memory request for %d/%d: %w", op, c.connID, requestID, err)
}
return r, nil
}
func (c *conn) readPacket(requestID int) (*packet, error) {
const op = "gldap.readPacket"
// read a request
berPacket, err := func() (*ber.Packet, error) {
c.mu.Lock()
defer c.mu.Unlock()
berPacket, err := ber.ReadPacket(c.reader)
switch {
case err != nil && strings.Contains(err.Error(), "invalid character for IA5String at pos 2"):
return nil, fmt.Errorf("%s: error reading ber packet for %d/%d (possible attempt to use TLS with a non-TLS server): %w", op, c.connID, requestID, err)
case err != nil:
return nil, fmt.Errorf("%s: error reading ber packet for %d/%d: %w", op, c.connID, requestID, err)
}
return berPacket, nil
}()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
p := &packet{Packet: berPacket}
if c.logger.IsDebug() {
c.logger.Debug("packet read", "op", op, "conn", c.connID, "requestID", requestID)
p.Log(c.logger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false)
}
// Simple header is first... let's make sure it's an ldap packet with 2
// children containing:
// [0] is a message ID
// [1] is a request header
if err := p.basicValidation(); err != nil {
return nil, fmt.Errorf("%s: failed validation: %w", op, err)
}
return p, nil
}
func (c *conn) initConn(netConn net.Conn) error {
const op = "gldap.(Conn).initConn"
if netConn == nil {
return fmt.Errorf("%s: missing net conn: %w", op, ErrInvalidParameter)
}
c.mu.Lock()
defer c.mu.Unlock()
c.netConn = netConn
c.reader = bufio.NewReader(c.netConn)
c.writer = bufio.NewWriter(c.netConn)
return nil
}
func (c *conn) close() error {
const op = "gldap.(Conn).close"
c.requestsWg.Wait()
if err := c.netConn.Close(); err != nil {
return fmt.Errorf("%s: error closing conn: %w", op, err)
}
return nil
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
"strconv"
ber "github.com/go-asn1-ber/asn1-ber"
)
const (
// ControlTypePaging - https://www.ietf.org/rfc/rfc2696.txt
ControlTypePaging = "1.2.840.113556.1.4.319"
// ControlTypeBeheraPasswordPolicy - https://tools.ietf.org/html/draft-behera-ldap-password-policy-10
ControlTypeBeheraPasswordPolicy = "1.3.6.1.4.1.42.2.27.8.5.1"
// ControlTypeVChuPasswordMustChange - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4"
// ControlTypeVChuPasswordWarning - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5"
// ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296
ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2"
// ControlTypeWhoAmI - https://tools.ietf.org/html/rfc4532
ControlTypeWhoAmI = "1.3.6.1.4.1.4203.1.11.3"
// ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528"
// ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417"
// ControlTypeMicrosoftServerLinkTTL - https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
ControlTypeMicrosoftServerLinkTTL = "1.2.840.113556.1.4.2309"
)
// ControlTypeMap maps controls to text descriptions
var ControlTypeMap = map[string]string{
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
ControlTypeMicrosoftServerLinkTTL: "Return TTL-DNs for link values with associated expiry times - Microsoft",
}
// Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10)
const (
BeheraPasswordExpired = 0
BeheraAccountLocked = 1
BeheraChangeAfterReset = 2
BeheraPasswordModNotAllowed = 3
BeheraMustSupplyOldPassword = 4
BeheraInsufficientPasswordQuality = 5
BeheraPasswordTooShort = 6
BeheraPasswordTooYoung = 7
BeheraPasswordInHistory = 8
)
// BeheraPasswordPolicyErrorMap contains human readable descriptions of Behera Password Policy error codes
var BeheraPasswordPolicyErrorMap = map[int8]string{
BeheraPasswordExpired: "Password expired",
BeheraAccountLocked: "Account locked",
BeheraChangeAfterReset: "Password must be changed",
BeheraPasswordModNotAllowed: "Policy prevents password modification",
BeheraMustSupplyOldPassword: "Policy requires old password in order to change password",
BeheraInsufficientPasswordQuality: "Password fails quality checks",
BeheraPasswordTooShort: "Password is too short for policy",
BeheraPasswordTooYoung: "Password has been changed too recently",
BeheraPasswordInHistory: "New password is in list of old passwords",
}
// Control defines a common interface for all ldap controls
type Control interface {
// GetControlType returns the OID
GetControlType() string
// Encode returns the ber packet representation
Encode() *ber.Packet
// String returns a human-readable description
String() string
}
func encodeControls(controls []Control) *ber.Packet {
packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
for _, control := range controls {
packet.AppendChild(control.Encode())
}
return packet
}
func decodeControl(packet *ber.Packet) (Control, error) {
const op = "gldap.decodeControl"
var (
ControlType = ""
Criticality = false
value *ber.Packet
)
if packet == nil {
return nil, fmt.Errorf("%s: packet is nil: %w", op, ErrInvalidParameter)
}
switch len(packet.Children) {
case 0:
// at least one child is required for a control type
return nil, fmt.Errorf("%s: at least one child is required for control type", op)
case 1:
// just type, no critically or value
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
case 2:
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
if _, ok := packet.Children[1].Value.(bool); ok {
packet.Children[1].Description = "Criticality"
Criticality = packet.Children[1].Value.(bool)
} else {
packet.Children[1].Description = "Control Value"
value = packet.Children[1]
}
case 3:
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
packet.Children[1].Description = "Criticality"
Criticality = packet.Children[1].Value.(bool)
packet.Children[2].Description = "Control Value"
value = packet.Children[2]
default:
// more than 3 children is invalid
return nil, fmt.Errorf("%s: more than 3 children is invalid for controls", op)
}
switch ControlType {
case ControlTypeManageDsaIT:
return NewControlManageDsaIT(WithCriticality(Criticality))
case ControlTypePaging:
if value == nil {
return new(ControlPaging), nil
}
value.Description += " (Paging)"
c := new(ControlPaging)
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("%s, failed to decode data bytes: %w", op, err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
if len(value.Children) < 1 {
return nil, fmt.Errorf("%s: paging control value must have a least 1 child: %w", op, ErrInvalidParameter)
}
value = value.Children[0]
value.Description = "Search Control Value"
value.Children[0].Description = "Paging Size"
value.Children[1].Description = "Cookie"
c.PagingSize = uint32(value.Children[0].Value.(int64))
c.Cookie = value.Children[1].Data.Bytes()
value.Children[1].Value = c.Cookie
return c, nil
case ControlTypeBeheraPasswordPolicy:
if value == nil {
c, err := NewControlBeheraPasswordPolicy()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return c, nil
}
value.Description += " (Password Policy - Behera)"
c, err := NewControlBeheraPasswordPolicy()
if err != nil {
return nil, fmt.Errorf("%s: failed to create behera password control", op)
}
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("%s: failed to decode data bytes: %w", op, err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
if len(value.Children) == 0 {
return nil, fmt.Errorf("%s: behera control value must have a least 1 child: %w", op, ErrInvalidParameter)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
// Warning
warningPacket := child.Children[0]
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("%s: failed to decode data bytes: %w", op, err)
}
if warningPacket.Tag == 0 {
// timeBeforeExpiration
c.expire = val
warningPacket.Value = c.expire
} else if warningPacket.Tag == 1 {
// graceAuthNsRemaining
c.grace = val
warningPacket.Value = c.grace
}
} else if child.Tag == 1 {
// Error
bs := child.Data.Bytes()
if len(bs) != 1 || bs[0] > 8 {
return nil, fmt.Errorf("%s: failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value", op)
}
val := int8(bs[0])
c.error = val
child.Value = c.error
c.errorString = BeheraPasswordPolicyErrorMap[c.error]
}
}
return c, nil
case ControlTypeVChuPasswordMustChange:
c := &ControlVChuPasswordMustChange{MustChange: true}
return c, nil
case ControlTypeVChuPasswordWarning:
if value == nil {
return &ControlVChuPasswordWarning{Expire: -1}, nil
}
c := &ControlVChuPasswordWarning{Expire: -1}
expireStr := ber.DecodeString(value.Data.Bytes())
expire, err := strconv.ParseInt(expireStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("%s: failed to parse value as int: %w", op, err)
}
c.Expire = expire
value.Value = c.Expire
return c, nil
case ControlTypeMicrosoftNotification:
return NewControlMicrosoftNotification()
case ControlTypeMicrosoftShowDeleted:
return NewControlMicrosoftShowDeleted()
case ControlTypeMicrosoftServerLinkTTL:
return NewControlMicrosoftServerLinkTTL()
default:
c := new(ControlString)
c.ControlType = ControlType
c.Criticality = Criticality
if value != nil {
c.ControlValue = value.Value.(string)
}
return c, nil
}
}
// ControlString implements the Control interface for simple controls
type ControlString struct {
ControlType string
Criticality bool
ControlValue string
}
// GetControlType returns the OID
func (c *ControlString) GetControlType() string {
return c.ControlType
}
// Encode returns the ber packet representation
func (c *ControlString) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")"))
if c.Criticality {
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
}
if c.ControlValue != "" {
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(c.ControlValue), "Control Value"))
}
return packet
}
// String returns a human-readable description
func (c *ControlString) String() string {
return fmt.Sprintf("Control Type: %s (%q) Criticality: %t Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue)
}
// NewControlString returns a generic control. Options supported:
// WithCriticality and WithControlValue
func NewControlString(controlType string, opt ...Option) (*ControlString, error) {
const op = "gldap.NewControlString"
if controlType == "" {
return nil, fmt.Errorf("%s: missing control type: %w", op, ErrInvalidParameter)
}
opts := getControlOpts(opt...)
return &ControlString{
ControlType: controlType,
Criticality: opts.withCriticality,
ControlValue: opts.withControlValue,
}, nil
}
// ControlManageDsaIT implements the control described in https://tools.ietf.org/html/rfc3296
type ControlManageDsaIT struct {
// Criticality indicates if this control is required
Criticality bool
}
// Encode returns the ber packet representation
func (c *ControlManageDsaIT) Encode() *ber.Packet {
// FIXME
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeManageDsaIT, "Control Type ("+ControlTypeMap[ControlTypeManageDsaIT]+")"))
if c.Criticality {
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
}
return packet
}
// GetControlType returns the OID
func (c *ControlManageDsaIT) GetControlType() string {
return ControlTypeManageDsaIT
}
// String returns a human-readable description
func (c *ControlManageDsaIT) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t",
ControlTypeMap[ControlTypeManageDsaIT],
ControlTypeManageDsaIT,
c.Criticality)
}
// NewControlManageDsaIT returns a ControlManageDsaIT control. Supported
// options: WithCriticality
func NewControlManageDsaIT(opt ...Option) (*ControlManageDsaIT, error) {
opts := getControlOpts(opt...)
return &ControlManageDsaIT{Criticality: opts.withCriticality}, nil
}
// ControlMicrosoftNotification implements the control described in https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
type ControlMicrosoftNotification struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftNotification) GetControlType() string {
return ControlTypeMicrosoftNotification
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftNotification) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftNotification, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftNotification]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftNotification) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftNotification],
ControlTypeMicrosoftNotification)
}
// NewControlMicrosoftNotification returns a ControlMicrosoftNotification
// control. No options are currently supported.
func NewControlMicrosoftNotification(_ ...Option) (*ControlMicrosoftNotification, error) {
return &ControlMicrosoftNotification{}, nil
}
// ControlMicrosoftServerLinkTTL implements the control described in https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/f4f523a8-abc0-4b3a-a471-6b2fef135481?redirectedfrom=MSDN
type ControlMicrosoftServerLinkTTL struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftServerLinkTTL) GetControlType() string {
return ControlTypeMicrosoftServerLinkTTL
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftServerLinkTTL) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftServerLinkTTL, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftServerLinkTTL]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftServerLinkTTL) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftServerLinkTTL],
ControlTypeMicrosoftServerLinkTTL)
}
// NewControlMicrosoftServerLinkTTL returns a ControlMicrosoftServerLinkTTL
// control. No options are currently supported.
func NewControlMicrosoftServerLinkTTL(_ ...Option) (*ControlMicrosoftServerLinkTTL, error) {
return &ControlMicrosoftServerLinkTTL{}, nil
}
// ControlMicrosoftShowDeleted implements the control described in https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
type ControlMicrosoftShowDeleted struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftShowDeleted) GetControlType() string {
return ControlTypeMicrosoftShowDeleted
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftShowDeleted) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftShowDeleted, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftShowDeleted]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftShowDeleted) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftShowDeleted],
ControlTypeMicrosoftShowDeleted)
}
// NewControlMicrosoftShowDeleted returns a ControlMicrosoftShowDeleted control.
// No options are currently supported.
func NewControlMicrosoftShowDeleted(_ ...Option) (*ControlMicrosoftShowDeleted, error) {
return &ControlMicrosoftShowDeleted{}, nil
}
// ControlBeheraPasswordPolicy implements the control described in https://tools.ietf.org/html/draft-behera-ldap-password-policy-10
type ControlBeheraPasswordPolicy struct {
// expire contains the number of seconds before a password will expire
expire int64
// grace indicates the remaining number of times a user will be allowed to authenticate with an expired password
grace int64
// error indicates the error code
error int8
// errorString is a human readable error
errorString string
}
// Grace returns the remaining number of times a user will be allowed to
// authenticate with an expired password. A value of -1 indicates it hasn't been
// set.
func (c *ControlBeheraPasswordPolicy) Grace() int {
return int(c.grace)
}
// Expire contains the number of seconds before a password will expire. A value
// of -1 indicates it hasn't been set.
func (c *ControlBeheraPasswordPolicy) Expire() int {
return int(c.expire)
}
// ErrorCode is the error code and a human readable string. A value of -1 and
// empty string indicates it hasn't been set.
func (c *ControlBeheraPasswordPolicy) ErrorCode() (int, string) {
return int(c.error), c.errorString
}
// GetControlType returns the OID
func (c *ControlBeheraPasswordPolicy) GetControlType() string {
return ControlTypeBeheraPasswordPolicy
}
// Encode returns the ber packet representation
func (c *ControlBeheraPasswordPolicy) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeBeheraPasswordPolicy, "Control Type ("+ControlTypeMap[ControlTypeBeheraPasswordPolicy]+")"))
switch {
case c.grace >= 0:
// control value packet for GraceAuthNsRemaining
valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "")
sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "")
// it's a warning. so it's the end of a context (ber.TagEOC)
contextPacket := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0x00, nil, "")
// "0x01" tag indicates an grace logins
contextPacket.AppendChild(ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x01, c.grace, ""))
sequencePacket.AppendChild(contextPacket)
valuePacket.AppendChild(sequencePacket)
packet.AppendChild(valuePacket)
return packet // I believe you can only have either Grace or Expire for a response.... not both.
case c.expire >= 0:
// control value packet for timeBeforeExpiration
valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "")
sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "")
// it's a warning. so it's the end of a context (ber.TagEOC)
contextPacket := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0x00, nil, "")
// "0x00" tag indicates an expires in
contextPacket.AppendChild(ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x00, c.expire, ""))
sequencePacket.AppendChild(contextPacket)
valuePacket.AppendChild(sequencePacket)
packet.AppendChild(valuePacket)
return packet // I believe you can only have either Grace or Expire for a response.... not both.
case c.error >= 0:
valuePacket := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "")
sequencePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "")
contextPacket := ber.NewInteger(ber.ClassContext, ber.TypePrimitive, 0x01, c.error, "")
sequencePacket.AppendChild(contextPacket)
valuePacket.AppendChild(sequencePacket)
packet.AppendChild(valuePacket)
}
return packet
}
// String returns a human-readable description
func (c *ControlBeheraPasswordPolicy) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %d Grace: %d Error: %d, ErrorString: %s",
ControlTypeMap[ControlTypeBeheraPasswordPolicy],
ControlTypeBeheraPasswordPolicy,
false,
c.expire,
c.grace,
c.error,
c.errorString)
}
// NewControlBeheraPasswordPolicy returns a ControlBeheraPasswordPolicy.
// Options supported: WithExpire, WithGrace, WithErrorCode
func NewControlBeheraPasswordPolicy(opt ...Option) (*ControlBeheraPasswordPolicy, error) {
const op = "NewControlBeheraPolicy"
opts := getControlOpts(opt...)
switch {
case opts.withGrace != -1 && opts.withExpire != -1:
return nil, fmt.Errorf("%s: behera policies cannot have both grace and expire set: %w", op, ErrInvalidParameter)
case opts.withGrace != -1 && opts.withErrorCode != -1:
return nil, fmt.Errorf("%s: behera policies cannot have both grace and error codes set: %w", op, ErrInvalidParameter)
case opts.withExpire != -1 && opts.withErrorCode != -1:
return nil, fmt.Errorf("%s: behera polices cannot have both expire and error codes set: %w", op, ErrInvalidParameter)
case opts.withErrorCode > 8:
return nil, fmt.Errorf("%s: %d is not a valid behera policy error code (must be between 0-8: %w", op, opts.withErrorCode, ErrInvalidParameter)
}
c := &ControlBeheraPasswordPolicy{
expire: int64(opts.withExpire),
grace: int64(opts.withGrace),
error: int8(opts.withErrorCode),
}
if opts.withErrorCode != -1 {
c.errorString = BeheraPasswordPolicyErrorMap[int8(opts.withErrorCode)]
}
return c, nil
}
// ControlVChuPasswordMustChange implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
type ControlVChuPasswordMustChange struct {
// MustChange indicates if the password is required to be changed
MustChange bool
}
// GetControlType returns the OID
func (c *ControlVChuPasswordMustChange) GetControlType() string {
return ControlTypeVChuPasswordMustChange
}
// Encode returns the ber packet representation
func (c *ControlVChuPasswordMustChange) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
// I believe, just the control type child is require... not criticality or
// value is require...
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeVChuPasswordMustChange, "Control Type ("+ControlTypeMap[ControlTypeVChuPasswordMustChange]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlVChuPasswordMustChange) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t MustChange: %v",
ControlTypeMap[ControlTypeVChuPasswordMustChange],
ControlTypeVChuPasswordMustChange,
false,
c.MustChange)
}
// ControlVChuPasswordWarning implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
type ControlVChuPasswordWarning struct {
// Expire indicates the time in seconds until the password expires
Expire int64
}
// GetControlType returns the OID
func (c *ControlVChuPasswordWarning) GetControlType() string {
return ControlTypeVChuPasswordWarning
}
// Encode returns the ber packet representation
func (c *ControlVChuPasswordWarning) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeVChuPasswordWarning, "Control Type ("+ControlTypeMap[ControlTypeVChuPasswordWarning]+")"))
// I believe, it's a string in the spec
expStr := strconv.FormatInt(c.Expire, 10)
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, expStr, "Control Value"))
return packet
}
// String returns a human-readable description
func (c *ControlVChuPasswordWarning) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %d",
ControlTypeMap[ControlTypeVChuPasswordWarning],
ControlTypeVChuPasswordWarning,
false,
c.Expire)
}
// ControlPaging implements the paging control described in https://www.ietf.org/rfc/rfc2696.txt
type ControlPaging struct {
// PagingSize indicates the page size
PagingSize uint32
// Cookie is an opaque value returned by the server to track a paging cursor
Cookie []byte
}
// GetControlType returns the OID
func (c *ControlPaging) GetControlType() string {
return ControlTypePaging
}
// Encode returns the ber packet representation
func (c *ControlPaging) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")"))
p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)")
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value")
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.PagingSize), "Paging Size"))
cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
cookie.Value = c.Cookie
cookie.Data.Write(c.Cookie)
seq.AppendChild(cookie)
p2.AppendChild(seq)
packet.AppendChild(p2)
return packet
}
// String returns a human-readable description
func (c *ControlPaging) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t PagingSize: %d Cookie: %q",
ControlTypeMap[ControlTypePaging],
ControlTypePaging,
false,
c.PagingSize,
c.Cookie)
}
// SetCookie stores the given cookie in the paging control
func (c *ControlPaging) SetCookie(cookie []byte) {
c.Cookie = cookie
}
// NewControlPaging returns a paging control
func NewControlPaging(pagingSize uint32, _ ...Option) (*ControlPaging, error) {
return &ControlPaging{PagingSize: pagingSize}, nil
}
func addControlDescriptions(packet *ber.Packet) error {
const op = "gldap.addControlDescriptions"
if packet == nil {
return fmt.Errorf("%s: missing packet: %w", op, ErrInvalidParameter)
}
packet.Description = "Controls"
for _, child := range packet.Children {
var value *ber.Packet
controlType := ""
child.Description = "Control"
switch len(child.Children) {
case 0:
// at least one child is required for control type
return fmt.Errorf("at least one child is required for a control type")
case 1:
// just type, no criticality or value
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
case 2:
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
if _, ok := child.Children[1].Value.(bool); ok {
child.Children[1].Description = "Criticality"
} else {
child.Children[1].Description = "Control Value"
value = child.Children[1]
}
case 3:
// criticality and value present
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
child.Children[1].Description = "Criticality"
child.Children[2].Description = "Control Value"
value = child.Children[2]
default:
// more than 3 children is invalid
return fmt.Errorf("more than 3 children for control packet found")
}
if value == nil {
continue
}
switch controlType {
case ControlTypePaging:
value.Description += " (Paging)"
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes()
value.AppendChild(valueChildren)
}
value.Children[0].Description = "Real Search Control Value"
value.Children[0].Children[0].Description = "Paging Size"
value.Children[0].Children[1].Description = "Cookie"
case ControlTypeBeheraPasswordPolicy:
value.Description += " (Password Policy - Behera Draft)"
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
// Warning
warningPacket := child.Children[0]
val, err := ber.ParseInt64(warningPacket.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
if warningPacket.Tag == 0 {
// timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
warningPacket.Value = val
} else if warningPacket.Tag == 1 {
// graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
warningPacket.Value = val
}
} else if child.Tag == 1 {
// Error
bs := child.Data.Bytes()
if len(bs) != 1 || bs[0] > 8 {
return fmt.Errorf("failed to decode data bytes: %s", "invalid PasswordPolicyResponse enum value")
}
val := int8(bs[0])
child.Description = "Error"
child.Value = val
}
}
}
}
return nil
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
type controlOptions struct {
withGrace int
withExpire int
withErrorCode int
withCriticality bool
withControlValue string
// test options
withTestType string
withTestToString string
}
func controlDefaults() controlOptions {
return controlOptions{
withGrace: -1,
withExpire: -1,
withErrorCode: -1,
}
}
func getControlOpts(opt ...Option) controlOptions {
opts := controlDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithGraceAuthNsRemaining specifies the number of grace authentication
// remaining.
func WithGraceAuthNsRemaining(remaining uint) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withGrace = int(remaining)
}
}
}
// WithSecondsBeforeExpiration specifies the number of seconds before a password
// will expire
func WithSecondsBeforeExpiration(seconds uint) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withExpire = int(seconds)
}
}
}
// WithErrorCode specifies the error code
func WithErrorCode(code uint) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withErrorCode = int(code)
}
}
}
// WithCriticality specifies the criticality
func WithCriticality(criticality bool) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withCriticality = criticality
}
}
}
// WithControlValue specifies the control value
func WithControlValue(value string) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withControlValue = value
}
}
}
func withTestType(s string) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withTestType = s
}
}
}
func withTestToString(s string) Option {
return func(o interface{}) {
if o, ok := o.(*controlOptions); ok {
o.withTestToString = s
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
"os"
"sort"
"strings"
ber "github.com/go-asn1-ber/asn1-ber"
)
// Entry represents an ldap entry
type Entry struct {
// DN is the distinguished name of the entry
DN string
// Attributes are the returned attributes for the entry
Attributes []*EntryAttribute
}
// GetAttributeValues returns the values for the named attribute, or an empty list
func (e *Entry) GetAttributeValues(attribute string) []string {
for _, attr := range e.Attributes {
if attr.Name == attribute {
return attr.Values
}
}
return []string{}
}
// NewEntry returns an Entry object with the specified distinguished name and attribute key-value pairs.
// The map of attributes is accessed in alphabetical order of the keys in order to ensure that, for the
// same input map of attributes, the output entry will contain the same order of attributes
func NewEntry(dn string, attributes map[string][]string) *Entry {
var attributeNames []string
for attributeName := range attributes {
attributeNames = append(attributeNames, attributeName)
}
sort.Strings(attributeNames)
var encodedAttributes []*EntryAttribute
for _, attributeName := range attributeNames {
encodedAttributes = append(encodedAttributes, NewEntryAttribute(attributeName, attributes[attributeName]))
}
return &Entry{
DN: dn,
Attributes: encodedAttributes,
}
}
// PrettyPrint outputs a human-readable description indenting. Supported
// options: WithWriter
func (e *Entry) PrettyPrint(indent int, opt ...Option) {
opts := getGeneralOpts(opt...)
if opts.withWriter == nil {
opts.withWriter = os.Stdout
}
fmt.Fprintf(opts.withWriter, "%sDN: %s\n", strings.Repeat(" ", indent), e.DN)
for _, attr := range e.Attributes {
attr.PrettyPrint(indent+2, opt...)
}
}
// PrettyPrint outputs a human-readable description with indenting. Supported
// options: WithWriter
func (e *EntryAttribute) PrettyPrint(indent int, opt ...Option) {
opts := getGeneralOpts(opt...)
if opts.withWriter == nil {
opts.withWriter = os.Stdout
}
fmt.Fprintf(opts.withWriter, "%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values)
}
// EntryAttribute holds a single attribute
type EntryAttribute struct {
// Name is the name of the attribute
Name string
// Values contain the string values of the attribute
Values []string
// ByteValues contain the raw values of the attribute
ByteValues [][]byte
}
// NewEntryAttribute returns a new EntryAttribute with the desired key-value pair
func NewEntryAttribute(name string, values []string) *EntryAttribute {
var bytes [][]byte
for _, value := range values {
bytes = append(bytes, []byte(value))
}
return &EntryAttribute{
Name: name,
Values: values,
ByteValues: bytes,
}
}
// AddValue to an existing EntryAttribute
func (e *EntryAttribute) AddValue(value ...string) {
for _, v := range value {
e.ByteValues = append(e.ByteValues, []byte(v))
e.Values = append(e.Values, v)
}
}
func (e *EntryAttribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, e.Name, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range e.Values {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
}
seq.AppendChild(set)
return seq
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
)
// Scope represents the scope of a search (see: https://ldap.com/the-ldap-search-operation/)
type Scope int64
const (
// BaseObject (often referred to as “base”): Indicates that only the entry
// specified as the search base should be considered. None of its
// subordinates will be considered.
BaseObject Scope = 0
// SingleLevel (often referred to as “one”): Indicates that only the
// immediate children of the entry specified as the search base should be
// considered. The base entry itself should not be considered, nor any
// descendants of the immediate children of the base entry.
SingleLevel Scope = 1
// WholeSubtree (often referred to as “sub”): Indicates that the entry
// specified as the search base, and all of its subordinates to any depth,
// should be considered. Note that in the special case that the search base
// DN is the null DN, the root DSE should not be considered in a
// wholeSubtree search.
WholeSubtree Scope = 2
)
// AuthChoice defines the authentication choice for bind message
type AuthChoice string
// SimpleAuthChoice specifies a simple user/password authentication choice for
// the bind message
const SimpleAuthChoice AuthChoice = "simple"
type requestType string
const (
unknownRequestType requestType = ""
bindRequestType requestType = "bind"
searchRequestType requestType = "search"
extendedRequestType requestType = "extended"
modifyRequestType requestType = "modify"
addRequestType requestType = "add"
deleteRequestType requestType = "delete"
unbindRequestType requestType = "unbind"
)
// Message defines a common interface for all messages
type Message interface {
// GetID returns the message ID
GetID() int64
}
// baseMessage defines a common base type for all messages (typically embedded)
type baseMessage struct {
id int64
}
// GetID() returns the message ID
func (m baseMessage) GetID() int64 { return m.id }
// SearchMessage is a search request message
type SearchMessage struct {
baseMessage
// BaseDN for the request
BaseDN string
// Scope of the request
Scope Scope
// DerefAliases for the request
DerefAliases int
// TimeLimit is the max time in seconds to spend processing
TimeLimit int64
// SizeLimit is the max number of results to return
SizeLimit int64
// TypesOnly is true if the client only expects type info
TypesOnly bool
// Filter for the request
Filter string
// Attributes requested
Attributes []string
// Controls requested
Controls []Control
}
// SimpleBindMessage is a simple bind request message
type SimpleBindMessage struct {
baseMessage
// AuthChoice for the request (SimpleAuthChoice)
AuthChoice AuthChoice
// UserName for the bind request
UserName string
// Password for the bind request
Password Password
// Controls are optional controls for the bind request
Controls []Control
}
// ExtendedOperationMessage is an extended operation request message
type ExtendedOperationMessage struct {
baseMessage
// Name of the extended operation
Name ExtendedOperationName
// Value of the extended operation
Value string
}
// DeleteMessage is an delete request message
type DeleteMessage struct {
baseMessage
// DN identifies the entry being added
DN string
// Controls hold optional controls to send with the request
Controls []Control
}
// UnbindMessage is an unbind request message
type UnbindMessage struct {
baseMessage
}
// newMessage will create a new message from the packet.
func newMessage(p *packet) (Message, error) {
const op = "gldap.NewMessage"
reqType, err := p.requestType()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
msgID, err := p.requestMessageID()
if err != nil {
return nil, fmt.Errorf("%s: unable to get message id: %w", op, err)
}
switch reqType {
case unbindRequestType:
return &UnbindMessage{
baseMessage: baseMessage{
id: msgID,
},
}, nil
case bindRequestType:
u, pass, controls, err := p.simpleBindParameters()
if err != nil {
return nil, fmt.Errorf("%s: invalid bind message: %w", op, err)
}
return &SimpleBindMessage{
baseMessage: baseMessage{
id: msgID,
},
UserName: u,
Password: pass,
AuthChoice: SimpleAuthChoice,
Controls: controls,
}, nil
case searchRequestType:
parameters, err := p.searchParmeters()
if err != nil {
return nil, fmt.Errorf("%s: invalid search message: %w", op, err)
}
return &SearchMessage{
baseMessage: baseMessage{
id: msgID,
},
BaseDN: parameters.baseDN,
Scope: Scope(parameters.scope),
DerefAliases: int(parameters.derefAliases),
SizeLimit: parameters.sizeLimit,
TimeLimit: parameters.timeLimit,
TypesOnly: parameters.typesOnly,
Filter: parameters.filter,
Attributes: parameters.attributes,
Controls: parameters.controls,
}, nil
case extendedRequestType:
opName, err := p.extendedOperationName()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return &ExtendedOperationMessage{
baseMessage: baseMessage{
id: msgID,
},
Name: opName,
}, nil
case modifyRequestType:
parameters, err := p.modifyParameters()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return &ModifyMessage{
baseMessage: baseMessage{
id: msgID,
},
DN: parameters.dn,
Changes: parameters.changes,
Controls: parameters.controls,
}, nil
case addRequestType:
parameters, err := p.addParameters()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return &AddMessage{
baseMessage: baseMessage{
id: msgID,
},
DN: parameters.dn,
Attributes: parameters.attributes,
Controls: parameters.controls,
}, nil
case deleteRequestType:
dn, controls, err := p.deleteParameters()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
return &DeleteMessage{
baseMessage: baseMessage{
id: msgID,
},
DN: dn,
Controls: controls,
}, nil
default:
return &ExtendedOperationMessage{
baseMessage: baseMessage{
id: msgID,
},
Name: ExtendedOperationUnknown,
}, nil
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import ber "github.com/go-asn1-ber/asn1-ber"
type messageOptions struct {
withMinChildren *int
withLenChildren *int
withAssertChild *int
withTag *ber.Tag
}
func messageDefaults() messageOptions {
return messageOptions{}
}
func getMessageOpts(opt ...Option) messageOptions {
opts := messageDefaults()
applyOpts(&opts, opt...)
return opts
}
func withMinChildren(min int) Option {
return func(o interface{}) {
if o, ok := o.(*messageOptions); ok {
o.withMinChildren = &min
}
}
}
// we'll see if we start using this again in the near future,
// but for now ignore the warning
//
//nolint:unused
func withLenChildren(len int) Option {
return func(o interface{}) {
if o, ok := o.(*messageOptions); ok {
o.withLenChildren = &len
}
}
}
func withAssertChild(idx int) Option {
return func(o interface{}) {
if o, ok := o.(*messageOptions); ok {
o.withAssertChild = &idx
}
}
}
func withTag(t ber.Tag) Option {
return func(o interface{}) {
if o, ok := o.(*messageOptions); ok {
o.withTag = &t
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import ber "github.com/go-asn1-ber/asn1-ber"
// Change operation choices
const (
AddAttribute = 0
DeleteAttribute = 1
ReplaceAttribute = 2
IncrementAttribute = 3 // (https://tools.ietf.org/html/rfc4525)
)
// ModifyMessage as defined in https://tools.ietf.org/html/rfc4511
type ModifyMessage struct {
baseMessage
DN string
Changes []Change
Controls []Control
}
// Change for a ModifyMessage as defined in https://tools.ietf.org/html/rfc4511
type Change struct {
// Operation is the type of change to be made
Operation int64
// Modification is the attribute to be modified
Modification PartialAttribute
}
// PartialAttribute for a ModifyMessage as defined in https://tools.ietf.org/html/rfc4511
type PartialAttribute struct {
// Type is the type of the partial attribute
Type string
// Vals are the values of the partial attribute
Vals []string
}
func (c *Change) encode() *ber.Packet {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(c.Operation), "Operation"))
change.AppendChild(c.Modification.encode())
return change
}
func (p *PartialAttribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range p.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
}
seq.AppendChild(set)
return seq
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
"sync"
)
// Mux is an ldap request multiplexer. It matches the inbound request against a
// list of registered route handlers. Routes are matched in the order they're
// added and only one route is called per request.
type Mux struct {
mu sync.Mutex
routes []route
defaultRoute route
unbindRoute route
}
// NewMux creates a new multiplexer.
func NewMux(opt ...Option) (*Mux, error) {
return &Mux{
routes: []route{},
}, nil
}
// Bind will register a handler for bind requests.
// Options supported: WithLabel
func (m *Mux) Bind(bindFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Bind"
if bindFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &simpleBindRoute{
baseRoute: &baseRoute{
h: bindFn,
routeOp: bindRouteOperation,
label: opts.withLabel,
},
authChoice: SimpleAuthChoice,
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// Unbind will register a handler for unbind requests and override the default
// unbind handler. Registering an unbind handler is optional and regardless of
// whether or not an unbind route is defined the server will stop serving
// requests for a connection after an unbind request is received. Options
// supported: WithLabel
func (m *Mux) Unbind(bindFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Unbind"
if bindFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &unbindRoute{
baseRoute: &baseRoute{
h: bindFn,
routeOp: bindRouteOperation,
label: opts.withLabel,
},
}
m.mu.Lock()
defer m.mu.Unlock()
m.unbindRoute = r
return nil
}
// Search will register a handler for search requests.
// Options supported: WithLabel, WithBaseDN, WithScope
func (m *Mux) Search(searchFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Search"
if searchFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &searchRoute{
baseRoute: &baseRoute{
h: searchFn,
routeOp: searchRouteOperation,
label: opts.withLabel,
},
basedn: opts.withBaseDN,
filter: opts.withFilter,
scope: opts.withScope,
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// ExtendedOperation will register a handler for extended operation requests.
// Options supported: WithLabel
func (m *Mux) ExtendedOperation(operationFn HandlerFunc, exName ExtendedOperationName, opt ...Option) error {
const op = "gldap.(Mux).Search"
if operationFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &extendedRoute{
baseRoute: &baseRoute{
h: operationFn,
routeOp: extendedRouteOperation,
label: opts.withLabel,
},
extendedName: exName,
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// Modify will register a handler for modify operation requests.
// Options supported: WithLabel
func (m *Mux) Modify(modifyFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Modify"
if modifyFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &modifyRoute{
baseRoute: &baseRoute{
h: modifyFn,
routeOp: modifyRouteOperation,
label: opts.withLabel,
},
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// Add will register a handler for add operation requests.
// Options supported: WithLabel
func (m *Mux) Add(addFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Add"
if addFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &addRoute{
baseRoute: &baseRoute{
h: addFn,
routeOp: addRouteOperation,
label: opts.withLabel,
},
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// Delete will register a handler for delete operation requests.
// Options supported: WithLabel
func (m *Mux) Delete(modifyFn HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Delete"
if modifyFn == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
opts := getRouteOpts(opt...)
r := &deleteRoute{
baseRoute: &baseRoute{
h: modifyFn,
routeOp: deleteRouteOperation,
label: opts.withLabel,
},
}
m.mu.Lock()
defer m.mu.Unlock()
m.routes = append(m.routes, r)
return nil
}
// DefaultRoute will register a default handler requests which have no other
// registered handler.
func (m *Mux) DefaultRoute(noRouteFN HandlerFunc, opt ...Option) error {
const op = "gldap.(Mux).Bind"
if noRouteFN == nil {
return fmt.Errorf("%s: missing HandlerFunc: %w", op, ErrInvalidParameter)
}
r := &baseRoute{
h: noRouteFN,
routeOp: bindRouteOperation,
}
m.mu.Lock()
defer m.mu.Unlock()
m.defaultRoute = r
return nil
}
// serveRequests will find a matching route to serve the request
func (m *Mux) serve(w *ResponseWriter, req *Request) {
const op = "gldap.(Mux).serve"
defer func() {
w.logger.Debug("finished serving request", "op", op, "connID", w.connID, "requestID", w.requestID)
}()
if w == nil {
// this should be unreachable, and if it is then we'll just panic
panic(fmt.Errorf("%s: %d/%d missing response writer: %w", op, w.connID, w.requestID, ErrInternal).Error())
}
if req == nil {
w.logger.Error("missing request", "op", op, "connID", w.connID, "requestID", w.requestID)
return
}
// find the first matching route to dispatch the request to and then return
for _, r := range m.routes {
if !r.match(req) {
continue
}
h := r.handler()
if h == nil {
w.logger.Error("route is missing handler", "op", op, "connID", w.connID, "requestID", w.requestID, "route", r.op)
return
}
// the handler intentionally doesn't return errors, since we want the
// handler to response to the connection's client with errors.
h(w, req)
return
}
if m.defaultRoute != nil {
h := m.defaultRoute.handler()
h(w, req)
return
}
w.logger.Error("no matching handler found for request and returning internal error", "op", op, "connID", w.connID, "requestID", w.requestID, "routeOp", req.routeOp)
resp := req.NewResponse(WithResponseCode(ResultUnwillingToPerform), WithDiagnosticMessage("No matching handler found"))
_ = w.Write(resp)
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"io"
"reflect"
)
// Option defines a common functional options type which can be used in a
// variadic parameter pattern.
type Option func(interface{})
// applyOpts takes a pointer to the options struct as a set of default options
// and applies the slice of opts as overrides.
func applyOpts(opts interface{}, opt ...Option) {
for _, o := range opt {
if o == nil { // ignore any nil Options
continue
}
o(opts)
}
}
type generalOptions struct {
withWriter io.Writer
}
func generalDefaults() generalOptions {
return generalOptions{}
}
func getGeneralOpts(opt ...Option) generalOptions {
opts := generalDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithWriter allows you to specify an optional writer.
func WithWriter(w io.Writer) Option {
return func(o interface{}) {
if o, ok := o.(*generalOptions); ok {
if !isNil(w) {
o.withWriter = w
}
}
}
}
func isNil(i interface{}) bool {
if i == nil {
return true
}
switch reflect.TypeOf(i).Kind() {
case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice:
return reflect.ValueOf(i).IsNil()
}
return false
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"fmt"
"io"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/go-ldap/ldap/v3"
"github.com/hashicorp/go-hclog"
)
type packet struct {
*ber.Packet
validated bool
}
func (p *packet) basicValidation() error {
const (
op = "gldap.(packet).basicValidation"
// messageID packet + Request packet
childMinChildren = 2
)
if p.validated {
return nil
}
// Simple header is first... let's make sure it's an ldap packet with 2
// children containing:
// [0] is a message ID
// [1] is a request header
if err := p.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withMinChildren(childMinChildren)); err != nil {
return fmt.Errorf("%s: invalid ldap packet 0: %w", op, ErrInvalidParameter)
}
p.validated = true
return nil
}
func (p *packet) requestMessageID() (int64, error) {
const (
op = "gldap.(packet).requestMessageID"
childMessageID = 0
)
if err := p.basicValidation(); err != nil {
return 0, fmt.Errorf("%s: %w", op, err)
}
msgIDPacket := &packet{Packet: p.Children[childMessageID]}
// assert it's capable of holding the message ID
if err := msgIDPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger)); err != nil {
return 0, fmt.Errorf("%s: missing/invalid packet: %w", op, err)
}
id, ok := msgIDPacket.Value.(int64)
if !ok {
return 0, fmt.Errorf("%s: expected int64 message ID and got %t: %w", op, msgIDPacket.Value, ErrInvalidParameter)
}
return id, nil
}
// returns nil, nil if there's no control packet
func (p *packet) controlPacket() (*packet, error) {
const (
op = "gldap.(packet).controlPacket"
childControl = 2
)
if len(p.Children) <= 2 {
// no control packet
return nil, nil
}
controlPacket := &packet{Packet: p.Children[childControl]}
if err := controlPacket.assert(ber.ClassContext, ber.TypeConstructed); err != nil {
return nil, fmt.Errorf("%s: invalid control packet: %w", op, ErrInvalidParameter)
}
return controlPacket, nil
}
func (p *packet) requestPacket() (*packet, error) {
const (
op = "gldap.(packet).requestPacket"
childApplicationRequest = 1
childVersionNumber = 0 // first child of the app request packet
)
if err := p.basicValidation(); err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if err := p.assertApplicationRequest(); err != nil {
return nil, fmt.Errorf("%s: missing request child packet: %w", op, err)
}
requestPacket := &packet{Packet: p.Children[childApplicationRequest]}
switch requestPacket.Packet.Tag {
case ApplicationBindRequest:
// assert it's ldap v3
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childVersionNumber)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid packet: %w", op, err)
}
ldapVersion, ok := requestPacket.Packet.Children[childVersionNumber].Value.(int64)
if !ok {
return nil, fmt.Errorf("%s: %v is not the expected int64 type: %w", op, requestPacket.Packet.Children[childVersionNumber].Value, ErrInvalidParameter)
}
if ldapVersion != 3 {
return nil, fmt.Errorf("%s: incorrect ldap version, expected 3 but got %v", op, requestPacket.Value.(int64))
}
default:
// nothing to do or see here, move along please... :)
}
return &packet{Packet: p.Children[childApplicationRequest]}, nil
}
func (p *packet) requestType() (requestType, error) {
const op = "gldap.(Packet).requestType"
requestPacket, err := p.requestPacket()
if err != nil {
return unknownRequestType, fmt.Errorf("%s: %w", op, err)
}
switch requestPacket.Tag {
case ApplicationBindRequest:
return bindRequestType, nil
case ApplicationSearchRequest:
return searchRequestType, nil
case ApplicationExtendedRequest:
return extendedRequestType, nil
case ApplicationModifyRequest:
return modifyRequestType, nil
case ApplicationAddRequest:
return addRequestType, nil
case ApplicationDelRequest:
return deleteRequestType, nil
case ApplicationUnbindRequest:
return unbindRequestType, nil
default:
return unknownRequestType, fmt.Errorf("%s: unhandled request type %d: %w", op, requestPacket.Tag, ErrInternal)
}
}
type modifyParameters struct {
dn string
changes []Change
controls []Control
}
// return the DN, changes, and controls
func (p *packet) modifyParameters() (*modifyParameters, error) {
const (
op = "gldap.(packet).modifyParameters"
childDN = 0
childChanges = 1
childOperation = 0
childModification = 1
childModificationType = 0
childModificationValues = 1
childControls = 2
)
requestPacket, err := p.requestPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if requestPacket.Packet.Tag != ApplicationModifyRequest {
return nil, fmt.Errorf("%s: not an modify request, expected tag %d and got %d: %w", op, ApplicationModifyRequest, requestPacket.Tag, ErrInvalidParameter)
}
var parameters modifyParameters
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childDN)); err != nil {
return nil, fmt.Errorf("%s: modify dn packet: %w", op, ErrInvalidParameter)
}
parameters.dn = requestPacket.Children[childDN].Data.String()
// assert changes packet
if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childChanges)); err != nil {
return nil, fmt.Errorf("%s: modify changes packet: %w", op, ErrInvalidParameter)
}
changesPacket := requestPacket.Children[childChanges]
parameters.changes = make([]Change, 0, len(changesPacket.Children))
for _, c := range changesPacket.Children {
changePacket := packet{Packet: c}
// assert this is a "Change" packet
if err := changePacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence)); err != nil {
return nil, fmt.Errorf("%s: modify changes child packet: %w", op, ErrInvalidParameter)
}
// assert the change operation child
if err := changePacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childOperation)); err != nil {
return nil, fmt.Errorf("%s: modify changes child operation packet: %w", op, ErrInvalidParameter)
}
var ok bool
var chg Change
if chg.Operation, ok = changePacket.Children[childOperation].Value.(int64); !ok {
return nil, fmt.Errorf("%s: change operation is not an int64: %t", op, changePacket.Children[childOperation].Value)
}
// assert the change modification child
if err := changePacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childModification)); err != nil {
return nil, fmt.Errorf("%s: change modification child packet: %w", op, ErrInvalidParameter)
}
// get the modification type
modificationPacket := packet{Packet: changePacket.Children[childModification]}
if err := modificationPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childModificationType)); err != nil {
return nil, fmt.Errorf("%s: modification type packet: %w", op, ErrInvalidParameter)
}
chg.Modification.Type = modificationPacket.Children[childModificationType].Data.String()
// get the modification values
if len(modificationPacket.Children) < childModificationValues+1 {
return nil, fmt.Errorf("%s: missing modification values packet: %w", op, ErrInvalidParameter)
}
chg.Modification.Vals = make([]string, 0, len(modificationPacket.Children)-1)
for _, value := range modificationPacket.Children[1:] {
chg.Modification.Vals = append(chg.Modification.Vals, value.Data.String())
}
parameters.changes = append(parameters.changes, chg)
}
controlPacket, err := p.controlPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if controlPacket != nil {
parameters.controls = make([]Control, 0, len(controlPacket.Children))
for _, c := range controlPacket.Children {
ctrl, err := decodeControl(c)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
parameters.controls = append(parameters.controls, ctrl)
}
}
return ¶meters, nil
}
func (p *packet) extendedOperationName() (ExtendedOperationName, error) {
const (
op = "gldap.(Packet).simpleBindParameters"
childExtendedOperationName = 0
)
requestPacket, err := p.requestPacket()
if err != nil {
return "", fmt.Errorf("%s: %w", op, err)
}
if requestPacket.Packet.Tag != ApplicationExtendedRequest {
return "", fmt.Errorf("%s: not an extended operation request, expected tag %d and got %d: %w", op, ApplicationExtendedRequest, requestPacket.Tag, ErrInvalidParameter)
}
if err := requestPacket.assert(ber.ClassContext, ber.TypePrimitive, withTag(0), withAssertChild(childExtendedOperationName)); err != nil {
return "", fmt.Errorf("%s: missing/invalid username packet: %w", op, ErrInvalidParameter)
}
n := requestPacket.Children[childExtendedOperationName].Data.String()
return ExtendedOperationName(n), nil
}
// Password is a simple bind request password
type Password string
func (p *packet) simpleBindParameters() (string, Password, []Control, error) {
const (
op = "gldap.(Packet).simpleBindParameters"
childBindUserName = 1
childBindPassword = 2
)
requestPacket, err := p.requestPacket()
if err != nil {
return "", "", nil, fmt.Errorf("%s: %w", op, err)
}
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childBindUserName)); err != nil {
return "", "", nil, fmt.Errorf("%s: missing/invalid username packet: %w", op, ErrInvalidParameter)
}
userName := requestPacket.Children[childBindUserName].Data.String()
// check if there's even an password packet in the request
if len(requestPacket.Children) > 3 {
return userName, "", nil, nil
}
if err := requestPacket.assert(ber.ClassContext, ber.TypePrimitive, withTag(0), withAssertChild(childBindPassword)); err != nil {
return "", "", nil, fmt.Errorf("%s: missing/invalid password packet: %w", op, ErrInvalidParameter)
}
password := requestPacket.Children[childBindPassword].Data.String()
var controls []Control
controlPacket, err := p.controlPacket()
if err != nil {
return "", "", nil, fmt.Errorf("%s: %w", op, err)
}
if controlPacket != nil {
controls = make([]Control, 0, len(controlPacket.Children))
for _, c := range controlPacket.Children {
ctrl, err := decodeControl(c)
if err != nil {
return "", "", nil, fmt.Errorf("%s: %w", op, err)
}
controls = append(controls, ctrl)
}
}
return userName, Password(password), controls, nil
}
type addParameters struct {
dn string
attributes []Attribute
controls []Control
}
// addParameters decodes the add request parameters from the packet
func (p *packet) addParameters() (*addParameters, error) {
const op = "gldap.(Packet).addParameters"
const (
childDN = 0
childAttributes = 1
childControls = 2
)
var add addParameters
requestPacket, err := p.requestPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
// validate that it's a search request
if requestPacket.Packet.Tag != ApplicationAddRequest {
return nil, fmt.Errorf("%s: not an add request, expected tag %d and got %d: %w", op, ApplicationAddRequest, requestPacket.Tag, ErrInvalidParameter)
}
// DN child
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childDN)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid DN: %w", op, ErrInvalidParameter)
}
add.dn = requestPacket.Children[childDN].Data.String()
if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childAttributes)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid attributes: %w", op, ErrInvalidParameter)
}
attributesPackets := packet{
Packet: requestPacket.Children[childAttributes],
}
for _, attribute := range attributesPackets.Children {
attr, err := decodeAttribute(attribute)
if err != nil {
return nil, fmt.Errorf("%s: failed to decode attribute packet: %w", op, err)
}
add.attributes = append(add.attributes, *attr)
}
controlPacket, err := p.controlPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if controlPacket != nil {
add.controls = make([]Control, 0, len(controlPacket.Children))
for _, c := range controlPacket.Children {
ctrl, err := decodeControl(c)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
add.controls = append(add.controls, ctrl)
}
}
return &add, nil
}
type searchParameters struct {
baseDN string
scope int64
derefAliases int64
sizeLimit int64
timeLimit int64
typesOnly bool
filter string
attributes []string
controls []Control
}
func (p *packet) searchParmeters() (*searchParameters, error) {
const op = "gldap.(Packet).searchParmeters"
const (
childBaseDN = 0
childScope = 1
childDerefAliases = 2
childSizeLimit = 3
childTimeLimit = 4
childTypesOnly = 5
childFilter = 6
childAttributes = 7
)
var ok bool
var searchFor searchParameters
requestPacket, err := p.requestPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
// validate that it's a search request
if requestPacket.Packet.Tag != ApplicationSearchRequest {
return nil, fmt.Errorf("%s: not an search request, expected tag %d and got %d: %w", op, ApplicationSearchRequest, requestPacket.Tag, ErrInvalidParameter)
}
// baseDN child
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(childBaseDN)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid baseDN: %w", op, ErrInvalidParameter)
}
searchFor.baseDN = requestPacket.Children[childBaseDN].Data.String()
// scope child
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childScope)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid scope: %w", op, ErrInvalidParameter)
}
if searchFor.scope, ok = requestPacket.Children[childScope].Value.(int64); !ok {
return nil, fmt.Errorf("%s: scope is not an int64", op)
}
// deref aliases
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagEnumerated), withAssertChild(childDerefAliases)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid deref aliases: %w", op, ErrInvalidParameter)
}
if searchFor.derefAliases, ok = requestPacket.Children[childDerefAliases].Value.(int64); !ok {
return nil, fmt.Errorf("%s: deref aliases is not an int64", op)
}
// size limit
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childSizeLimit)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid size limit: %w", op, ErrInvalidParameter)
}
if searchFor.sizeLimit, ok = requestPacket.Children[childSizeLimit].Value.(int64); !ok {
return nil, fmt.Errorf("%s: size limit is not an int64", op)
}
// time limit
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagInteger), withAssertChild(childTimeLimit)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid time limit: %w", op, ErrInvalidParameter)
}
if searchFor.timeLimit, ok = requestPacket.Children[childTimeLimit].Value.(int64); !ok {
return nil, fmt.Errorf("%s: time limit is not an int64", op)
}
// types only
if err := requestPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagBoolean), withAssertChild(childTypesOnly)); err != nil {
return nil, fmt.Errorf("%s: missing/invalid types only: %w", op, ErrInvalidParameter)
}
if searchFor.typesOnly, ok = requestPacket.Children[childTypesOnly].Value.(bool); !ok {
return nil, fmt.Errorf("%s: types only is not a bool", op)
}
if len(requestPacket.Children) < childFilter+1 {
return nil, fmt.Errorf("%s: missing filter: %w", op, ErrInvalidParameter)
}
filter, err := ldap.DecompileFilter(requestPacket.Children[childFilter])
if err != nil {
return nil, fmt.Errorf("%s: unable to decompile filter: %w", op, err)
}
searchFor.filter = filter
// check for attributes packet
if len(requestPacket.Children) < childAttributes+1 {
return &searchFor, nil // there's none, so just return
}
if err := requestPacket.assert(ber.ClassUniversal, ber.TypeConstructed, withTag(ber.TagSequence), withAssertChild(childAttributes)); err != nil {
return nil, fmt.Errorf("%s: invalid attributes: %w", op, err)
}
attributesPacket := packet{
Packet: requestPacket.Children[childAttributes],
}
searchFor.attributes = make([]string, 0, len(attributesPacket.Children))
for idx, attribute := range attributesPacket.Children {
if err := attributesPacket.assert(ber.ClassUniversal, ber.TypePrimitive, withTag(ber.TagOctetString), withAssertChild(idx)); err != nil {
return nil, fmt.Errorf("%s: invalid attribute child packet: %w", op, err)
}
searchFor.attributes = append(searchFor.attributes, attribute.Data.String())
}
controlPacket, err := p.controlPacket()
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
if controlPacket != nil {
searchFor.controls = make([]Control, 0, len(controlPacket.Children))
for _, c := range controlPacket.Children {
ctrl, err := decodeControl(c)
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
searchFor.controls = append(searchFor.controls, ctrl)
}
}
return &searchFor, nil
}
func (p *packet) assert(cl ber.Class, ty ber.Type, opt ...Option) error {
const op = "gldap.assert"
opts := getMessageOpts(opt...)
if opts.withLenChildren != nil {
if len(p.Children) != *opts.withLenChildren {
return fmt.Errorf("%s: not the correct number of children packets, expected %d but got %d", op, *opts.withLenChildren, len(p.Children))
}
}
if opts.withMinChildren != nil {
if len(p.Children) < *opts.withMinChildren {
return fmt.Errorf("%s: not enough children packets, expected %d but got %d", op, *opts.withMinChildren, len(p.Children))
}
}
chkPacket := p.Packet
if opts.withAssertChild != nil {
if len(p.Children) < *opts.withAssertChild+1 {
return fmt.Errorf("%s: missing asserted child %d, but there are only %d", op, *opts.withAssertChild, len(p.Children))
}
chkPacket = p.Packet.Children[*opts.withAssertChild]
}
if chkPacket.ClassType != cl {
return fmt.Errorf("%s: incorrect class, expected %v but got %v", op, cl, chkPacket.ClassType)
}
if chkPacket.TagType != ty {
return fmt.Errorf("%s: incorrect type, expected %v but got %v", op, ty, chkPacket.TagType)
}
if opts.withTag != nil && chkPacket.Tag != *opts.withTag {
return fmt.Errorf("%s: incorrect tag, expected %v but got %v", op, *opts.withTag, chkPacket.Tag)
}
return nil
}
func (p *packet) assertApplicationRequest() error {
const (
op = "gldap.(packet).assertApplicationRequest"
childApplicationRequest = 1
)
if len(p.Children) < childApplicationRequest+1 {
return fmt.Errorf("%s: missing asserted application request child, but there are only %d", op, len(p.Children))
}
chkPacket := p.Packet.Children[childApplicationRequest]
if chkPacket.ClassType != ber.ClassApplication {
return fmt.Errorf("%s: incorrect class, expected %v (ber.ClassApplication) but got %v", op, ber.ClassApplication, chkPacket.ClassType)
}
switch chkPacket.TagType {
case ber.TypePrimitive:
if chkPacket.Tag != ApplicationDelRequest && chkPacket.Tag != ApplicationUnbindRequest {
return fmt.Errorf("%s: incorrect type, primitive %q must be a delete request %q or an unbind request %q, but got %q", op, ber.TypePrimitive, ApplicationDelRequest, ApplicationUnbindRequest, chkPacket.Tag)
}
case ber.TypeConstructed:
default:
return fmt.Errorf("%s: incorrect type, expected ber.TypeConstructed %q but got %v", op, ber.TypeConstructed, chkPacket.TagType)
}
return nil
}
func (p *packet) debug() {
testLogger := hclog.New(&hclog.LoggerOptions{
Name: "debug-logger",
Level: hclog.Debug,
})
p.Log(testLogger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false)
}
// Log will pretty print log a packet
func (p *packet) Log(out io.Writer, indent int, printBytes bool) {
indentStr := ""
for len(indentStr) != indent {
indentStr += " "
}
classStr := ber.ClassMap[p.ClassType]
tagtypeStr := ber.TypeMap[p.TagType]
tagStr := fmt.Sprintf("0x%02X", p.Tag)
if p.ClassType == ber.ClassUniversal {
tagStr = tagMap[p.Tag]
}
value := fmt.Sprint(p.Value)
description := ""
if p.Description != "" {
description = p.Description + ": "
}
fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagtypeStr, tagStr, p.Data.Len(), value)
if printBytes {
ber.PrintBytes(out, p.Bytes(), indentStr)
}
for _, child := range p.Children {
childPacket := packet{Packet: child}
childPacket.Log(out, indent+1, printBytes)
}
}
func (p *packet) deleteParameters() (string, []Control, error) {
const op = "gldap.(packet).deleteDN"
requestPacket, err := p.requestPacket()
if err != nil {
return "", nil, fmt.Errorf("%s: %w", op, err)
}
if requestPacket.Packet.Tag != ApplicationDelRequest {
return "", nil, fmt.Errorf("%s: not a delete request, expected tag %d and got %d: %w", op, ApplicationDelRequest, requestPacket.Tag, ErrInvalidParameter)
}
dn := requestPacket.Data.String()
controlPacket, err := p.controlPacket()
if err != nil {
return "", nil, fmt.Errorf("%s: %w", op, err)
}
var controls []Control
if controlPacket != nil {
controls = make([]Control, 0, len(controlPacket.Children))
for _, c := range controlPacket.Children {
ctrl, err := decodeControl(c)
if err != nil {
return "", nil, fmt.Errorf("%s: %w", op, err)
}
controls = append(controls, ctrl)
}
}
return dn, controls, nil
}
var tagMap = map[ber.Tag]string{
ber.TagEOC: "EOC (End-of-Content)",
ber.TagBoolean: "Boolean",
ber.TagInteger: "Integer",
ber.TagBitString: "Bit String",
ber.TagOctetString: "Octet String",
ber.TagNULL: "NULL",
ber.TagObjectIdentifier: "Object Identifier",
ber.TagObjectDescriptor: "Object Descriptor",
ber.TagExternal: "External",
ber.TagRealFloat: "Real (float)",
ber.TagEnumerated: "Enumerated",
ber.TagEmbeddedPDV: "Embedded PDV",
ber.TagUTF8String: "UTF8 String",
ber.TagRelativeOID: "Relative-OID",
ber.TagSequence: "Sequence and Sequence of",
ber.TagSet: "Set and Set OF",
ber.TagNumericString: "Numeric String",
ber.TagPrintableString: "Printable String",
ber.TagT61String: "T61 String",
ber.TagVideotexString: "Videotex String",
ber.TagIA5String: "IA5 String",
ber.TagUTCTime: "UTC Time",
ber.TagGeneralizedTime: "Generalized Time",
ber.TagGraphicString: "Graphic String",
ber.TagVisibleString: "Visible String",
ber.TagGeneralString: "General String",
ber.TagUniversalString: "Universal String",
ber.TagCharacterString: "Character String",
ber.TagBMPString: "BMP String",
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"crypto/tls"
"errors"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// ExtendedOperationName is an extended operation request/response name
type ExtendedOperationName string
// Extended operation response/request names
const (
ExtendedOperationDisconnection ExtendedOperationName = "1.3.6.1.4.1.1466.2003"
ExtendedOperationCancel ExtendedOperationName = "1.3.6.1.1.8"
ExtendedOperationStartTLS ExtendedOperationName = "1.3.6.1.4.1.1466.20037"
ExtendedOperationWhoAmI ExtendedOperationName = "1.3.6.1.4.1.4203.1.11.3"
ExtendedOperationGetConnectionID ExtendedOperationName = "1.3.6.1.4.1.26027.1.6.2"
ExtendedOperationPasswordModify ExtendedOperationName = "1.3.6.1.4.1.4203.1.11.1"
ExtendedOperationUnknown ExtendedOperationName = "Unknown"
)
// Request represents an ldap request
type Request struct {
// ID is the request number for a specific connection. Every connection has
// its own request counter which starts at 1.
ID int
// conn is needed this for cancellation among other things.
conn *conn
message Message
routeOp routeOperation
extendedName ExtendedOperationName
}
func newRequest(id int, c *conn, p *packet) (*Request, error) {
const op = "gldap.newRequest"
if c == nil {
return nil, fmt.Errorf("%s: missing connection: %w", op, ErrInvalidParameter)
}
if p == nil {
return nil, fmt.Errorf("%s: missing packet: %w", op, ErrInvalidParameter)
}
m, err := newMessage(p)
if err != nil {
return nil, fmt.Errorf("%s: unable to build message for request %d: %w", op, id, err)
}
var extendedName ExtendedOperationName
var routeOp routeOperation
switch v := m.(type) {
case *SimpleBindMessage:
routeOp = bindRouteOperation
case *SearchMessage:
routeOp = searchRouteOperation
case *ExtendedOperationMessage:
routeOp = extendedRouteOperation
extendedName = v.Name
case *ModifyMessage:
routeOp = modifyRouteOperation
case *AddMessage:
routeOp = addRouteOperation
case *DeleteMessage:
routeOp = deleteRouteOperation
case *UnbindMessage:
routeOp = unbindRouteOperation
default:
// this should be unreachable, since newMessage defaults to returning an
// *ExtendedOperationMessage
return nil, fmt.Errorf("%s: %v is an unsupported route operation: %w", op, v, ErrInternal)
}
r := &Request{
ID: id,
conn: c,
message: m,
routeOp: routeOp,
extendedName: extendedName,
}
return r, nil
}
// ConnectionID returns the request's connection ID which enables you to know
// "who" (i.e. which connection) made a request. Using the connection ID you
// can do things like ensure a connection performing a search operation has
// successfully authenticated (a.k.a. performed a successful bind operation).
func (r *Request) ConnectionID() int {
return r.conn.connID
}
// NewModifyResponse creates a modify response
// Supported options: WithResponseCode, WithDiagnosticMessage, WithMatchedDN
func (r *Request) NewModifyResponse(opt ...Option) *ModifyResponse {
opts := getResponseOpts(opt...)
return &ModifyResponse{
GeneralResponse: r.NewResponse(
WithApplicationCode(ApplicationModifyResponse),
WithResponseCode(*opts.withResponseCode),
WithDiagnosticMessage(opts.withDiagnosticMessage),
WithMatchedDN(opts.withMatchedDN),
),
}
}
// StartTLS will start a TLS connection using the Message's existing connection
func (r *Request) StartTLS(tlsconfig *tls.Config) error {
const op = "gldap.(Message).StartTLS"
if tlsconfig == nil {
return fmt.Errorf("%s: missing tls configuration: %w", op, ErrInvalidParameter)
}
tlsConn := tls.Server(r.conn.netConn, tlsconfig)
if err := tlsConn.Handshake(); err != nil {
return fmt.Errorf("%s: handshake error: %w", op, err)
}
if err := r.conn.initConn(tlsConn); err != nil {
return fmt.Errorf("%s: %w", op, err)
}
return nil
}
// NewResponse creates a general response (not necessarily to any specific
// request because you can set WithApplicationCode).
// Supported options: WithResponseCode, WithApplicationCode,
// WithDiagnosticMessage, WithMatchedDN
func (r *Request) NewResponse(opt ...Option) *GeneralResponse {
const op = "gldap.NewResponse" // nolint:unused
opts := getResponseOpts(opt...)
if opts.withResponseCode == nil {
opts.withResponseCode = intPtr(ResultUnwillingToPerform)
}
if opts.withApplicationCode == nil {
opts.withApplicationCode = intPtr(ApplicationExtendedResponse)
}
return &GeneralResponse{
baseResponse: &baseResponse{
messageID: r.message.GetID(),
code: int16(*opts.withResponseCode),
diagMessage: opts.withDiagnosticMessage,
matchedDN: opts.withMatchedDN,
},
applicationCode: *opts.withApplicationCode,
}
}
// NewExtendedResponse creates a new extended response.
// Supported options: WithResponseCode
func (r *Request) NewExtendedResponse(opt ...Option) *ExtendedResponse {
const op = "gldap.NewExtendedResponse" // nolint:unused
opts := getResponseOpts(opt...)
resp := &ExtendedResponse{
baseResponse: &baseResponse{
messageID: r.message.GetID(),
},
}
if opts.withResponseCode != nil {
resp.code = int16(*opts.withResponseCode)
}
return resp
}
// NewBindResponse creates a new bind response.
// Supported options: WithResponseCode
func (r *Request) NewBindResponse(opt ...Option) *BindResponse {
const op = "gldap.NewBindResponse" // nolint:unused
opts := getResponseOpts(opt...)
resp := &BindResponse{
baseResponse: &baseResponse{
messageID: r.message.GetID(),
},
}
if opts.withResponseCode != nil {
resp.code = int16(*opts.withResponseCode)
}
return resp
}
// GetSimpleBindMessage retrieves the SimpleBindMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetSimpleBindMessage() (*SimpleBindMessage, error) {
const op = "gldap.(Request).GetSimpleBindMessage"
s, ok := r.message.(*SimpleBindMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not a simple bind request: %w", op, r.message, ErrInvalidParameter)
}
return s, nil
}
// NewSearchDoneResponse creates a new search done response. If there are no
// results found, then set the response code by adding the option
// WithResponseCode(ResultNoSuchObject)
//
// Supported options: WithResponseCode
func (r *Request) NewSearchDoneResponse(opt ...Option) *SearchResponseDone {
const op = "gldap.(Request).NewSearchDoneResponse" // nolint:unused
opts := getResponseOpts(opt...)
resp := &SearchResponseDone{
baseResponse: &baseResponse{
messageID: r.message.GetID(),
},
}
if opts.withResponseCode != nil {
resp.code = int16(*opts.withResponseCode)
}
return resp
}
// GetSearchMessage retrieves the SearchMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetSearchMessage() (*SearchMessage, error) {
const op = "gldap.(Request).GetSearchMessage"
m, ok := r.message.(*SearchMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not a search request: %w", op, r.message, ErrInvalidParameter)
}
return m, nil
}
// NewSearchResponseEntry is a search response entry.
// Supported options: WithAttributes
func (r *Request) NewSearchResponseEntry(entryDN string, opt ...Option) *SearchResponseEntry {
opts := getResponseOpts(opt...)
newAttrs := make([]*EntryAttribute, 0, len(opts.withAttributes))
for name, values := range opts.withAttributes {
newAttrs = append(newAttrs, NewEntryAttribute(name, values))
}
return &SearchResponseEntry{
baseResponse: &baseResponse{
messageID: r.message.GetID(),
},
entry: Entry{
DN: entryDN,
Attributes: newAttrs,
},
}
}
// GetModifyMessage retrieves the ModifyMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetModifyMessage() (*ModifyMessage, error) {
const op = "gldap.(Request).GetModifyMessage"
m, ok := r.message.(*ModifyMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not a modify request: %w", op, r.message, ErrInvalidParameter)
}
return m, nil
}
// GetAddMessage retrieves the AddMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetAddMessage() (*AddMessage, error) {
const op = "gldap.(Request).GetAddMessage"
m, ok := r.message.(*AddMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not a add request: %w", op, r.message, ErrInvalidParameter)
}
return m, nil
}
// GetDeleteMessage retrieves the DeleteMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetDeleteMessage() (*DeleteMessage, error) {
const op = "gldap.(Request).GetDeleteMessage"
m, ok := r.message.(*DeleteMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not a delete request: %w", op, r.message, ErrInvalidParameter)
}
return m, nil
}
// GetUnbindMessage retrieves the UnbindMessage from the request, which
// allows you handle the request based on the message attributes.
func (r *Request) GetUnbindMessage() (*UnbindMessage, error) {
const op = "gldap.(Request).GetUnbindMessage"
m, ok := r.message.(*UnbindMessage)
if !ok {
return nil, fmt.Errorf("%s: %T not an unbind request: %w", op, r.message, ErrInvalidParameter)
}
return m, nil
}
// ConvertString will convert an ASN1 BER Octet string into a "native" go
// string. Support ber string encoding types: OctetString, GeneralString and
// all other types will return an error.
func ConvertString(octetString ...string) ([]string, error) {
const (
op = "gldap.ConvertOctetString"
berTagIdx = 0
startOfDataIdx = 1
)
converted := make([]string, 0, len(octetString))
for _, s := range octetString {
data := []byte(s)
switch {
case
ber.Tag(data[berTagIdx]) == ber.TagOctetString,
ber.Tag(data[berTagIdx]) == ber.TagGeneralString:
_, strDataLen, err := readLength(data[startOfDataIdx:])
if err != nil {
return nil, err
}
converted = append(converted, string(data[(startOfDataIdx+strDataLen):]))
default:
return nil, fmt.Errorf("%s: unsupported ber encoding type %s: %w", op, string(data[berTagIdx]), ErrInvalidParameter)
}
}
return converted, nil
}
// readLength(...)
// jimlambrt: 2/2023
// copied directly from github.com/go-asn1-ber/asn1-ber@v1.5.4/length.go
// it has an MIT license: https://github.com/go-asn1-ber/asn1-ber/blob/master/LICENSE
func readLength(bytes []byte) (length int, read int, err error) {
// length byte
b := bytes[0]
read++
switch {
case b == 0xFF:
// Invalid 0xFF (x.600, 8.1.3.5.c)
return 0, read, errors.New("invalid length byte 0xff")
case b == ber.LengthLongFormBitmask:
// Indefinite form, we have to decode packets until we encounter an EOC packet (x.600, 8.1.3.6)
length = ber.LengthIndefinite
case b&ber.LengthLongFormBitmask == 0:
// Short definite form, extract the length from the bottom 7 bits (x.600, 8.1.3.4)
length = int(b) & ber.LengthValueBitmask
case b&ber.LengthLongFormBitmask != 0:
// Long definite form, extract the number of length bytes to follow from the bottom 7 bits (x.600, 8.1.3.5.b)
lengthBytes := int(b) & ber.LengthValueBitmask
// Protect against overflow
// TODO: support big int length?
if lengthBytes > 8 {
return 0, read, errors.New("long-form length overflow")
}
// Accumulate into a 64-bit variable
var length64 int64
for i := 0; i < lengthBytes; i++ {
b = bytes[read]
read++
// x.600, 8.1.3.5
length64 <<= 8
length64 |= int64(b)
}
// Cast to a platform-specific integer
length = int(length64)
// Ensure we didn't overflow
if int64(length) != length64 {
return 0, read, errors.New("long-form length overflow")
}
default:
return 0, read, errors.New("invalid length byte")
}
return length, read, nil
}
func intPtr(i int) *int {
return &i
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"bufio"
"fmt"
"sync"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/hashicorp/go-hclog"
)
// ResponseWriter is an ldap request response writer which is used by a
// HanderFunc to write responses to client requests.
type ResponseWriter struct {
writerMu *sync.Mutex // a shared lock across all requests to prevent data races when writing
writer *bufio.Writer
logger hclog.Logger
connID int
requestID int
}
func newResponseWriter(w *bufio.Writer, lock *sync.Mutex, logger hclog.Logger, connID, requestID int) (*ResponseWriter, error) {
const op = "gldap.NewResponseWriter"
if w == nil {
return nil, fmt.Errorf("%s: missing writer: %w", op, ErrInvalidParameter)
}
if lock == nil {
return nil, fmt.Errorf("%s: missing writer lock: %w", op, ErrInvalidParameter)
}
if logger == nil {
return nil, fmt.Errorf("%s: missing logger: %w", op, ErrInvalidParameter)
}
if connID == 0 {
return nil, fmt.Errorf("%s: missing conn ID: %w", op, ErrInvalidParameter)
}
if requestID == 0 {
return nil, fmt.Errorf("%s: missing request ID: %w", op, ErrInvalidParameter)
}
return &ResponseWriter{
writerMu: lock,
writer: w,
logger: logger,
connID: connID,
requestID: requestID,
}, nil
}
// Write will write the response to the client
func (rw *ResponseWriter) Write(r Response) error {
const op = "gldap.(ResponseWriter).Write"
if r == nil {
return fmt.Errorf("%s: missing response: %w", op, ErrInvalidParameter)
}
p := r.packet()
if rw.logger.IsDebug() {
rw.logger.Debug("response write", "op", op, "conn", rw.connID, "requestID", rw.requestID)
p.Log(rw.logger.StandardWriter(&hclog.StandardLoggerOptions{}), 0, false)
}
rw.writerMu.Lock()
defer rw.writerMu.Unlock()
if _, err := rw.writer.Write(r.packet().Bytes()); err != nil {
return fmt.Errorf("%s: unable to write response: %w", op, err)
}
if err := rw.writer.Flush(); err != nil {
return fmt.Errorf("%s: unable to flush write: %w", op, err)
}
rw.logger.Debug("finished writing", "op", op, "conn", rw.connID, "requestID", rw.requestID)
return nil
}
func beginResponse(messageID int64) *ber.Packet {
const op = "gldap.beginResponse" // nolint:unused
p := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
p.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
return p
}
func addOptionalResponseChildren(bindResponse *ber.Packet, opt ...Option) {
const op = "gldap.addOptionalResponseChildren" // nolint:unused
opts := getResponseOpts(opt...)
bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, opts.withMatchedDN, "matchedDN"))
bindResponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, opts.withDiagnosticMessage, "diagnosticMessage"))
}
// Response represents a response to an ldap request
type Response interface {
packet() *packet
}
type baseResponse struct {
messageID int64
code int16
diagMessage string
matchedDN string
}
// SetResultCode the result code for a response.
func (l *baseResponse) SetResultCode(code int) {
l.code = int16(code)
}
// SetDiagnosticMessage sets the optional diagnostic message for a response.
func (l *baseResponse) SetDiagnosticMessage(msg string) {
l.diagMessage = msg
}
// SetMatchedDN sets the optional matched DN for a response.
func (l *baseResponse) SetMatchedDN(dn string) {
l.matchedDN = dn
}
// ExtendedResponse represents a response to an extended operation request
type ExtendedResponse struct {
*baseResponse
name ExtendedOperationName
}
// SetResponseName will set the response name for the extended operation response.
func (r *ExtendedResponse) SetResponseName(n ExtendedOperationName) {
r.name = n
}
func (r *ExtendedResponse) packet() *packet {
replyPacket := beginResponse(r.messageID)
// a new packet for the bind response
resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(ApplicationExtendedResponse), nil, ApplicationCodeMap[ApplicationExtendedResponse])
// append the result code to the bind response packet
resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)]))
// Add optional diagnostic message and matched DN
addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN))
replyPacket.AppendChild(resultPacket)
return &packet{Packet: replyPacket}
}
// BindResponse represents the response to a bind request
type BindResponse struct {
*baseResponse
controls []Control
}
// SetControls for bind response
func (r *BindResponse) SetControls(controls ...Control) {
r.controls = controls
}
func (r *BindResponse) packet() *packet {
replyPacket := beginResponse(r.messageID)
// a new packet for the bind response
resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(ApplicationBindResponse), nil, ApplicationCodeMap[ApplicationBindResponse])
// append the result code to the bind response packet
resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)]))
// Add optional diagnostic message and matched DN
addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN))
replyPacket.AppendChild(resultPacket)
if len(r.controls) > 0 {
replyPacket.AppendChild(encodeControls(r.controls))
}
return &packet{Packet: replyPacket}
}
// GeneralResponse represents a general response (non-specific to a request).
type GeneralResponse struct {
*baseResponse
applicationCode int
}
func (r *GeneralResponse) packet() *packet {
const op = "gldap.(GeneralResponse).packet" // nolint:unused
replyPacket := beginResponse(r.messageID)
// a new packet for the bind response
resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ber.Tag(r.applicationCode), nil, ApplicationCodeMap[uint8(r.applicationCode)])
// append the result code to the bind response packet
resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)]))
// Add optional diagnostic message and matched DN
addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN))
replyPacket.AppendChild(resultPacket)
return &packet{Packet: replyPacket}
}
// SearchResponseDone represents that handling a search requests is done.
type SearchResponseDone struct {
*baseResponse
controls []Control
}
// SetControls for the search response
func (r *SearchResponseDone) SetControls(controls ...Control) {
r.controls = controls
}
func (r *SearchResponseDone) packet() *packet {
const op = "gldap.(SearchDoneResponse).packet" // nolint:unused
replyPacket := beginResponse(r.messageID)
resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, ApplicationCodeMap[ApplicationSearchResultDone])
resultPacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, r.code, ResultCodeMap[uint16(r.code)]))
// Add optional diagnostic message and matched DN
addOptionalResponseChildren(resultPacket, WithDiagnosticMessage(r.diagMessage), WithMatchedDN(r.matchedDN))
replyPacket.AppendChild(resultPacket)
if len(r.controls) > 0 {
replyPacket.AppendChild(encodeControls(r.controls))
}
return &packet{Packet: replyPacket}
}
// SearchResponseEntry is an ldap entry that's part of search response.
type SearchResponseEntry struct {
*baseResponse
entry Entry
}
// AddAttribute will an attributes to the response entry
func (r *SearchResponseEntry) AddAttribute(name string, values []string) {
r.entry.Attributes = append(r.entry.Attributes, NewEntryAttribute(name, values))
}
func (r *SearchResponseEntry) packet() *packet {
const op = "gldap.(SearchEntryResponse).packet" // nolint:unused
replyPacket := beginResponse(r.messageID)
resultPacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, ApplicationCodeMap[ApplicationSearchResultEntry])
resultPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, r.entry.DN, "DN"))
attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, a := range r.entry.Attributes {
attributesPacket.AppendChild(a.encode())
}
resultPacket.AppendChild(attributesPacket)
replyPacket.AppendChild(resultPacket)
return &packet{Packet: replyPacket}
}
// ModifyResponse is a response to a modify request.
type ModifyResponse struct {
*GeneralResponse
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
type responseOptions struct {
withDiagnosticMessage string
withMatchedDN string
withResponseCode *int
withApplicationCode *int
withAttributes map[string][]string
}
func responseDefaults() responseOptions {
return responseOptions{
withMatchedDN: "Unused",
withDiagnosticMessage: "Unused",
}
}
func getResponseOpts(opt ...Option) responseOptions {
opts := responseDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithDiagnosticMessage provides an optional diagnostic message for the
// response.
func WithDiagnosticMessage(msg string) Option {
return func(o interface{}) {
if o, ok := o.(*responseOptions); ok {
o.withDiagnosticMessage = msg
}
}
}
// WithMatchedDN provides an optional match DN for the response.
func WithMatchedDN(dn string) Option {
return func(o interface{}) {
if o, ok := o.(*responseOptions); ok {
o.withMatchedDN = dn
}
}
}
// WithResponseCode specifies the ldap response code. For a list of valid codes
// see:
// https://github.com/go-ldap/ldap/blob/13008e4c5260d08625b65eb1f172ae909152b751/v3/error.go#L11
func WithResponseCode(code int) Option {
return func(o interface{}) {
if o, ok := o.(*responseOptions); ok {
o.withResponseCode = &code
}
}
}
// WithApplicationCode specifies the ldap application code. For a list of valid codes
// for a list of supported application codes see:
// https://github.com/jimlambrt/gldap/blob/8f171b8eb659c76019719382c4daf519dd1281e6/codes.go#L159
func WithApplicationCode(applicationCode int) Option {
return func(o interface{}) {
if o, ok := o.(*responseOptions); ok {
o.withApplicationCode = &applicationCode
}
}
}
// WithAttributes specifies optional attributes for a response entry
func WithAttributes(attributes map[string][]string) Option {
return func(o interface{}) {
if o, ok := o.(*responseOptions); ok {
o.withAttributes = attributes
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"strings"
)
// routeOperation represents the ldap operation for a route.
type routeOperation string
const (
// undefinedRouteOperation is an undefined operation.
undefinedRouteOperation routeOperation = "" // nolint:unused
// bindRouteOperation is a route supporting the bind operation
bindRouteOperation routeOperation = "bind"
// searchRouteOperation is a route supporting the search operation
searchRouteOperation routeOperation = "search"
// extendedRouteOperation is a route supporting an extended operation
extendedRouteOperation routeOperation = "extendedOperation"
// modifyRouteOperation is a route supporting the modify operation
modifyRouteOperation routeOperation = "modify"
// addRouteOperation is a route supporting the add operation
addRouteOperation routeOperation = "add"
// deleteRouteOperation is a route supporting the delete operation
deleteRouteOperation routeOperation = "delete"
// unbindRouteOperation is a route supporting the unbind operation
unbindRouteOperation routeOperation = "unbind"
// defaultRouteOperation is a default route which is used when there are no routes
// defined for a particular operation
defaultRouteOperation routeOperation = "noRoute" // nolint:unused
)
// HandlerFunc defines a function for handling an LDAP request.
type HandlerFunc func(*ResponseWriter, *Request)
type route interface {
match(req *Request) bool
handler() HandlerFunc
op() routeOperation
}
type baseRoute struct {
h HandlerFunc
routeOp routeOperation
label string
}
func (r *baseRoute) handler() HandlerFunc {
return r.h
}
func (r *baseRoute) op() routeOperation {
return r.routeOp
}
func (r *baseRoute) match(req *Request) bool {
return false
}
type searchRoute struct {
*baseRoute
basedn string
filter string
scope Scope
}
type simpleBindRoute struct {
*baseRoute
authChoice AuthChoice
}
type unbindRoute struct {
*baseRoute
}
type extendedRoute struct {
*baseRoute
extendedName ExtendedOperationName
}
type modifyRoute struct {
*baseRoute
}
type addRoute struct {
*baseRoute
}
type deleteRoute struct {
*baseRoute
}
func (r *deleteRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
if _, ok := req.message.(*DeleteMessage); !ok {
return false
}
return true
}
func (r *addRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
if _, ok := req.message.(*AddMessage); !ok {
return false
}
return true
}
func (r *modifyRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
if _, ok := req.message.(*ModifyMessage); !ok {
return false
}
return true
}
func (r *simpleBindRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
if m, ok := req.message.(*SimpleBindMessage); ok {
if r.authChoice != "" && r.authChoice == m.AuthChoice {
return true
}
}
return false
}
func (r *extendedRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
if r.extendedName != req.extendedName {
return false
}
_, ok := req.message.(*ExtendedOperationMessage)
return ok
}
func (r *searchRoute) match(req *Request) bool {
if req == nil {
return false
}
if r.op() != req.routeOp {
return false
}
searchMsg, ok := req.message.(*SearchMessage)
if !ok {
return false
}
if r.basedn != "" && !strings.EqualFold(searchMsg.BaseDN, r.basedn) {
return false
}
if r.filter != "" && !strings.EqualFold(searchMsg.Filter, r.filter) {
return false
}
if r.scope != 0 && searchMsg.Scope != r.scope {
return false
}
// if it didn't get eliminated by earlier request criteria, then it's a
// match.
return true
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
type routeOptions struct {
withLabel string
withBaseDN string
withFilter string
withScope Scope
}
func routeDefaults() routeOptions {
return routeOptions{}
}
func getRouteOpts(opt ...Option) routeOptions {
opts := routeDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithLabel specifies an optional label for the route
func WithLabel(l string) Option {
return func(o interface{}) {
if o, ok := o.(*routeOptions); ok {
o.withLabel = l
}
}
}
// WithBaseDN specifies an optional base DN to associate with a Search route
func WithBaseDN(dn string) Option {
return func(o interface{}) {
if o, ok := o.(*routeOptions); ok {
o.withBaseDN = dn
}
}
}
// WithFilter specifies an optional filter to associate with a Search route
func WithFilter(filter string) Option {
return func(o interface{}) {
if o, ok := o.(*routeOptions); ok {
o.withFilter = filter
}
}
}
// WithScope specifies and optional scope to associate with a Search route
func WithScope(s Scope) Option {
return func(o interface{}) {
if o, ok := o.(*routeOptions); ok {
o.withScope = s
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/hashicorp/go-hclog"
)
// Server is an ldap server that you can add a mux (multiplexer) router to and
// then run it to accept and process requests.
type Server struct {
mu sync.RWMutex
logger hclog.Logger
connWg sync.WaitGroup
listener net.Listener
listenerReady bool
router *Mux
tlsConfig *tls.Config
readTimeout time.Duration
writeTimeout time.Duration
onCloseHandler OnCloseHandler
disablePanicRecovery bool
shutdownCancel context.CancelFunc
shutdownCtx context.Context
}
// NewServer creates a new ldap server
//
// Options supported:
// - WithLogger allows you pass a logger with whatever hclog.Level you wish including hclog.Off to turn off all logging
// - WithReadTimeout will set a read time out per connection
// - WithWriteTimeout will set a write time out per connection
// - WithOnClose will define a callback the server will call every time a connection is closed
func NewServer(opt ...Option) (*Server, error) {
cancelCtx, cancel := context.WithCancel(context.Background())
opts := getConfigOpts(opt...)
if opts.withLogger == nil {
opts.withLogger = hclog.New(&hclog.LoggerOptions{
Name: "Server-logger",
Level: hclog.Error,
})
}
return &Server{
router: &Mux{}, // TODO: a better default router
logger: opts.withLogger,
shutdownCancel: cancel,
shutdownCtx: cancelCtx,
writeTimeout: opts.withWriteTimeout,
readTimeout: opts.withReadTimeout,
disablePanicRecovery: opts.withDisablePanicRecovery,
onCloseHandler: opts.withOnClose,
}, nil
}
// Index of rightmost occurrence of b in s.
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}
// validateAddr will not only validate the addr, but if it's an ipv6 literal without
// proper brackets, it will add them.
func validateAddr(addr string) (string, error) {
const op = "gldap.parseAddr"
lastColon := last(addr, ':')
if lastColon < 0 {
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
}
rawHost := addr[0:lastColon]
rawPort := addr[lastColon+1:]
if len(rawPort) == 0 {
return "", fmt.Errorf("%s: missing port in addr %s : %w", op, addr, ErrInvalidParameter)
}
if len(rawHost) == 0 {
return fmt.Sprintf(":%s", rawPort), nil
}
// ipv6 literal with proper brackets
if rawHost[0] == '[' {
// Expect the first ']' just before the last ':'.
end := strings.IndexByte(rawHost, ']')
if end < 0 {
return "", fmt.Errorf("%s: missing ']' in ipv6 address %s : %w", op, addr, ErrInvalidParameter)
}
trimedIp := strings.Trim(rawHost, "[]")
if net.ParseIP(trimedIp) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address %s : %w", op, rawHost, ErrInvalidParameter)
}
// ipv6 literal has enclosing brackets, and it's a valid ipv6 address, so we're good
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}
// see if we're dealing with a hostname
hostnames, _ := net.LookupHost(rawHost)
if len(hostnames) > 0 {
if rawHost == "::1" {
// special case for localhost
return fmt.Sprintf("[%s]:%s", rawHost, rawPort), nil
}
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}
lastColon = last(rawHost, ':')
if lastColon >= 0 {
// ipv6 literal without proper brackets
ipv6Literal := fmt.Sprintf("[%s]", rawHost)
if net.ParseIP(ipv6Literal) == nil {
return "", fmt.Errorf("%s: invalid ipv6 address + port %s : %w", op, addr, ErrInvalidParameter)
}
return fmt.Sprintf("[%s]:%s", ipv6Literal, rawPort), nil
}
// ipv4
if net.ParseIP(rawHost) == nil {
return "", fmt.Errorf("%s: invalid IP address %s : %w", op, rawHost, ErrInvalidParameter)
}
return fmt.Sprintf("%s:%s", rawHost, rawPort), nil
}
// Run will run the server which will listen and serve requests.
//
// Options supported: WithTLSConfig
func (s *Server) Run(addr string, opt ...Option) error {
const op = "gldap.(Server).Run"
opts := getConfigOpts(opt...)
var err error
addr, err = validateAddr(addr)
if err != nil {
return fmt.Errorf("%s: %w", op, err)
}
s.mu.Lock()
s.listener, err = net.Listen("tcp", addr)
s.listenerReady = true
s.mu.Unlock()
if err != nil {
return fmt.Errorf("%s: unable to listen to addr %s: %w", op, addr, err)
}
if opts.withTLSConfig != nil {
s.logger.Debug("setting up TLS listener", "op", op)
s.tlsConfig = opts.withTLSConfig
s.mu.Lock()
s.listener = tls.NewListener(s.listener, s.tlsConfig)
s.mu.Unlock()
}
s.logger.Info("listening", "op", op, "addr", s.listener.Addr())
connID := 0
for {
connID++
select {
case <-s.shutdownCtx.Done():
return nil
default:
// need a default to fall through to rest of loop...
}
c, err := s.listener.Accept()
if err != nil {
if strings.Contains(err.Error(), "use of closed network connection") {
s.logger.Debug("accept on closed conn")
return nil
}
return fmt.Errorf("%s: error accepting conn: %w", op, err)
}
s.logger.Debug("new connection accepted", "op", op, "conn", connID)
conn, err := newConn(s.shutdownCtx, connID, c, s.logger, s.router)
if err != nil {
return fmt.Errorf("%s: unable to create in-memory conn: %w", op, err)
}
localConnID := connID
s.connWg.Add(1)
go func() {
defer func() {
s.logger.Debug("connWg done", "op", op, "conn", localConnID)
s.connWg.Done()
err := conn.close()
if err != nil {
s.logger.Error("error closing conn", "op", op, "conn", localConnID, "conn/req", "err", err)
// we are intentionally not returning here; since we still
// need to call the onCloseHandler if it's not nil
}
if s.onCloseHandler != nil {
s.onCloseHandler(localConnID)
}
}()
if !s.disablePanicRecovery {
// catch and report panics - we don't want it to crash the server if
// handling a single conn causes a panic
defer func() {
if r := recover(); r != nil {
s.logger.Error("Caught panic while serving request", "op", op, "conn", localConnID, "conn/req", fmt.Sprintf("%+v: %+v", c, r))
}
}()
}
if s.readTimeout != 0 {
if err := c.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
s.logger.Error("unable to set read deadline", "op", op, "err", err.Error())
return
}
}
if s.writeTimeout != 0 {
if err := c.SetWriteDeadline(time.Now().Add(s.writeTimeout)); err != nil {
s.logger.Error("unable to set write deadline", "op", op, "err", err.Error())
return
}
}
if err := conn.serveRequests(); err != nil {
s.logger.Error("error handling conn", "op", op, "conn", localConnID, "err", err.Error())
}
}()
}
}
// Ready will return true when the server is ready to accept connection
func (s *Server) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.listenerReady
}
// Stop a running ldap server
func (s *Server) Stop() error {
const op = "gldap.(Server).Stop"
s.mu.RLock()
defer s.mu.RUnlock()
s.logger.Debug("shutting down")
if s.listener == nil && s.shutdownCancel == nil {
s.logger.Debug("nothing to do for shutdown")
return nil
}
if s.listener != nil {
s.logger.Debug("closing listener")
if err := s.listener.Close(); err != nil {
switch {
case !strings.Contains(err.Error(), "use of closed network connection"):
return fmt.Errorf("%s: %w", op, err)
default:
s.logger.Debug("listener already closed")
}
}
}
if s.shutdownCancel != nil {
s.logger.Debug("shutdown cancel func")
s.shutdownCancel()
}
s.logger.Debug("waiting on connections to close")
s.connWg.Wait()
s.logger.Debug("stopped")
return nil
}
// Router sets the mux (multiplexer) router for matching inbound requests
// to handlers.
func (s *Server) Router(r *Mux) error {
const op = "gldap.(Server).HandleRoutes"
if r == nil {
return fmt.Errorf("%s: missing router: %w", op, ErrInvalidParameter)
}
s.mu.Lock()
defer s.mu.Unlock()
s.router = r
return nil
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"crypto/tls"
"time"
"github.com/hashicorp/go-hclog"
)
type configOptions struct {
withTLSConfig *tls.Config
withLogger hclog.Logger
withReadTimeout time.Duration
withWriteTimeout time.Duration
withDisablePanicRecovery bool
withOnClose OnCloseHandler
}
func configDefaults() configOptions {
return configOptions{}
}
// getConfigOpts gets the defaults and applies the opt overrides passed
// in.
func getConfigOpts(opt ...Option) configOptions {
opts := configDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithLogger provides the optional logger.
func WithLogger(l hclog.Logger) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withLogger = l
}
}
}
// WithTLSConfig provides an optional tls.Config
func WithTLSConfig(tc *tls.Config) Option {
return func(o interface{}) {
switch v := o.(type) {
case *configOptions:
v.withTLSConfig = tc
}
}
}
// WithReadTimeout will set a read time out per connection
func WithReadTimeout(d time.Duration) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withReadTimeout = d
}
}
}
// WithWriteTimeout will set a write timeout per connection
func WithWriteTimeout(d time.Duration) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withWriteTimeout = d
}
}
}
// WithDisablePanicRecovery will disable recovery from panics which occur when
// handling a request. This is helpful for debugging since you'll get the
// panic's callstack.
func WithDisablePanicRecovery() Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withDisablePanicRecovery = true
}
}
}
// OnCloseHandler defines a function for a "on close" callback handler. See:
// NewServer(...) and WithOnClose(...) option for more information
type OnCloseHandler func(connectionID int)
// WithOnClose defines a OnCloseHandler that the server will use as a callback
// every time a connection to the server is closed. This allows callers to
// clean up resources for closed connections (using their ID to determine which
// one to clean up)
func WithOnClose(handler OnCloseHandler) Option {
return func(o interface{}) {
if o, ok := o.(*configOptions); ok {
o.withOnClose = handler
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"bytes"
"encoding/binary"
"fmt"
)
// SIDBytes creates a SID from the provided revision and identifierAuthority
func SIDBytes(revision uint8, identifierAuthority uint16) ([]byte, error) {
const op = "gldap.SidBytes"
var identifierAuthorityParts [3]uint16
identifierAuthorityParts[2] = identifierAuthority
subAuthorityCount := uint8(0)
var writer bytes.Buffer
if err := binary.Write(&writer, binary.LittleEndian, uint8(revision)); err != nil {
return nil, fmt.Errorf("%s: unable to write revision: %w", op, err)
}
if err := binary.Write(&writer, binary.LittleEndian, subAuthorityCount); err != nil {
return nil, fmt.Errorf("%s: unable to write subauthority count: %w", op, err)
}
if err := binary.Write(&writer, binary.BigEndian, identifierAuthorityParts); err != nil {
return nil, fmt.Errorf("%s: unable to write authority parts: %w", op, err)
}
return writer.Bytes(), nil
}
// SIDBytesToString will convert SID bytes to a string
func SIDBytesToString(b []byte) (string, error) {
const op = "gldap.sidBytesToString"
reader := bytes.NewReader(b)
var revision, subAuthorityCount uint8
var identifierAuthorityParts [3]uint16
if err := binary.Read(reader, binary.LittleEndian, &revision); err != nil {
return "", fmt.Errorf("%s: SID %#v convert failed reading Revision: %w", op, b, err)
}
if err := binary.Read(reader, binary.LittleEndian, &subAuthorityCount); err != nil {
return "", fmt.Errorf("%s: SID %#v convert failed reading SubAuthorityCount: %w", op, b, err)
}
if err := binary.Read(reader, binary.BigEndian, &identifierAuthorityParts); err != nil {
return "", fmt.Errorf("%s: SID %#v convert failed reading IdentifierAuthority: %w", op, b, err)
}
identifierAuthority := (uint64(identifierAuthorityParts[0]) << 32) + (uint64(identifierAuthorityParts[1]) << 16) + uint64(identifierAuthorityParts[2])
subAuthority := make([]uint32, subAuthorityCount)
if err := binary.Read(reader, binary.LittleEndian, &subAuthority); err != nil {
return "", fmt.Errorf("%s: SID %#v convert failed reading SubAuthority: %w", op, b, err)
}
result := fmt.Sprintf("S-%d-%d", revision, identifierAuthority)
for _, subAuthorityPart := range subAuthority {
result += fmt.Sprintf("-%d", subAuthorityPart)
}
return result, nil
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package testdirectory
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff"
"github.com/go-ldap/ldap/v3"
"github.com/hashicorp/go-hclog"
"github.com/jimlambrt/gldap"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
const (
// DefaultUserAttr is the "username" attribute of the entry's DN and is
// typically either the cn in ActiveDirectory or uid in openLDAP (default:
// cn)
DefaultUserAttr = "cn"
// DefaultGroupAttr for the ClientConfig.GroupAttr
DefaultGroupAttr = "cn"
// DefaultUserDN defines a default base distinguished name to use when
// searching for users for the Directory
DefaultUserDN = "ou=people,dc=example,dc=org"
// DefaultGroupDN defines a default base distinguished name to use when
// searching for groups for the Directory
DefaultGroupDN = "ou=groups,dc=example,dc=org"
)
// Directory is a local ldap directory that supports test ldap capabilities
// which makes writing tests much easier.
//
// It's important to remember that the Directory is stateful (see any of its
// receiver functions that begin with Set*)
//
// Once you started a Directory with Start(...), the following
// test ldap operations are supported:
//
// - Bind
// - StartTLS
// - Search
// - Modify
// - Add
//
// Making requests to the Directory is facilitated by:
// - Directory.Conn() returns a *ldap.Conn connected to the Directory (honors WithMTLS options from start)
// - Directory.Cert() returns the pem-encoded CA certificate used by the directory.
// - Directory.Port() returns the port the directory is listening on.
// - Directory.ClientCert() returns a client cert for mtls
// - Directory.ClientKey() returns a client private key for mtls
type Directory struct {
t TestingT
s *gldap.Server
logger hclog.Logger
port int
host string
useTLS bool
client *tls.Config
server *tls.Config
mu sync.Mutex
users []*gldap.Entry
groups []*gldap.Entry
tokenGroups map[string][]*gldap.Entry // string == SID
allowAnonymousBind bool
controls []gldap.Control
// userDN is the base distinguished name to use when searching for users
userDN string
// groupDN is the base distinguished name to use when searching for groups
groupDN string
}
// Start creates and starts a running Directory ldap server.
// Support options: WithPort, WithMTLS, WithNoTLS, WithDefaults,
// WithLogger.
//
// The Directory will be shutdown when the test and all its
// subtests are compted via a registered function with t.Cleanup(...)
func Start(t TestingT, opt ...Option) *Directory {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
require := require.New(t)
opts := getOpts(t, opt...)
if opts.withPort == 0 {
opts.withPort = FreePort(t)
}
d := &Directory{
t: t,
logger: opts.withLogger,
users: opts.withDefaults.Users,
groups: opts.withDefaults.Groups,
port: opts.withPort,
host: opts.withHost,
userDN: opts.withDefaults.UserDN,
groupDN: opts.withDefaults.GroupDN,
allowAnonymousBind: opts.withDefaults.AllowAnonymousBind,
}
var err error
var srvOpts []gldap.Option
if opts.withLogger != nil {
srvOpts = append(srvOpts, gldap.WithLogger(opts.withLogger))
}
if opts.withDisablePanicRecovery {
srvOpts = append(srvOpts, gldap.WithDisablePanicRecovery())
}
d.s, err = gldap.NewServer(srvOpts...)
require.NoError(err)
d.logger.Debug("base search DNs", "users", d.userDN, "groups", d.groupDN)
mux, err := gldap.NewMux()
require.NoError(err)
require.NoError(mux.DefaultRoute(d.handleNotFound(t)))
require.NoError(mux.Bind(d.handleBind(t)))
require.NoError(mux.ExtendedOperation(d.handleStartTLS(t), gldap.ExtendedOperationStartTLS))
require.NoError(mux.Search(d.handleSearchUsers(t), gldap.WithBaseDN(d.userDN), gldap.WithLabel("Search - Users")))
require.NoError(mux.Search(d.handleSearchGroups(t), gldap.WithBaseDN(d.groupDN), gldap.WithLabel("Search - Groups")))
require.NoError(mux.Search(d.handleSearchGeneric(t), gldap.WithLabel("Search - Generic")))
require.NoError(mux.Modify(d.handleModify(t), gldap.WithLabel("Modify")))
require.NoError(mux.Add(d.handleAdd(t), gldap.WithLabel("Add")))
require.NoError(mux.Delete(d.handleDelete(t), gldap.WithLabel("Delete")))
require.NoError(d.s.Router(mux))
serverTLSConfig, clientTLSConfig := GetTLSConfig(t, opt...)
d.client = clientTLSConfig
d.server = serverTLSConfig
var connOpts []gldap.Option
if !opts.withNoTLS {
d.useTLS = true
connOpts = append(connOpts, gldap.WithTLSConfig(d.server))
if opts.withMTLS {
d.logger.Debug("using mTLS")
} else {
d.logger.Debug("using TLS")
}
} else {
d.logger.Debug("not using TLS")
}
go func() {
err := d.s.Run(fmt.Sprintf("%s:%d", opts.withHost, opts.withPort), connOpts...)
if err != nil {
d.logger.Error("Error during shutdown", "op", "testdirectory.Start", "err", err.Error())
}
}()
if v, ok := interface{}(t).(CleanupT); ok {
v.Cleanup(func() { _ = d.s.Stop() })
}
// need a bit of a pause to get the service up and running, otherwise we'll
// get a connection error because the service isn't listening yet.
for {
time.Sleep(100 * time.Nanosecond)
if d.s.Ready() {
break
}
}
return d
}
// Stop will stop the Directory if it wasn't started with a *testing.T
// if it was started with *testing.T then Stop() is ignored.
func (d *Directory) Stop() {
const op = "testdirectory.(Directory).Stop"
if _, ok := interface{}(d.t).(CleanupT); !ok {
err := d.s.Stop()
if err != nil {
d.logger.Error("error stopping directory: %s", "op", op, "err", err)
return
}
}
}
// handleBind is ONLY supporting simple authentication (no SASL here!)
func (d *Directory) handleBind(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleBind"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
resp := r.NewBindResponse(gldap.WithResponseCode(gldap.ResultInvalidCredentials))
defer func() {
_ = w.Write(resp)
}()
m, err := r.GetSimpleBindMessage()
if err != nil {
d.logger.Error("not a simple bind message", "op", op, "err", err)
return
}
if m.AuthChoice != gldap.SimpleAuthChoice {
// if it's not a simple auth request, then the bind failed...
return
}
if m.Password == "" && d.allowAnonymousBind {
resp.SetResultCode(gldap.ResultSuccess)
return
}
for _, u := range d.users {
d.logger.Debug("user", "u.DN", u.DN, "m.UserName", m.UserName)
if u.DN == m.UserName {
d.logger.Debug("found bind user", "op", op, "DN", u.DN)
values := u.GetAttributeValues("password")
if len(values) > 0 && string(m.Password) == values[0] {
resp.SetResultCode(gldap.ResultSuccess)
if d.controls != nil {
d.mu.Lock()
defer d.mu.Unlock()
resp.SetControls(d.controls...)
}
return
}
}
}
// bind failed...
return //nolint:gosimple // (ignore redundant return)
}
}
func (d *Directory) handleNotFound(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleNotFound"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
resp := r.NewResponse(gldap.WithDiagnosticMessage("intentionally not handled"))
_ = w.Write(resp)
return //nolint:gosimple // (ignore redundant return)
}
}
func (d *Directory) handleStartTLS(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleStartTLS"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewExtendedResponse(gldap.WithResponseCode(gldap.ResultSuccess))
res.SetResponseName(gldap.ExtendedOperationStartTLS)
err := w.Write(res)
if err != nil {
d.logger.Error("error writing response: %s", "op", op, "err", err)
return
}
if err := r.StartTLS(d.server); err != nil {
d.logger.Error("StartTLS Handshake error", "op", op, "err", err)
res.SetDiagnosticMessage(fmt.Sprintf("StartTLS Handshake error : \"%s\"", err.Error()))
res.SetResultCode(gldap.ResultOperationsError)
err := w.Write(res)
if err != nil {
d.logger.Error("error writing response: %s", "op", op, "err", err)
return
}
return
}
d.logger.Debug("StartTLS OK", "op", op)
}
}
func (d *Directory) handleSearchGeneric(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleSearchGeneric"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing response: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetSearchMessage()
if err != nil {
d.logger.Error("not a search message: %s", "op", op, "err", err)
return
}
d.logSearchRequest(m)
filter := m.Filter
// if our search base is the base userDN, we're searching for a single
// user, so adjust the filter to match user's entries
if strings.Contains(string(m.BaseDN), d.userDN) {
filter = fmt.Sprintf("(%s)", m.BaseDN)
d.logger.Debug("new filter", "op", op, "value", filter)
for _, a := range m.Attributes {
d.logger.Debug("attr", "op", op, "value", a)
if a == "tokenGroups" {
d.logger.Debug("asking for groups", "op", op)
}
}
}
var foundEntries int
// if our search base is a SID, then we're searching for tokenGroups
if len(d.tokenGroups) > 0 && strings.HasPrefix(string(m.BaseDN), "<SID=") {
sid := string(m.BaseDN)
sid = strings.TrimPrefix(sid, "<SID=")
sid = strings.TrimSuffix(sid, ">")
for _, g := range d.tokenGroups[sid] {
d.logger.Debug("found tokenGroup", "op", op, "group DN", g.DN)
result := r.NewSearchResponseEntry(g.DN)
for _, attr := range g.Attributes {
result.AddAttribute(attr.Name, attr.Values)
}
foundEntries += 1
err = w.Write(result)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}
d.logger.Debug("found entries", "op", op, "count", foundEntries)
res.SetResultCode(gldap.ResultSuccess)
return
}
d.logger.Debug("filter", "op", op, "value", filter)
var entries []*gldap.Entry
for _, e := range d.users {
if ok, _ := match(filter, e.DN); !ok {
continue
}
entries = append(entries, e)
foundEntries += 1
}
for _, e := range d.groups {
if ok, _ := match(filter, e.DN); !ok {
continue
}
switch {
case slices.Contains(entries, e):
continue
default:
entries = append(entries, e)
foundEntries += 1
}
}
if foundEntries > 0 {
d.logger.Debug("found entries", "op", op, "count", foundEntries)
for _, e := range entries {
result := r.NewSearchResponseEntry(e.DN)
for _, attr := range e.Attributes {
result.AddAttribute(attr.Name, attr.Values)
}
foundEntries += 1
err := w.Write(result)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}
if d.controls != nil {
d.mu.Lock()
defer d.mu.Unlock()
res.SetControls(d.controls...)
}
res.SetResultCode(gldap.ResultSuccess)
}
}
}
func (d *Directory) handleSearchGroups(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleSearchGroups"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetSearchMessage()
if err != nil {
d.logger.Error("not a search message: %s", "op", op, "err", err)
return
}
d.logSearchRequest(m)
_, entries := d.findMembers(m.Filter)
foundEntries := len(entries)
for _, e := range d.groups {
if ok, _ := match(m.Filter, e.DN); !ok {
continue
}
switch {
case slices.Contains(entries, e):
continue
default:
entries = append(entries, e)
}
foundEntries += 1
}
if foundEntries > 0 {
for _, e := range entries {
result := r.NewSearchResponseEntry(e.DN)
for _, attr := range e.Attributes {
result.AddAttribute(attr.Name, attr.Values)
}
foundEntries += 1
err = w.Write(result)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}
d.logger.Debug("found entries", "op", op, "count", foundEntries)
if d.controls != nil {
d.mu.Lock()
defer d.mu.Unlock()
res.SetControls(d.controls...)
}
res.SetResultCode(gldap.ResultSuccess)
}
}
}
func (d *Directory) handleSearchUsers(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleSearchUsers"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetSearchMessage()
if err != nil {
d.logger.Error("not a search message: %s", "op", op, "err", err)
return
}
d.logSearchRequest(m)
var foundEntries int
_, _, entries := find(d.t, m.Filter, d.users)
if len(entries) == 0 {
return
}
for _, e := range entries {
result := r.NewSearchResponseEntry(e.DN)
for _, attr := range e.Attributes {
result.AddAttribute(attr.Name, attr.Values)
}
foundEntries += 1
err := w.Write(result)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}
if foundEntries > 0 {
d.logger.Debug("found entries", "op", op, "count", foundEntries)
if d.controls != nil {
d.mu.Lock()
defer d.mu.Unlock()
res.SetControls(d.controls...)
fmt.Println(d.controls)
}
res.SetResultCode(gldap.ResultSuccess)
}
}
}
func (d *Directory) handleModify(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleModify"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewModifyResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetModifyMessage()
if err != nil {
d.logger.Error("not a modify message: %s", "op", op, "err", err)
return
}
d.logger.Info("modify request", "dn", m.DN)
var entries []*gldap.Entry
_, _, entries = find(d.t, fmt.Sprintf("(%s)", m.DN), d.users)
if len(entries) == 0 {
_, _, entries = find(d.t, m.DN, d.groups)
}
if len(entries) == 0 {
return
}
if len(entries) > 1 {
res.SetResultCode(gldap.ResultInappropriateMatching)
res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(entries)))
return
}
d.mu.Lock()
defer d.mu.Unlock()
e := entries[0]
if entries[0].Attributes == nil {
e.Attributes = []*gldap.EntryAttribute{}
}
res.SetMatchedDN(entries[0].DN)
for _, chg := range m.Changes {
// find specific attr
var foundAttr *gldap.EntryAttribute
var foundAt int
for i, a := range e.Attributes {
if a.Name == chg.Modification.Type {
foundAttr = a
foundAt = i
}
}
// then apply operation
switch chg.Operation {
case gldap.AddAttribute:
if foundAttr != nil {
foundAttr.AddValue(chg.Modification.Vals...)
} else {
e.Attributes = append(e.Attributes, gldap.NewEntryAttribute(chg.Modification.Type, chg.Modification.Vals))
}
case gldap.DeleteAttribute:
if foundAttr != nil {
// slice out the deleted attribute
copy(e.Attributes[foundAt:], e.Attributes[foundAt+1:])
e.Attributes = e.Attributes[:len(e.Attributes)-1]
}
case gldap.ReplaceAttribute:
if foundAttr != nil {
// we're updating what the ptr points at, so disable lint of
// unused var
//nolint:staticcheck
foundAttr = gldap.NewEntryAttribute(chg.Modification.Type, chg.Modification.Vals)
}
}
}
res.SetResultCode(gldap.ResultSuccess)
}
}
func (d *Directory) handleAdd(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleAdd"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewResponse(gldap.WithApplicationCode(gldap.ApplicationAddResponse), gldap.WithResponseCode(gldap.ResultOperationsError))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing result: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetAddMessage()
if err != nil {
d.logger.Error("not an add message: %s", "op", op, "err", err)
return
}
d.logger.Info("add request", "dn", m.DN)
if found, _, _ := find(d.t, fmt.Sprintf("(%s)", m.DN), d.users); found {
res.SetResultCode(gldap.ResultEntryAlreadyExists)
res.SetDiagnosticMessage(fmt.Sprintf("entry exists for DN: %s", m.DN))
return
}
attrs := map[string][]string{}
for _, a := range m.Attributes {
attrs[a.Type] = a.Vals
}
newEntry := gldap.NewEntry(m.DN, attrs)
d.mu.Lock()
defer d.mu.Unlock()
d.users = append(d.users, newEntry)
res.SetResultCode(gldap.ResultSuccess)
}
}
func (d *Directory) handleDelete(t TestingT) func(w *gldap.ResponseWriter, r *gldap.Request) {
const op = "testdirectory.(Directory).handleDelete"
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
return func(w *gldap.ResponseWriter, r *gldap.Request) {
d.logger.Debug(op)
res := r.NewResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject), gldap.WithApplicationCode(gldap.ApplicationDelResponse))
defer func() {
err := w.Write(res)
if err != nil {
d.logger.Error("error writing response: %s", "op", op, "err", err)
return
}
}()
m, err := r.GetDeleteMessage()
if err != nil {
d.logger.Error("not a delete message: %s", "op", op, "err", err)
return
}
d.logger.Info("delete request", "dn", m.DN)
_, foundAt, _ := find(d.t, fmt.Sprintf("(%s)", m.DN), d.users)
if len(foundAt) > 0 {
if len(foundAt) > 1 {
res.SetResultCode(gldap.ResultInappropriateMatching)
res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(foundAt)))
return
}
d.mu.Lock()
defer d.mu.Unlock()
d.users = append(d.users[:foundAt[0]], d.users[foundAt[0]+1:]...)
res.SetResultCode(gldap.ResultSuccess)
return
}
_, foundAt, _ = find(d.t, fmt.Sprintf("(%s)", m.DN), d.groups)
if len(foundAt) > 0 {
if len(foundAt) > 1 {
res.SetResultCode(gldap.ResultInappropriateMatching)
res.SetDiagnosticMessage(fmt.Sprintf("more than one match: %d entries", len(foundAt)))
return
}
d.mu.Lock()
defer d.mu.Unlock()
d.groups = append(d.groups[:foundAt[0]], d.groups[foundAt[0]+1:]...)
res.SetResultCode(gldap.ResultSuccess)
return
}
return //nolint:gosimple // (ignore redundant return)
}
}
func (d *Directory) findMembers(filter string, opt ...Option) (bool, []*gldap.Entry) {
opts := getOpts(d.t, opt...)
var matches []*gldap.Entry
for _, e := range d.groups {
members := e.GetAttributeValues("member")
for _, m := range members {
if ok, _ := match(filter, "member="+m); ok {
matches = append(matches, e)
if opts.withFirst {
return true, matches
}
}
}
}
if len(matches) > 0 {
return true, matches
}
return false, nil
}
func find(t TestingT, filter string, entries []*gldap.Entry, opt ...Option) (bool, []int, []*gldap.Entry) {
opts := getOpts(t, opt...)
var matches []*gldap.Entry
var matchIndexes []int
for idx, e := range entries {
if ok, _ := match(filter, e.DN); ok {
matches = append(matches, e)
matchIndexes = append(matchIndexes, idx)
if opts.withFirst {
return true, []int{idx}, matches
}
}
}
if len(matches) > 0 {
return true, matchIndexes, matches
}
return false, nil, nil
}
func match(filter string, attr string) (bool, error) {
// TODO: make this actually do something more reasonable with the search
// request filter
re := regexp.MustCompile(`\((.*?)\)`)
submatchall := re.FindAllString(filter, -1)
for _, element := range submatchall {
element = strings.ReplaceAll(element, "*", "")
element = strings.Trim(element, "|(")
element = strings.Trim(element, "(")
element = strings.Trim(element, ")")
element = strings.TrimSpace(element)
if strings.Contains(attr, element) {
return true, nil
}
}
return false, nil
}
// Conn returns an *ldap.Conn that's connected (using whatever tls.Config is
// appropriate for the directory) and ready send requests to the directory.
func (d *Directory) Conn() *ldap.Conn {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
require := require.New(d.t)
var conn *ldap.Conn
retryAttempt := 5
retryErrFn := func(e error) error {
if retryAttempt > 0 {
fmt.Println(retryAttempt)
retryAttempt--
return backoff.Permanent(e)
}
return backoff.Permanent(e)
}
err := backoff.Retry(func() error {
var connErr error
if d.useTLS {
if conn, connErr = ldap.DialURL(fmt.Sprintf("ldaps://%s:%d", d.Host(), d.Port()), ldap.DialWithTLSConfig(d.client)); connErr != nil {
return retryErrFn(connErr)
}
return nil
}
if conn, connErr = ldap.DialURL(fmt.Sprintf("ldap://%s:%d", d.Host(), d.Port())); connErr != nil {
return retryErrFn(connErr)
}
return nil
}, backoff.NewConstantBackOff(1*time.Second))
require.NoError(err)
return conn
}
// Cert returns the pem-encoded certificate used by the Directory.
func (d *Directory) Cert() string {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
require := require.New(d.t)
require.NotNil(d.server)
require.Len(d.server.Certificates, 1)
cert := d.server.Certificates[0]
require.NotNil(cert)
require.Len(cert.Certificate, 1)
var buf bytes.Buffer
err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
require.NoError(err)
return buf.String()
}
// Port returns the port the directory is listening on
func (d *Directory) Port() int {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
return d.port
}
// Host returns the host the directory is listening on
func (d *Directory) Host() string {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
return d.host
}
// ClientCert returns the pem-encoded certificate which can be used by a client
// for mTLS.
func (d *Directory) ClientCert() string {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
require := require.New(d.t)
require.NotNil(d.client)
require.Len(d.client.Certificates, 1)
cert := d.client.Certificates[0]
require.NotNil(cert)
require.Len(cert.Certificate, 1)
var buf bytes.Buffer
err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
require.NoError(err)
return buf.String()
}
// ClientKey returns the pem-encoded private key which can be used by a client
// for mTLS.
func (d *Directory) ClientKey() string {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
require := require.New(d.t)
require.NotNil(d.client)
require.Len(d.client.Certificates, 1)
privBytes, err := x509.MarshalPKCS8PrivateKey(d.client.Certificates[0].PrivateKey)
require.NoError(err)
pemKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
require.NotNil(pemKey)
return string(pemKey)
}
// Controls returns all the current bind controls for the Directory
func (d *Directory) Controls() []gldap.Control {
return d.controls
}
// SetControls sets the bind controls.
func (d *Directory) SetControls(controls ...gldap.Control) {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
d.mu.Lock()
defer d.mu.Unlock()
d.controls = controls
}
// Users returns all the current user entries in the Directory
func (d *Directory) Users() []*gldap.Entry {
return d.users
}
// SetUsers sets the user entries.
func (d *Directory) SetUsers(users ...*gldap.Entry) {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
d.mu.Lock()
defer d.mu.Unlock()
d.users = users
}
// Groups returns all the current group entries in the Directory
func (d *Directory) Groups() []*gldap.Entry {
return d.groups
}
// SetGroups sets the group entries.
func (d *Directory) SetGroups(groups ...*gldap.Entry) {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
d.mu.Lock()
defer d.mu.Unlock()
d.groups = groups
}
// SetTokenGroups will set the tokenGroup entries.
func (d *Directory) SetTokenGroups(tokenGroups map[string][]*gldap.Entry) {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
d.mu.Lock()
defer d.mu.Unlock()
d.tokenGroups = tokenGroups
}
// TokenGroups will return the tokenGroup entries
func (d *Directory) TokenGroups() map[string][]*gldap.Entry {
return d.tokenGroups
}
// AllowAnonymousBind returns the allow anon bind setting
func (d *Directory) AllowAnonymousBind() bool {
return d.allowAnonymousBind
}
// SetAllowAnonymousBind enables/disables anon binds
func (d *Directory) SetAllowAnonymousBind(enabled bool) {
if v, ok := interface{}(d.t).(HelperT); ok {
v.Helper()
}
d.mu.Lock()
defer d.mu.Unlock()
d.allowAnonymousBind = enabled
}
func (d *Directory) logSearchRequest(m *gldap.SearchMessage) {
d.logger.Info("search request",
"baseDN", m.BaseDN,
"scope", m.Scope,
"filter", m.Filter,
"attributes", m.Attributes,
)
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package testdirectory
import (
"strings"
"github.com/hashicorp/go-hclog"
"github.com/jimlambrt/gldap"
)
// Option defines a common functional options type which can be used in a
// variadic parameter pattern.
type Option func(interface{})
// getOpts gets the defaults and applies the opt overrides passed in
func getOpts(t TestingT, opt ...Option) options {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
opts := defaults(t)
applyOpts(&opts, opt...)
return opts
}
// applyOpts takes a pointer to the options struct as a set of default options
// and applies the slice of opts as overrides.
func applyOpts(opts interface{}, opt ...Option) {
for _, o := range opt {
if o == nil { // ignore any nil Options
continue
}
o(opts)
}
}
// options are the set of available options for test functions
type options struct {
withPort int
withHost string
withLogger hclog.Logger
withNoTLS bool
withMTLS bool
withDisablePanicRecovery bool
withDefaults *Defaults
withMembersOf []string
withTokenGroupSIDs [][]byte
withFirst bool
}
func defaults(t TestingT) options {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
debugLogger := hclog.New(&hclog.LoggerOptions{
Name: "testdirectory-default-logger",
Level: hclog.Error,
})
return options{
withLogger: debugLogger,
withHost: "localhost",
withDefaults: &Defaults{
UserAttr: DefaultUserAttr,
GroupAttr: DefaultGroupAttr,
UserDN: DefaultUserDN,
GroupDN: DefaultGroupDN,
},
}
}
// Defaults define a type for composing all the defaults for Directory.Start(...)
type Defaults struct {
UserAttr string
GroupAttr string
// Users configures the user entries which are empty by default
Users []*gldap.Entry
// Groups configures the group entries which are empty by default
Groups []*gldap.Entry
// TokenGroups configures the tokenGroup entries which are empty be default
TokenGroups map[string][]*gldap.Entry
// UserDN is the base distinguished name to use when searching for users
// which is "ou=people,dc=example,dc=org" by default
UserDN string
// GroupDN is the base distinguished name to use when searching for groups
// which is "ou=groups,dc=example,dc=org" by default
GroupDN string
// AllowAnonymousBind determines if anon binds are allowed
AllowAnonymousBind bool
// UPNDomain is the userPrincipalName domain, which enables a
// userPrincipalDomain login with [username]@UPNDomain (optional)
UPNDomain string
}
// WithDefaults provides an option to provide a set of defaults to
// Directory.Start(...) which make it much more composable.
func WithDefaults(t TestingT, defaults *Defaults) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
if defaults != nil {
if defaults.AllowAnonymousBind {
o.withDefaults.AllowAnonymousBind = true
}
if defaults.Users != nil {
o.withDefaults.Users = defaults.Users
}
if defaults.Groups != nil {
o.withDefaults.Groups = defaults.Groups
}
if defaults.UserDN != "" {
o.withDefaults.UserDN = defaults.UserDN
}
if defaults.GroupDN != "" {
o.withDefaults.GroupDN = defaults.GroupDN
}
if len(defaults.TokenGroups) > 0 {
o.withDefaults.TokenGroups = defaults.TokenGroups
}
if defaults.UserAttr != "" {
o.withDefaults.UserAttr = defaults.UserAttr
}
if defaults.GroupAttr != "" {
o.withDefaults.GroupAttr = defaults.GroupAttr
}
if defaults.UPNDomain != "" {
o.withDefaults.UPNDomain = defaults.UPNDomain
}
}
}
}
}
// WithMTLS provides the option to use mTLS for the directory.
func WithMTLS(t TestingT) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withMTLS = true
}
}
}
// WithNoTLS provides the option to not use TLS for the directory.
func WithNoTLS(t TestingT) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withNoTLS = true
}
}
}
// WithLogger provides the optional logger for the directory.
func WithLogger(t TestingT, l hclog.Logger) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withLogger = l
}
}
}
// WithPort provides an optional port for the directory. 0 causes a
// started server with a random port. Any other value returns a started server
// on that port.
func WithPort(t TestingT, port int) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withPort = port
}
}
}
// WithHost provides an optional hostname for the directory
func WithHost(t TestingT, host string) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withHost = strings.TrimSpace(host)
}
}
}
// withFirst provides the option to only find the first match.
func withFirst(t TestingT) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withFirst = true
}
}
}
// WithMembersOf specifies optional memberOf attributes for user
// entries
func WithMembersOf(t TestingT, membersOf ...string) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withMembersOf = membersOf
}
}
}
// WithTokenGroups specifies optional test tokenGroups SID attributes for user
// entries
func WithTokenGroups(t TestingT, tokenGroupSID ...[]byte) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withTokenGroupSIDs = tokenGroupSID
}
}
}
func WithDisablePanicRecovery(t TestingT, disable bool) Option {
return func(o interface{}) {
if o, ok := o.(*options); ok {
o.withDisablePanicRecovery = disable
}
}
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package testdirectory
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"time"
"github.com/jimlambrt/gldap"
"github.com/stretchr/testify/require"
)
// NewMemberOf creates memberOf attributes which can be assigned to user
// entries. Supported Options: WithDefaults
func NewMemberOf(t TestingT, groupNames []string, opt ...Option) []string {
opts := getOpts(t, opt...)
DNs := make([]string, 0, len(groupNames))
for _, n := range groupNames {
DNs = append(DNs, fmt.Sprintf("%s=%s,%s", opts.withDefaults.GroupAttr, n, opts.withDefaults.GroupDN))
}
return DNs
}
// NewUsers creates user entries. Options supported: WithDefaults, WithMembersOf
func NewUsers(t TestingT, userNames []string, opt ...Option) []*gldap.Entry {
opts := getOpts(t, opt...)
entries := make([]*gldap.Entry, 0, len(userNames))
for _, n := range userNames {
entryAttrs := map[string][]string{
"name": {n},
"email": {fmt.Sprintf("%s@example.com", n)},
"password": {"password"},
}
if len(opts.withMembersOf) > 0 {
entryAttrs["memberOf"] = opts.withMembersOf
}
if len(opts.withTokenGroupSIDs) > 0 {
groups := make([]string, 0, len(opts.withTokenGroupSIDs))
for _, s := range opts.withTokenGroupSIDs {
groups = append(groups, string(s))
}
entryAttrs["tokenGroups"] = groups
}
var DN string
switch {
case opts.withDefaults.UPNDomain != "":
DN = fmt.Sprintf("userPrincipalName=%s@%s,%s", n, opts.withDefaults.UPNDomain, opts.withDefaults.UserDN)
default:
DN = fmt.Sprintf("%s=%s,%s", opts.withDefaults.UserAttr, n, opts.withDefaults.UserDN)
}
entries = append(entries,
gldap.NewEntry(
DN,
entryAttrs,
),
)
}
return entries
}
// NewGroup creates a group entry. Options supported: WithDefaults
func NewGroup(t TestingT, groupName string, memberNames []string, opt ...Option) *gldap.Entry {
opts := getOpts(t, opt...)
members := make([]string, 0, len(memberNames))
for _, n := range memberNames {
var DN string
switch {
case opts.withDefaults.UPNDomain != "":
DN = fmt.Sprintf("userPrincipalName=%s@%s,%s", n, opts.withDefaults.UPNDomain, opts.withDefaults.UserDN)
default:
DN = fmt.Sprintf("%s=%s,%s", opts.withDefaults.UserAttr, n, opts.withDefaults.UserDN)
}
members = append(members, DN)
}
return gldap.NewEntry(
fmt.Sprintf("%s=%s,%s", opts.withDefaults.GroupAttr, groupName, opts.withDefaults.GroupDN),
map[string][]string{
"member": members,
})
}
// FreePort just returns an available free localhost port
func FreePort(t TestingT) int {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
require := require.New(t)
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
require.NoError(err)
l, err := net.ListenTCP("tcp", addr)
require.NoError(err)
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
// supports WithMTLS
func GetTLSConfig(t TestingT, opt ...Option) (s *tls.Config, c *tls.Config) {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
require := require.New(t)
certSubject := pkix.Name{
Organization: []string{"Acme, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"New York"},
StreetAddress: []string{"Empire State Building"},
PostalCode: []string{"10118"},
}
// set up our CA certificate
ca := &x509.Certificate{
SerialNumber: genSerialNumber(t),
Subject: certSubject,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPriv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
require.NoError(err)
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPriv.PublicKey, caPriv)
require.NoError(err)
caPEM := new(bytes.Buffer)
err = pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
require.NoError(err)
privBytes, err := x509.MarshalPKCS8PrivateKey(caPriv)
require.NoError(err)
caPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "PRIVATE KEY",
Bytes: privBytes,
})
require.NoError(err)
opts := getOpts(t, opt...)
var ipAddrs []net.IP
if hostIp := net.ParseIP(opts.withHost); hostIp != nil {
ipAddrs = append(ipAddrs, hostIp)
}
cert := &x509.Certificate{
SerialNumber: genSerialNumber(t),
Subject: certSubject,
IPAddresses: ipAddrs,
DNSNames: []string{opts.withHost},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
serverCert := genCert(t, ca, caPriv, cert)
certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
serverTLSConf := &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
clientTLSConf := &tls.Config{
RootCAs: certpool,
}
if opts.withMTLS {
// setup mTLS for certs from the ca
serverTLSConf.ClientCAs = certpool
serverTLSConf.ClientAuth = tls.RequireAndVerifyClientCert
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: certSubject,
EmailAddresses: []string{"mtls.client@example.com"},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
}
clientCert := genCert(t, ca, caPriv, cert)
clientTLSConf.Certificates = []tls.Certificate{clientCert}
}
return serverTLSConf, clientTLSConf
}
func genCert(t TestingT, ca *x509.Certificate, caPriv interface{}, certTemplate *x509.Certificate) tls.Certificate {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
require := require.New(t)
certPrivKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
require.NoError(err)
certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, ca, &certPrivKey.PublicKey, caPriv)
require.NoError(err)
certPEM := new(bytes.Buffer)
err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
require.NoError(err)
privBytes, err := x509.MarshalPKCS8PrivateKey(certPrivKey)
require.NoError(err)
certPrivKeyPEM := new(bytes.Buffer)
err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "PRIVATE KEY",
Bytes: privBytes,
})
require.NoError(err)
newCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
require.NoError(err)
return newCert
}
func genSerialNumber(t TestingT) *big.Int {
if v, ok := interface{}(t).(HelperT); ok {
v.Helper()
}
require := require.New(t)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(err)
return serialNumber
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package testdirectory
import (
"errors"
"github.com/hashicorp/go-hclog"
)
// TestingT defines a very slim interface required by a Directory and any
// test functions it uses.
type TestingT interface {
Errorf(format string, args ...interface{})
FailNow()
Log(...interface{})
}
// CleanupT defines an single function interface for a testing.Cleanup(func()).
type CleanupT interface{ Cleanup(func()) }
// HelperT defines a single function interface for a testing.Helper()
type HelperT interface{ Helper() }
// InfofT defines a single function interface for a Info(format string, args ...interface{})
type InfofT interface {
Infof(format string, args ...interface{})
}
// Logger defines a logger that will implement the TestingT interface so
// it can be used with Directory.Start(...) as its t TestingT parameter.
type Logger struct {
Logger hclog.Logger
}
// NewLogger makes a new TestingLogger
func NewLogger(logger hclog.Logger) (*Logger, error) {
if logger == nil {
return nil, errors.New("missing logger")
}
return &Logger{
Logger: logger,
}, nil
}
// Errorf will output the error to the log
func (l *Logger) Errorf(format string, args ...interface{}) {
l.Logger.Error(format, args...)
}
// Infof will output the info to the log
func (l *Logger) Infof(format string, args ...interface{}) {
l.Logger.Info(format, args...)
}
// FailNow will panic
func (l *Logger) FailNow() {
panic("testing.T failed, see logs for output (if any)")
}
func (l *Logger) Log(i ...interface{}) {
l.Logger.StandardLogger(&hclog.StandardLoggerOptions{}).Println(i...)
}
// Copyright (c) Jim Lambert
// SPDX-License-Identifier: MIT
package gldap
import (
"net"
"os"
"strings"
"sync"
"testing"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/go-ldap/ldap/v3"
"github.com/stretchr/testify/require"
)
type testOptions struct {
// test options
withDescription string
}
func testDefaults() testOptions {
return testOptions{}
}
func getTestOpts(opt ...Option) testOptions {
opts := testDefaults()
applyOpts(&opts, opt...)
return opts
}
// WithDescription allows you to specify an optional description.
func WithDescription(desc string) Option {
return func(o interface{}) {
if o, ok := o.(*testOptions); ok {
o.withDescription = desc
}
}
}
func freePort(t *testing.T) int {
t.Helper()
require := require.New(t)
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
require.NoError(err)
l, err := net.ListenTCP("tcp", addr)
require.NoError(err)
defer l.Close()
return l.Addr().(*net.TCPAddr).Port
}
func testStartTLSRequestPacket(t *testing.T, messageID int) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(messageID))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
envelope.AppendChild(request)
return &packet{
Packet: envelope,
}
}
func testSearchRequestPacket(t *testing.T, s SearchMessage) *packet {
t.Helper()
require := require.New(t)
envelope := testRequestEnvelope(t, int(s.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.Scope), "Scope"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.DerefAliases), "Deref Aliases"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.SizeLimit), "Size Limit"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.TimeLimit), "Time Limit"))
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only"))
// compile and encode filter
filterPacket, err := ldap.CompileFilter(s.Filter)
require.NoError(err)
pkt.AppendChild(filterPacket)
attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attribute := range s.Attributes {
attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
}
pkt.AppendChild(attributesPacket)
envelope.AppendChild(pkt)
if len(s.Controls) > 0 {
envelope.AppendChild(encodeControls(s.Controls))
}
return &packet{
Packet: envelope,
}
}
func testSimpleBindRequestPacket(t *testing.T, m SimpleBindMessage) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(m.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(3), "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.UserName, "User Name"))
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, string(m.Password), "Password"))
envelope.AppendChild(pkt)
if len(m.Controls) > 0 {
envelope.AppendChild(encodeControls(m.Controls))
}
return &packet{
Packet: envelope,
}
}
func testUnbindRequestPacket(t *testing.T, m UnbindMessage) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(m.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationUnbindRequest, nil, "Unbind Request")
envelope.AppendChild(pkt)
return &packet{
Packet: envelope,
}
}
func testModifyRequestPacket(t *testing.T, m ModifyMessage) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(m.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN"))
changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
for _, change := range m.Changes {
changes.AppendChild(change.encode())
}
pkt.AppendChild(changes)
envelope.AppendChild(pkt)
if len(m.Controls) > 0 {
envelope.AppendChild(encodeControls(m.Controls))
}
return &packet{
Packet: envelope,
}
}
func testDeleteRequestPacket(t *testing.T, m DeleteMessage) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(m.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationDelRequest, nil, "Delete Request")
pkt.Data.Write([]byte(m.DN))
envelope.AppendChild(pkt)
if len(m.Controls) > 0 {
envelope.AppendChild(encodeControls(m.Controls))
}
return &packet{
Packet: envelope,
}
}
func testAddRequestPacket(t *testing.T, m AddMessage) *packet {
t.Helper()
envelope := testRequestEnvelope(t, int(m.GetID()))
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, m.DN, "DN"))
attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attr := range m.Attributes {
attributes.AppendChild(attr.encode())
}
pkt.AppendChild(attributes)
envelope.AppendChild(pkt)
if len(m.Controls) > 0 {
envelope.AppendChild(encodeControls(m.Controls))
}
return &packet{
Packet: envelope,
}
}
func testRequestEnvelope(t *testing.T, messageID int) *ber.Packet {
p := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
p.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(messageID), "MessageID"))
return p
}
func testControlString(t *testing.T, controlType string, opt ...Option) *ControlString {
t.Helper()
require := require.New(t)
c, err := NewControlString(controlType, opt...)
require.NoError(err)
return c
}
func testControlManageDsaIT(t *testing.T, opt ...Option) *ControlManageDsaIT {
t.Helper()
require := require.New(t)
c, err := NewControlManageDsaIT(opt...)
require.NoError(err)
return c
}
func testControlMicrosoftNotification(t *testing.T, opt ...Option) *ControlMicrosoftNotification {
t.Helper()
require := require.New(t)
c, err := NewControlMicrosoftNotification(opt...)
require.NoError(err)
return c
}
func testControlMicrosoftServerLinkTTL(t *testing.T, opt ...Option) *ControlMicrosoftServerLinkTTL {
t.Helper()
require := require.New(t)
c, err := NewControlMicrosoftServerLinkTTL(opt...)
require.NoError(err)
return c
}
func testControlMicrosoftShowDeleted(t *testing.T, opt ...Option) *ControlMicrosoftShowDeleted {
t.Helper()
require := require.New(t)
c, err := NewControlMicrosoftShowDeleted(opt...)
require.NoError(err)
return c
}
func testControlPaging(t *testing.T, pagingSize uint32, opt ...Option) *ControlPaging {
t.Helper()
require := require.New(t)
c, err := NewControlPaging(uint32(pagingSize), opt...)
require.NoError(err)
return c
}
// TestWithDebug specifies that the test should be run under "debug" mode
func TestWithDebug(t *testing.T) bool {
t.Helper()
return strings.ToLower(os.Getenv("DEBUG")) == "true"
}
func TestEncodeString(t *testing.T, tag ber.Tag, s string, opt ...Option) string {
t.Helper()
opts := getTestOpts(opt...)
pkt := ber.NewString(ber.ClassUniversal, ber.TypePrimitive, tag, s, opts.withDescription)
dec, err := ber.DecodePacketErr(pkt.Bytes())
require.NoError(t, err)
return string(dec.Bytes())
}
type safeBuf struct {
buf *strings.Builder
mu *sync.Mutex
}
func testSafeBuf(t *testing.T) *safeBuf {
t.Helper()
return &safeBuf{
mu: &sync.Mutex{},
buf: &strings.Builder{},
}
}
func (w *safeBuf) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.buf.Write(p)
}
func (w *safeBuf) String() string {
w.mu.Lock()
defer w.mu.Unlock()
return w.buf.String()
}