Implement osquery datastore methods for inmem datastore (#255)

- Implement osquery datastore methods
- Update tests for compatibility with inmem

Closes #255
This commit is contained in:
Zachary Wasserman 2016-09-29 18:19:51 -07:00 committed by GitHub
parent 09e988626f
commit 74aaa14dde
12 changed files with 423 additions and 60 deletions

View file

@ -3,6 +3,7 @@ package datastore
import (
"fmt"
"math/rand"
"sort"
"testing"
"time"
@ -325,18 +326,21 @@ func testLabels(t *testing.T, db kolide.Datastore) {
}
for _, query := range labelQueries {
assert.Nil(t, db.NewQuery(&query))
newQuery, err := db.NewQuery(&query)
assert.Nil(t, err)
assert.NotZero(t, newQuery.ID)
}
// this one should not show up
assert.NoError(t, db.NewQuery(&kolide.Query{
_, err = db.NewQuery(&kolide.Query{
Platform: "not_darwin",
Query: "query5",
}))
})
assert.Nil(t, err)
// No queries should be returned before labels added
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.NoError(t, err)
assert.Nil(t, err)
assert.Empty(t, queries)
newLabels := []kolide.Label{
@ -360,7 +364,9 @@ func testLabels(t *testing.T, db kolide.Datastore) {
}
for _, label := range newLabels {
assert.Nil(t, db.NewLabel(&label))
newLabel, err := db.NewLabel(&label)
assert.Nil(t, err)
assert.NotZero(t, newLabel.ID)
}
expectQueries := map[string]string{
@ -374,7 +380,7 @@ func testLabels(t *testing.T, db kolide.Datastore) {
// Now queries should be returned
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.NoError(t, err)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// No labels should match with no results yet
@ -384,38 +390,40 @@ func testLabels(t *testing.T, db kolide.Datastore) {
// Record a query execution
err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime)
assert.NoError(t, err)
assert.Nil(t, err)
// Use a 10 minute interval, so the query we just added should show up
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.NoError(t, err)
assert.Nil(t, err)
delete(expectQueries, "1")
assert.Equal(t, expectQueries, queries)
// Record an old query execution -- Shouldn't change the return
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour))
assert.NoError(t, err)
assert.Nil(t, err)
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.NoError(t, err)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Record a newer execution for that query and another
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime)
assert.NoError(t, err)
assert.Nil(t, err)
// Now these should no longer show up in the necessary to run queries
delete(expectQueries, "2")
delete(expectQueries, "3")
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.NoError(t, err)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Now the two matching labels should be returned
labels, err = db.LabelsForHost(host)
assert.Nil(t, err)
if assert.Len(t, labels, 2) {
assert.Equal(t, "label3", labels[0].Name)
assert.Equal(t, "label2", labels[1].Name)
labelNames := []string{labels[0].Name, labels[1].Name}
sort.Strings(labelNames)
assert.Equal(t, "label2", labelNames[0])
assert.Equal(t, "label3", labelNames[1])
}
// A host that hasn't executed any label queries should still be asked
@ -438,6 +446,7 @@ func setup(t *testing.T) kolide.Datastore {
require.Nil(t, err)
ds := gormDB{DB: db, Driver: "sqlite3"}
err = ds.Migrate()
assert.Nil(t, err)
// Log using t.Log so that output only shows up if the test fails
@ -475,16 +484,16 @@ func randomString(strlen int) string {
}
func testSaveQuery(t *testing.T, ds kolide.Datastore) {
query := kolide.Query{
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
err := ds.SaveQuery(&query)
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, 0, query.ID)
query.Query = "baz"
err = ds.SaveQuery(&query)
err = ds.SaveQuery(query)
assert.Nil(t, err)
queryVerify, err := ds.Query(query.ID)
@ -493,15 +502,15 @@ func testSaveQuery(t *testing.T, ds kolide.Datastore) {
}
func testDeleteQuery(t *testing.T, ds kolide.Datastore) {
query := kolide.Query{
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
err := ds.SaveQuery(&query)
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
err = ds.DeleteQuery(&query)
err = ds.DeleteQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
@ -539,7 +548,7 @@ func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
Name: "bar",
Query: "bar",
}
err = ds.NewQuery(q1)
_, err = ds.NewQuery(q1)
assert.Nil(t, err)
err = ds.AddQueryToPack(q1, pack)
assert.Nil(t, err)
@ -548,7 +557,7 @@ func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
Name: "baz",
Query: "baz",
}
err = ds.NewQuery(q2)
_, err = ds.NewQuery(q2)
assert.Nil(t, err)
err = ds.AddQueryToPack(q2, pack)
assert.Nil(t, err)

View file

@ -228,23 +228,26 @@ func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) {
}
func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error {
updateTime := time.Now()
err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error
if err != nil {
return errors.DatabaseError(err)
}
host.UpdatedAt = updateTime
host.UpdatedAt = t
return nil
}
func (orm gormDB) NewQuery(query *kolide.Query) error {
func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) {
if query == nil {
return errors.New(
return nil, errors.New(
"error creating query",
"nil pointer passed to NewQuery",
)
}
return orm.DB.Create(query).Error
err := orm.DB.Create(query).Error
if err != nil {
return nil, err
}
return query, nil
}
func (orm gormDB) SaveQuery(query *kolide.Query) error {
@ -284,14 +287,18 @@ func (orm gormDB) Queries() ([]*kolide.Query, error) {
return queries, err
}
func (orm gormDB) NewLabel(label *kolide.Label) error {
func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) {
if label == nil {
return errors.New(
return nil, errors.New(
"error creating label",
"nil pointer passed to NewLabel",
)
}
return orm.DB.Create(label).Error
err := orm.DB.Create(label).Error
if err != nil {
return nil, err
}
return label, nil
}
func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {

View file

@ -8,13 +8,17 @@ import (
type inmem struct {
kolide.Datastore
Driver string
mtx sync.RWMutex
users map[uint]*kolide.User
sessions map[uint]*kolide.Session
passwordResets map[uint]*kolide.PasswordResetRequest
invites map[uint]*kolide.Invite
orginfo *kolide.OrgInfo
Driver string
mtx sync.RWMutex
users map[uint]*kolide.User
sessions map[uint]*kolide.Session
passwordResets map[uint]*kolide.PasswordResetRequest
invites map[uint]*kolide.Invite
labels map[uint]*kolide.Label
labelQueryExecutions map[uint]*kolide.LabelQueryExecution
queries map[uint]*kolide.Query
hosts map[uint]*kolide.Host
orginfo *kolide.OrgInfo
}
func (orm *inmem) Name() string {
@ -22,14 +26,19 @@ func (orm *inmem) Name() string {
}
func (orm *inmem) Migrate() error {
return nil
}
func (orm *inmem) Drop() error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
orm.users = make(map[uint]*kolide.User)
orm.sessions = make(map[uint]*kolide.Session)
orm.passwordResets = make(map[uint]*kolide.PasswordResetRequest)
orm.invites = make(map[uint]*kolide.Invite)
orm.labels = make(map[uint]*kolide.Label)
orm.labelQueryExecutions = make(map[uint]*kolide.LabelQueryExecution)
orm.queries = make(map[uint]*kolide.Query)
orm.hosts = make(map[uint]*kolide.Host)
return nil
}
func (orm *inmem) Drop() error {
return orm.Migrate()
}

View file

@ -0,0 +1,140 @@
package datastore
import (
"errors"
"time"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm *inmem) NewHost(host *kolide.Host) (*kolide.Host, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
for _, h := range orm.hosts {
if host.NodeKey == h.NodeKey || host.UUID == h.UUID {
return nil, ErrExists
}
}
host.ID = uint(len(orm.hosts) + 1)
orm.hosts[host.ID] = host
return host, nil
}
func (orm *inmem) SaveHost(host *kolide.Host) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.hosts[host.ID]; !ok {
return ErrNotFound
}
orm.hosts[host.ID] = host
return nil
}
func (orm *inmem) DeleteHost(host *kolide.Host) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.hosts[host.ID]; !ok {
return ErrNotFound
}
delete(orm.hosts, host.ID)
return nil
}
func (orm *inmem) Host(id uint) (*kolide.Host, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
host, ok := orm.hosts[id]
if !ok {
return nil, ErrNotFound
}
return host, nil
}
func (orm *inmem) Hosts() ([]*kolide.Host, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
hosts := []*kolide.Host{}
for _, host := range orm.hosts {
hosts = append(hosts, host)
}
return hosts, nil
}
func (orm *inmem) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if uuid == "" {
return nil, errors.New("missing uuid for host enrollment")
}
host := kolide.Host{UUID: uuid}
for _, h := range orm.hosts {
if h.UUID == uuid {
host = *h
break
}
}
var err error
host.NodeKey, err = generateRandomText(nodeKeySize)
if err != nil {
return nil, err
}
if hostname != "" {
host.HostName = hostname
}
if ip != "" {
host.IPAddress = ip
}
if platform != "" {
host.Platform = platform
}
if host.ID == 0 {
host.ID = uint(len(orm.hosts) + 1)
}
orm.hosts[host.ID] = &host
return &host, nil
}
func (orm *inmem) AuthenticateHost(nodeKey string) (*kolide.Host, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
for _, host := range orm.hosts {
if host.NodeKey == nodeKey {
return host, nil
}
}
return nil, ErrNotFound
}
func (orm *inmem) MarkHostSeen(host *kolide.Host, t time.Time) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
host.UpdatedAt = t
for _, h := range orm.hosts {
if h.ID == host.ID {
h.UpdatedAt = t
break
}
}
return nil
}

View file

@ -0,0 +1,126 @@
package datastore
import (
"errors"
"strconv"
"time"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm *inmem) NewLabel(label *kolide.Label) (*kolide.Label, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
newLabel := *label
for _, l := range orm.labels {
if l.Name == label.Name {
return nil, ErrExists
}
}
newLabel.ID = uint(len(orm.labels) + 1)
orm.labels[newLabel.ID] = &newLabel
return &newLabel, nil
}
func (orm *inmem) LabelsForHost(host *kolide.Host) ([]kolide.Label, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
// First get IDs of label executions for the host
resLabels := []kolide.Label{}
for _, lqe := range orm.labelQueryExecutions {
if lqe.HostID == host.ID && lqe.Matches {
if label := orm.labels[lqe.LabelID]; label != nil {
resLabels = append(resLabels, *label)
}
}
}
return resLabels, nil
}
func (orm *inmem) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
// Get post-cutoff executions for host
execedQueryIDs := map[uint]uint{} // Map queryID -> labelID
for _, lqe := range orm.labelQueryExecutions {
if lqe.HostID == host.ID && (lqe.UpdatedAt == cutoff || lqe.UpdatedAt.After(cutoff)) {
label := orm.labels[lqe.LabelID]
execedQueryIDs[label.QueryID] = label.ID
}
}
queryToLabel := map[uint]uint{} // Map queryID -> labelID
for _, label := range orm.labels {
queryToLabel[label.QueryID] = label.ID
}
resQueries := map[string]string{}
for _, query := range orm.queries {
_, execed := execedQueryIDs[query.ID]
labelID := queryToLabel[query.ID]
if query.Platform == host.Platform && !execed {
resQueries[strconv.Itoa(int(labelID))] = query.Query
}
}
return resQueries, nil
}
func (orm *inmem) getLabelByIDString(id string) (*kolide.Label, error) {
labelID, err := strconv.Atoi(id)
if err != nil {
return nil, errors.New("non-int label ID")
}
label, ok := orm.labels[uint(labelID)]
if !ok {
return nil, errors.New("label ID not found: " + string(labelID))
}
return label, nil
}
func (orm *inmem) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
// Record executions
for strLabelID, matches := range results {
label, err := orm.getLabelByIDString(strLabelID)
if err != nil {
return err
}
updated := false
for _, lqe := range orm.labelQueryExecutions {
if lqe.LabelID == label.ID && lqe.HostID == host.ID {
// Update existing execution values
lqe.UpdatedAt = t
lqe.Matches = matches
updated = true
break
}
}
if !updated {
// Create new execution
lqe := kolide.LabelQueryExecution{
ID: uint(len(orm.labelQueryExecutions) + 1),
HostID: host.ID,
LabelID: label.ID,
UpdatedAt: t,
Matches: matches,
}
orm.labelQueryExecutions[lqe.ID] = &lqe
}
}
return nil
}

View file

@ -0,0 +1,69 @@
package datastore
import "github.com/kolide/kolide-ose/server/kolide"
func (orm *inmem) NewQuery(query *kolide.Query) (*kolide.Query, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
newQuery := *query
for _, q := range orm.queries {
if query.Name == q.Name {
return nil, ErrExists
}
}
newQuery.ID = uint(len(orm.queries) + 1)
orm.queries[newQuery.ID] = &newQuery
return &newQuery, nil
}
func (orm *inmem) SaveQuery(query *kolide.Query) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.queries[query.ID]; !ok {
return ErrNotFound
}
orm.queries[query.ID] = query
return nil
}
func (orm *inmem) DeleteQuery(query *kolide.Query) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.queries[query.ID]; !ok {
return ErrNotFound
}
delete(orm.queries, query.ID)
return nil
}
func (orm *inmem) Query(id uint) (*kolide.Query, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
query, ok := orm.queries[id]
if !ok {
return nil, ErrNotFound
}
return query, nil
}
func (orm *inmem) Queries() ([]*kolide.Query, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
queries := []*kolide.Query{}
for _, query := range orm.queries {
queries = append(queries, query)
}
return queries, nil
}

