Datastore cleaning (#262)

This PR reorganizes a bunch of the files in datastore such that all datastore implementations are consistently broken up into multiple files. Additionally, the datastore tests follow a similar pattern and can easily be applied to any complete datastore implementation.
This commit is contained in:
Mike Arpaia 2016-10-04 16:34:36 -04:00 committed by GitHub
parent 59c194a7f4
commit 12f8c0b671
22 changed files with 1323 additions and 1353 deletions

View file

@ -67,11 +67,6 @@ the way that the kolide server works.
if err != nil {
initFatal(err, "initializing datastore")
}
err = ds.Migrate()
if err != nil {
initFatal(err, "initializing datastore")
}
} else {
connString := datastore.GetMysqlConnectionString(config.Mysql)
ds, err = datastore.New("gorm-mysql", connString)

View file

@ -2,9 +2,12 @@
package datastore
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"github.com/kolide/kolide-ose/server/config"
"github.com/kolide/kolide-ose/server/kolide"
)
@ -71,14 +74,38 @@ func New(driver, conn string, opts ...DBOption) (kolide.Datastore, error) {
return ds, nil
case "inmem":
ds := &inmem{
Driver: "inmem",
users: make(map[uint]*kolide.User),
sessions: make(map[uint]*kolide.Session),
passwordResets: make(map[uint]*kolide.PasswordResetRequest),
invites: make(map[uint]*kolide.Invite),
Driver: "inmem",
}
err := ds.Migrate()
if err != nil {
return nil, err
}
return ds, nil
default:
return nil, fmt.Errorf("unsupported datastore driver %s", driver)
}
}
func generateRandomText(keySize int) (string, error) {
key := make([]byte, keySize)
_, err := rand.Read(key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// GetMysqlConnectionString returns a MySQL connection string using the
// provided configuration.
func GetMysqlConnectionString(conf config.MysqlConfig) string {
return fmt.Sprintf(
"%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
conf.Username,
conf.Password,
conf.Address,
conf.Database,
)
}

View file

@ -1,26 +1,12 @@
package datastore
import (
"os"
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func TestOrgInfo(t *testing.T) {
var ds kolide.Datastore
address := os.Getenv("MYSQL_ADDR")
if address == "" {
ds = setup(t)
} else {
ds = setupMySQLGORM(t)
defer teardownMySQLGORM(t, ds)
}
testOrgInfo(t, ds)
}
func testOrgInfo(t *testing.T, db kolide.Datastore) {
info := &kolide.OrgInfo{
OrgName: "Kolide",

View file

@ -0,0 +1,83 @@
package datastore
import (
"fmt"
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
var enrollTests = []struct {
uuid, hostname, ip, platform string
nodeKeySize int
}{
0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
hostname: "web.kolide.co",
ip: "172.0.0.1",
platform: "linux",
nodeKeySize: 12,
},
1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
hostname: "mail.kolide.co",
ip: "172.0.0.2",
platform: "linux",
nodeKeySize: 10,
},
2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
hostname: "home.kolide.co",
ip: "127.0.0.1",
platform: "darwin",
nodeKeySize: 25,
},
3: {uuid: "uuid123",
hostname: "fakehostname",
ip: "192.168.1.1",
platform: "darwin",
nodeKeySize: 1,
},
}
func testEnrollHost(t *testing.T, db kolide.Datastore) {
var hosts []*kolide.Host
for _, tt := range enrollTests {
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
assert.Nil(t, err)
hosts = append(hosts, h)
assert.Equal(t, tt.uuid, h.UUID)
assert.Equal(t, tt.hostname, h.HostName)
assert.Equal(t, tt.ip, h.IPAddress)
assert.Equal(t, tt.platform, h.Platform)
assert.NotEmpty(t, h.NodeKey)
}
for _, enrolled := range hosts {
oldNodeKey := enrolled.NodeKey
newhostname := fmt.Sprintf("changed.%s", enrolled.HostName)
h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15)
assert.Nil(t, err)
assert.Equal(t, enrolled.UUID, h.UUID)
assert.NotEmpty(t, h.NodeKey)
assert.NotEqual(t, oldNodeKey, h.NodeKey)
}
}
func testAuthenticateHost(t *testing.T, db kolide.Datastore) {
for _, tt := range enrollTests {
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
assert.Nil(t, err)
returned, err := db.AuthenticateHost(h.NodeKey)
assert.Nil(t, err)
assert.Equal(t, h.NodeKey, returned.NodeKey)
}
_, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA")
assert.NotNil(t, err)
_, err = db.AuthenticateHost("")
assert.NotNil(t, err)
}

View file

@ -1,23 +1,13 @@
package datastore
import (
"os"
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func TestCreateInvite(t *testing.T) {
var ds kolide.Datastore
address := os.Getenv("MYSQL_ADDR")
if address == "" {
ds = setup(t)
} else {
ds = setupMySQLGORM(t)
defer teardownMySQLGORM(t, ds)
}
func testCreateInvite(t *testing.T, ds kolide.Datastore) {
invite := &kolide.Invite{}
invite, err := ds.NewInvite(invite)

View file

@ -0,0 +1,221 @@
package datastore
import (
"sort"
"testing"
"time"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func testLabels(t *testing.T, db kolide.Datastore) {
hosts := []kolide.Host{}
var host *kolide.Host
var err error
for i := 0; i < 10; i++ {
host, err = db.EnrollHost(string(i), "foo", "", "", 10)
assert.Nil(t, err, "enrollment should succeed")
hosts = append(hosts, *host)
}
baseTime := time.Now()
// No queries should be returned before labels or queries added
queries, err := db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Empty(t, queries)
// No labels should match
labels, err := db.LabelsForHost(host.ID)
assert.Nil(t, err)
assert.Empty(t, labels)
labelQueries := []kolide.Query{
kolide.Query{
Name: "query1",
Query: "query1",
Platform: "darwin",
},
kolide.Query{
Name: "query2",
Query: "query2",
Platform: "darwin",
},
kolide.Query{
Name: "query3",
Query: "query3",
Platform: "darwin",
},
kolide.Query{
Name: "query4",
Query: "query4",
Platform: "darwin",
},
}
for _, query := range labelQueries {
newQuery, err := db.NewQuery(&query)
assert.Nil(t, err)
assert.NotZero(t, newQuery.ID)
}
// this one should not show up
_, err = db.NewQuery(&kolide.Query{
Platform: "not_darwin",
Query: "query5",
})
assert.Nil(t, err)
// No queries should be returned before labels added
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Empty(t, queries)
newLabels := []kolide.Label{
// Note these are intentionally out of order
kolide.Label{
Name: "label3",
QueryID: 3,
},
kolide.Label{
Name: "label1",
QueryID: 1,
},
kolide.Label{
Name: "label2",
QueryID: 2,
},
kolide.Label{
Name: "label4",
QueryID: 4,
},
}
for _, label := range newLabels {
newLabel, err := db.NewLabel(&label)
assert.Nil(t, err)
assert.NotZero(t, newLabel.ID)
}
expectQueries := map[string]string{
"1": "query3",
"2": "query1",
"3": "query2",
"4": "query4",
}
host.Platform = "darwin"
// Now queries should be returned
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// No labels should match with no results yet
labels, err = db.LabelsForHost(host.ID)
assert.Nil(t, err)
assert.Empty(t, labels)
// Record a query execution
err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime)
assert.Nil(t, err)
// Use a 10 minute interval, so the query we just added should show up
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
delete(expectQueries, "1")
assert.Equal(t, expectQueries, queries)
// Record an old query execution -- Shouldn't change the return
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour))
assert.Nil(t, err)
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Record a newer execution for that query and another
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime)
assert.Nil(t, err)
// Now these should no longer show up in the necessary to run queries
delete(expectQueries, "2")
delete(expectQueries, "3")
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Now the two matching labels should be returned
labels, err = db.LabelsForHost(host.ID)
assert.Nil(t, err)
if assert.Len(t, labels, 2) {
labelNames := []string{labels[0].Name, labels[1].Name}
sort.Strings(labelNames)
assert.Equal(t, "label2", labelNames[0])
assert.Equal(t, "label3", labelNames[1])
}
// A host that hasn't executed any label queries should still be asked
// to execute those queries
hosts[0].Platform = "darwin"
queries, err = db.LabelQueriesForHost(host, time.Now())
assert.Nil(t, err)
assert.Len(t, queries, 4)
// There should still be no labels returned for a host that never
// executed any label queries
labels, err = db.LabelsForHost(hosts[0].ID)
assert.Nil(t, err)
assert.Empty(t, labels)
}
func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) {
mysqlQuery := &kolide.Query{
Name: "MySQL",
Query: "select pid from processes where name = 'mysqld';",
}
mysqlQuery, err := ds.NewQuery(mysqlQuery)
assert.Nil(t, err)
osqueryRunningQuery := &kolide.Query{
Name: "Is osquery currently running?",
Query: "select pid from processes where name = 'osqueryd';",
}
osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery)
assert.Nil(t, err)
monitoringPack := &kolide.Pack{
Name: "monitoring",
}
err = ds.NewPack(monitoringPack)
assert.Nil(t, err)
mysqlLabel := &kolide.Label{
Name: "MySQL Monitoring",
QueryID: mysqlQuery.ID,
}
mysqlLabel, err = ds.NewLabel(mysqlLabel)
assert.Nil(t, err)
err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID)
assert.Nil(t, err)
labels, err := ds.GetLabelsForPack(monitoringPack)
assert.Nil(t, err)
assert.Len(t, labels, 1)
assert.Equal(t, "MySQL Monitoring", labels[0].Name)
osqueryLabel := &kolide.Label{
Name: "Osquery Monitoring",
QueryID: osqueryRunningQuery.ID,
}
osqueryLabel, err = ds.NewLabel(osqueryLabel)
assert.Nil(t, err)
err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID)
assert.Nil(t, err)
labels, err = ds.GetLabelsForPack(monitoringPack)
assert.Nil(t, err)
assert.Len(t, labels, 2)
}

