diff --git a/cli/serve.go b/cli/serve.go index 081b73d860..c202263400 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -67,11 +67,6 @@ the way that the kolide server works. if err != nil { initFatal(err, "initializing datastore") } - - err = ds.Migrate() - if err != nil { - initFatal(err, "initializing datastore") - } } else { connString := datastore.GetMysqlConnectionString(config.Mysql) ds, err = datastore.New("gorm-mysql", connString) diff --git a/server/datastore/datastore.go b/server/datastore/datastore.go index a30cf49313..6c0dbde5a7 100644 --- a/server/datastore/datastore.go +++ b/server/datastore/datastore.go @@ -2,9 +2,12 @@ package datastore import ( + "crypto/rand" + "encoding/base64" "errors" "fmt" + "github.com/kolide/kolide-ose/server/config" "github.com/kolide/kolide-ose/server/kolide" ) @@ -71,14 +74,38 @@ func New(driver, conn string, opts ...DBOption) (kolide.Datastore, error) { return ds, nil case "inmem": ds := &inmem{ - Driver: "inmem", - users: make(map[uint]*kolide.User), - sessions: make(map[uint]*kolide.Session), - passwordResets: make(map[uint]*kolide.PasswordResetRequest), - invites: make(map[uint]*kolide.Invite), + Driver: "inmem", } + + err := ds.Migrate() + if err != nil { + return nil, err + } + return ds, nil default: return nil, fmt.Errorf("unsupported datastore driver %s", driver) } } + +func generateRandomText(keySize int) (string, error) { + key := make([]byte, keySize) + _, err := rand.Read(key) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(key), nil +} + +// GetMysqlConnectionString returns a MySQL connection string using the +// provided configuration. +func GetMysqlConnectionString(conf config.MysqlConfig) string { + return fmt.Sprintf( + "%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", + conf.Username, + conf.Password, + conf.Address, + conf.Database, + ) +} diff --git a/server/datastore/gorm_app_test.go b/server/datastore/datastore_app_test.go similarity index 76% rename from server/datastore/gorm_app_test.go rename to server/datastore/datastore_app_test.go index 0ef2d9f29f..0ca3e79f5b 100644 --- a/server/datastore/gorm_app_test.go +++ b/server/datastore/datastore_app_test.go @@ -1,26 +1,12 @@ package datastore import ( - "os" "testing" "github.com/kolide/kolide-ose/server/kolide" "github.com/stretchr/testify/assert" ) -func TestOrgInfo(t *testing.T) { - var ds kolide.Datastore - address := os.Getenv("MYSQL_ADDR") - if address == "" { - ds = setup(t) - } else { - ds = setupMySQLGORM(t) - defer teardownMySQLGORM(t, ds) - } - - testOrgInfo(t, ds) -} - func testOrgInfo(t *testing.T, db kolide.Datastore) { info := &kolide.OrgInfo{ OrgName: "Kolide", diff --git a/server/datastore/datastore_hosts_test.go b/server/datastore/datastore_hosts_test.go new file mode 100644 index 0000000000..80d7be1c53 --- /dev/null +++ b/server/datastore/datastore_hosts_test.go @@ -0,0 +1,83 @@ +package datastore + +import ( + "fmt" + "testing" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" +) + +var enrollTests = []struct { + uuid, hostname, ip, platform string + nodeKeySize int +}{ + 0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26", + hostname: "web.kolide.co", + ip: "172.0.0.1", + platform: "linux", + nodeKeySize: 12, + }, + 1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B", + hostname: "mail.kolide.co", + ip: "172.0.0.2", + platform: "linux", + nodeKeySize: 10, + }, + 2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2", + hostname: "home.kolide.co", + ip: "127.0.0.1", + platform: "darwin", + nodeKeySize: 25, + }, + 3: {uuid: "uuid123", + hostname: "fakehostname", + ip: "192.168.1.1", + platform: "darwin", + nodeKeySize: 1, + }, +} + +func testEnrollHost(t *testing.T, db kolide.Datastore) { + var hosts []*kolide.Host + for _, tt := range enrollTests { + h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize) + assert.Nil(t, err) + + hosts = append(hosts, h) + assert.Equal(t, tt.uuid, h.UUID) + assert.Equal(t, tt.hostname, h.HostName) + assert.Equal(t, tt.ip, h.IPAddress) + assert.Equal(t, tt.platform, h.Platform) + assert.NotEmpty(t, h.NodeKey) + } + + for _, enrolled := range hosts { + oldNodeKey := enrolled.NodeKey + newhostname := fmt.Sprintf("changed.%s", enrolled.HostName) + + h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15) + assert.Nil(t, err) + assert.Equal(t, enrolled.UUID, h.UUID) + assert.NotEmpty(t, h.NodeKey) + assert.NotEqual(t, oldNodeKey, h.NodeKey) + } + +} + +func testAuthenticateHost(t *testing.T, db kolide.Datastore) { + for _, tt := range enrollTests { + h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize) + assert.Nil(t, err) + + returned, err := db.AuthenticateHost(h.NodeKey) + assert.Nil(t, err) + assert.Equal(t, h.NodeKey, returned.NodeKey) + } + + _, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA") + assert.NotNil(t, err) + + _, err = db.AuthenticateHost("") + assert.NotNil(t, err) +} diff --git a/server/datastore/gorm_invite_test.go b/server/datastore/datastore_invite_test.go similarity index 63% rename from server/datastore/gorm_invite_test.go rename to server/datastore/datastore_invite_test.go index de527cca11..fc567125a9 100644 --- a/server/datastore/gorm_invite_test.go +++ b/server/datastore/datastore_invite_test.go @@ -1,23 +1,13 @@ package datastore import ( - "os" "testing" "github.com/kolide/kolide-ose/server/kolide" "github.com/stretchr/testify/assert" ) -func TestCreateInvite(t *testing.T) { - var ds kolide.Datastore - address := os.Getenv("MYSQL_ADDR") - if address == "" { - ds = setup(t) - } else { - ds = setupMySQLGORM(t) - defer teardownMySQLGORM(t, ds) - } - +func testCreateInvite(t *testing.T, ds kolide.Datastore) { invite := &kolide.Invite{} invite, err := ds.NewInvite(invite) diff --git a/server/datastore/datastore_labels_test.go b/server/datastore/datastore_labels_test.go new file mode 100644 index 0000000000..8fac03bd44 --- /dev/null +++ b/server/datastore/datastore_labels_test.go @@ -0,0 +1,221 @@ +package datastore + +import ( + "sort" + "testing" + "time" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" +) + +func testLabels(t *testing.T, db kolide.Datastore) { + hosts := []kolide.Host{} + var host *kolide.Host + var err error + for i := 0; i < 10; i++ { + host, err = db.EnrollHost(string(i), "foo", "", "", 10) + assert.Nil(t, err, "enrollment should succeed") + hosts = append(hosts, *host) + } + + baseTime := time.Now() + + // No queries should be returned before labels or queries added + queries, err := db.LabelQueriesForHost(host, baseTime) + assert.Nil(t, err) + assert.Empty(t, queries) + + // No labels should match + labels, err := db.LabelsForHost(host.ID) + assert.Nil(t, err) + assert.Empty(t, labels) + + labelQueries := []kolide.Query{ + kolide.Query{ + Name: "query1", + Query: "query1", + Platform: "darwin", + }, + kolide.Query{ + Name: "query2", + Query: "query2", + Platform: "darwin", + }, + kolide.Query{ + Name: "query3", + Query: "query3", + Platform: "darwin", + }, + kolide.Query{ + Name: "query4", + Query: "query4", + Platform: "darwin", + }, + } + + for _, query := range labelQueries { + newQuery, err := db.NewQuery(&query) + assert.Nil(t, err) + assert.NotZero(t, newQuery.ID) + } + + // this one should not show up + _, err = db.NewQuery(&kolide.Query{ + Platform: "not_darwin", + Query: "query5", + }) + assert.Nil(t, err) + + // No queries should be returned before labels added + queries, err = db.LabelQueriesForHost(host, baseTime) + assert.Nil(t, err) + assert.Empty(t, queries) + + newLabels := []kolide.Label{ + // Note these are intentionally out of order + kolide.Label{ + Name: "label3", + QueryID: 3, + }, + kolide.Label{ + Name: "label1", + QueryID: 1, + }, + kolide.Label{ + Name: "label2", + QueryID: 2, + }, + kolide.Label{ + Name: "label4", + QueryID: 4, + }, + } + + for _, label := range newLabels { + newLabel, err := db.NewLabel(&label) + assert.Nil(t, err) + assert.NotZero(t, newLabel.ID) + } + + expectQueries := map[string]string{ + "1": "query3", + "2": "query1", + "3": "query2", + "4": "query4", + } + + host.Platform = "darwin" + + // Now queries should be returned + queries, err = db.LabelQueriesForHost(host, baseTime) + assert.Nil(t, err) + assert.Equal(t, expectQueries, queries) + + // No labels should match with no results yet + labels, err = db.LabelsForHost(host.ID) + assert.Nil(t, err) + assert.Empty(t, labels) + + // Record a query execution + err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime) + assert.Nil(t, err) + + // Use a 10 minute interval, so the query we just added should show up + queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) + assert.Nil(t, err) + delete(expectQueries, "1") + assert.Equal(t, expectQueries, queries) + + // Record an old query execution -- Shouldn't change the return + err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour)) + assert.Nil(t, err) + queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) + assert.Nil(t, err) + assert.Equal(t, expectQueries, queries) + + // Record a newer execution for that query and another + err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime) + assert.Nil(t, err) + + // Now these should no longer show up in the necessary to run queries + delete(expectQueries, "2") + delete(expectQueries, "3") + queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) + assert.Nil(t, err) + assert.Equal(t, expectQueries, queries) + + // Now the two matching labels should be returned + labels, err = db.LabelsForHost(host.ID) + assert.Nil(t, err) + if assert.Len(t, labels, 2) { + labelNames := []string{labels[0].Name, labels[1].Name} + sort.Strings(labelNames) + assert.Equal(t, "label2", labelNames[0]) + assert.Equal(t, "label3", labelNames[1]) + } + + // A host that hasn't executed any label queries should still be asked + // to execute those queries + hosts[0].Platform = "darwin" + queries, err = db.LabelQueriesForHost(host, time.Now()) + assert.Nil(t, err) + assert.Len(t, queries, 4) + + // There should still be no labels returned for a host that never + // executed any label queries + labels, err = db.LabelsForHost(hosts[0].ID) + assert.Nil(t, err) + assert.Empty(t, labels) +} + +func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) { + mysqlQuery := &kolide.Query{ + Name: "MySQL", + Query: "select pid from processes where name = 'mysqld';", + } + mysqlQuery, err := ds.NewQuery(mysqlQuery) + assert.Nil(t, err) + + osqueryRunningQuery := &kolide.Query{ + Name: "Is osquery currently running?", + Query: "select pid from processes where name = 'osqueryd';", + } + osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery) + assert.Nil(t, err) + + monitoringPack := &kolide.Pack{ + Name: "monitoring", + } + err = ds.NewPack(monitoringPack) + assert.Nil(t, err) + + mysqlLabel := &kolide.Label{ + Name: "MySQL Monitoring", + QueryID: mysqlQuery.ID, + } + mysqlLabel, err = ds.NewLabel(mysqlLabel) + assert.Nil(t, err) + + err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID) + assert.Nil(t, err) + + labels, err := ds.GetLabelsForPack(monitoringPack) + assert.Nil(t, err) + assert.Len(t, labels, 1) + assert.Equal(t, "MySQL Monitoring", labels[0].Name) + + osqueryLabel := &kolide.Label{ + Name: "Osquery Monitoring", + QueryID: osqueryRunningQuery.ID, + } + osqueryLabel, err = ds.NewLabel(osqueryLabel) + assert.Nil(t, err) + + err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID) + assert.Nil(t, err) + + labels, err = ds.GetLabelsForPack(monitoringPack) + assert.Nil(t, err) + assert.Len(t, labels, 2) +} diff --git a/server/datastore/datastore_packs_test.go b/server/datastore/datastore_packs_test.go new file mode 100644 index 0000000000..eec1211066 --- /dev/null +++ b/server/datastore/datastore_packs_test.go @@ -0,0 +1,65 @@ +package datastore + +import ( + "testing" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testDeletePack(t *testing.T, ds kolide.Datastore) { + pack := &kolide.Pack{ + Name: "foo", + } + err := ds.NewPack(pack) + assert.Nil(t, err) + assert.NotEqual(t, pack.ID, 0) + + pack, err = ds.Pack(pack.ID) + require.Nil(t, err) + + err = ds.DeletePack(pack.ID) + assert.Nil(t, err) + + assert.NotEqual(t, pack.ID, 0) + pack, err = ds.Pack(pack.ID) + assert.NotNil(t, err) +} + +func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) { + pack := &kolide.Pack{ + Name: "foo", + } + err := ds.NewPack(pack) + assert.Nil(t, err) + + q1 := &kolide.Query{ + Name: "bar", + Query: "bar", + } + _, err = ds.NewQuery(q1) + assert.Nil(t, err) + err = ds.AddQueryToPack(q1.ID, pack.ID) + assert.Nil(t, err) + + q2 := &kolide.Query{ + Name: "baz", + Query: "baz", + } + _, err = ds.NewQuery(q2) + assert.Nil(t, err) + err = ds.AddQueryToPack(q2.ID, pack.ID) + assert.Nil(t, err) + + queries, err := ds.GetQueriesInPack(pack) + assert.Nil(t, err) + assert.Len(t, queries, 2) + + err = ds.RemoveQueryFromPack(q1, pack) + assert.Nil(t, err) + + queries, err = ds.GetQueriesInPack(pack) + assert.Nil(t, err) + assert.Len(t, queries, 1) +} diff --git a/server/datastore/datastore_password_reset_test.go b/server/datastore/datastore_password_reset_test.go new file mode 100644 index 0000000000..861b4ed1ae --- /dev/null +++ b/server/datastore/datastore_password_reset_test.go @@ -0,0 +1,33 @@ +package datastore + +import ( + "testing" + "time" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" +) + +func testPasswordResetRequests(t *testing.T, db kolide.Datastore) { + createTestUsers(t, db) + now := time.Now() + tomorrow := now.Add(time.Hour * 24) + var passwordResetTests = []struct { + userID uint + expires time.Time + token string + }{ + {userID: 1, expires: tomorrow, token: "abcd"}, + } + + for _, tt := range passwordResetTests { + r := &kolide.PasswordResetRequest{ + UserID: tt.userID, + ExpiresAt: tt.expires, + Token: tt.token, + } + req, err := db.NewPasswordResetRequest(r) + assert.Nil(t, err) + assert.Equal(t, tt.userID, req.UserID) + } +} diff --git a/server/datastore/datastore_queries_test.go b/server/datastore/datastore_queries_test.go new file mode 100644 index 0000000000..9390c656ec --- /dev/null +++ b/server/datastore/datastore_queries_test.go @@ -0,0 +1,44 @@ +package datastore + +import ( + "testing" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" +) + +func testDeleteQuery(t *testing.T, ds kolide.Datastore) { + query := &kolide.Query{ + Name: "foo", + Query: "bar", + } + query, err := ds.NewQuery(query) + assert.Nil(t, err) + assert.NotEqual(t, query.ID, 0) + + err = ds.DeleteQuery(query) + assert.Nil(t, err) + + assert.NotEqual(t, query.ID, 0) + _, err = ds.Query(query.ID) + assert.NotNil(t, err) +} + +func testSaveQuery(t *testing.T, ds kolide.Datastore) { + query := &kolide.Query{ + Name: "foo", + Query: "bar", + } + query, err := ds.NewQuery(query) + assert.Nil(t, err) + assert.NotEqual(t, 0, query.ID) + + query.Query = "baz" + err = ds.SaveQuery(query) + + assert.Nil(t, err) + + queryVerify, err := ds.Query(query.ID) + assert.Nil(t, err) + assert.Equal(t, "baz", queryVerify.Query) +} diff --git a/server/datastore/datastore_test.go b/server/datastore/datastore_test.go index e517f48c6f..397ad7b895 100644 --- a/server/datastore/datastore_test.go +++ b/server/datastore/datastore_test.go @@ -1,626 +1,35 @@ package datastore import ( - "fmt" - "math/rand" - "sort" + "reflect" + "runtime" + "strings" "testing" - "time" - "github.com/jinzhu/gorm" "github.com/kolide/kolide-ose/server/kolide" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -const bcryptCost = 6 - -func TestPasswordResetRequests(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testPasswordResetRequests(t, db) -} - -func testPasswordResetRequests(t *testing.T, db kolide.Datastore) { - createTestUsers(t, db) - now := time.Now() - tomorrow := now.Add(time.Hour * 24) - var passwordResetTests = []struct { - userID uint - expires time.Time - token string - }{ - {userID: 1, expires: tomorrow, token: "abcd"}, - } - - for _, tt := range passwordResetTests { - r := &kolide.PasswordResetRequest{ - UserID: tt.userID, - ExpiresAt: tt.expires, - Token: tt.token, - } - req, err := db.NewPasswordResetRequest(r) - assert.Nil(t, err) - assert.Equal(t, tt.userID, req.UserID) - } -} - -func TestEnrollHost(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testEnrollHost(t, db) +func functionName(f func(*testing.T, kolide.Datastore)) string { + fullName := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() + elements := strings.Split(fullName, ".") + return elements[len(elements)-1] } -func TestAuthenticateHost(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testAuthenticateHost(t, db) - -} - -var enrollTests = []struct { - uuid, hostname, ip, platform string - nodeKeySize int -}{ - 0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26", - hostname: "web.kolide.co", - ip: "172.0.0.1", - platform: "linux", - nodeKeySize: 12, - }, - 1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B", - hostname: "mail.kolide.co", - ip: "172.0.0.2", - platform: "linux", - nodeKeySize: 10, - }, - 2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2", - hostname: "home.kolide.co", - ip: "127.0.0.1", - platform: "darwin", - nodeKeySize: 25, - }, - 3: {uuid: "uuid123", - hostname: "fakehostname", - ip: "192.168.1.1", - platform: "darwin", - nodeKeySize: 1, - }, -} - -func testEnrollHost(t *testing.T, db kolide.HostStore) { - var hosts []*kolide.Host - for _, tt := range enrollTests { - h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize) - assert.Nil(t, err) - - hosts = append(hosts, h) - assert.Equal(t, tt.uuid, h.UUID) - assert.Equal(t, tt.hostname, h.HostName) - assert.Equal(t, tt.ip, h.IPAddress) - assert.Equal(t, tt.platform, h.Platform) - assert.NotEmpty(t, h.NodeKey) - } - - for _, enrolled := range hosts { - oldNodeKey := enrolled.NodeKey - newhostname := fmt.Sprintf("changed.%s", enrolled.HostName) - - h, err := db.EnrollHost(enrolled.UUID, newhostname, enrolled.IPAddress, enrolled.Platform, 15) - assert.Nil(t, err) - assert.Equal(t, enrolled.UUID, h.UUID) - assert.NotEmpty(t, h.NodeKey) - assert.NotEqual(t, oldNodeKey, h.NodeKey) - } - -} - -func testAuthenticateHost(t *testing.T, db kolide.HostStore) { - for _, tt := range enrollTests { - h, err := db.EnrollHost(tt.uuid, tt.hostname, tt.ip, tt.platform, tt.nodeKeySize) - assert.Nil(t, err) - - returned, err := db.AuthenticateHost(h.NodeKey) - assert.Nil(t, err) - assert.Equal(t, h.NodeKey, returned.NodeKey) - } - - _, err := db.AuthenticateHost("7B1A9DC9-B042-489F-8D5A-EEC2412C95AA") - assert.NotNil(t, err) - - _, err = db.AuthenticateHost("") - assert.NotNil(t, err) -} - -// TestUser tests the UserStore interface -// this test uses the default testing backend -func TestCreateUser(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testCreateUser(t, db) -} - -func TestSaveUser(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testSaveUser(t, db) -} - -func testCreateUser(t *testing.T, db kolide.UserStore) { - var createTests = []struct { - username, password, email string - isAdmin, passwordReset bool - }{ - {"marpaia", "foobar", "mike@kolide.co", true, false}, - {"jason", "foobar", "jason@kolide.co", true, false}, - } - - for _, tt := range createTests { - u := &kolide.User{ - Username: tt.username, - Password: []byte(tt.password), - Admin: tt.isAdmin, - AdminForcedPasswordReset: tt.passwordReset, - Email: tt.email, - } - user, err := db.NewUser(u) - assert.Nil(t, err) - - verify, err := db.User(tt.username) - assert.Nil(t, err) - - assert.Equal(t, user.ID, verify.ID) - assert.Equal(t, tt.username, verify.Username) - assert.Equal(t, tt.email, verify.Email) - assert.Equal(t, tt.email, verify.Email) - } -} - -func TestUserByID(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testUserByID(t, db) -} - -func testUserByID(t *testing.T, db kolide.UserStore) { - users := createTestUsers(t, db) - for _, tt := range users { - returned, err := db.UserByID(tt.ID) - assert.Nil(t, err) - assert.Equal(t, tt.ID, returned.ID) - } - - // test missing user - _, err := db.UserByID(10000000000) - assert.NotNil(t, err) -} - -func createTestUsers(t *testing.T, db kolide.UserStore) []*kolide.User { - var createTests = []struct { - username, password, email string - isAdmin, passwordReset bool - }{ - {"marpaia", "foobar", "mike@kolide.co", true, false}, - {"jason", "foobar", "jason@kolide.co", false, false}, - } - - var users []*kolide.User - for _, tt := range createTests { - u := &kolide.User{ - Username: tt.username, - Password: []byte(tt.password), - Admin: tt.isAdmin, - AdminForcedPasswordReset: tt.passwordReset, - Email: tt.email, - } - - user, err := db.NewUser(u) - assert.Nil(t, err) - - users = append(users, user) - } - assert.NotEmpty(t, users) - return users -} - -func testSaveUser(t *testing.T, db kolide.UserStore) { - users := createTestUsers(t, db) - testAdminAttribute(t, db, users) - testEmailAttribute(t, db, users) - testPasswordAttribute(t, db, users) -} - -func testPasswordAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) { - for _, user := range users { - user.Password = []byte(randomString(8)) - err := db.SaveUser(user) - assert.Nil(t, err) - - verify, err := db.User(user.Username) - assert.Nil(t, err) - assert.Equal(t, user.Password, verify.Password) - } -} - -func testEmailAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) { - for _, user := range users { - user.Email = fmt.Sprintf("test.%s", user.Email) - err := db.SaveUser(user) - assert.Nil(t, err) - - verify, err := db.User(user.Username) - assert.Nil(t, err) - assert.Equal(t, user.Email, verify.Email) - } -} - -func testAdminAttribute(t *testing.T, db kolide.UserStore, users []*kolide.User) { - for _, user := range users { - user.Admin = false - err := db.SaveUser(user) - assert.Nil(t, err) - - verify, err := db.User(user.Username) - assert.Nil(t, err) - assert.Equal(t, user.Admin, verify.Admin) - } -} - -func TestLabelQueries(t *testing.T) { - db := setup(t) - defer teardown(t, db) - - testLabels(t, db) -} - -func testLabels(t *testing.T, db kolide.Datastore) { - hosts := []kolide.Host{} - var host *kolide.Host - var err error - for i := 0; i < 10; i++ { - host, err = db.EnrollHost(string(i), "foo", "", "", 10) - assert.Nil(t, err, "enrollment should succeed") - hosts = append(hosts, *host) - } - - baseTime := time.Now() - - // No queries should be returned before labels or queries added - queries, err := db.LabelQueriesForHost(host, baseTime) - assert.Nil(t, err) - assert.Empty(t, queries) - - // No labels should match - labels, err := db.LabelsForHost(host.ID) - assert.Nil(t, err) - assert.Empty(t, labels) - - labelQueries := []kolide.Query{ - kolide.Query{ - Name: "query1", - Query: "query1", - Platform: "darwin", - }, - kolide.Query{ - Name: "query2", - Query: "query2", - Platform: "darwin", - }, - kolide.Query{ - Name: "query3", - Query: "query3", - Platform: "darwin", - }, - kolide.Query{ - Name: "query4", - Query: "query4", - Platform: "darwin", - }, - } - - for _, query := range labelQueries { - newQuery, err := db.NewQuery(&query) - assert.Nil(t, err) - assert.NotZero(t, newQuery.ID) - } - - // this one should not show up - _, err = db.NewQuery(&kolide.Query{ - Platform: "not_darwin", - Query: "query5", - }) - assert.Nil(t, err) - - // No queries should be returned before labels added - queries, err = db.LabelQueriesForHost(host, baseTime) - assert.Nil(t, err) - assert.Empty(t, queries) - - newLabels := []kolide.Label{ - // Note these are intentionally out of order - kolide.Label{ - Name: "label3", - QueryID: 3, - }, - kolide.Label{ - Name: "label1", - QueryID: 1, - }, - kolide.Label{ - Name: "label2", - QueryID: 2, - }, - kolide.Label{ - Name: "label4", - QueryID: 4, - }, - } - - for _, label := range newLabels { - newLabel, err := db.NewLabel(&label) - assert.Nil(t, err) - assert.NotZero(t, newLabel.ID) - } - - expectQueries := map[string]string{ - "1": "query3", - "2": "query1", - "3": "query2", - "4": "query4", - } - - host.Platform = "darwin" - - // Now queries should be returned - queries, err = db.LabelQueriesForHost(host, baseTime) - assert.Nil(t, err) - assert.Equal(t, expectQueries, queries) - - // No labels should match with no results yet - labels, err = db.LabelsForHost(host.ID) - assert.Nil(t, err) - assert.Empty(t, labels) - - // Record a query execution - err = db.RecordLabelQueryExecutions(host, map[string]bool{"1": true}, baseTime) - assert.Nil(t, err) - - // Use a 10 minute interval, so the query we just added should show up - queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) - assert.Nil(t, err) - delete(expectQueries, "1") - assert.Equal(t, expectQueries, queries) - - // Record an old query execution -- Shouldn't change the return - err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": true}, baseTime.Add(-1*time.Hour)) - assert.Nil(t, err) - queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) - assert.Nil(t, err) - assert.Equal(t, expectQueries, queries) - - // Record a newer execution for that query and another - err = db.RecordLabelQueryExecutions(host, map[string]bool{"2": false, "3": true}, baseTime) - assert.Nil(t, err) - - // Now these should no longer show up in the necessary to run queries - delete(expectQueries, "2") - delete(expectQueries, "3") - queries, err = db.LabelQueriesForHost(host, time.Now().Add(-(10 * time.Minute))) - assert.Nil(t, err) - assert.Equal(t, expectQueries, queries) - - // Now the two matching labels should be returned - labels, err = db.LabelsForHost(host.ID) - assert.Nil(t, err) - if assert.Len(t, labels, 2) { - labelNames := []string{labels[0].Name, labels[1].Name} - sort.Strings(labelNames) - assert.Equal(t, "label2", labelNames[0]) - assert.Equal(t, "label3", labelNames[1]) - } - - // A host that hasn't executed any label queries should still be asked - // to execute those queries - hosts[0].Platform = "darwin" - queries, err = db.LabelQueriesForHost(host, time.Now()) - assert.Nil(t, err) - assert.Len(t, queries, 4) - - // There should still be no labels returned for a host that never - // executed any label queries - labels, err = db.LabelsForHost(hosts[0].ID) - assert.Nil(t, err) - assert.Empty(t, labels) -} - -// setup creates a datastore for testing -func setup(t *testing.T) kolide.Datastore { - db, err := gorm.Open("sqlite3", ":memory:") - require.Nil(t, err) - - ds := gormDB{DB: db, Driver: "sqlite3"} - - err = ds.Migrate() - assert.Nil(t, err) - // Log using t.Log so that output only shows up if the test fails - //db.SetLogger(&testLogger{t: t}) - //db.LogMode(true) - return ds -} - -func teardown(t *testing.T, ds kolide.Datastore) { - err := ds.Drop() - assert.Nil(t, err) -} - -type testLogger struct { - t *testing.T -} - -func (t *testLogger) Print(v ...interface{}) { - t.t.Log(v...) -} - -func (t *testLogger) Write(p []byte) (n int, err error) { - t.t.Log(string(p)) - return len(p), nil -} - -func randomString(strlen int) string { - rand.Seed(time.Now().UTC().UnixNano()) - const chars = "abcdefghijklmnopqrstuvwxyz0123456789" - result := make([]byte, strlen) - for i := 0; i < strlen; i++ { - result[i] = chars[rand.Intn(len(chars))] - } - return string(result) -} - -func testSaveQuery(t *testing.T, ds kolide.Datastore) { - query := &kolide.Query{ - Name: "foo", - Query: "bar", - } - query, err := ds.NewQuery(query) - assert.Nil(t, err) - assert.NotEqual(t, 0, query.ID) - - query.Query = "baz" - err = ds.SaveQuery(query) - assert.Nil(t, err) - - queryVerify, err := ds.Query(query.ID) - assert.Nil(t, err) - assert.Equal(t, "baz", queryVerify.Query) -} - -func testDeleteQuery(t *testing.T, ds kolide.Datastore) { - query := &kolide.Query{ - Name: "foo", - Query: "bar", - } - query, err := ds.NewQuery(query) - assert.Nil(t, err) - assert.NotEqual(t, query.ID, 0) - - err = ds.DeleteQuery(query) - assert.Nil(t, err) - - assert.NotEqual(t, query.ID, 0) - _, err = ds.Query(query.ID) - assert.NotNil(t, err) -} - -func testDeletePack(t *testing.T, ds kolide.Datastore) { - pack := &kolide.Pack{ - Name: "foo", - } - err := ds.NewPack(pack) - assert.Nil(t, err) - assert.NotEqual(t, pack.ID, 0) - - pack, err = ds.Pack(pack.ID) - assert.Nil(t, err) - - err = ds.DeletePack(pack.ID) - assert.Nil(t, err) - - assert.NotEqual(t, pack.ID, 0) - pack, err = ds.Pack(pack.ID) - assert.NotNil(t, err) -} - -func testAddAndRemoveQueryFromPack(t *testing.T, ds kolide.Datastore) { - pack := &kolide.Pack{ - Name: "foo", - } - err := ds.NewPack(pack) - assert.Nil(t, err) - - q1 := &kolide.Query{ - Name: "bar", - Query: "bar", - } - _, err = ds.NewQuery(q1) - assert.Nil(t, err) - err = ds.AddQueryToPack(q1.ID, pack.ID) - assert.Nil(t, err) - - q2 := &kolide.Query{ - Name: "baz", - Query: "baz", - } - _, err = ds.NewQuery(q2) - assert.Nil(t, err) - err = ds.AddQueryToPack(q2.ID, pack.ID) - assert.Nil(t, err) - - queries, err := ds.GetQueriesInPack(pack) - assert.Nil(t, err) - assert.Len(t, queries, 2) - - err = ds.RemoveQueryFromPack(q1, pack) - assert.Nil(t, err) - - queries, err = ds.GetQueriesInPack(pack) - assert.Nil(t, err) - assert.Len(t, queries, 1) -} - -func testManagingLabelsOnPacks(t *testing.T, ds kolide.Datastore) { - mysqlQuery := &kolide.Query{ - Name: "MySQL", - Query: "select pid from processes where name = 'mysqld';", - } - mysqlQuery, err := ds.NewQuery(mysqlQuery) - assert.Nil(t, err) - - osqueryRunningQuery := &kolide.Query{ - Name: "Is osquery currently running?", - Query: "select pid from processes where name = 'osqueryd';", - } - osqueryRunningQuery, err = ds.NewQuery(osqueryRunningQuery) - assert.Nil(t, err) - - monitoringPack := &kolide.Pack{ - Name: "monitoring", - } - err = ds.NewPack(monitoringPack) - assert.Nil(t, err) - - mysqlLabel := &kolide.Label{ - Name: "MySQL Monitoring", - QueryID: mysqlQuery.ID, - } - mysqlLabel, err = ds.NewLabel(mysqlLabel) - assert.Nil(t, err) - - err = ds.AddLabelToPack(mysqlLabel.ID, monitoringPack.ID) - assert.Nil(t, err) - - labels, err := ds.GetLabelsForPack(monitoringPack) - assert.Nil(t, err) - assert.Len(t, labels, 1) - assert.Equal(t, "MySQL Monitoring", labels[0].Name) - - osqueryLabel := &kolide.Label{ - Name: "Osquery Monitoring", - QueryID: osqueryRunningQuery.ID, - } - osqueryLabel, err = ds.NewLabel(osqueryLabel) - assert.Nil(t, err) - - err = ds.AddLabelToPack(osqueryLabel.ID, monitoringPack.ID) - assert.Nil(t, err) - - labels, err = ds.GetLabelsForPack(monitoringPack) - assert.Nil(t, err) - assert.Len(t, labels, 2) +var testFunctions = [...]func(*testing.T, kolide.Datastore){ + testOrgInfo, + testCreateInvite, + testDeleteQuery, + testSaveQuery, + testDeletePack, + testAddAndRemoveQueryFromPack, + testEnrollHost, + testAuthenticateHost, + testLabels, + testManagingLabelsOnPacks, + testPasswordResetRequests, + testCreateUser, + testSaveUser, + testUserByID, + testPasswordResetRequests, } diff --git a/server/datastore/datastore_users_test.go b/server/datastore/datastore_users_test.go new file mode 100644 index 0000000000..fa81eedba2 --- /dev/null +++ b/server/datastore/datastore_users_test.go @@ -0,0 +1,125 @@ +package datastore + +import ( + "fmt" + "testing" + + "github.com/kolide/kolide-ose/server/kolide" + "github.com/stretchr/testify/assert" +) + +func testCreateUser(t *testing.T, ds kolide.Datastore) { + var createTests = []struct { + username, password, email string + isAdmin, passwordReset bool + }{ + {"marpaia", "foobar", "mike@kolide.co", true, false}, + {"jason", "foobar", "jason@kolide.co", true, false}, + } + + for _, tt := range createTests { + u := &kolide.User{ + Username: tt.username, + Password: []byte(tt.password), + Admin: tt.isAdmin, + AdminForcedPasswordReset: tt.passwordReset, + Email: tt.email, + } + user, err := ds.NewUser(u) + assert.Nil(t, err) + + verify, err := ds.User(tt.username) + assert.Nil(t, err) + + assert.Equal(t, user.ID, verify.ID) + assert.Equal(t, tt.username, verify.Username) + assert.Equal(t, tt.email, verify.Email) + assert.Equal(t, tt.email, verify.Email) + } +} + +func testUserByID(t *testing.T, ds kolide.Datastore) { + users := createTestUsers(t, ds) + for _, tt := range users { + returned, err := ds.UserByID(tt.ID) + assert.Nil(t, err) + assert.Equal(t, tt.ID, returned.ID) + } + + // test missing user + _, err := ds.UserByID(10000000000) + assert.NotNil(t, err) +} + +func createTestUsers(t *testing.T, ds kolide.Datastore) []*kolide.User { + var createTests = []struct { + username, password, email string + isAdmin, passwordReset bool + }{ + {"marpaia", "foobar", "mike@kolide.co", true, false}, + {"jason", "foobar", "jason@kolide.co", false, false}, + } + + var users []*kolide.User + for _, tt := range createTests { + u := &kolide.User{ + Username: tt.username, + Password: []byte(tt.password), + Admin: tt.isAdmin, + AdminForcedPasswordReset: tt.passwordReset, + Email: tt.email, + } + + user, err := ds.NewUser(u) + assert.Nil(t, err) + + users = append(users, user) + } + assert.NotEmpty(t, users) + return users +} + +func testSaveUser(t *testing.T, ds kolide.Datastore) { + users := createTestUsers(t, ds) + testAdminAttribute(t, ds, users) + testEmailAttribute(t, ds, users) + testPasswordAttribute(t, ds, users) +} + +func testPasswordAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) { + for _, user := range users { + randomText, err := generateRandomText(8) + assert.Nil(t, err) + user.Password = []byte(randomText) + err = ds.SaveUser(user) + assert.Nil(t, err) + + verify, err := ds.User(user.Username) + assert.Nil(t, err) + assert.Equal(t, user.Password, verify.Password) + } +} + +func testEmailAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) { + for _, user := range users { + user.Email = fmt.Sprintf("test.%s", user.Email) + err := ds.SaveUser(user) + assert.Nil(t, err) + + verify, err := ds.User(user.Username) + assert.Nil(t, err) + assert.Equal(t, user.Email, verify.Email) + } +} + +func testAdminAttribute(t *testing.T, ds kolide.Datastore, users []*kolide.User) { + for _, user := range users { + user.Admin = false + err := ds.SaveUser(user) + assert.Nil(t, err) + + verify, err := ds.User(user.Username) + assert.Nil(t, err) + assert.Equal(t, user.Admin, verify.Admin) + } +} diff --git a/server/datastore/gorm.go b/server/datastore/gorm.go index 7180367b2a..b7d300682c 100644 --- a/server/datastore/gorm.go +++ b/server/datastore/gorm.go @@ -1,12 +1,7 @@ package datastore import ( - "bytes" - "crypto/rand" - "encoding/base64" "fmt" - "net/http" - "strings" "time" _ "github.com/go-sql-driver/mysql" // db driver @@ -14,7 +9,6 @@ import ( "github.com/jinzhu/gorm" "github.com/kolide/kolide-ose/server/config" - "github.com/kolide/kolide-ose/server/errors" "github.com/kolide/kolide-ose/server/kolide" ) @@ -43,18 +37,6 @@ type gormDB struct { config config.KolideConfig } -// GetMysqlConnectionString returns a MySQL connection string using the -// provided configuration. -func GetMysqlConnectionString(conf config.MysqlConfig) string { - return fmt.Sprintf( - "%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", - conf.Username, - conf.Password, - conf.Address, - conf.Database, - ) -} - func (orm gormDB) Name() string { return "gorm" } @@ -102,579 +84,3 @@ func openGORM(driver, conn string, maxAttempts int) (*gorm.DB, error) { } return db, nil } - -func generateRandomText(keySize int) (string, error) { - key := make([]byte, keySize) - _, err := rand.Read(key) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(key), nil -} - -func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) { - if uuid == "" { - return nil, errors.New("missing uuid for host enrollment", "programmer error?") - } - host := kolide.Host{UUID: uuid} - err := orm.DB.Where(&host).First(&host).Error - if err != nil { - switch err { - case gorm.ErrRecordNotFound: - // Create new Host - host = kolide.Host{ - UUID: uuid, - HostName: hostname, - IPAddress: ip, - Platform: platform, - } - - default: - return nil, err - } - } - - // Generate a new key each enrollment - host.NodeKey, err = generateRandomText(nodeKeySize) - if err != nil { - return nil, err - } - - // Update these fields if provided - if hostname != "" { - host.HostName = hostname - } - if ip != "" { - host.IPAddress = ip - } - if platform != "" { - host.Platform = platform - } - - if err := orm.DB.Save(&host).Error; err != nil { - return nil, err - } - - return &host, nil -} - -func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) { - host := kolide.Host{NodeKey: nodeKey} - err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error - if err != nil { - switch err { - case gorm.ErrRecordNotFound: - e := errors.NewFromError( - err, - http.StatusUnauthorized, - "invalid node key", - ) - // osqueryd expects the literal string "true" here - e.Extra = map[string]interface{}{"node_invalid": "true"} - return nil, e - default: - return nil, errors.DatabaseError(err) - } - } - - return &host, nil -} - -func (orm gormDB) SaveHost(host *kolide.Host) error { - if err := orm.DB.Save(host).Error; err != nil { - return errors.DatabaseError(err) - } - return nil -} - -func (orm gormDB) DeleteHost(host *kolide.Host) error { - return orm.DB.Delete(host).Error -} - -func (orm gormDB) Host(id uint) (*kolide.Host, error) { - host := &kolide.Host{ - ID: id, - } - err := orm.DB.Where(host).First(host).Error - if err != nil { - return nil, err - } - return host, nil -} - -func (orm gormDB) Hosts() ([]*kolide.Host, error) { - var hosts []*kolide.Host - err := orm.DB.Find(&hosts).Error - if err != nil { - return nil, err - } - return hosts, nil -} - -func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) { - if host == nil { - return nil, errors.New( - "error creating host", - "nil pointer passed to NewHost", - ) - } - err := orm.DB.Create(host).Error - if err != nil { - return nil, err - } - return host, err -} - -func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error { - err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error - if err != nil { - return errors.DatabaseError(err) - } - host.UpdatedAt = t - return nil -} - -func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) { - if query == nil { - return nil, errors.New( - "error creating query", - "nil pointer passed to NewQuery", - ) - } - err := orm.DB.Create(query).Error - if err != nil { - return nil, err - } - return query, nil -} - -func (orm gormDB) SaveQuery(query *kolide.Query) error { - if query == nil { - return errors.New( - "error saving query", - "nil pointer passed to SaveQuery", - ) - } - return orm.DB.Save(query).Error -} - -func (orm gormDB) DeleteQuery(query *kolide.Query) error { - if query == nil { - return errors.New( - "error deleting query", - "nil pointer passed to DeleteQuery", - ) - } - return orm.DB.Delete(query).Error -} - -func (orm gormDB) Query(id uint) (*kolide.Query, error) { - query := &kolide.Query{ - ID: id, - } - err := orm.DB.Where(query).First(query).Error - if err != nil { - return nil, err - } - return query, nil -} - -func (orm gormDB) Queries() ([]*kolide.Query, error) { - var queries []*kolide.Query - err := orm.DB.Find(&queries).Error - return queries, err -} - -func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) { - if label == nil { - return nil, errors.New( - "error creating label", - "nil pointer passed to NewLabel", - ) - } - err := orm.DB.Create(label).Error - if err != nil { - return nil, err - } - return label, nil -} - -func (orm gormDB) SaveLabel(label *kolide.Label) error { - if label == nil { - return errors.New( - "error saving label", - "nil pointer passed to SaveLabel", - ) - } - return orm.DB.Save(label).Error -} - -func (orm gormDB) DeleteLabel(lid uint) error { - err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error - if err != nil { - return err - } - - return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error -} - -func (orm gormDB) Label(lid uint) (*kolide.Label, error) { - label := &kolide.Label{ - ID: lid, - } - err := orm.DB.Where("id = ?", label.ID).First(&label).Error - if err != nil { - return nil, err - } - return label, nil -} - -func (orm gormDB) Labels() ([]*kolide.Label, error) { - var labels []*kolide.Label - err := orm.DB.Find(&labels).Error - return labels, err -} - -func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) { - if host == nil { - return nil, errors.New( - "error finding host queries", - "nil pointer passed to LabelQueriesForHost", - ) - } - rows, err := orm.DB.Raw(` -SELECT l.id, q.query -FROM labels l JOIN queries q -ON l.query_id = q.id -WHERE q.platform = ? -AND q.id NOT IN /* subtract the set of executions that are recent enough */ -( - SELECT l.query_id - FROM labels l - JOIN label_query_executions lqe - ON lqe.label_id = l.id - WHERE lqe.host_id = ? AND lqe.updated_at > ? -)`, host.Platform, host.ID, cutoff).Rows() - if err != nil && err != gorm.ErrRecordNotFound { - return nil, errors.DatabaseError(err) - } - defer rows.Close() - - results := make(map[string]string) - for rows.Next() { - var id, query string - err = rows.Scan(&id, &query) - if err != nil { - return nil, errors.DatabaseError(err) - } - results[id] = query - } - - return results, nil -} - -func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error { - if host == nil { - return errors.New( - "error recording host label query execution", - "nil pointer passed to RecordLabelQueryExecutions", - ) - } - - insert := new(bytes.Buffer) - switch orm.Driver { - case "mysql": - insert.WriteString("INSERT ") - case "sqlite3": - insert.WriteString("REPLACE ") - default: - return errors.New( - "Unknown DB driver", - "Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver, - ) - } - - insert.WriteString( - "INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES", - ) - - // Build up all the values and the query string - vals := []interface{}{} - for labelId, res := range results { - insert.WriteString("(?,?,?,?),") - vals = append(vals, t, res, labelId, host.ID) - } - - queryString := insert.String() - queryString = strings.TrimSuffix(queryString, ",") - - switch orm.Driver { - case "mysql": - queryString += ` -ON DUPLICATE KEY UPDATE -updated_at = VALUES(updated_at), -matches = VALUES(matches) -` - } - - if err := orm.DB.Exec(queryString, vals...).Error; err != nil { - return errors.DatabaseError(err) - } - - return nil -} - -func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) { - results := []kolide.Label{} - err := orm.DB.Raw(` -SELECT labels.* from labels, label_query_executions lqe -WHERE lqe.host_id = ? -AND lqe.label_id = labels.id -AND lqe.matches -`, hid).Scan(&results).Error - - if err != nil && err != gorm.ErrRecordNotFound { - return nil, errors.DatabaseError(err) - } - - return results, nil -} - -func (orm gormDB) NewPack(pack *kolide.Pack) error { - if pack == nil { - return errors.New( - "error creating pack", - "nil pointer passed to NewPack", - ) - } - return orm.DB.Create(pack).Error -} - -func (orm gormDB) SavePack(pack *kolide.Pack) error { - if pack == nil { - return errors.New( - "error saving pack", - "nil pointer passed to SavePack", - ) - } - return orm.DB.Save(pack).Error -} - -func (orm gormDB) DeletePack(pid uint) error { - err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error - if err != nil { - return err - } - - err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error - if err != nil { - return err - } - return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error -} - -func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) { - pack := &kolide.Pack{ - ID: pid, - } - err := orm.DB.Where(pack).First(pack).Error - if err != nil { - return nil, err - } - return pack, nil -} - -func (orm gormDB) Packs() ([]*kolide.Pack, error) { - var packs []*kolide.Pack - err := orm.DB.Find(&packs).Error - return packs, err -} - -func (orm gormDB) AddQueryToPack(qid uint, pid uint) error { - pq := &kolide.PackQuery{ - QueryID: qid, - PackID: pid, - } - return orm.DB.Create(pq).Error -} - -func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) { - var queries []*kolide.Query - if pack == nil { - return nil, errors.New( - "error getting queries in pack", - "nil pointer passed to GetQueriesInPack", - ) - } - - rows, err := orm.DB.Raw(` -SELECT - q.id, - q.created_at, - q.updated_at, - q.name, - q.query, - q.interval, - q.snapshot, - q.differential, - q.platform, - q.version -FROM - queries q -JOIN - pack_queries pq -ON - pq.query_id = q.id -AND - pq.pack_id = ?; -`, pack.ID).Rows() - if err != nil && err != gorm.ErrRecordNotFound { - return nil, errors.DatabaseError(err) - } - defer rows.Close() - - for rows.Next() { - query := new(kolide.Query) - err = rows.Scan( - &query.ID, - &query.CreatedAt, - &query.UpdatedAt, - &query.Name, - &query.Query, - &query.Interval, - &query.Snapshot, - &query.Differential, - &query.Platform, - &query.Version, - ) - if err != nil { - return nil, err - } - queries = append(queries, query) - } - - return queries, nil -} - -func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error { - if query == nil || pack == nil { - return errors.New( - "error removing query from pack", - "nil pointer passed to RemoveQueryFromPack", - ) - } - pq := &kolide.PackQuery{ - QueryID: query.ID, - PackID: pack.ID, - } - return orm.DB.Where(pq).Delete(pq).Error -} - -func (orm gormDB) AddLabelToPack(lid uint, pid uint) error { - pt := &kolide.PackTarget{ - Type: kolide.TargetLabel, - PackID: pid, - TargetID: lid, - } - - return orm.DB.Create(pt).Error -} - -func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) { - packs := []*kolide.Pack{} - - // we will need to give some subset of packs to this host based on the - // labels which this host is known to belong to - allPacks, err := orm.Packs() - if err != nil { - return nil, err - } - - // pull the labels that this host belongs to - labels, err := orm.LabelsForHost(hid) - if err != nil { - return nil, err - } - - // in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double - // for loop iteration, we must create the array which may be indexed below - labelIDs := map[uint]bool{} - for _, label := range labels { - labelIDs[label.ID] = true - } - - for _, pack := range allPacks { - // for each pack, we must know what labels have been assigned to that - // pack - labelsForPack, err := orm.GetLabelsForPack(pack) - if err != nil { - return nil, err - } - - // o(n) iteration to determine whether or not a pack is enabled - // in this case, n is len(labelsForPack) - for _, label := range labelsForPack { - if labelIDs[label.ID] { - packs = append(packs, pack) - break - } - } - } - - return packs, nil -} - -func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) { - if pack == nil { - return nil, errors.New( - "error getting labels for pack", - "nil pointer passed to GetLabelsForPack", - ) - } - - results := []*kolide.Label{} - err := orm.DB.Raw(` -SELECT - l.id, - l.created_at, - l.updated_at, - l.name, - l.query_id -FROM - labels l -JOIN - pack_targets pt -ON - pt.target_id = l.id -WHERE - pt.type = ? - AND - pt.pack_id = ?; - -`, - kolide.TargetLabel, pack.ID).Scan(&results).Error - - if err != nil && err != gorm.ErrRecordNotFound { - return nil, errors.DatabaseError(err) - } - - return results, nil -} - -func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error { - if label == nil || pack == nil { - return errors.New( - "error removing label from pack", - "nil pointer passed to RemoveLabelFromPack", - ) - } - - pt := &kolide.PackTarget{ - Type: kolide.TargetLabel, - PackID: pack.ID, - TargetID: label.ID, - } - - return orm.DB.Delete(pt).Error -} diff --git a/server/datastore/gorm_hosts.go b/server/datastore/gorm_hosts.go new file mode 100644 index 0000000000..4b2fcc181d --- /dev/null +++ b/server/datastore/gorm_hosts.go @@ -0,0 +1,132 @@ +package datastore + +import ( + "net/http" + "time" + + "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/server/errors" + "github.com/kolide/kolide-ose/server/kolide" +) + +func (orm gormDB) EnrollHost(uuid, hostname, ip, platform string, nodeKeySize int) (*kolide.Host, error) { + if uuid == "" { + return nil, errors.New("missing uuid for host enrollment", "programmer error?") + } + host := kolide.Host{UUID: uuid} + err := orm.DB.Where(&host).First(&host).Error + if err != nil { + switch err { + case gorm.ErrRecordNotFound: + // Create new Host + host = kolide.Host{ + UUID: uuid, + HostName: hostname, + IPAddress: ip, + Platform: platform, + } + + default: + return nil, err + } + } + + // Generate a new key each enrollment + host.NodeKey, err = generateRandomText(nodeKeySize) + if err != nil { + return nil, err + } + + // Update these fields if provided + if hostname != "" { + host.HostName = hostname + } + if ip != "" { + host.IPAddress = ip + } + if platform != "" { + host.Platform = platform + } + + if err := orm.DB.Save(&host).Error; err != nil { + return nil, err + } + + return &host, nil +} + +func (orm gormDB) AuthenticateHost(nodeKey string) (*kolide.Host, error) { + host := kolide.Host{NodeKey: nodeKey} + err := orm.DB.Where("node_key = ?", host.NodeKey).First(&host).Error + if err != nil { + switch err { + case gorm.ErrRecordNotFound: + e := errors.NewFromError( + err, + http.StatusUnauthorized, + "invalid node key", + ) + // osqueryd expects the literal string "true" here + e.Extra = map[string]interface{}{"node_invalid": "true"} + return nil, e + default: + return nil, errors.DatabaseError(err) + } + } + + return &host, nil +} + +func (orm gormDB) SaveHost(host *kolide.Host) error { + if err := orm.DB.Save(host).Error; err != nil { + return errors.DatabaseError(err) + } + return nil +} + +func (orm gormDB) DeleteHost(host *kolide.Host) error { + return orm.DB.Delete(host).Error +} + +func (orm gormDB) Host(id uint) (*kolide.Host, error) { + host := &kolide.Host{ + ID: id, + } + err := orm.DB.Where(host).First(host).Error + if err != nil { + return nil, err + } + return host, nil +} + +func (orm gormDB) Hosts() ([]*kolide.Host, error) { + var hosts []*kolide.Host + err := orm.DB.Find(&hosts).Error + if err != nil { + return nil, err + } + return hosts, nil +} + +func (orm gormDB) NewHost(host *kolide.Host) (*kolide.Host, error) { + if host == nil { + return nil, errors.New( + "error creating host", + "nil pointer passed to NewHost", + ) + } + err := orm.DB.Create(host).Error + if err != nil { + return nil, err + } + return host, err +} + +func (orm gormDB) MarkHostSeen(host *kolide.Host, t time.Time) error { + err := orm.DB.Exec("UPDATE hosts SET updated_at=? WHERE node_key=?", t, host.NodeKey).Error + if err != nil { + return errors.DatabaseError(err) + } + host.UpdatedAt = t + return nil +} diff --git a/server/datastore/gorm_labels.go b/server/datastore/gorm_labels.go new file mode 100644 index 0000000000..55bd971d3f --- /dev/null +++ b/server/datastore/gorm_labels.go @@ -0,0 +1,166 @@ +package datastore + +import ( + "bytes" + "strings" + "time" + + "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/server/errors" + "github.com/kolide/kolide-ose/server/kolide" +) + +func (orm gormDB) NewLabel(label *kolide.Label) (*kolide.Label, error) { + if label == nil { + return nil, errors.New( + "error creating label", + "nil pointer passed to NewLabel", + ) + } + err := orm.DB.Create(label).Error + if err != nil { + return nil, err + } + return label, nil +} + +func (orm gormDB) SaveLabel(label *kolide.Label) error { + if label == nil { + return errors.New( + "error saving label", + "nil pointer passed to SaveLabel", + ) + } + return orm.DB.Save(label).Error +} + +func (orm gormDB) DeleteLabel(lid uint) error { + err := orm.DB.Where("id = ?", lid).Delete(&kolide.Label{}).Error + if err != nil { + return err + } + + return orm.DB.Where("target_id = ? and type = ?", lid, kolide.TargetLabel).Delete(&kolide.PackTarget{}).Error +} + +func (orm gormDB) Label(lid uint) (*kolide.Label, error) { + label := &kolide.Label{ + ID: lid, + } + err := orm.DB.Where("id = ?", label.ID).First(&label).Error + if err != nil { + return nil, err + } + return label, nil +} + +func (orm gormDB) Labels() ([]*kolide.Label, error) { + var labels []*kolide.Label + err := orm.DB.Find(&labels).Error + return labels, err +} + +func (orm gormDB) LabelQueriesForHost(host *kolide.Host, cutoff time.Time) (map[string]string, error) { + if host == nil { + return nil, errors.New( + "error finding host queries", + "nil pointer passed to LabelQueriesForHost", + ) + } + rows, err := orm.DB.Raw(` +SELECT l.id, q.query +FROM labels l JOIN queries q +ON l.query_id = q.id +WHERE q.platform = ? +AND q.id NOT IN /* subtract the set of executions that are recent enough */ +( + SELECT l.query_id + FROM labels l + JOIN label_query_executions lqe + ON lqe.label_id = l.id + WHERE lqe.host_id = ? AND lqe.updated_at > ? +)`, host.Platform, host.ID, cutoff).Rows() + if err != nil && err != gorm.ErrRecordNotFound { + return nil, errors.DatabaseError(err) + } + defer rows.Close() + + results := make(map[string]string) + for rows.Next() { + var id, query string + err = rows.Scan(&id, &query) + if err != nil { + return nil, errors.DatabaseError(err) + } + results[id] = query + } + + return results, nil +} + +func (orm gormDB) RecordLabelQueryExecutions(host *kolide.Host, results map[string]bool, t time.Time) error { + if host == nil { + return errors.New( + "error recording host label query execution", + "nil pointer passed to RecordLabelQueryExecutions", + ) + } + + insert := new(bytes.Buffer) + switch orm.Driver { + case "mysql": + insert.WriteString("INSERT ") + case "sqlite3": + insert.WriteString("REPLACE ") + default: + return errors.New( + "Unknown DB driver", + "Tried to use unknown DB driver in RecordLabelQueryExecutions: "+orm.Driver, + ) + } + + insert.WriteString( + "INTO label_query_executions (updated_at, matches, label_id, host_id) VALUES", + ) + + // Build up all the values and the query string + vals := []interface{}{} + for labelId, res := range results { + insert.WriteString("(?,?,?,?),") + vals = append(vals, t, res, labelId, host.ID) + } + + queryString := insert.String() + queryString = strings.TrimSuffix(queryString, ",") + + switch orm.Driver { + case "mysql": + queryString += ` +ON DUPLICATE KEY UPDATE +updated_at = VALUES(updated_at), +matches = VALUES(matches) +` + } + + if err := orm.DB.Exec(queryString, vals...).Error; err != nil { + return errors.DatabaseError(err) + } + + return nil +} + +func (orm gormDB) LabelsForHost(hid uint) ([]kolide.Label, error) { + results := []kolide.Label{} + err := orm.DB.Raw(` +SELECT labels.* from labels, label_query_executions lqe +WHERE lqe.host_id = ? +AND lqe.label_id = labels.id +AND lqe.matches +`, hid).Scan(&results).Error + + if err != nil && err != gorm.ErrRecordNotFound { + return nil, errors.DatabaseError(err) + } + + return results, nil +} diff --git a/server/datastore/gorm_packs.go b/server/datastore/gorm_packs.go new file mode 100644 index 0000000000..ac5acb4ea5 --- /dev/null +++ b/server/datastore/gorm_packs.go @@ -0,0 +1,245 @@ +package datastore + +import ( + "github.com/jinzhu/gorm" + "github.com/kolide/kolide-ose/server/errors" + "github.com/kolide/kolide-ose/server/kolide" +) + +func (orm gormDB) NewPack(pack *kolide.Pack) error { + if pack == nil { + return errors.New( + "error creating pack", + "nil pointer passed to NewPack", + ) + } + return orm.DB.Create(pack).Error +} + +func (orm gormDB) SavePack(pack *kolide.Pack) error { + if pack == nil { + return errors.New( + "error saving pack", + "nil pointer passed to SavePack", + ) + } + return orm.DB.Save(pack).Error +} + +func (orm gormDB) DeletePack(pid uint) error { + err := orm.DB.Where("id = ?", pid).Delete(&kolide.Pack{}).Error + if err != nil { + return err + } + + err = orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackQuery{}).Error + if err != nil { + return err + } + return orm.DB.Where("pack_id = ?", pid).Delete(&kolide.PackTarget{}).Error +} + +func (orm gormDB) Pack(pid uint) (*kolide.Pack, error) { + pack := &kolide.Pack{ + ID: pid, + } + err := orm.DB.Where(pack).First(pack).Error + if err != nil { + return nil, err + } + return pack, nil +} + +func (orm gormDB) Packs() ([]*kolide.Pack, error) { + var packs []*kolide.Pack + err := orm.DB.Find(&packs).Error + return packs, err +} + +func (orm gormDB) AddQueryToPack(qid uint, pid uint) error { + pq := &kolide.PackQuery{ + QueryID: qid, + PackID: pid, + } + return orm.DB.Create(pq).Error +} + +func (orm gormDB) GetQueriesInPack(pack *kolide.Pack) ([]*kolide.Query, error) { + var queries []*kolide.Query + if pack == nil { + return nil, errors.New( + "error getting queries in pack", + "nil pointer passed to GetQueriesInPack", + ) + } + + rows, err := orm.DB.Raw(` +SELECT + q.id, + q.created_at, + q.updated_at, + q.name, + q.query, + q.interval, + q.snapshot, + q.differential, + q.platform, + q.version +FROM + queries q +JOIN + pack_queries pq +ON + pq.query_id = q.id +AND + pq.pack_id = ?; +`, pack.ID).Rows() + if err != nil && err != gorm.ErrRecordNotFound { + return nil, errors.DatabaseError(err) + } + defer rows.Close() + + for rows.Next() { + query := new(kolide.Query) + err = rows.Scan( + &query.ID, + &query.CreatedAt, + &query.UpdatedAt, + &query.Name, + &query.Query, + &query.Interval, + &query.Snapshot, + &query.Differential, + &query.Platform, + &query.Version, + ) + if err != nil { + return nil, err + } + queries = append(queries, query) + } + + return queries, nil +} + +func (orm gormDB) RemoveQueryFromPack(query *kolide.Query, pack *kolide.Pack) error { + if query == nil || pack == nil { + return errors.New( + "error removing query from pack", + "nil pointer passed to RemoveQueryFromPack", + ) + } + pq := &kolide.PackQuery{ + QueryID: query.ID, + PackID: pack.ID, + } + return orm.DB.Where(pq).Delete(pq).Error +} + +func (orm gormDB) AddLabelToPack(lid uint, pid uint) error { + pt := &kolide.PackTarget{ + Type: kolide.TargetLabel, + PackID: pid, + TargetID: lid, + } + + return orm.DB.Create(pt).Error +} + +func (orm gormDB) ActivePacksForHost(hid uint) ([]*kolide.Pack, error) { + packs := []*kolide.Pack{} + + // we will need to give some subset of packs to this host based on the + // labels which this host is known to belong to + allPacks, err := orm.Packs() + if err != nil { + return nil, err + } + + // pull the labels that this host belongs to + labels, err := orm.LabelsForHost(hid) + if err != nil { + return nil, err + } + + // in order to use o(1) array indexing in an o(n) loop vs a o(n^2) double + // for loop iteration, we must create the array which may be indexed below + labelIDs := map[uint]bool{} + for _, label := range labels { + labelIDs[label.ID] = true + } + + for _, pack := range allPacks { + // for each pack, we must know what labels have been assigned to that + // pack + labelsForPack, err := orm.GetLabelsForPack(pack) + if err != nil { + return nil, err + } + + // o(n) iteration to determine whether or not a pack is enabled + // in this case, n is len(labelsForPack) + for _, label := range labelsForPack { + if labelIDs[label.ID] { + packs = append(packs, pack) + break + } + } + } + + return packs, nil +} + +func (orm gormDB) GetLabelsForPack(pack *kolide.Pack) ([]*kolide.Label, error) { + if pack == nil { + return nil, errors.New( + "error getting labels for pack", + "nil pointer passed to GetLabelsForPack", + ) + } + + results := []*kolide.Label{} + err := orm.DB.Raw(` +SELECT + l.id, + l.created_at, + l.updated_at, + l.name, + l.query_id +FROM + labels l +JOIN + pack_targets pt +ON + pt.target_id = l.id +WHERE + pt.type = ? + AND + pt.pack_id = ?; + +`, + kolide.TargetLabel, pack.ID).Scan(&results).Error + + if err != nil && err != gorm.ErrRecordNotFound { + return nil, errors.DatabaseError(err) + } + + return results, nil +} + +func (orm gormDB) RemoveLabelFromPack(label *kolide.Label, pack *kolide.Pack) error { + if label == nil || pack == nil { + return errors.New( + "error removing label from pack", + "nil pointer passed to RemoveLabelFromPack", + ) + } + + pt := &kolide.PackTarget{ + Type: kolide.TargetLabel, + PackID: pack.ID, + TargetID: label.ID, + } + + return orm.DB.Delete(pt).Error +} diff --git a/server/datastore/gorm_queries.go b/server/datastore/gorm_queries.go new file mode 100644 index 0000000000..4ff80ef889 --- /dev/null +++ b/server/datastore/gorm_queries.go @@ -0,0 +1,57 @@ +package datastore + +import ( + "github.com/kolide/kolide-ose/server/errors" + "github.com/kolide/kolide-ose/server/kolide" +) + +func (orm gormDB) NewQuery(query *kolide.Query) (*kolide.Query, error) { + if query == nil { + return nil, errors.New( + "error creating query", + "nil pointer passed to NewQuery", + ) + } + err := orm.DB.Create(query).Error + if err != nil { + return nil, err + } + return query, nil +} + +func (orm gormDB) SaveQuery(query *kolide.Query) error { + if query == nil { + return errors.New( + "error saving query", + "nil pointer passed to SaveQuery", + ) + } + return orm.DB.Save(query).Error +} + +func (orm gormDB) DeleteQuery(query *kolide.Query) error { + if query == nil { + return errors.New( + "error deleting query", + "nil pointer passed to DeleteQuery", + ) + } + return orm.DB.Delete(query).Error +} + +func (orm gormDB) Query(id uint) (*kolide.Query, error) { + query := &kolide.Query{ + ID: id, + } + err := orm.DB.Where(query).First(query).Error + if err != nil { + return nil, err + } + return query, nil +} + +func (orm gormDB) Queries() ([]*kolide.Query, error) { + var queries []*kolide.Query + err := orm.DB.Find(&queries).Error + return queries, err +} diff --git a/server/datastore/gorm_test.go b/server/datastore/gorm_test.go index f734ca9ecf..26c9f8f723 100644 --- a/server/datastore/gorm_test.go +++ b/server/datastore/gorm_test.go @@ -1,124 +1,36 @@ package datastore import ( - "fmt" - "os" "testing" - _ "github.com/jinzhu/gorm/dialects/mysql" - _ "github.com/jinzhu/gorm/dialects/sqlite" + "github.com/jinzhu/gorm" "github.com/kolide/kolide-ose/server/kolide" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func setupMySQLGORM(t *testing.T) kolide.Datastore { - user := "kolide" - password := "kolide" - dbName := "kolide" +func setupGorm(t *testing.T) kolide.Datastore { + db, err := gorm.Open("sqlite3", ":memory:") + require.Nil(t, err) - // try container first - host := os.Getenv("MYSQL_PORT_3306_TCP_ADDR") - if host == "" { - host = "127.0.0.1" - } - host = fmt.Sprintf("%s:3306", host) + ds := gormDB{DB: db, Driver: "sqlite3"} - conn := fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, host, dbName) - db, err := New("gorm-mysql", conn, LimitAttempts(1)) + err = ds.Migrate() assert.Nil(t, err) - - backend := db.(gormDB) - err = backend.Migrate() - assert.Nil(t, err) - - return db + return ds } -func teardownMySQLGORM(t *testing.T, db kolide.Datastore) { - err := db.Drop() +func teardownGorm(t *testing.T, ds kolide.Datastore) { + err := ds.Drop() assert.Nil(t, err) } -func TestEnrollHostMySQLGORM(t *testing.T) { - address := os.Getenv("MYSQL_ADDR") - if address == "" { - t.SkipNow() +func TestGorm(t *testing.T) { + for _, f := range testFunctions { + t.Run(functionName(f), func(t *testing.T) { + ds := setupGorm(t) + defer teardownGorm(t, ds) + f(t, ds) + }) } - db := setupMySQLGORM(t) - defer teardownMySQLGORM(t, db) - - testEnrollHost(t, db) -} - -func TestAuthenticateHostMySQLGORM(t *testing.T) { - address := os.Getenv("MYSQL_ADDR") - if address == "" { - t.SkipNow() - } - db := setup(t) - defer teardown(t, db) - - testAuthenticateHost(t, db) -} - -func TestUserByIDMySQLGORM(t *testing.T) { - address := os.Getenv("MYSQL_ADDR") - if address == "" { - t.SkipNow() - } - db := setup(t) - defer teardown(t, db) - - testUserByID(t, db) -} - -// TestCreateUser tests the UserStore interface -// this test uses the MySQL GORM backend -func TestCreateUserMySQLGORM(t *testing.T) { - address := os.Getenv("MYSQL_ADDR") - if address == "" { - t.SkipNow() - } - - db := setupMySQLGORM(t) - defer teardownMySQLGORM(t, db) - - testCreateUser(t, db) -} - -func TestSaveUserMySQLGORM(t *testing.T) { - address := os.Getenv("MYSQL_ADDR") - if address == "" { - t.SkipNow() - } - - db := setupMySQLGORM(t) - defer teardownMySQLGORM(t, db) - - testSaveUser(t, db) -} - -func TestSaveQuery(t *testing.T) { - ds := setup(t) - testSaveQuery(t, ds) -} - -func TestDeleteQuery(t *testing.T) { - ds := setup(t) - testDeleteQuery(t, ds) -} - -func TestDeletePack(t *testing.T) { - ds := setup(t) - testDeletePack(t, ds) -} - -func TestAddAndRemoveQueryFromPack(t *testing.T) { - ds := setup(t) - testAddAndRemoveQueryFromPack(t, ds) -} - -func TestManagingLabelsOnPacks(t *testing.T) { - ds := setup(t) - testManagingLabelsOnPacks(t, ds) } diff --git a/server/datastore/gorm_users.go b/server/datastore/gorm_users.go index 150bd8533e..4676dc3367 100644 --- a/server/datastore/gorm_users.go +++ b/server/datastore/gorm_users.go @@ -1,8 +1,9 @@ package datastore -import "github.com/kolide/kolide-ose/server/kolide" +import ( + "github.com/kolide/kolide-ose/server/kolide" +) -// NewUser creates a new user in the gorm backend func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) { err := orm.DB.Create(user).Error if err != nil { @@ -11,7 +12,6 @@ func (orm gormDB) NewUser(user *kolide.User) (*kolide.User, error) { return user, nil } -// User returns a specific user in the gorm backend func (orm gormDB) User(username string) (*kolide.User, error) { user := &kolide.User{ Username: username, @@ -43,7 +43,6 @@ func (orm gormDB) UserByEmail(email string) (*kolide.User, error) { return user, nil } -// UserByID returns a datastore user given a user ID func (orm gormDB) UserByID(id uint) (*kolide.User, error) { user := &kolide.User{ID: id} err := orm.DB.Where(user).First(user).Error diff --git a/server/datastore/inmem.go b/server/datastore/inmem.go index d53a773f32..f1abe18504 100644 --- a/server/datastore/inmem.go +++ b/server/datastore/inmem.go @@ -17,6 +17,7 @@ type inmem struct { labels map[uint]*kolide.Label labelQueryExecutions map[uint]*kolide.LabelQueryExecution queries map[uint]*kolide.Query + packs map[uint]*kolide.Pack hosts map[uint]*kolide.Host orginfo *kolide.OrgInfo } @@ -35,6 +36,7 @@ func (orm *inmem) Migrate() error { orm.labels = make(map[uint]*kolide.Label) orm.labelQueryExecutions = make(map[uint]*kolide.LabelQueryExecution) orm.queries = make(map[uint]*kolide.Query) + orm.packs = make(map[uint]*kolide.Pack) orm.hosts = make(map[uint]*kolide.Host) return nil } diff --git a/server/datastore/inmem_osquery.go b/server/datastore/inmem_labels.go similarity index 100% rename from server/datastore/inmem_osquery.go rename to server/datastore/inmem_labels.go diff --git a/server/datastore/inmem_packs.go b/server/datastore/inmem_packs.go new file mode 100644 index 0000000000..65ea47fb07 --- /dev/null +++ b/server/datastore/inmem_packs.go @@ -0,0 +1,71 @@ +package datastore + +import ( + "github.com/kolide/kolide-ose/server/kolide" +) + +func (orm *inmem) NewPack(pack *kolide.Pack) error { + orm.mtx.Lock() + defer orm.mtx.Unlock() + + newPack := *pack + + for _, q := range orm.packs { + if pack.Name == q.Name { + return ErrExists + } + } + + newPack.ID = uint(len(orm.packs) + 1) + orm.packs[newPack.ID] = &newPack + + return nil +} + +func (orm *inmem) SavePack(pack *kolide.Pack) error { + orm.mtx.Lock() + defer orm.mtx.Unlock() + + if _, ok := orm.packs[pack.ID]; !ok { + return ErrNotFound + } + + orm.packs[pack.ID] = pack + return nil +} + +func (orm *inmem) DeletePack(pid uint) error { + orm.mtx.Lock() + defer orm.mtx.Unlock() + + if _, ok := orm.packs[pid]; !ok { + return ErrNotFound + } + + delete(orm.packs, pid) + return nil +} + +func (orm *inmem) Pack(id uint) (*kolide.Pack, error) { + orm.mtx.Lock() + defer orm.mtx.Unlock() + + pack, ok := orm.packs[id] + if !ok { + return nil, ErrNotFound + } + + return pack, nil +} + +func (orm *inmem) Packs() ([]*kolide.Pack, error) { + orm.mtx.Lock() + defer orm.mtx.Unlock() + + packs := []*kolide.Pack{} + for _, pack := range orm.packs { + packs = append(packs, pack) + } + + return packs, nil +} diff --git a/server/datastore/inmem_queries.go b/server/datastore/inmem_queries.go index cfac80095d..d654ae42ac 100644 --- a/server/datastore/inmem_queries.go +++ b/server/datastore/inmem_queries.go @@ -1,6 +1,8 @@ package datastore -import "github.com/kolide/kolide-ose/server/kolide" +import ( + "github.com/kolide/kolide-ose/server/kolide" +) func (orm *inmem) NewQuery(query *kolide.Query) (*kolide.Query, error) { orm.mtx.Lock()