@@ -19,11 +19,13 @@ import (
19
19
// Returns a function that must be invoked by the caller to wait for the SAML
20
20
// response and the server to shutdown.
21
21
func Start (ctx context.Context , listen , url string ) (string , func () (string , error )) { //nolint:cyclop,funlen
22
- // Channels for asynchronously communicating the SAML response string or
23
- // any errors that are encountered.
24
- responseChan := make (chan string )
22
+ // Channels for asynchronously communicating any error that is encountered.
25
23
errChan := make (chan error )
26
24
25
+ // The SAML response body that is received by the callback handler and
26
+ // passed on to AWS.
27
+ var samlResponse string
28
+
27
29
// sendError is a helper function for sending errors to the error channel
28
30
// in a non-blocking fashion.
29
31
sendError := func (err error ) {
@@ -71,13 +73,32 @@ func Start(ctx context.Context, listen, url string) (string, func() (string, err
71
73
}
72
74
73
75
go func () {
74
- // Write the SAML response string to our channel .
75
- responseChan <- request .FormValue ("SAMLResponse" )
76
+ // Read the SAML response string.
77
+ samlResponse = request .FormValue ("SAMLResponse" )
76
78
}()
77
79
78
80
http .Redirect (writer , request , "/" , http .StatusFound )
79
81
})
80
82
83
+ // The /response route serves the formatted SAML assertion.
84
+ mux .HandleFunc ("/response" , func (writer http.ResponseWriter , request * http.Request ) {
85
+ if request .Method != http .MethodGet {
86
+ http .Error (writer , "method not allowed" , http .StatusMethodNotAllowed )
87
+
88
+ return
89
+ }
90
+
91
+ if len (samlResponse ) == 0 {
92
+ http .NotFound (writer , request )
93
+
94
+ return
95
+ }
96
+
97
+ if err := formatSAMLResponse (samlResponse , writer ); err != nil {
98
+ http .Error (writer , err .Error (), http .StatusInternalServerError )
99
+ }
100
+ })
101
+
81
102
// The /callback route is called by the user to terminate this server.
82
103
mux .HandleFunc ("/shutdown" , func (writer http.ResponseWriter , request * http.Request ) {
83
104
if request .Method != http .MethodPost {
@@ -114,12 +135,6 @@ func Start(ctx context.Context, listen, url string) (string, func() (string, err
114
135
return fmt .Sprintf ("http://%s/login" , listen ), func () (string , error ) {
115
136
defer server .Shutdown (ctx ) //nolint:errcheck
116
137
117
- var samlResponse string
118
- // Wait for the SAML response.
119
- go func () {
120
- samlResponse = <- responseChan
121
- }()
122
-
123
138
// Wait for an error.
124
139
err := <- errChan
125
140
switch err {
0 commit comments