View file

@ -0,0 +1,65 @@
package datastore
import (
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testDeletePack(t *testing.T, ds kolide.Datastore) {
pack := &kolide.Pack{
Name: "foo",
}
err := ds.NewPack(pack)
assert.Nil(t, err)
assert.NotEqual(t, pack.ID, 0)
pack, err = ds.Pack(pack.ID)
require.Nil(t, err)
err = ds.DeletePack(pack.ID)
assert.Nil(t, err)
assert.NotEqual(t, pack.ID, 0)
pack, err = ds.Pack(pack.ID)
assert.NotNil(t, err)
}
func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
pack := &kolide.Pack{
Name: "foo",
}
err := ds.NewPack(pack)
assert.Nil(t, err)
q1 := &kolide.Query{
Name: "bar",
Query: "bar",
}
_, err = ds.NewQuery(q1)
assert.Nil(t, err)
err = ds.AddQueryToPack(q1.ID, pack.ID)
assert.Nil(t, err)
q2 := &kolide.Query{
Name: "baz",
Query: "baz",
}
_, err = ds.NewQuery(q2)
assert.Nil(t, err)
err = ds.AddQueryToPack(q2.ID, pack.ID)
assert.Nil(t, err)
queries, err := ds.GetQueriesInPack(pack)
assert.Nil(t, err)
assert.Len(t, queries, 2)
err = ds.RemoveQueryFromPack(q1, pack)
assert.Nil(t, err)
queries, err = ds.GetQueriesInPack(pack)
assert.Nil(t, err)
assert.Len(t, queries, 1)
}

View file

