-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathhandler.http.mw.go
152 lines (131 loc) · 4.24 KB
/
handler.http.mw.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package main
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strings"
jwtgo "github.com/dgrijalva/jwt-go"
)
// CtxKeyRetries is the context key that holds retry middleware that is
// used when error checking and retrying requests route matches.
const CtxKeyRetries ctxKey = "_retry_"
// checkRetries is middleware that sets the retry context values on a request
// if there are more that on requests available to check.
func checkRetries(v hfsmws) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), CtxKeyRetries, v)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// checkBasicAuth is middleware that preforms a Basic Auth check. Any errors result
// in a 401 wrapped error
func checkBasicAuth(config ConfigHTTP, notfound http.HandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(WriteError(func(w http.ResponseWriter, r *http.Request) error {
authStrs := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(authStrs) != 2 {
return Ext401Error{fmt.Errorf("auth header is not two parts")}
}
b, err := base64.StdEncoding.DecodeString(authStrs[1])
if err != nil {
return ErrDecodeBase64.F401(err)
}
userpass := strings.SplitN(string(b), ":", 2)
if len(userpass) != 2 {
return Ext401Error{fmt.Errorf("username/password is not two parts")}
}
if userpass[0] != config.BasicAuth.User || userpass[1] != config.BasicAuth.Pass {
return Ext401Error{fmt.Errorf("bad username/password")}
}
if relm := config.BasicAuth.Relm; relm != "" {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Basic realm="%s"`, relm))
}
next.ServeHTTP(w, r)
return nil
}))
}
}
// checkRequestJWT is middleware that checks an incoming JWT auth against values that it should contain
func checkRequestJWT(req RequestHTTP, notfound http.HandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return WriteError(func(w http.ResponseWriter, r *http.Request) error {
token, err := decodeJWT(w, r, req.JWT)
if err != nil {
if !errors.As(err, &WarnError{}) {
return ErrMarshalJWT.F(err)
}
}
// go through the claims and see if the strings match
if claims, ok := token.Claims.(jwtgo.MapClaims); ok {
for k, clav := range claims {
if reqv, ok := req.JWT.KeyVals[k]; ok {
if v1, ok := clav.(string); ok {
v2, _ := reqv.Expr.Value(nil)
if v1 != v2.AsString() {
return ErrInvalidJWTClaim
}
}
}
}
}
ctx := context.WithValue(r.Context(), CtxKeyJWTToken, token)
next.ServeHTTP(w, r.WithContext(ctx))
return nil
})
}
}
// checkRequestHeader checks incoming header values against values that it should contain
func checkRequestHeader(req RequestHTTP, _nf http.HandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return WriteError(func(w http.ResponseWriter, r *http.Request) error {
for k, vals := range req.Headers.Data {
values := r.Header.Values(k)
chk := len(vals)
if chk != len(values) {
return ErrFilterFailed.F404("header", "unequal lengths")
}
// check that all the values are the same or a "*"
for _, val := range vals {
v1 := val.AsString()
for _, v2 := range values {
if v1 == "*" || v1 == v2 {
chk--
}
}
}
// if we've found them all then we'll be at 0, otherwise...
if chk != 0 {
return ErrFilterFailed.F404("header", "did not find a value")
}
next.ServeHTTP(w, r)
}
return nil
})
}
}
// checkRequestJWT checks incoming post against values that it should contain
func checkRequestPost(req RequestHTTP, notfound http.HandlerFunc) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return WriteError(func(w http.ResponseWriter, r *http.Request) error {
err := r.ParseForm()
if err != nil {
return ErrParseForm.F(err)
}
for k, v := range req.Posted {
if v == "*" {
continue
}
if v != r.PostFormValue(k) {
notfound(w, r)
return nil
}
}
next.ServeHTTP(w, r)
return nil
})
}
}