Merge branch 'main' into feat-labels-scoped-software

This commit is contained in:
Gabriel Hernandez 2024-12-16 12:35:18 -06:00
commit 12bf9880ad
20 changed files with 664 additions and 125 deletions

View file

@ -0,0 +1 @@
* On policy deletion any associated pending software installer or scripts are deleted.

View file

@ -0,0 +1 @@
* Changed script upload endpoint (`POST /api/v1/fleet/scripts`) to automatically switch CRLF line endings to LF

View file

@ -0,0 +1 @@
Fixed potential deadlocks when deploying Apple configuration profiles.

View file

@ -2050,7 +2050,7 @@ conjunction with an STS role ARN to ensure that only the intended AWS account ca
### s3_software_installers_endpoint_url
AWS S3 Endpoint URL. Override when using a different S3 compatible object storage backend (such as Minio),
or running s3 locally with localstack. Leave this blank to use the default S3 service endpoint.
or running S3 locally with localstack. Leave this blank to use the default S3 service endpoint.
- Default value: none
- Environment variable: `FLEET_S3_SOFTWARE_INSTALLERS_ENDPOINT_URL`

View file

@ -2814,11 +2814,16 @@ func (ds *Datastore) UpdateOrDeleteHostMDMAppleProfile(ctx context.Context, prof
status = &fleet.MDMDeliveryVerified
}
_, err := ds.writer(ctx).ExecContext(ctx, `
// We need to run with retry due to potential deadlocks with BulkSetPendingMDMHostProfiles.
// Deadlock seen in 2024/12/12 loadtest: https://docs.google.com/document/d/1-Q6qFTd7CDm-lh7MVRgpNlNNJijk6JZ4KO49R1fp80U
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `
UPDATE host_mdm_apple_profiles
SET status = ?, operation_type = ?, detail = ?
WHERE host_uuid = ? AND command_uuid = ?
`, status, profile.OperationType, detail, profile.HostUUID, profile.CommandUUID)
return err
})
return err
}

View file

@ -0,0 +1,93 @@
package common_mysql
import (
"context"
"database/sql"
"errors"
"time"
"github.com/VividCortex/mysqlerr"
"github.com/cenkalti/backoff/v4"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/go-kit/log"
"github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)
var DoRetryErr = errors.New("fleet datastore retry")
type TxFn func(tx sqlx.ExtContext) error
// WithRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
func WithRetryTxx(ctx context.Context, db *sqlx.DB, fn TxFn, logger log.Logger) error {
operation := func() error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return ctxerr.Wrap(ctx, err, "create transaction")
}
defer func() {
if p := recover(); p != nil {
if err := tx.Rollback(); err != nil {
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
}
panic(p)
}
}()
if err := fn(tx); err != nil {
rbErr := tx.Rollback()
if rbErr != nil && rbErr != sql.ErrTxDone {
// Consider rollback errors to be non-retryable
return backoff.Permanent(ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error()))
}
if retryableError(err) {
return err
}
// Consider any other errors to be non-retryable
return backoff.Permanent(err)
}
if err := tx.Commit(); err != nil {
err = ctxerr.Wrap(ctx, err, "commit transaction")
if retryableError(err) {
return err
}
return backoff.Permanent(err)
}
return nil
}
expBo := backoff.NewExponentialBackOff()
// MySQL innodb_lock_wait_timeout default is 50 seconds, so transaction can be waiting for a lock for several seconds.
// Setting a higher MaxElapsedTime to increase probability that transaction will be retried.
// This will reduce the number of retryable 'Deadlock found' errors. However, with a loaded DB, we will still see
// 'Context cancelled' errors when the server drops long-lasting connections.
expBo.MaxElapsedTime = 1 * time.Minute
bo := backoff.WithMaxRetries(expBo, 5)
return backoff.Retry(operation, bo)
}
// retryableError determines whether a MySQL error can be retried. By default
// errors are considered non-retryable. Only errors that we know have a
// possibility of succeeding on a retry should return true in this function.
func retryableError(err error) bool {
base := ctxerr.Cause(err)
if b, ok := base.(*mysql.MySQLError); ok {
switch b.Number {
// Consider lock related errors to be retryable
case mysqlerr.ER_LOCK_DEADLOCK, mysqlerr.ER_LOCK_WAIT_TIMEOUT:
return true
}
}
if errors.Is(err, DoRetryErr) {
return true
}
return false
}

View file

@ -14,15 +14,14 @@ import (
"sync"
"time"
"github.com/VividCortex/mysqlerr"
"github.com/WatchBeam/clock"
"github.com/XSAM/otelsql"
"github.com/cenkalti/backoff/v4"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
"github.com/fleetdm/fleet/v4/server/config"
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
"github.com/fleetdm/fleet/v4/server/fleet"
@ -175,8 +174,6 @@ func (ds *Datastore) NewSCEPDepot() (scep_depot.Depot, error) {
return newSCEPDepot(ds.primary.DB, ds)
}
type txFn func(tx sqlx.ExtContext) error
type entity struct {
name string
}
@ -190,88 +187,12 @@ var (
usersTable = entity{"users"}
)
var doRetryErr = errors.New("fleet datastore retry")
// retryableError determines whether a MySQL error can be retried. By default
// errors are considered non-retryable. Only errors that we know have a
// possibility of succeeding on a retry should return true in this function.
func retryableError(err error) bool {
base := ctxerr.Cause(err)
if b, ok := base.(*mysql.MySQLError); ok {
switch b.Number {
// Consider lock related errors to be retryable
case mysqlerr.ER_LOCK_DEADLOCK, mysqlerr.ER_LOCK_WAIT_TIMEOUT:
return true
}
}
if errors.Is(err, doRetryErr) {
return true
}
return false
}
func (ds *Datastore) withRetryTxx(ctx context.Context, fn txFn) (err error) {
return withRetryTxx(ctx, ds.writer(ctx), fn, ds.logger)
}
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
func withRetryTxx(ctx context.Context, db *sqlx.DB, fn txFn, logger log.Logger) (err error) {
operation := func() error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return ctxerr.Wrap(ctx, err, "create transaction")
}
defer func() {
if p := recover(); p != nil {
if err := tx.Rollback(); err != nil {
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
}
panic(p)
}
}()
if err := fn(tx); err != nil {
rbErr := tx.Rollback()
if rbErr != nil && rbErr != sql.ErrTxDone {
// Consider rollback errors to be non-retryable
return backoff.Permanent(ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error()))
}
if retryableError(err) {
return err
}
// Consider any other errors to be non-retryable
return backoff.Permanent(err)
}
if err := tx.Commit(); err != nil {
err = ctxerr.Wrap(ctx, err, "commit transaction")
if retryableError(err) {
return err
}
return backoff.Permanent(err)
}
return nil
}
expBo := backoff.NewExponentialBackOff()
// MySQL innodb_lock_wait_timeout default is 50 seconds, so transaction can be waiting for a lock for several seconds.
// Setting a higher MaxElapsedTime to increase probability that transaction will be retried.
// This will reduce the number of retryable 'Deadlock found' errors. However, with a loaded DB, we will still see
// 'Context cancelled' errors when the server drops long-lasting connections.
expBo.MaxElapsedTime = 1 * time.Minute
bo := backoff.WithMaxRetries(expBo, 5)
return backoff.Retry(operation, bo)
func (ds *Datastore) withRetryTxx(ctx context.Context, fn common_mysql.TxFn) (err error) {
return common_mysql.WithRetryTxx(ctx, ds.writer(ctx), fn, ds.logger)
}
// withTx provides a common way to commit/rollback a txFn
func (ds *Datastore) withTx(ctx context.Context, fn txFn) (err error) {
func (ds *Datastore) withTx(ctx context.Context, fn common_mysql.TxFn) (err error) {
tx, err := ds.writer(ctx).BeginTxx(ctx, nil)
if err != nil {
return ctxerr.Wrap(ctx, err, "create transaction")

View file

@ -10,6 +10,7 @@ import (
abmctx "github.com/fleetdm/fleet/v4/server/contexts/apple_bm"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/assets"
nanodep_client "github.com/fleetdm/fleet/v4/server/mdm/nanodep/client"
@ -125,7 +126,7 @@ func (s *NanoMDMStorage) EnqueueDeviceLockCommand(
cmd *mdm.Command,
pin string,
) error {
return withRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
return common_mysql.WithRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
if err := enqueueCommandDB(ctx, tx, []string{host.UUID}, cmd); err != nil {
return err
}
@ -154,7 +155,7 @@ func (s *NanoMDMStorage) EnqueueDeviceLockCommand(
// EnqueueDeviceWipeCommand enqueues a EraseDevice command for the given host.
func (s *NanoMDMStorage) EnqueueDeviceWipeCommand(ctx context.Context, host *fleet.Host, cmd *mdm.Command) error {
return withRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
return common_mysql.WithRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
if err := enqueueCommandDB(ctx, tx, []string{host.UUID}, cmd); err != nil {
return err
}

View file

@ -6,6 +6,7 @@ import (
"errors"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
)
@ -93,7 +94,7 @@ func newOperatingSystemDB(ctx context.Context, tx sqlx.ExtContext, hostOS fleet.
case err == nil:
return storedOS, nil
case errors.Is(err, sql.ErrNoRows):
return nil, doRetryErr
return nil, common_mysql.DoRetryErr
default:
return nil, ctxerr.Wrap(ctx, err, "get new operating system")
}

View file

@ -13,6 +13,7 @@ import (
"golang.org/x/text/unicode/norm"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
kitlog "github.com/go-kit/log"
@ -238,7 +239,7 @@ func cleanupPolicy(
}
if _, isDB := extContext.(*sqlx.DB); isDB {
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
err = withRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
err = common_mysql.WithRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
} else {
err = fn(extContext)
}
@ -583,6 +584,15 @@ func (ds *Datastore) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*fl
}
func (ds *Datastore) DeleteGlobalPolicies(ctx context.Context, ids []uint) ([]uint, error) {
for _, id := range ids {
if err := ds.deletePendingSoftwareInstallsForPolicy(ctx, nil, id); err != nil {
return nil, ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
}
if err := ds.deletePendingHostScriptExecutionsForPolicy(ctx, nil, id); err != nil {
return nil, ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
}
}
return deletePolicyDB(ctx, ds.writer(ctx), ids, nil)
}
@ -736,6 +746,15 @@ func (ds *Datastore) ListMergedTeamPolicies(ctx context.Context, teamID uint, op
}
func (ds *Datastore) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {
for _, id := range ids {
if err := ds.deletePendingSoftwareInstallsForPolicy(ctx, &teamID, id); err != nil {
return nil, ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
}
if err := ds.deletePendingHostScriptExecutionsForPolicy(ctx, &teamID, id); err != nil {
return nil, ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
}
}
return deletePolicyDB(ctx, ds.writer(ctx), ids, &teamID)
}

View file

@ -32,10 +32,12 @@ func TestPolicies(t *testing.T) {
}{
{"NewGlobalPolicyLegacy", testPoliciesNewGlobalPolicyLegacy},
{"NewGlobalPolicyProprietary", testPoliciesNewGlobalPolicyProprietary},
{"GlobalPolicyPendingScriptsAndInstalls", testGlobalPolicyPendingScriptsAndInstalls},
{"MembershipViewDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(true, t, ds) }},
{"MembershipViewNotDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(false, t, ds) }},
{"TeamPolicyLegacy", testTeamPolicyLegacy},
{"TeamPolicyProprietary", testTeamPolicyProprietary},
{"TeamPolicyPendingScriptsAndInstalls", testTeamPolicyPendingScriptsAndInstalls},
{"ListMergedTeamPolicies", testListMergedTeamPolicies},
{"PolicyQueriesForHost", testPolicyQueriesForHost},
{"PolicyQueriesForHostPlatforms", testPolicyQueriesForHostPlatforms},
@ -219,6 +221,106 @@ func testPoliciesNewGlobalPolicyProprietary(t *testing.T, ds *Datastore) {
assert.Equal(t, user1.ID, *p3.AuthorID)
}
func testGlobalPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
// create a new script and associate with global policy
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo",
TeamID: nil,
})
require.NoError(t, err)
policy1, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
Description: "query1 desc",
Resolution: "query1 resolution",
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE policies SET script_id = ?", script.ID)
return err
})
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
// create pending script execution
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &policy1.ID,
SyncRequest: true,
ScriptID: &script.ID,
})
require.NoError(t, err)
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingScripts))
// delete the policy
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy1.ID})
require.NoError(t, err)
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingScripts))
// create a new installer and associate with global policy
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
UninstallScript: "goodbye",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user.ID,
})
require.NoError(t, err)
policy2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
Name: "query2",
Query: "select 1;",
Description: "query2 desc",
Resolution: "query2 resolution",
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE policies SET software_installer_id = ?", installerID)
return err
})
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, policies, 1)
// create a pending software install request
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, false, &policy2.ID)
require.NoError(t, err)
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingInstalls))
// delete the policy
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy2.ID})
require.NoError(t, err)
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingInstalls))
}
func testPoliciesListOptions(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
ctx := context.Background()
@ -717,6 +819,99 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
}
func testTeamPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
// create a script and associate it with a team policy
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
script, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
ScriptContents: "echo",
TeamID: &team1.ID,
})
require.NoError(t, err)
policy1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "team1",
Query: "select 1;",
Description: "description",
Resolution: "resolution",
ScriptID: &script.ID,
})
require.NoError(t, err)
// create pending script execution
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: host1.ID,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &policy1.ID,
SyncRequest: true,
ScriptID: &script.ID,
})
require.NoError(t, err)
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingScripts))
// delete the policy
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{policy1.ID})
require.NoError(t, err)
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingScripts))
// create a software install and associate it with a team policy
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
require.NoError(t, err)
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "hello",
PreInstallQuery: "SELECT 1",
PostInstallScript: "world",
UninstallScript: "goodbye",
InstallerFile: tfr1,
StorageID: "storage1",
Filename: "file1",
Title: "file1",
Version: "1.0",
Source: "apps",
UserID: user.ID,
TeamID: &team2.ID,
})
require.NoError(t, err)
policy2, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
Name: "team2",
Query: "select 1;",
Description: "description2",
Resolution: "resolution2",
SoftwareInstallerID: &installerID,
})
require.NoError(t, err)
// create a pending software install request
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, false, &policy2.ID)
require.NoError(t, err)
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 1, len(pendingInstalls))
// delete the policy
_, err = ds.DeleteTeamPolicies(ctx, team2.ID, []uint{policy2.ID})
require.NoError(t, err)
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
require.NoError(t, err)
require.Equal(t, 0, len(pendingInstalls))
}
func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
ctx := context.Background()
gpol, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{

View file

@ -18,6 +18,15 @@ import (
"github.com/jmoiron/sqlx"
)
const whereFilterPendingScript = `
exit_code IS NULL
-- async requests + sync requests created within the given interval
AND (
sync_request = 0
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
)
`
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
var res *fleet.HostScriptResult
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
@ -204,24 +213,20 @@ func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *f
}
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint) ([]*fleet.HostScriptResult, error) {
const listStmt = `
SELECT
id,
host_id,
execution_id,
script_id
FROM
host_script_results
WHERE
host_id = ? AND
exit_code IS NULL
-- async requests + sync requests created within the given interval
AND (
sync_request = 0
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
)
ORDER BY
created_at ASC`
listStmt := fmt.Sprintf(`
SELECT
id,
host_id,
execution_id,
script_id
FROM
host_script_results
WHERE
host_id = ? AND
%s
ORDER BY
created_at ASC
`, whereFilterPendingScript)
var results []*fleet.HostScriptResult
seconds := int(constants.MaxServerWaitTime.Seconds())
@ -471,6 +476,33 @@ func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
})
}
// deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions
func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
deleteStmt := fmt.Sprintf(`
DELETE FROM
host_script_results
WHERE
policy_id = ? AND
script_id IN (
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
) AND
%s
`, whereFilterPendingScript)
seconds := int(constants.MaxServerWaitTime.Seconds())
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt, policyID, globalOrTeamID, seconds)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
}
return nil
}
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
var scripts []*fleet.Script

View file

@ -38,6 +38,7 @@ func TestScripts(t *testing.T) {
{"TestCleanupUnusedScriptContents", testCleanupUnusedScriptContents},
{"TestGetAnyScriptContents", testGetAnyScriptContents},
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -1422,3 +1423,110 @@ func testDeleteScriptsAssignedToPolicy(t *testing.T, ds *Datastore) {
err = ds.DeleteScript(ctx, script.ID)
require.NoError(t, err)
}
func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, _ := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
script1, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1.sh",
TeamID: &team1.ID,
ScriptContents: "hello world",
})
require.NoError(t, err)
script2, err := ds.NewScript(ctx, &fleet.Script{
Name: "script2.sh",
TeamID: &team1.ID,
ScriptContents: "hello world",
})
require.NoError(t, err)
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
ScriptID: &script1.ID,
})
require.NoError(t, err)
p2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 2;",
ScriptID: &script2.ID,
})
require.NoError(t, err)
// pending host script execution for correct policy
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Equal(t, 0, len(pending))
// test pending host script execution for incorrect policy
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p2.ID,
SyncRequest: true,
ScriptID: &script2.ID,
})
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Equal(t, 1, len(pending))
// test not pending host script execution for correct policy
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo",
UserID: &user.ID,
PolicyID: &p1.ID,
SyncRequest: true,
ScriptID: &script1.ID,
})
require.NoError(t, err)
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err = q.ExecContext(ctx, `UPDATE host_script_results SET exit_code = 1 WHERE id = ?`, scriptExecution.ID)
require.NoError(t, err)
return nil
})
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
require.NoError(t, err)
var count int
err = sqlx.GetContext(
ctx,
ds.reader(ctx),
&count,
"SELECT count(1) FROM host_script_results WHERE id = ?",
scriptExecution.ID,
)
require.Equal(t, 1, count)
}

View file

@ -461,6 +461,31 @@ func (ds *Datastore) DeleteSoftwareInstaller(ctx context.Context, id uint) error
})
}
// deletePendingSoftwareInstallsForPolicy should be called after a policy is deleted to remove any pending software installs
func (ds *Datastore) deletePendingSoftwareInstallsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
const deleteStmt = `
DELETE FROM
host_software_installs
WHERE
policy_id = ? AND
status = ? AND
software_installer_id IN (
SELECT id FROM software_installers WHERE global_or_team_id = ?
)
`
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt, policyID, fleet.SoftwareInstallPending, globalOrTeamID)
if err != nil {
return ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
}
return nil
}
func (ds *Datastore) InsertSoftwareInstallRequest(ctx context.Context, hostID uint, softwareInstallerID uint, selfService bool, policyID *uint) (string, error) {
const (
getInstallerStmt = `SELECT filename, "version", title_id, COALESCE(st.name, '[deleted title]') title_name

View file

@ -34,6 +34,7 @@ func TestSoftwareInstallers(t *testing.T) {
{"GetSoftwareInstallerMetadataByTeamAndTitleID", testGetSoftwareInstallerMetadataByTeamAndTitleID},
{"HasSelfServiceSoftwareInstallers", testHasSelfServiceSoftwareInstallers},
{"DeleteSoftwareInstallers", testDeleteSoftwareInstallers},
{"testDeletePendingSoftwareInstallsForPolicy", testDeletePendingSoftwareInstallsForPolicy},
{"GetHostLastInstallData", testGetHostLastInstallData},
{"GetOrGenerateSoftwareInstallerTitleID", testGetOrGenerateSoftwareInstallerTitleID},
}
@ -1137,6 +1138,120 @@ func testDeleteSoftwareInstallers(t *testing.T, ds *Datastore) {
require.ErrorAs(t, err, &nfe)
}
func testDeletePendingSoftwareInstallsForPolicy(t *testing.T, ds *Datastore) {
ctx := context.Background()
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
dir := t.TempDir()
store, err := filesystem.NewSoftwareInstallerStore(dir)
require.NoError(t, err)
ins0 := "installer.pkg"
ins0File := bytes.NewReader([]byte("installer0"))
err = store.Put(ctx, ins0, ins0File)
require.NoError(t, err)
_, _ = ins0File.Seek(0, 0)
tfr0, err := fleet.NewTempFileReader(ins0File, t.TempDir)
require.NoError(t, err)
installerID1, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "install",
InstallerFile: tfr0,
StorageID: ins0,
Filename: "installer.pkg",
Title: "ins0",
Source: "apps",
Platform: "darwin",
TeamID: &team1.ID,
UserID: user1.ID,
})
require.NoError(t, err)
policy1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p1",
Query: "SELECT 1;",
SoftwareInstallerID: &installerID1,
})
require.NoError(t, err)
installerID2, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
InstallScript: "install",
InstallerFile: tfr0,
StorageID: ins0,
Filename: "installer.pkg",
Title: "ins1",
Source: "apps",
Platform: "darwin",
TeamID: &team1.ID,
UserID: user1.ID,
})
require.NoError(t, err)
policy2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "p2",
Query: "SELECT 2;",
SoftwareInstallerID: &installerID2,
})
require.NoError(t, err)
const hostSoftwareInstallsCount = "SELECT count(1) FROM host_software_installs WHERE status = ? and execution_id = ?"
var count int
// install for correct policy & correct status
executionID, err := ds.InsertSoftwareInstallRequest(ctx, host1.ID, installerID1, false, &policy1.ID)
require.NoError(t, err)
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
require.NoError(t, err)
require.Equal(t, 1, count)
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
require.NoError(t, err)
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
require.NoError(t, err)
require.Equal(t, 0, count)
// install for different policy & correct status
executionID, err = ds.InsertSoftwareInstallRequest(ctx, host1.ID, installerID2, false, &policy2.ID)
require.NoError(t, err)
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
require.NoError(t, err)
require.Equal(t, 1, count)
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
require.NoError(t, err)
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
require.NoError(t, err)
require.Equal(t, 1, count)
// install for correct policy & incorrect status
executionID, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID1, false, &policy1.ID)
require.NoError(t, err)
err = ds.SetHostSoftwareInstallResult(ctx, &fleet.HostSoftwareInstallResultPayload{
HostID: host2.ID,
InstallUUID: executionID,
InstallScriptExitCode: ptr.Int(0),
})
require.NoError(t, err)
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
require.NoError(t, err)
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, `SELECT count(1) FROM host_software_installs WHERE execution_id = ?`, executionID)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func testGetHostLastInstallData(t *testing.T, ds *Datastore) {
ctx := context.Background()

View file

@ -3,6 +3,7 @@ package apple_mdm
import (
"context"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm"
)
@ -118,11 +119,15 @@ func HandleHostMDMProfileInstallResult(ctx context.Context, ds fleet.ProfileVeri
}
// otherwise update status and detail as usual
return ds.UpdateOrDeleteHostMDMAppleProfile(ctx, &fleet.HostMDMAppleProfile{
err := ds.UpdateOrDeleteHostMDMAppleProfile(ctx, &fleet.HostMDMAppleProfile{
CommandUUID: cmdUUID,
HostUUID: hostUUID,
Status: status,
Detail: detail,
OperationType: fleet.MDMOperationTypeInstall,
})
if err != nil {
return ctxerr.Wrap(ctx, err, "updating host MDM Apple profile install result")
}
return nil
}

View file

@ -10,6 +10,7 @@ import (
"os"
"time"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
@ -113,7 +114,7 @@ func New(opts ...Option) (*MySQLStorage, error) {
mysqlStore := &MySQLStorage{db: cfg.db, logger: cfg.logger, rm: cfg.rm}
if cfg.reader == nil {
mysqlStore.reader = func(ctx context.Context) fleet.DBReader {
return sqlx.NewDb(mysqlStore.db, "mysql")
return sqlx.NewDb(mysqlStore.db, "")
}
} else {
mysqlStore.reader = cfg.reader
@ -337,7 +338,10 @@ func (s *MySQLStorage) updateLastSeenBatch(ctx context.Context, ids []string) {
return
}
_, err = s.db.ExecContext(ctx, stmt, args...)
err = common_mysql.WithRetryTxx(ctx, sqlx.NewDb(s.db, ""), func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, stmt, args...)
return err
}, loggerWrapper{s.logger})
if err != nil {
s.logger.Info("msg", "error batch updating nano_enrollments.last_seen_at", "err", err)
}

View file

@ -8,11 +8,14 @@ import (
"strings"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/micromdm/nanolib/log"
)
func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) error {
func enqueue(ctx context.Context, tx sqlx.ExtContext, ids []string, cmd *mdm.Command) error {
if len(ids) < 1 {
return errors.New("no id(s) supplied to queue command to")
}
@ -50,18 +53,22 @@ func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) er
return nil
}
type loggerWrapper struct {
logger log.Logger
}
func (l loggerWrapper) Log(keyvals ...interface{}) error {
l.logger.Info(keyvals...)
return nil
}
func (m *MySQLStorage) EnqueueCommand(ctx context.Context, ids []string, cmd *mdm.Command) (map[string]error, error) {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
if err = enqueue(ctx, tx, ids, cmd); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return nil, fmt.Errorf("rollback error: %w; while trying to handle error: %v", rbErr, err)
}
return nil, err
}
return nil, tx.Commit()
// We need to retry because this transaction may deadlock with updates to nano_enrollment.last_seen_at
// Deadlock seen in 2024/12/12 loadtest: https://docs.google.com/document/d/1-Q6qFTd7CDm-lh7MVRgpNlNNJijk6JZ4KO49R1fp80U
err := common_mysql.WithRetryTxx(ctx, sqlx.NewDb(m.db, ""), func(tx sqlx.ExtContext) error {
return enqueue(ctx, tx, ids, cmd)
}, loggerWrapper{m.logger})
return nil, err
}
func (m *MySQLStorage) deleteCommand(ctx context.Context, tx *sql.Tx, id, uuid string) error {

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/docker/go-units"
"github.com/fleetdm/fleet/v4/pkg/file"
"github.com/fleetdm/fleet/v4/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
@ -507,7 +508,7 @@ func (svc *Service) NewScript(ctx context.Context, teamID *uint, name string, r
script := &fleet.Script{
TeamID: teamID,
Name: name,
ScriptContents: string(b),
ScriptContents: file.Dos2UnixNewlines(string(b)),
}
if err := script.ValidateNewScript(); err != nil {
return nil, fleet.NewInvalidArgumentError("script", err.Error())

View file

@ -498,10 +498,14 @@ func TestSavedScripts(t *testing.T) {
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true})
withLFContents := "echo\necho"
withCRLFContents := "echo\r\necho"
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewScriptFunc = func(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
require.Equal(t, withLFContents, script.ScriptContents)
newScript := *script
newScript.ID = 1
return &newScript, nil
@ -669,7 +673,7 @@ func TestSavedScripts(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ctx = viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
_, err := svc.NewScript(ctx, nil, "test.sh", strings.NewReader("echo"))
_, err := svc.NewScript(ctx, nil, "test.ps1", strings.NewReader(withCRLFContents))
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
err = svc.DeleteScript(ctx, noTeamScriptID)
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
@ -680,7 +684,7 @@ func TestSavedScripts(t *testing.T) {
_, _, err = svc.GetScript(ctx, noTeamScriptID, true)
checkAuthErr(t, tt.shouldFailGlobalRead, err)
_, err = svc.NewScript(ctx, ptr.Uint(1), "test.sh", strings.NewReader("echo"))
_, err = svc.NewScript(ctx, ptr.Uint(1), "test.sh", strings.NewReader(withLFContents))
checkAuthErr(t, tt.shouldFailTeamWrite, err)
err = svc.DeleteScript(ctx, team1ScriptID)
checkAuthErr(t, tt.shouldFailTeamWrite, err)