@ -0,0 +1,33 @@
package datastore
import (
"testing"
"time"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func testPasswordResetRequests(t *testing.T, db kolide.Datastore) {
createTestUsers(t, db)
now := time.Now()
tomorrow := now.Add(time.Hour * 24)
var passwordResetTests = []struct {
userID uint
expires time.Time
token string
}{
{userID: 1, expires: tomorrow, token: "abcd"},
}
for _, tt := range passwordResetTests {
r := &kolide.PasswordResetRequest{
UserID: tt.userID,
ExpiresAt: tt.expires,
Token: tt.token,
}
req, err := db.NewPasswordResetRequest(r)
assert.Nil(t, err)
assert.Equal(t, tt.userID, req.UserID)
}
}

View file

@ -0,0 +1,44 @@
package datastore
import (
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func testDeleteQuery(t *testing.T, ds kolide.Datastore) {
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
err = ds.DeleteQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
_, err = ds.Query(query.ID)
assert.NotNil(t, err)
}
func testSaveQuery(t *testing.T, ds kolide.Datastore) {
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, 0, query.ID)
query.Query = "baz"
err = ds.SaveQuery(query)
assert.Nil(t, err)
queryVerify, err := ds.Query(query.ID)
assert.Nil(t, err)
assert.Equal(t, "baz", queryVerify.Query)
}

View file

@ -1,626 +1,35 @@
package datastore
import (
"fmt"
"math/rand"
"sort"
"reflect"
"runtime"
"strings"
"testing"
"time"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const bcryptCost = 6
func TestPasswordResetRequests(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testPasswordResetRequests(t, db)
}
func testPasswordResetRequests(t *testing.T, db kolide.Datastore) {
createTestUsers(t, db)
now := time.Now()
tomorrow := now.Add(time.Hour * 24)
var passwordResetTests = []struct {
userID uint
expires time.Time
token string
}{
{userID: 1, expires: tomorrow, token: "abcd"},
}
for _, tt := range passwordResetTests {
r := &kolide.PasswordResetRequest{
UserID: tt.userID,
ExpiresAt: tt.expires,
Token: tt.token,
}
req, err := db.NewPasswordResetRequest(r)
assert.Nil(t, err)
assert.Equal(t, tt.userID, req.UserID)
}
}
func TestEnrollHost(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testEnrollHost(t, db)
func functionName(f func(*testing.T, kolide.Datastore)) string {
fullName := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
elements := strings.Split(fullName, ".")
return elements[len(elements)-1]
}
func TestAuthenticateHost(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testAuthenticateHost(t, db)
}
var enrollTests = []struct {
uuid, hostname, ip, platform string
nodeKeySize int
}{
0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
hostname: "web.kolide.co",
ip: "172.0.0.1",
platform: "linux",
nodeKeySize: 12,
},
1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
hostname: "mail.kolide.co",
ip: "172.0.0.2",
platform: "linux",
nodeKeySize: 10,
},
2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
hostname: "home.kolide.co",
ip: "127.0.0.1",
platform: "darwin",
nodeKeySize: 25,
},
3: {uuid: "uuid123",
hostname: "fakehostname",
ip: "192.168.1.1",
platform: "darwin",
nodeKeySize: 1,
},
}
func testEnrollHost(t *testing.T, db kolide.HostStore) {
var hosts []*kolide.Host
for _, tt := range enrollTests {
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
assert.Nil(t, err)
hosts = append(hosts, h)
assert.Equal(t, tt.uuid, h.UUID)
assert.Equal(t, tt.hostname, h.HostName)
assert.Equal(t, tt.ip, h.IPAddress)
assert.Equal(t, tt.platform, h.Platform)
assert.NotEmpty(t, h.NodeKey)
}
for _, enrolled := range hosts {
oldNodeKey := enrolled.NodeKey
newhostname := fmt.Sprintf("changed.%s", enrolled.HostName)
h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15)
assert.Nil(t, err)
assert.Equal(t, enrolled.UUID, h.UUID)
assert.NotEmpty(t, h.NodeKey)
assert.NotEqual(t, oldNodeKey, h.NodeKey)
}
}
func testAuthenticateHost(t *testing.T, db kolide.HostStore) {
for _, tt := range enrollTests {
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
assert.Nil(t, err)
returned, err := db.AuthenticateHost(h.NodeKey)
assert.Nil(t, err)
assert.Equal(t, h.NodeKey, returned.NodeKey)
}
_, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA")
assert.NotNil(t, err)
_, err = db.AuthenticateHost("")
assert.NotNil(t, err)
}
// TestUser tests the UserStore interface
// this test uses the default testing backend
func TestCreateUser(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testCreateUser(t, db)
}
func TestSaveUser(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testSaveUser(t, db)
}
func testCreateUser(t *testing.T, db kolide.UserStore) {
var createTests = []struct {
username, password, email string
isAdmin, passwordReset bool
}{
{"marpaia", "foobar", "mike@kolide.co", true, false},
{"jason", "foobar", "jason@kolide.co", true, false},
}
for _, tt := range createTests {
u := &kolide.User{
Username: tt.username,
Password: []byte(tt.password),
Admin: tt.isAdmin,
AdminForcedPasswordReset: tt.passwordReset,
Email: tt.email,
}
user, err := db.NewUser(u)
assert.Nil(t, err)
verify, err := db.User(tt.username)
assert.Nil(t, err)
assert.Equal(t, user.ID, verify.ID)
assert.Equal(t, tt.username, verify.Username)
assert.Equal(t, tt.email, verify.Email)
assert.Equal(t, tt.email, verify.Email)
}
}
func TestUserByID(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testUserByID(t, db)
}
func testUserByID(t *testing.T, db kolide.UserStore) {
users := createTestUsers(t, db)
for _, tt := range users {
returned, err := db.UserByID(tt.ID)
assert.Nil(t, err)
assert.Equal(t, tt.ID, returned.ID)
}
// test missing user
_, err := db.UserByID(10000000000)
assert.NotNil(t, err)
}
func createTestUsers(t *testing.T, db kolide.UserStore) []*kolide.User {
var createTests = []struct {
username, password, email string
isAdmin, passwordReset bool
}{
{"marpaia", "foobar", "mike@kolide.co", true, false},
{"jason", "foobar", "jason@kolide.co", false, false},
}
var users []*kolide.User
for _, tt := range createTests {
u := &kolide.User{
Username: tt.username,
Password: []byte(tt.password),
Admin: tt.isAdmin,
AdminForcedPasswordReset: tt.passwordReset,
Email: tt.email,
}
user, err := db.NewUser(u)
assert.Nil(t, err)
users = append(users, user)
}
assert.NotEmpty(t, users)
return users
}
func testSaveUser(t *testing.T, db kolide.UserStore) {
users := createTestUsers(t, db)
testAdminAttribute(t, db, users)
testEmailAttribute(t, db, users)
testPasswordAttribute(t, db, users)
}
func testPasswordAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
for _, user := range users {
user.Password = []byte(randomString(8))
err := db.SaveUser(user)
assert.Nil(t, err)
verify, err := db.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Password, verify.Password)
}
}
func testEmailAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
for _, user := range users {
user.Email = fmt.Sprintf("test.%s", user.Email)
err := db.SaveUser(user)
assert.Nil(t, err)
verify, err := db.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Email, verify.Email)
}
}
func testAdminAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
for _, user := range users {
user.Admin = false
err := db.SaveUser(user)
assert.Nil(t, err)
verify, err := db.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Admin, verify.Admin)
}
}
func TestLabelQueries(t *testing.T) {
db := setup(t)
defer teardown(t, db)
testLabels(t, db)
}
func testLabels(t *testing.T, db kolide.Datastore) {
hosts := []kolide.Host{}
var host *kolide.Host
var err error
for i := 0; i < 10; i++ {
host, err = db.EnrollHost(string(i), "foo", "", "", 10)
assert.Nil(t, err, "enrollment should succeed")
hosts = append(hosts, *host)
}
baseTime := time.Now()
// No queries should be returned before labels or queries added
queries, err := db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Empty(t, queries)
// No labels should match
labels, err := db.LabelsForHost(host.ID)
assert.Nil(t, err)
assert.Empty(t, labels)
labelQueries := []kolide.Query{
kolide.Query{
Name: "query1",
Query: "query1",
Platform: "darwin",
},
kolide.Query{
Name: "query2",
Query: "query2",
Platform: "darwin",
},
kolide.Query{
Name: "query3",
Query: "query3",
Platform: "darwin",
},
kolide.Query{
Name: "query4",
Query: "query4",
Platform: "darwin",
},
}
for _, query := range labelQueries {
newQuery, err := db.NewQuery(&query)
assert.Nil(t, err)
assert.NotZero(t, newQuery.ID)
}
// this one should not show up
_, err = db.NewQuery(&kolide.Query{
Platform: "not_darwin",
Query: "query5",
})
assert.Nil(t, err)
// No queries should be returned before labels added
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Empty(t, queries)
newLabels := []kolide.Label{
// Note these are intentionally out of order
kolide.Label{
Name: "label3",
QueryID: 3,
},
kolide.Label{
Name: "label1",
QueryID: 1,
},
kolide.Label{
Name: "label2",
QueryID: 2,
},
kolide.Label{
Name: "label4",
QueryID: 4,
},
}
for _, label := range newLabels {
newLabel, err := db.NewLabel(&label)
assert.Nil(t, err)
assert.NotZero(t, newLabel.ID)
}
expectQueries := map[string]string{
"1": "query3",
"2": "query1",
"3": "query2",
"4": "query4",
}
host.Platform = "darwin"
// Now queries should be returned
queries, err = db.LabelQueriesForHost(host, baseTime)
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// No labels should match with no results yet
labels, err = db.LabelsForHost(host.ID)
assert.Nil(t, err)
assert.Empty(t, labels)
// Record a query execution
err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime)
assert.Nil(t, err)
// Use a 10 minute interval, so the query we just added should show up
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
delete(expectQueries, "1")
assert.Equal(t, expectQueries, queries)
// Record an old query execution -- Shouldn't change the return
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour))
assert.Nil(t, err)
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Record a newer execution for that query and another
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime)
assert.Nil(t, err)
// Now these should no longer show up in the necessary to run queries
delete(expectQueries, "2")
delete(expectQueries, "3")
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
assert.Nil(t, err)
assert.Equal(t, expectQueries, queries)
// Now the two matching labels should be returned
labels, err = db.LabelsForHost(host.ID)
assert.Nil(t, err)
if assert.Len(t, labels, 2) {
labelNames := []string{labels[0].Name, labels[1].Name}
sort.Strings(labelNames)
assert.Equal(t, "label2", labelNames[0])
assert.Equal(t, "label3", labelNames[1])
}
// A host that hasn't executed any label queries should still be asked
// to execute those queries
hosts[0].Platform = "darwin"
queries, err = db.LabelQueriesForHost(host, time.Now())
assert.Nil(t, err)
assert.Len(t, queries, 4)
// There should still be no labels returned for a host that never
// executed any label queries
labels, err = db.LabelsForHost(hosts[0].ID)
assert.Nil(t, err)
assert.Empty(t, labels)
}
// setup creates a datastore for testing
func setup(t *testing.T) kolide.Datastore {
db, err := gorm.Open("sqlite3", ":memory:")
require.Nil(t, err)
ds := gormDB{DB: db, Driver: "sqlite3"}
err = ds.Migrate()
assert.Nil(t, err)
// Log using t.Log so that output only shows up if the test fails
//db.SetLogger(&testLogger{t: t})
//db.LogMode(true)
return ds
}
func teardown(t *testing.T, ds kolide.Datastore) {
err := ds.Drop()
assert.Nil(t, err)
}
type testLogger struct {
t *testing.T
}
func (t *testLogger) Print(v ...interface{}) {
t.t.Log(v...)
}
func (t *testLogger) Write(p []byte) (n int, err error) {
t.t.Log(string(p))
return len(p), nil
}
func randomString(strlen int) string {
rand.Seed(time.Now().UTC().UnixNano())
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
result := make([]byte, strlen)
for i := 0; i < strlen; i++ {
result[i] = chars[rand.Intn(len(chars))]
}
return string(result)
}
func testSaveQuery(t *testing.T, ds kolide.Datastore) {
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, 0, query.ID)
query.Query = "baz"
err = ds.SaveQuery(query)
assert.Nil(t, err)
queryVerify, err := ds.Query(query.ID)
assert.Nil(t, err)
assert.Equal(t, "baz", queryVerify.Query)
}
func testDeleteQuery(t *testing.T, ds kolide.Datastore) {
query := &kolide.Query{
Name: "foo",
Query: "bar",
}
query, err := ds.NewQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
err = ds.DeleteQuery(query)
assert.Nil(t, err)
assert.NotEqual(t, query.ID, 0)
_, err = ds.Query(query.ID)
assert.NotNil(t, err)
}
func testDeletePack(t *testing.T, ds kolide.Datastore) {
pack := &kolide.Pack{
Name: "foo",
}
err := ds.NewPack(pack)
assert.Nil(t, err)
assert.NotEqual(t, pack.ID, 0)
pack, err = ds.Pack(pack.ID)
assert.Nil(t, err)
err = ds.DeletePack(pack.ID)
assert.Nil(t, err)
assert.NotEqual(t, pack.ID, 0)
pack, err = ds.Pack(pack.ID)
assert.NotNil(t, err)
}
func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
pack := &kolide.Pack{
Name: "foo",
}
err := ds.NewPack(pack)
assert.Nil(t, err)
q1 := &kolide.Query{
Name: "bar",
Query: "bar",
}
_, err = ds.NewQuery(q1)
assert.Nil(t, err)
err = ds.AddQueryToPack(q1.ID, pack.ID)
assert.Nil(t, err)
q2 := &kolide.Query{
Name: "baz",
Query: "baz",
}
_, err = ds.NewQuery(q2)
assert.Nil(t, err)
err = ds.AddQueryToPack(q2.ID, pack.ID)
assert.Nil(t, err)
queries, err := ds.GetQueriesInPack(pack)
assert.Nil(t, err)
assert.Len(t, queries, 2)
err = ds.RemoveQueryFromPack(q1, pack)
assert.Nil(t, err)
queries, err = ds.GetQueriesInPack(pack)
assert.Nil(t, err)
assert.Len(t, queries, 1)
}
func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) {
mysqlQuery := &kolide.Query{
Name: "MySQL",
Query: "select pid from processes where name = 'mysqld';",
}
mysqlQuery, err := ds.NewQuery(mysqlQuery)
assert.Nil(t, err)
osqueryRunningQuery := &kolide.Query{
Name: "Is osquery currently running?",
Query: "select pid from processes where name = 'osqueryd';",
}
osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery)
assert.Nil(t, err)
monitoringPack := &kolide.Pack{
Name: "monitoring",
}
err = ds.NewPack(monitoringPack)
assert.Nil(t, err)
mysqlLabel := &kolide.Label{
Name: "MySQL Monitoring",
QueryID: mysqlQuery.ID,
}
mysqlLabel, err = ds.NewLabel(mysqlLabel)
assert.Nil(t, err)
err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID)
assert.Nil(t, err)
labels, err := ds.GetLabelsForPack(monitoringPack)
assert.Nil(t, err)
assert.Len(t, labels, 1)
assert.Equal(t, "MySQL Monitoring", labels[0].Name)
osqueryLabel := &kolide.Label{
Name: "Osquery Monitoring",
QueryID: osqueryRunningQuery.ID,
}
osqueryLabel, err = ds.NewLabel(osqueryLabel)
assert.Nil(t, err)
err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID)
assert.Nil(t, err)
labels, err = ds.GetLabelsForPack(monitoringPack)
assert.Nil(t, err)
assert.Len(t, labels, 2)
var testFunctions = [...]func(*testing.T, kolide.Datastore){
testOrgInfo,
testCreateInvite,
testDeleteQuery,
testSaveQuery,
testDeletePack,
testAddAndRemoveQueryFromPack,
testEnrollHost,
testAuthenticateHost,
testLabels,
testManagingLabelsOnPacks,
testPasswordResetRequests,
testCreateUser,
testSaveUser,
testUserByID,
testPasswordResetRequests,
}

