mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 08:58:41 +00:00
Merge branch 'main' into feat-labels-scoped-software
This commit is contained in:
commit
12bf9880ad
20 changed files with 664 additions and 125 deletions
1
changes/23886-remove-associations-on-policy-delete
Normal file
1
changes/23886-remove-associations-on-policy-delete
Normal file
|
|
@ -0,0 +1 @@
|
|||
* On policy deletion any associated pending software installer or scripts are deleted.
|
||||
1
changes/24166-script-line-endings
Normal file
1
changes/24166-script-line-endings
Normal file
|
|
@ -0,0 +1 @@
|
|||
* Changed script upload endpoint (`POST /api/v1/fleet/scripts`) to automatically switch CRLF line endings to LF
|
||||
1
changes/24771-mdm-deadlock-fixes
Normal file
1
changes/24771-mdm-deadlock-fixes
Normal file
|
|
@ -0,0 +1 @@
|
|||
Fixed potential deadlocks when deploying Apple configuration profiles.
|
||||
|
|
@ -2050,7 +2050,7 @@ conjunction with an STS role ARN to ensure that only the intended AWS account ca
|
|||
### s3_software_installers_endpoint_url
|
||||
|
||||
AWS S3 Endpoint URL. Override when using a different S3 compatible object storage backend (such as Minio),
|
||||
or running s3 locally with localstack. Leave this blank to use the default S3 service endpoint.
|
||||
or running S3 locally with localstack. Leave this blank to use the default S3 service endpoint.
|
||||
|
||||
- Default value: none
|
||||
- Environment variable: `FLEET_S3_SOFTWARE_INSTALLERS_ENDPOINT_URL`
|
||||
|
|
|
|||
|
|
@ -2814,11 +2814,16 @@ func (ds *Datastore) UpdateOrDeleteHostMDMAppleProfile(ctx context.Context, prof
|
|||
status = &fleet.MDMDeliveryVerified
|
||||
}
|
||||
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, `
|
||||
// We need to run with retry due to potential deadlocks with BulkSetPendingMDMHostProfiles.
|
||||
// Deadlock seen in 2024/12/12 loadtest: https://docs.google.com/document/d/1-Q6qFTd7CDm-lh7MVRgpNlNNJijk6JZ4KO49R1fp80U
|
||||
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
UPDATE host_mdm_apple_profiles
|
||||
SET status = ?, operation_type = ?, detail = ?
|
||||
WHERE host_uuid = ? AND command_uuid = ?
|
||||
`, status, profile.OperationType, detail, profile.HostUUID, profile.CommandUUID)
|
||||
return err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
93
server/datastore/mysql/common_mysql/retry.go
Normal file
93
server/datastore/mysql/common_mysql/retry.go
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package common_mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/VividCortex/mysqlerr"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/go-kit/log"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
var DoRetryErr = errors.New("fleet datastore retry")
|
||||
|
||||
type TxFn func(tx sqlx.ExtContext) error
|
||||
|
||||
// WithRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
||||
func WithRetryTxx(ctx context.Context, db *sqlx.DB, fn TxFn, logger log.Logger) error {
|
||||
operation := func() error {
|
||||
tx, err := db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "create transaction")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
|
||||
}
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
rbErr := tx.Rollback()
|
||||
if rbErr != nil && rbErr != sql.ErrTxDone {
|
||||
// Consider rollback errors to be non-retryable
|
||||
return backoff.Permanent(ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error()))
|
||||
}
|
||||
|
||||
if retryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Consider any other errors to be non-retryable
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
err = ctxerr.Wrap(ctx, err, "commit transaction")
|
||||
|
||||
if retryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
expBo := backoff.NewExponentialBackOff()
|
||||
// MySQL innodb_lock_wait_timeout default is 50 seconds, so transaction can be waiting for a lock for several seconds.
|
||||
// Setting a higher MaxElapsedTime to increase probability that transaction will be retried.
|
||||
// This will reduce the number of retryable 'Deadlock found' errors. However, with a loaded DB, we will still see
|
||||
// 'Context cancelled' errors when the server drops long-lasting connections.
|
||||
expBo.MaxElapsedTime = 1 * time.Minute
|
||||
bo := backoff.WithMaxRetries(expBo, 5)
|
||||
return backoff.Retry(operation, bo)
|
||||
}
|
||||
|
||||
// retryableError determines whether a MySQL error can be retried. By default
|
||||
// errors are considered non-retryable. Only errors that we know have a
|
||||
// possibility of succeeding on a retry should return true in this function.
|
||||
func retryableError(err error) bool {
|
||||
base := ctxerr.Cause(err)
|
||||
if b, ok := base.(*mysql.MySQLError); ok {
|
||||
switch b.Number {
|
||||
// Consider lock related errors to be retryable
|
||||
case mysqlerr.ER_LOCK_DEADLOCK, mysqlerr.ER_LOCK_WAIT_TIMEOUT:
|
||||
return true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, DoRetryErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -14,15 +14,14 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VividCortex/mysqlerr"
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/XSAM/otelsql"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/doug-martin/goqu/v9/exp"
|
||||
"github.com/fleetdm/fleet/v4/server/config"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxdb"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/data"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/migrations/tables"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
|
|
@ -175,8 +174,6 @@ func (ds *Datastore) NewSCEPDepot() (scep_depot.Depot, error) {
|
|||
return newSCEPDepot(ds.primary.DB, ds)
|
||||
}
|
||||
|
||||
type txFn func(tx sqlx.ExtContext) error
|
||||
|
||||
type entity struct {
|
||||
name string
|
||||
}
|
||||
|
|
@ -190,88 +187,12 @@ var (
|
|||
usersTable = entity{"users"}
|
||||
)
|
||||
|
||||
var doRetryErr = errors.New("fleet datastore retry")
|
||||
|
||||
// retryableError determines whether a MySQL error can be retried. By default
|
||||
// errors are considered non-retryable. Only errors that we know have a
|
||||
// possibility of succeeding on a retry should return true in this function.
|
||||
func retryableError(err error) bool {
|
||||
base := ctxerr.Cause(err)
|
||||
if b, ok := base.(*mysql.MySQLError); ok {
|
||||
switch b.Number {
|
||||
// Consider lock related errors to be retryable
|
||||
case mysqlerr.ER_LOCK_DEADLOCK, mysqlerr.ER_LOCK_WAIT_TIMEOUT:
|
||||
return true
|
||||
}
|
||||
}
|
||||
if errors.Is(err, doRetryErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (ds *Datastore) withRetryTxx(ctx context.Context, fn txFn) (err error) {
|
||||
return withRetryTxx(ctx, ds.writer(ctx), fn, ds.logger)
|
||||
}
|
||||
|
||||
// withRetryTxx provides a common way to commit/rollback a txFn wrapped in a retry with exponential backoff
|
||||
func withRetryTxx(ctx context.Context, db *sqlx.DB, fn txFn, logger log.Logger) (err error) {
|
||||
operation := func() error {
|
||||
tx, err := db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "create transaction")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if err := tx.Rollback(); err != nil {
|
||||
logger.Log("err", err, "msg", "error encountered during transaction panic rollback")
|
||||
}
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
rbErr := tx.Rollback()
|
||||
if rbErr != nil && rbErr != sql.ErrTxDone {
|
||||
// Consider rollback errors to be non-retryable
|
||||
return backoff.Permanent(ctxerr.Wrapf(ctx, err, "got err '%s' rolling back after err", rbErr.Error()))
|
||||
}
|
||||
|
||||
if retryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Consider any other errors to be non-retryable
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
err = ctxerr.Wrap(ctx, err, "commit transaction")
|
||||
|
||||
if retryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
expBo := backoff.NewExponentialBackOff()
|
||||
// MySQL innodb_lock_wait_timeout default is 50 seconds, so transaction can be waiting for a lock for several seconds.
|
||||
// Setting a higher MaxElapsedTime to increase probability that transaction will be retried.
|
||||
// This will reduce the number of retryable 'Deadlock found' errors. However, with a loaded DB, we will still see
|
||||
// 'Context cancelled' errors when the server drops long-lasting connections.
|
||||
expBo.MaxElapsedTime = 1 * time.Minute
|
||||
bo := backoff.WithMaxRetries(expBo, 5)
|
||||
return backoff.Retry(operation, bo)
|
||||
func (ds *Datastore) withRetryTxx(ctx context.Context, fn common_mysql.TxFn) (err error) {
|
||||
return common_mysql.WithRetryTxx(ctx, ds.writer(ctx), fn, ds.logger)
|
||||
}
|
||||
|
||||
// withTx provides a common way to commit/rollback a txFn
|
||||
func (ds *Datastore) withTx(ctx context.Context, fn txFn) (err error) {
|
||||
func (ds *Datastore) withTx(ctx context.Context, fn common_mysql.TxFn) (err error) {
|
||||
tx, err := ds.writer(ctx).BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "create transaction")
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
abmctx "github.com/fleetdm/fleet/v4/server/contexts/apple_bm"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/assets"
|
||||
nanodep_client "github.com/fleetdm/fleet/v4/server/mdm/nanodep/client"
|
||||
|
|
@ -125,7 +126,7 @@ func (s *NanoMDMStorage) EnqueueDeviceLockCommand(
|
|||
cmd *mdm.Command,
|
||||
pin string,
|
||||
) error {
|
||||
return withRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
|
||||
return common_mysql.WithRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
|
||||
if err := enqueueCommandDB(ctx, tx, []string{host.UUID}, cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -154,7 +155,7 @@ func (s *NanoMDMStorage) EnqueueDeviceLockCommand(
|
|||
|
||||
// EnqueueDeviceWipeCommand enqueues a EraseDevice command for the given host.
|
||||
func (s *NanoMDMStorage) EnqueueDeviceWipeCommand(ctx context.Context, host *fleet.Host, cmd *mdm.Command) error {
|
||||
return withRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
|
||||
return common_mysql.WithRetryTxx(ctx, s.db, func(tx sqlx.ExtContext) error {
|
||||
if err := enqueueCommandDB(ctx, tx, []string{host.UUID}, cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
|
@ -93,7 +94,7 @@ func newOperatingSystemDB(ctx context.Context, tx sqlx.ExtContext, hostOS fleet.
|
|||
case err == nil:
|
||||
return storedOS, nil
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
return nil, doRetryErr
|
||||
return nil, common_mysql.DoRetryErr
|
||||
default:
|
||||
return nil, ctxerr.Wrap(ctx, err, "get new operating system")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"golang.org/x/text/unicode/norm"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
kitlog "github.com/go-kit/log"
|
||||
|
|
@ -238,7 +239,7 @@ func cleanupPolicy(
|
|||
}
|
||||
if _, isDB := extContext.(*sqlx.DB); isDB {
|
||||
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
|
||||
err = withRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
|
||||
err = common_mysql.WithRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
|
||||
} else {
|
||||
err = fn(extContext)
|
||||
}
|
||||
|
|
@ -583,6 +584,15 @@ func (ds *Datastore) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*fl
|
|||
}
|
||||
|
||||
func (ds *Datastore) DeleteGlobalPolicies(ctx context.Context, ids []uint) ([]uint, error) {
|
||||
for _, id := range ids {
|
||||
if err := ds.deletePendingSoftwareInstallsForPolicy(ctx, nil, id); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
|
||||
}
|
||||
if err := ds.deletePendingHostScriptExecutionsForPolicy(ctx, nil, id); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
|
||||
}
|
||||
}
|
||||
|
||||
return deletePolicyDB(ctx, ds.writer(ctx), ids, nil)
|
||||
}
|
||||
|
||||
|
|
@ -736,6 +746,15 @@ func (ds *Datastore) ListMergedTeamPolicies(ctx context.Context, teamID uint, op
|
|||
}
|
||||
|
||||
func (ds *Datastore) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {
|
||||
for _, id := range ids {
|
||||
if err := ds.deletePendingSoftwareInstallsForPolicy(ctx, &teamID, id); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
|
||||
}
|
||||
if err := ds.deletePendingHostScriptExecutionsForPolicy(ctx, &teamID, id); err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
|
||||
}
|
||||
}
|
||||
|
||||
return deletePolicyDB(ctx, ds.writer(ctx), ids, &teamID)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,10 +32,12 @@ func TestPolicies(t *testing.T) {
|
|||
}{
|
||||
{"NewGlobalPolicyLegacy", testPoliciesNewGlobalPolicyLegacy},
|
||||
{"NewGlobalPolicyProprietary", testPoliciesNewGlobalPolicyProprietary},
|
||||
{"GlobalPolicyPendingScriptsAndInstalls", testGlobalPolicyPendingScriptsAndInstalls},
|
||||
{"MembershipViewDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(true, t, ds) }},
|
||||
{"MembershipViewNotDeferred", func(t *testing.T, ds *Datastore) { testPoliciesMembershipView(false, t, ds) }},
|
||||
{"TeamPolicyLegacy", testTeamPolicyLegacy},
|
||||
{"TeamPolicyProprietary", testTeamPolicyProprietary},
|
||||
{"TeamPolicyPendingScriptsAndInstalls", testTeamPolicyPendingScriptsAndInstalls},
|
||||
{"ListMergedTeamPolicies", testListMergedTeamPolicies},
|
||||
{"PolicyQueriesForHost", testPolicyQueriesForHost},
|
||||
{"PolicyQueriesForHostPlatforms", testPolicyQueriesForHostPlatforms},
|
||||
|
|
@ -219,6 +221,106 @@ func testPoliciesNewGlobalPolicyProprietary(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, user1.ID, *p3.AuthorID)
|
||||
}
|
||||
|
||||
func testGlobalPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
||||
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
|
||||
|
||||
// create a new script and associate with global policy
|
||||
script, err := ds.NewScript(ctx, &fleet.Script{
|
||||
Name: "script1.sh",
|
||||
ScriptContents: "echo",
|
||||
TeamID: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
policy1, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
|
||||
Name: "query1",
|
||||
Query: "select 1;",
|
||||
Description: "query1 desc",
|
||||
Resolution: "query1 resolution",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
||||
_, err := q.ExecContext(ctx, "UPDATE policies SET script_id = ?", script.ID)
|
||||
return err
|
||||
})
|
||||
policies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
|
||||
// create pending script execution
|
||||
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
HostID: host1.ID,
|
||||
ScriptContents: "echo",
|
||||
UserID: &user.ID,
|
||||
PolicyID: &policy1.ID,
|
||||
SyncRequest: true,
|
||||
ScriptID: &script.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pendingScripts))
|
||||
|
||||
// delete the policy
|
||||
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy1.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(pendingScripts))
|
||||
|
||||
// create a new installer and associate with global policy
|
||||
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
|
||||
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
|
||||
require.NoError(t, err)
|
||||
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
||||
InstallScript: "hello",
|
||||
PreInstallQuery: "SELECT 1",
|
||||
PostInstallScript: "world",
|
||||
UninstallScript: "goodbye",
|
||||
InstallerFile: tfr1,
|
||||
StorageID: "storage1",
|
||||
Filename: "file1",
|
||||
Title: "file1",
|
||||
Version: "1.0",
|
||||
Source: "apps",
|
||||
UserID: user.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
policy2, err := ds.NewGlobalPolicy(ctx, &user.ID, fleet.PolicyPayload{
|
||||
Name: "query2",
|
||||
Query: "select 1;",
|
||||
Description: "query2 desc",
|
||||
Resolution: "query2 resolution",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
||||
_, err := q.ExecContext(ctx, "UPDATE policies SET software_installer_id = ?", installerID)
|
||||
return err
|
||||
})
|
||||
policies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
|
||||
// create a pending software install request
|
||||
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, false, &policy2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pendingInstalls))
|
||||
|
||||
// delete the policy
|
||||
_, err = ds.DeleteGlobalPolicies(ctx, []uint{policy2.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(pendingInstalls))
|
||||
}
|
||||
|
||||
func testPoliciesListOptions(t *testing.T, ds *Datastore) {
|
||||
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
||||
ctx := context.Background()
|
||||
|
|
@ -717,6 +819,99 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
|
|||
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
|
||||
}
|
||||
|
||||
func testTeamPolicyPendingScriptsAndInstalls(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
||||
|
||||
// create a script and associate it with a team policy
|
||||
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
||||
require.NoError(t, err)
|
||||
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
|
||||
|
||||
script, err := ds.NewScript(ctx, &fleet.Script{
|
||||
Name: "script1.sh",
|
||||
ScriptContents: "echo",
|
||||
TeamID: &team1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
policy1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
||||
Name: "team1",
|
||||
Query: "select 1;",
|
||||
Description: "description",
|
||||
Resolution: "resolution",
|
||||
ScriptID: &script.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create pending script execution
|
||||
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
HostID: host1.ID,
|
||||
ScriptContents: "echo",
|
||||
UserID: &user.ID,
|
||||
PolicyID: &policy1.ID,
|
||||
SyncRequest: true,
|
||||
ScriptID: &script.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
pendingScripts, err := ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pendingScripts))
|
||||
|
||||
// delete the policy
|
||||
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{policy1.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingScripts, err = ds.ListPendingHostScriptExecutions(ctx, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(pendingScripts))
|
||||
|
||||
// create a software install and associate it with a team policy
|
||||
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
|
||||
require.NoError(t, err)
|
||||
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
|
||||
tfr1, err := fleet.NewTempFileReader(strings.NewReader("hello"), t.TempDir)
|
||||
require.NoError(t, err)
|
||||
installerID, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
||||
InstallScript: "hello",
|
||||
PreInstallQuery: "SELECT 1",
|
||||
PostInstallScript: "world",
|
||||
UninstallScript: "goodbye",
|
||||
InstallerFile: tfr1,
|
||||
StorageID: "storage1",
|
||||
Filename: "file1",
|
||||
Title: "file1",
|
||||
Version: "1.0",
|
||||
Source: "apps",
|
||||
UserID: user.ID,
|
||||
TeamID: &team2.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
policy2, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{
|
||||
Name: "team2",
|
||||
Query: "select 1;",
|
||||
Description: "description2",
|
||||
Resolution: "resolution2",
|
||||
SoftwareInstallerID: &installerID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// create a pending software install request
|
||||
_, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID, false, &policy2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingInstalls, err := ds.ListPendingSoftwareInstalls(ctx, host2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pendingInstalls))
|
||||
|
||||
// delete the policy
|
||||
_, err = ds.DeleteTeamPolicies(ctx, team2.ID, []uint{policy2.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingInstalls, err = ds.ListPendingSoftwareInstalls(ctx, host2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(pendingInstalls))
|
||||
}
|
||||
|
||||
func testListMergedTeamPolicies(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
gpol, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
|
||||
|
|
|
|||
|
|
@ -18,6 +18,15 @@ import (
|
|||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const whereFilterPendingScript = `
|
||||
exit_code IS NULL
|
||||
-- async requests + sync requests created within the given interval
|
||||
AND (
|
||||
sync_request = 0
|
||||
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
|
||||
)
|
||||
`
|
||||
|
||||
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
|
||||
var res *fleet.HostScriptResult
|
||||
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
|
|
@ -204,24 +213,20 @@ func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *f
|
|||
}
|
||||
|
||||
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint) ([]*fleet.HostScriptResult, error) {
|
||||
const listStmt = `
|
||||
SELECT
|
||||
id,
|
||||
host_id,
|
||||
execution_id,
|
||||
script_id
|
||||
FROM
|
||||
host_script_results
|
||||
WHERE
|
||||
host_id = ? AND
|
||||
exit_code IS NULL
|
||||
-- async requests + sync requests created within the given interval
|
||||
AND (
|
||||
sync_request = 0
|
||||
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
|
||||
)
|
||||
ORDER BY
|
||||
created_at ASC`
|
||||
listStmt := fmt.Sprintf(`
|
||||
SELECT
|
||||
id,
|
||||
host_id,
|
||||
execution_id,
|
||||
script_id
|
||||
FROM
|
||||
host_script_results
|
||||
WHERE
|
||||
host_id = ? AND
|
||||
%s
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
`, whereFilterPendingScript)
|
||||
|
||||
var results []*fleet.HostScriptResult
|
||||
seconds := int(constants.MaxServerWaitTime.Seconds())
|
||||
|
|
@ -471,6 +476,33 @@ func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
|
|||
})
|
||||
}
|
||||
|
||||
// deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions
|
||||
func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
|
||||
var globalOrTeamID uint
|
||||
if teamID != nil {
|
||||
globalOrTeamID = *teamID
|
||||
}
|
||||
|
||||
deleteStmt := fmt.Sprintf(`
|
||||
DELETE FROM
|
||||
host_script_results
|
||||
WHERE
|
||||
policy_id = ? AND
|
||||
script_id IN (
|
||||
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
||||
) AND
|
||||
%s
|
||||
`, whereFilterPendingScript)
|
||||
|
||||
seconds := int(constants.MaxServerWaitTime.Seconds())
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt, policyID, globalOrTeamID, seconds)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
|
||||
var scripts []*fleet.Script
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ func TestScripts(t *testing.T) {
|
|||
{"TestCleanupUnusedScriptContents", testCleanupUnusedScriptContents},
|
||||
{"TestGetAnyScriptContents", testGetAnyScriptContents},
|
||||
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
|
||||
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
|
|
@ -1422,3 +1423,110 @@ func testDeleteScriptsAssignedToPolicy(t *testing.T, ds *Datastore) {
|
|||
err = ds.DeleteScript(ctx, script.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
|
||||
user := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
||||
team1, _ := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
||||
|
||||
script1, err := ds.NewScript(ctx, &fleet.Script{
|
||||
Name: "script1.sh",
|
||||
TeamID: &team1.ID,
|
||||
ScriptContents: "hello world",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
script2, err := ds.NewScript(ctx, &fleet.Script{
|
||||
Name: "script2.sh",
|
||||
TeamID: &team1.ID,
|
||||
ScriptContents: "hello world",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
p1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
||||
Name: "p1",
|
||||
Query: "SELECT 1;",
|
||||
ScriptID: &script1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
p2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{
|
||||
Name: "p2",
|
||||
Query: "SELECT 2;",
|
||||
ScriptID: &script2.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// pending host script execution for correct policy
|
||||
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
HostID: 1,
|
||||
ScriptContents: "echo",
|
||||
UserID: &user.ID,
|
||||
PolicyID: &p1.ID,
|
||||
SyncRequest: true,
|
||||
ScriptID: &script1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pending, err := ds.ListPendingHostScriptExecutions(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pending))
|
||||
|
||||
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(pending))
|
||||
|
||||
// test pending host script execution for incorrect policy
|
||||
_, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
HostID: 1,
|
||||
ScriptContents: "echo",
|
||||
UserID: &user.ID,
|
||||
PolicyID: &p2.ID,
|
||||
SyncRequest: true,
|
||||
ScriptID: &script2.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pending))
|
||||
|
||||
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(pending))
|
||||
|
||||
// test not pending host script execution for correct policy
|
||||
scriptExecution, err := ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
|
||||
HostID: 1,
|
||||
ScriptContents: "echo",
|
||||
UserID: &user.ID,
|
||||
PolicyID: &p1.ID,
|
||||
SyncRequest: true,
|
||||
ScriptID: &script1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
|
||||
_, err = q.ExecContext(ctx, `UPDATE host_script_results SET exit_code = 1 WHERE id = ?`, scriptExecution.ID)
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
err = ds.deletePendingHostScriptExecutionsForPolicy(ctx, &team1.ID, p1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int
|
||||
err = sqlx.GetContext(
|
||||
ctx,
|
||||
ds.reader(ctx),
|
||||
&count,
|
||||
"SELECT count(1) FROM host_script_results WHERE id = ?",
|
||||
scriptExecution.ID,
|
||||
)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -461,6 +461,31 @@ func (ds *Datastore) DeleteSoftwareInstaller(ctx context.Context, id uint) error
|
|||
})
|
||||
}
|
||||
|
||||
// deletePendingSoftwareInstallsForPolicy should be called after a policy is deleted to remove any pending software installs
|
||||
func (ds *Datastore) deletePendingSoftwareInstallsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
|
||||
var globalOrTeamID uint
|
||||
if teamID != nil {
|
||||
globalOrTeamID = *teamID
|
||||
}
|
||||
|
||||
const deleteStmt = `
|
||||
DELETE FROM
|
||||
host_software_installs
|
||||
WHERE
|
||||
policy_id = ? AND
|
||||
status = ? AND
|
||||
software_installer_id IN (
|
||||
SELECT id FROM software_installers WHERE global_or_team_id = ?
|
||||
)
|
||||
`
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt, policyID, fleet.SoftwareInstallPending, globalOrTeamID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "delete pending software installs for policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) InsertSoftwareInstallRequest(ctx context.Context, hostID uint, softwareInstallerID uint, selfService bool, policyID *uint) (string, error) {
|
||||
const (
|
||||
getInstallerStmt = `SELECT filename, "version", title_id, COALESCE(st.name, '[deleted title]') title_name
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ func TestSoftwareInstallers(t *testing.T) {
|
|||
{"GetSoftwareInstallerMetadataByTeamAndTitleID", testGetSoftwareInstallerMetadataByTeamAndTitleID},
|
||||
{"HasSelfServiceSoftwareInstallers", testHasSelfServiceSoftwareInstallers},
|
||||
{"DeleteSoftwareInstallers", testDeleteSoftwareInstallers},
|
||||
{"testDeletePendingSoftwareInstallsForPolicy", testDeletePendingSoftwareInstallsForPolicy},
|
||||
{"GetHostLastInstallData", testGetHostLastInstallData},
|
||||
{"GetOrGenerateSoftwareInstallerTitleID", testGetOrGenerateSoftwareInstallerTitleID},
|
||||
}
|
||||
|
|
@ -1137,6 +1138,120 @@ func testDeleteSoftwareInstallers(t *testing.T, ds *Datastore) {
|
|||
require.ErrorAs(t, err, &nfe)
|
||||
}
|
||||
|
||||
func testDeletePendingSoftwareInstallsForPolicy(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
|
||||
host1 := test.NewHost(t, ds, "host1", "1", "host1key", "host1uuid", time.Now())
|
||||
host2 := test.NewHost(t, ds, "host2", "2", "host2key", "host2uuid", time.Now())
|
||||
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
|
||||
|
||||
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
dir := t.TempDir()
|
||||
store, err := filesystem.NewSoftwareInstallerStore(dir)
|
||||
require.NoError(t, err)
|
||||
ins0 := "installer.pkg"
|
||||
ins0File := bytes.NewReader([]byte("installer0"))
|
||||
err = store.Put(ctx, ins0, ins0File)
|
||||
require.NoError(t, err)
|
||||
_, _ = ins0File.Seek(0, 0)
|
||||
|
||||
tfr0, err := fleet.NewTempFileReader(ins0File, t.TempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
installerID1, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
||||
InstallScript: "install",
|
||||
InstallerFile: tfr0,
|
||||
StorageID: ins0,
|
||||
Filename: "installer.pkg",
|
||||
Title: "ins0",
|
||||
Source: "apps",
|
||||
Platform: "darwin",
|
||||
TeamID: &team1.ID,
|
||||
UserID: user1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy1, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
|
||||
Name: "p1",
|
||||
Query: "SELECT 1;",
|
||||
SoftwareInstallerID: &installerID1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
installerID2, _, err := ds.MatchOrCreateSoftwareInstaller(ctx, &fleet.UploadSoftwareInstallerPayload{
|
||||
InstallScript: "install",
|
||||
InstallerFile: tfr0,
|
||||
StorageID: ins0,
|
||||
Filename: "installer.pkg",
|
||||
Title: "ins1",
|
||||
Source: "apps",
|
||||
Platform: "darwin",
|
||||
TeamID: &team1.ID,
|
||||
UserID: user1.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy2, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
|
||||
Name: "p2",
|
||||
Query: "SELECT 2;",
|
||||
SoftwareInstallerID: &installerID2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
const hostSoftwareInstallsCount = "SELECT count(1) FROM host_software_installs WHERE status = ? and execution_id = ?"
|
||||
var count int
|
||||
|
||||
// install for correct policy & correct status
|
||||
executionID, err := ds.InsertSoftwareInstallRequest(ctx, host1.ID, installerID1, false, &policy1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
// install for different policy & correct status
|
||||
executionID, err = ds.InsertSoftwareInstallRequest(ctx, host1.ID, installerID2, false, &policy2.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, hostSoftwareInstallsCount, fleet.SoftwareInstallPending, executionID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
// install for correct policy & incorrect status
|
||||
executionID, err = ds.InsertSoftwareInstallRequest(ctx, host2.ID, installerID1, false, &policy1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.SetHostSoftwareInstallResult(ctx, &fleet.HostSoftwareInstallResultPayload{
|
||||
HostID: host2.ID,
|
||||
InstallUUID: executionID,
|
||||
InstallScriptExitCode: ptr.Int(0),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ds.deletePendingSoftwareInstallsForPolicy(ctx, &team1.ID, policy1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sqlx.GetContext(ctx, ds.reader(ctx), &count, `SELECT count(1) FROM host_software_installs WHERE execution_id = ?`, executionID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func testGetHostLastInstallData(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package apple_mdm
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm"
|
||||
)
|
||||
|
|
@ -118,11 +119,15 @@ func HandleHostMDMProfileInstallResult(ctx context.Context, ds fleet.ProfileVeri
|
|||
}
|
||||
|
||||
// otherwise update status and detail as usual
|
||||
return ds.UpdateOrDeleteHostMDMAppleProfile(ctx, &fleet.HostMDMAppleProfile{
|
||||
err := ds.UpdateOrDeleteHostMDMAppleProfile(ctx, &fleet.HostMDMAppleProfile{
|
||||
CommandUUID: cmdUUID,
|
||||
HostUUID: hostUUID,
|
||||
Status: status,
|
||||
Detail: detail,
|
||||
OperationType: fleet.MDMOperationTypeInstall,
|
||||
})
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "updating host MDM Apple profile install result")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/cryptoutil"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
|
|
@ -113,7 +114,7 @@ func New(opts ...Option) (*MySQLStorage, error) {
|
|||
mysqlStore := &MySQLStorage{db: cfg.db, logger: cfg.logger, rm: cfg.rm}
|
||||
if cfg.reader == nil {
|
||||
mysqlStore.reader = func(ctx context.Context) fleet.DBReader {
|
||||
return sqlx.NewDb(mysqlStore.db, "mysql")
|
||||
return sqlx.NewDb(mysqlStore.db, "")
|
||||
}
|
||||
} else {
|
||||
mysqlStore.reader = cfg.reader
|
||||
|
|
@ -337,7 +338,10 @@ func (s *MySQLStorage) updateLastSeenBatch(ctx context.Context, ids []string) {
|
|||
return
|
||||
}
|
||||
|
||||
_, err = s.db.ExecContext(ctx, stmt, args...)
|
||||
err = common_mysql.WithRetryTxx(ctx, sqlx.NewDb(s.db, ""), func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, stmt, args...)
|
||||
return err
|
||||
}, loggerWrapper{s.logger})
|
||||
if err != nil {
|
||||
s.logger.Info("msg", "error batch updating nano_enrollments.last_seen_at", "err", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,11 +8,14 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/datastore/mysql/common_mysql"
|
||||
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/micromdm/nanolib/log"
|
||||
)
|
||||
|
||||
func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) error {
|
||||
func enqueue(ctx context.Context, tx sqlx.ExtContext, ids []string, cmd *mdm.Command) error {
|
||||
if len(ids) < 1 {
|
||||
return errors.New("no id(s) supplied to queue command to")
|
||||
}
|
||||
|
|
@ -50,18 +53,22 @@ func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) er
|
|||
return nil
|
||||
}
|
||||
|
||||
type loggerWrapper struct {
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func (l loggerWrapper) Log(keyvals ...interface{}) error {
|
||||
l.logger.Info(keyvals...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQLStorage) EnqueueCommand(ctx context.Context, ids []string, cmd *mdm.Command) (map[string]error, error) {
|
||||
tx, err := m.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = enqueue(ctx, tx, ids, cmd); err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return nil, fmt.Errorf("rollback error: %w; while trying to handle error: %v", rbErr, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return nil, tx.Commit()
|
||||
// We need to retry because this transaction may deadlock with updates to nano_enrollment.last_seen_at
|
||||
// Deadlock seen in 2024/12/12 loadtest: https://docs.google.com/document/d/1-Q6qFTd7CDm-lh7MVRgpNlNNJijk6JZ4KO49R1fp80U
|
||||
err := common_mysql.WithRetryTxx(ctx, sqlx.NewDb(m.db, ""), func(tx sqlx.ExtContext) error {
|
||||
return enqueue(ctx, tx, ids, cmd)
|
||||
}, loggerWrapper{m.logger})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (m *MySQLStorage) deleteCommand(ctx context.Context, tx *sql.Tx, id, uuid string) error {
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
"github.com/fleetdm/fleet/v4/pkg/file"
|
||||
"github.com/fleetdm/fleet/v4/pkg/scripts"
|
||||
"github.com/fleetdm/fleet/v4/server/authz"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
|
|
@ -507,7 +508,7 @@ func (svc *Service) NewScript(ctx context.Context, teamID *uint, name string, r
|
|||
script := &fleet.Script{
|
||||
TeamID: teamID,
|
||||
Name: name,
|
||||
ScriptContents: string(b),
|
||||
ScriptContents: file.Dos2UnixNewlines(string(b)),
|
||||
}
|
||||
if err := script.ValidateNewScript(); err != nil {
|
||||
return nil, fleet.NewInvalidArgumentError("script", err.Error())
|
||||
|
|
|
|||
|
|
@ -498,10 +498,14 @@ func TestSavedScripts(t *testing.T) {
|
|||
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
|
||||
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{License: license, SkipCreateTestUsers: true})
|
||||
|
||||
withLFContents := "echo\necho"
|
||||
withCRLFContents := "echo\r\necho"
|
||||
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
return &fleet.AppConfig{}, nil
|
||||
}
|
||||
ds.NewScriptFunc = func(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
|
||||
require.Equal(t, withLFContents, script.ScriptContents)
|
||||
newScript := *script
|
||||
newScript.ID = 1
|
||||
return &newScript, nil
|
||||
|
|
@ -669,7 +673,7 @@ func TestSavedScripts(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx = viewer.NewContext(ctx, viewer.Viewer{User: tt.user})
|
||||
|
||||
_, err := svc.NewScript(ctx, nil, "test.sh", strings.NewReader("echo"))
|
||||
_, err := svc.NewScript(ctx, nil, "test.ps1", strings.NewReader(withCRLFContents))
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
err = svc.DeleteScript(ctx, noTeamScriptID)
|
||||
checkAuthErr(t, tt.shouldFailGlobalWrite, err)
|
||||
|
|
@ -680,7 +684,7 @@ func TestSavedScripts(t *testing.T) {
|
|||
_, _, err = svc.GetScript(ctx, noTeamScriptID, true)
|
||||
checkAuthErr(t, tt.shouldFailGlobalRead, err)
|
||||
|
||||
_, err = svc.NewScript(ctx, ptr.Uint(1), "test.sh", strings.NewReader("echo"))
|
||||
_, err = svc.NewScript(ctx, ptr.Uint(1), "test.sh", strings.NewReader(withLFContents))
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
err = svc.DeleteScript(ctx, team1ScriptID)
|
||||
checkAuthErr(t, tt.shouldFailTeamWrite, err)
|
||||
|
|
|
|||
Loading…
Reference in a new issue