From 897cb35e24e1faf9185d34aeb0f8e5d1c4178929 Mon Sep 17 00:00:00 2001 From: Victor Vrantchan Date: Fri, 24 Feb 2017 19:33:42 -0500 Subject: [PATCH] Allow checking in to license server when license is saved. (#1299) * Allow checking in to license server when license is saved. Closes #1290 Closes #1277 --- cli/prepare.go | 2 +- cli/serve.go | 2 +- server/kolide/licenses.go | 6 +++ server/license/checker.go | 63 +++++++++++++++++++++--------- server/license/checker_test.go | 21 ++++++++-- server/service/service.go | 24 ++++++------ server/service/service_licenses.go | 2 + server/service/util_test.go | 4 +- 8 files changed, 86 insertions(+), 38 deletions(-) diff --git a/cli/prepare.go b/cli/prepare.go index 3e411aa6fd..9b11b7c54d 100644 --- a/cli/prepare.go +++ b/cli/prepare.go @@ -77,7 +77,7 @@ To setup kolide infrastructure, use one of the available commands. Enabled: &enabled, Admin: &isAdmin, } - svc, err := service.NewService(ds, pubsub.NewInmemQueryResults(), kitlog.NewNopLogger(), config, nil, clock.C) + svc, err := service.NewService(ds, pubsub.NewInmemQueryResults(), kitlog.NewNopLogger(), config, nil, clock.C, nil) if err != nil { initFatal(err, "creating service") } diff --git a/cli/serve.go b/cli/serve.go index 3489f43992..d8dd624ad8 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -88,7 +88,7 @@ the way that the kolide server works. redisPool := pubsub.NewRedisPool(config.Redis.Address, config.Redis.Password) resultStore = pubsub.NewRedisQueryResults(redisPool) - svc, err := service.NewService(ds, resultStore, logger, config, mailService, clock.C) + svc, err := service.NewService(ds, resultStore, logger, config, mailService, clock.C, licenseService) if err != nil { initFatal(err, "initializing service") } diff --git a/server/kolide/licenses.go b/server/kolide/licenses.go index c0660b346b..95e558c1e7 100644 --- a/server/kolide/licenses.go +++ b/server/kolide/licenses.go @@ -125,3 +125,9 @@ func (l *License) Claims() (*Claims, error) { result.HostCount = int(l.HostCount) return &result, nil } + +// LicenseChecker allows checking that a license is valid by calling in to +// a remote URL. +type LicenseChecker interface { + RunLicenseCheck(ctx context.Context) +} diff --git a/server/license/checker.go b/server/license/checker.go index 9c9e938c43..29abc21f56 100644 --- a/server/license/checker.go +++ b/server/license/checker.go @@ -2,15 +2,18 @@ package license import ( "encoding/json" - "errors" "fmt" "net/http" + "net/url" "sync" "time" "github.com/WatchBeam/clock" "github.com/go-kit/kit/log" "github.com/kolide/kolide/server/kolide" + "github.com/kolide/kolide/server/version" + "github.com/pkg/errors" + "golang.org/x/net/context" ) const ( @@ -122,7 +125,7 @@ func (cc *Checker) Start() error { chk.logger.Log("msg", "finishing") return case <-chk.ticker.Chan(): - updateLicenseRevocation(&chk) + chk.RunLicenseCheck(context.Background()) } } }(*cc, &wait) @@ -138,51 +141,73 @@ func (cc *Checker) Stop() { cc.finish = nil } -func updateLicenseRevocation(chk *Checker) { - chk.logger.Log("msg", "begin license check") - defer chk.logger.Log("msg", "ending license check") - - license, err := chk.ds.License() +// addVersionInfo parses the license URL and adds the current revision of the +// kolide binary to the query params. The reported revision is set using +// ldflags by the make command, otherwise defaults to 'unknown'. +func addVersionInfo(licenseURL string) (*url.URL, error) { + ur, err := url.Parse(licenseURL) if err != nil { - chk.logger.Log("msg", "couldn't fetch license", "err", err) + return nil, errors.Wrapf(err, "license checker failed to parse URL string %q", licenseURL) + } + revision := version.Version().Revision + q := ur.Query() + q.Set("version", revision) + ur.RawQuery = q.Encode() + return ur, nil +} + +func (cc *Checker) RunLicenseCheck(ctx context.Context) { + cc.logger.Log("msg", "begin license check") + defer cc.logger.Log("msg", "ending license check") + + license, err := cc.ds.License() + if err != nil { + cc.logger.Log("msg", "couldn't fetch license", "err", err) return } claims, err := license.Claims() if err != nil { - chk.logger.Log("msg", "fetching claims", "err", err) + cc.logger.Log("msg", "fetching claims", "err", err) return } - url := fmt.Sprintf("%s/%s", chk.url, claims.LicenseUUID) - resp, err := chk.client.Get(url) + + licenseURL, err := addVersionInfo(fmt.Sprintf("%s/%s", cc.url, claims.LicenseUUID)) if err != nil { - chk.logger.Log("msg", fmt.Sprintf("fetching %s", url), "err", err) + cc.logger.Log("msg", "adding version information to license", "err", err) + return + } + + resp, err := cc.client.Get(licenseURL.String()) + if err != nil { + cc.logger.Log("msg", fmt.Sprintf("fetching %s", licenseURL.String()), "err", err) return } defer resp.Body.Close() + switch resp.StatusCode { case http.StatusOK: var revInfo revokeInfo err = json.NewDecoder(resp.Body).Decode(&revInfo) if err != nil { - chk.logger.Log("msg", "decoding response", "err", err) + cc.logger.Log("msg", "decoding response", "err", err) return } - err = chk.ds.RevokeLicense(revInfo.Revoked) + err = cc.ds.RevokeLicense(revInfo.Revoked) if err != nil { - chk.logger.Log("msg", "revoke status", "err", err) + cc.logger.Log("msg", "revoke status", "err", err) return } // success - chk.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked)) + cc.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked)) case http.StatusNotFound: var revInfo revokeError err = json.NewDecoder(resp.Body).Decode(&revInfo) if err != nil { - chk.logger.Log("msg", "decoding response", "err", err) + cc.logger.Log("msg", "decoding response", "err", err) return } - chk.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error)) + cc.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error)) default: - chk.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status)) + cc.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status)) } } diff --git a/server/license/checker_test.go b/server/license/checker_test.go index 244d1a712e..82c0880e1b 100644 --- a/server/license/checker_test.go +++ b/server/license/checker_test.go @@ -15,6 +15,8 @@ import ( "github.com/kolide/kolide/server/kolide" "github.com/kolide/kolide/server/mock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" ) var tokenString = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjRkOmM1OmRlOmE1Oj" + @@ -82,6 +84,7 @@ func TestLicenseFound(t *testing.T) { mockTicker(c.NewTicker(time.Millisecond)), ) checker.Start() + checker.RunLicenseCheck(context.Background()) <-time.After(10 * time.Millisecond) c.AddTime(time.Millisecond) c.AddTime(time.Millisecond) @@ -92,8 +95,8 @@ func TestLicenseFound(t *testing.T) { // use the flags from the mock package to indicate function invocation race detector will // complain - assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked)) - assert.Equal(t, int64(2), atomic.LoadInt64(&revokeFunInvoked)) + assert.Equal(t, int64(3), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(3), atomic.LoadInt64(&revokeFunInvoked)) } func TestLicenseNotFound(t *testing.T) { @@ -134,12 +137,13 @@ func TestLicenseNotFound(t *testing.T) { mockTicker(c.NewTicker(time.Millisecond)), ) checker.Start() + checker.RunLicenseCheck(context.Background()) <-time.After(10 * time.Millisecond) c.AddTime(time.Millisecond) <-time.After(10 * time.Millisecond) checker.Stop() - assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked)) assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked)) } @@ -208,12 +212,13 @@ func TestLicenseTimeout(t *testing.T) { Logger(logger), ) checker.Start() + checker.RunLicenseCheck(context.Background()) <-time.After(10 * time.Millisecond) c.AddTime(time.Millisecond) <-time.After(10 * time.Millisecond) checker.Stop() - assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked)) assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked)) match, _ := regexp.MatchString("(Client.Timeout exceeded while awaiting headers)", logger.read()) assert.True(t, match) @@ -222,3 +227,11 @@ func TestLicenseTimeout(t *testing.T) { assert.True(t, match) } + +func TestURLAddVersionInfo(t *testing.T) { + licenseURL := "https://kolide.co/api/v0/licenses" + ur, err := addVersionInfo(licenseURL) + require.Nil(t, err) + want := licenseURL + "?version=unknown" + assert.Equal(t, want, ur.String(), "query params must include version") +} diff --git a/server/service/service.go b/server/service/service.go index aab68949c2..fc7c311ab7 100644 --- a/server/service/service.go +++ b/server/service/service.go @@ -13,7 +13,7 @@ import ( ) // NewService creates a new service from the config struct -func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger kitlog.Logger, kolideConfig config.KolideConfig, mailService kolide.MailService, c clock.Clock) (kolide.Service, error) { +func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger kitlog.Logger, kolideConfig config.KolideConfig, mailService kolide.MailService, c clock.Clock, checker kolide.LicenseChecker) (kolide.Service, error) { var svc kolide.Service logFile := func(path string) io.Writer { @@ -26,11 +26,12 @@ func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger } svc = service{ - ds: ds, - resultStore: resultStore, - logger: logger, - config: kolideConfig, - clock: c, + ds: ds, + resultStore: resultStore, + logger: logger, + config: kolideConfig, + clock: c, + licenseChecker: checker, osqueryStatusLogWriter: logFile(kolideConfig.Osquery.StatusLogFile), osqueryResultLogWriter: logFile(kolideConfig.Osquery.ResultLogFile), @@ -41,11 +42,12 @@ func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger } type service struct { - ds kolide.Datastore - resultStore kolide.QueryResultStore - logger kitlog.Logger - config config.KolideConfig - clock clock.Clock + ds kolide.Datastore + resultStore kolide.QueryResultStore + logger kitlog.Logger + config config.KolideConfig + clock clock.Clock + licenseChecker kolide.LicenseChecker osqueryStatusLogWriter io.Writer osqueryResultLogWriter io.Writer diff --git a/server/service/service_licenses.go b/server/service/service_licenses.go index 0abd19c773..2b17488384 100644 --- a/server/service/service_licenses.go +++ b/server/service/service_licenses.go @@ -22,5 +22,7 @@ func (svc service) SaveLicense(ctx context.Context, jwtToken string) (*kolide.Li if err != nil { return nil, err } + // schedule a checkin with the license server. + go func() { svc.licenseChecker.RunLicenseCheck(ctx) }() return updated, nil } diff --git a/server/service/util_test.go b/server/service/util_test.go index c6461dd123..1fabd85adc 100644 --- a/server/service/util_test.go +++ b/server/service/util_test.go @@ -12,12 +12,12 @@ import ( func newTestService(ds kolide.Datastore, rs kolide.QueryResultStore) (kolide.Service, error) { mailer := &mockMailService{SendEmailFn: func(e kolide.Email) error { return nil }} - return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, clock.C) + return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, clock.C, nil) } func newTestServiceWithClock(ds kolide.Datastore, rs kolide.QueryResultStore, c clock.Clock) (kolide.Service, error) { mailer := &mockMailService{SendEmailFn: func(e kolide.Email) error { return nil }} - return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, c) + return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, c, nil) } func createTestAppConfig(t *testing.T, ds kolide.Datastore) *kolide.AppConfig {