View file

@ -0,0 +1,125 @@
package datastore
import (
"fmt"
"testing"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
)
func testCreateUser(t *testing.T, ds kolide.Datastore) {
var createTests = []struct {
username, password, email string
isAdmin, passwordReset bool
}{
{"marpaia", "foobar", "mike@kolide.co", true, false},
{"jason", "foobar", "jason@kolide.co", true, false},
}
for _, tt := range createTests {
u := &kolide.User{
Username: tt.username,
Password: []byte(tt.password),
Admin: tt.isAdmin,
AdminForcedPasswordReset: tt.passwordReset,
Email: tt.email,
}
user, err := ds.NewUser(u)
assert.Nil(t, err)
verify, err := ds.User(tt.username)
assert.Nil(t, err)
assert.Equal(t, user.ID, verify.ID)
assert.Equal(t, tt.username, verify.Username)
assert.Equal(t, tt.email, verify.Email)
assert.Equal(t, tt.email, verify.Email)
}
}
func testUserByID(t *testing.T, ds kolide.Datastore) {
users := createTestUsers(t, ds)
for _, tt := range users {
returned, err := ds.UserByID(tt.ID)
assert.Nil(t, err)
assert.Equal(t, tt.ID, returned.ID)
}
// test missing user
_, err := ds.UserByID(10000000000)
assert.NotNil(t, err)
}
func createTestUsers(t *testing.T, ds kolide.Datastore) []*kolide.User {
var createTests = []struct {
username, password, email string
isAdmin, passwordReset bool
}{
{"marpaia", "foobar", "mike@kolide.co", true, false},
{"jason", "foobar", "jason@kolide.co", false, false},
}
var users []*kolide.User
for _, tt := range createTests {
u := &kolide.User{
Username: tt.username,
Password: []byte(tt.password),
Admin: tt.isAdmin,
AdminForcedPasswordReset: tt.passwordReset,
Email: tt.email,
}
user, err := ds.NewUser(u)
assert.Nil(t, err)
users = append(users, user)
}
assert.NotEmpty(t, users)
return users
}
func testSaveUser(t *testing.T, ds kolide.Datastore) {
users := createTestUsers(t, ds)
testAdminAttribute(t, ds, users)
testEmailAttribute(t, ds, users)
testPasswordAttribute(t, ds, users)
}
func testPasswordAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
for _, user := range users {
randomText, err := generateRandomText(8)
assert.Nil(t, err)
user.Password = []byte(randomText)
err = ds.SaveUser(user)
assert.Nil(t, err)
verify, err := ds.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Password, verify.Password)
}
}
func testEmailAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
for _, user := range users {
user.Email = fmt.Sprintf("test.%s", user.Email)
err := ds.SaveUser(user)
assert.Nil(t, err)
verify, err := ds.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Email, verify.Email)
}
}
func testAdminAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
for _, user := range users {
user.Admin = false
err := ds.SaveUser(user)
assert.Nil(t, err)
verify, err := ds.User(user.Username)
assert.Nil(t, err)
assert.Equal(t, user.Admin, verify.Admin)
}
}

