fleet/server/mdm/microsoft/wstep.go
Konstantin Sykulev ac16eb234c
Verifying jwt signing algo to prevent vulnerability (#43474)
Related to a vulnerability found when working on
https://github.com/fleetdm/fleet/pull/43295
https://github.com/fleetdm/fleet/pull/43295#discussion_r3065433754

`golang-jwt/jwt/v5` library already mitigates this, however, we are
using `v4` which does not include this check.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Enforced RSA-only validation for JWTs used in authentication; tokens
signed with non-RSA algorithms are now rejected.
* **Tests**
* Added tests to verify that non-RSA and unsigned JWTs are rejected and
produce the expected error.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
2026-04-13 19:11:55 -05:00

583 lines
19 KiB
Go

package microsoft_mdm
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1" //nolint:gosec
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"time"
"github.com/MicahParks/jwkset"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/dev_mode"
"github.com/fleetdm/fleet/v4/server/mdm/microsoft/syncml"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/smallstep/pkcs7"
)
// CertManager is an interface for certificate management tasks associated with Microsoft MDM (e.g.,
// signing CSRs).
type CertManager interface {
// IdentityFingerprint returns the hex-encoded, uppercased sha1 fingerprint of the identity certificate.
IdentityFingerprint() string
// SignClientCSR signs a client CSR and returns the signed, DER-encoded certificate bytes and
// its uppercased, hex-endcoded sha1 fingerprint. The subject passed is set as the common name of
// the signed certificate.
SignClientCSR(ctx context.Context, subject string, clientCSR *x509.CertificateRequest) ([]byte, string, error)
// IdentityCert returns the identity certificate of the depot.
IdentityCert() x509.Certificate
// NewSTSAuthToken returns an STS auth token for the given UPN claim.
NewSTSAuthToken(upn string) (string, error)
// NewEUAToken returns a Fleet-signed JWT for the given UPN and Windows MDM
// device ID. Used to pass end-user authentication context to the orbit
// installer so the user is not prompted twice.
NewEUAToken(upn string, deviceID string) (string, error)
// GetSTSAuthTokenUPNClaim validates the given token and returns the UPN claim
GetSTSAuthTokenUPNClaim(token string) (string, error)
// GetEUATokenClaims validates the given EUA token and returns the parsed claims.
GetEUATokenClaims(token string) (*EUATokenClaims, error)
// TODO: implement other methods as needed:
// - verify certificate-device association
// - certificate lifecycle management (e.g., renewal, revocation)
}
// CertStore implements storage tasks associated with MS-WSTEP messages in the MS-MDE2
// protocol. It is implemented by fleet.Datastore.
type CertStore interface {
WSTEPStoreCertificate(ctx context.Context, name string, crt *x509.Certificate) error
WSTEPNewSerial(ctx context.Context) (*big.Int, error)
WSTEPAssociateCertHash(ctx context.Context, deviceUUID string, hash string) error
}
type STSClaims struct {
UPN string `json:"upn"`
jwt.RegisteredClaims
}
// euaJWTClaims is the internal JWT struct for signing/parsing EUA tokens.
type euaJWTClaims struct {
UPN string `json:"upn"`
DeviceID string `json:"device_id"`
jwt.RegisteredClaims
}
// EUATokenClaims is the validated result returned to callers of GetEUATokenClaims.
type EUATokenClaims struct {
UPN string
DeviceID string
}
type AzureData struct {
UPN string
Audience []string
TenantID string
UniqueName string
SCP string
}
type manager struct {
store CertStore
// identityCert holds the identity certificate of the depot.
identityCert *x509.Certificate
// identityPrivateKey holds the private key of the depot.
identityPrivateKey *rsa.PrivateKey
// identityFingerprint holds the hex-encoded, sha1 fingerprint of the identity certificate.
identityFingerprint string
// maxSerialNumber holds the maximum serial number. The maximum value a serial number can have
// is 2^160. However, this could be limited further if required.
maxSerialNumber *big.Int
}
// NewCertManager returns a new CertManager instance.
func NewCertManager(store CertStore, certPEM []byte, privKeyPEM []byte) (CertManager, error) {
return newManager(store, certPEM, privKeyPEM)
}
func newManager(store CertStore, certPEM []byte, privKeyPEM []byte) (*manager, error) {
crt, err := cryptoutil.DecodePEMCertificate(certPEM)
if err != nil {
return nil, fmt.Errorf("decode certificate: %w", err)
}
key, err := server.DecodePrivateKeyPEM(privKeyPEM)
if err != nil {
return nil, fmt.Errorf("decode private key: %w", err)
}
fp := CertFingerprintHexStr(crt)
return &manager{
store: store,
identityCert: crt,
identityPrivateKey: key,
identityFingerprint: fp,
maxSerialNumber: new(big.Int).Lsh(big.NewInt(1), 128), // 2^12,
}, nil
}
func (m *manager) IdentityFingerprint() string {
if m == nil {
return ""
}
return m.identityFingerprint
}
func (m *manager) IdentityCert() x509.Certificate {
if m == nil {
return x509.Certificate{}
}
return *m.identityCert
}
// SignClientCSR returns a signed certificate from the client certificate signing request and the certificate fingerprint
// subject is the DeviceID of the about to be MDM enrolled device, it will be used as the CommonName of the certificate
// clientCSR is the client certificate signing request
func (m *manager) SignClientCSR(ctx context.Context, subject string, clientCSR *x509.CertificateRequest) ([]byte, string, error) {
if m == nil {
return nil, "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return nil, "", errors.New("invalid identity certificate or private key")
}
// serial number is used to uniquely identify the certificate
sn, err := m.store.WSTEPNewSerial(ctx)
if err != nil {
return nil, "", fmt.Errorf("failed to generate serial number: %w", err)
}
// populate the client certificate template
tmpl, err := populateClientCert(sn, subject, m.identityCert, clientCSR)
if err != nil {
return nil, "", fmt.Errorf("failed to populate client certificate: %w", err)
}
rawSignedDER, err := x509.CreateCertificate(rand.Reader, tmpl, m.identityCert, clientCSR.PublicKey, m.identityPrivateKey)
if err != nil {
return nil, "", fmt.Errorf("failed to sign client certificate: %w", err)
}
signedCert, err := x509.ParseCertificate(rawSignedDER)
if err != nil {
return nil, "", fmt.Errorf("failed to parse client certificate: %w", err)
}
if err := m.store.WSTEPStoreCertificate(ctx, subject, signedCert); err != nil {
return nil, "", fmt.Errorf("failed to store client certificate: %w", err)
}
return rawSignedDER, CertFingerprintHexStr(signedCert), nil
}
// NewSTSAuthToken returns an STS auth token for the given UPN claim.
func (m *manager) NewSTSAuthToken(upn string) (string, error) {
if m == nil {
return "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return "", errors.New("invalid identity certificate or private key")
}
if len(upn) == 0 {
return "", errors.New("invalid upn field")
}
// Create claims with upn field populated
claims := STSClaims{
UPN: upn,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Subject: "STSAuthToken",
},
}
// Create a new token with the claims and sign it with the private key
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
signedToken, err := token.SignedString(m.identityPrivateKey)
if err != nil {
return "", fmt.Errorf("failed to sign STS token: %w", err)
}
return signedToken, nil
}
// NewEUAToken returns a Fleet-signed JWT for the given UPN and Windows MDM device ID.
func (m *manager) NewEUAToken(upn string, deviceID string) (string, error) {
if m == nil {
return "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return "", errors.New("invalid identity certificate or private key")
}
if len(upn) == 0 {
return "", errors.New("invalid upn field")
}
if len(deviceID) == 0 {
return "", errors.New("invalid device_id field")
}
claims := euaJWTClaims{
UPN: upn,
DeviceID: deviceID,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Subject: "EUAToken",
},
}
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
signedToken, err := token.SignedString(m.identityPrivateKey)
if err != nil {
return "", fmt.Errorf("failed to sign EUA token: %w", err)
}
return signedToken, nil
}
// GetEUATokenClaims validates the given EUA token and returns the parsed claims.
func (m *manager) GetEUATokenClaims(tokenStr string) (*EUATokenClaims, error) {
if m == nil {
return nil, errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return nil, errors.New("invalid identity certificate or private key")
}
if len(tokenStr) == 0 {
return nil, errors.New("invalid EUA token")
}
token, err := jwt.ParseWithClaims(tokenStr, &euaJWTClaims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.identityCert.PublicKey, nil
})
if err != nil {
return nil, fmt.Errorf("there was an error parsing the EUA token claims: %w", err)
}
if claims, ok := token.Claims.(*euaJWTClaims); ok && token.Valid {
if len(claims.UPN) == 0 {
return nil, errors.New("issue with UPN token claim")
}
if len(claims.DeviceID) == 0 {
return nil, errors.New("issue with device_id token claim")
}
return &EUATokenClaims{UPN: claims.UPN, DeviceID: claims.DeviceID}, nil
}
return nil, errors.New("issue with EUA token validation")
}
// GetSTSAuthToken validates the given token and returns the UPN claim
func (m *manager) GetSTSAuthTokenUPNClaim(tokenStr string) (string, error) {
if m == nil {
return "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return "", errors.New("invalid identity certificate or private key")
}
if len(tokenStr) == 0 {
return "", errors.New("invalid STS token")
}
// Since we used the private key to sign the tokens, we use the public counterpart to verify the signature
token, err := jwt.ParseWithClaims(tokenStr, &STSClaims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return m.identityCert.PublicKey, nil
})
if err != nil {
return "", fmt.Errorf("there was an error parsing the STS token claims: %w", err)
}
if claims, ok := token.Claims.(*STSClaims); ok && token.Valid {
if len(claims.UPN) == 0 {
return "", errors.New("issue with UPN token claim")
}
return claims.UPN, nil
}
return "", errors.New("issue with STS token validation")
}
// GetAzureAuthTokenClaims validates the given Azure AD token and returns
// UPN, TenantID, UniqueName, DeviceID
func GetAzureAuthTokenClaims(ctx context.Context, tokenStr string) (AzureData, error) {
if len(tokenStr) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid STS token")
}
// Decode base64 token
tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr)
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "invalid Azure JWT token")
}
// Validate token format (header.payload.signature)
parts := bytes.Split(tokenBytes, []byte("."))
if len(parts) != 3 {
return AzureData{}, ctxerr.New(ctx, "invalid Azure JWT format")
}
// Parse JWT token
jwksURI := "https://login.microsoftonline.com/common/discovery/v2.0/keys"
var token *jwt.Token
FLEET_DEV_AZURE_JWT_JWKS_URI := dev_mode.Env("FLEET_DEV_AZURE_JWT_JWKS_URI")
if FLEET_DEV_AZURE_JWT_JWKS_URI != "" {
jwksURI = FLEET_DEV_AZURE_JWT_JWKS_URI
}
keys, err := jwkset.NewDefaultHTTPClient([]string{jwksURI})
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "failed to retrieve Azure JWT signing keys")
}
token, err = jwt.Parse(string(tokenBytes), func(token *jwt.Token) (any, error) {
tokenAlg, ok := token.Header["alg"]
if !ok {
return nil, errors.New("Azure JWT missing alg header")
}
tokenAlgStr, ok := tokenAlg.(string)
if !ok {
return nil, errors.New("invalid alg header in Azure JWT")
}
kid, ok := token.Header["kid"]
if !ok {
return nil, errors.New("Azure JWT missing kid header")
}
kidStr, ok := kid.(string)
if !ok {
return nil, errors.New("invalid kid header in Azure JWT")
}
key, err := keys.KeyRead(ctx, kidStr)
if err != nil {
if errors.Is(err, jwkset.ErrKeyNotFound) {
return nil, fmt.Errorf("Azure JWT signed by unknown key: %w", err)
}
return nil, fmt.Errorf("failed to retrieve Azure JWT signing key: %w", err)
}
// Alg is optional in the JWK but if present must match the token
keyAlg := key.Marshal().ALG.String()
if keyAlg != "" && keyAlg != tokenAlgStr {
return nil, fmt.Errorf("Azure JWT signing key algorithm mismatch: expected %s from key, got %s", keyAlg, tokenAlgStr)
}
return key.Key(), nil
})
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "parse error Azure JWT content")
}
// Parse JWT token
claims := token.Claims.(jwt.MapClaims)
// Get UPN claim
upnClaim, ok := claims["upn"].(string)
if !ok || len(upnClaim) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid UPN claim")
}
// Get TenantID claim
tenantIDClaim, ok := claims["tid"].(string)
if !ok || len(tenantIDClaim) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid TenantID claim")
}
// Validate that tenant ID is a UUID and matches the issuer
_, err = uuid.Parse(tenantIDClaim)
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "invalid TenantID claim format")
}
issuer, ok := claims["iss"].(string)
if !ok || len(issuer) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid Issuer claim")
}
// Depending on exactly how the Azure AD app is configured, the issuer claim
// may vary. Validate that the issuer contains the tenant ID.
issuerMatchesTenant := false
for _, expectedIssuer := range []string{fmt.Sprintf("https://sts.windows.net/%s/", tenantIDClaim), fmt.Sprintf("https://login.microsoftonline.com/%s/", tenantIDClaim)} {
if strings.HasPrefix(issuer, expectedIssuer) {
issuerMatchesTenant = true
break
}
}
if !issuerMatchesTenant {
return AzureData{}, ctxerr.New(ctx, "issuer claim does not match tenant ID")
}
audience := []string{}
singleAudience, ok := claims["aud"].(string)
if !ok {
multiAudience, ok := claims["aud"].([]string)
if ok {
audience = multiAudience
}
} else {
audience = append(audience, singleAudience)
}
// Get UniqueName claim
uniqueNameClaim, ok := claims["unique_name"].(string)
if !ok {
return AzureData{}, ctxerr.New(ctx, "invalid UniqueName claim")
}
// Get SCP claim
azureSCPClaim, ok := claims["scp"].(string)
if !ok || azureSCPClaim != "mdm_delegation" {
return AzureData{}, ctxerr.New(ctx, "invalid SCP claim")
}
return AzureData{
UPN: upnClaim,
TenantID: tenantIDClaim,
UniqueName: uniqueNameClaim,
SCP: azureSCPClaim,
Audience: audience,
}, nil
}
func populateClientCert(sn *big.Int, subject string, issuerCert *x509.Certificate, csr *x509.CertificateRequest) (*x509.Certificate, error) {
certRenewalPeriodInSecsInt, err := strconv.Atoi(syncml.PolicyCertRenewalPeriodInSecs)
if err != nil {
return nil, fmt.Errorf("invalid renewal time: %w", err)
}
notBeforeDuration := time.Now().Add(time.Duration(certRenewalPeriodInSecsInt) * -time.Second)
yearDuration := 365 * 24 * time.Hour
certSubject := pkix.Name{
OrganizationalUnit: []string{syncml.DocProvisioningAppProviderID},
CommonName: subject,
}
tmpl := &x509.Certificate{
Subject: certSubject,
Issuer: issuerCert.Issuer,
Version: csr.Version,
PublicKey: csr.PublicKey,
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
Signature: csr.Signature,
SignatureAlgorithm: x509.SHA256WithRSA,
Extensions: csr.Extensions,
ExtraExtensions: csr.ExtraExtensions,
IPAddresses: csr.IPAddresses,
EmailAddresses: csr.EmailAddresses,
DNSNames: csr.DNSNames,
URIs: csr.URIs,
NotBefore: notBeforeDuration,
NotAfter: notBeforeDuration.Add(yearDuration),
SerialNumber: sn,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: false,
}
return tmpl, nil
}
// GetClientCSR returns the client certificate signing request from the BinarySecurityToken
func GetClientCSR(binSecTokenData string, tokenType string) (*x509.CertificateRequest, error) {
// Checking if this is a valid enroll security token (CSR)
if (tokenType != syncml.EnrollReqTypePKCS10) && (tokenType != syncml.EnrollReqTypePKCS7) {
return nil, fmt.Errorf("token type is not valid for MDM enrollment: %s", tokenType)
}
// Decoding the Base64 encoded binary security token to obtain the client CSR bytes
rawCSR, err := base64.StdEncoding.DecodeString(binSecTokenData)
if err != nil {
return nil, fmt.Errorf("decoding the binary security token: %w", err)
}
// Sanity checks on binary signature token
// Sanity checks are done on PKCS10 for the moment
if tokenType == syncml.EnrollReqTypePKCS7 {
// Parse the CSR in PKCS7 Syntax Standard
pk7CSR, err := pkcs7.Parse(rawCSR)
if err != nil {
return nil, fmt.Errorf("parsing the binary security token: %v", err)
}
// Verify the signatures of the CSR PKCS7 object
err = pk7CSR.Verify()
if err != nil {
return nil, fmt.Errorf("verifying CSR data: %v", err)
}
// Verify signing time
currentTime := time.Now()
if currentTime.Before(pk7CSR.GetOnlySigner().NotBefore) || currentTime.After(pk7CSR.GetOnlySigner().NotAfter) {
return nil, fmt.Errorf("invalid CSR signing time: %v", err)
}
}
// Decode and verify CSR
certCSR, err := ParseCertificateRequestFromWindowsDevice(rawCSR)
if err != nil {
return nil, fmt.Errorf("parsing CSR data: %v", err)
}
err = certCSR.CheckSignature()
if err != nil {
return nil, fmt.Errorf("CSR signature: %v", err)
}
if certCSR.PublicKey == nil {
return nil, fmt.Errorf("CSR public key: %v", err)
}
if len(certCSR.Subject.String()) == 0 {
return nil, fmt.Errorf("CSR subject: %v", err)
}
return certCSR, nil
}
// CertFingerprintHexStr returns the hex-encoded, uppercased sha1 fingerprint of the certificate.
func CertFingerprintHexStr(cert *x509.Certificate) string {
// Windows Certificate Store requires passing the certificate thumbprint, which is the same as
// SHA1 fingerprint. See also:
// https://security.stackexchange.com/questions/14330/what-is-the-actual-value-of-a-certificate-fingerprint
// https://www.thesslstore.com/blog/ssl-certificate-still-sha-1-thumbprint/
fingerprint := sha1.Sum(cert.Raw) //nolint:gosec
return strings.ToUpper(hex.EncodeToString(fingerprint[:]))
}