Do host user inserts one by one to not lock the whole database (#1884)

This commit is contained in:
Tomas Touceda 2021-09-01 11:39:23 -03:00 committed by GitHub
parent 8daa5da84a
commit 79b5330a43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 158 additions and 11 deletions

View file

@ -834,10 +834,15 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
}
incomingUsers := make(map[uint]bool)
var insertArgs []interface{}
for _, u := range host.Users {
insertArgs = append(insertArgs, host.ID, u.Uid, u.Username, u.Type, u.GroupName)
incomingUsers[u.Uid] = true
if _, err := d.db.Exec(
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES (?, ?, ?, ?, ?)`,
host.ID, u.Uid, u.Username, u.Type, u.GroupName,
); err != nil {
return errors.Wrap(err, "insert users")
}
}
var removedArgs []interface{}
@ -847,15 +852,6 @@ func (d *Datastore) SaveHostUsers(host *fleet.Host) error {
}
}
insertValues := strings.TrimSuffix(strings.Repeat("(?, ?, ?, ?, ?),", len(host.Users)), ",")
insertSql := fmt.Sprintf(
`INSERT IGNORE INTO host_users (host_id, uid, username, user_type, groupname) VALUES %s`,
insertValues,
)
if _, err := d.db.Exec(insertSql, insertArgs...); err != nil {
return errors.Wrap(err, "insert users")
}
if len(removedArgs) == 0 {
return nil
}

View file

@ -1,11 +1,13 @@
package mysql
import (
"context"
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@ -1286,3 +1288,152 @@ func TestListHostsByPolicy(t *testing.T) {
require.NoError(t, err)
require.Len(t, hosts, 8)
}
func TestSaveTonsOfUsers(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
host1, err := ds.NewHost(&fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: "1",
UUID: "1",
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
OsqueryHostID: "1",
})
require.NoError(t, err)
require.NotNil(t, host1)
host2, err := ds.NewHost(&fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: "2",
UUID: "2",
Hostname: "foo2.local",
PrimaryIP: "192.168.1.2",
PrimaryMac: "30-65-EC-6F-C4-58",
OsqueryHostID: "2",
})
require.NoError(t, err)
require.NotNil(t, host2)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errCh := make(chan error)
var count1 int32
var count2 int32
go func() {
for {
host1, err := ds.Host(host1.ID)
if err != nil {
errCh <- err
return
}
u1 := fleet.HostUser{
Uid: 42,
Username: "user",
Type: "aaa",
GroupName: "group",
}
u2 := fleet.HostUser{
Uid: 43,
Username: "user2",
Type: "aaa",
GroupName: "group",
}
host1.Users = []fleet.HostUser{u1, u2}
host1.SeenTime = time.Now()
host1.Modified = true
soft := fleet.HostSoftware{
Modified: true,
Software: []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "chrome_extensions"},
{Name: "foo", Version: "0.0.3", Source: "chrome_extensions"},
},
}
host1.HostSoftware = soft
additional := json.RawMessage(`{"some":"thing"}`)
host1.Additional = &additional
err = ds.SaveHost(host1)
if err != nil {
errCh <- err
return
}
atomic.AddInt32(&count1, 1)
select {
case <-ctx.Done():
return
default:
}
}
}()
go func() {
for {
host2, err := ds.Host(host2.ID)
if err != nil {
errCh <- err
return
}
u1 := fleet.HostUser{
Uid: 99,
Username: "user",
Type: "aaa",
GroupName: "group",
}
u2 := fleet.HostUser{
Uid: 98,
Username: "user2",
Type: "aaa",
GroupName: "group",
}
host2.Users = []fleet.HostUser{u1, u2}
host2.SeenTime = time.Now()
host2.Modified = true
soft := fleet.HostSoftware{
Modified: true,
Software: []fleet.Software{
{Name: "foo", Version: "0.0.1", Source: "chrome_extensions"},
{Name: "foo4", Version: "0.0.3", Source: "chrome_extensions"},
},
}
host2.HostSoftware = soft
additional := json.RawMessage(`{"some":"thing"}`)
host2.Additional = &additional
err = ds.SaveHost(host2)
if err != nil {
errCh <- err
return
}
atomic.AddInt32(&count2, 1)
select {
case <-ctx.Done():
return
default:
}
}
}()
ticker := time.NewTicker(10 * time.Second)
select {
case err := <-errCh:
require.NoError(t, err)
cancelFunc()
case <-ticker.C:
}
fmt.Println("Count1", atomic.LoadInt32(&count1))
fmt.Println("Count2", atomic.LoadInt32(&count2))
}