View file

@ -1,12 +1,7 @@
package datastore
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
_ "github.com/go-sql-driver/mysql" // db driver
@ -14,7 +9,6 @@ import (
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/config"
"github.com/kolide/kolide-ose/server/errors"
"github.com/kolide/kolide-ose/server/kolide"
)
@ -43,18 +37,6 @@ type gormDB struct {
config config.KolideConfig
}
// GetMysqlConnectionString returns a MySQL connection string using the
// provided configuration.
func GetMysqlConnectionString(conf config.MysqlConfig) string {
return fmt.Sprintf(
"%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
conf.Username,
conf.Password,
conf.Address,
conf.Database,
)
}
func (orm gormDB) Name() string {
return "gorm"
}
@ -102,579 +84,3 @@ func openGORM(driver, conn string, maxAttempts int) (*gorm.DB, error) {
}
return db, nil
}
func generateRandomText(keySize int) (string, error) {
key := make([]byte, keySize)
_, err := rand.Read(key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) {
if uuid == "" {
return nil, errors.New("missing uuid for host enrollment", "programmer error?")
}
host := kolide.Host{UUID: uuid}
err := orm.DB.Where(&host).First(&host).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
// Create new Host
host = kolide.Host{
UUID: uuid,
HostName: hostname,
IPAddress: ip,
Platform: platform,
}
default:
return nil, err
}
}
// Generate a new key each enrollment
host.NodeKey, err = generateRandomText(nodeKeySize)
if err != nil {
return nil, err
}
// Update these fields if provided
if hostname != "" {
host.HostName = hostname
}
if ip != "" {
host.IPAddress = ip
}
if platform != "" {
host.Platform = platform
}
if err := orm.DB.Save(&host).Error; err != nil {
return nil, err
}
return &host, nil
}
func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) {
host := kolide.Host{NodeKey: nodeKey}
err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
e := errors.NewFromError(
err,
http.StatusUnauthorized,
"invalid node key",
)
// osqueryd expects the literal string "true" here
e.Extra = map[string]interface{}{"node_invalid": "true"}
return nil, e
default:
return nil, errors.DatabaseError(err)
}
}
return &host, nil
}
func (orm gormDB) SaveHost(host *kolide.Host) error {
if err := orm.DB.Save(host).Error; err != nil {
return errors.DatabaseError(err)
}
return nil
}
func (orm gormDB) DeleteHost(host *kolide.Host) error {
return orm.DB.Delete(host).Error
}
func (orm gormDB) Host(id uint) (*kolide.Host, error) {
host := &kolide.Host{
ID: id,
}
err := orm.DB.Where(host).First(host).Error
if err != nil {
return nil, err
}
return host, nil
}
func (orm gormDB) Hosts() ([]*kolide.Host, error) {
var hosts []*kolide.Host
err := orm.DB.Find(&hosts).Error
if err != nil {
return nil, err
}
return hosts, nil
}
func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) {
if host == nil {
return nil, errors.New(
"error creating host",
"nil pointer passed to NewHost",
)
}
err := orm.DB.Create(host).Error
if err != nil {
return nil, err
}
return host, err
}
func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error {
err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error
if err != nil {
return errors.DatabaseError(err)
}
host.UpdatedAt = t
return nil
}
func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) {
if query == nil {
return nil, errors.New(
"error creating query",
"nil pointer passed to NewQuery",
)
}
err := orm.DB.Create(query).Error
if err != nil {
return nil, err
}
return query, nil
}
func (orm gormDB) SaveQuery(query *kolide.Query) error {
if query == nil {
return errors.New(
"error saving query",
"nil pointer passed to SaveQuery",
)
}
return orm.DB.Save(query).Error
}
func (orm gormDB) DeleteQuery(query *kolide.Query) error {
if query == nil {
return errors.New(
"error deleting query",
"nil pointer passed to DeleteQuery",
)
}
return orm.DB.Delete(query).Error
}
func (orm gormDB) Query(id uint) (*kolide.Query, error) {
query := &kolide.Query{
ID: id,
}
err := orm.DB.Where(query).First(query).Error
if err != nil {
return nil, err
}
return query, nil
}
func (orm gormDB) Queries() ([]*kolide.Query, error) {
var queries []*kolide.Query
err := orm.DB.Find(&queries).Error
return queries, err
}
func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) {
if label == nil {
return nil, errors.New(
"error creating label",
"nil pointer passed to NewLabel",
)
}
err := orm.DB.Create(label).Error
if err != nil {
return nil, err
}
return label, nil
}
func (orm gormDB) SaveLabel(label *kolide.Label) error {
if label == nil {
return errors.New(
"error saving label",
"nil pointer passed to SaveLabel",
)
}
return orm.DB.Save(label).Error
}
func (orm gormDB) DeleteLabel(lid uint) error {
err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error
if err != nil {
return err
}
return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error
}
func (orm gormDB) Label(lid uint) (*kolide.Label, error) {
label := &kolide.Label{
ID: lid,
}
err := orm.DB.Where("id = ?", label.ID).First(&label).Error
if err != nil {
return nil, err
}
return label, nil
}
func (orm gormDB) Labels() ([]*kolide.Label, error) {
var labels []*kolide.Label
err := orm.DB.Find(&labels).Error
return labels, err
}
func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {
if host == nil {
return nil, errors.New(
"error finding host queries",
"nil pointer passed to LabelQueriesForHost",
)
}
rows, err := orm.DB.Raw(`
SELECT l.id, q.query
FROM labels l JOIN queries q
ON l.query_id = q.id
WHERE q.platform = ?
AND q.id NOT IN /* subtract the set of executions that are recent enough */
(
SELECT l.query_id
FROM labels l
JOIN label_query_executions lqe
ON lqe.label_id = l.id
WHERE lqe.host_id = ? AND lqe.updated_at > ?
)`, host.Platform, host.ID, cutoff).Rows()
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
defer rows.Close()
results := make(map[string]string)
for rows.Next() {
var id, query string
err = rows.Scan(&id, &query)
if err != nil {
return nil, errors.DatabaseError(err)
}
results[id] = query
}
return results, nil
}
func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error {
if host == nil {
return errors.New(
"error recording host label query execution",
"nil pointer passed to RecordLabelQueryExecutions",
)
}
insert := new(bytes.Buffer)
switch orm.Driver {
case "mysql":
insert.WriteString("INSERT ")
case "sqlite3":
insert.WriteString("REPLACE ")
default:
return errors.New(
"Unknown DB driver",
"Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver,
)
}
insert.WriteString(
"INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES",
)
// Build up all the values and the query string
vals := []interface{}{}
for labelId, res := range results {
insert.WriteString("(?,?,?,?),")
vals = append(vals, t, res, labelId, host.ID)
}
queryString := insert.String()
queryString = strings.TrimSuffix(queryString, ",")
switch orm.Driver {
case "mysql":
queryString += `
ON DUPLICATE KEY UPDATE
updated_at = VALUES(updated_at),
matches = VALUES(matches)
`
}
if err := orm.DB.Exec(queryString, vals...).Error; err != nil {
return errors.DatabaseError(err)
}
return nil
}
func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) {
results := []kolide.Label{}
err := orm.DB.Raw(`
SELECT labels.* from labels, label_query_executions lqe
WHERE lqe.host_id = ?
AND lqe.label_id = labels.id
AND lqe.matches
`, hid).Scan(&results).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
return results, nil
}
func (orm gormDB) NewPack(pack *kolide.Pack) error {
if pack == nil {
return errors.New(
"error creating pack",
"nil pointer passed to NewPack",
)
}
return orm.DB.Create(pack).Error
}
func (orm gormDB) SavePack(pack *kolide.Pack) error {
if pack == nil {
return errors.New(
"error saving pack",
"nil pointer passed to SavePack",
)
}
return orm.DB.Save(pack).Error
}
func (orm gormDB) DeletePack(pid uint) error {
err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error
if err != nil {
return err
}
err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error
if err != nil {
return err
}
return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error
}
func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) {
pack := &kolide.Pack{
ID: pid,
}
err := orm.DB.Where(pack).First(pack).Error
if err != nil {
return nil, err
}
return pack, nil
}
func (orm gormDB) Packs() ([]*kolide.Pack, error) {
var packs []*kolide.Pack
err := orm.DB.Find(&packs).Error
return packs, err
}
func (orm gormDB) AddQueryToPack(qid uint, pid uint) error {
pq := &kolide.PackQuery{
QueryID: qid,
PackID: pid,
}
return orm.DB.Create(pq).Error
}
func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) {
var queries []*kolide.Query
if pack == nil {
return nil, errors.New(
"error getting queries in pack",
"nil pointer passed to GetQueriesInPack",
)
}
rows, err := orm.DB.Raw(`
SELECT
q.id,
q.created_at,
q.updated_at,
q.name,
q.query,
q.interval,
q.snapshot,
q.differential,
q.platform,
q.version
FROM
queries q
JOIN
pack_queries pq
ON
pq.query_id = q.id
AND
pq.pack_id = ?;
`, pack.ID).Rows()
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
defer rows.Close()
for rows.Next() {
query := new(kolide.Query)
err = rows.Scan(
&query.ID,
&query.CreatedAt,
&query.UpdatedAt,
&query.Name,
&query.Query,
&query.Interval,
&query.Snapshot,
&query.Differential,
&query.Platform,
&query.Version,
)
if err != nil {
return nil, err
}
queries = append(queries, query)
}
return queries, nil
}
func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error {
if query == nil || pack == nil {
return errors.New(
"error removing query from pack",
"nil pointer passed to RemoveQueryFromPack",
)
}
pq := &kolide.PackQuery{
QueryID: query.ID,
PackID: pack.ID,
}
return orm.DB.Where(pq).Delete(pq).Error
}
func (orm gormDB) AddLabelToPack(lid uint, pid uint) error {
pt := &kolide.PackTarget{
Type: kolide.TargetLabel,
PackID: pid,
TargetID: lid,
}
return orm.DB.Create(pt).Error
}
func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) {
packs := []*kolide.Pack{}
// we will need to give some subset of packs to this host based on the
// labels which this host is known to belong to
allPacks, err := orm.Packs()
if err != nil {
return nil, err
}
// pull the labels that this host belongs to
labels, err := orm.LabelsForHost(hid)
if err != nil {
return nil, err
}
// in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double
// for loop iteration, we must create the array which may be indexed below
labelIDs := map[uint]bool{}
for _, label := range labels {
labelIDs[label.ID] = true
}
for _, pack := range allPacks {
// for each pack, we must know what labels have been assigned to that
// pack
labelsForPack, err := orm.GetLabelsForPack(pack)
if err != nil {
return nil, err
}
// o(n) iteration to determine whether or not a pack is enabled
// in this case, n is len(labelsForPack)
for _, label := range labelsForPack {
if labelIDs[label.ID] {
packs = append(packs, pack)
break
}
}
}
return packs, nil
}
func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) {
if pack == nil {
return nil, errors.New(
"error getting labels for pack",
"nil pointer passed to GetLabelsForPack",
)
}
results := []*kolide.Label{}
err := orm.DB.Raw(`
SELECT
l.id,
l.created_at,
l.updated_at,
l.name,
l.query_id
FROM
labels l
JOIN
pack_targets pt
ON
pt.target_id = l.id
WHERE
pt.type = ?
AND
pt.pack_id = ?;
`,
kolide.TargetLabel, pack.ID).Scan(&results).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
return results, nil
}
func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error {
if label == nil || pack == nil {
return errors.New(
"error removing label from pack",
"nil pointer passed to RemoveLabelFromPack",
)
}
pt := &kolide.PackTarget{
Type: kolide.TargetLabel,
PackID: pack.ID,
TargetID: label.ID,
}
return orm.DB.Delete(pt).Error
}

