mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 13:37:30 +00:00
API endpoints initial models (#42881)
**Related issue:** Resolves #42881 - Added user_api_endpoints table to track per user API endpoint permissions. - Added service/api_endpoints, used to handle service/api_endpoints.yml artifact. - Added check on server start that makes sure that service/apin_endpoints.yml is a subset of router routes.
This commit is contained in:
parent
07df99daa7
commit
3df6449426
7 changed files with 304 additions and 2 deletions
2
changes/42881-api-endpoints-initial-models
Normal file
2
changes/42881-api-endpoints-initial-models
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
* Added user_api_endpoints table to track per-user API endpoint permissions.
|
||||
* Added startup validation that panics if any route declared in service/api_endpoints.yml is not registered in the router.
|
||||
|
|
@ -1483,6 +1483,10 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev
|
|||
apiHandler = service.MakeHandler(svc, config, httpLogger, limiterStore, redisPool, carveStore,
|
||||
[]endpointer.HandlerRoutesFunc{android_service.GetRoutes(svc, androidSvc), activityRoutes, acmeRoutes}, extra...)
|
||||
|
||||
if err := service.ValidateAPIEndpoints(apiHandler); err != nil {
|
||||
panic(fmt.Sprintf("invalid api_endpoints.yml: %v", err))
|
||||
}
|
||||
|
||||
if serveCSP {
|
||||
// Only injecting this if CSP is turned on since the default security headers add some overhead to each request
|
||||
apiHandler = endpointer.BrowserSecurityHeadersHandler(serveCSP, apiHandler)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package tables
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
MigrationClient.AddMigration(Up_20260406114157, Down_20260406114157)
|
||||
}
|
||||
|
||||
func Up_20260406114157(tx *sql.Tx) error {
|
||||
if !tableExists(tx, "user_api_endpoints") {
|
||||
_, err := tx.Exec(`
|
||||
CREATE TABLE user_api_endpoints (
|
||||
user_id INT UNSIGNED NOT NULL,
|
||||
|
||||
path VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL,
|
||||
method VARCHAR(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL,
|
||||
|
||||
is_allowed BOOLEAN DEFAULT TRUE NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
created_by_id INT UNSIGNED,
|
||||
|
||||
PRIMARY KEY (user_id, path, method),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (created_by_id) REFERENCES users(id) ON DELETE SET NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Down_20260406114157(tx *sql.Tx) error {
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
111
server/service/api_endpoints.go
Normal file
111
server/service/api_endpoints.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
//go:embed api_endpoints.yml
|
||||
var apiEndpointsYAML []byte
|
||||
|
||||
var apiEndpoints = mustParseAPIEndpoints()
|
||||
|
||||
// APIEndpoint represents an API endpoint that we can attach permissions to.
|
||||
type APIEndpoint struct {
|
||||
Method string `yaml:"method"`
|
||||
Path string `yaml:"path"`
|
||||
Name string `yaml:"name"`
|
||||
Deprecated bool `yaml:"deprecated"`
|
||||
}
|
||||
|
||||
var validHTTPMethods = map[string]struct{}{
|
||||
http.MethodGet: {},
|
||||
http.MethodPost: {},
|
||||
http.MethodPut: {},
|
||||
http.MethodPatch: {},
|
||||
http.MethodDelete: {},
|
||||
}
|
||||
|
||||
func (e APIEndpoint) validate() error {
|
||||
if strings.TrimSpace(e.Name) == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
if _, ok := validHTTPMethods[strings.ToUpper(e.Method)]; !ok {
|
||||
return fmt.Errorf("invalid HTTP method %q", e.Method)
|
||||
}
|
||||
if strings.TrimSpace(e.Path) == "" {
|
||||
return errors.New("path is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustParseAPIEndpoints() []APIEndpoint {
|
||||
var routes []APIEndpoint
|
||||
if err := yaml.Unmarshal(apiEndpointsYAML, &routes); err != nil {
|
||||
panic(fmt.Sprintf("api_endpoints.yml: failed to parse: %v", err))
|
||||
}
|
||||
for i, r := range routes {
|
||||
if err := r.validate(); err != nil {
|
||||
panic(fmt.Sprintf("api_endpoints.yml: entry %d: %v", i, err))
|
||||
}
|
||||
// Normalise method to upper-case so callers don't have to.
|
||||
routes[i].Method = strings.ToUpper(r.Method)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetAPIEndpoints returns all routes defined in api_endpoints.yml.
|
||||
func GetAPIEndpoints() []APIEndpoint {
|
||||
return apiEndpoints
|
||||
}
|
||||
|
||||
// versionSegmentRe matches the gorilla/mux version segment that attachFleetAPIRoutes
|
||||
// inserts in place of /_version_/ (e.g. /{fleetversion:(?:v1|2022-04|latest)}/).
|
||||
var versionSegmentRe = regexp.MustCompile(`/\{fleetversion:[^}]+\}/`)
|
||||
|
||||
// ValidateAPIEndpoints checks that every route declared in api_endpoints.yml is
|
||||
// registered in h.
|
||||
func ValidateAPIEndpoints(h http.Handler) error {
|
||||
r, ok := h.(*mux.Router)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected *mux.Router, got %T", h)
|
||||
}
|
||||
|
||||
registered := make(map[string]struct{})
|
||||
_ = r.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
|
||||
tpl, err := route.GetPathTemplate()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
meths, err := route.GetMethods()
|
||||
if err != nil || len(meths) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := versionSegmentRe.ReplaceAllString(tpl, "/_version_/")
|
||||
for _, m := range meths {
|
||||
registered[m+":"+normalized] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
var missing []string
|
||||
for _, route := range GetAPIEndpoints() {
|
||||
key := route.Method + ":" + route.Path
|
||||
if _, ok := registered[key]; !ok {
|
||||
missing = append(missing, route.Method+" "+route.Path)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return fmt.Errorf("the following API endpoints are missing: %v", missing)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
4
server/service/api_endpoints.yml
Normal file
4
server/service/api_endpoints.yml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
- method: POST
|
||||
path: "/api/_version_/fleet/trigger"
|
||||
name: "Some wild description goes here"
|
||||
deprecated: false
|
||||
127
server/service/api_endpoints_test.go
Normal file
127
server/service/api_endpoints_test.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAPIEndpoints(t *testing.T) {
|
||||
routes := GetAPIEndpoints()
|
||||
require.NotEmpty(t, routes)
|
||||
for _, r := range routes {
|
||||
require.NotEmpty(t, r.Method, "route method should not be empty")
|
||||
require.NotEmpty(t, r.Path, "route path should not be empty")
|
||||
require.NotEmpty(t, r.Name, "route name should not be empty")
|
||||
require.True(t, strings.HasPrefix(r.Path, "/"), "route path should start with /")
|
||||
_, validMethod := validHTTPMethods[r.Method]
|
||||
require.True(t, validMethod, "route method %q should be a valid HTTP method", r.Method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIEndpointValidate(t *testing.T) {
|
||||
base := APIEndpoint{Method: "GET", Path: "/api/_version_/fleet/foo", Name: "foo"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modify func(APIEndpoint) APIEndpoint
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid endpoint",
|
||||
modify: func(e APIEndpoint) APIEndpoint { return e },
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
modify: func(e APIEndpoint) APIEndpoint { e.Name = ""; return e },
|
||||
wantErr: "name is required",
|
||||
},
|
||||
{
|
||||
name: "whitespace name",
|
||||
modify: func(e APIEndpoint) APIEndpoint { e.Name = " "; return e },
|
||||
wantErr: "name is required",
|
||||
},
|
||||
{
|
||||
name: "invalid method",
|
||||
modify: func(e APIEndpoint) APIEndpoint { e.Method = "GTE"; return e },
|
||||
wantErr: "invalid HTTP method",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
modify: func(e APIEndpoint) APIEndpoint { e.Path = " "; return e },
|
||||
wantErr: "path is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.modify(base).validate()
|
||||
if tt.wantErr == "" {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAPIEndpoints(t *testing.T) {
|
||||
allRoutes := GetAPIEndpoints()
|
||||
|
||||
routerWithRoutes := func(routes []APIEndpoint) *mux.Router {
|
||||
r := mux.NewRouter()
|
||||
for _, route := range routes {
|
||||
path := strings.Replace(route.Path, "/_version_/", "/{fleetversion:(?:v1|latest)}/", 1)
|
||||
r.Handle(path, http.NotFoundHandler()).Methods(route.Method)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "all routes present",
|
||||
handler: routerWithRoutes(allRoutes),
|
||||
},
|
||||
{
|
||||
name: "no routes registered",
|
||||
handler: mux.NewRouter(),
|
||||
wantErr: "the following API endpoints are missing",
|
||||
},
|
||||
{
|
||||
name: "non-mux handler returns error",
|
||||
handler: http.NewServeMux(),
|
||||
wantErr: "expected *mux.Router, got *http.ServeMux",
|
||||
},
|
||||
}
|
||||
|
||||
if len(allRoutes) >= 2 {
|
||||
last := allRoutes[len(allRoutes)-1]
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
wantErr string
|
||||
}{
|
||||
name: "last route missing",
|
||||
handler: routerWithRoutes(allRoutes[:len(allRoutes)-1]),
|
||||
wantErr: last.Method + " " + last.Path,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateAPIEndpoints(tt.handler)
|
||||
if tt.wantErr == "" {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue