diff --git a/server/datastore/datastore_licenses_test.go b/server/datastore/datastore_licenses_test.go index 86d1aca8a8..9ed6f250fc 100644 --- a/server/datastore/datastore_licenses_test.go +++ b/server/datastore/datastore_licenses_test.go @@ -41,6 +41,12 @@ func testLicense(t *testing.T, ds kolide.Datastore) { require.NotNil(t, license.Token) assert.Equal(t, token, *license.Token) + err = ds.RevokeLicense(!license.Revoked) + require.Nil(t, err) + changedLicense, err := ds.License() + require.Nil(t, err) + assert.NotEqual(t, license.Revoked, changedLicense.Revoked) + // screw around with the token in random ways and make sure that A) it doesn't // panic and B) returns an error r := rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/server/datastore/inmem/licensure.go b/server/datastore/inmem/licensure.go index 6d407c7d65..cc178e39c9 100644 --- a/server/datastore/inmem/licensure.go +++ b/server/datastore/inmem/licensure.go @@ -13,3 +13,7 @@ func (ds *Datastore) License() (*kolide.License, error) { func (ds *Datastore) LicensePublicKey(string) (string, error) { panic("inmem is being deprecated") } + +func (ds *Datastore) RevokeLicense(revoked bool) error { + panic("inmem is being deprecated") +} diff --git a/server/datastore/mysql/licenses.go b/server/datastore/mysql/licenses.go index 991ba73f4a..913e9ed4f4 100644 --- a/server/datastore/mysql/licenses.go +++ b/server/datastore/mysql/licenses.go @@ -8,6 +8,19 @@ import ( "github.com/pkg/errors" ) +func (ds *Datastore) RevokeLicense(revoked bool) error { + sql := ` + UPDATE licenses SET + revoked = ? + WHERE id = 1 + ` + _, err := ds.db.Exec(sql, revoked) + if err != nil { + return errors.Wrap(err, "updating license revoked") + } + return nil +} + // LicensePublicKey will insure that a jwt token is signed properly and that we // have the public key we need to validate it. The public key string is returned // on success diff --git a/server/kolide/licenses.go b/server/kolide/licenses.go index 438a816eb7..45ccfe53e9 100644 --- a/server/kolide/licenses.go +++ b/server/kolide/licenses.go @@ -21,6 +21,8 @@ type LicenseStore interface { License() (*License, error) // LicensePublicKey gets the public key associated with this license LicensePublicKey(tokenString string) (string, error) + // RevokeLicense sets revoked status of license + RevokeLicense(revoked bool) error } type LicenseService interface { diff --git a/server/license-checker/checker.go b/server/license-checker/checker.go new file mode 100644 index 0000000000..0a6045494b --- /dev/null +++ b/server/license-checker/checker.go @@ -0,0 +1,184 @@ +package license + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/WatchBeam/clock" + "github.com/go-kit/kit/log" + "github.com/kolide/kolide/server/kolide" +) + +const ( + defaultPollFrequency = time.Hour + defaultHttpClientTimeout = 10 * time.Second +) + +type timer struct { + *time.Ticker +} + +func (t *timer) Chan() <-chan time.Time { + return t.C +} + +type revokeInfo struct { + UUID string `json:"uuid"` + Revoked bool `json:"revoked"` +} + +type revokeError struct { + Status int `json:"status"` + Error string `json:"error"` +} + +// Checker checks remote kolide/cloud app for license revocation +// status +type Checker struct { + ds kolide.Datastore + logger log.Logger + url string + pollFrequency time.Duration + ticker clock.Ticker + client *http.Client + finish chan struct{} +} + +type Option func(chk *Checker) + +// Logger set the logger that will be used by the Checker +func Logger(logger log.Logger) Option { + return func(chk *Checker) { + chk.logger = logger + } +} + +// HTTPClient supply your own http client +func HTTPClient(client *http.Client) Option { + return func(chk *Checker) { + chk.client = client + } +} + +func PollFrequency(freq time.Duration) Option { + ticker := &timer{ + Ticker: time.NewTicker(freq), + } + return func(chk *Checker) { + chk.ticker = ticker + } +} + +// NewChecker instantiates a service that will check periodically to see if a license +// is revoked. licenseEndpointURL is the root url for kolide/cloud server. For example +// https://cloud.kolide.co/api/v0/licenses +// You may optionally set a logger, and/or supply a polling frequency that defines +// how often we check for revocation. +func NewChecker(ds kolide.Datastore, licenseEndpointURL string, opts ...Option) *Checker { + defaultTicker := &timer{ + Ticker: time.NewTicker(defaultPollFrequency), + } + response := &Checker{ + logger: log.NewNopLogger(), + ds: ds, + client: &http.Client{Timeout: defaultHttpClientTimeout}, + url: licenseEndpointURL, + ticker: defaultTicker, + finish: make(chan struct{}), + } + for _, o := range opts { + o(response) + } + + response.logger = log.NewContext(response.logger).With("component", "license-checker") + return response +} + +var wait sync.WaitGroup + +// Start begins checking for license revocation. Note that start can only +// be called once. If Stop is called you must create a new checker to use +// it again. +func (cc *Checker) Start() error { + if cc.finish == nil { + return errors.New("start called on stopped checker") + } + // pass in copy of receiver to avoid race conditions + go func(chk Checker, wait *sync.WaitGroup) { + wait.Add(1) + defer wait.Done() + chk.logger.Log("msg", "starting") + for { + select { + case <-chk.finish: + chk.logger.Log("msg", "finishing") + return + case <-chk.ticker.Chan(): + updateLicenseRevocation(&chk) + } + } + }(*cc, &wait) + + return nil +} + +// Stop ends checking for license revocation. +func (cc *Checker) Stop() { + cc.ticker.Stop() + close(cc.finish) + wait.Wait() + cc.finish = nil +} + +func updateLicenseRevocation(chk *Checker) { + chk.logger.Log("msg", "begin license check") + defer chk.logger.Log("msg", "ending license check") + + license, err := chk.ds.License() + if err != nil { + chk.logger.Log("msg", "couldn't fetch license", "err", err) + return + } + claims, err := license.Claims() + if err != nil { + chk.logger.Log("msg", "fetching claims", "err", err) + return + } + url := fmt.Sprintf("%s/%s", chk.url, claims.LicenseUUID) + resp, err := chk.client.Get(url) + if err != nil { + chk.logger.Log("msg", fmt.Sprintf("fetching %s", url), "err", err) + return + } + defer resp.Body.Close() + switch resp.StatusCode { + case http.StatusOK: + var revInfo revokeInfo + err = json.NewDecoder(resp.Body).Decode(&revInfo) + if err != nil { + chk.logger.Log("msg", "decoding response", "err", err) + return + } + err = chk.ds.RevokeLicense(revInfo.Revoked) + if err != nil { + chk.logger.Log("msg", "revoke status", "err", err) + return + } + // success + chk.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked)) + case http.StatusNotFound: + var revInfo revokeError + err = json.NewDecoder(resp.Body).Decode(&revInfo) + if err != nil { + chk.logger.Log("msg", "decoding response", "err", err) + return + } + chk.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error)) + default: + chk.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status)) + } +} diff --git a/server/license-checker/checker_test.go b/server/license-checker/checker_test.go new file mode 100644 index 0000000000..244d1a712e --- /dev/null +++ b/server/license-checker/checker_test.go @@ -0,0 +1,224 @@ +package license + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "regexp" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/WatchBeam/clock" + "github.com/kolide/kolide/server/kolide" + "github.com/kolide/kolide/server/mock" + "github.com/stretchr/testify/assert" +) + +var tokenString = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjRkOmM1OmRlOmE1Oj" + + "czOmUxOmE4OjI4OmU2OmEyOjMwOmI4OmI1OjBmOjg4OjQ0In0.eyJsaWNlbnNlX3V1aWQiOiIyZD" + + "gwMmEyYS1hZjRjLTQ5ZjItYWRlNC0zOGJmNjBmMmQxZjYiLCJvcmdhbml6YXRpb25fbmFtZSI6Il" + + "BoYW50YXNtLCBJbmMuIiwib3JnYW5pemF0aW9uX3V1aWQiOiI5ZmFiNjdiMy0wZWFjLTRhODMtOTI" + + "wNS04MjkyMWIwNDJmODYiLCJob3N0X2xpbWl0IjowLCJldmFsdWF0aW9uIjp0cnVlLCJleHBpcmV" + + "zX2F0IjoiMjAxNy0wMy0wNFQxNDozODo1NyswMDowMCJ9.DRFQIUDFXT0bDdya0IJKvATKCJjv3Mv" + + "w5gMxHNzby_L80muoe-36DoRxBAJZHL7dOfQDU8NRK2Mt64ozThrhWVl8wJlD9mk5ABe3tNw3LJRl" + + "2mHvOLmk37_AIHp5AEKZ6cWMPa9zf8hWf6bAv_0rOJf5wgyE81pfqRFtO0OnkGO3WLcP66L0AIntq" + + "IzAE_vWmizcUvUOCWDqwcBlT-P1mZnWJFCaSBpmpQoi3KEKJDx0wMjLiRNLX9R9dr3v3ojccoYuxR" + + "qAws-OHv3VzcuGdn3Pt9WBDr4cXdtqxaGtxJb6-BDvp8QQk69ACZXrZJ8NhZAL0EVlviRRw8bbEYchZQ" + +var publicKey = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0ZhY7r6HmifXPtServt4 +D3MSi8Awe9u132vLf8yzlknvnq+8CSnOPSSbCD+HajvZ6dnNJXjdcAhuZ32ShrH8 +rEQACEUS8Mh4z8Mo5Nlq1ou0s2JzWCx049kA34jP3u6AiPgpWUf8JRGstTlisxMn +H6B7miDs1038gVbN5rk+j+3ALYzllaTnCX3Y0C7f6IW7BjNO/tvFB84/95xfOLEz +o2MeFMqkD29hvcrUW+8+fQGJaVLvcEqBDnIEVbCCk8Wnoi48dUE06WHUl6voJecD +dW1E6jHcq8PQFK+4bI1gKZVbV4dFGSSMUyD7ov77aWHjxdQe6YEGcSXKzfyMaUtQ +vQIDAQAB +-----END PUBLIC KEY----- +` + +func mockTicker(ticker clock.Ticker) Option { + return func(chk *Checker) { + chk.ticker = ticker + } +} + +func TestLicenseFound(t *testing.T) { + var licFunInvoked int64 + var revokeFunInvoked int64 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := revokeInfo{ + UUID: "DEADBEEF", + Revoked: true, + } + json.NewEncoder(w).Encode(response) + + })) + defer ts.Close() + + ds := new(mock.Store) + ds.LicenseFunc = func() (*kolide.License, error) { + atomic.AddInt64(&licFunInvoked, 1) + result := &kolide.License{ + UpdateTimestamp: kolide.UpdateTimestamp{ + UpdatedAt: time.Now().Add(-5 * time.Minute), + }, + Token: &tokenString, + PublicKey: publicKey, + Revoked: false, + ID: 1, + } + return result, nil + } + ds.RevokeLicenseFunc = func(revoked bool) error { + atomic.AddInt64(&revokeFunInvoked, 1) + return nil + } + c := clock.NewMockClock() + checker := NewChecker(ds, ts.URL, + mockTicker(c.NewTicker(time.Millisecond)), + ) + checker.Start() + <-time.After(10 * time.Millisecond) + c.AddTime(time.Millisecond) + c.AddTime(time.Millisecond) + <-time.After(10 * time.Millisecond) + checker.Stop() + + // verify muliple checks occurred, we have to use atomic because if we + // use the flags from the mock package to indicate function invocation race detector will + // complain + + assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(2), atomic.LoadInt64(&revokeFunInvoked)) +} + +func TestLicenseNotFound(t *testing.T) { + var licFunInvoked int64 + var revokeFunInvoked int64 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + response := revokeError{ + Status: 404, + Error: "not found", + } + json.NewEncoder(w).Encode(response) + + })) + defer ts.Close() + + ds := new(mock.Store) + ds.LicenseFunc = func() (*kolide.License, error) { + atomic.AddInt64(&licFunInvoked, 1) + result := &kolide.License{ + UpdateTimestamp: kolide.UpdateTimestamp{ + UpdatedAt: time.Now().Add(-5 * time.Minute), + }, + Token: &tokenString, + PublicKey: publicKey, + Revoked: false, + ID: 1, + } + return result, nil + } + ds.RevokeLicenseFunc = func(revoked bool) error { + atomic.AddInt64(&revokeFunInvoked, 1) + return nil + } + + c := clock.NewMockClock() + checker := NewChecker(ds, ts.URL, + mockTicker(c.NewTicker(time.Millisecond)), + ) + checker.Start() + <-time.After(10 * time.Millisecond) + c.AddTime(time.Millisecond) + <-time.After(10 * time.Millisecond) + checker.Stop() + + assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked)) +} + +type testLogger struct { + logContent string + lock sync.Mutex +} + +func (tl *testLogger) Log(keyVals ...interface{}) error { + tl.lock.Lock() + defer tl.lock.Unlock() + tl.logContent += fmt.Sprint(keyVals...) + return nil +} + +func (tl *testLogger) read() string { + var buff []byte + tl.lock.Lock() + buff = make([]byte, len(tl.logContent)) + copy(buff, tl.logContent) + tl.lock.Unlock() + return string(buff) +} + +func TestLicenseTimeout(t *testing.T) { + var licFunInvoked int64 + var revokeFunInvoked int64 + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-time.After(300 * time.Millisecond) + response := revokeInfo{ + UUID: "DEADBEEF", + Revoked: true, + } + json.NewEncoder(w).Encode(response) + })) + defer ts.Close() + + ds := new(mock.Store) + ds.LicenseFunc = func() (*kolide.License, error) { + atomic.AddInt64(&licFunInvoked, 1) + result := &kolide.License{ + UpdateTimestamp: kolide.UpdateTimestamp{ + UpdatedAt: time.Now().Add(-5 * time.Minute), + }, + Token: &tokenString, + PublicKey: publicKey, + Revoked: false, + ID: 1, + } + return result, nil + } + ds.RevokeLicenseFunc = func(revoked bool) error { + atomic.AddInt64(&revokeFunInvoked, 1) + return nil + } + + // inject our custom logger so we can get log without breaking race + // detection + logger := &testLogger{} + c := clock.NewMockClock() + + checker := NewChecker(ds, ts.URL, + mockTicker(c.NewTicker(time.Millisecond)), + HTTPClient(&http.Client{Timeout: 2 * time.Millisecond}), + Logger(logger), + ) + checker.Start() + <-time.After(10 * time.Millisecond) + c.AddTime(time.Millisecond) + <-time.After(10 * time.Millisecond) + checker.Stop() + + assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked)) + assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked)) + match, _ := regexp.MatchString("(Client.Timeout exceeded while awaiting headers)", logger.read()) + assert.True(t, match) + // check to make sure things cleanly shut down. + match, _ = regexp.MatchString("finishing", logger.read()) + assert.True(t, match) + +} diff --git a/server/mock/datastore_licenses.go b/server/mock/datastore_licenses.go index ab70a41f9c..fb94a47200 100644 --- a/server/mock/datastore_licenses.go +++ b/server/mock/datastore_licenses.go @@ -12,6 +12,8 @@ type LicenseFunc func() (*kolide.License, error) type LicensePublicKeyFunc func(tokenString string) (string, error) +type RevokeLicenseFunc func(revoked bool) error + type LicenseStore struct { SaveLicenseFunc SaveLicenseFunc SaveLicenseFuncInvoked bool @@ -21,6 +23,9 @@ type LicenseStore struct { LicensePublicKeyFunc LicensePublicKeyFunc LicensePublicKeyFuncInvoked bool + + RevokeLicenseFunc RevokeLicenseFunc + RevokeLicenseFuncInvoked bool } func (s *LicenseStore) SaveLicense(tokenString string, publicKey string) (*kolide.License, error) { @@ -37,3 +42,8 @@ func (s *LicenseStore) LicensePublicKey(tokenString string) (string, error) { s.LicensePublicKeyFuncInvoked = true return s.LicensePublicKeyFunc(tokenString) } + +func (s *LicenseStore) RevokeLicense(revoked bool) error { + s.RevokeLicenseFuncInvoked = true + return s.RevokeLicenseFunc(revoked) +}