View file

@ -0,0 +1,132 @@
package datastore
import (
"net/http"
"time"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/errors"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) {
if uuid == "" {
return nil, errors.New("missing uuid for host enrollment", "programmer error?")
}
host := kolide.Host{UUID: uuid}
err := orm.DB.Where(&host).First(&host).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
// Create new Host
host = kolide.Host{
UUID: uuid,
HostName: hostname,
IPAddress: ip,
Platform: platform,
}
default:
return nil, err
}
}
// Generate a new key each enrollment
host.NodeKey, err = generateRandomText(nodeKeySize)
if err != nil {
return nil, err
}
// Update these fields if provided
if hostname != "" {
host.HostName = hostname
}
if ip != "" {
host.IPAddress = ip
}
if platform != "" {
host.Platform = platform
}
if err := orm.DB.Save(&host).Error; err != nil {
return nil, err
}
return &host, nil
}
func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) {
host := kolide.Host{NodeKey: nodeKey}
err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
e := errors.NewFromError(
err,
http.StatusUnauthorized,
"invalid node key",
)
// osqueryd expects the literal string "true" here
e.Extra = map[string]interface{}{"node_invalid": "true"}
return nil, e
default:
return nil, errors.DatabaseError(err)
}
}
return &host, nil
}
func (orm gormDB) SaveHost(host *kolide.Host) error {
if err := orm.DB.Save(host).Error; err != nil {
return errors.DatabaseError(err)
}
return nil
}
func (orm gormDB) DeleteHost(host *kolide.Host) error {
return orm.DB.Delete(host).Error
}
func (orm gormDB) Host(id uint) (*kolide.Host, error) {
host := &kolide.Host{
ID: id,
}
err := orm.DB.Where(host).First(host).Error
if err != nil {
return nil, err
}
return host, nil
}
func (orm gormDB) Hosts() ([]*kolide.Host, error) {
var hosts []*kolide.Host
err := orm.DB.Find(&hosts).Error
if err != nil {
return nil, err
}
return hosts, nil
}
func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) {
if host == nil {
return nil, errors.New(
"error creating host",
"nil pointer passed to NewHost",
)
}
err := orm.DB.Create(host).Error
if err != nil {
return nil, err
}
return host, err
}
func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error {
err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error
if err != nil {
return errors.DatabaseError(err)
}
host.UpdatedAt = t
return nil
}