View file

@ -22,7 +22,7 @@ type OsqueryStore interface {
RecordLabelQueryExecutions(host *Host, results map[string]bool, t time.Time) error
// NewLabel saves a new label.
NewLabel(label *Label) error
NewLabel(label *Label) (*Label, error)
// LabelsForHost returns the labels that the given host is in.
LabelsForHost(host *Host) ([]Label, error)

View file

@ -8,7 +8,7 @@ import (
type QueryStore interface {
// Query methods
NewQuery(query *Query) error
NewQuery(query *Query) (*Query, error)
SaveQuery(query *Query) error
DeleteQuery(query *Query) error
Query(id uint) (*Query, error)

View file

@ -151,16 +151,18 @@ func TestGetDistributedQueries(t *testing.T) {
expectQueries := make(map[string]string)
for _, query := range labelQueries {
assert.NoError(t, ds.NewQuery(query))
_, err := ds.NewQuery(query)
assert.Nil(t, err)
expectQueries[fmt.Sprintf("kolide_label_query_%d", query.ID)] = query.Query
}
// this one should not show up
assert.NoError(t, ds.NewQuery(&kolide.Query{
_, err = ds.NewQuery(&kolide.Query{
ID: 4,
Name: "query4",
Platform: "not_darwin",
Query: "query4",
}))
})
assert.Nil(t, err)
labels := []*kolide.Label{
&kolide.Label{
@ -182,7 +184,8 @@ func TestGetDistributedQueries(t *testing.T) {
}
for _, label := range labels {
assert.NoError(t, ds.NewLabel(label))
_, err := ds.NewLabel(label)
assert.Nil(t, err)
}
// Now we should get the label queries
@ -193,7 +196,7 @@ func TestGetDistributedQueries(t *testing.T) {
// Record a query execution
err = ds.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, mockClock.Now())
assert.NoError(t, err)
assert.Nil(t, err)
// Now that query should not be returned
queries, err = svc.GetDistributedQueries(ctx)
@ -212,19 +215,19 @@ func TestGetDistributedQueries(t *testing.T) {
// Record an old query execution -- Shouldn't change the return
err = ds.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, mockClock.Now().Add(-10*time.Hour))
assert.NoError(t, err)
assert.Nil(t, err)
queries, err = svc.GetDistributedQueries(ctx)
assert.NoError(t, err)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Record a newer execution for that query and another
err = ds.RecordLabelQueryExecutions(host, map[string]bool{"2": true, "3": false}, mockClock.Now().Add(-1*time.Minute))
assert.NoError(t, err)
assert.Nil(t, err)
// Now these should no longer show up in the necessary to run queries
delete(expectQueries, "kolide_label_query_2")
delete(expectQueries, "kolide_label_query_3")
queries, err = svc.GetDistributedQueries(ctx)
assert.NoError(t, err)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
}

View file

@ -146,7 +146,7 @@ func TestAddQueryToPack(t *testing.T) {
Name: "bar",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)
@ -182,7 +182,7 @@ func TestGetQueriesInPack(t *testing.T) {
Name: "bar",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)
@ -214,7 +214,7 @@ func TestRemoveQueryFromPack(t *testing.T) {
Name: "bar",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)

View file

@ -44,7 +44,7 @@ func (svc service) NewQuery(ctx context.Context, p kolide.QueryPayload) (*kolide
query.Version = *p.Version
}
err := svc.ds.NewQuery(&query)
_, err := svc.ds.NewQuery(&query)
if err != nil {
return nil, err
}

View file

@ -22,7 +22,7 @@ func TestGetAllQueries(t *testing.T) {
assert.Nil(t, err)
assert.Len(t, queries, 0)
err = ds.NewQuery(&kolide.Query{
_, err = ds.NewQuery(&kolide.Query{
Name: "foo",
Query: "select * from time;",
})
@ -46,7 +46,7 @@ func TestGetQuery(t *testing.T) {
Name: "foo",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)
@ -92,7 +92,7 @@ func TestModifyQuery(t *testing.T) {
Name: "foo",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)
@ -119,7 +119,7 @@ func TestDeleteQuery(t *testing.T) {
Name: "foo",
Query: "select * from time;",
}
err = ds.NewQuery(query)
_, err = ds.NewQuery(query)
assert.Nil(t, err)
assert.NotZero(t, query.ID)