Skip to content

Commit 0d08b50

Browse files
committed
Use common interface to fetch secrets in HTTP client config
Signed-off-by: Daniel Hrabovcak <thespiritxiii@gmail.com>
1 parent 1d8c672 commit 0d08b50

File tree

2 files changed

+132
-151
lines changed

2 files changed

+132
-151
lines changed

config/http_config.go

Lines changed: 105 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"crypto/tls"
2121
"crypto/x509"
2222
"encoding/json"
23+
"errors"
2324
"fmt"
2425
"net"
2526
"net/http"
@@ -29,6 +30,7 @@ import (
2930
"strings"
3031
"sync"
3132
"time"
33+
"unsafe"
3234

3335
"github.com/mwitkow/go-conntrack"
3436
"golang.org/x/net/http/httpproxy"
@@ -546,21 +548,17 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
546548

547549
// If a authorization_credentials is provided, create a round tripper that will set the
548550
// Authorization header correctly on each request.
549-
if cfg.Authorization != nil && len(cfg.Authorization.Credentials) > 0 {
550-
rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, cfg.Authorization.Credentials, rt)
551-
} else if cfg.Authorization != nil && len(cfg.Authorization.CredentialsFile) > 0 {
552-
rt = NewAuthorizationCredentialsFileRoundTripper(cfg.Authorization.Type, cfg.Authorization.CredentialsFile, rt)
551+
if cfg.Authorization != nil && (len(cfg.Authorization.Credentials) > 0 || len(cfg.Authorization.CredentialsFile) > 0) {
552+
rt = NewAuthorizationCredentialsRoundTripper(cfg.Authorization.Type, secretFrom(cfg.Authorization.Credentials, cfg.Authorization.CredentialsFile), rt)
553553
}
554554
// Backwards compatibility, be nice with importers who would not have
555555
// called Validate().
556-
if len(cfg.BearerToken) > 0 {
557-
rt = NewAuthorizationCredentialsRoundTripper("Bearer", cfg.BearerToken, rt)
558-
} else if len(cfg.BearerTokenFile) > 0 {
559-
rt = NewAuthorizationCredentialsFileRoundTripper("Bearer", cfg.BearerTokenFile, rt)
556+
if len(cfg.BearerToken) > 0 || len(cfg.BearerTokenFile) > 0 {
557+
rt = NewAuthorizationCredentialsRoundTripper("Bearer", secretFrom(cfg.BearerToken, cfg.BearerTokenFile), rt)
560558
}
561559

562560
if cfg.BasicAuth != nil {
563-
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.UsernameFile, cfg.BasicAuth.PasswordFile, rt)
561+
rt = NewBasicAuthRoundTripper(secretFrom(Secret(cfg.BasicAuth.Username), cfg.BasicAuth.UsernameFile), secretFrom(cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile), rt)
564562
}
565563

566564
if cfg.OAuth2 != nil {
@@ -587,52 +585,67 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
587585
return NewTLSRoundTripper(tlsConfig, cfg.TLSConfig.roundTripperSettings(), newRT)
588586
}
589587

590-
type authorizationCredentialsRoundTripper struct {
591-
authType string
592-
authCredentials Secret
593-
rt http.RoundTripper
588+
type secret interface {
589+
fetch() (string, error)
594590
}
595591

596-
// NewAuthorizationCredentialsRoundTripper adds the provided credentials to a
597-
// request unless the authorization header has already been set.
598-
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Secret, rt http.RoundTripper) http.RoundTripper {
599-
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
592+
type inlineSecret struct {
593+
text string
600594
}
601595

602-
func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
603-
if len(req.Header.Get("Authorization")) == 0 {
604-
req = cloneRequest(req)
605-
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, string(rt.authCredentials)))
596+
func (s *inlineSecret) fetch() (string, error) {
597+
return s.text, nil
598+
}
599+
600+
type fileSecret struct {
601+
file string
602+
}
603+
604+
func (s *fileSecret) fetch() (string, error) {
605+
fileBytes, err := os.ReadFile(s.file)
606+
if err != nil {
607+
return "", fmt.Errorf("unable to read file %s: %w", s.file, err)
606608
}
607-
return rt.rt.RoundTrip(req)
609+
return strings.TrimSpace(string(fileBytes)), nil
608610
}
609611

610-
func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
611-
if ci, ok := rt.rt.(closeIdler); ok {
612-
ci.CloseIdleConnections()
612+
func secretFrom(text Secret, file string) secret {
613+
if text != "" {
614+
return &inlineSecret{
615+
text: string(text),
616+
}
617+
}
618+
if file != "" {
619+
return &fileSecret{
620+
file: file,
621+
}
613622
}
623+
return nil
614624
}
615625

616-
type authorizationCredentialsFileRoundTripper struct {
617-
authType string
618-
authCredentialsFile string
619-
rt http.RoundTripper
626+
type authorizationCredentialsRoundTripper struct {
627+
authType string
628+
authCredentials secret
629+
rt http.RoundTripper
620630
}
621631

622-
// NewAuthorizationCredentialsFileRoundTripper adds the authorization
623-
// credentials read from the provided file to a request unless the authorization
624-
// header has already been set. This file is read for every request.
625-
func NewAuthorizationCredentialsFileRoundTripper(authType, authCredentialsFile string, rt http.RoundTripper) http.RoundTripper {
626-
return &authorizationCredentialsFileRoundTripper{authType, authCredentialsFile, rt}
632+
// NewAuthorizationCredentialsRoundTripper adds the authorization credentials
633+
// read from the provided secret to a request unless the authorization header
634+
// has already been set.
635+
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper {
636+
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
627637
}
628638

629-
func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
639+
func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
630640
if len(req.Header.Get("Authorization")) == 0 {
631-
b, err := os.ReadFile(rt.authCredentialsFile)
632-
if err != nil {
633-
return nil, fmt.Errorf("unable to read authorization credentials file %s: %s", rt.authCredentialsFile, err)
641+
var authCredentials string
642+
if rt.authCredentials != nil {
643+
var err error
644+
authCredentials, err = rt.authCredentials.fetch()
645+
if err != nil {
646+
return nil, fmt.Errorf("unable to get authorization credentials: %w", err)
647+
}
634648
}
635-
authCredentials := strings.TrimSpace(string(b))
636649

637650
req = cloneRequest(req)
638651
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))
@@ -641,49 +654,43 @@ func (rt *authorizationCredentialsFileRoundTripper) RoundTrip(req *http.Request)
641654
return rt.rt.RoundTrip(req)
642655
}
643656

644-
func (rt *authorizationCredentialsFileRoundTripper) CloseIdleConnections() {
657+
func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
645658
if ci, ok := rt.rt.(closeIdler); ok {
646659
ci.CloseIdleConnections()
647660
}
648661
}
649662

650663
type basicAuthRoundTripper struct {
651-
username string
652-
password Secret
653-
usernameFile string
654-
passwordFile string
655-
rt http.RoundTripper
664+
username secret
665+
password secret
666+
rt http.RoundTripper
656667
}
657668

658669
// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
659670
// already been set.
660-
func NewBasicAuthRoundTripper(username string, password Secret, usernameFile, passwordFile string, rt http.RoundTripper) http.RoundTripper {
661-
return &basicAuthRoundTripper{username, password, usernameFile, passwordFile, rt}
671+
func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper {
672+
return &basicAuthRoundTripper{username, password, rt}
662673
}
663674

664675
func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
665-
var username string
666-
var password string
667676
if len(req.Header.Get("Authorization")) != 0 {
668677
return rt.rt.RoundTrip(req)
669678
}
670-
if rt.usernameFile != "" {
671-
usernameBytes, err := os.ReadFile(rt.usernameFile)
679+
var username string
680+
var password string
681+
if rt.username != nil {
682+
var err error
683+
username, err = rt.username.fetch()
672684
if err != nil {
673-
return nil, fmt.Errorf("unable to read basic auth username file %s: %s", rt.usernameFile, err)
685+
return nil, fmt.Errorf("unable to get basic auth username: %w", err)
674686
}
675-
username = strings.TrimSpace(string(usernameBytes))
676-
} else {
677-
username = rt.username
678687
}
679-
if rt.passwordFile != "" {
680-
passwordBytes, err := os.ReadFile(rt.passwordFile)
688+
if rt.password != nil {
689+
var err error
690+
password, err = rt.password.fetch()
681691
if err != nil {
682-
return nil, fmt.Errorf("unable to read basic auth password file %s: %s", rt.passwordFile, err)
692+
return nil, fmt.Errorf("unable to get basic auth password: %w", err)
683693
}
684-
password = strings.TrimSpace(string(passwordBytes))
685-
} else {
686-
password = string(rt.password)
687694
}
688695
req = cloneRequest(req)
689696
req.SetBasicAuth(username, password)
@@ -697,20 +704,22 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
697704
}
698705

699706
type oauth2RoundTripper struct {
700-
config *OAuth2
701-
rt http.RoundTripper
702-
next http.RoundTripper
703-
secret string
704-
mtx sync.RWMutex
705-
opts *httpClientOptions
706-
client *http.Client
707+
config *OAuth2
708+
clientSecret secret
709+
rt http.RoundTripper
710+
next http.RoundTripper
711+
secret string
712+
mtx sync.RWMutex
713+
opts *httpClientOptions
714+
client *http.Client
707715
}
708716

709717
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
710718
return &oauth2RoundTripper{
711-
config: config,
712-
next: next,
713-
opts: opts,
719+
config: config,
720+
clientSecret: secretFrom(config.ClientSecret, config.ClientSecretFile),
721+
next: next,
722+
opts: opts,
714723
}
715724
}
716725

@@ -720,22 +729,18 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
720729
changed bool
721730
)
722731

723-
if rt.config.ClientSecretFile != "" {
724-
data, err := os.ReadFile(rt.config.ClientSecretFile)
732+
if rt.clientSecret != nil {
733+
var err error
734+
secret, err = rt.clientSecret.fetch()
725735
if err != nil {
726-
return nil, fmt.Errorf("unable to read oauth2 client secret file %s: %s", rt.config.ClientSecretFile, err)
736+
return nil, fmt.Errorf("unable to get oauth2 client secret: %w", err)
727737
}
728-
secret = strings.TrimSpace(string(data))
729-
rt.mtx.RLock()
730-
changed = secret != rt.secret
731-
rt.mtx.RUnlock()
732738
}
739+
rt.mtx.RLock()
740+
changed = secret != rt.secret
741+
rt.mtx.RUnlock()
733742

734743
if changed || rt.rt == nil {
735-
if rt.config.ClientSecret != "" {
736-
secret = string(rt.config.ClientSecret)
737-
}
738-
739744
config := &clientcredentials.Config{
740745
ClientID: rt.config.ClientID,
741746
ClientSecret: secret,
@@ -852,17 +857,14 @@ func NewTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
852857

853858
// If a CA cert is provided then let's read it in so we can validate the
854859
// scrape target's certificate properly.
855-
if len(cfg.CA) > 0 {
856-
if !updateRootCA(tlsConfig, []byte(cfg.CA)) {
857-
return nil, fmt.Errorf("unable to use inline CA cert")
858-
}
859-
} else if len(cfg.CAFile) > 0 {
860-
b, err := readCAFile(cfg.CAFile)
860+
caSecret := secretFrom(Secret(cfg.CA), cfg.CAFile)
861+
if caSecret != nil {
862+
ca, err := caSecret.fetch()
861863
if err != nil {
862-
return nil, err
864+
return nil, fmt.Errorf("unable to get CA cert: %w", err)
863865
}
864-
if !updateRootCA(tlsConfig, b) {
865-
return nil, fmt.Errorf("unable to use specified CA cert %s", cfg.CAFile)
866+
if !updateRootCA(tlsConfig, []byte(ca)) {
867+
return nil, errors.New("unable to use CA cert")
866868
}
867869
}
868870

@@ -970,45 +972,36 @@ func (c *TLSConfig) roundTripperSettings() TLSRoundTripperSettings {
970972
// getClientCertificate reads the pair of client cert and key from disk and returns a tls.Certificate.
971973
func (c *TLSConfig) getClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
972974
var (
973-
certData, keyData []byte
975+
certData, keyData string
974976
err error
975977
)
976978

977-
if c.CertFile != "" {
978-
certData, err = os.ReadFile(c.CertFile)
979+
certSecret := secretFrom(Secret(c.Cert), c.CertFile)
980+
if certSecret != nil {
981+
certData, err = certSecret.fetch()
979982
if err != nil {
980-
return nil, fmt.Errorf("unable to read specified client cert (%s): %s", c.CertFile, err)
983+
return nil, fmt.Errorf("unable to get client cert: %w", err)
981984
}
982-
} else {
983-
certData = []byte(c.Cert)
984985
}
985986

986-
if c.KeyFile != "" {
987-
keyData, err = os.ReadFile(c.KeyFile)
987+
keySecret := secretFrom(Secret(c.Key), c.KeyFile)
988+
if keySecret != nil {
989+
keyData, err = keySecret.fetch()
988990
if err != nil {
989-
return nil, fmt.Errorf("unable to read specified client key (%s): %s", c.KeyFile, err)
991+
return nil, fmt.Errorf("unable to get client key: %w", err)
990992
}
991-
} else {
992-
keyData = []byte(c.Key)
993993
}
994994

995-
cert, err := tls.X509KeyPair(certData, keyData)
995+
certStr := unsafe.Slice(unsafe.StringData(certData), len(certData))
996+
keyStr := unsafe.Slice(unsafe.StringData(keyData), len(keyData))
997+
cert, err := tls.X509KeyPair(certStr, keyStr)
996998
if err != nil {
997999
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %s", c.CertFile, c.KeyFile, err)
9981000
}
9991001

10001002
return &cert, nil
10011003
}
10021004

1003-
// readCAFile reads the CA cert file from disk.
1004-
func readCAFile(f string) ([]byte, error) {
1005-
data, err := os.ReadFile(f)
1006-
if err != nil {
1007-
return nil, fmt.Errorf("unable to load specified CA cert %s: %s", f, err)
1008-
}
1009-
return data, nil
1010-
}
1011-
10121005
// updateRootCA parses the given byte slice as a series of PEM encoded certificates and updates tls.Config.RootCAs.
10131006
func updateRootCA(cfg *tls.Config, b []byte) bool {
10141007
caCertPool := x509.NewCertPool()

0 commit comments

Comments
 (0)