View file

@ -0,0 +1,166 @@
package datastore
import (
"bytes"
"strings"
"time"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/errors"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) {
if label == nil {
return nil, errors.New(
"error creating label",
"nil pointer passed to NewLabel",
)
}
err := orm.DB.Create(label).Error
if err != nil {
return nil, err
}
return label, nil
}
func (orm gormDB) SaveLabel(label *kolide.Label) error {
if label == nil {
return errors.New(
"error saving label",
"nil pointer passed to SaveLabel",
)
}
return orm.DB.Save(label).Error
}
func (orm gormDB) DeleteLabel(lid uint) error {
err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error
if err != nil {
return err
}
return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error
}
func (orm gormDB) Label(lid uint) (*kolide.Label, error) {
label := &kolide.Label{
ID: lid,
}
err := orm.DB.Where("id = ?", label.ID).First(&label).Error
if err != nil {
return nil, err
}
return label, nil
}
func (orm gormDB) Labels() ([]*kolide.Label, error) {
var labels []*kolide.Label
err := orm.DB.Find(&labels).Error
return labels, err
}
func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {
if host == nil {
return nil, errors.New(
"error finding host queries",
"nil pointer passed to LabelQueriesForHost",
)
}
rows, err := orm.DB.Raw(`
SELECT l.id, q.query
FROM labels l JOIN queries q
ON l.query_id = q.id
WHERE q.platform = ?
AND q.id NOT IN /* subtract the set of executions that are recent enough */
(
SELECT l.query_id
FROM labels l
JOIN label_query_executions lqe
ON lqe.label_id = l.id
WHERE lqe.host_id = ? AND lqe.updated_at > ?
)`, host.Platform, host.ID, cutoff).Rows()
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
defer rows.Close()
results := make(map[string]string)
for rows.Next() {
var id, query string
err = rows.Scan(&id, &query)
if err != nil {
return nil, errors.DatabaseError(err)
}
results[id] = query
}
return results, nil
}
func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error {
if host == nil {
return errors.New(
"error recording host label query execution",
"nil pointer passed to RecordLabelQueryExecutions",
)
}
insert := new(bytes.Buffer)
switch orm.Driver {
case "mysql":
insert.WriteString("INSERT ")
case "sqlite3":
insert.WriteString("REPLACE ")
default:
return errors.New(
"Unknown DB driver",
"Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver,
)
}
insert.WriteString(
"INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES",
)
// Build up all the values and the query string
vals := []interface{}{}
for labelId, res := range results {
insert.WriteString("(?,?,?,?),")
vals = append(vals, t, res, labelId, host.ID)
}
queryString := insert.String()
queryString = strings.TrimSuffix(queryString, ",")
switch orm.Driver {
case "mysql":
queryString += `
ON DUPLICATE KEY UPDATE
updated_at = VALUES(updated_at),
matches = VALUES(matches)
`
}
if err := orm.DB.Exec(queryString, vals...).Error; err != nil {
return errors.DatabaseError(err)
}
return nil
}
func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) {
results := []kolide.Label{}
err := orm.DB.Raw(`
SELECT labels.* from labels, label_query_executions lqe
WHERE lqe.host_id = ?
AND lqe.label_id = labels.id
AND lqe.matches
`, hid).Scan(&results).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
return results, nil
}

View file

@ -0,0 +1,245 @@
package datastore
import (
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/errors"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm gormDB) NewPack(pack *kolide.Pack) error {
if pack == nil {
return errors.New(
"error creating pack",
"nil pointer passed to NewPack",
)
}
return orm.DB.Create(pack).Error
}
func (orm gormDB) SavePack(pack *kolide.Pack) error {
if pack == nil {
return errors.New(
"error saving pack",
"nil pointer passed to SavePack",
)
}
return orm.DB.Save(pack).Error
}
func (orm gormDB) DeletePack(pid uint) error {
err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error
if err != nil {
return err
}
err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error
if err != nil {
return err
}
return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error
}
func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) {
pack := &kolide.Pack{
ID: pid,
}
err := orm.DB.Where(pack).First(pack).Error
if err != nil {
return nil, err
}
return pack, nil
}
func (orm gormDB) Packs() ([]*kolide.Pack, error) {
var packs []*kolide.Pack
err := orm.DB.Find(&packs).Error
return packs, err
}
func (orm gormDB) AddQueryToPack(qid uint, pid uint) error {
pq := &kolide.PackQuery{
QueryID: qid,
PackID: pid,
}
return orm.DB.Create(pq).Error
}
func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) {
var queries []*kolide.Query
if pack == nil {
return nil, errors.New(
"error getting queries in pack",
"nil pointer passed to GetQueriesInPack",
)
}
rows, err := orm.DB.Raw(`
SELECT
q.id,
q.created_at,
q.updated_at,
q.name,
q.query,
q.interval,
q.snapshot,
q.differential,
q.platform,
q.version
FROM
queries q
JOIN
pack_queries pq
ON
pq.query_id = q.id
AND
pq.pack_id = ?;
`, pack.ID).Rows()
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
defer rows.Close()
for rows.Next() {
query := new(kolide.Query)
err = rows.Scan(
&query.ID,
&query.CreatedAt,
&query.UpdatedAt,
&query.Name,
&query.Query,
&query.Interval,
&query.Snapshot,
&query.Differential,
&query.Platform,
&query.Version,
)
if err != nil {
return nil, err
}
queries = append(queries, query)
}
return queries, nil
}
func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error {
if query == nil || pack == nil {
return errors.New(
"error removing query from pack",
"nil pointer passed to RemoveQueryFromPack",
)
}
pq := &kolide.PackQuery{
QueryID: query.ID,
PackID: pack.ID,
}
return orm.DB.Where(pq).Delete(pq).Error
}
func (orm gormDB) AddLabelToPack(lid uint, pid uint) error {
pt := &kolide.PackTarget{
Type: kolide.TargetLabel,
PackID: pid,
TargetID: lid,
}
return orm.DB.Create(pt).Error
}
func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) {
packs := []*kolide.Pack{}
// we will need to give some subset of packs to this host based on the
// labels which this host is known to belong to
allPacks, err := orm.Packs()
if err != nil {
return nil, err
}
// pull the labels that this host belongs to
labels, err := orm.LabelsForHost(hid)
if err != nil {
return nil, err
}
// in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double
// for loop iteration, we must create the array which may be indexed below
labelIDs := map[uint]bool{}
for _, label := range labels {
labelIDs[label.ID] = true
}
for _, pack := range allPacks {
// for each pack, we must know what labels have been assigned to that
// pack
labelsForPack, err := orm.GetLabelsForPack(pack)
if err != nil {
return nil, err
}
// o(n) iteration to determine whether or not a pack is enabled
// in this case, n is len(labelsForPack)
for _, label := range labelsForPack {
if labelIDs[label.ID] {
packs = append(packs, pack)
break
}
}
}
return packs, nil
}
func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) {
if pack == nil {
return nil, errors.New(
"error getting labels for pack",
"nil pointer passed to GetLabelsForPack",
)
}
results := []*kolide.Label{}
err := orm.DB.Raw(`
SELECT
l.id,
l.created_at,
l.updated_at,
l.name,
l.query_id
FROM
labels l
JOIN
pack_targets pt
ON
pt.target_id = l.id
WHERE
pt.type = ?
AND
pt.pack_id = ?;
`,
kolide.TargetLabel, pack.ID).Scan(&results).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, errors.DatabaseError(err)
}
return results, nil
}
func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error {
if label == nil || pack == nil {
return errors.New(
"error removing label from pack",
"nil pointer passed to RemoveLabelFromPack",
)
}
pt := &kolide.PackTarget{
Type: kolide.TargetLabel,
PackID: pack.ID,
TargetID: label.ID,
}
return orm.DB.Delete(pt).Error
}

View file

@ -0,0 +1,57 @@
package datastore
import (
"github.com/kolide/kolide-ose/server/errors"
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) {
if query == nil {
return nil, errors.New(
"error creating query",
"nil pointer passed to NewQuery",
)
}
err := orm.DB.Create(query).Error
if err != nil {
return nil, err
}
return query, nil
}
func (orm gormDB) SaveQuery(query *kolide.Query) error {
if query == nil {
return errors.New(
"error saving query",
"nil pointer passed to SaveQuery",
)
}
return orm.DB.Save(query).Error
}
func (orm gormDB) DeleteQuery(query *kolide.Query) error {
if query == nil {
return errors.New(
"error deleting query",
"nil pointer passed to DeleteQuery",
)
}
return orm.DB.Delete(query).Error
}
func (orm gormDB) Query(id uint) (*kolide.Query, error) {
query := &kolide.Query{
ID: id,
}
err := orm.DB.Where(query).First(query).Error
if err != nil {
return nil, err
}
return query, nil
}
func (orm gormDB) Queries() ([]*kolide.Query, error) {
var queries []*kolide.Query
err := orm.DB.Find(&queries).Error
return queries, err
}

View file

@ -1,124 +1,36 @@
package datastore
import (
"fmt"
"os"
"testing"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/server/kolide"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupMySQLGORM(t *testing.T) kolide.Datastore {
user := "kolide"
password := "kolide"
dbName := "kolide"
func setupGorm(t *testing.T) kolide.Datastore {
db, err := gorm.Open("sqlite3", ":memory:")
require.Nil(t, err)
// try container first
host := os.Getenv("MYSQL_PORT_3306_TCP_ADDR")
if host == "" {
host = "127.0.0.1"
}
host = fmt.Sprintf("%s:3306", host)
ds := gormDB{DB: db, Driver: "sqlite3"}
conn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, host, dbName)
db, err := New("gorm-mysql", conn, LimitAttempts(1))
err = ds.Migrate()
assert.Nil(t, err)
backend := db.(gormDB)
err = backend.Migrate()
assert.Nil(t, err)
return db
return ds
}
func teardownMySQLGORM(t *testing.T, db kolide.Datastore) {
err := db.Drop()
func teardownGorm(t *testing.T, ds kolide.Datastore) {
err := ds.Drop()
assert.Nil(t, err)
}
func TestEnrollHostMySQLGORM(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
func TestGorm(t *testing.T) {
for _, f := range testFunctions {
t.Run(functionName(f), func(t *testing.T) {
ds := setupGorm(t)
defer teardownGorm(t, ds)
f(t, ds)
})
}
db := setupMySQLGORM(t)
defer teardownMySQLGORM(t, db)
testEnrollHost(t, db)
}
func TestAuthenticateHostMySQLGORM(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
db := setup(t)
defer teardown(t, db)
testAuthenticateHost(t, db)
}
func TestUserByIDMySQLGORM(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
db := setup(t)
defer teardown(t, db)
testUserByID(t, db)
}
// TestCreateUser tests the UserStore interface
// this test uses the MySQL GORM backend
func TestCreateUserMySQLGORM(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
db := setupMySQLGORM(t)
defer teardownMySQLGORM(t, db)
testCreateUser(t, db)
}
func TestSaveUserMySQLGORM(t *testing.T) {
address := os.Getenv("MYSQL_ADDR")
if address == "" {
t.SkipNow()
}
db := setupMySQLGORM(t)
defer teardownMySQLGORM(t, db)
testSaveUser(t, db)
}
func TestSaveQuery(t *testing.T) {
ds := setup(t)
testSaveQuery(t, ds)
}
func TestDeleteQuery(t *testing.T) {
ds := setup(t)
testDeleteQuery(t, ds)
}
func TestDeletePack(t *testing.T) {
ds := setup(t)
testDeletePack(t, ds)
}
func TestAddAndRemoveQueryFromPack(t *testing.T) {
ds := setup(t)
testAddAndRemoveQueryFromPack(t, ds)
}
func TestManagingLabelsOnPacks(t *testing.T) {
ds := setup(t)
testManagingLabelsOnPacks(t, ds)
}

View file

@ -1,8 +1,9 @@
package datastore
import "github.com/kolide/kolide-ose/server/kolide"
import (
"github.com/kolide/kolide-ose/server/kolide"
)
// NewUser creates a new user in the gorm backend
func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) {
err := orm.DB.Create(user).Error
if err != nil {
@ -11,7 +12,6 @@ func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) {
return user, nil
}
// User returns a specific user in the gorm backend
func (orm gormDB) User(username string) (*kolide.User, error) {
user := &kolide.User{
Username: username,
@ -43,7 +43,6 @@ func (orm gormDB) UserByEmail(email string) (*kolide.User, error) {
return user, nil
}
// UserByID returns a datastore user given a user ID
func (orm gormDB) UserByID(id uint) (*kolide.User, error) {
user := &kolide.User{ID: id}
err := orm.DB.Where(user).First(user).Error

View file

@ -17,6 +17,7 @@ type inmem struct {
labels map[uint]*kolide.Label
labelQueryExecutions map[uint]*kolide.LabelQueryExecution
queries map[uint]*kolide.Query
packs map[uint]*kolide.Pack
hosts map[uint]*kolide.Host
orginfo *kolide.OrgInfo
}
@ -35,6 +36,7 @@ func (orm *inmem) Migrate() error {
orm.labels = make(map[uint]*kolide.Label)
orm.labelQueryExecutions = make(map[uint]*kolide.LabelQueryExecution)
orm.queries = make(map[uint]*kolide.Query)
orm.packs = make(map[uint]*kolide.Pack)
orm.hosts = make(map[uint]*kolide.Host)
return nil
}

View file

@ -0,0 +1,71 @@
package datastore
import (
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm *inmem) NewPack(pack *kolide.Pack) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
newPack := *pack
for _, q := range orm.packs {
if pack.Name == q.Name {
return ErrExists
}
}
newPack.ID = uint(len(orm.packs) + 1)
orm.packs[newPack.ID] = &newPack
return nil
}
func (orm *inmem) SavePack(pack *kolide.Pack) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.packs[pack.ID]; !ok {
return ErrNotFound
}
orm.packs[pack.ID] = pack
return nil
}
func (orm *inmem) DeletePack(pid uint) error {
orm.mtx.Lock()
defer orm.mtx.Unlock()
if _, ok := orm.packs[pid]; !ok {
return ErrNotFound
}
delete(orm.packs, pid)
return nil
}
func (orm *inmem) Pack(id uint) (*kolide.Pack, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
pack, ok := orm.packs[id]
if !ok {
return nil, ErrNotFound
}
return pack, nil
}
func (orm *inmem) Packs() ([]*kolide.Pack, error) {
orm.mtx.Lock()
defer orm.mtx.Unlock()
packs := []*kolide.Pack{}
for _, pack := range orm.packs {
packs = append(packs, pack)
}
return packs, nil
}

View file

@ -1,6 +1,8 @@
package datastore
import "github.com/kolide/kolide-ose/server/kolide"
import (
"github.com/kolide/kolide-ose/server/kolide"
)
func (orm *inmem) NewQuery(query *kolide.Query) (*kolide.Query, error) {
orm.mtx.Lock()