diff --git a/gss/apcera.go b/gss/apcera.go index ab10358..7806dbd 100644 --- a/gss/apcera.go +++ b/gss/apcera.go @@ -17,6 +17,54 @@ import ( "github.com/openshift/gssapi" ) +func generate(lib *gssapi.Lib, ctx *gssapi.CtxId, msg []byte) ([]byte, error) { + message, err := lib.MakeBufferBytes(msg) + if err != nil { + return nil, err + } + + defer func() { + err = multierror.Append(err, message.Release()).ErrorOrNil() + }() + + token, err := ctx.GetMIC(gssapi.GSS_C_QOP_DEFAULT, message) + if err != nil { + return nil, err + } + + defer func() { + err = multierror.Append(err, token.Release()).ErrorOrNil() + }() + + return token.Bytes(), nil +} + +func verify(lib *gssapi.Lib, ctx *gssapi.CtxId, stripped, mac []byte) error { + message, err := lib.MakeBufferBytes(stripped) + if err != nil { + return err + } + + defer func() { + err = multierror.Append(err, message.Release()).ErrorOrNil() + }() + + token, err := lib.MakeBufferBytes(mac) + if err != nil { + return err + } + + defer func() { + err = multierror.Append(err, token.Release()).ErrorOrNil() + }() + + if _, err = ctx.VerifyMIC(message, token); err != nil { + return err + } + + return nil +} + // Client maps the TKEY name to the context that negotiated it as // well as any other internal state. type Client struct { @@ -28,16 +76,14 @@ type Client struct { } // WithConfig sets the Kerberos configuration used. -func WithConfig(_ string) func(*Client) error { - return func(c *Client) error { - return errNotSupported - } +func WithConfig[T Client](_ string) Option[T] { + return unsupportedOption[T] } // NewClient performs any library initialization necessary. // It returns a context handle for any further functions along with any error // that occurred. -func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { +func NewClient(dnsClient *dns.Client, options ...Option[Client]) (*Client, error) { client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -57,8 +103,10 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, logger: logr.Discard(), } - if err := c.setOption(options...); err != nil { - return nil, multierror.Append(err, c.lib.Unload()) + for _, option := range options { + if err := option(c); err != nil { + return nil, multierror.Append(err, c.lib.Unload()) + } } return c, nil @@ -72,54 +120,11 @@ func (c *Client) Close() error { } func (c *Client) generate(ctx *gssapi.CtxId, msg []byte) ([]byte, error) { - message, err := c.lib.MakeBufferBytes(msg) - if err != nil { - return nil, err - } - - defer func() { - err = multierror.Append(err, message.Release()).ErrorOrNil() - }() - - token, err := ctx.GetMIC(gssapi.GSS_C_QOP_DEFAULT, message) - if err != nil { - return nil, err - } - - defer func() { - err = multierror.Append(err, token.Release()).ErrorOrNil() - }() - - return token.Bytes(), nil + return generate(c.lib, ctx, msg) } func (c *Client) verify(ctx *gssapi.CtxId, stripped, mac []byte) error { - // Turn the TSIG-stripped message bytes into a *gssapi.Buffer - message, err := c.lib.MakeBufferBytes(stripped) - if err != nil { - return err - } - - defer func() { - err = multierror.Append(err, message.Release()).ErrorOrNil() - }() - - // Turn the TSIG MAC bytes into a *gssapi.Buffer - token, err := c.lib.MakeBufferBytes(mac) - if err != nil { - return err - } - - defer func() { - err = multierror.Append(err, token.Release()).ErrorOrNil() - }() - - // This is the actual verification bit - if _, err = ctx.VerifyMIC(message, token); err != nil { - return err - } - - return nil + return verify(c.lib, ctx, stripped, mac) } // NegotiateContext exchanges RFC 2930 TKEY records with the indicated DNS @@ -260,3 +265,150 @@ func (c *Client) DeleteContext(keyname string) error { return nil } + +// Server maps the TKEY name to the context that negotiated it as +// well as any other internal state. +type Server struct { + m sync.RWMutex + lib *gssapi.Lib + ctx map[string]*gssapi.CtxId + logger logr.Logger +} + +// NewServer performs any library initialization necessary. +// It returns a context handle for any further functions along with any error +// that occurred. +func NewServer(options ...func(*Server) error) (*Server, error) { + lib, err := gssapi.Load(nil) + if err != nil { + return nil, err + } + + s := &Server{ + lib: lib, + ctx: make(map[string]*gssapi.CtxId), + logger: logr.Discard(), + } + + for _, option := range options { + if err := option(s); err != nil { + return nil, multierror.Append(err, s.lib.Unload()) + } + } + + return s, nil +} + +// Close deletes any active contexts and unloads any underlying libraries as +// necessary. +// It returns any error that occurred. +func (s *Server) Close() error { + return multierror.Append(s.close(true), s.lib.Unload()).ErrorOrNil() +} + +func (s *Server) newContext() (*gssapi.CtxId, error) { + //nolint:nilnil + return nil, nil +} + +//nolint:funlen +func (s *Server) update(ctx *gssapi.CtxId, input []byte) (*gssapi.CtxId, []byte, error) { + /*var cred *gssapi.CredId + + // equivalent of GSSAPIStrictAcceptorCheck + if s.strict { //nolint:nestif + hostname, err := osHostname() + if err != nil { + return nil, "", false, err + } + + buffer, err := s.lib.MakeBufferString("host@" + hostname) + if err != nil { + return nil, "", false, err + } + + defer func() { + err = multierror.Append(err, buffer.Release()).ErrorOrNil() + }() + + service, err := buffer.Name(s.lib.GSS_C_NT_HOSTBASED_SERVICE) + if err != nil { + return nil, "", false, err + } + + defer func() { + err = multierror.Append(err, service.Release()).ErrorOrNil() + }() + + oids, err := s.lib.MakeOIDSet(s.lib.GSS_MECH_KRB5) + if err != nil { + return nil, "", false, err + } + + defer func() { + err = multierror.Append(err, oids.Release()).ErrorOrNil() + }() + + cred, _, _, err = s.lib.AcquireCred(service, gssapi.GSS_C_INDEFINITE, oids, gssapi.GSS_C_ACCEPT) + if err != nil { + return nil, "", false, err + } + + defer func() { + err = multierror.Append(err, cred.Release()).ErrorOrNil() + }() + } else {*/ + cred := s.lib.GSS_C_NO_CREDENTIAL + //} + + token, err := s.lib.MakeBufferBytes(input) + if err != nil { + return nil, nil, err + } + + defer func() { + err = multierror.Append(err, token.Release()).ErrorOrNil() + }() + + //nolint:dogsled + nctx, _, _, output, _, _, _, err := s.lib.AcceptSecContext(ctx, cred, token, s.lib.GSS_C_NO_CHANNEL_BINDINGS) + if err != nil && !s.lib.LastStatus.Major.ContinueNeeded() { + return nil, nil, err + } + + defer func() { + err = multierror.Append(err, output.Release()).ErrorOrNil() + }() + + return nctx, output.Bytes(), nil +} + +func (s *Server) generate(ctx *gssapi.CtxId, msg []byte) ([]byte, error) { + return generate(s.lib, ctx, msg) +} + +func (s *Server) verify(ctx *gssapi.CtxId, stripped, mac []byte) error { + return verify(s.lib, ctx, stripped, mac) +} + +func (s *Server) established(ctx *gssapi.CtxId) (established bool, err error) { + if ctx != nil { + _, _, _, _, _, _, established, err = ctx.InquireContext() + } + + return +} + +func (s *Server) expired(ctx *gssapi.CtxId) (expired bool, err error) { + if ctx != nil { + var duration time.Duration + _, _, duration, _, _, _, _, err = ctx.InquireContext() + expired = duration <= 0 + } + + return +} + +func (s *Server) delete(ctx *gssapi.CtxId) error { + return ctx.DeleteSecContext() +} diff --git a/gss/apcera_test.go b/gss/apcera_test.go index f7bd073..4779fe0 100644 --- a/gss/apcera_test.go +++ b/gss/apcera_test.go @@ -29,3 +29,11 @@ func TestNewClientWithConfig(t *testing.T) { _, err := gss.NewClient(new(dns.Client), gss.WithConfig("")) assert.NotNil(t, err) } + +func TestNewServer(t *testing.T) { + t.Parallel() + + if err := testNewServer(t); err != nil { + t.Fatal(err) + } +} diff --git a/gss/client.go b/gss/client.go index 69398a6..4cbaa4e 100644 --- a/gss/client.go +++ b/gss/client.go @@ -68,12 +68,12 @@ func (c *Client) close() error { c.m.RUnlock() - var errs error + var err *multierror.Error for _, k := range keys { - errs = multierror.Append(errs, c.DeleteContext(k)) + err = multierror.Append(err, c.DeleteContext(k)) } - return errs + return err.ErrorOrNil() } func (c *Client) setOption(options ...func(*Client) error) error { @@ -87,20 +87,15 @@ func (c *Client) setOption(options ...func(*Client) error) error { } // SetConfig sets the Kerberos configuration used by c. +// +// Deprecated: FIXME. func (c *Client) SetConfig(config string) error { return c.setOption(WithConfig(config)) } -// WithLogger sets the logger used. -func WithLogger(logger logr.Logger) func(*Client) error { - return func(c *Client) error { - c.logger = logger.WithName("client") - - return nil - } -} - // SetLogger sets the logger used by c. +// +// Deprecated: FIXME. func (c *Client) SetLogger(logger logr.Logger) error { - return c.setOption(WithLogger(logger)) + return c.setOption(WithLogger[Client](logger)) } diff --git a/gss/gokrb5.go b/gss/gokrb5.go index 8dfcbad..24828a9 100644 --- a/gss/gokrb5.go +++ b/gss/gokrb5.go @@ -28,9 +28,11 @@ type Client struct { } // WithConfig sets the Kerberos configuration used. -func WithConfig(config string) func(*Client) error { - return func(c *Client) error { - c.config = config +func WithConfig[T Client](config string) Option[T] { + return func(a *T) error { + if x, ok := any(a).(*Client); ok { + x.config = config + } return nil } @@ -39,7 +41,7 @@ func WithConfig(config string) func(*Client) error { // NewClient performs any library initialization necessary. // It returns a context handle for any further functions along with any error // that occurred. -func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { +func NewClient(dnsClient *dns.Client, options ...Option[Client]) (*Client, error) { client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -53,8 +55,10 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, logger: logr.Discard(), } - if err := c.setOption(options...); err != nil { - return nil, err + for _, option := range options { + if err := option(c); err != nil { + return nil, err + } } return c, nil @@ -195,3 +199,66 @@ func (c *Client) DeleteContext(keyname string) error { return nil } + +// Server maps the TKEY name to the context that negotiated it as +// well as any other internal state. +type Server struct { + m sync.RWMutex + ctx map[string]*wrapper.Acceptor + logger logr.Logger +} + +// NewServer performs any library initialization necessary. +// It returns a context handle for any further functions along with any error +// that occurred. +func NewServer(options ...func(*Server) error) (*Server, error) { + s := &Server{ + ctx: make(map[string]*wrapper.Acceptor), + logger: logr.Discard(), + } + + for _, option := range options { + if err := option(s); err != nil { + return nil, err + } + } + + return s, nil +} + +// Close deletes any active contexts and unloads any underlying libraries as +// necessary. +// It returns any error that occurred. +func (s *Server) Close() error { + return s.close(true) +} + +func (s *Server) newContext() (*wrapper.Acceptor, error) { + return wrapper.NewAcceptor(wrapper.WithLogger[wrapper.Acceptor](s.logger)) +} + +func (s *Server) update(ctx *wrapper.Acceptor, input []byte) (*wrapper.Acceptor, []byte, error) { + output, _, err := ctx.Accept(input) + + return ctx, output, err +} + +func (s *Server) generate(ctx *wrapper.Acceptor, msg []byte) ([]byte, error) { + return ctx.MakeSignature(msg) +} + +func (s *Server) verify(ctx *wrapper.Acceptor, stripped, mac []byte) error { + return ctx.VerifySignature(stripped, mac) +} + +func (s *Server) established(ctx *wrapper.Acceptor) (bool, error) { + return ctx.Established(), nil +} + +func (s *Server) expired(ctx *wrapper.Acceptor) (bool, error) { + return ctx.Expiry().Before(time.Now()), nil +} + +func (s *Server) delete(ctx *wrapper.Acceptor) error { + return ctx.Close() +} diff --git a/gss/gokrb5_test.go b/gss/gokrb5_test.go index 04e0de8..1e83463 100644 --- a/gss/gokrb5_test.go +++ b/gss/gokrb5_test.go @@ -29,3 +29,11 @@ func TestNewClientWithConfig(t *testing.T) { _, err := gss.NewClient(new(dns.Client), gss.WithConfig("")) assert.Nil(t, err) } + +func TestNewServer(t *testing.T) { + t.Parallel() + + if err := testNewServer(t); err != nil { + t.Fatal(err) + } +} diff --git a/gss/gss_test.go b/gss/gss_test.go index 305b87d..9a98e6f 100644 --- a/gss/gss_test.go +++ b/gss/gss_test.go @@ -3,8 +3,10 @@ package gss_test import ( "fmt" "net" + "net/netip" "os" "runtime" + "strconv" "testing" "time" @@ -95,7 +97,7 @@ func testExchange(t *testing.T) (err error) { dnsClient := new(dns.Client) dnsClient.Net = dnsClientTransport - gssClient, err := gss.NewClient(dnsClient, gss.WithLogger(testr.New(t))) + gssClient, err := gss.NewClient(dnsClient, gss.WithLogger[gss.Client](testr.New(t))) if err != nil { return err } @@ -180,7 +182,7 @@ func testExchangeKeytab(t *testing.T) (err error) { dnsClient := new(dns.Client) dnsClient.Net = dnsClientTransport - gssClient, err := gss.NewClient(dnsClient, gss.WithLogger(testr.New(t))) + gssClient, err := gss.NewClient(dnsClient, gss.WithLogger[gss.Client](testr.New(t))) if err != nil { return err } @@ -206,6 +208,101 @@ func TestExchange(t *testing.T) { func TestNewClientWithLogger(t *testing.T) { t.Parallel() - _, err := gss.NewClient(new(dns.Client), gss.WithLogger(logr.Discard())) + _, err := gss.NewClient(new(dns.Client), gss.WithLogger[gss.Client](logr.Discard())) assert.Nil(t, err) } + +func newServer(t *testing.T, hostname string) (string, func() error) { + t.Helper() + + gssServer, err := gss.NewServer(gss.WithLogger[gss.Server](testr.New(t))) + if err != nil { + t.Fatal(err) + } + + server := &dns.Server{ + Addr: net.JoinHostPort(hostname, "0"), + Net: "tcp4", + TsigProvider: gssServer, + Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { + switch r.Question[0].Qtype { + case dns.TypeTKEY: + gssServer.ServeDNS(w, r) + case dns.TypeA: + m := new(dns.Msg) + if rr := r.IsTsig(); rr != nil && w.TsigStatus() == nil { + m.SetReply(r) + m.SetTsig(rr.Header().Name, tsig.GSS, 300, time.Now().Unix()) + } else { + m.SetRcode(r, dns.RcodeNotAuth) + } + _ = w.WriteMsg(m) + } + }), + MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction { + return dns.MsgAccept + }, + } + + //nolint:errcheck + go server.ListenAndServe() + + for server.Listener == nil { + time.Sleep(10 * time.Millisecond) + } + + return strconv.FormatUint(uint64(netip.MustParseAddrPort(server.Listener.Addr().String()).Port()), 10), func() error { + return multierror.Append(server.Shutdown(), gssServer.Close()).ErrorOrNil() + } +} + +func testNewServer(t *testing.T) (err error) { + t.Helper() + + if testing.Short() { + t.Skip("skipping integration test") + } + + //nolint:dogsled + host, _, _, _, _, _ := testEnvironmentVariables(t) + + port, teardown := newServer(t, host) + + defer func() { + err = multierror.Append(err, teardown()).ErrorOrNil() + }() + + dnsClient := new(dns.Client) + dnsClient.Net = dnsClientTransport + + gssClient, err := gss.NewClient(dnsClient, gss.WithLogger[gss.Client](testr.New(t))) + if err != nil { + return err + } + + defer func() { + err = multierror.Append(err, gssClient.Close()).ErrorOrNil() + }() + + keyname, _, err := gssClient.NegotiateContext(net.JoinHostPort(host, port)) + if err != nil { + return err + } + + dnsClient.TsigProvider = gssClient + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn("test.example.com"), dns.TypeA) + msg.SetTsig(keyname, tsig.GSS, 300, time.Now().Unix()) + + rr, _, err := dnsClient.Exchange(msg, net.JoinHostPort(host, port)) + if err != nil { + return err + } + + if rr.Rcode != dns.RcodeSuccess { + return fmt.Errorf("DNS error: %s (%d)", dns.RcodeToString[rr.Rcode], rr.Rcode) + } + + return gssClient.DeleteContext(keyname) +} diff --git a/gss/options.go b/gss/options.go new file mode 100644 index 0000000..c475df5 --- /dev/null +++ b/gss/options.go @@ -0,0 +1,25 @@ +package gss + +import "github.com/go-logr/logr" + +// Option is the signature for all constructor options. +type Option[T Client | Server] func(*T) error + +// WithLogger sets the logger used. +func WithLogger[T Client | Server](logger logr.Logger) Option[T] { + return func(a *T) error { + switch x := any(a).(type) { + case *Client: + x.logger = logger.WithName("client") + case *Server: + x.logger = logger.WithName("server") + } + + return nil + } +} + +//nolint:nolintlint,unused +func unsupportedOption[T Client | Server](_ *T) error { + return errNotSupported +} diff --git a/gss/server.go b/gss/server.go new file mode 100644 index 0000000..293bfb9 --- /dev/null +++ b/gss/server.go @@ -0,0 +1,280 @@ +package gss + +import ( + "encoding/hex" + "time" + + "github.com/bodgit/tsig" + "github.com/bodgit/tsig/internal/util" + multierror "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" +) + +var ( + _ dns.TsigProvider = new(Server) + _ dns.Handler = new(Server) +) + +func (s *Server) close(all bool) error { + s.m.Lock() + defer s.m.Unlock() + + var errs *multierror.Error + + for keyname, ctx := range s.ctx { + switch { + case !all: + expired, err := s.expired(ctx) + if err != nil { + errs = multierror.Append(errs, err) + + continue + } + + if !expired { + continue + } + + fallthrough + default: + if err := s.delete(ctx); err != nil { + errs = multierror.Append(errs, err) + } + + delete(s.ctx, keyname) + } + } + + return errs.ErrorOrNil() +} + +// Generate generates the TSIG MAC based on the established context. +// It is called with the bytes of the DNS message, and the partial TSIG +// record containing the algorithm and name which is the negotiated TKEY +// for this context. +// It returns the bytes for the TSIG MAC and any error that occurred. +func (s *Server) Generate(msg []byte, t *dns.TSIG) ([]byte, error) { + if err := s.close(false); err != nil { + return nil, err + } + + if dns.CanonicalName(t.Algorithm) != tsig.GSS { + return nil, dns.ErrKeyAlg + } + + s.m.RLock() + defer s.m.RUnlock() + + ctx, ok := s.ctx[t.Hdr.Name] + if !ok { // || !ctx.Established() { + return nil, dns.ErrSecret + } + + return s.generate(ctx, msg) +} + +// Verify verifies the TSIG MAC based on the established context. +// It is called with the bytes of the DNS message, and the TSIG record +// containing the algorithm, MAC, and name which is the negotiated TKEY +// for this context. +// It returns any error that occurred. +func (s *Server) Verify(stripped []byte, t *dns.TSIG) error { + if err := s.close(false); err != nil { + return err + } + + if dns.CanonicalName(t.Algorithm) != tsig.GSS { + return dns.ErrKeyAlg + } + + s.m.RLock() + defer s.m.RUnlock() + + ctx, ok := s.ctx[t.Hdr.Name] + if !ok { + return dns.ErrSecret + } + + mac, err := hex.DecodeString(t.MAC) + if err != nil { + return err + } + + return s.verify(ctx, stripped, mac) +} + +func extractTKEY(r *dns.Msg) *dns.TKEY { + if len(r.Question) != 1 || r.Question[0].Qtype != dns.TypeTKEY || r.Question[0].Qclass != dns.ClassANY { + return nil + } + + if len(r.Extra) != 1 { + return nil + } + + tkey, ok := r.Extra[0].(*dns.TKEY) + if !ok { + return nil + } + + if tkey.Hdr.Name != r.Question[0].Name || tkey.Hdr.Rrtype != dns.TypeTKEY || + tkey.Hdr.Class != dns.ClassANY || tkey.Hdr.Ttl != 0 { + return nil + } + + return tkey +} + +// ServeDNS satisfies the dns.Handler interface. It only handles queries for +// TKEY records and will refuse anything else. +// +//nolint:cyclop,funlen +func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + s.logger.Info("incoming", "request", r.String()) + + m := new(dns.Msg) + defer func() { + if err := w.WriteMsg(m); err != nil { + s.logger.Error(err, "write error") + } + }() + + if err := s.close(false); err != nil { + m.SetRcode(r, dns.RcodeServerFailure) + + return + } + + tkey := extractTKEY(r) + if tkey == nil { + m.SetRcode(r, dns.RcodeRefused) + + return + } + + keyname := tkey.Hdr.Name + + input, err := hex.DecodeString(tkey.Key) + if err != nil { + s.logger.Error(err, "unable to decode key") + m.SetRcode(r, dns.RcodeFormatError) + + return + } + + m.SetReply(r) + + rr := &dns.TKEY{ + Hdr: dns.RR_Header{ + Name: keyname, + Rrtype: dns.TypeTKEY, + Class: dns.ClassANY, + }, + Algorithm: tkey.Algorithm, + Mode: tkey.Mode, + Inception: tkey.Inception, + Expiration: tkey.Expiration, + KeySize: tkey.KeySize, + Key: tkey.Key, + } + m.Answer = append(m.Answer, rr) + + if tkey.Algorithm != tsig.GSS { + rr.Error = dns.RcodeBadAlg + + return + } + + s.m.Lock() + defer s.m.Unlock() + + ctx, ok := s.ctx[keyname] + + switch tkey.Mode { + case util.TkeyModeGSS: + var ( + established bool + expired bool + ) + + if ok { + if established, err = s.established(ctx); err != nil { + rr.Error = dns.RcodeServerFailure + + return + } + + if expired, err = s.expired(ctx); err != nil { + rr.Error = dns.RcodeServerFailure + + return + } + } + + switch { + case ok && established && !expired: + rr.Error = dns.RcodeBadName + + return + case ok && established && expired: + delete(s.ctx, keyname) + + fallthrough + case !ok: + if ctx, err = s.newContext(); err != nil { + s.logger.Error(err, "unable to create acceptor") + + rr.Error = dns.RcodeServerFailure + + return + } + } + + ctx, output, err := s.update(ctx, input) + if err != nil { + s.logger.Error(err, "unable to accept") + + rr.Error = dns.RcodeServerFailure + + return + } + + s.ctx[keyname] = ctx + + rr.KeySize = uint16(len(output)) + rr.Key = hex.EncodeToString(output) + + if established, err = s.established(ctx); err != nil { + rr.Error = dns.RcodeServerFailure + + return + } + + if established { + m.SetTsig(keyname, tsig.GSS, 300, time.Now().Unix()) + } + + s.logger.Info("outgoing", "response", m.String()) + case util.TkeyModeDelete: //nolint:wsl + /* + switch { + case !ok: + rr.Error = dns.RcodeBadName + case r.IsTsig() != nil && w.TsigStatus() == nil: + if err := s.delete(ctx); err != nil { + rr.Error = dns.RcodeServerFailure + + return + } + + delete(s.ctx, keyname) + default: + rr.Error = dns.RcodeNotAuth + } + */ + + fallthrough + default: + rr.Error = dns.RcodeBadMode + } +} diff --git a/gss/sspi.go b/gss/sspi.go index d9d0dc8..0f727ed 100644 --- a/gss/sspi.go +++ b/gss/sspi.go @@ -28,16 +28,14 @@ type Client struct { } // WithConfig sets the Kerberos configuration used. -func WithConfig(_ string) func(*Client) error { - return func(c *Client) error { - return errNotSupported - } +func WithConfig[T Client](_ string) Option[T] { + return unsupportedOption[T] } // NewClient performs any library initialization necessary. // It returns a context handle for any further functions along with any error // that occurred. -func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, error) { +func NewClient(dnsClient *dns.Client, options ...Option[Client]) (*Client, error) { client, err := util.CopyDNSClient(dnsClient) if err != nil { return nil, err @@ -51,8 +49,10 @@ func NewClient(dnsClient *dns.Client, options ...func(*Client) error) (*Client, logger: logr.Discard(), } - if err := c.setOption(options...); err != nil { - return nil, err + for _, option := range options { + if err := option(c); err != nil { + return nil, err + } } return c, nil