fix: address feedback

This commit is contained in:
Jahziel Villasana-Espinoza 2024-05-23 20:55:36 -04:00
parent ef52ff8f70
commit 2d8038ddd0
4 changed files with 45 additions and 28 deletions

View file

@ -4153,19 +4153,17 @@ SELECT
FROM
mdm_config_assets
WHERE
name IN (%s)
name IN (?)
AND deletion_uuid = ''
`
var b []any
var p strings.Builder
for _, an := range assetNames {
b = append(b, an)
p.WriteString("?,")
stmt, args, err := sqlx.In(stmt, assetNames)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "sqlx.In GetMDMConfigAssetsByName")
}
stmt = fmt.Sprintf(stmt, strings.TrimSuffix(p.String(), ","))
var res []fleet.MDMConfigAsset
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &res, stmt, b...); err != nil {
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &res, stmt, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "get mdm config assets by name")
}

View file

@ -689,10 +689,10 @@ type Service interface {
GetAppleBM(ctx context.Context) (*AppleBM, error)
RequestMDMAppleCSR(ctx context.Context, email, org string) (*AppleCSR, error)
// GetMDMAppleCSR returns a signed CSR as a base64 encoded string for Apple MDM. The first time
// GetMDMAppleCSR returns a signed CSR as base64 encoded bytes for Apple MDM. The first time
// this method is called, it will create a SCEP certificate, a SCEP key, and an APNS key and
// write these to the DB. On subsequent calls, it will use the saved APNS key for generating the CSR.
GetMDMAppleCSR(ctx context.Context) (string, error)
GetMDMAppleCSR(ctx context.Context) ([]byte, error)
// GetHostDEPAssignment retrieves the host DEP assignment for the specified host.
GetHostDEPAssignment(ctx context.Context, host *Host) (*HostDEPAssignment, error)

View file

@ -60,9 +60,13 @@ func GenerateAPNSCSRKey(email, org string) (*x509.CertificateRequest, *rsa.Priva
return certReq, key, nil
}
func GenerateAPNSCSR(org string, key *rsa.PrivateKey) (*x509.CertificateRequest, error) {
func GenerateAPNSCSR(org, email string, key *rsa.PrivateKey) (*x509.CertificateRequest, error) {
subj := pkix.Name{
Organization: []string{org},
ExtraNames: []pkix.AttributeTypeAndValue{{
Type: emailAddressOID,
Value: email,
}},
}
template := &x509.CertificateRequest{
Subject: subj,
@ -142,8 +146,12 @@ func GetSignedAPNSCSR(client *http.Client, csr *x509.CertificateRequest) error {
return nil
}
type WebsiteResponse struct {
CSR []byte `json:"csr"`
}
// GetSignedAPNSCSRNoEmail makes a request to the fleetdm.com API to get a signed APNs
// CSR and returns the signed CSR.
// CSR and returns the signed CSR directly.
func GetSignedAPNSCSRNoEmail(client *http.Client, csr *x509.CertificateRequest) ([]byte, error) {
csrPEM := EncodeCertRequestPEM(csr)
@ -179,7 +187,12 @@ func GetSignedAPNSCSRNoEmail(client *http.Client, csr *x509.CertificateRequest)
return nil, FleetWebsiteError{Status: resp.StatusCode, message: string(respBytes)}
}
return respBytes, nil
var csrResp WebsiteResponse
if err := json.Unmarshal(respBytes, &csrResp); err != nil {
return nil, err
}
return csrResp.CSR, nil
}
// NewSCEPCACertKey creates a self-signed CA certificate for use with SCEP and

View file

@ -2120,7 +2120,7 @@ func (svc *Service) ResendHostMDMProfile(ctx context.Context, hostID uint, profi
type getMDMAppleCSRRequest struct{}
type getMDMAppleCSRResponse struct {
CSR string `json:"csr"` // base64 encoded
CSR []byte `json:"csr"` // base64 encoded
Err error `json:"error,omitempty"`
}
@ -2135,28 +2135,33 @@ func getMDMAppleCSREndpoint(ctx context.Context, request interface{}, svc fleet.
return &getMDMAppleCSRResponse{CSR: signedCSRB64}, nil
}
func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
func (svc *Service) GetMDMAppleCSR(ctx context.Context) ([]byte, error) {
if err := svc.authz.Authorize(ctx, &fleet.AppleCSR{}, fleet.ActionWrite); err != nil {
return "", err
return nil, ctxerr.Wrap(ctx, err)
}
vc, ok := viewer.FromContext(ctx)
if !ok {
return nil, fleet.ErrNoContext
}
// Check if we have existing certs and keys
var apnsKey *rsa.PrivateKey
savedAssets, err := svc.ds.GetMDMConfigAssetsByName(ctx, []fleet.MDMAssetName{fleet.MDMAssetCACert, fleet.MDMAssetCAKey, fleet.MDMAssetAPNSKey})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "checking asset existence")
return nil, ctxerr.Wrap(ctx, err, "checking asset existence")
}
if len(savedAssets) == 0 {
// Then we should create them
scepCert, scepKey, err := apple_mdm.NewSCEPCACertKey()
if err != nil {
return "", ctxerr.Wrap(ctx, err, "generate SCEP cert and key")
return nil, ctxerr.Wrap(ctx, err, "generate SCEP cert and key")
}
apnsKey, err = apple_mdm.NewPrivateKey()
if err != nil {
return "", ctxerr.Wrap(ctx, err, "generate new apns private key")
return nil, ctxerr.Wrap(ctx, err, "generate new apns private key")
}
// Store our config assets
@ -2173,7 +2178,7 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
}
if err := svc.ds.InsertMDMConfigAssets(ctx, assets); err != nil {
return "", ctxerr.Wrap(ctx, err, "inserting mdm config assets")
return nil, ctxerr.Wrap(ctx, err, "inserting mdm config assets")
}
} else {
for _, a := range savedAssets {
@ -2181,7 +2186,7 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
block, _ := pem.Decode(a.Value)
apnsKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "unmarshaling saved apns key")
return nil, ctxerr.Wrap(ctx, err, "unmarshaling saved apns key")
}
}
}
@ -2190,12 +2195,12 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
// Generate new APNS CSR every time this is called
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "get app config")
return nil, ctxerr.Wrap(ctx, err, "get app config")
}
apnsCSR, err := apple_mdm.GenerateAPNSCSR(appConfig.OrgInfo.OrgName, apnsKey)
apnsCSR, err := apple_mdm.GenerateAPNSCSR(appConfig.OrgInfo.OrgName, vc.Email(), apnsKey)
if err != nil {
return "", ctxerr.Wrap(ctx, err, "generate APNS cert and key")
return nil, ctxerr.Wrap(ctx, err, "generate APNS cert and key")
}
// Submit CSR to fleetdm.com for signing
@ -2203,8 +2208,9 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
signedCSRB64, err := apple_mdm.GetSignedAPNSCSRNoEmail(websiteClient, apnsCSR)
if err != nil {
if _, ok := err.(apple_mdm.FleetWebsiteError); ok {
return "", ctxerr.Wrap(
var fwe apple_mdm.FleetWebsiteError
if errors.As(err, &fwe) {
return nil, ctxerr.Wrap(
ctx,
fleet.NewUserMessageError(
fmt.Errorf("FleetDM CSR request failed: %w", err),
@ -2212,9 +2218,9 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) (string, error) {
),
)
}
return "", ctxerr.Wrap(ctx, err, "get signed CSR")
return nil, ctxerr.Wrap(ctx, err, "get signed CSR")
}
// Return signed CSR; these bytes are already base64 encoded
return string(signedCSRB64), nil
return signedCSRB64, nil
}