From 736dbbc6005bc40eaba536ed078c2e8b80824688 Mon Sep 17 00:00:00 2001 From: Zach Wasserman Date: Mon, 22 Feb 2021 14:52:21 -0800 Subject: [PATCH] Add update runner Runner handles watching for updates in the background and exiting the Orbit process after a successful update. --- cmd/orbit/orbit.go | 12 +++ pkg/packaging/macos.go | 6 +- pkg/packaging/macos_templates.go | 7 +- pkg/update/runner.go | 155 +++++++++++++++++++++++++++++++ pkg/update/update.go | 32 +++---- 5 files changed, 194 insertions(+), 18 deletions(-) create mode 100644 pkg/update/runner.go diff --git a/cmd/orbit/orbit.go b/cmd/orbit/orbit.go index f532938635..c042cf5978 100644 --- a/cmd/orbit/orbit.go +++ b/cmd/orbit/orbit.go @@ -141,6 +141,18 @@ func main() { } var g run.Group + + updateRunner, err := update.NewRunner(updater, update.RunnerOptions{ + CheckInterval: 10 * time.Second, + Targets: map[string]string{ + "osqueryd": c.String("osquery-version"), + }, + }) + if err != nil { + return err + } + g.Add(updateRunner.Execute, updateRunner.Interrupt) + var options []func(*osquery.Runner) error options = append(options, osquery.WithDataPath(c.String("root-dir"))) diff --git a/pkg/packaging/macos.go b/pkg/packaging/macos.go index 2481d55a0b..304af9cd2d 100644 --- a/pkg/packaging/macos.go +++ b/pkg/packaging/macos.go @@ -83,7 +83,11 @@ func BuildPkg(opt Options) error { return errors.Wrap(err, "write launchd") } } - if err := copyFile("./orbit", filepath.Join(orbitRoot, "orbit"), 0755); err != nil { + if err := copyFile( + "./orbit", + filepath.Join(orbitRoot, "bin", "orbit", "macos", "current", "orbit"), + 0755, + ); err != nil { return errors.Wrap(err, "write orbit") } diff --git a/pkg/packaging/macos_templates.go b/pkg/packaging/macos_templates.go index 8f656b4959..ebf451a850 100644 --- a/pkg/packaging/macos_templates.go +++ b/pkg/packaging/macos_templates.go @@ -46,6 +46,11 @@ launchctl load /Library/LaunchDaemons/com.fleetdm.orbit.plist `)) // TODO set Nice? +// +//Note it's important not to start the orbit binary in +// `/usr/local/bin/orbit` because this is a path that users usually have write +// access to, and running that binary with launchd can become a privilege +// escalation vector. var macosLaunchdTemplate = template.Must(template.New("").Option("missingkey=error").Parse( ` @@ -55,7 +60,7 @@ var macosLaunchdTemplate = template.Must(template.New("").Option("missingkey=err com.fleetdm.orbit ProgramArguments - /var/lib/orbit/orbit + /var/lib/orbit/bin/orbit/macos/current/orbit StandardOutPath /var/log/orbit/orbit.stdout.log diff --git a/pkg/update/runner.go b/pkg/update/runner.go new file mode 100644 index 0000000000..489e9330d6 --- /dev/null +++ b/pkg/update/runner.go @@ -0,0 +1,155 @@ +package update + +import ( + "bytes" + "os" + "path/filepath" + "time" + + "github.com/pkg/errors" + "github.com/rs/zerolog/log" +) + +// RunnerOptions is options provided for the update runner. +type RunnerOptions struct { + // CheckInterval is the interval to check for updates. + CheckInterval time.Duration + // Targets is the names of the artifacts to watch for updates. + Targets map[string]string +} + +// Runner is a specialized runner for the updater. It is designed with Execute and +// Interrupt functions to be compatible with oklog/run. +type Runner struct { + client *Updater + opt RunnerOptions + cancel chan struct{} + hashCache map[string][]byte +} + +// NewRunner creates a new runner with the provided options. The runner must be +// started with Execute. +func NewRunner(client *Updater, opt RunnerOptions) (*Runner, error) { + if opt.CheckInterval <= 0 { + return nil, errors.New("Runner must be configured with interval greater than 0") + } + if len(opt.Targets) == 0 { + return nil, errors.New("Runner must have nonempty subscriptions") + } + + // Initialize hash cache + cache := make(map[string][]byte) + for target, channel := range opt.Targets { + meta, err := client.Lookup(target, channel) + if err != nil { + return nil, errors.Wrap(err, "initialize update cache") + } + + _, hash, err := selectHashFunction(meta) + if err != nil { + return nil, errors.Wrap(err, "select hash for cache") + } + cache[target] = hash + } + + return &Runner{ + client: client, + opt: opt, + + // chan gets capacity of 1 so we don't end up hung if Interrupt is + // called after Execute has already returned. + cancel: make(chan struct{}, 1), + hashCache: cache, + }, nil +} + +// Execute begins a loop checking for updates. +func (r *Runner) Execute() error { + ticker := time.NewTicker(r.opt.CheckInterval) + defer ticker.Stop() + + // Run until cancel or returning an error + for { + select { + case <-r.cancel: + return nil + + case <-ticker.C: + // On each tick, check for updates + didUpdate, err := r.updateAction() + if err != nil { + log.Info().Err(err).Msg("update failed") + } + if didUpdate { + log.Info().Msg("exiting due to successful update") + return nil + } + } + } +} + +func (r *Runner) updateAction() (bool, error) { + var didUpdate bool + if err := r.client.UpdateMetadata(); err != nil { + // Consider this a non-fatal error since it will be common to be offline + // or otherwise unable to retrieve the metadata. + return didUpdate, errors.Wrap(err, "update metadata") + } + + for target, channel := range r.opt.Targets { + meta, err := r.client.Lookup(target, channel) + if err != nil { + return didUpdate, errors.Wrapf(err, "lookup failed") + } + + // Check whether the hash has changed + _, hash, err := selectHashFunction(meta) + if err != nil { + return didUpdate, errors.Wrap(err, "select hash for cache") + } + + if !bytes.Equal(r.hashCache[target], hash) { + // Update detected + log.Info().Str("target", target).Str("channel", channel).Msg("update detected") + if err := r.updateTarget(target, channel); err != nil { + return didUpdate, errors.Wrapf(err, "update %s@%s", target, channel) + } + log.Info().Str("target", target).Str("channel", channel).Msg("update completed") + didUpdate = true + } else { + log.Debug().Str("target", target).Str("channel", channel).Msg("no update") + } + } + + return didUpdate, nil +} + +func (r *Runner) updateTarget(target, channel string) error { + path, err := r.client.Get(target, channel) + if err != nil { + return errors.Wrap(err, "get binary") + } + + // Replace file/link + currentPath := r.client.LocalPath(target, "current") + if err := os.Remove(currentPath); err != nil && !os.IsNotExist(err) { + return errors.Wrap(err, "remove old current") + } + + if err := os.MkdirAll(filepath.Dir(currentPath), 0755); err != nil { + return errors.Wrap(err, "mkdir for symlink") + } + if err := os.Symlink(path, currentPath); err != nil { + return errors.Wrap(err, "symlink current") + } + + // TODO signal a restart? + + return nil +} + +func (r *Runner) Interrupt(err error) { + r.cancel <- struct{}{} + log.Debug().Msg("interrupt updater") + return +} diff --git a/pkg/update/update.go b/pkg/update/update.go index 6dd63b4d1c..2fd31de1e2 100644 --- a/pkg/update/update.go +++ b/pkg/update/update.go @@ -20,7 +20,7 @@ import ( const ( binDir = "bin" - defaultRootKeys = `[{"keytype":"ed25519","scheme":"ed25519","keyid_hash_algorithms":["sha256","sha512"],"keyval":{"public":"c5008789635b7ac63228d80eec24edbfb8b3bddfd2121b08456de201ec24df7a"}}]` + defaultRootKeys = `[{"keytype":"ed25519","scheme":"ed25519","keyid_hash_algorithms":["sha256","sha512"],"keyval":{"public":"037b475337c1acdafe20cff4fee6308209bc4ba23a2439a1f7be85131794cae1"}}]` ) // Updater is responsible for managing update state. @@ -43,7 +43,7 @@ type Options struct { RootKeys string // LocalStore is the local metadata store. LocalStore client.LocalStore - // Platform is the name of the platform to update for. In the default + // Platform is the target of the platform to update for. In the default // options this is the current platform. Platform string } @@ -115,23 +115,23 @@ func (u *Updater) UpdateMetadata() error { return nil } -func (u *Updater) RepoPath(name, channel string) string { - return path.Join(name, u.opt.Platform, channel, name+constant.ExecutableExtension(u.opt.Platform)) +func (u *Updater) RepoPath(target, channel string) string { + return path.Join(target, u.opt.Platform, channel, target+constant.ExecutableExtension(u.opt.Platform)) } -func (u *Updater) LocalPath(name, channel string) string { - return u.pathFromRoot(filepath.Join(binDir, name, u.opt.Platform, channel, name+constant.ExecutableExtension(u.opt.Platform))) +func (u *Updater) LocalPath(target, channel string) string { + return u.pathFromRoot(filepath.Join(binDir, target, u.opt.Platform, channel, target+constant.ExecutableExtension(u.opt.Platform))) } // Lookup looks up the provided target in the local target metadata. This should // be called after UpdateMetadata. -func (u *Updater) Lookup(name, channel string) (*data.TargetFileMeta, error) { - target, err := u.client.Target(u.RepoPath(name, channel)) +func (u *Updater) Lookup(target, channel string) (*data.TargetFileMeta, error) { + t, err := u.client.Target(u.RepoPath(target, channel)) if err != nil { - return nil, errors.Wrapf(err, "lookup %s@%s", name, channel) + return nil, errors.Wrapf(err, "lookup %s@%s", target, channel) } - return &target, nil + return &t, nil } // Targets gets all of the known targets @@ -146,9 +146,9 @@ func (u *Updater) Targets() (data.TargetFiles, error) { // Get returns the local path to the specified target. The target is downloaded // if it does not yet exist locally or the hash does not match. -func (u *Updater) Get(name, channel string) (string, error) { - localPath := u.LocalPath(name, channel) - repoPath := u.RepoPath(name, channel) +func (u *Updater) Get(target, channel string) (string, error) { + localPath := u.LocalPath(target, channel) + repoPath := u.RepoPath(target, channel) stat, err := os.Stat(localPath) if err != nil { log.Debug().Err(err).Msg("stat file") @@ -158,17 +158,17 @@ func (u *Updater) Get(name, channel string) (string, error) { return "", errors.Errorf("expected %s to be regular file", localPath) } - meta, err := u.Lookup(name, channel) + meta, err := u.Lookup(target, channel) if err != nil { return "", err } if err := CheckFileHash(meta, localPath); err != nil { - log.Debug().Err(err).Msg("will redownload due to error checking hash") + log.Debug().Err(err).Msg("will redownload") return localPath, u.Download(repoPath, localPath) } - log.Debug().Str("path", localPath).Msg("found expected channel locally") + log.Debug().Str("path", localPath).Str("target", target).Str("channel", channel).Msg("found expected target locally") return localPath, nil }