diff --git a/changes/issue-2450-fix-search-by-host b/changes/issue-2450-fix-search-by-host new file mode 100644 index 0000000000..b3390e2bee --- /dev/null +++ b/changes/issue-2450-fix-search-by-host @@ -0,0 +1 @@ +* Fix SearchHosts to match for one-char and two-chars queries. diff --git a/server/datastore/mysql/fulltext.go b/server/datastore/mysql/fulltext.go index a490941660..d9122d400a 100644 --- a/server/datastore/mysql/fulltext.go +++ b/server/datastore/mysql/fulltext.go @@ -7,6 +7,11 @@ import ( var mysqlFTSSymbolRegexp = regexp.MustCompile("[-+]+") +// queryMinLength returns true if the query argument is longer than a "short" word. +// What defines a "short" word is MySQL's "ft_min_word_len" VARIABLE, generally set +// to 4 by default in Fleet deployments. +// +// TODO(lucas): Remove this method on #2627. func queryMinLength(query string) bool { return countLongestTerm(query) >= 3 } @@ -24,7 +29,11 @@ func countLongestTerm(query string) int { // transformQuery replaces occurrences of characters that are treated specially // by the MySQL FTS engine to try to make the search more user-friendly func transformQuery(query string) string { + return transformQueryWithSuffix(query, "*") +} + +func transformQueryWithSuffix(query, suffix string) string { return strings.TrimSpace( mysqlFTSSymbolRegexp.ReplaceAllLiteralString(query, " "), - ) + "*" + ) + suffix } diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 88db5feace..0bfa1f1d35 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -554,7 +554,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin ) VALUES (?, ?, ?, ?, ?, ?, ?) ` result, err := tx.ExecContext(ctx, sqlInsert, zeroTime, zeroTime, zeroTime, osqueryHostID, time.Now().UTC(), nodeKey, teamID) - if err != nil { return errors.Wrap(err, "insert host") } @@ -578,7 +577,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin WHERE osquery_host_id = ? ` _, err := tx.ExecContext(ctx, sqlUpdate, nodeKey, teamID, osqueryHostID) - if err != nil { return errors.Wrap(err, "update host") } @@ -599,7 +597,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin return nil }) - if err != nil { return nil, err } @@ -667,104 +664,53 @@ func (d *Datastore) MarkHostsSeen(ctx context.Context, hostIDs []uint, t time.Ti return nil } -func (d *Datastore) searchHostsWithOmits(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) { - hostQuery := transformQuery(query) - ipQuery := `"` + query + `"` +// SearchHosts performs a search on the hosts table using the following criteria: +// - Use the provided team filter. +// - Full-text search with the "query" argument (if query == "", then no fulltext matching is executed). +// Full-text search is used even if "query" is a short or stopword. +// (what defines a short word is the "ft_min_word_len" VARIABLE, set to 4 by default in Fleet deployments). +// - An optional list of IDs to omit from the search. +func (d *Datastore) SearchHosts(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) { + var sqlb strings.Builder + sqlb.WriteString("SELECT * FROM hosts WHERE") - sql := fmt.Sprintf(` - SELECT DISTINCT * - FROM hosts - WHERE - ( + var args []interface{} + if len(query) > 0 { + sqlb.WriteString(` ( MATCH (hostname, uuid) AGAINST (? IN BOOLEAN MODE) OR MATCH (primary_ip, primary_mac) AGAINST (? IN BOOLEAN MODE) - ) - AND id NOT IN (?) AND %s - LIMIT 10 - `, d.whereFilterHostsByTeams(filter, "hosts"), - ) - - sql, args, err := sqlx.In(sql, hostQuery, ipQuery, omit) - if err != nil { - return nil, errors.Wrap(err, "searching hosts") + ) AND`) + // Transform query argument and append the truncation operator "*" for MATCH. + // From Oracle docs: "If a word is specified with the truncation operator, it is not + // stripped from a boolean query, even if it is too short or a stopword." + hostQuery := transformQueryWithSuffix(query, "*") + // Needs quotes to avoid each "." marking a word boundary. + // TODO(lucas): Currently matching the primary_mac doesn't work, see #1959. + ipQuery := `"` + query + `"` + args = append(args, hostQuery, ipQuery) } - sql = d.reader.Rebind(sql) - - hosts := []*fleet.Host{} - - err = sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...) - if err != nil { - return nil, errors.Wrap(err, "searching hosts rebound") - } - - return hosts, nil -} - -func (d *Datastore) searchHostsDefault(ctx context.Context, filter fleet.TeamFilter, omit ...uint) ([]*fleet.Host, error) { - sql := fmt.Sprintf(` - SELECT * FROM hosts - WHERE id NOT in (?) AND %s - ORDER BY seen_time DESC - LIMIT 5 - `, d.whereFilterHostsByTeams(filter, "hosts"), - ) - var in interface{} - { - // use -1 if there are no values to omit. - // Avoids empty args error for `sqlx.In` - in = omit - if len(omit) == 0 { - in = -1 - } + // use -1 if there are no values to omit. + // Avoids empty args error for `sqlx.In` + in = omit + if len(omit) == 0 { + in = -1 } + args = append(args, in) + sqlb.WriteString(" id NOT IN (?) AND ") + sqlb.WriteString(d.whereFilterHostsByTeams(filter, "hosts")) + sqlb.WriteString(` ORDER BY seen_time DESC LIMIT 10`) - var hosts []*fleet.Host - sql, args, err := sqlx.In(sql, in) + sql, args, err := sqlx.In(sqlb.String(), args...) if err != nil { return nil, errors.Wrap(err, "searching default hosts") } sql = d.reader.Rebind(sql) - err = sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...) - if err != nil { - return nil, errors.Wrap(err, "searching default hosts rebound") - } - return hosts, nil -} - -// SearchHosts find hosts by query containing an IP address, a host name or UUID. -// Optionally pass a list of IDs to omit from the search -func (d *Datastore) SearchHosts(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) { - hostQuery := transformQuery(query) - if !queryMinLength(hostQuery) { - return d.searchHostsDefault(ctx, filter, omit...) - } - if len(omit) > 0 { - return d.searchHostsWithOmits(ctx, filter, query, omit...) - } - - // Needs quotes to avoid each . marking a word boundary - ipQuery := `"` + query + `"` - - sql := fmt.Sprintf(` - SELECT DISTINCT * - FROM hosts - WHERE - ( - MATCH (hostname, uuid) AGAINST (? IN BOOLEAN MODE) - OR MATCH (primary_ip, primary_mac) AGAINST (? IN BOOLEAN MODE) - ) AND %s - LIMIT 10 - `, d.whereFilterHostsByTeams(filter, "hosts"), - ) - hosts := []*fleet.Host{} - if err := sqlx.SelectContext(ctx, d.reader, &hosts, sql, hostQuery, ipQuery); err != nil { + if err := sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...); err != nil { return nil, errors.Wrap(err, "searching hosts") } - return hosts, nil - } func (d *Datastore) HostIDsByName(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { @@ -789,7 +735,6 @@ func (d *Datastore) HostIDsByName(ctx context.Context, filter fleet.TeamFilter, } return hostIDs, nil - } func (d *Datastore) HostByIdentifier(ctx context.Context, identifier string) (*fleet.Host, error) { diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 65c16be215..1986b801b8 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -24,22 +24,26 @@ import ( var enrollTests = []struct { uuid, hostname, platform, nodeKey string }{ - 0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26", + 0: { + uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26", hostname: "web.fleet.co", platform: "linux", nodeKey: "key0", }, - 1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B", + 1: { + uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B", hostname: "mail.fleet.co", platform: "linux", nodeKey: "key1", }, - 2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2", + 2: { + uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2", hostname: "home.fleet.co", platform: "darwin", nodeKey: "key2", }, - 3: {uuid: "uuid123", + 3: { + uuid: "uuid123", hostname: "fakehostname", platform: "darwin", nodeKey: "key3", @@ -848,6 +852,26 @@ func testHostsSearch(t *testing.T, ds *Datastore) { hits, err = ds.SearchHosts(context.Background(), filter, "99.100.101", h3.ID) require.NoError(t, err) assert.Equal(t, 1, len(hits)) + + hits, err = ds.SearchHosts(context.Background(), filter, "f") + require.NoError(t, err) + assert.Equal(t, 2, len(hits)) + + hits, err = ds.SearchHosts(context.Background(), filter, "f", h3.ID) + require.NoError(t, err) + assert.Equal(t, 1, len(hits)) + + hits, err = ds.SearchHosts(context.Background(), filter, "fx") + require.NoError(t, err) + assert.Equal(t, 0, len(hits)) + + hits, err = ds.SearchHosts(context.Background(), filter, "x") + require.NoError(t, err) + assert.Equal(t, 0, len(hits)) + + hits, err = ds.SearchHosts(context.Background(), filter, "x", h3.ID) + require.NoError(t, err) + assert.Equal(t, 0, len(hits)) } func testHostsSearchLimit(t *testing.T, ds *Datastore) {