Skip to content

Commit 026a460

Browse files
committed
feat: add backend for oidc consent
1 parent abb47a8 commit 026a460

17 files changed

Lines changed: 262 additions & 64 deletions

internal/bootstrap/app_bootstrap.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type Services struct {
4747
type BootstrapApp struct {
4848
config model.Config
4949
runtime model.RuntimeConfig
50+
helpers model.RuntimeHelpers
5051
services Services
5152
log *logger.Logger
5253
ctx context.Context
@@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
185186
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough
186187

187188
app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
188-
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
189-
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
190189
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
190+
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)
191191

192192
// database
193193
store, err := app.SetupStore()
@@ -264,6 +264,9 @@ func (app *BootstrapApp) Setup() error {
264264
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
265265
}
266266

267+
// runtime helpers
268+
app.helpers.GetCookieDomain = app.getCookieDomain
269+
267270
// setup router
268271
err = app.setupRouter()
269272

internal/bootstrap/app_helpers.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package bootstrap
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
"github.com/tinyauthapp/tinyauth/internal/utils"
9+
)
10+
11+
// Not really the best place for the helpers to be but it works because bootstrap app provides
12+
// them with everything they need
13+
14+
func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
15+
cookieDomain := app.runtime.CookieDomain
16+
17+
if app.isTailscaleRequest(ctx, ip) {
18+
if app.services.tailscaleService == nil {
19+
return "", errors.New("tailscale service is not configured")
20+
}
21+
22+
tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))
23+
24+
if err != nil {
25+
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
26+
}
27+
28+
cookieDomain = tsCookieDomain
29+
}
30+
31+
if app.config.Auth.SubdomainsEnabled {
32+
cookieDomain = "." + cookieDomain
33+
}
34+
35+
return cookieDomain, nil
36+
}
37+
38+
func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
39+
if app.services.tailscaleService == nil {
40+
return false
41+
}
42+
43+
whois, err := app.services.tailscaleService.Whois(ctx, ip)
44+
45+
if err != nil {
46+
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
47+
return false
48+
}
49+
50+
if whois == nil {
51+
return false
52+
}
53+
54+
return true
55+
}

internal/bootstrap/router_bootstrap.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ func (app *BootstrapApp) setupRouter() error {
5858
apiRouter := engine.Group("/api")
5959

6060
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
61-
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
62-
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter, &engine.RouterGroup)
61+
controller.NewOAuthController(app.log, app.config, app.runtime, app.helpers, apiRouter, app.services.authService)
62+
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, app.helpers, app.config, apiRouter, &engine.RouterGroup)
6363
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService, app.services.policyEngine)
6464
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
6565
controller.NewResourcesController(app.config, &engine.RouterGroup)

internal/bootstrap/service_bootstrap.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (app *BootstrapApp) setupServices() error {
4242
oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx)
4343
app.services.oauthBrokerService = oauthBrokerService
4444

45-
authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
45+
authService := service.NewAuthService(app.log, app.config, app.runtime, app.helpers, app.ctx, app.ding, app.services.ldapService, app.queries, app.services.oauthBrokerService, app.services.tailscaleService, app.services.policyEngine)
4646
app.services.authService = authService
4747

4848
oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ding)

internal/controller/oauth_controller.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,23 @@ type OAuthController struct {
2424
log *logger.Logger
2525
config model.Config
2626
runtime model.RuntimeConfig
27+
helpers model.RuntimeHelpers
2728
auth *service.AuthService
2829
}
2930

3031
func NewOAuthController(
3132
log *logger.Logger,
3233
config model.Config,
3334
runtimeConfig model.RuntimeConfig,
35+
helpers model.RuntimeHelpers,
3436
router *gin.RouterGroup,
3537
auth *service.AuthService,
3638
) *OAuthController {
3739
controller := &OAuthController{
3840
log: log,
3941
config: config,
4042
runtime: runtimeConfig,
43+
helpers: helpers,
4144
auth: auth,
4245
}
4346

@@ -105,7 +108,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
105108
return
106109
}
107110

108-
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
111+
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
112+
113+
if err != nil {
114+
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
115+
c.JSON(500, gin.H{
116+
"status": 500,
117+
"message": "Internal Server Error",
118+
})
119+
return
120+
}
121+
122+
c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)
109123

110124
c.JSON(200, gin.H{
111125
"status": 200,
@@ -135,7 +149,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
135149
return
136150
}
137151

138-
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
152+
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())
153+
154+
if err != nil {
155+
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
156+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
157+
return
158+
}
159+
160+
c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)
139161

140162
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
141163

@@ -252,7 +274,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
252274

253275
controller.log.App.Debug().Msg("Creating session cookie for user")
254276

255-
cookie, err := controller.auth.CreateSession(c, sessionCookie)
277+
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())
256278

257279
if err != nil {
258280
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -298,10 +320,3 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
298320
func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackParams) bool {
299321
return params.LoginFor == string(FrontendLoginForOIDC)
300322
}
301-
302-
func (controller *OAuthController) getCookieDomain() string {
303-
if controller.config.Auth.SubdomainsEnabled {
304-
return "." + controller.runtime.CookieDomain
305-
}
306-
return controller.runtime.CookieDomain
307-
}

internal/controller/oidc_controller.go

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package controller
22

33
import (
4+
"database/sql"
45
"encoding/json"
56
"errors"
67
"fmt"
78
"net/http"
89
"slices"
910
"strings"
11+
"time"
1012

1113
"github.com/gin-gonic/gin"
1214
"github.com/gin-gonic/gin/binding"
@@ -31,6 +33,8 @@ type OIDCController struct {
3133
log *logger.Logger
3234
oidc *service.OIDCService
3335
runtime model.RuntimeConfig
36+
helpers model.RuntimeHelpers
37+
config model.Config
3438
}
3539

3640
type AuthorizeCallback struct {
@@ -68,10 +72,11 @@ type ClientCredentials struct {
6872
}
6973

7074
type AuthorizeScreenParams struct {
71-
LoginFor FrontendLoginFor `url:"login_for"`
72-
OIDCTicket string `url:"oidc_ticket"`
73-
OIDCScope string `url:"oidc_scope"`
74-
OIDCName string `url:"oidc_name"`
75+
LoginFor FrontendLoginFor `url:"login_for"`
76+
OIDCTicket string `url:"oidc_ticket"`
77+
OIDCScope string `url:"oidc_scope"`
78+
OIDCName string `url:"oidc_name"`
79+
OIDCShowConsent bool `url:"oidc_show_consent"`
7580
}
7681

7782
type AuthorizeCompleteRequest struct {
@@ -82,12 +87,16 @@ func NewOIDCController(
8287
log *logger.Logger,
8388
oidcService *service.OIDCService,
8489
runtimeConfig model.RuntimeConfig,
90+
helpers model.RuntimeHelpers,
91+
config model.Config,
8592
router *gin.RouterGroup,
8693
mainRouter *gin.RouterGroup) *OIDCController {
8794
controller := &OIDCController{
8895
log: log,
8996
oidc: oidcService,
9097
runtime: runtimeConfig,
98+
helpers: helpers,
99+
config: config,
91100
}
92101

93102
mainRouter.POST("/authorize", controller.authorize)
@@ -163,11 +172,31 @@ func (controller *OIDCController) authorize(c *gin.Context) {
163172

164173
ticket := controller.oidc.CreateAuthorizeRequestTicket(*req)
165174

175+
// Check if we have consented before for this client and scope
176+
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)
177+
178+
showConsent := true
179+
180+
if err == nil {
181+
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)
182+
183+
if err == nil && consentEntry != nil {
184+
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
185+
showConsent = false
186+
}
187+
} else {
188+
if !errors.Is(err, sql.ErrNoRows) {
189+
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
190+
}
191+
}
192+
}
193+
166194
queries, err := query.Values(AuthorizeScreenParams{
167-
LoginFor: FrontendLoginForOIDC,
168-
OIDCTicket: ticket,
169-
OIDCScope: req.Scope,
170-
OIDCName: client.Name,
195+
LoginFor: FrontendLoginForOIDC,
196+
OIDCTicket: ticket,
197+
OIDCScope: req.Scope,
198+
OIDCName: client.Name,
199+
OIDCShowConsent: showConsent,
171200
})
172201

173202
if err != nil {
@@ -289,6 +318,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
289318
return
290319
}
291320

321+
// Just before returning let's set the consent cookie
322+
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)
323+
324+
// If we fail to create the consent entry, we don't want to block the authorization flow,
325+
// but we log the error and move on without setting the cookie
326+
if err == nil {
327+
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())
328+
329+
if err == nil {
330+
cookie := &http.Cookie{
331+
Name: controller.runtime.ConsentCookieName,
332+
Value: consnetUUID,
333+
Path: "/",
334+
Domain: cookieDomain,
335+
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
336+
Secure: controller.config.Auth.SecureCookie,
337+
HttpOnly: true,
338+
SameSite: http.SameSiteLaxMode,
339+
}
340+
http.SetCookie(c.Writer, cookie)
341+
} else {
342+
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
343+
}
344+
} else {
345+
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
346+
}
347+
292348
c.JSON(200, gin.H{
293349
"status": 200,
294350
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),

internal/controller/oidc_controller_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ func TestOIDCController(t *testing.T) {
3030

3131
cfg, runtime := test.CreateTestConfigs(t)
3232

33+
helpers := test.CreateTestHelpers()
34+
3335
ctx := context.TODO()
3436
dg := ding.New(ctx)
3537

@@ -831,7 +833,7 @@ func TestOIDCController(t *testing.T) {
831833
svc = nil
832834
}
833835

834-
controller.NewOIDCController(log, svc, runtime, group, &router.RouterGroup)
836+
controller.NewOIDCController(log, svc, runtime, helpers, cfg, group, &router.RouterGroup)
835837

836838
recorder := httptest.NewRecorder()
837839

internal/controller/proxy_controller_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ func TestProxyController(t *testing.T) {
2424

2525
cfg, runtime := test.CreateTestConfigs(t)
2626

27+
helpers := test.CreateTestHelpers()
28+
2729
const browserUserAgent = `
2830
Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36`
2931

@@ -395,7 +397,7 @@ func TestProxyController(t *testing.T) {
395397
Log: log,
396398
})
397399

398-
authService := service.NewAuthService(log, cfg, runtime, ctx, dg, nil, store, broker, nil, policyEngine)
400+
authService := service.NewAuthService(log, cfg, runtime, helpers, ctx, dg, nil, store, broker, nil, policyEngine)
399401

400402
for _, test := range tests {
401403
t.Run(test.description, func(t *testing.T) {

0 commit comments

Comments
 (0)