fleet/server/mdm/microsoft/wstep.go
Ian Littman 2f25580c3a
Only allow FLEET_DEV_* env vars when --dev is passed, allow overriding configs one at a time in dev (#38652)
Resolves #38484. This includes a CI job change to make sure we don't
introduce any more env vars that don't get proxied (and thus turned off
outside `--dev`).

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/guides/committing-changes.md#changes-files)
for more information.

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)

## Testing

- [x] Added/updated automated tests

Manual QA touched hot paths, but did _not_ manually test every
FLEET_DEV_* environment variable change.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Centralized dev-mode environment management for consistent FLEET_DEV_*
handling and test-friendly overrides.
* Dev-mode allows targeted overrides for certain dev-only configuration
when running with --dev.

* **Chores**
* Migrated environment access to the centralized dev-mode helper across
the codebase.
  * Added CI checks to enforce proper usage of FLEET_DEV_* variables.

* **Documentation**
  * Added guidance on dev-mode environment variable rules and overrides.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Victor Lyuboslavsky <2685025+getvictor@users.noreply.github.com>
2026-01-27 14:32:56 -06:00

448 lines
14 KiB
Go

package microsoft_mdm
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha1" //nolint:gosec
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"time"
"github.com/MicahParks/jwkset"
"github.com/fleetdm/fleet/v4/server"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/dev_mode"
"github.com/fleetdm/fleet/v4/server/mdm/microsoft/syncml"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
"github.com/golang-jwt/jwt/v4"
"github.com/smallstep/pkcs7"
)
// CertManager is an interface for certificate management tasks associated with Microsoft MDM (e.g.,
// signing CSRs).
type CertManager interface {
// IdentityFingerprint returns the hex-encoded, uppercased sha1 fingerprint of the identity certificate.
IdentityFingerprint() string
// SignClientCSR signs a client CSR and returns the signed, DER-encoded certificate bytes and
// its uppercased, hex-endcoded sha1 fingerprint. The subject passed is set as the common name of
// the signed certificate.
SignClientCSR(ctx context.Context, subject string, clientCSR *x509.CertificateRequest) ([]byte, string, error)
// IdentityCert returns the identity certificate of the depot.
IdentityCert() x509.Certificate
// NewSTSAuthToken returns an STS auth token for the given UPN claim.
NewSTSAuthToken(upn string) (string, error)
// GetSTSAuthTokenUPNClaim validates the given token and returns the UPN claim
GetSTSAuthTokenUPNClaim(token string) (string, error)
// TODO: implement other methods as needed:
// - verify certificate-device association
// - certificate lifecycle management (e.g., renewal, revocation)
}
// CertStore implements storage tasks associated with MS-WSTEP messages in the MS-MDE2
// protocol. It is implemented by fleet.Datastore.
type CertStore interface {
WSTEPStoreCertificate(ctx context.Context, name string, crt *x509.Certificate) error
WSTEPNewSerial(ctx context.Context) (*big.Int, error)
WSTEPAssociateCertHash(ctx context.Context, deviceUUID string, hash string) error
}
type STSClaims struct {
UPN string `json:"upn"`
jwt.RegisteredClaims
}
type AzureData struct {
UPN string
TenantID string
UniqueName string
SCP string
}
type manager struct {
store CertStore
// identityCert holds the identity certificate of the depot.
identityCert *x509.Certificate
// identityPrivateKey holds the private key of the depot.
identityPrivateKey *rsa.PrivateKey
// identityFingerprint holds the hex-encoded, sha1 fingerprint of the identity certificate.
identityFingerprint string
// maxSerialNumber holds the maximum serial number. The maximum value a serial number can have
// is 2^160. However, this could be limited further if required.
maxSerialNumber *big.Int
}
// NewCertManager returns a new CertManager instance.
func NewCertManager(store CertStore, certPEM []byte, privKeyPEM []byte) (CertManager, error) {
return newManager(store, certPEM, privKeyPEM)
}
func newManager(store CertStore, certPEM []byte, privKeyPEM []byte) (*manager, error) {
crt, err := cryptoutil.DecodePEMCertificate(certPEM)
if err != nil {
return nil, fmt.Errorf("decode certificate: %w", err)
}
key, err := server.DecodePrivateKeyPEM(privKeyPEM)
if err != nil {
return nil, fmt.Errorf("decode private key: %w", err)
}
fp := CertFingerprintHexStr(crt)
return &manager{
store: store,
identityCert: crt,
identityPrivateKey: key,
identityFingerprint: fp,
maxSerialNumber: new(big.Int).Lsh(big.NewInt(1), 128), // 2^12,
}, nil
}
func (m *manager) IdentityFingerprint() string {
if m == nil {
return ""
}
return m.identityFingerprint
}
func (m *manager) IdentityCert() x509.Certificate {
if m == nil {
return x509.Certificate{}
}
return *m.identityCert
}
// SignClientCSR returns a signed certificate from the client certificate signing request and the certificate fingerprint
// subject is the DeviceID of the about to be MDM enrolled device, it will be used as the CommonName of the certificate
// clientCSR is the client certificate signing request
func (m *manager) SignClientCSR(ctx context.Context, subject string, clientCSR *x509.CertificateRequest) ([]byte, string, error) {
if m == nil {
return nil, "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return nil, "", errors.New("invalid identity certificate or private key")
}
// serial number is used to uniquely identify the certificate
sn, err := m.store.WSTEPNewSerial(ctx)
if err != nil {
return nil, "", fmt.Errorf("failed to generate serial number: %w", err)
}
// populate the client certificate template
tmpl, err := populateClientCert(sn, subject, m.identityCert, clientCSR)
if err != nil {
return nil, "", fmt.Errorf("failed to populate client certificate: %w", err)
}
rawSignedDER, err := x509.CreateCertificate(rand.Reader, tmpl, m.identityCert, clientCSR.PublicKey, m.identityPrivateKey)
if err != nil {
return nil, "", fmt.Errorf("failed to sign client certificate: %w", err)
}
signedCert, err := x509.ParseCertificate(rawSignedDER)
if err != nil {
return nil, "", fmt.Errorf("failed to parse client certificate: %w", err)
}
if err := m.store.WSTEPStoreCertificate(ctx, subject, signedCert); err != nil {
return nil, "", fmt.Errorf("failed to store client certificate: %w", err)
}
return rawSignedDER, CertFingerprintHexStr(signedCert), nil
}
// NewSTSAuthToken returns an STS auth token for the given UPN claim.
func (m *manager) NewSTSAuthToken(upn string) (string, error) {
if m == nil {
return "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return "", errors.New("invalid identity certificate or private key")
}
if len(upn) == 0 {
return "", errors.New("invalid upn field")
}
// Create claims with upn field populated
claims := STSClaims{
upn,
jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Subject: "STSAuthToken",
},
}
// Create a new token with the claims and sign it with the private key
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
signedToken, err := token.SignedString(m.identityPrivateKey)
if err != nil {
return "", fmt.Errorf("failed to sign STS token: %w", err)
}
return signedToken, nil
}
// GetSTSAuthToken validates the given token and returns the UPN claim
func (m *manager) GetSTSAuthTokenUPNClaim(tokenStr string) (string, error) {
if m == nil {
return "", errors.New("windows mdm identity keypair was not configured")
}
if m.identityCert == nil || m.identityPrivateKey == nil {
return "", errors.New("invalid identity certificate or private key")
}
if len(tokenStr) == 0 {
return "", errors.New("invalid STS token")
}
// Since we used the private key to sign the tokens, we use the public counterpart to verify the signature
token, err := jwt.ParseWithClaims(tokenStr, &STSClaims{}, func(token *jwt.Token) (any, error) {
return m.identityCert.PublicKey, nil
})
if err != nil {
return "", fmt.Errorf("there was an error parsing the STS token claims: %w", err)
}
if claims, ok := token.Claims.(*STSClaims); ok && token.Valid {
if len(claims.UPN) == 0 {
return "", errors.New("issue with UPN token claim")
}
return claims.UPN, nil
}
return "", errors.New("issue with STS token validation")
}
// GetAzureAuthTokenClaims validates the given Azure AD token and returns
// UPN, TenantID, UniqueName, DeviceID
func GetAzureAuthTokenClaims(ctx context.Context, tokenStr string) (AzureData, error) {
if len(tokenStr) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid STS token")
}
// Decode base64 token
tokenBytes, err := base64.StdEncoding.DecodeString(tokenStr)
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "invalid Azure JWT token")
}
// Validate token format (header.payload.signature)
parts := bytes.Split(tokenBytes, []byte("."))
if len(parts) != 3 {
return AzureData{}, ctxerr.New(ctx, "invalid Azure JWT format")
}
// Parse JWT token
jwksURI := "https://login.microsoftonline.com/common/discovery/v2.0/keys"
var token *jwt.Token
FLEET_DEV_AZURE_JWT_JWKS_URI := dev_mode.Env("FLEET_DEV_AZURE_JWT_JWKS_URI")
if FLEET_DEV_AZURE_JWT_JWKS_URI != "" {
jwksURI = FLEET_DEV_AZURE_JWT_JWKS_URI
}
keys, err := jwkset.NewDefaultHTTPClient([]string{jwksURI})
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "failed to retrieve Azure JWT signing keys")
}
token, err = jwt.Parse(string(tokenBytes), func(token *jwt.Token) (any, error) {
tokenAlg, ok := token.Header["alg"]
if !ok {
return nil, errors.New("Azure JWT missing alg header")
}
tokenAlgStr, ok := tokenAlg.(string)
if !ok {
return nil, errors.New("invalid alg header in Azure JWT")
}
kid, ok := token.Header["kid"]
if !ok {
return nil, errors.New("Azure JWT missing kid header")
}
kidStr, ok := kid.(string)
if !ok {
return nil, errors.New("invalid kid header in Azure JWT")
}
key, err := keys.KeyRead(ctx, kidStr)
if err != nil {
if errors.Is(err, jwkset.ErrKeyNotFound) {
return nil, fmt.Errorf("Azure JWT signed by unknown key: %w", err)
}
return nil, fmt.Errorf("failed to retrieve Azure JWT signing key: %w", err)
}
// Alg is optional in the JWK but if present must match the token
keyAlg := key.Marshal().ALG.String()
if keyAlg != "" && keyAlg != tokenAlgStr {
return nil, fmt.Errorf("Azure JWT signing key algorithm mismatch: expected %s from key, got %s", keyAlg, tokenAlgStr)
}
return key.Key(), nil
})
if err != nil {
return AzureData{}, ctxerr.Wrap(ctx, err, "parse error Azure JWT content")
}
// Parse JWT token
claims := token.Claims.(jwt.MapClaims)
// Get UPN claim
upnClaim, ok := claims["upn"].(string)
if !ok || len(upnClaim) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid UPN claim")
}
// Get TenantID claim
tenantIDClaim, ok := claims["tid"].(string)
if !ok || len(tenantIDClaim) == 0 {
return AzureData{}, ctxerr.New(ctx, "invalid TenantID claim")
}
// Get UniqueName claim
uniqueNameClaim, ok := claims["unique_name"].(string)
if !ok {
return AzureData{}, ctxerr.New(ctx, "invalid UniqueName claim")
}
// Get SCP claim
azureSCPClaim, ok := claims["scp"].(string)
if !ok || azureSCPClaim != "mdm_delegation" {
return AzureData{}, ctxerr.New(ctx, "invalid SCP claim")
}
return AzureData{
UPN: upnClaim,
TenantID: tenantIDClaim,
UniqueName: uniqueNameClaim,
SCP: azureSCPClaim,
}, nil
}
func populateClientCert(sn *big.Int, subject string, issuerCert *x509.Certificate, csr *x509.CertificateRequest) (*x509.Certificate, error) {
certRenewalPeriodInSecsInt, err := strconv.Atoi(syncml.PolicyCertRenewalPeriodInSecs)
if err != nil {
return nil, fmt.Errorf("invalid renewal time: %w", err)
}
notBeforeDuration := time.Now().Add(time.Duration(certRenewalPeriodInSecsInt) * -time.Second)
yearDuration := 365 * 24 * time.Hour
certSubject := pkix.Name{
OrganizationalUnit: []string{syncml.DocProvisioningAppProviderID},
CommonName: subject,
}
tmpl := &x509.Certificate{
Subject: certSubject,
Issuer: issuerCert.Issuer,
Version: csr.Version,
PublicKey: csr.PublicKey,
PublicKeyAlgorithm: csr.PublicKeyAlgorithm,
Signature: csr.Signature,
SignatureAlgorithm: x509.SHA256WithRSA,
Extensions: csr.Extensions,
ExtraExtensions: csr.ExtraExtensions,
IPAddresses: csr.IPAddresses,
EmailAddresses: csr.EmailAddresses,
DNSNames: csr.DNSNames,
URIs: csr.URIs,
NotBefore: notBeforeDuration,
NotAfter: notBeforeDuration.Add(yearDuration),
SerialNumber: sn,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: false,
}
return tmpl, nil
}
// GetClientCSR returns the client certificate signing request from the BinarySecurityToken
func GetClientCSR(binSecTokenData string, tokenType string) (*x509.CertificateRequest, error) {
// Checking if this is a valid enroll security token (CSR)
if (tokenType != syncml.EnrollReqTypePKCS10) && (tokenType != syncml.EnrollReqTypePKCS7) {
return nil, fmt.Errorf("token type is not valid for MDM enrollment: %s", tokenType)
}
// Decoding the Base64 encoded binary security token to obtain the client CSR bytes
rawCSR, err := base64.StdEncoding.DecodeString(binSecTokenData)
if err != nil {
return nil, fmt.Errorf("decoding the binary security token: %w", err)
}
// Sanity checks on binary signature token
// Sanity checks are done on PKCS10 for the moment
if tokenType == syncml.EnrollReqTypePKCS7 {
// Parse the CSR in PKCS7 Syntax Standard
pk7CSR, err := pkcs7.Parse(rawCSR)
if err != nil {
return nil, fmt.Errorf("parsing the binary security token: %v", err)
}
// Verify the signatures of the CSR PKCS7 object
err = pk7CSR.Verify()
if err != nil {
return nil, fmt.Errorf("verifying CSR data: %v", err)
}
// Verify signing time
currentTime := time.Now()
if currentTime.Before(pk7CSR.GetOnlySigner().NotBefore) || currentTime.After(pk7CSR.GetOnlySigner().NotAfter) {
return nil, fmt.Errorf("invalid CSR signing time: %v", err)
}
}
// Decode and verify CSR
certCSR, err := ParseCertificateRequestFromWindowsDevice(rawCSR)
if err != nil {
return nil, fmt.Errorf("parsing CSR data: %v", err)
}
err = certCSR.CheckSignature()
if err != nil {
return nil, fmt.Errorf("CSR signature: %v", err)
}
if certCSR.PublicKey == nil {
return nil, fmt.Errorf("CSR public key: %v", err)
}
if len(certCSR.Subject.String()) == 0 {
return nil, fmt.Errorf("CSR subject: %v", err)
}
return certCSR, nil
}
// CertFingerprintHexStr returns the hex-encoded, uppercased sha1 fingerprint of the certificate.
func CertFingerprintHexStr(cert *x509.Certificate) string {
// Windows Certificate Store requires passing the certificate thumbprint, which is the same as
// SHA1 fingerprint. See also:
// https://security.stackexchange.com/questions/14330/what-is-the-actual-value-of-a-certificate-fingerprint
// https://www.thesslstore.com/blog/ssl-certificate-still-sha-1-thumbprint/
fingerprint := sha1.Sum(cert.Raw) //nolint:gosec
return strings.ToUpper(hex.EncodeToString(fingerprint[:]))
}