mirror of
https://github.com/fleetdm/fleet
synced 2026-05-17 05:58:40 +00:00
> Related issue: #17374 # Checklist for submitter If some of the following don't apply, delete the relevant line. <!-- Note that API documentation changes are now addressed by the product design team. --> - [x] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] If database migrations are included, checked table schema to confirm autoupdate - For database migrations: - [x] Checked schema for all modified table for columns that will auto-update timestamps during migration. - [x] Confirmed that updating the timestamps is acceptable, and will not cause unwanted side effects. - [x] Manual QA for all new/changed functionality
1090 lines
31 KiB
Go
1090 lines
31 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5" //nolint:gosec
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/fleetdm/fleet/v4/pkg/scripts"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
|
|
var res *fleet.HostScriptResult
|
|
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
if request.ScriptContentID == 0 {
|
|
// then we are doing a sync execution, so create the contents first
|
|
scRes, err := insertScriptContents(ctx, request.ScriptContents, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id)
|
|
}
|
|
res, err = newHostScriptExecutionRequest(ctx, request, tx)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func newHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload, tx sqlx.ExtContext) (*fleet.HostScriptResult, error) {
|
|
const (
|
|
insStmt = `INSERT INTO host_script_results (host_id, execution_id, script_content_id, output, script_id, user_id, sync_request) VALUES (?, ?, ?, '', ?, ?, ?)`
|
|
getStmt = `SELECT hsr.id, hsr.host_id, hsr.execution_id, hsr.created_at, hsr.script_id, hsr.user_id, hsr.sync_request, sc.contents as script_contents FROM host_script_results hsr JOIN script_contents sc WHERE sc.id = hsr.script_content_id AND hsr.id = ?`
|
|
)
|
|
|
|
execID := uuid.New().String()
|
|
result, err := tx.ExecContext(ctx, insStmt,
|
|
request.HostID,
|
|
execID,
|
|
request.ScriptContentID,
|
|
request.ScriptID,
|
|
request.UserID,
|
|
request.SyncRequest,
|
|
)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "new host script execution request")
|
|
}
|
|
|
|
var script fleet.HostScriptResult
|
|
id, _ := result.LastInsertId()
|
|
err = sqlx.GetContext(ctx, tx, &script, getStmt, id)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "getting the created host script result to return")
|
|
}
|
|
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) (*fleet.HostScriptResult, error) {
|
|
const resultExistsStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ? AND
|
|
exit_code IS NOT NULL
|
|
`
|
|
|
|
const updStmt = `
|
|
UPDATE host_script_results SET
|
|
output = ?,
|
|
runtime = ?,
|
|
exit_code = ?
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ?`
|
|
|
|
const hostMDMActionsStmt = `
|
|
SELECT
|
|
CASE
|
|
WHEN lock_ref = ? THEN 'lock_ref'
|
|
WHEN unlock_ref = ? THEN 'unlock_ref'
|
|
WHEN wipe_ref = ? THEN 'wipe_ref'
|
|
ELSE ''
|
|
END AS ref_col
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = ?
|
|
`
|
|
|
|
const maxOutputRuneLen = 10000
|
|
output := result.Output
|
|
if len(output) > utf8.UTFMax*maxOutputRuneLen {
|
|
// truncate the bytes as we know the output is too long, no point
|
|
// converting more bytes than needed to runes.
|
|
output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):]
|
|
}
|
|
if utf8.RuneCountInString(output) > maxOutputRuneLen {
|
|
outputRunes := []rune(output)
|
|
output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:])
|
|
}
|
|
|
|
var hsr *fleet.HostScriptResult
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var resultExists bool
|
|
err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return ctxerr.Wrap(ctx, err, "check if host script result exists")
|
|
}
|
|
if resultExists {
|
|
// succeed but leave hsr nil
|
|
return nil
|
|
}
|
|
|
|
res, err := tx.ExecContext(ctx, updStmt,
|
|
output,
|
|
result.Runtime,
|
|
result.ExitCode,
|
|
result.HostID,
|
|
result.ExecutionID,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host script result")
|
|
}
|
|
|
|
if n, _ := res.RowsAffected(); n > 0 {
|
|
// it did update, so return the updated result
|
|
hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load updated host script result")
|
|
}
|
|
|
|
// look up if that script was a lock/unlock/wipe script for that host,
|
|
// and if so update the host_mdm_actions table accordingly.
|
|
var refCol string
|
|
err = sqlx.GetContext(ctx, tx, &refCol, hostMDMActionsStmt, result.ExecutionID, result.ExecutionID, result.ExecutionID, result.HostID)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty
|
|
return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action")
|
|
}
|
|
if refCol != "" {
|
|
err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, refCol, result.ExitCode == 0)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host mdm action based on script result")
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return hsr, nil
|
|
}
|
|
|
|
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint) ([]*fleet.HostScriptResult, error) {
|
|
const listStmt = `
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
execution_id,
|
|
script_id
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
exit_code IS NULL
|
|
-- async requests + sync requests created within the given interval
|
|
AND (
|
|
sync_request = 0
|
|
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
|
|
)
|
|
ORDER BY
|
|
created_at ASC`
|
|
|
|
var results []*fleet.HostScriptResult
|
|
seconds := int(scripts.MaxServerWaitTime.Seconds())
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID, seconds); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) ([]*uint, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
script_id = ? AND
|
|
exit_code IS NULL
|
|
`
|
|
|
|
var results []*uint
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "is execution pending for host")
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
|
|
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID)
|
|
}
|
|
|
|
func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string) (*fleet.HostScriptResult, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
hsr.id,
|
|
hsr.host_id,
|
|
hsr.execution_id,
|
|
sc.contents as script_contents,
|
|
hsr.script_id,
|
|
hsr.output,
|
|
hsr.runtime,
|
|
hsr.exit_code,
|
|
hsr.created_at,
|
|
hsr.user_id,
|
|
hsr.sync_request
|
|
FROM
|
|
host_script_results hsr
|
|
JOIN
|
|
script_contents sc
|
|
WHERE
|
|
hsr.execution_id = ?
|
|
AND
|
|
hsr.script_content_id = sc.id
|
|
`
|
|
|
|
var result fleet.HostScriptResult
|
|
if err := sqlx.GetContext(ctx, q, &result, getStmt, execID); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID))
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get host script result")
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
|
|
var res sql.Result
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
// first insert script contents
|
|
scRes, err := insertScriptContents(ctx, script.ScriptContents, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id, _ := scRes.LastInsertId()
|
|
|
|
// then create the script entity
|
|
res, err = insertScript(ctx, script, uint(id), tx)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id))
|
|
}
|
|
|
|
func insertScript(ctx context.Context, script *fleet.Script, scriptContentsID uint, tx sqlx.ExtContext) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
`
|
|
var globalOrTeamID uint
|
|
if script.TeamID != nil {
|
|
globalOrTeamID = *script.TeamID
|
|
}
|
|
res, err := tx.ExecContext(ctx, insertStmt,
|
|
script.TeamID, globalOrTeamID, script.Name, scriptContentsID)
|
|
if err != nil {
|
|
if isDuplicate(err) {
|
|
// name already exists for this team/global
|
|
err = alreadyExists("Script", script.Name)
|
|
} else if isChildForeignKeyError(err) {
|
|
// team does not exist
|
|
err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID))
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script")
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func insertScriptContents(ctx context.Context, contents string, tx sqlx.ExtContext) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
script_contents (
|
|
md5_checksum, contents
|
|
)
|
|
VALUES (UNHEX(?),?)
|
|
ON DUPLICATE KEY UPDATE
|
|
id=LAST_INSERT_ID(id)
|
|
`
|
|
|
|
md5Checksum := md5ChecksumScriptContent(contents)
|
|
res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script contents")
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func md5ChecksumScriptContent(s string) string {
|
|
rawChecksum := md5.Sum([]byte(s)) //nolint:gosec
|
|
return strings.ToUpper(hex.EncodeToString(rawChecksum[:]))
|
|
}
|
|
|
|
func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
|
|
return ds.getScriptDB(ctx, ds.reader(ctx), id)
|
|
}
|
|
|
|
func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
id,
|
|
team_id,
|
|
name,
|
|
created_at,
|
|
updated_at,
|
|
script_content_id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
id = ?
|
|
`
|
|
var script fleet.Script
|
|
if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script")
|
|
}
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
sc.contents
|
|
FROM
|
|
script_contents sc
|
|
JOIN scripts s
|
|
WHERE
|
|
s.script_content_id = sc.id
|
|
AND s.id = ?;
|
|
`
|
|
var contents []byte
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script contents")
|
|
}
|
|
return contents, nil
|
|
}
|
|
|
|
func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
|
|
_, err := ds.writer(ctx).ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "delete script")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
|
|
var scripts []*fleet.Script
|
|
|
|
const selectStmt = `
|
|
SELECT
|
|
s.id,
|
|
s.team_id,
|
|
s.name,
|
|
s.created_at,
|
|
s.updated_at
|
|
FROM
|
|
scripts s
|
|
WHERE
|
|
s.global_or_team_id = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
args := []any{globalOrTeamID}
|
|
stmt, args := appendListOptionsWithCursorToSQL(selectStmt, args, &opt)
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "select scripts")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(scripts) > int(opt.PerPage) {
|
|
metaData.HasNextResults = true
|
|
scripts = scripts[:len(scripts)-1]
|
|
}
|
|
}
|
|
return scripts, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) {
|
|
const selectStmt = `
|
|
SELECT
|
|
id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
AND name = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var id uint
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return 0, notFound("Script").WithName(name)
|
|
}
|
|
return 0, ctxerr.Wrap(ctx, err, "get script by name")
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) {
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var extension string
|
|
switch {
|
|
case hostPlatform == "windows":
|
|
// filter by .ps1 extension
|
|
extension = `%.ps1`
|
|
case fleet.IsUnixLike(hostPlatform):
|
|
// filter by .sh extension
|
|
extension = `%.sh`
|
|
default:
|
|
// no extension filter
|
|
}
|
|
|
|
type row struct {
|
|
ScriptID uint `db:"script_id"`
|
|
Name string `db:"name"`
|
|
HSRID *uint `db:"hsr_id"`
|
|
ExecutionID *string `db:"execution_id"`
|
|
ExecutedAt *time.Time `db:"executed_at"`
|
|
ExitCode *int64 `db:"exit_code"`
|
|
}
|
|
|
|
sql := `
|
|
SELECT
|
|
s.id AS script_id,
|
|
s.name,
|
|
hsr.id AS hsr_id,
|
|
hsr.created_at AS executed_at,
|
|
hsr.execution_id,
|
|
hsr.exit_code
|
|
FROM
|
|
scripts s
|
|
LEFT JOIN (
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
script_id,
|
|
execution_id,
|
|
created_at,
|
|
exit_code
|
|
FROM
|
|
host_script_results r
|
|
WHERE
|
|
host_id = ?
|
|
AND NOT EXISTS (
|
|
SELECT
|
|
1
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ?
|
|
AND id != r.id
|
|
AND script_id = r.script_id
|
|
AND(created_at > r.created_at
|
|
OR(created_at = r.created_at
|
|
AND id > r.id)))) hsr
|
|
ON s.id = hsr.script_id
|
|
WHERE
|
|
(hsr.host_id IS NULL OR hsr.host_id = ?)
|
|
AND s.global_or_team_id = ?
|
|
`
|
|
|
|
args := []any{hostID, hostID, hostID, globalOrTeamID}
|
|
if len(extension) > 0 {
|
|
args = append(args, extension)
|
|
sql += `
|
|
AND s.name LIKE ?
|
|
`
|
|
}
|
|
stmt, args := appendListOptionsWithCursorToSQL(sql, args, &opt)
|
|
|
|
var rows []*row
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get host script details")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(rows) > int(opt.PerPage) {
|
|
metaData.HasNextResults = true
|
|
rows = rows[:len(rows)-1]
|
|
}
|
|
}
|
|
|
|
results := make([]*fleet.HostScriptDetail, 0, len(rows))
|
|
for _, r := range rows {
|
|
results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID))
|
|
}
|
|
|
|
return results, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) error {
|
|
const loadExistingScripts = `
|
|
SELECT
|
|
name
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name IN (?)
|
|
`
|
|
const deleteAllScriptsInTeam = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
`
|
|
|
|
const deleteScriptsNotInList = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name NOT IN (?)
|
|
`
|
|
|
|
const insertNewOrEditedScript = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
script_content_id = VALUES(script_content_id)
|
|
`
|
|
|
|
// use a team id of 0 if no-team
|
|
var globalOrTeamID uint
|
|
if tmID != nil {
|
|
globalOrTeamID = *tmID
|
|
}
|
|
|
|
// build a list of names for the incoming scripts, will keep the
|
|
// existing ones if there's a match and no change
|
|
incomingNames := make([]string, len(scripts))
|
|
// at the same time, index the incoming scripts keyed by name for ease
|
|
// of processing
|
|
incomingScripts := make(map[string]*fleet.Script, len(scripts))
|
|
for i, p := range scripts {
|
|
incomingNames[i] = p.Name
|
|
incomingScripts[p.Name] = p
|
|
}
|
|
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var existingScripts []*fleet.Script
|
|
|
|
if len(incomingNames) > 0 {
|
|
// load existing scripts that match the incoming scripts by names
|
|
stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build query to load existing scripts")
|
|
}
|
|
if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load existing scripts")
|
|
}
|
|
}
|
|
|
|
// figure out if we need to delete any scripts
|
|
keepNames := make([]string, 0, len(incomingNames))
|
|
for _, p := range existingScripts {
|
|
if newS := incomingScripts[p.Name]; newS != nil {
|
|
keepNames = append(keepNames, p.Name)
|
|
}
|
|
}
|
|
|
|
var (
|
|
stmt string
|
|
args []any
|
|
err error
|
|
)
|
|
if len(keepNames) > 0 {
|
|
// delete the obsolete scripts
|
|
stmt, args, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts")
|
|
}
|
|
} else {
|
|
stmt = deleteAllScriptsInTeam
|
|
args = []any{globalOrTeamID}
|
|
}
|
|
if _, err := tx.ExecContext(ctx, stmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "delete obsolete scripts")
|
|
}
|
|
|
|
// insert the new scripts and the ones that have changed
|
|
for _, s := range incomingScripts {
|
|
scRes, err := insertScriptContents(ctx, s.ScriptContents, tx)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name)
|
|
}
|
|
id, _ := scRes.LastInsertId()
|
|
if _, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(id)); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
|
|
const stmt = `
|
|
SELECT
|
|
lock_ref,
|
|
wipe_ref,
|
|
unlock_ref,
|
|
unlock_pin,
|
|
fleet_platform
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = ?
|
|
`
|
|
|
|
var mdmActions struct {
|
|
LockRef *string `db:"lock_ref"`
|
|
WipeRef *string `db:"wipe_ref"`
|
|
UnlockRef *string `db:"unlock_ref"`
|
|
UnlockPIN *string `db:"unlock_pin"`
|
|
FleetPlatform string `db:"fleet_platform"`
|
|
}
|
|
fleetPlatform := host.FleetPlatform()
|
|
status := &fleet.HostLockWipeStatus{
|
|
HostFleetPlatform: fleetPlatform,
|
|
}
|
|
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
// do not return a Not Found error, return the zero-value status, which
|
|
// will report the correct states.
|
|
return status, nil
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status")
|
|
}
|
|
|
|
// if we have a fleet platform stored in host_mdm_actions, use it instead of
|
|
// the host.FleetPlatform() because the platform can be overwritten with an
|
|
// unknown OS name when a Wipe gets executed.
|
|
if mdmActions.FleetPlatform != "" {
|
|
fleetPlatform = mdmActions.FleetPlatform
|
|
status.HostFleetPlatform = fleetPlatform
|
|
}
|
|
|
|
switch fleetPlatform {
|
|
case "darwin":
|
|
if mdmActions.UnlockPIN != nil {
|
|
status.UnlockPIN = *mdmActions.UnlockPIN
|
|
}
|
|
if mdmActions.UnlockRef != nil {
|
|
var err error
|
|
status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef)
|
|
if err != nil {
|
|
// if the format is unexpected but there's something in UnlockRef, just
|
|
// replace it with the current timestamp, it should still indicate that
|
|
// an unlock was requested (e.g. in case someone plays with the data
|
|
// directly in the DB and messes up the format).
|
|
status.UnlockRequestedAt = time.Now().UTC()
|
|
}
|
|
}
|
|
|
|
if mdmActions.LockRef != nil {
|
|
// the lock reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference")
|
|
}
|
|
status.LockMDMCommand = cmd
|
|
status.LockMDMCommandResult = cmdRes
|
|
}
|
|
|
|
if mdmActions.WipeRef != nil {
|
|
// the wipe reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
}
|
|
|
|
case "windows", "linux":
|
|
// lock and unlock references are scripts
|
|
if mdmActions.LockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference script result")
|
|
}
|
|
status.LockScript = hsr
|
|
}
|
|
|
|
if mdmActions.UnlockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result")
|
|
}
|
|
status.UnlockScript = hsr
|
|
}
|
|
|
|
// wipe is an MDM command on Windows, a script on Linux
|
|
if mdmActions.WipeRef != nil {
|
|
if fleetPlatform == "windows" {
|
|
cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
} else {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result")
|
|
}
|
|
status.WipeScript = hsr
|
|
}
|
|
}
|
|
}
|
|
|
|
return status, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command
|
|
// is pending). Note that it doesn't return ErrNoRows if not found, it
|
|
// returns success and an empty cmdRes slice.
|
|
cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result")
|
|
}
|
|
|
|
// each item in the slice returned by GetMDMWindowsCommandResults is
|
|
// potentially a result for a different host, we need to find the one for
|
|
// that specific host.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
continue
|
|
}
|
|
// all statuses for Windows indicate end of processing of the command
|
|
// (there is no equivalent of "NotNow" or "Idle" as for Apple).
|
|
cmdRes = r
|
|
break
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command
|
|
// is pending). Note that it doesn't return ErrNoRows if not found, it
|
|
// returns success and an empty cmdRes slice.
|
|
cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result")
|
|
}
|
|
|
|
// each item in the slice returned by GetMDMAppleCommandResults is
|
|
// potentially a result for a different host, we need to find the one for
|
|
// that specific host.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
continue
|
|
}
|
|
if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError {
|
|
cmdRes = r
|
|
break
|
|
}
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
// LockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, request.ScriptContents, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id)
|
|
|
|
res, err = newHostScriptExecutionRequest(ctx, request, tx)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to lock the host that is recorded,
|
|
// it is pending execution. The host's state should be updated to "locked"
|
|
// only when the script execution is successfully completed, and then any
|
|
// unlock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
lock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
lock_ref = VALUES(lock_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UnlockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, request.ScriptContents, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id)
|
|
|
|
res, err = newHostScriptExecutionRequest(ctx, request, tx)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to unlock the host that is
|
|
// recorded, it is pending execution. The host's state should be updated to
|
|
// "unlocked" only when the script execution is successfully completed, and
|
|
// then any lock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
unlock_ref = VALUES(unlock_ref),
|
|
unlock_pin = NULL
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
// WipeHostViaScript creates the script execution request and updates the
|
|
// host_mdm_actions table in a single transaction.
|
|
func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, request.ScriptContents, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id)
|
|
|
|
res, err = newHostScriptExecutionRequest(ctx, request, tx)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to wipe the host that is recorded,
|
|
// it is pending execution, so if it was locked, it is still locked (so the
|
|
// lock_ref info must still be there).
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
wipe_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
wipe_ref = VALUES(wipe_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error {
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
-- do not overwrite if a value is already set
|
|
unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref)
|
|
`
|
|
// for macOS, the unlock_ref is just the timestamp at which the user first
|
|
// requested to unlock the host. This then indicates in the host's status
|
|
// that it's pending an unlock (which requires manual intervention by
|
|
// entering a PIN on the device). The /unlock endpoint can be called multiple
|
|
// times, so we record the timestamp of the first time it was requested and
|
|
// from then on, the host is marked as "pending unlock" until the device is
|
|
// actually unlocked with the PIN. The actual unlocking happens when the
|
|
// device sends an Idle MDM request.
|
|
unlockRef := ts.Format(time.DateTime)
|
|
_, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform)
|
|
return ctxerr.Wrap(ctx, err, "record manual unlock host request")
|
|
}
|
|
|
|
func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string) string {
|
|
var alias string
|
|
|
|
stmt := `UPDATE host_mdm_actions `
|
|
if joinPart != "" {
|
|
stmt += `hma ` + joinPart
|
|
alias = "hma."
|
|
}
|
|
stmt += ` SET `
|
|
|
|
if succeeded {
|
|
switch refCol {
|
|
case "lock_ref":
|
|
// Note that this must not clear the unlock_pin, because recording the
|
|
// lock request does generate the PIN and store it there to be used by an
|
|
// eventual unlock.
|
|
stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias)
|
|
case "unlock_ref":
|
|
// a successful unlock clears itself as well as the lock ref, because
|
|
// unlock is the default state so we don't need to keep its unlock_ref
|
|
// around once it's confirmed.
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias)
|
|
case "wipe_ref":
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias)
|
|
}
|
|
} else {
|
|
// if the action failed, then we clear the reference to that action itself so
|
|
// the host stays in the previous state (it doesn't transition to the new
|
|
// state).
|
|
stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias)
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error {
|
|
// a bit of MDM protocol leaking in the mysql layer, but it's either that or
|
|
// the other way around (MDM protocol would translate to database column)
|
|
var refCol string
|
|
switch requestType {
|
|
case "EraseDevice":
|
|
refCol = "wipe_ref"
|
|
case "DeviceLock":
|
|
refCol = "lock_ref"
|
|
default:
|
|
return nil
|
|
}
|
|
return updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded)
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResultAndHostUUID(ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`)
|
|
stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid")
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "")
|
|
stmt += ` WHERE host_id = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result")
|
|
}
|
|
|
|
func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error {
|
|
deleteStmt := `
|
|
DELETE FROM
|
|
script_contents
|
|
WHERE
|
|
NOT EXISTS (
|
|
SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM scripts WHERE script_content_id = script_contents.id)
|
|
`
|
|
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "cleaning up unused script contents")
|
|
}
|
|
return nil
|
|
}
|