From 9c25ea16414dc7637b655d458faafe2d125e523e Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Mon, 14 Feb 2022 12:13:38 -0300 Subject: [PATCH] Prepare `LoadHostByNodeKey` query once (#4128) * Prepare LoadHostByNodeKey query once * Use a protected map for storing statements * Add proposed test --- go.mod | 1 + server/datastore/mysql/hosts.go | 18 +++++++- server/datastore/mysql/hosts_test.go | 53 ++++++++++++++++++++++++ server/datastore/mysql/mysql.go | 61 ++++++++++++++++++++++++++-- 4 files changed, 127 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 4b0ee67568..92b4ae206b 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 github.com/gosuri/uilive v0.0.4 + github.com/hashicorp/go-multierror v1.0.0 // indirect github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 github.com/igm/sockjs-go/v3 v3.0.0 github.com/jinzhu/copier v0.3.2 diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 1d0e8a6aa3..f98f7f44f8 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -709,12 +709,25 @@ func (ds *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey stri return &host, nil } +// GetContextTryStmt will attempt to run sqlx.GetContext on a cached statement if available, resorting to ds.reader. +func (ds *Datastore) GetContextTryStmt(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + var err error + //nolint the statements are closed in Datastore.Close. + if stmt := ds.loadOrPrepareStmt(ctx, query); stmt != nil { + err = stmt.GetContext(ctx, dest, args...) + } else { + err = sqlx.GetContext(ctx, ds.reader, dest, query, args...) + } + return err +} + // LoadHostByNodeKey loads the whole host identified by the node key. // If the node key is invalid it returns a NotFoundError. func (ds *Datastore) LoadHostByNodeKey(ctx context.Context, nodeKey string) (*fleet.Host, error) { - sqlStatement := `SELECT * FROM hosts WHERE node_key = ?` + query := `SELECT * FROM hosts WHERE node_key = ?` + var host fleet.Host - switch err := sqlx.GetContext(ctx, ds.reader, &host, sqlStatement, nodeKey); { + switch err := ds.GetContextTryStmt(ctx, &host, query, nodeKey); { case err == nil: return &host, nil case errors.Is(err, sql.ErrNoRows): @@ -1264,6 +1277,7 @@ func (ds *Datastore) GetMDM(ctx context.Context, hostID uint) (bool, string, boo } return dest.Enrolled, dest.ServerURL, dest.InstalledFromDep, nil } + func (ds *Datastore) AggregatedMunkiVersion(ctx context.Context, teamID *uint) ([]fleet.AggregatedMunkiVersion, time.Time, error) { id := uint(0) diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 1646daaa99..8d3f882b93 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -95,6 +95,7 @@ func TestHosts(t *testing.T) { {"SaveTonsOfUsers", testHostsSaveTonsOfUsers}, {"SavePackStatsConcurrent", testHostsSavePackStatsConcurrent}, {"LoadHostByNodeKeyLoadsDisk", testLoadHostByNodeKeyLoadsDisk}, + {"LoadHostByNodeKeyUsesStmt", testLoadHostByNodeKeyUsesStmt}, {"HostsListBySoftware", testHostsListBySoftware}, {"HostsListFailingPolicies", printReadsInTest(testHostsListFailingPolicies)}, {"HostsExpiration", testHostsExpiration}, @@ -1421,6 +1422,58 @@ func testLoadHostByNodeKeyLoadsDisk(t *testing.T, ds *Datastore) { assert.NotZero(t, h.PercentDiskSpaceAvailable) } +func testLoadHostByNodeKeyUsesStmt(t *testing.T, ds *Datastore) { + _, err := ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + OsqueryHostID: "foobar", + NodeKey: "nodekey", + UUID: "uuid", + Hostname: "foobar.local", + }) + require.NoError(t, err) + _, err = ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + OsqueryHostID: "foobar2", + NodeKey: "nodekey2", + UUID: "uuid2", + Hostname: "foobar2.local", + }) + require.NoError(t, err) + + err = ds.closeStmts() + require.NoError(t, err) + + ds.stmtCacheMu.Lock() + require.Len(t, ds.stmtCache, 0) + ds.stmtCacheMu.Unlock() + + h, err := ds.LoadHostByNodeKey(context.Background(), "nodekey") + require.NoError(t, err) + require.Equal(t, "foobar.local", h.Hostname) + + ds.stmtCacheMu.Lock() + require.Len(t, ds.stmtCache, 1) + ds.stmtCacheMu.Unlock() + + h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey") + require.NoError(t, err) + require.Equal(t, "foobar.local", h.Hostname) + + ds.stmtCacheMu.Lock() + require.Len(t, ds.stmtCache, 1) + ds.stmtCacheMu.Unlock() + + h, err = ds.LoadHostByNodeKey(context.Background(), "nodekey2") + require.NoError(t, err) + require.Equal(t, "foobar2.local", h.Hostname) +} + func testHostsAdditional(t *testing.T, ds *Datastore) { h, err := ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index e96106c350..c95f6caeab 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -11,6 +11,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/VividCortex/mysqlerr" @@ -27,6 +28,7 @@ import ( "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/go-sql-driver/mysql" + "github.com/hashicorp/go-multierror" "github.com/jmoiron/sqlx" "github.com/ngrok/sqlmw" ) @@ -42,6 +44,7 @@ var columnCharsRegexp = regexp.MustCompile(`[^\w-.]`) // dbReader is an interface that defines the methods required for reads. type dbReader interface { sqlx.QueryerContext + sqlx.PreparerContext Close() error Rebind(string) string @@ -61,6 +64,36 @@ type Datastore struct { readReplicaConfig *config.MysqlConfig writeCh chan itemToWrite + + // stmtCacheMu protects access to stmtCache. + stmtCacheMu sync.Mutex + // stmtCache holds statements for queries. + stmtCache map[string]*sqlx.Stmt +} + +// loadOrPrepareStmt will load a statement from the statements cache. +// If not available, it will attempt to prepare (create) it. +// +// Returns nil if it failed to prepare a statement. +func (ds *Datastore) loadOrPrepareStmt(ctx context.Context, query string) *sqlx.Stmt { + ds.stmtCacheMu.Lock() + defer ds.stmtCacheMu.Unlock() + + stmt, ok := ds.stmtCache[query] + if !ok { + var err error + stmt, err = sqlx.PreparexContext(ctx, ds.reader, query) + if err != nil { + level.Error(ds.logger).Log( + "msg", "failed to prepare statement", + "query", query, + "err", err, + ) + return nil + } + ds.stmtCache[query] = stmt + } + return stmt } type txFn func(sqlx.ExtContext) error @@ -217,6 +250,7 @@ func New(config config.MysqlConfig, c clock.Clock, opts ...DBOption) (*Datastore config: config, readReplicaConfig: options.replicaConfig, writeCh: make(chan itemToWrite), + stmtCache: make(map[string]*sqlx.Stmt), } go ds.writeChanLoop() @@ -482,13 +516,32 @@ func (ds *Datastore) HealthCheck() error { return nil } +func (ds *Datastore) closeStmts() error { + ds.stmtCacheMu.Lock() + defer ds.stmtCacheMu.Unlock() + + var err error + for query, stmt := range ds.stmtCache { + if errClose := stmt.Close(); errClose != nil { + err = multierror.Append(err, errClose) + } + delete(ds.stmtCache, query) + } + return err +} + // Close frees resources associated with underlying mysql connection func (ds *Datastore) Close() error { - err := ds.writer.Close() + var err error + if errStmt := ds.closeStmts(); errStmt != nil { + err = multierror.Append(err, errStmt) + } + if errWriter := ds.writer.Close(); errWriter != nil { + err = multierror.Append(err, errWriter) + } if ds.readReplicaConfig != nil { - errRead := ds.reader.Close() - if err == nil { - err = errRead + if errRead := ds.reader.Close(); errRead != nil { + err = multierror.Append(err, errRead) } } return err