mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 17:08:53 +00:00
License revocation checker (#1170)
This commit is contained in:
parent
a74063c1d1
commit
e9c4760979
7 changed files with 443 additions and 0 deletions
|
|
@ -41,6 +41,12 @@ func testLicense(t *testing.T, ds kolide.Datastore) {
|
|||
require.NotNil(t, license.Token)
|
||||
assert.Equal(t, token, *license.Token)
|
||||
|
||||
err = ds.RevokeLicense(!license.Revoked)
|
||||
require.Nil(t, err)
|
||||
changedLicense, err := ds.License()
|
||||
require.Nil(t, err)
|
||||
assert.NotEqual(t, license.Revoked, changedLicense.Revoked)
|
||||
|
||||
// screw around with the token in random ways and make sure that A) it doesn't
|
||||
// panic and B) returns an error
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
|
|
|||
|
|
@ -13,3 +13,7 @@ func (ds *Datastore) License() (*kolide.License, error) {
|
|||
func (ds *Datastore) LicensePublicKey(string) (string, error) {
|
||||
panic("inmem is being deprecated")
|
||||
}
|
||||
|
||||
func (ds *Datastore) RevokeLicense(revoked bool) error {
|
||||
panic("inmem is being deprecated")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,19 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (ds *Datastore) RevokeLicense(revoked bool) error {
|
||||
sql := `
|
||||
UPDATE licenses SET
|
||||
revoked = ?
|
||||
WHERE id = 1
|
||||
`
|
||||
_, err := ds.db.Exec(sql, revoked)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "updating license revoked")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LicensePublicKey will insure that a jwt token is signed properly and that we
|
||||
// have the public key we need to validate it. The public key string is returned
|
||||
// on success
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ type LicenseStore interface {
|
|||
License() (*License, error)
|
||||
// LicensePublicKey gets the public key associated with this license
|
||||
LicensePublicKey(tokenString string) (string, error)
|
||||
// RevokeLicense sets revoked status of license
|
||||
RevokeLicense(revoked bool) error
|
||||
}
|
||||
|
||||
type LicenseService interface {
|
||||
|
|
|
|||
184
server/license-checker/checker.go
Normal file
184
server/license-checker/checker.go
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
package license
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/go-kit/kit/log"
|
||||
"github.com/kolide/kolide/server/kolide"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPollFrequency = time.Hour
|
||||
defaultHttpClientTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type timer struct {
|
||||
*time.Ticker
|
||||
}
|
||||
|
||||
func (t *timer) Chan() <-chan time.Time {
|
||||
return t.C
|
||||
}
|
||||
|
||||
type revokeInfo struct {
|
||||
UUID string `json:"uuid"`
|
||||
Revoked bool `json:"revoked"`
|
||||
}
|
||||
|
||||
type revokeError struct {
|
||||
Status int `json:"status"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Checker checks remote kolide/cloud app for license revocation
|
||||
// status
|
||||
type Checker struct {
|
||||
ds kolide.Datastore
|
||||
logger log.Logger
|
||||
url string
|
||||
pollFrequency time.Duration
|
||||
ticker clock.Ticker
|
||||
client *http.Client
|
||||
finish chan struct{}
|
||||
}
|
||||
|
||||
type Option func(chk *Checker)
|
||||
|
||||
// Logger set the logger that will be used by the Checker
|
||||
func Logger(logger log.Logger) Option {
|
||||
return func(chk *Checker) {
|
||||
chk.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPClient supply your own http client
|
||||
func HTTPClient(client *http.Client) Option {
|
||||
return func(chk *Checker) {
|
||||
chk.client = client
|
||||
}
|
||||
}
|
||||
|
||||
func PollFrequency(freq time.Duration) Option {
|
||||
ticker := &timer{
|
||||
Ticker: time.NewTicker(freq),
|
||||
}
|
||||
return func(chk *Checker) {
|
||||
chk.ticker = ticker
|
||||
}
|
||||
}
|
||||
|
||||
// NewChecker instantiates a service that will check periodically to see if a license
|
||||
// is revoked. licenseEndpointURL is the root url for kolide/cloud server. For example
|
||||
// https://cloud.kolide.co/api/v0/licenses
|
||||
// You may optionally set a logger, and/or supply a polling frequency that defines
|
||||
// how often we check for revocation.
|
||||
func NewChecker(ds kolide.Datastore, licenseEndpointURL string, opts ...Option) *Checker {
|
||||
defaultTicker := &timer{
|
||||
Ticker: time.NewTicker(defaultPollFrequency),
|
||||
}
|
||||
response := &Checker{
|
||||
logger: log.NewNopLogger(),
|
||||
ds: ds,
|
||||
client: &http.Client{Timeout: defaultHttpClientTimeout},
|
||||
url: licenseEndpointURL,
|
||||
ticker: defaultTicker,
|
||||
finish: make(chan struct{}),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(response)
|
||||
}
|
||||
|
||||
response.logger = log.NewContext(response.logger).With("component", "license-checker")
|
||||
return response
|
||||
}
|
||||
|
||||
var wait sync.WaitGroup
|
||||
|
||||
// Start begins checking for license revocation. Note that start can only
|
||||
// be called once. If Stop is called you must create a new checker to use
|
||||
// it again.
|
||||
func (cc *Checker) Start() error {
|
||||
if cc.finish == nil {
|
||||
return errors.New("start called on stopped checker")
|
||||
}
|
||||
// pass in copy of receiver to avoid race conditions
|
||||
go func(chk Checker, wait *sync.WaitGroup) {
|
||||
wait.Add(1)
|
||||
defer wait.Done()
|
||||
chk.logger.Log("msg", "starting")
|
||||
for {
|
||||
select {
|
||||
case <-chk.finish:
|
||||
chk.logger.Log("msg", "finishing")
|
||||
return
|
||||
case <-chk.ticker.Chan():
|
||||
updateLicenseRevocation(&chk)
|
||||
}
|
||||
}
|
||||
}(*cc, &wait)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop ends checking for license revocation.
|
||||
func (cc *Checker) Stop() {
|
||||
cc.ticker.Stop()
|
||||
close(cc.finish)
|
||||
wait.Wait()
|
||||
cc.finish = nil
|
||||
}
|
||||
|
||||
func updateLicenseRevocation(chk *Checker) {
|
||||
chk.logger.Log("msg", "begin license check")
|
||||
defer chk.logger.Log("msg", "ending license check")
|
||||
|
||||
license, err := chk.ds.License()
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "couldn't fetch license", "err", err)
|
||||
return
|
||||
}
|
||||
claims, err := license.Claims()
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "fetching claims", "err", err)
|
||||
return
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s", chk.url, claims.LicenseUUID)
|
||||
resp, err := chk.client.Get(url)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", fmt.Sprintf("fetching %s", url), "err", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
var revInfo revokeInfo
|
||||
err = json.NewDecoder(resp.Body).Decode(&revInfo)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "decoding response", "err", err)
|
||||
return
|
||||
}
|
||||
err = chk.ds.RevokeLicense(revInfo.Revoked)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "revoke status", "err", err)
|
||||
return
|
||||
}
|
||||
// success
|
||||
chk.logger.Log("msg", fmt.Sprintf("license revocation status retrieved succesfully, revoked: %t", revInfo.Revoked))
|
||||
case http.StatusNotFound:
|
||||
var revInfo revokeError
|
||||
err = json.NewDecoder(resp.Body).Decode(&revInfo)
|
||||
if err != nil {
|
||||
chk.logger.Log("msg", "decoding response", "err", err)
|
||||
return
|
||||
}
|
||||
chk.logger.Log("msg", "host response", "err", fmt.Sprintf("status: %d error: %s", revInfo.Status, revInfo.Error))
|
||||
default:
|
||||
chk.logger.Log("msg", "host response", "err", fmt.Sprintf("unexpected response status from host, status %s", resp.Status))
|
||||
}
|
||||
}
|
||||
224
server/license-checker/checker_test.go
Normal file
224
server/license-checker/checker_test.go
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
package license
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/WatchBeam/clock"
|
||||
"github.com/kolide/kolide/server/kolide"
|
||||
"github.com/kolide/kolide/server/mock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var tokenString = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IjRkOmM1OmRlOmE1Oj" +
|
||||
"czOmUxOmE4OjI4OmU2OmEyOjMwOmI4OmI1OjBmOjg4OjQ0In0.eyJsaWNlbnNlX3V1aWQiOiIyZD" +
|
||||
"gwMmEyYS1hZjRjLTQ5ZjItYWRlNC0zOGJmNjBmMmQxZjYiLCJvcmdhbml6YXRpb25fbmFtZSI6Il" +
|
||||
"BoYW50YXNtLCBJbmMuIiwib3JnYW5pemF0aW9uX3V1aWQiOiI5ZmFiNjdiMy0wZWFjLTRhODMtOTI" +
|
||||
"wNS04MjkyMWIwNDJmODYiLCJob3N0X2xpbWl0IjowLCJldmFsdWF0aW9uIjp0cnVlLCJleHBpcmV" +
|
||||
"zX2F0IjoiMjAxNy0wMy0wNFQxNDozODo1NyswMDowMCJ9.DRFQIUDFXT0bDdya0IJKvATKCJjv3Mv" +
|
||||
"w5gMxHNzby_L80muoe-36DoRxBAJZHL7dOfQDU8NRK2Mt64ozThrhWVl8wJlD9mk5ABe3tNw3LJRl" +
|
||||
"2mHvOLmk37_AIHp5AEKZ6cWMPa9zf8hWf6bAv_0rOJf5wgyE81pfqRFtO0OnkGO3WLcP66L0AIntq" +
|
||||
"IzAE_vWmizcUvUOCWDqwcBlT-P1mZnWJFCaSBpmpQoi3KEKJDx0wMjLiRNLX9R9dr3v3ojccoYuxR" +
|
||||
"qAws-OHv3VzcuGdn3Pt9WBDr4cXdtqxaGtxJb6-BDvp8QQk69ACZXrZJ8NhZAL0EVlviRRw8bbEYchZQ"
|
||||
|
||||
var publicKey = `-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0ZhY7r6HmifXPtServt4
|
||||
D3MSi8Awe9u132vLf8yzlknvnq+8CSnOPSSbCD+HajvZ6dnNJXjdcAhuZ32ShrH8
|
||||
rEQACEUS8Mh4z8Mo5Nlq1ou0s2JzWCx049kA34jP3u6AiPgpWUf8JRGstTlisxMn
|
||||
H6B7miDs1038gVbN5rk+j+3ALYzllaTnCX3Y0C7f6IW7BjNO/tvFB84/95xfOLEz
|
||||
o2MeFMqkD29hvcrUW+8+fQGJaVLvcEqBDnIEVbCCk8Wnoi48dUE06WHUl6voJecD
|
||||
dW1E6jHcq8PQFK+4bI1gKZVbV4dFGSSMUyD7ov77aWHjxdQe6YEGcSXKzfyMaUtQ
|
||||
vQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
`
|
||||
|
||||
func mockTicker(ticker clock.Ticker) Option {
|
||||
return func(chk *Checker) {
|
||||
chk.ticker = ticker
|
||||
}
|
||||
}
|
||||
|
||||
func TestLicenseFound(t *testing.T) {
|
||||
var licFunInvoked int64
|
||||
var revokeFunInvoked int64
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := revokeInfo{
|
||||
UUID: "DEADBEEF",
|
||||
Revoked: true,
|
||||
}
|
||||
json.NewEncoder(w).Encode(response)
|
||||
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
ds := new(mock.Store)
|
||||
ds.LicenseFunc = func() (*kolide.License, error) {
|
||||
atomic.AddInt64(&licFunInvoked, 1)
|
||||
result := &kolide.License{
|
||||
UpdateTimestamp: kolide.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
Token: &tokenString,
|
||||
PublicKey: publicKey,
|
||||
Revoked: false,
|
||||
ID: 1,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
ds.RevokeLicenseFunc = func(revoked bool) error {
|
||||
atomic.AddInt64(&revokeFunInvoked, 1)
|
||||
return nil
|
||||
}
|
||||
c := clock.NewMockClock()
|
||||
checker := NewChecker(ds, ts.URL,
|
||||
mockTicker(c.NewTicker(time.Millisecond)),
|
||||
)
|
||||
checker.Start()
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
<-time.After(10 * time.Millisecond)
|
||||
checker.Stop()
|
||||
|
||||
// verify muliple checks occurred, we have to use atomic because if we
|
||||
// use the flags from the mock package to indicate function invocation race detector will
|
||||
// complain
|
||||
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(2), atomic.LoadInt64(&revokeFunInvoked))
|
||||
}
|
||||
|
||||
func TestLicenseNotFound(t *testing.T) {
|
||||
var licFunInvoked int64
|
||||
var revokeFunInvoked int64
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
response := revokeError{
|
||||
Status: 404,
|
||||
Error: "not found",
|
||||
}
|
||||
json.NewEncoder(w).Encode(response)
|
||||
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
ds := new(mock.Store)
|
||||
ds.LicenseFunc = func() (*kolide.License, error) {
|
||||
atomic.AddInt64(&licFunInvoked, 1)
|
||||
result := &kolide.License{
|
||||
UpdateTimestamp: kolide.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
Token: &tokenString,
|
||||
PublicKey: publicKey,
|
||||
Revoked: false,
|
||||
ID: 1,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
ds.RevokeLicenseFunc = func(revoked bool) error {
|
||||
atomic.AddInt64(&revokeFunInvoked, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
c := clock.NewMockClock()
|
||||
checker := NewChecker(ds, ts.URL,
|
||||
mockTicker(c.NewTicker(time.Millisecond)),
|
||||
)
|
||||
checker.Start()
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
<-time.After(10 * time.Millisecond)
|
||||
checker.Stop()
|
||||
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked))
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
logContent string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (tl *testLogger) Log(keyVals ...interface{}) error {
|
||||
tl.lock.Lock()
|
||||
defer tl.lock.Unlock()
|
||||
tl.logContent += fmt.Sprint(keyVals...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tl *testLogger) read() string {
|
||||
var buff []byte
|
||||
tl.lock.Lock()
|
||||
buff = make([]byte, len(tl.logContent))
|
||||
copy(buff, tl.logContent)
|
||||
tl.lock.Unlock()
|
||||
return string(buff)
|
||||
}
|
||||
|
||||
func TestLicenseTimeout(t *testing.T) {
|
||||
var licFunInvoked int64
|
||||
var revokeFunInvoked int64
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-time.After(300 * time.Millisecond)
|
||||
response := revokeInfo{
|
||||
UUID: "DEADBEEF",
|
||||
Revoked: true,
|
||||
}
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
ds := new(mock.Store)
|
||||
ds.LicenseFunc = func() (*kolide.License, error) {
|
||||
atomic.AddInt64(&licFunInvoked, 1)
|
||||
result := &kolide.License{
|
||||
UpdateTimestamp: kolide.UpdateTimestamp{
|
||||
UpdatedAt: time.Now().Add(-5 * time.Minute),
|
||||
},
|
||||
Token: &tokenString,
|
||||
PublicKey: publicKey,
|
||||
Revoked: false,
|
||||
ID: 1,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
ds.RevokeLicenseFunc = func(revoked bool) error {
|
||||
atomic.AddInt64(&revokeFunInvoked, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// inject our custom logger so we can get log without breaking race
|
||||
// detection
|
||||
logger := &testLogger{}
|
||||
c := clock.NewMockClock()
|
||||
|
||||
checker := NewChecker(ds, ts.URL,
|
||||
mockTicker(c.NewTicker(time.Millisecond)),
|
||||
HTTPClient(&http.Client{Timeout: 2 * time.Millisecond}),
|
||||
Logger(logger),
|
||||
)
|
||||
checker.Start()
|
||||
<-time.After(10 * time.Millisecond)
|
||||
c.AddTime(time.Millisecond)
|
||||
<-time.After(10 * time.Millisecond)
|
||||
checker.Stop()
|
||||
|
||||
assert.Equal(t, int64(1), atomic.LoadInt64(&licFunInvoked))
|
||||
assert.Equal(t, int64(0), atomic.LoadInt64(&revokeFunInvoked))
|
||||
match, _ := regexp.MatchString("(Client.Timeout exceeded while awaiting headers)", logger.read())
|
||||
assert.True(t, match)
|
||||
// check to make sure things cleanly shut down.
|
||||
match, _ = regexp.MatchString("finishing", logger.read())
|
||||
assert.True(t, match)
|
||||
|
||||
}
|
||||
|
|
@ -12,6 +12,8 @@ type LicenseFunc func() (*kolide.License, error)
|
|||
|
||||
type LicensePublicKeyFunc func(tokenString string) (string, error)
|
||||
|
||||
type RevokeLicenseFunc func(revoked bool) error
|
||||
|
||||
type LicenseStore struct {
|
||||
SaveLicenseFunc SaveLicenseFunc
|
||||
SaveLicenseFuncInvoked bool
|
||||
|
|
@ -21,6 +23,9 @@ type LicenseStore struct {
|
|||
|
||||
LicensePublicKeyFunc LicensePublicKeyFunc
|
||||
LicensePublicKeyFuncInvoked bool
|
||||
|
||||
RevokeLicenseFunc RevokeLicenseFunc
|
||||
RevokeLicenseFuncInvoked bool
|
||||
}
|
||||
|
||||
func (s *LicenseStore) SaveLicense(tokenString string, publicKey string) (*kolide.License, error) {
|
||||
|
|
@ -37,3 +42,8 @@ func (s *LicenseStore) LicensePublicKey(tokenString string) (string, error) {
|
|||
s.LicensePublicKeyFuncInvoked = true
|
||||
return s.LicensePublicKeyFunc(tokenString)
|
||||
}
|
||||
|
||||
func (s *LicenseStore) RevokeLicense(revoked bool) error {
|
||||
s.RevokeLicenseFuncInvoked = true
|
||||
return s.RevokeLicenseFunc(revoked)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue