mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 08:58:41 +00:00
Datastore cleaning (#262)
This PR reorganizes a bunch of the files in datastore such that all datastore implementations are consistently broken up into multiple files. Additionally, the datastore tests follow a similar pattern and can easily be applied to any complete datastore implementation.
This commit is contained in:
parent
59c194a7f4
commit
12f8c0b671
22 changed files with 1323 additions and 1353 deletions
|
|
@ -67,11 +67,6 @@ the way that the kolide server works.
|
|||
if err != nil {
|
||||
initFatal(err, "initializing datastore")
|
||||
}
|
||||
|
||||
err = ds.Migrate()
|
||||
if err != nil {
|
||||
initFatal(err, "initializing datastore")
|
||||
}
|
||||
} else {
|
||||
connString := datastore.GetMysqlConnectionString(config.Mysql)
|
||||
ds, err = datastore.New("gorm-mysql", connString)
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/config"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
|
|
@ -71,14 +74,38 @@ func New(driver, conn string, opts ...DBOption) (kolide.Datastore, error) {
|
|||
return ds, nil
|
||||
case "inmem":
|
||||
ds := &inmem{
|
||||
Driver: "inmem",
|
||||
users: make(map[uint]*kolide.User),
|
||||
sessions: make(map[uint]*kolide.Session),
|
||||
passwordResets: make(map[uint]*kolide.PasswordResetRequest),
|
||||
invites: make(map[uint]*kolide.Invite),
|
||||
Driver: "inmem",
|
||||
}
|
||||
|
||||
err := ds.Migrate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ds, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported datastore driver %s", driver)
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomText(keySize int) (string, error) {
|
||||
key := make([]byte, keySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// GetMysqlConnectionString returns a MySQL connection string using the
|
||||
// provided configuration.
|
||||
func GetMysqlConnectionString(conf config.MysqlConfig) string {
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
|
||||
conf.Username,
|
||||
conf.Password,
|
||||
conf.Address,
|
||||
conf.Database,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,26 +1,12 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOrgInfo(t *testing.T) {
|
||||
var ds kolide.Datastore
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
ds = setup(t)
|
||||
} else {
|
||||
ds = setupMySQLGORM(t)
|
||||
defer teardownMySQLGORM(t, ds)
|
||||
}
|
||||
|
||||
testOrgInfo(t, ds)
|
||||
}
|
||||
|
||||
func testOrgInfo(t *testing.T, db kolide.Datastore) {
|
||||
info := &kolide.OrgInfo{
|
||||
OrgName: "Kolide",
|
||||
83
server/datastore/datastore_hosts_test.go
Normal file
83
server/datastore/datastore_hosts_test.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var enrollTests = []struct {
|
||||
uuid, hostname, ip, platform string
|
||||
nodeKeySize int
|
||||
}{
|
||||
0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
|
||||
hostname: "web.kolide.co",
|
||||
ip: "172.0.0.1",
|
||||
platform: "linux",
|
||||
nodeKeySize: 12,
|
||||
},
|
||||
1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
|
||||
hostname: "mail.kolide.co",
|
||||
ip: "172.0.0.2",
|
||||
platform: "linux",
|
||||
nodeKeySize: 10,
|
||||
},
|
||||
2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
|
||||
hostname: "home.kolide.co",
|
||||
ip: "127.0.0.1",
|
||||
platform: "darwin",
|
||||
nodeKeySize: 25,
|
||||
},
|
||||
3: {uuid: "uuid123",
|
||||
hostname: "fakehostname",
|
||||
ip: "192.168.1.1",
|
||||
platform: "darwin",
|
||||
nodeKeySize: 1,
|
||||
},
|
||||
}
|
||||
|
||||
func testEnrollHost(t *testing.T, db kolide.Datastore) {
|
||||
var hosts []*kolide.Host
|
||||
for _, tt := range enrollTests {
|
||||
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
|
||||
assert.Nil(t, err)
|
||||
|
||||
hosts = append(hosts, h)
|
||||
assert.Equal(t, tt.uuid, h.UUID)
|
||||
assert.Equal(t, tt.hostname, h.HostName)
|
||||
assert.Equal(t, tt.ip, h.IPAddress)
|
||||
assert.Equal(t, tt.platform, h.Platform)
|
||||
assert.NotEmpty(t, h.NodeKey)
|
||||
}
|
||||
|
||||
for _, enrolled := range hosts {
|
||||
oldNodeKey := enrolled.NodeKey
|
||||
newhostname := fmt.Sprintf("changed.%s", enrolled.HostName)
|
||||
|
||||
h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, enrolled.UUID, h.UUID)
|
||||
assert.NotEmpty(t, h.NodeKey)
|
||||
assert.NotEqual(t, oldNodeKey, h.NodeKey)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func testAuthenticateHost(t *testing.T, db kolide.Datastore) {
|
||||
for _, tt := range enrollTests {
|
||||
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
|
||||
assert.Nil(t, err)
|
||||
|
||||
returned, err := db.AuthenticateHost(h.NodeKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h.NodeKey, returned.NodeKey)
|
||||
}
|
||||
|
||||
_, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = db.AuthenticateHost("")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
|
@ -1,23 +1,13 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateInvite(t *testing.T) {
|
||||
var ds kolide.Datastore
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
ds = setup(t)
|
||||
} else {
|
||||
ds = setupMySQLGORM(t)
|
||||
defer teardownMySQLGORM(t, ds)
|
||||
}
|
||||
|
||||
func testCreateInvite(t *testing.T, ds kolide.Datastore) {
|
||||
invite := &kolide.Invite{}
|
||||
|
||||
invite, err := ds.NewInvite(invite)
|
||||
221
server/datastore/datastore_labels_test.go
Normal file
221
server/datastore/datastore_labels_test.go
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testLabels(t *testing.T, db kolide.Datastore) {
|
||||
hosts := []kolide.Host{}
|
||||
var host *kolide.Host
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
host, err = db.EnrollHost(string(i), "foo", "", "", 10)
|
||||
assert.Nil(t, err, "enrollment should succeed")
|
||||
hosts = append(hosts, *host)
|
||||
}
|
||||
|
||||
baseTime := time.Now()
|
||||
|
||||
// No queries should be returned before labels or queries added
|
||||
queries, err := db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, queries)
|
||||
|
||||
// No labels should match
|
||||
labels, err := db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
|
||||
labelQueries := []kolide.Query{
|
||||
kolide.Query{
|
||||
Name: "query1",
|
||||
Query: "query1",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query2",
|
||||
Query: "query2",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query3",
|
||||
Query: "query3",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query4",
|
||||
Query: "query4",
|
||||
Platform: "darwin",
|
||||
},
|
||||
}
|
||||
|
||||
for _, query := range labelQueries {
|
||||
newQuery, err := db.NewQuery(&query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotZero(t, newQuery.ID)
|
||||
}
|
||||
|
||||
// this one should not show up
|
||||
_, err = db.NewQuery(&kolide.Query{
|
||||
Platform: "not_darwin",
|
||||
Query: "query5",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// No queries should be returned before labels added
|
||||
queries, err = db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, queries)
|
||||
|
||||
newLabels := []kolide.Label{
|
||||
// Note these are intentionally out of order
|
||||
kolide.Label{
|
||||
Name: "label3",
|
||||
QueryID: 3,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label1",
|
||||
QueryID: 1,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label2",
|
||||
QueryID: 2,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label4",
|
||||
QueryID: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, label := range newLabels {
|
||||
newLabel, err := db.NewLabel(&label)
|
||||
assert.Nil(t, err)
|
||||
assert.NotZero(t, newLabel.ID)
|
||||
}
|
||||
|
||||
expectQueries := map[string]string{
|
||||
"1": "query3",
|
||||
"2": "query1",
|
||||
"3": "query2",
|
||||
"4": "query4",
|
||||
}
|
||||
|
||||
host.Platform = "darwin"
|
||||
|
||||
// Now queries should be returned
|
||||
queries, err = db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// No labels should match with no results yet
|
||||
labels, err = db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
|
||||
// Record a query execution
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Use a 10 minute interval, so the query we just added should show up
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
delete(expectQueries, "1")
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Record an old query execution -- Shouldn't change the return
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour))
|
||||
assert.Nil(t, err)
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Record a newer execution for that query and another
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Now these should no longer show up in the necessary to run queries
|
||||
delete(expectQueries, "2")
|
||||
delete(expectQueries, "3")
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Now the two matching labels should be returned
|
||||
labels, err = db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
if assert.Len(t, labels, 2) {
|
||||
labelNames := []string{labels[0].Name, labels[1].Name}
|
||||
sort.Strings(labelNames)
|
||||
assert.Equal(t, "label2", labelNames[0])
|
||||
assert.Equal(t, "label3", labelNames[1])
|
||||
}
|
||||
|
||||
// A host that hasn't executed any label queries should still be asked
|
||||
// to execute those queries
|
||||
hosts[0].Platform = "darwin"
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 4)
|
||||
|
||||
// There should still be no labels returned for a host that never
|
||||
// executed any label queries
|
||||
labels, err = db.LabelsForHost(hosts[0].ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
}
|
||||
|
||||
func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) {
|
||||
mysqlQuery := &kolide.Query{
|
||||
Name: "MySQL",
|
||||
Query: "select pid from processes where name = 'mysqld';",
|
||||
}
|
||||
mysqlQuery, err := ds.NewQuery(mysqlQuery)
|
||||
assert.Nil(t, err)
|
||||
|
||||
osqueryRunningQuery := &kolide.Query{
|
||||
Name: "Is osquery currently running?",
|
||||
Query: "select pid from processes where name = 'osqueryd';",
|
||||
}
|
||||
osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery)
|
||||
assert.Nil(t, err)
|
||||
|
||||
monitoringPack := &kolide.Pack{
|
||||
Name: "monitoring",
|
||||
}
|
||||
err = ds.NewPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mysqlLabel := &kolide.Label{
|
||||
Name: "MySQL Monitoring",
|
||||
QueryID: mysqlQuery.ID,
|
||||
}
|
||||
mysqlLabel, err = ds.NewLabel(mysqlLabel)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
labels, err := ds.GetLabelsForPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, labels, 1)
|
||||
assert.Equal(t, "MySQL Monitoring", labels[0].Name)
|
||||
|
||||
osqueryLabel := &kolide.Label{
|
||||
Name: "Osquery Monitoring",
|
||||
QueryID: osqueryRunningQuery.ID,
|
||||
}
|
||||
osqueryLabel, err = ds.NewLabel(osqueryLabel)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
labels, err = ds.GetLabelsForPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, labels, 2)
|
||||
}
|
||||
65
server/datastore/datastore_packs_test.go
Normal file
65
server/datastore/datastore_packs_test.go
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testDeletePack(t *testing.T, ds kolide.Datastore) {
|
||||
pack := &kolide.Pack{
|
||||
Name: "foo",
|
||||
}
|
||||
err := ds.NewPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, pack.ID, 0)
|
||||
|
||||
pack, err = ds.Pack(pack.ID)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = ds.DeletePack(pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotEqual(t, pack.ID, 0)
|
||||
pack, err = ds.Pack(pack.ID)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
|
||||
pack := &kolide.Pack{
|
||||
Name: "foo",
|
||||
}
|
||||
err := ds.NewPack(pack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
q1 := &kolide.Query{
|
||||
Name: "bar",
|
||||
Query: "bar",
|
||||
}
|
||||
_, err = ds.NewQuery(q1)
|
||||
assert.Nil(t, err)
|
||||
err = ds.AddQueryToPack(q1.ID, pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
q2 := &kolide.Query{
|
||||
Name: "baz",
|
||||
Query: "baz",
|
||||
}
|
||||
_, err = ds.NewQuery(q2)
|
||||
assert.Nil(t, err)
|
||||
err = ds.AddQueryToPack(q2.ID, pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queries, err := ds.GetQueriesInPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 2)
|
||||
|
||||
err = ds.RemoveQueryFromPack(q1, pack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queries, err = ds.GetQueriesInPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 1)
|
||||
}
|
||||
33
server/datastore/datastore_password_reset_test.go
Normal file
33
server/datastore/datastore_password_reset_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testPasswordResetRequests(t *testing.T, db kolide.Datastore) {
|
||||
createTestUsers(t, db)
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(time.Hour * 24)
|
||||
var passwordResetTests = []struct {
|
||||
userID uint
|
||||
expires time.Time
|
||||
token string
|
||||
}{
|
||||
{userID: 1, expires: tomorrow, token: "abcd"},
|
||||
}
|
||||
|
||||
for _, tt := range passwordResetTests {
|
||||
r := &kolide.PasswordResetRequest{
|
||||
UserID: tt.userID,
|
||||
ExpiresAt: tt.expires,
|
||||
Token: tt.token,
|
||||
}
|
||||
req, err := db.NewPasswordResetRequest(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.userID, req.UserID)
|
||||
}
|
||||
}
|
||||
44
server/datastore/datastore_queries_test.go
Normal file
44
server/datastore/datastore_queries_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testDeleteQuery(t *testing.T, ds kolide.Datastore) {
|
||||
query := &kolide.Query{
|
||||
Name: "foo",
|
||||
Query: "bar",
|
||||
}
|
||||
query, err := ds.NewQuery(query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, query.ID, 0)
|
||||
|
||||
err = ds.DeleteQuery(query)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotEqual(t, query.ID, 0)
|
||||
_, err = ds.Query(query.ID)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func testSaveQuery(t *testing.T, ds kolide.Datastore) {
|
||||
query := &kolide.Query{
|
||||
Name: "foo",
|
||||
Query: "bar",
|
||||
}
|
||||
query, err := ds.NewQuery(query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, 0, query.ID)
|
||||
|
||||
query.Query = "baz"
|
||||
err = ds.SaveQuery(query)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
queryVerify, err := ds.Query(query.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "baz", queryVerify.Query)
|
||||
}
|
||||
|
|
@ -1,626 +1,35 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const bcryptCost = 6
|
||||
|
||||
func TestPasswordResetRequests(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testPasswordResetRequests(t, db)
|
||||
}
|
||||
|
||||
func testPasswordResetRequests(t *testing.T, db kolide.Datastore) {
|
||||
createTestUsers(t, db)
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(time.Hour * 24)
|
||||
var passwordResetTests = []struct {
|
||||
userID uint
|
||||
expires time.Time
|
||||
token string
|
||||
}{
|
||||
{userID: 1, expires: tomorrow, token: "abcd"},
|
||||
}
|
||||
|
||||
for _, tt := range passwordResetTests {
|
||||
r := &kolide.PasswordResetRequest{
|
||||
UserID: tt.userID,
|
||||
ExpiresAt: tt.expires,
|
||||
Token: tt.token,
|
||||
}
|
||||
req, err := db.NewPasswordResetRequest(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.userID, req.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnrollHost(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testEnrollHost(t, db)
|
||||
func functionName(f func(*testing.T, kolide.Datastore)) string {
|
||||
fullName := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
|
||||
elements := strings.Split(fullName, ".")
|
||||
return elements[len(elements)-1]
|
||||
|
||||
}
|
||||
|
||||
func TestAuthenticateHost(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testAuthenticateHost(t, db)
|
||||
|
||||
}
|
||||
|
||||
var enrollTests = []struct {
|
||||
uuid, hostname, ip, platform string
|
||||
nodeKeySize int
|
||||
}{
|
||||
0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
|
||||
hostname: "web.kolide.co",
|
||||
ip: "172.0.0.1",
|
||||
platform: "linux",
|
||||
nodeKeySize: 12,
|
||||
},
|
||||
1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
|
||||
hostname: "mail.kolide.co",
|
||||
ip: "172.0.0.2",
|
||||
platform: "linux",
|
||||
nodeKeySize: 10,
|
||||
},
|
||||
2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
|
||||
hostname: "home.kolide.co",
|
||||
ip: "127.0.0.1",
|
||||
platform: "darwin",
|
||||
nodeKeySize: 25,
|
||||
},
|
||||
3: {uuid: "uuid123",
|
||||
hostname: "fakehostname",
|
||||
ip: "192.168.1.1",
|
||||
platform: "darwin",
|
||||
nodeKeySize: 1,
|
||||
},
|
||||
}
|
||||
|
||||
func testEnrollHost(t *testing.T, db kolide.HostStore) {
|
||||
var hosts []*kolide.Host
|
||||
for _, tt := range enrollTests {
|
||||
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
|
||||
assert.Nil(t, err)
|
||||
|
||||
hosts = append(hosts, h)
|
||||
assert.Equal(t, tt.uuid, h.UUID)
|
||||
assert.Equal(t, tt.hostname, h.HostName)
|
||||
assert.Equal(t, tt.ip, h.IPAddress)
|
||||
assert.Equal(t, tt.platform, h.Platform)
|
||||
assert.NotEmpty(t, h.NodeKey)
|
||||
}
|
||||
|
||||
for _, enrolled := range hosts {
|
||||
oldNodeKey := enrolled.NodeKey
|
||||
newhostname := fmt.Sprintf("changed.%s", enrolled.HostName)
|
||||
|
||||
h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, enrolled.UUID, h.UUID)
|
||||
assert.NotEmpty(t, h.NodeKey)
|
||||
assert.NotEqual(t, oldNodeKey, h.NodeKey)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func testAuthenticateHost(t *testing.T, db kolide.HostStore) {
|
||||
for _, tt := range enrollTests {
|
||||
h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize)
|
||||
assert.Nil(t, err)
|
||||
|
||||
returned, err := db.AuthenticateHost(h.NodeKey)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, h.NodeKey, returned.NodeKey)
|
||||
}
|
||||
|
||||
_, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA")
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = db.AuthenticateHost("")
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
// TestUser tests the UserStore interface
|
||||
// this test uses the default testing backend
|
||||
func TestCreateUser(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testCreateUser(t, db)
|
||||
}
|
||||
|
||||
func TestSaveUser(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testSaveUser(t, db)
|
||||
}
|
||||
|
||||
func testCreateUser(t *testing.T, db kolide.UserStore) {
|
||||
var createTests = []struct {
|
||||
username, password, email string
|
||||
isAdmin, passwordReset bool
|
||||
}{
|
||||
{"marpaia", "foobar", "mike@kolide.co", true, false},
|
||||
{"jason", "foobar", "jason@kolide.co", true, false},
|
||||
}
|
||||
|
||||
for _, tt := range createTests {
|
||||
u := &kolide.User{
|
||||
Username: tt.username,
|
||||
Password: []byte(tt.password),
|
||||
Admin: tt.isAdmin,
|
||||
AdminForcedPasswordReset: tt.passwordReset,
|
||||
Email: tt.email,
|
||||
}
|
||||
user, err := db.NewUser(u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := db.User(tt.username)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, user.ID, verify.ID)
|
||||
assert.Equal(t, tt.username, verify.Username)
|
||||
assert.Equal(t, tt.email, verify.Email)
|
||||
assert.Equal(t, tt.email, verify.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserByID(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testUserByID(t, db)
|
||||
}
|
||||
|
||||
func testUserByID(t *testing.T, db kolide.UserStore) {
|
||||
users := createTestUsers(t, db)
|
||||
for _, tt := range users {
|
||||
returned, err := db.UserByID(tt.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.ID, returned.ID)
|
||||
}
|
||||
|
||||
// test missing user
|
||||
_, err := db.UserByID(10000000000)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func createTestUsers(t *testing.T, db kolide.UserStore) []*kolide.User {
|
||||
var createTests = []struct {
|
||||
username, password, email string
|
||||
isAdmin, passwordReset bool
|
||||
}{
|
||||
{"marpaia", "foobar", "mike@kolide.co", true, false},
|
||||
{"jason", "foobar", "jason@kolide.co", false, false},
|
||||
}
|
||||
|
||||
var users []*kolide.User
|
||||
for _, tt := range createTests {
|
||||
u := &kolide.User{
|
||||
Username: tt.username,
|
||||
Password: []byte(tt.password),
|
||||
Admin: tt.isAdmin,
|
||||
AdminForcedPasswordReset: tt.passwordReset,
|
||||
Email: tt.email,
|
||||
}
|
||||
|
||||
user, err := db.NewUser(u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
assert.NotEmpty(t, users)
|
||||
return users
|
||||
}
|
||||
|
||||
func testSaveUser(t *testing.T, db kolide.UserStore) {
|
||||
users := createTestUsers(t, db)
|
||||
testAdminAttribute(t, db, users)
|
||||
testEmailAttribute(t, db, users)
|
||||
testPasswordAttribute(t, db, users)
|
||||
}
|
||||
|
||||
func testPasswordAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
user.Password = []byte(randomString(8))
|
||||
err := db.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := db.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Password, verify.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func testEmailAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
user.Email = fmt.Sprintf("test.%s", user.Email)
|
||||
err := db.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := db.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Email, verify.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func testAdminAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
user.Admin = false
|
||||
err := db.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := db.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Admin, verify.Admin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLabelQueries(t *testing.T) {
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testLabels(t, db)
|
||||
}
|
||||
|
||||
func testLabels(t *testing.T, db kolide.Datastore) {
|
||||
hosts := []kolide.Host{}
|
||||
var host *kolide.Host
|
||||
var err error
|
||||
for i := 0; i < 10; i++ {
|
||||
host, err = db.EnrollHost(string(i), "foo", "", "", 10)
|
||||
assert.Nil(t, err, "enrollment should succeed")
|
||||
hosts = append(hosts, *host)
|
||||
}
|
||||
|
||||
baseTime := time.Now()
|
||||
|
||||
// No queries should be returned before labels or queries added
|
||||
queries, err := db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, queries)
|
||||
|
||||
// No labels should match
|
||||
labels, err := db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
|
||||
labelQueries := []kolide.Query{
|
||||
kolide.Query{
|
||||
Name: "query1",
|
||||
Query: "query1",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query2",
|
||||
Query: "query2",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query3",
|
||||
Query: "query3",
|
||||
Platform: "darwin",
|
||||
},
|
||||
kolide.Query{
|
||||
Name: "query4",
|
||||
Query: "query4",
|
||||
Platform: "darwin",
|
||||
},
|
||||
}
|
||||
|
||||
for _, query := range labelQueries {
|
||||
newQuery, err := db.NewQuery(&query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotZero(t, newQuery.ID)
|
||||
}
|
||||
|
||||
// this one should not show up
|
||||
_, err = db.NewQuery(&kolide.Query{
|
||||
Platform: "not_darwin",
|
||||
Query: "query5",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// No queries should be returned before labels added
|
||||
queries, err = db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, queries)
|
||||
|
||||
newLabels := []kolide.Label{
|
||||
// Note these are intentionally out of order
|
||||
kolide.Label{
|
||||
Name: "label3",
|
||||
QueryID: 3,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label1",
|
||||
QueryID: 1,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label2",
|
||||
QueryID: 2,
|
||||
},
|
||||
kolide.Label{
|
||||
Name: "label4",
|
||||
QueryID: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, label := range newLabels {
|
||||
newLabel, err := db.NewLabel(&label)
|
||||
assert.Nil(t, err)
|
||||
assert.NotZero(t, newLabel.ID)
|
||||
}
|
||||
|
||||
expectQueries := map[string]string{
|
||||
"1": "query3",
|
||||
"2": "query1",
|
||||
"3": "query2",
|
||||
"4": "query4",
|
||||
}
|
||||
|
||||
host.Platform = "darwin"
|
||||
|
||||
// Now queries should be returned
|
||||
queries, err = db.LabelQueriesForHost(host, baseTime)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// No labels should match with no results yet
|
||||
labels, err = db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
|
||||
// Record a query execution
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Use a 10 minute interval, so the query we just added should show up
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
delete(expectQueries, "1")
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Record an old query execution -- Shouldn't change the return
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour))
|
||||
assert.Nil(t, err)
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Record a newer execution for that query and another
|
||||
err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Now these should no longer show up in the necessary to run queries
|
||||
delete(expectQueries, "2")
|
||||
delete(expectQueries, "3")
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute)))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectQueries, queries)
|
||||
|
||||
// Now the two matching labels should be returned
|
||||
labels, err = db.LabelsForHost(host.ID)
|
||||
assert.Nil(t, err)
|
||||
if assert.Len(t, labels, 2) {
|
||||
labelNames := []string{labels[0].Name, labels[1].Name}
|
||||
sort.Strings(labelNames)
|
||||
assert.Equal(t, "label2", labelNames[0])
|
||||
assert.Equal(t, "label3", labelNames[1])
|
||||
}
|
||||
|
||||
// A host that hasn't executed any label queries should still be asked
|
||||
// to execute those queries
|
||||
hosts[0].Platform = "darwin"
|
||||
queries, err = db.LabelQueriesForHost(host, time.Now())
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 4)
|
||||
|
||||
// There should still be no labels returned for a host that never
|
||||
// executed any label queries
|
||||
labels, err = db.LabelsForHost(hosts[0].ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, labels)
|
||||
}
|
||||
|
||||
// setup creates a datastore for testing
|
||||
func setup(t *testing.T) kolide.Datastore {
|
||||
db, err := gorm.Open("sqlite3", ":memory:")
|
||||
require.Nil(t, err)
|
||||
|
||||
ds := gormDB{DB: db, Driver: "sqlite3"}
|
||||
|
||||
err = ds.Migrate()
|
||||
assert.Nil(t, err)
|
||||
// Log using t.Log so that output only shows up if the test fails
|
||||
//db.SetLogger(&testLogger{t: t})
|
||||
//db.LogMode(true)
|
||||
return ds
|
||||
}
|
||||
|
||||
func teardown(t *testing.T, ds kolide.Datastore) {
|
||||
err := ds.Drop()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
type testLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (t *testLogger) Print(v ...interface{}) {
|
||||
t.t.Log(v...)
|
||||
}
|
||||
|
||||
func (t *testLogger) Write(p []byte) (n int, err error) {
|
||||
t.t.Log(string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func randomString(strlen int) string {
|
||||
rand.Seed(time.Now().UTC().UnixNano())
|
||||
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
result := make([]byte, strlen)
|
||||
for i := 0; i < strlen; i++ {
|
||||
result[i] = chars[rand.Intn(len(chars))]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func testSaveQuery(t *testing.T, ds kolide.Datastore) {
|
||||
query := &kolide.Query{
|
||||
Name: "foo",
|
||||
Query: "bar",
|
||||
}
|
||||
query, err := ds.NewQuery(query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, 0, query.ID)
|
||||
|
||||
query.Query = "baz"
|
||||
err = ds.SaveQuery(query)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queryVerify, err := ds.Query(query.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "baz", queryVerify.Query)
|
||||
}
|
||||
|
||||
func testDeleteQuery(t *testing.T, ds kolide.Datastore) {
|
||||
query := &kolide.Query{
|
||||
Name: "foo",
|
||||
Query: "bar",
|
||||
}
|
||||
query, err := ds.NewQuery(query)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, query.ID, 0)
|
||||
|
||||
err = ds.DeleteQuery(query)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotEqual(t, query.ID, 0)
|
||||
_, err = ds.Query(query.ID)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func testDeletePack(t *testing.T, ds kolide.Datastore) {
|
||||
pack := &kolide.Pack{
|
||||
Name: "foo",
|
||||
}
|
||||
err := ds.NewPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEqual(t, pack.ID, 0)
|
||||
|
||||
pack, err = ds.Pack(pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = ds.DeletePack(pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotEqual(t, pack.ID, 0)
|
||||
pack, err = ds.Pack(pack.ID)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) {
|
||||
pack := &kolide.Pack{
|
||||
Name: "foo",
|
||||
}
|
||||
err := ds.NewPack(pack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
q1 := &kolide.Query{
|
||||
Name: "bar",
|
||||
Query: "bar",
|
||||
}
|
||||
_, err = ds.NewQuery(q1)
|
||||
assert.Nil(t, err)
|
||||
err = ds.AddQueryToPack(q1.ID, pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
q2 := &kolide.Query{
|
||||
Name: "baz",
|
||||
Query: "baz",
|
||||
}
|
||||
_, err = ds.NewQuery(q2)
|
||||
assert.Nil(t, err)
|
||||
err = ds.AddQueryToPack(q2.ID, pack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queries, err := ds.GetQueriesInPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 2)
|
||||
|
||||
err = ds.RemoveQueryFromPack(q1, pack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
queries, err = ds.GetQueriesInPack(pack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, queries, 1)
|
||||
}
|
||||
|
||||
func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) {
|
||||
mysqlQuery := &kolide.Query{
|
||||
Name: "MySQL",
|
||||
Query: "select pid from processes where name = 'mysqld';",
|
||||
}
|
||||
mysqlQuery, err := ds.NewQuery(mysqlQuery)
|
||||
assert.Nil(t, err)
|
||||
|
||||
osqueryRunningQuery := &kolide.Query{
|
||||
Name: "Is osquery currently running?",
|
||||
Query: "select pid from processes where name = 'osqueryd';",
|
||||
}
|
||||
osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery)
|
||||
assert.Nil(t, err)
|
||||
|
||||
monitoringPack := &kolide.Pack{
|
||||
Name: "monitoring",
|
||||
}
|
||||
err = ds.NewPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mysqlLabel := &kolide.Label{
|
||||
Name: "MySQL Monitoring",
|
||||
QueryID: mysqlQuery.ID,
|
||||
}
|
||||
mysqlLabel, err = ds.NewLabel(mysqlLabel)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
labels, err := ds.GetLabelsForPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, labels, 1)
|
||||
assert.Equal(t, "MySQL Monitoring", labels[0].Name)
|
||||
|
||||
osqueryLabel := &kolide.Label{
|
||||
Name: "Osquery Monitoring",
|
||||
QueryID: osqueryRunningQuery.ID,
|
||||
}
|
||||
osqueryLabel, err = ds.NewLabel(osqueryLabel)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID)
|
||||
assert.Nil(t, err)
|
||||
|
||||
labels, err = ds.GetLabelsForPack(monitoringPack)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, labels, 2)
|
||||
var testFunctions = [...]func(*testing.T, kolide.Datastore){
|
||||
testOrgInfo,
|
||||
testCreateInvite,
|
||||
testDeleteQuery,
|
||||
testSaveQuery,
|
||||
testDeletePack,
|
||||
testAddAndRemoveQueryFromPack,
|
||||
testEnrollHost,
|
||||
testAuthenticateHost,
|
||||
testLabels,
|
||||
testManagingLabelsOnPacks,
|
||||
testPasswordResetRequests,
|
||||
testCreateUser,
|
||||
testSaveUser,
|
||||
testUserByID,
|
||||
testPasswordResetRequests,
|
||||
}
|
||||
|
|
|
|||
125
server/datastore/datastore_users_test.go
Normal file
125
server/datastore/datastore_users_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testCreateUser(t *testing.T, ds kolide.Datastore) {
|
||||
var createTests = []struct {
|
||||
username, password, email string
|
||||
isAdmin, passwordReset bool
|
||||
}{
|
||||
{"marpaia", "foobar", "mike@kolide.co", true, false},
|
||||
{"jason", "foobar", "jason@kolide.co", true, false},
|
||||
}
|
||||
|
||||
for _, tt := range createTests {
|
||||
u := &kolide.User{
|
||||
Username: tt.username,
|
||||
Password: []byte(tt.password),
|
||||
Admin: tt.isAdmin,
|
||||
AdminForcedPasswordReset: tt.passwordReset,
|
||||
Email: tt.email,
|
||||
}
|
||||
user, err := ds.NewUser(u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := ds.User(tt.username)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, user.ID, verify.ID)
|
||||
assert.Equal(t, tt.username, verify.Username)
|
||||
assert.Equal(t, tt.email, verify.Email)
|
||||
assert.Equal(t, tt.email, verify.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func testUserByID(t *testing.T, ds kolide.Datastore) {
|
||||
users := createTestUsers(t, ds)
|
||||
for _, tt := range users {
|
||||
returned, err := ds.UserByID(tt.ID)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, tt.ID, returned.ID)
|
||||
}
|
||||
|
||||
// test missing user
|
||||
_, err := ds.UserByID(10000000000)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func createTestUsers(t *testing.T, ds kolide.Datastore) []*kolide.User {
|
||||
var createTests = []struct {
|
||||
username, password, email string
|
||||
isAdmin, passwordReset bool
|
||||
}{
|
||||
{"marpaia", "foobar", "mike@kolide.co", true, false},
|
||||
{"jason", "foobar", "jason@kolide.co", false, false},
|
||||
}
|
||||
|
||||
var users []*kolide.User
|
||||
for _, tt := range createTests {
|
||||
u := &kolide.User{
|
||||
Username: tt.username,
|
||||
Password: []byte(tt.password),
|
||||
Admin: tt.isAdmin,
|
||||
AdminForcedPasswordReset: tt.passwordReset,
|
||||
Email: tt.email,
|
||||
}
|
||||
|
||||
user, err := ds.NewUser(u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
assert.NotEmpty(t, users)
|
||||
return users
|
||||
}
|
||||
|
||||
func testSaveUser(t *testing.T, ds kolide.Datastore) {
|
||||
users := createTestUsers(t, ds)
|
||||
testAdminAttribute(t, ds, users)
|
||||
testEmailAttribute(t, ds, users)
|
||||
testPasswordAttribute(t, ds, users)
|
||||
}
|
||||
|
||||
func testPasswordAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
randomText, err := generateRandomText(8)
|
||||
assert.Nil(t, err)
|
||||
user.Password = []byte(randomText)
|
||||
err = ds.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := ds.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Password, verify.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func testEmailAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
user.Email = fmt.Sprintf("test.%s", user.Email)
|
||||
err := ds.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := ds.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Email, verify.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func testAdminAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) {
|
||||
for _, user := range users {
|
||||
user.Admin = false
|
||||
err := ds.SaveUser(user)
|
||||
assert.Nil(t, err)
|
||||
|
||||
verify, err := ds.User(user.Username)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, user.Admin, verify.Admin)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,7 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // db driver
|
||||
|
|
@ -14,7 +9,6 @@ import (
|
|||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/config"
|
||||
"github.com/kolide/kolide-ose/server/errors"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
|
|
@ -43,18 +37,6 @@ type gormDB struct {
|
|||
config config.KolideConfig
|
||||
}
|
||||
|
||||
// GetMysqlConnectionString returns a MySQL connection string using the
|
||||
// provided configuration.
|
||||
func GetMysqlConnectionString(conf config.MysqlConfig) string {
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
|
||||
conf.Username,
|
||||
conf.Password,
|
||||
conf.Address,
|
||||
conf.Database,
|
||||
)
|
||||
}
|
||||
|
||||
func (orm gormDB) Name() string {
|
||||
return "gorm"
|
||||
}
|
||||
|
|
@ -102,579 +84,3 @@ func openGORM(driver, conn string, maxAttempts int) (*gorm.DB, error) {
|
|||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func generateRandomText(keySize int) (string, error) {
|
||||
key := make([]byte, keySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) {
|
||||
if uuid == "" {
|
||||
return nil, errors.New("missing uuid for host enrollment", "programmer error?")
|
||||
}
|
||||
host := kolide.Host{UUID: uuid}
|
||||
err := orm.DB.Where(&host).First(&host).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
// Create new Host
|
||||
host = kolide.Host{
|
||||
UUID: uuid,
|
||||
HostName: hostname,
|
||||
IPAddress: ip,
|
||||
Platform: platform,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new key each enrollment
|
||||
host.NodeKey, err = generateRandomText(nodeKeySize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update these fields if provided
|
||||
if hostname != "" {
|
||||
host.HostName = hostname
|
||||
}
|
||||
if ip != "" {
|
||||
host.IPAddress = ip
|
||||
}
|
||||
if platform != "" {
|
||||
host.Platform = platform
|
||||
}
|
||||
|
||||
if err := orm.DB.Save(&host).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) {
|
||||
host := kolide.Host{NodeKey: nodeKey}
|
||||
err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
e := errors.NewFromError(
|
||||
err,
|
||||
http.StatusUnauthorized,
|
||||
"invalid node key",
|
||||
)
|
||||
// osqueryd expects the literal string "true" here
|
||||
e.Extra = map[string]interface{}{"node_invalid": "true"}
|
||||
return nil, e
|
||||
default:
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveHost(host *kolide.Host) error {
|
||||
if err := orm.DB.Save(host).Error; err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteHost(host *kolide.Host) error {
|
||||
return orm.DB.Delete(host).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Host(id uint) (*kolide.Host, error) {
|
||||
host := &kolide.Host{
|
||||
ID: id,
|
||||
}
|
||||
err := orm.DB.Where(host).First(host).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Hosts() ([]*kolide.Host, error) {
|
||||
var hosts []*kolide.Host
|
||||
err := orm.DB.Find(&hosts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) {
|
||||
if host == nil {
|
||||
return nil, errors.New(
|
||||
"error creating host",
|
||||
"nil pointer passed to NewHost",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(host).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host, err
|
||||
}
|
||||
|
||||
func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error {
|
||||
err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error
|
||||
if err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
host.UpdatedAt = t
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) {
|
||||
if query == nil {
|
||||
return nil, errors.New(
|
||||
"error creating query",
|
||||
"nil pointer passed to NewQuery",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(query).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveQuery(query *kolide.Query) error {
|
||||
if query == nil {
|
||||
return errors.New(
|
||||
"error saving query",
|
||||
"nil pointer passed to SaveQuery",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(query).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteQuery(query *kolide.Query) error {
|
||||
if query == nil {
|
||||
return errors.New(
|
||||
"error deleting query",
|
||||
"nil pointer passed to DeleteQuery",
|
||||
)
|
||||
}
|
||||
return orm.DB.Delete(query).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Query(id uint) (*kolide.Query, error) {
|
||||
query := &kolide.Query{
|
||||
ID: id,
|
||||
}
|
||||
err := orm.DB.Where(query).First(query).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Queries() ([]*kolide.Query, error) {
|
||||
var queries []*kolide.Query
|
||||
err := orm.DB.Find(&queries).Error
|
||||
return queries, err
|
||||
}
|
||||
|
||||
func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) {
|
||||
if label == nil {
|
||||
return nil, errors.New(
|
||||
"error creating label",
|
||||
"nil pointer passed to NewLabel",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(label).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveLabel(label *kolide.Label) error {
|
||||
if label == nil {
|
||||
return errors.New(
|
||||
"error saving label",
|
||||
"nil pointer passed to SaveLabel",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(label).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteLabel(lid uint) error {
|
||||
err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Label(lid uint) (*kolide.Label, error) {
|
||||
label := &kolide.Label{
|
||||
ID: lid,
|
||||
}
|
||||
err := orm.DB.Where("id = ?", label.ID).First(&label).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Labels() ([]*kolide.Label, error) {
|
||||
var labels []*kolide.Label
|
||||
err := orm.DB.Find(&labels).Error
|
||||
return labels, err
|
||||
}
|
||||
|
||||
func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {
|
||||
if host == nil {
|
||||
return nil, errors.New(
|
||||
"error finding host queries",
|
||||
"nil pointer passed to LabelQueriesForHost",
|
||||
)
|
||||
}
|
||||
rows, err := orm.DB.Raw(`
|
||||
SELECT l.id, q.query
|
||||
FROM labels l JOIN queries q
|
||||
ON l.query_id = q.id
|
||||
WHERE q.platform = ?
|
||||
AND q.id NOT IN /* subtract the set of executions that are recent enough */
|
||||
(
|
||||
SELECT l.query_id
|
||||
FROM labels l
|
||||
JOIN label_query_executions lqe
|
||||
ON lqe.label_id = l.id
|
||||
WHERE lqe.host_id = ? AND lqe.updated_at > ?
|
||||
)`, host.Platform, host.ID, cutoff).Rows()
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make(map[string]string)
|
||||
for rows.Next() {
|
||||
var id, query string
|
||||
err = rows.Scan(&id, &query)
|
||||
if err != nil {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
results[id] = query
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error {
|
||||
if host == nil {
|
||||
return errors.New(
|
||||
"error recording host label query execution",
|
||||
"nil pointer passed to RecordLabelQueryExecutions",
|
||||
)
|
||||
}
|
||||
|
||||
insert := new(bytes.Buffer)
|
||||
switch orm.Driver {
|
||||
case "mysql":
|
||||
insert.WriteString("INSERT ")
|
||||
case "sqlite3":
|
||||
insert.WriteString("REPLACE ")
|
||||
default:
|
||||
return errors.New(
|
||||
"Unknown DB driver",
|
||||
"Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver,
|
||||
)
|
||||
}
|
||||
|
||||
insert.WriteString(
|
||||
"INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES",
|
||||
)
|
||||
|
||||
// Build up all the values and the query string
|
||||
vals := []interface{}{}
|
||||
for labelId, res := range results {
|
||||
insert.WriteString("(?,?,?,?),")
|
||||
vals = append(vals, t, res, labelId, host.ID)
|
||||
}
|
||||
|
||||
queryString := insert.String()
|
||||
queryString = strings.TrimSuffix(queryString, ",")
|
||||
|
||||
switch orm.Driver {
|
||||
case "mysql":
|
||||
queryString += `
|
||||
ON DUPLICATE KEY UPDATE
|
||||
updated_at = VALUES(updated_at),
|
||||
matches = VALUES(matches)
|
||||
`
|
||||
}
|
||||
|
||||
if err := orm.DB.Exec(queryString, vals...).Error; err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) {
|
||||
results := []kolide.Label{}
|
||||
err := orm.DB.Raw(`
|
||||
SELECT labels.* from labels, label_query_executions lqe
|
||||
WHERE lqe.host_id = ?
|
||||
AND lqe.label_id = labels.id
|
||||
AND lqe.matches
|
||||
`, hid).Scan(&results).Error
|
||||
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) NewPack(pack *kolide.Pack) error {
|
||||
if pack == nil {
|
||||
return errors.New(
|
||||
"error creating pack",
|
||||
"nil pointer passed to NewPack",
|
||||
)
|
||||
}
|
||||
return orm.DB.Create(pack).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) SavePack(pack *kolide.Pack) error {
|
||||
if pack == nil {
|
||||
return errors.New(
|
||||
"error saving pack",
|
||||
"nil pointer passed to SavePack",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(pack).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeletePack(pid uint) error {
|
||||
err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) {
|
||||
pack := &kolide.Pack{
|
||||
ID: pid,
|
||||
}
|
||||
err := orm.DB.Where(pack).First(pack).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Packs() ([]*kolide.Pack, error) {
|
||||
var packs []*kolide.Pack
|
||||
err := orm.DB.Find(&packs).Error
|
||||
return packs, err
|
||||
}
|
||||
|
||||
func (orm gormDB) AddQueryToPack(qid uint, pid uint) error {
|
||||
pq := &kolide.PackQuery{
|
||||
QueryID: qid,
|
||||
PackID: pid,
|
||||
}
|
||||
return orm.DB.Create(pq).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) {
|
||||
var queries []*kolide.Query
|
||||
if pack == nil {
|
||||
return nil, errors.New(
|
||||
"error getting queries in pack",
|
||||
"nil pointer passed to GetQueriesInPack",
|
||||
)
|
||||
}
|
||||
|
||||
rows, err := orm.DB.Raw(`
|
||||
SELECT
|
||||
q.id,
|
||||
q.created_at,
|
||||
q.updated_at,
|
||||
q.name,
|
||||
q.query,
|
||||
q.interval,
|
||||
q.snapshot,
|
||||
q.differential,
|
||||
q.platform,
|
||||
q.version
|
||||
FROM
|
||||
queries q
|
||||
JOIN
|
||||
pack_queries pq
|
||||
ON
|
||||
pq.query_id = q.id
|
||||
AND
|
||||
pq.pack_id = ?;
|
||||
`, pack.ID).Rows()
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
query := new(kolide.Query)
|
||||
err = rows.Scan(
|
||||
&query.ID,
|
||||
&query.CreatedAt,
|
||||
&query.UpdatedAt,
|
||||
&query.Name,
|
||||
&query.Query,
|
||||
&query.Interval,
|
||||
&query.Snapshot,
|
||||
&query.Differential,
|
||||
&query.Platform,
|
||||
&query.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queries = append(queries, query)
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error {
|
||||
if query == nil || pack == nil {
|
||||
return errors.New(
|
||||
"error removing query from pack",
|
||||
"nil pointer passed to RemoveQueryFromPack",
|
||||
)
|
||||
}
|
||||
pq := &kolide.PackQuery{
|
||||
QueryID: query.ID,
|
||||
PackID: pack.ID,
|
||||
}
|
||||
return orm.DB.Where(pq).Delete(pq).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) AddLabelToPack(lid uint, pid uint) error {
|
||||
pt := &kolide.PackTarget{
|
||||
Type: kolide.TargetLabel,
|
||||
PackID: pid,
|
||||
TargetID: lid,
|
||||
}
|
||||
|
||||
return orm.DB.Create(pt).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) {
|
||||
packs := []*kolide.Pack{}
|
||||
|
||||
// we will need to give some subset of packs to this host based on the
|
||||
// labels which this host is known to belong to
|
||||
allPacks, err := orm.Packs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// pull the labels that this host belongs to
|
||||
labels, err := orm.LabelsForHost(hid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double
|
||||
// for loop iteration, we must create the array which may be indexed below
|
||||
labelIDs := map[uint]bool{}
|
||||
for _, label := range labels {
|
||||
labelIDs[label.ID] = true
|
||||
}
|
||||
|
||||
for _, pack := range allPacks {
|
||||
// for each pack, we must know what labels have been assigned to that
|
||||
// pack
|
||||
labelsForPack, err := orm.GetLabelsForPack(pack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// o(n) iteration to determine whether or not a pack is enabled
|
||||
// in this case, n is len(labelsForPack)
|
||||
for _, label := range labelsForPack {
|
||||
if labelIDs[label.ID] {
|
||||
packs = append(packs, pack)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return packs, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) {
|
||||
if pack == nil {
|
||||
return nil, errors.New(
|
||||
"error getting labels for pack",
|
||||
"nil pointer passed to GetLabelsForPack",
|
||||
)
|
||||
}
|
||||
|
||||
results := []*kolide.Label{}
|
||||
err := orm.DB.Raw(`
|
||||
SELECT
|
||||
l.id,
|
||||
l.created_at,
|
||||
l.updated_at,
|
||||
l.name,
|
||||
l.query_id
|
||||
FROM
|
||||
labels l
|
||||
JOIN
|
||||
pack_targets pt
|
||||
ON
|
||||
pt.target_id = l.id
|
||||
WHERE
|
||||
pt.type = ?
|
||||
AND
|
||||
pt.pack_id = ?;
|
||||
|
||||
`,
|
||||
kolide.TargetLabel, pack.ID).Scan(&results).Error
|
||||
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error {
|
||||
if label == nil || pack == nil {
|
||||
return errors.New(
|
||||
"error removing label from pack",
|
||||
"nil pointer passed to RemoveLabelFromPack",
|
||||
)
|
||||
}
|
||||
|
||||
pt := &kolide.PackTarget{
|
||||
Type: kolide.TargetLabel,
|
||||
PackID: pack.ID,
|
||||
TargetID: label.ID,
|
||||
}
|
||||
|
||||
return orm.DB.Delete(pt).Error
|
||||
}
|
||||
|
|
|
|||
132
server/datastore/gorm_hosts.go
Normal file
132
server/datastore/gorm_hosts.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/errors"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) {
|
||||
if uuid == "" {
|
||||
return nil, errors.New("missing uuid for host enrollment", "programmer error?")
|
||||
}
|
||||
host := kolide.Host{UUID: uuid}
|
||||
err := orm.DB.Where(&host).First(&host).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
// Create new Host
|
||||
host = kolide.Host{
|
||||
UUID: uuid,
|
||||
HostName: hostname,
|
||||
IPAddress: ip,
|
||||
Platform: platform,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new key each enrollment
|
||||
host.NodeKey, err = generateRandomText(nodeKeySize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update these fields if provided
|
||||
if hostname != "" {
|
||||
host.HostName = hostname
|
||||
}
|
||||
if ip != "" {
|
||||
host.IPAddress = ip
|
||||
}
|
||||
if platform != "" {
|
||||
host.Platform = platform
|
||||
}
|
||||
|
||||
if err := orm.DB.Save(&host).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) {
|
||||
host := kolide.Host{NodeKey: nodeKey}
|
||||
err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
e := errors.NewFromError(
|
||||
err,
|
||||
http.StatusUnauthorized,
|
||||
"invalid node key",
|
||||
)
|
||||
// osqueryd expects the literal string "true" here
|
||||
e.Extra = map[string]interface{}{"node_invalid": "true"}
|
||||
return nil, e
|
||||
default:
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveHost(host *kolide.Host) error {
|
||||
if err := orm.DB.Save(host).Error; err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteHost(host *kolide.Host) error {
|
||||
return orm.DB.Delete(host).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Host(id uint) (*kolide.Host, error) {
|
||||
host := &kolide.Host{
|
||||
ID: id,
|
||||
}
|
||||
err := orm.DB.Where(host).First(host).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Hosts() ([]*kolide.Host, error) {
|
||||
var hosts []*kolide.Host
|
||||
err := orm.DB.Find(&hosts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) {
|
||||
if host == nil {
|
||||
return nil, errors.New(
|
||||
"error creating host",
|
||||
"nil pointer passed to NewHost",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(host).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host, err
|
||||
}
|
||||
|
||||
func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error {
|
||||
err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error
|
||||
if err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
host.UpdatedAt = t
|
||||
return nil
|
||||
}
|
||||
166
server/datastore/gorm_labels.go
Normal file
166
server/datastore/gorm_labels.go
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/errors"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) {
|
||||
if label == nil {
|
||||
return nil, errors.New(
|
||||
"error creating label",
|
||||
"nil pointer passed to NewLabel",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(label).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveLabel(label *kolide.Label) error {
|
||||
if label == nil {
|
||||
return errors.New(
|
||||
"error saving label",
|
||||
"nil pointer passed to SaveLabel",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(label).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteLabel(lid uint) error {
|
||||
err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Label(lid uint) (*kolide.Label, error) {
|
||||
label := &kolide.Label{
|
||||
ID: lid,
|
||||
}
|
||||
err := orm.DB.Where("id = ?", label.ID).First(&label).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Labels() ([]*kolide.Label, error) {
|
||||
var labels []*kolide.Label
|
||||
err := orm.DB.Find(&labels).Error
|
||||
return labels, err
|
||||
}
|
||||
|
||||
func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) {
|
||||
if host == nil {
|
||||
return nil, errors.New(
|
||||
"error finding host queries",
|
||||
"nil pointer passed to LabelQueriesForHost",
|
||||
)
|
||||
}
|
||||
rows, err := orm.DB.Raw(`
|
||||
SELECT l.id, q.query
|
||||
FROM labels l JOIN queries q
|
||||
ON l.query_id = q.id
|
||||
WHERE q.platform = ?
|
||||
AND q.id NOT IN /* subtract the set of executions that are recent enough */
|
||||
(
|
||||
SELECT l.query_id
|
||||
FROM labels l
|
||||
JOIN label_query_executions lqe
|
||||
ON lqe.label_id = l.id
|
||||
WHERE lqe.host_id = ? AND lqe.updated_at > ?
|
||||
)`, host.Platform, host.ID, cutoff).Rows()
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make(map[string]string)
|
||||
for rows.Next() {
|
||||
var id, query string
|
||||
err = rows.Scan(&id, &query)
|
||||
if err != nil {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
results[id] = query
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error {
|
||||
if host == nil {
|
||||
return errors.New(
|
||||
"error recording host label query execution",
|
||||
"nil pointer passed to RecordLabelQueryExecutions",
|
||||
)
|
||||
}
|
||||
|
||||
insert := new(bytes.Buffer)
|
||||
switch orm.Driver {
|
||||
case "mysql":
|
||||
insert.WriteString("INSERT ")
|
||||
case "sqlite3":
|
||||
insert.WriteString("REPLACE ")
|
||||
default:
|
||||
return errors.New(
|
||||
"Unknown DB driver",
|
||||
"Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver,
|
||||
)
|
||||
}
|
||||
|
||||
insert.WriteString(
|
||||
"INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES",
|
||||
)
|
||||
|
||||
// Build up all the values and the query string
|
||||
vals := []interface{}{}
|
||||
for labelId, res := range results {
|
||||
insert.WriteString("(?,?,?,?),")
|
||||
vals = append(vals, t, res, labelId, host.ID)
|
||||
}
|
||||
|
||||
queryString := insert.String()
|
||||
queryString = strings.TrimSuffix(queryString, ",")
|
||||
|
||||
switch orm.Driver {
|
||||
case "mysql":
|
||||
queryString += `
|
||||
ON DUPLICATE KEY UPDATE
|
||||
updated_at = VALUES(updated_at),
|
||||
matches = VALUES(matches)
|
||||
`
|
||||
}
|
||||
|
||||
if err := orm.DB.Exec(queryString, vals...).Error; err != nil {
|
||||
return errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) {
|
||||
results := []kolide.Label{}
|
||||
err := orm.DB.Raw(`
|
||||
SELECT labels.* from labels, label_query_executions lqe
|
||||
WHERE lqe.host_id = ?
|
||||
AND lqe.label_id = labels.id
|
||||
AND lqe.matches
|
||||
`, hid).Scan(&results).Error
|
||||
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
245
server/datastore/gorm_packs.go
Normal file
245
server/datastore/gorm_packs.go
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/errors"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm gormDB) NewPack(pack *kolide.Pack) error {
|
||||
if pack == nil {
|
||||
return errors.New(
|
||||
"error creating pack",
|
||||
"nil pointer passed to NewPack",
|
||||
)
|
||||
}
|
||||
return orm.DB.Create(pack).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) SavePack(pack *kolide.Pack) error {
|
||||
if pack == nil {
|
||||
return errors.New(
|
||||
"error saving pack",
|
||||
"nil pointer passed to SavePack",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(pack).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeletePack(pid uint) error {
|
||||
err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) {
|
||||
pack := &kolide.Pack{
|
||||
ID: pid,
|
||||
}
|
||||
err := orm.DB.Where(pack).First(pack).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Packs() ([]*kolide.Pack, error) {
|
||||
var packs []*kolide.Pack
|
||||
err := orm.DB.Find(&packs).Error
|
||||
return packs, err
|
||||
}
|
||||
|
||||
func (orm gormDB) AddQueryToPack(qid uint, pid uint) error {
|
||||
pq := &kolide.PackQuery{
|
||||
QueryID: qid,
|
||||
PackID: pid,
|
||||
}
|
||||
return orm.DB.Create(pq).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) {
|
||||
var queries []*kolide.Query
|
||||
if pack == nil {
|
||||
return nil, errors.New(
|
||||
"error getting queries in pack",
|
||||
"nil pointer passed to GetQueriesInPack",
|
||||
)
|
||||
}
|
||||
|
||||
rows, err := orm.DB.Raw(`
|
||||
SELECT
|
||||
q.id,
|
||||
q.created_at,
|
||||
q.updated_at,
|
||||
q.name,
|
||||
q.query,
|
||||
q.interval,
|
||||
q.snapshot,
|
||||
q.differential,
|
||||
q.platform,
|
||||
q.version
|
||||
FROM
|
||||
queries q
|
||||
JOIN
|
||||
pack_queries pq
|
||||
ON
|
||||
pq.query_id = q.id
|
||||
AND
|
||||
pq.pack_id = ?;
|
||||
`, pack.ID).Rows()
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
query := new(kolide.Query)
|
||||
err = rows.Scan(
|
||||
&query.ID,
|
||||
&query.CreatedAt,
|
||||
&query.UpdatedAt,
|
||||
&query.Name,
|
||||
&query.Query,
|
||||
&query.Interval,
|
||||
&query.Snapshot,
|
||||
&query.Differential,
|
||||
&query.Platform,
|
||||
&query.Version,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queries = append(queries, query)
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error {
|
||||
if query == nil || pack == nil {
|
||||
return errors.New(
|
||||
"error removing query from pack",
|
||||
"nil pointer passed to RemoveQueryFromPack",
|
||||
)
|
||||
}
|
||||
pq := &kolide.PackQuery{
|
||||
QueryID: query.ID,
|
||||
PackID: pack.ID,
|
||||
}
|
||||
return orm.DB.Where(pq).Delete(pq).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) AddLabelToPack(lid uint, pid uint) error {
|
||||
pt := &kolide.PackTarget{
|
||||
Type: kolide.TargetLabel,
|
||||
PackID: pid,
|
||||
TargetID: lid,
|
||||
}
|
||||
|
||||
return orm.DB.Create(pt).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) {
|
||||
packs := []*kolide.Pack{}
|
||||
|
||||
// we will need to give some subset of packs to this host based on the
|
||||
// labels which this host is known to belong to
|
||||
allPacks, err := orm.Packs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// pull the labels that this host belongs to
|
||||
labels, err := orm.LabelsForHost(hid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double
|
||||
// for loop iteration, we must create the array which may be indexed below
|
||||
labelIDs := map[uint]bool{}
|
||||
for _, label := range labels {
|
||||
labelIDs[label.ID] = true
|
||||
}
|
||||
|
||||
for _, pack := range allPacks {
|
||||
// for each pack, we must know what labels have been assigned to that
|
||||
// pack
|
||||
labelsForPack, err := orm.GetLabelsForPack(pack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// o(n) iteration to determine whether or not a pack is enabled
|
||||
// in this case, n is len(labelsForPack)
|
||||
for _, label := range labelsForPack {
|
||||
if labelIDs[label.ID] {
|
||||
packs = append(packs, pack)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return packs, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) {
|
||||
if pack == nil {
|
||||
return nil, errors.New(
|
||||
"error getting labels for pack",
|
||||
"nil pointer passed to GetLabelsForPack",
|
||||
)
|
||||
}
|
||||
|
||||
results := []*kolide.Label{}
|
||||
err := orm.DB.Raw(`
|
||||
SELECT
|
||||
l.id,
|
||||
l.created_at,
|
||||
l.updated_at,
|
||||
l.name,
|
||||
l.query_id
|
||||
FROM
|
||||
labels l
|
||||
JOIN
|
||||
pack_targets pt
|
||||
ON
|
||||
pt.target_id = l.id
|
||||
WHERE
|
||||
pt.type = ?
|
||||
AND
|
||||
pt.pack_id = ?;
|
||||
|
||||
`,
|
||||
kolide.TargetLabel, pack.ID).Scan(&results).Error
|
||||
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, errors.DatabaseError(err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error {
|
||||
if label == nil || pack == nil {
|
||||
return errors.New(
|
||||
"error removing label from pack",
|
||||
"nil pointer passed to RemoveLabelFromPack",
|
||||
)
|
||||
}
|
||||
|
||||
pt := &kolide.PackTarget{
|
||||
Type: kolide.TargetLabel,
|
||||
PackID: pack.ID,
|
||||
TargetID: label.ID,
|
||||
}
|
||||
|
||||
return orm.DB.Delete(pt).Error
|
||||
}
|
||||
57
server/datastore/gorm_queries.go
Normal file
57
server/datastore/gorm_queries.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"github.com/kolide/kolide-ose/server/errors"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) {
|
||||
if query == nil {
|
||||
return nil, errors.New(
|
||||
"error creating query",
|
||||
"nil pointer passed to NewQuery",
|
||||
)
|
||||
}
|
||||
err := orm.DB.Create(query).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) SaveQuery(query *kolide.Query) error {
|
||||
if query == nil {
|
||||
return errors.New(
|
||||
"error saving query",
|
||||
"nil pointer passed to SaveQuery",
|
||||
)
|
||||
}
|
||||
return orm.DB.Save(query).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) DeleteQuery(query *kolide.Query) error {
|
||||
if query == nil {
|
||||
return errors.New(
|
||||
"error deleting query",
|
||||
"nil pointer passed to DeleteQuery",
|
||||
)
|
||||
}
|
||||
return orm.DB.Delete(query).Error
|
||||
}
|
||||
|
||||
func (orm gormDB) Query(id uint) (*kolide.Query, error) {
|
||||
query := &kolide.Query{
|
||||
ID: id,
|
||||
}
|
||||
err := orm.DB.Where(query).First(query).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func (orm gormDB) Queries() ([]*kolide.Query, error) {
|
||||
var queries []*kolide.Query
|
||||
err := orm.DB.Find(&queries).Error
|
||||
return queries, err
|
||||
}
|
||||
|
|
@ -1,124 +1,36 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupMySQLGORM(t *testing.T) kolide.Datastore {
|
||||
user := "kolide"
|
||||
password := "kolide"
|
||||
dbName := "kolide"
|
||||
func setupGorm(t *testing.T) kolide.Datastore {
|
||||
db, err := gorm.Open("sqlite3", ":memory:")
|
||||
require.Nil(t, err)
|
||||
|
||||
// try container first
|
||||
host := os.Getenv("MYSQL_PORT_3306_TCP_ADDR")
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
host = fmt.Sprintf("%s:3306", host)
|
||||
ds := gormDB{DB: db, Driver: "sqlite3"}
|
||||
|
||||
conn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, host, dbName)
|
||||
db, err := New("gorm-mysql", conn, LimitAttempts(1))
|
||||
err = ds.Migrate()
|
||||
assert.Nil(t, err)
|
||||
|
||||
backend := db.(gormDB)
|
||||
err = backend.Migrate()
|
||||
assert.Nil(t, err)
|
||||
|
||||
return db
|
||||
return ds
|
||||
}
|
||||
|
||||
func teardownMySQLGORM(t *testing.T, db kolide.Datastore) {
|
||||
err := db.Drop()
|
||||
func teardownGorm(t *testing.T, ds kolide.Datastore) {
|
||||
err := ds.Drop()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestEnrollHostMySQLGORM(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
func TestGorm(t *testing.T) {
|
||||
for _, f := range testFunctions {
|
||||
t.Run(functionName(f), func(t *testing.T) {
|
||||
ds := setupGorm(t)
|
||||
defer teardownGorm(t, ds)
|
||||
f(t, ds)
|
||||
})
|
||||
}
|
||||
db := setupMySQLGORM(t)
|
||||
defer teardownMySQLGORM(t, db)
|
||||
|
||||
testEnrollHost(t, db)
|
||||
}
|
||||
|
||||
func TestAuthenticateHostMySQLGORM(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testAuthenticateHost(t, db)
|
||||
}
|
||||
|
||||
func TestUserByIDMySQLGORM(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
db := setup(t)
|
||||
defer teardown(t, db)
|
||||
|
||||
testUserByID(t, db)
|
||||
}
|
||||
|
||||
// TestCreateUser tests the UserStore interface
|
||||
// this test uses the MySQL GORM backend
|
||||
func TestCreateUserMySQLGORM(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
db := setupMySQLGORM(t)
|
||||
defer teardownMySQLGORM(t, db)
|
||||
|
||||
testCreateUser(t, db)
|
||||
}
|
||||
|
||||
func TestSaveUserMySQLGORM(t *testing.T) {
|
||||
address := os.Getenv("MYSQL_ADDR")
|
||||
if address == "" {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
db := setupMySQLGORM(t)
|
||||
defer teardownMySQLGORM(t, db)
|
||||
|
||||
testSaveUser(t, db)
|
||||
}
|
||||
|
||||
func TestSaveQuery(t *testing.T) {
|
||||
ds := setup(t)
|
||||
testSaveQuery(t, ds)
|
||||
}
|
||||
|
||||
func TestDeleteQuery(t *testing.T) {
|
||||
ds := setup(t)
|
||||
testDeleteQuery(t, ds)
|
||||
}
|
||||
|
||||
func TestDeletePack(t *testing.T) {
|
||||
ds := setup(t)
|
||||
testDeletePack(t, ds)
|
||||
}
|
||||
|
||||
func TestAddAndRemoveQueryFromPack(t *testing.T) {
|
||||
ds := setup(t)
|
||||
testAddAndRemoveQueryFromPack(t, ds)
|
||||
}
|
||||
|
||||
func TestManagingLabelsOnPacks(t *testing.T) {
|
||||
ds := setup(t)
|
||||
testManagingLabelsOnPacks(t, ds)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
package datastore
|
||||
|
||||
import "github.com/kolide/kolide-ose/server/kolide"
|
||||
import (
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
// NewUser creates a new user in the gorm backend
|
||||
func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) {
|
||||
err := orm.DB.Create(user).Error
|
||||
if err != nil {
|
||||
|
|
@ -11,7 +12,6 @@ func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) {
|
|||
return user, nil
|
||||
}
|
||||
|
||||
// User returns a specific user in the gorm backend
|
||||
func (orm gormDB) User(username string) (*kolide.User, error) {
|
||||
user := &kolide.User{
|
||||
Username: username,
|
||||
|
|
@ -43,7 +43,6 @@ func (orm gormDB) UserByEmail(email string) (*kolide.User, error) {
|
|||
return user, nil
|
||||
}
|
||||
|
||||
// UserByID returns a datastore user given a user ID
|
||||
func (orm gormDB) UserByID(id uint) (*kolide.User, error) {
|
||||
user := &kolide.User{ID: id}
|
||||
err := orm.DB.Where(user).First(user).Error
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type inmem struct {
|
|||
labels map[uint]*kolide.Label
|
||||
labelQueryExecutions map[uint]*kolide.LabelQueryExecution
|
||||
queries map[uint]*kolide.Query
|
||||
packs map[uint]*kolide.Pack
|
||||
hosts map[uint]*kolide.Host
|
||||
orginfo *kolide.OrgInfo
|
||||
}
|
||||
|
|
@ -35,6 +36,7 @@ func (orm *inmem) Migrate() error {
|
|||
orm.labels = make(map[uint]*kolide.Label)
|
||||
orm.labelQueryExecutions = make(map[uint]*kolide.LabelQueryExecution)
|
||||
orm.queries = make(map[uint]*kolide.Query)
|
||||
orm.packs = make(map[uint]*kolide.Pack)
|
||||
orm.hosts = make(map[uint]*kolide.Host)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
71
server/datastore/inmem_packs.go
Normal file
71
server/datastore/inmem_packs.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package datastore
|
||||
|
||||
import (
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm *inmem) NewPack(pack *kolide.Pack) error {
|
||||
orm.mtx.Lock()
|
||||
defer orm.mtx.Unlock()
|
||||
|
||||
newPack := *pack
|
||||
|
||||
for _, q := range orm.packs {
|
||||
if pack.Name == q.Name {
|
||||
return ErrExists
|
||||
}
|
||||
}
|
||||
|
||||
newPack.ID = uint(len(orm.packs) + 1)
|
||||
orm.packs[newPack.ID] = &newPack
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm *inmem) SavePack(pack *kolide.Pack) error {
|
||||
orm.mtx.Lock()
|
||||
defer orm.mtx.Unlock()
|
||||
|
||||
if _, ok := orm.packs[pack.ID]; !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
orm.packs[pack.ID] = pack
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm *inmem) DeletePack(pid uint) error {
|
||||
orm.mtx.Lock()
|
||||
defer orm.mtx.Unlock()
|
||||
|
||||
if _, ok := orm.packs[pid]; !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
delete(orm.packs, pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (orm *inmem) Pack(id uint) (*kolide.Pack, error) {
|
||||
orm.mtx.Lock()
|
||||
defer orm.mtx.Unlock()
|
||||
|
||||
pack, ok := orm.packs[id]
|
||||
if !ok {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
return pack, nil
|
||||
}
|
||||
|
||||
func (orm *inmem) Packs() ([]*kolide.Pack, error) {
|
||||
orm.mtx.Lock()
|
||||
defer orm.mtx.Unlock()
|
||||
|
||||
packs := []*kolide.Pack{}
|
||||
for _, pack := range orm.packs {
|
||||
packs = append(packs, pack)
|
||||
}
|
||||
|
||||
return packs, nil
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
package datastore
|
||||
|
||||
import "github.com/kolide/kolide-ose/server/kolide"
|
||||
import (
|
||||
"github.com/kolide/kolide-ose/server/kolide"
|
||||
)
|
||||
|
||||
func (orm *inmem) NewQuery(query *kolide.Query) (*kolide.Query, error) {
|
||||
orm.mtx.Lock()
|
||||
|
|
|
|||
Loading…
Reference in a new issue