mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 01:18:42 +00:00
Allow checking in to license server when license is saved. (#1299)
* Allow checking in to license server when license is saved. Closes #1290 Closes #1277
This commit is contained in:
parent
d10cb6e725
commit
897cb35e24
8 changed files with 86 additions and 38 deletions
|
|
@ -77,7 +77,7 @@ To setup kolide infrastructure, use one of the available commands.
|
|||
Enabled: &enabled,
|
||||
Admin: &isAdmin,
|
||||
}
|
||||
svc, err := service.NewService(ds, pubsub.NewInmemQueryResults(), kitlog.NewNopLogger(), config, nil, clock.C)
|
||||
svc, err := service.NewService(ds, pubsub.NewInmemQueryResults(), kitlog.NewNopLogger(), config, nil, clock.C, nil)
|
||||
if err != nil {
|
||||
initFatal(err, "creating service")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ the way that the kolide server works.
|
|||
redisPool := pubsub.NewRedisPool(config.Redis.Address, config.Redis.Password)
|
||||
resultStore = pubsub.NewRedisQueryResults(redisPool)
|
||||
|
||||
svc, err := service.NewService(ds, resultStore, logger, config, mailService, clock.C)
|
||||
svc, err := service.NewService(ds, resultStore, logger, config, mailService, clock.C, licenseService)
|
||||
if err != nil {
|
||||
initFatal(err, "initializing service")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -125,3 +125,9 @@ func (l *License) Claims() (*Claims, error) {
|
|||
result.HostCount = int(l.HostCount)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// LicenseChecker allows checking that a license is valid by calling in to
|
||||
// a remote URL.
|
||||
type LicenseChecker interface {
|
||||
RunLicenseCheck(ctx context.Context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,15 +2,18 @@ package license
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/kolide/kolide/server/kolide"
|
||||
"github.com/kolide/kolide/server/version"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -122,7 +125,7 @@ func (cc *Checker) Start() error {
|
|||
chk.logger.Log("msg", "finishing")
|
||||
return
|
||||
case <-chk.ticker.Chan():
|
||||
updateLicenseRevocation(&chk)
|
||||
chk.RunLicenseCheck(context.Background())
|
||||
}
|
||||
}
|
||||
}(*cc, &wait)
|
||||
|
|
@ -138,51 +141,73 @@ func (cc *Checker) Stop() {
|
|||
cc.finish = nil
|
||||
}
|
||||
|
||||
func updateLicenseRevocation(chk *Checker) {
|
||||
chk.logger.Log("msg", "begin license check")
|
||||
defer chk.logger.Log("msg", "ending license check")
|
||||
|
||||
license, err := chk.ds.License()
|
||||
// addVersionInfo parses the license URL and adds the current revision of the
|
||||
// kolide binary to the query params. The reported revision is set using
|
||||
// ldflags by the make command, otherwise defaults to 'unknown'.
|
||||
func addVersionInfo(licenseURL string) (*url.URL, error) {
|
||||
ur, err := url.Parse(licenseURL)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "couldn't fetch license", "err", err)
|
||||
return nil, errors.Wrapf(err, "license checker failed to parse URL string %q", licenseURL)
|
||||
}
|
||||
revision := version.Version().Revision
|
||||
q := ur.Query()
|
||||
q.Set("version", revision)
|
||||
ur.RawQuery = q.Encode()
|
||||
return ur, nil
|
||||
}
|
||||
|
||||
func (cc *Checker) RunLicenseCheck(ctx context.Context) {
|
||||
cc.logger.Log("msg", "begin license check")
|
||||
defer cc.logger.Log("msg", "ending license check")
|
||||
|
||||
license, err := cc.ds.License()
|
||||
if err != nil {
|
||||
cc.logger.Log("msg", "couldn't fetch license", "err", err)
|
||||
return
|
||||
}
|
||||
claims, err := license.Claims()
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "fetching claims", "err", err)
|
||||
cc.logger.Log("msg", "fetching claims", "err", err)
|
||||
return
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", chk.url, claims.LicenseUUID)
|
||||
resp, err := chk.client.Get(url)
|
||||
|
||||
licenseURL, err := addVersionInfo(fmt.Sprintf("%s/%s", cc.url, claims.LicenseUUID))
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", fmt.Sprintf("fetching %s", url), "err", err)
|
||||
cc.logger.Log("msg", "adding version information to license", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := cc.client.Get(licenseURL.String())
|
||||
if err != nil {
|
||||
cc.logger.Log("msg", fmt.Sprintf("fetching %s", licenseURL.String()), "err", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
var revInfo revokeInfo
|
||||
err = json.NewDecoder(resp.Body).Decode(&revInfo)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "decoding response", "err", err)
|
||||
cc.logger.Log("msg", "decoding response", "err", err)
|
||||
return
|
||||
}
|
||||
err = chk.ds.RevokeLicense(revInfo.Revoked)
|
||||
err = cc.ds.RevokeLicense(revInfo.Revoked)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "revoke status", "err", err)
|
||||
cc.logger.Log("msg", "revoke status", "err", err)
|
||||
return
|
||||
}
|
||||
// success
|
||||
chk.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked))
|
||||
cc.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked))
|
||||
case http.StatusNotFound:
|
||||
var revInfo revokeError
|
||||
err = json.NewDecoder(resp.Body).Decode(&revInfo)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "decoding response", "err", err)
|
||||
cc.logger.Log("msg", "decoding response", "err", err)
|
||||
return
|
||||
}
|
||||
chk.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error))
|
||||
cc.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error))
|
||||
default:
|
||||
chk.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status))
|
||||
cc.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ import (
|
|||
"github.com/kolide/kolide/server/kolide"
|
||||
"github.com/kolide/kolide/server/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
var tokenString = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjRkOmM1OmRlOmE1Oj" +
|
||||
|
|
@ -82,6 +84,7 @@ func TestLicenseFound(t *testing.T) {
|
|||
mockTicker(c.NewTicker(time.Millisecond)),
|
||||
)
|
||||
checker.Start()
|
||||
checker.RunLicenseCheck(context.Background())
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
|
|
@ -92,8 +95,8 @@ func TestLicenseFound(t *testing.T) {
|
|||
// use the flags from the mock package to indicate function invocation race detector will
|
||||
// complain
|
||||
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&revokeFunInvoked))
|
||||
assert.Equal(t, int64(3), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(3), atomic.LoadInt64(&revokeFunInvoked))
|
||||
}
|
||||
|
||||
func TestLicenseNotFound(t *testing.T) {
|
||||
|
|
@ -134,12 +137,13 @@ func TestLicenseNotFound(t *testing.T) {
|
|||
mockTicker(c.NewTicker(time.Millisecond)),
|
||||
)
|
||||
checker.Start()
|
||||
checker.RunLicenseCheck(context.Background())
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
<-time.After(10 * time.Millisecond)
|
||||
checker.Stop()
|
||||
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked))
|
||||
}
|
||||
|
||||
|
|
@ -208,12 +212,13 @@ func TestLicenseTimeout(t *testing.T) {
|
|||
Logger(logger),
|
||||
)
|
||||
checker.Start()
|
||||
checker.RunLicenseCheck(context.Background())
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
<-time.After(10 * time.Millisecond)
|
||||
checker.Stop()
|
||||
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked))
|
||||
match, _ := regexp.MatchString("(Client.Timeout exceeded while awaiting headers)", logger.read())
|
||||
assert.True(t, match)
|
||||
|
|
@ -222,3 +227,11 @@ func TestLicenseTimeout(t *testing.T) {
|
|||
assert.True(t, match)
|
||||
|
||||
}
|
||||
|
||||
func TestURLAddVersionInfo(t *testing.T) {
|
||||
licenseURL := "https://kolide.co/api/v0/licenses"
|
||||
ur, err := addVersionInfo(licenseURL)
|
||||
require.Nil(t, err)
|
||||
want := licenseURL + "?version=unknown"
|
||||
assert.Equal(t, want, ur.String(), "query params must include version")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
// NewService creates a new service from the config struct
|
||||
func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger kitlog.Logger, kolideConfig config.KolideConfig, mailService kolide.MailService, c clock.Clock) (kolide.Service, error) {
|
||||
func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger kitlog.Logger, kolideConfig config.KolideConfig, mailService kolide.MailService, c clock.Clock, checker kolide.LicenseChecker) (kolide.Service, error) {
|
||||
var svc kolide.Service
|
||||
|
||||
logFile := func(path string) io.Writer {
|
||||
|
|
@ -26,11 +26,12 @@ func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger
|
|||
}
|
||||
|
||||
svc = service{
|
||||
ds: ds,
|
||||
resultStore: resultStore,
|
||||
logger: logger,
|
||||
config: kolideConfig,
|
||||
clock: c,
|
||||
ds: ds,
|
||||
resultStore: resultStore,
|
||||
logger: logger,
|
||||
config: kolideConfig,
|
||||
clock: c,
|
||||
licenseChecker: checker,
|
||||
|
||||
osqueryStatusLogWriter: logFile(kolideConfig.Osquery.StatusLogFile),
|
||||
osqueryResultLogWriter: logFile(kolideConfig.Osquery.ResultLogFile),
|
||||
|
|
@ -41,11 +42,12 @@ func NewService(ds kolide.Datastore, resultStore kolide.QueryResultStore, logger
|
|||
}
|
||||
|
||||
type service struct {
|
||||
ds kolide.Datastore
|
||||
resultStore kolide.QueryResultStore
|
||||
logger kitlog.Logger
|
||||
config config.KolideConfig
|
||||
clock clock.Clock
|
||||
ds kolide.Datastore
|
||||
resultStore kolide.QueryResultStore
|
||||
logger kitlog.Logger
|
||||
config config.KolideConfig
|
||||
clock clock.Clock
|
||||
licenseChecker kolide.LicenseChecker
|
||||
|
||||
osqueryStatusLogWriter io.Writer
|
||||
osqueryResultLogWriter io.Writer
|
||||
|
|
|
|||
|
|
@ -22,5 +22,7 @@ func (svc service) SaveLicense(ctx context.Context, jwtToken string) (*kolide.Li
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// schedule a checkin with the license server.
|
||||
go func() { svc.licenseChecker.RunLicenseCheck(ctx) }()
|
||||
return updated, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@ import (
|
|||
|
||||
func newTestService(ds kolide.Datastore, rs kolide.QueryResultStore) (kolide.Service, error) {
|
||||
mailer := &mockMailService{SendEmailFn: func(e kolide.Email) error { return nil }}
|
||||
return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, clock.C)
|
||||
return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, clock.C, nil)
|
||||
}
|
||||
|
||||
func newTestServiceWithClock(ds kolide.Datastore, rs kolide.QueryResultStore, c clock.Clock) (kolide.Service, error) {
|
||||
mailer := &mockMailService{SendEmailFn: func(e kolide.Email) error { return nil }}
|
||||
return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, c)
|
||||
return NewService(ds, rs, kitlog.NewNopLogger(), config.TestConfig(), mailer, c, nil)
|
||||
}
|
||||
|
||||
func createTestAppConfig(t *testing.T, ds kolide.Datastore) *kolide.AppConfig {
|
||||
|
|
|
|||
Loading…
Reference in a new issue