Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sso_proxy: reduce direct calls to ValidateGroup() and clean up logic #275

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func TestSignIn(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
auth, err := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail, nil)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setMockRedirectURL(),
Expand Down Expand Up @@ -565,7 +565,7 @@ func TestSignOutPage(t *testing.T) {
provider.RevokeError = tc.RevokeError

p, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]options.Validator{options.NewMockValidator(true, nil)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setTestProvider(provider),
Expand Down Expand Up @@ -942,7 +942,7 @@ func TestGetProfile(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
p, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]options.Validator{options.NewMockValidator(true, nil)}),
)
u, _ := url.Parse("http://example.com")
testProvider := providers.NewTestProvider(u)
Expand Down Expand Up @@ -1044,7 +1044,7 @@ func TestRedeemCode(t *testing.T) {
config := testConfiguration(t)

proxy, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]options.Validator{options.NewMockValidator(true, nil)}),
)

testURL, err := url.Parse("example.com")
Expand Down Expand Up @@ -1433,7 +1433,7 @@ func TestOAuthCallback(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
proxy, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail, nil)}),
setMockCSRFStore(tc.csrfResp),
setMockSessionStore(tc.sessionStore),
)
Expand Down Expand Up @@ -1554,7 +1554,7 @@ func TestOAuthStart(t *testing.T) {
provider := providers.NewTestProvider(nil)
proxy, _ := NewAuthenticator(config,
setTestProvider(provider),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]options.Validator{options.NewMockValidator(true, nil)}),
setMockRedirectURL(),
setMockCSRFStore(&sessions.MockCSRFStore{}),
)
Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/options/email_group_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (v EmailGroupValidator) Validate(session *sessions.SessionState) error {
func (v EmailGroupValidator) validate(session *sessions.SessionState) error {
matchedGroups, valid, err := v.Provider.ValidateGroup(session.Email, v.AllowedGroups, session.AccessToken)
if err != nil {
return ErrValidationError
return err
}

if valid {
Expand Down
10 changes: 9 additions & 1 deletion internal/pkg/options/mock_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,26 @@ var (

type MockValidator struct {
Result bool
Err error
}

func NewMockValidator(result bool) MockValidator {
func NewMockValidator(result bool, err error) MockValidator {
return MockValidator{
Result: result,
Err: err,
}
}

func (v MockValidator) Validate(session *sessions.SessionState) error {
// if we pass in a specific error, return it
if v.Err != nil {
return v.Err
}
// if result is true, return nil
if v.Result {
return nil
}

// otherwise, return generic mock validator error
return errors.New("MockValidator error")
}
1 change: 0 additions & 1 deletion internal/pkg/options/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ var (
// These error message should be formatted in such a way that is appropriate
// for display to the end user.
ErrInvalidEmailAddress = errors.New("Invalid Email Address In Session State")
ErrValidationError = errors.New("Error during validation")
)

type Validator interface {
Expand Down
8 changes: 8 additions & 0 deletions internal/pkg/sessions/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}

// IsWithinGracePeriod returns true if the session is still within the grace period
func (s *SessionState) IsWithinGracePeriod(gracePeriodTTL time.Duration) bool {
if s.GracePeriodStart.IsZero() {
s.GracePeriodStart = time.Now()
}
return s.GracePeriodStart.Add(gracePeriodTTL).After(time.Now())
}

func isExpired(t time.Time) bool {
if t.Before(time.Now()) {
return true
Expand Down
15 changes: 12 additions & 3 deletions internal/pkg/sessions/session_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,29 @@ func TestSessionStateExpirations(t *testing.T) {
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
GracePeriodStart: time.Now().Add(-2 * time.Minute),

Email: "[email protected]",
User: "user",
}

if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if session.IsWithinGracePeriod(1 * time.Minute) {
t.Errorf("expected session to be outside of grace period")
}

if !session.IsWithinGracePeriod(3 * time.Minute) {
t.Errorf("expected session to be inside of grace period")
}
}
85 changes: 57 additions & 28 deletions internal/proxy/oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i
p.templates.ExecuteTemplate(rw, "error.html", t)
}

// IsWhitelistedRequest cheks that proxy host exists and checks the SkipAuthRegex
// IsWhitelistedRequest checks that proxy host exists and checks the SkipAuthRegex
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) bool {
if p.skipAuthPreflight && req.Method == "OPTIONS" {
return true
Expand All @@ -375,6 +375,27 @@ func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) bool {
return false
}

// runValidatorsWithGracePeriod runs all validators and upon finding errors, checks to see if the
// auth provider is explicity denying authentication or if it's merely unavailable. If it's unavailable,
// we check whether the session is within the grace period or not to determine the specific error we return.
func (p *OAuthProxy) runValidatorsWithGracePeriod(session *sessions.SessionState) (err error) {
logger := log.NewLogEntry()
errors := options.RunValidators(p.Validators, session)
if len(errors) == len(p.Validators) {
for _, err := range errors {
// Check to see if the auth provider is explicity denying authentication, or if it is merely unavailable.
if err == providers.ErrAuthProviderUnavailable && session.IsWithinGracePeriod(p.provider.Data().GracePeriodTTL) {
return err
}
}
allowedGroups := p.upstreamConfig.AllowedGroups
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put the allowed groups in the validator error message instead of pulling them out this way? I think it could make the logline easier to understand which set of validators rejected a user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the allowed groups into the group validator error message. Below is an example of the error message displayed to users (the same formatting is used for the log lines).

Screen Shot 2020-01-28 at 12 11 50

Also removed a chunk of formatting logic around the errors as it was becoming over-engineered and unnecessary. We could also add some extra context to errors coming from the domain/email address validators, but as is the error message would become pretty bloated if multiple validators returned errors - so perhaps worth addressing that separately?

logger.WithUser(session.Email).WithAllowedGroups(allowedGroups).Error(errors,
"no longer authorized after validation period")
return ErrUserNotAuthorized
}
return nil
}

func (p *OAuthProxy) isXHR(req *http.Request) bool {
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
}
Expand Down Expand Up @@ -693,8 +714,6 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
remoteAddr := getRemoteAddr(req)
tags := []string{"action:authenticate"}

allowedGroups := p.upstreamConfig.AllowedGroups

// Clear the session cookie if anything goes wrong.
defer func() {
if err != nil {
Expand All @@ -705,7 +724,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
session, err := p.sessionStore.LoadSession(req)
if err != nil {
// We loaded a cookie but it wasn't valid, clear it, and reject the request
logger.Error(err, "error authenticating user")
logger.Error(err, "invalid session loaded")
return err
}

Expand All @@ -728,14 +747,17 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
} else if session.RefreshPeriodExpired() {
// Refresh period is the period in which the access token is valid. This is ultimately
// controlled by the upstream provider and tends to be around 1 hour.
ok, err := p.provider.RefreshSession(session, allowedGroups)
// If it has expired we:
// - attempt to refresh the session
// - run email domain, email address, and email group validations against the session (if defined).

ok, err := p.provider.RefreshSession(session)
// We failed to refresh the session successfully
// clear the cookie and reject the request
if err != nil {
logger.WithUser(session.Email).Error(err, "refreshing session failed")
return err
}

if !ok {
// User is not authorized after refresh
// clear the cookie and reject the request
Expand All @@ -744,6 +766,18 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
return ErrUserNotAuthorized
}

err = p.runValidatorsWithGracePeriod(session)
if err != nil {
switch err {
case providers.ErrAuthProviderUnavailable:
tags = append(tags, "action:refresh_session", "error:validation_failed")
p.StatsdClient.Incr("provider_error_fallback", tags, 1.0)
session.RefreshDeadline = sessions.ExtendDeadline(p.provider.Data().SessionValidTTL)
default:
return ErrUserNotAuthorized
}
}

err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
// We refreshed the session successfully, but failed to save it.
Expand All @@ -757,9 +791,11 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
} else if session.ValidationPeriodExpired() {
// Validation period has expired, this is the shortest interval we use to
// check for valid requests. This should be set to something like a minute.
// This calls up the provider chain to validate this user is still active
// and hasn't been de-authorized.
ok := p.provider.ValidateSessionState(session, allowedGroups)
// In this case we:
// - call up the provider chain to validate this user is still active and hasn't been de-authorized.
// - run any defined email domain, email address, and email group validators against the session

ok := p.provider.ValidateSessionToken(session)
if !ok {
// This user is now no longer authorized, or we failed to
// validate the user.
Expand All @@ -769,6 +805,18 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
return ErrUserNotAuthorized
}

err = p.runValidatorsWithGracePeriod(session)
if err != nil {
switch err {
case providers.ErrAuthProviderUnavailable:
tags = append(tags, "action:validate_session", "error:validation_failed")
p.StatsdClient.Incr("provider_error_fallback", tags, 1.0)
session.ValidDeadline = sessions.ExtendDeadline(p.provider.Data().SessionValidTTL)
default:
return ErrUserNotAuthorized
}
}

err = p.sessionStore.SaveSession(rw, req, session)
if err != nil {
// We validated the session successfully, but failed to save it.
Expand All @@ -781,25 +829,6 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er
}
}

// We revalidate group membership whenever the session is refreshed or revalidated
// just above in the call to ValidateSessionState and RefreshSession.
// To reduce strain on upstream identity providers we only revalidate email domains and
// addresses on each request here.
for _, v := range p.Validators {
_, EmailGroupValidator := v.(options.EmailGroupValidator)

if !EmailGroupValidator {
err := v.Validate(session)
if err != nil {
tags = append(tags, "error:validation_failed")
p.StatsdClient.Incr("application_error", tags, 1.0)
logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
fmt.Sprintf("permission denied: unauthorized: %q", err))
return ErrUserNotAuthorized
}
}
}

logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info(
fmt.Sprintf("authentication: user validated"))

Expand Down
Loading