Support restricting API token access based on IP address

This commit is contained in:
MaysWind 2026-03-04 23:46:02 +08:00
parent f0f3143605
commit 404cd62d7b
4 changed files with 84 additions and 21 deletions

View file

@ -316,6 +316,7 @@ func startWebServer(c *core.CliContext) error {
apiV1Route := apiRoute.Group("/v1")
apiV1Route.Use(bindMiddleware(middlewares.JWTAuthorization(config)))
apiV1Route.Use(bindMiddleware(middlewares.APITokenIpLimit(config)))
{
// Tokens
apiV1Route.GET("/tokens/list.json", bindApi(api.Tokens.TokenListHandler))

View file

@ -296,6 +296,9 @@ password_reset_token_expired_time = 3600
# Set to true to enable API token generation
enable_api_token = false
# Allowed remote IPs for using the API token, a comma-separated list of allowed remote IPs (asterisk * for any addresses, e.g. 192.168.1.* means any IPs in the 192.168.1.x subnet), leave blank to allow all remote IPs
api_token_allowed_remote_ips =
# Maximum count of password / token check failures (0 - 4294967295) per IP per minute (use the above duplicate checker), default is 5, set to 0 to disable
max_failures_per_ip_per_minute = 5

View file

@ -0,0 +1,39 @@
package middlewares
import (
"github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/settings"
"github.com/mayswind/ezbookkeeping/pkg/utils"
)
// APITokenIpLimit limits API token access based on IP address
func APITokenIpLimit(config *settings.Config) core.MiddlewareHandlerFunc {
return func(c *core.WebContext) {
claims := c.GetTokenClaims()
if claims == nil {
c.Next()
return
}
if claims.Type != core.USER_TOKEN_TYPE_API {
c.Next()
return
}
if len(config.APITokenAllowedRemoteIPs) < 1 {
c.Next()
return
}
for i := 0; i < len(config.APITokenAllowedRemoteIPs); i++ {
if config.APITokenAllowedRemoteIPs[i].Match(c.ClientIP()) {
c.Next()
return
}
}
utils.PrintJsonErrorResult(c, errs.ErrIPForbidden)
}
}

View file

@ -370,6 +370,7 @@ type Config struct {
PasswordResetTokenExpiredTime uint32
PasswordResetTokenExpiredTimeDuration time.Duration
EnableAPIToken bool
APITokenAllowedRemoteIPs []*core.IPPattern
MaxFailuresPerIpPerMinute uint32
MaxFailuresPerUserPerMinute uint32
@ -667,29 +668,13 @@ func loadServerConfiguration(config *Config, configFile *ini.File, sectionName s
}
func loadMCPServerConfiguration(config *Config, configFile *ini.File, sectionName string) error {
var err error
config.EnableMCPServer = getConfigItemBoolValue(configFile, sectionName, "enable_mcp", false)
mcpAllowedRemoteIps := getConfigItemStringValue(configFile, sectionName, "mcp_allowed_remote_ips", "")
config.MCPAllowedRemoteIPs, err = getIPPatterns(configFile, sectionName, "mcp_allowed_remote_ips", "")
if mcpAllowedRemoteIps != "" {
remoteIPs := strings.Split(mcpAllowedRemoteIps, ",")
config.MCPAllowedRemoteIPs = make([]*core.IPPattern, 0, len(remoteIPs))
for i := 0; i < len(remoteIPs); i++ {
ip := strings.TrimSpace(remoteIPs[i])
pattern, err := core.ParseIPPattern(ip)
if err != nil {
return err
}
if pattern == nil {
continue
}
config.MCPAllowedRemoteIPs = append(config.MCPAllowedRemoteIPs, pattern)
}
} else {
config.MCPAllowedRemoteIPs = nil
if err != nil {
return err
}
return nil
@ -976,6 +961,8 @@ func loadCronConfiguration(config *Config, configFile *ini.File, sectionName str
}
func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName string) error {
var err error
config.SecretKeyNoSet = !getConfigItemIsSet(configFile, sectionName, "secret_key")
config.SecretKey = getConfigItemStringValue(configFile, sectionName, "secret_key", defaultSecretKey)
@ -1018,6 +1005,11 @@ func loadSecurityConfiguration(config *Config, configFile *ini.File, sectionName
config.PasswordResetTokenExpiredTimeDuration = time.Duration(config.PasswordResetTokenExpiredTime) * time.Second
config.EnableAPIToken = getConfigItemBoolValue(configFile, sectionName, "enable_api_token", false)
config.APITokenAllowedRemoteIPs, err = getIPPatterns(configFile, sectionName, "api_token_allowed_remote_ips", "")
if err != nil {
return err
}
config.MaxFailuresPerIpPerMinute = getConfigItemUint32Value(configFile, sectionName, "max_failures_per_ip_per_minute", defaultMaxFailuresPerIpPerMinute)
config.MaxFailuresPerUserPerMinute = getConfigItemUint32Value(configFile, sectionName, "max_failures_per_user_per_minute", defaultMaxFailuresPerUserPerMinute)
@ -1260,6 +1252,34 @@ func getFinalPath(workingPath, p string) (string, error) {
return p, err
}
func getIPPatterns(configFile *ini.File, sectionName string, itemName string, defaultValue string) ([]*core.IPPattern, error) {
configValue := getConfigItemStringValue(configFile, sectionName, itemName, defaultValue)
if configValue == "" {
return nil, nil
}
remoteIPs := strings.Split(configValue, ",")
ipPatterns := make([]*core.IPPattern, 0, len(remoteIPs))
for i := 0; i < len(remoteIPs); i++ {
ip := strings.TrimSpace(remoteIPs[i])
pattern, err := core.ParseIPPattern(ip)
if err != nil {
return nil, err
}
if pattern == nil {
continue
}
ipPatterns = append(ipPatterns, pattern)
}
return ipPatterns, nil
}
func getMultiLanguageContentConfig(configFile *ini.File, sectionName string, enableKey string, contentKey string) MultiLanguageContentConfig {
config := MultiLanguageContentConfig{
Enabled: getConfigItemBoolValue(configFile, sectionName, enableKey, false),