Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion interfaces/openapi-to-go-server/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def _generate_apis(
"fmt",
"go.opentelemetry.io/otel",
"go.opentelemetry.io/otel/trace",
"github.com/interuss/stacktrace",
]
),
)
+ '\n dsserr "github.com/interuss/dss/pkg/errors"',
"<API_PACKAGE>": api_package,
"<ROUTES>": "\n".join(routes),
"<ROUTING>": "\n".join(rendering.routing(api, api_package)),
Expand Down
56 changes: 38 additions & 18 deletions interfaces/openapi-to-go-server/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,9 @@ def routes(

body: List[str] = []

# Create object to hold the processed input to the operation
# Create object to hold the processed input and output to the operation
body.append("var req {}".format(operation.request_type_name))
body.append("")

# Attempt to authorize access to the operation
body.extend(comment(["Authorize request"]))
body.append(
"req.Auth = s.Authorizer.Authorize(w, r, {}Security)".format(
operation.interface_name
)
)
body.append("var response {}".format(operation.response_type_name))
body.append("")

# Parse any path parameters
Expand Down Expand Up @@ -385,20 +377,48 @@ def routes(

# Actually invoke the API Implementation with the processed request to obtain the response
imports.add("context")
body.extend(comment(["Call implementation"]))
body.append("ctx, cancel := context.WithCancel(r.Context())")
body.append("defer cancel()")
body.append(
"response := s.Implementation.{}(ctx, &req)".format(
operation.interface_name
)

body.append("")

call_block = []

call_block.extend(comment(["Call implementation"]))
call_block.append("ctx, cancel := context.WithCancel(r.Context())")
call_block.append("defer cancel()")
call_block.append(
"response = s.Implementation.{}(ctx, &req)".format(operation.interface_name)
)

if operation.security.options:
# Authorize & verify the call
body.extend(comment(["Authorize request"]))
body.append(
"req.Auth = s.Authorizer.Authorize(w, r, {}Security)".format(
operation.interface_name
)
)

body.extend(comment(["Verify authorization"]))
body.append("if req.Auth.Error != nil {")
body.extend(
indent(
[
'setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)'
],
1,
)
)
body.append("} else {")
body.extend(indent(call_block, 1))
body.append("}")
else:
body.extend(call_block)
body.append("")

# Write the first populated response discovered and finish the handler
body.extend(comment(["Write response to client"]))
responses = [r for r in operation.responses]
if ensure_500 and 500 not in {r.code for r in responses}:
if ensure_500 and 500 not in {r.code for r in operation.responses}:
responses.append(
operations.Response(
code=500,
Expand Down
17 changes: 17 additions & 0 deletions interfaces/openapi-to-go-server/templates/server.go.template
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ func (s *APIRouter) Handle(w http.ResponseWriter, r *http.Request) bool {
return false
}


func setAuthError(ctx context.Context, authErr error, resp401 **ErrorResponse, resp403 **ErrorResponse, resp500 **api.InternalServerErrorBody) {
switch stacktrace.GetCode(authErr) {
case dsserr.Unauthenticated:
*resp401 = &ErrorResponse{Message: dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Authentication failed"))}
case dsserr.PermissionDenied:
*resp403 = &ErrorResponse{Message: dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Authorization failed"))}
default:

if authErr == nil {
authErr = stacktrace.NewError("Unknown error")
}

*resp500 = &api.InternalServerErrorBody{ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Could not perform authorization"))}
}
}

<ROUTES>

func MakeAPIRouter(impl Implementation, auth <API_PACKAGE>.Authorizer) APIRouter {
Expand Down
119 changes: 78 additions & 41 deletions pkg/api/auxv1/server.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"fmt"
"github.com/interuss/dss/pkg/api"
dsserr "github.com/interuss/dss/pkg/errors"
"github.com/interuss/stacktrace"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"net/http"
Expand Down Expand Up @@ -42,16 +44,30 @@ func (s *APIRouter) Handle(w http.ResponseWriter, r *http.Request) bool {
return false
}

func setAuthError(ctx context.Context, authErr error, resp401 **ErrorResponse, resp403 **ErrorResponse, resp500 **api.InternalServerErrorBody) {
switch stacktrace.GetCode(authErr) {
case dsserr.Unauthenticated:
*resp401 = &ErrorResponse{Message: dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Authentication failed"))}
case dsserr.PermissionDenied:
*resp403 = &ErrorResponse{Message: dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Authorization failed"))}
default:

if authErr == nil {
authErr = stacktrace.NewError("Unknown error")
}

*resp500 = &api.InternalServerErrorBody{ErrorMessage: *dsserr.Handle(ctx, stacktrace.Propagate(authErr, "Could not perform authorization"))}
}
}

func (s *APIRouter) GetVersion(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetVersionRequest

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetVersionSecurity)
var response GetVersionResponseSet

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetVersion(ctx, &req)
response = s.Implementation.GetVersion(ctx, &req)

// Write response to client
if response.Response200 != nil {
Expand All @@ -67,9 +83,7 @@ func (s *APIRouter) GetVersion(exp *regexp.Regexp, w http.ResponseWriter, r *htt

func (s *APIRouter) ValidateOauth(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req ValidateOauthRequest

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, ValidateOauthSecurity)
var response ValidateOauthResponseSet

// Copy query parameters
query := r.URL.Query()
Expand All @@ -78,10 +92,17 @@ func (s *APIRouter) ValidateOauth(exp *regexp.Regexp, w http.ResponseWriter, r *
req.Owner = &v
}

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.ValidateOauth(ctx, &req)
// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, ValidateOauthSecurity)
// Verify authorization
if req.Auth.Error != nil {
setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)
} else {
// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response = s.Implementation.ValidateOauth(ctx, &req)
}

// Write response to client
if response.Response200 != nil {
Expand All @@ -105,14 +126,19 @@ func (s *APIRouter) ValidateOauth(exp *regexp.Regexp, w http.ResponseWriter, r *

func (s *APIRouter) GetPool(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetPoolRequest
var response GetPoolResponseSet

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetPoolSecurity)

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetPool(ctx, &req)
// Verify authorization
if req.Auth.Error != nil {
setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)
} else {
// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response = s.Implementation.GetPool(ctx, &req)
}

// Write response to client
if response.Response200 != nil {
Expand Down Expand Up @@ -140,14 +166,19 @@ func (s *APIRouter) GetPool(exp *regexp.Regexp, w http.ResponseWriter, r *http.R

func (s *APIRouter) GetDSSInstances(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetDSSInstancesRequest
var response GetDSSInstancesResponseSet

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetDSSInstancesSecurity)

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetDSSInstances(ctx, &req)
// Verify authorization
if req.Auth.Error != nil {
setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)
} else {
// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response = s.Implementation.GetDSSInstances(ctx, &req)
}

// Write response to client
if response.Response200 != nil {
Expand Down Expand Up @@ -175,9 +206,7 @@ func (s *APIRouter) GetDSSInstances(exp *regexp.Regexp, w http.ResponseWriter, r

func (s *APIRouter) PutDSSInstancesHeartbeat(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req PutDSSInstancesHeartbeatRequest

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, PutDSSInstancesHeartbeatSecurity)
var response PutDSSInstancesHeartbeatResponseSet

// Copy query parameters
query := r.URL.Query()
Expand All @@ -194,10 +223,17 @@ func (s *APIRouter) PutDSSInstancesHeartbeat(exp *regexp.Regexp, w http.Response
req.NextHeartbeatExpectedBefore = &v
}

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.PutDSSInstancesHeartbeat(ctx, &req)
// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, PutDSSInstancesHeartbeatSecurity)
// Verify authorization
if req.Auth.Error != nil {
setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)
} else {
// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response = s.Implementation.PutDSSInstancesHeartbeat(ctx, &req)
}

// Write response to client
if response.Response201 != nil {
Expand Down Expand Up @@ -229,14 +265,12 @@ func (s *APIRouter) PutDSSInstancesHeartbeat(exp *regexp.Regexp, w http.Response

func (s *APIRouter) GetAcceptedCAs(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetAcceptedCAsRequest

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetAcceptedCAsSecurity)
var response GetAcceptedCAsResponseSet

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetAcceptedCAs(ctx, &req)
response = s.Implementation.GetAcceptedCAs(ctx, &req)

// Write response to client
if response.Response200 != nil {
Expand All @@ -256,14 +290,12 @@ func (s *APIRouter) GetAcceptedCAs(exp *regexp.Regexp, w http.ResponseWriter, r

func (s *APIRouter) GetInstanceCAs(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetInstanceCAsRequest

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetInstanceCAsSecurity)
var response GetInstanceCAsResponseSet

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetInstanceCAs(ctx, &req)
response = s.Implementation.GetInstanceCAs(ctx, &req)

// Write response to client
if response.Response200 != nil {
Expand All @@ -283,14 +315,19 @@ func (s *APIRouter) GetInstanceCAs(exp *regexp.Regexp, w http.ResponseWriter, r

func (s *APIRouter) GetScdLockMode(exp *regexp.Regexp, w http.ResponseWriter, r *http.Request) {
var req GetScdLockModeRequest
var response GetScdLockModeResponseSet

// Authorize request
req.Auth = s.Authorizer.Authorize(w, r, GetScdLockModeSecurity)

// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response := s.Implementation.GetScdLockMode(ctx, &req)
// Verify authorization
if req.Auth.Error != nil {
setAuthError(r.Context(), stacktrace.Propagate(req.Auth.Error, "Auth failed"), &response.Response401, &response.Response403, &response.Response500)
} else {
// Call implementation
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
response = s.Implementation.GetScdLockMode(ctx, &req)
}

// Write response to client
if response.Response200 != nil {
Expand Down
Loading
Loading