Fix SearchHosts to match for one-char and two-chars queries (#2590)

* Fix SearchHosts to match for one-char and two-chars queries

* Add issue number for future reference
This commit is contained in:
Lucas Manuel Rodriguez 2021-10-21 17:46:21 -03:00 committed by GitHub
parent bcf6697741
commit c84cbb1679
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 92 deletions

View file

@ -0,0 +1 @@
* Fix SearchHosts to match for one-char and two-chars queries.

View file

@ -7,6 +7,11 @@ import (
var mysqlFTSSymbolRegexp = regexp.MustCompile("[-+]+")
// queryMinLength returns true if the query argument is longer than a "short" word.
// What defines a "short" word is MySQL's "ft_min_word_len" VARIABLE, generally set
// to 4 by default in Fleet deployments.
//
// TODO(lucas): Remove this method on #2627.
func queryMinLength(query string) bool {
return countLongestTerm(query) >= 3
}
@ -24,7 +29,11 @@ func countLongestTerm(query string) int {
// transformQuery replaces occurrences of characters that are treated specially
// by the MySQL FTS engine to try to make the search more user-friendly
func transformQuery(query string) string {
return transformQueryWithSuffix(query, "*")
}
func transformQueryWithSuffix(query, suffix string) string {
return strings.TrimSpace(
mysqlFTSSymbolRegexp.ReplaceAllLiteralString(query, " "),
) + "*"
) + suffix
}

View file

@ -554,7 +554,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin
) VALUES (?, ?, ?, ?, ?, ?, ?)
`
result, err := tx.ExecContext(ctx, sqlInsert, zeroTime, zeroTime, zeroTime, osqueryHostID, time.Now().UTC(), nodeKey, teamID)
if err != nil {
return errors.Wrap(err, "insert host")
}
@ -578,7 +577,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin
WHERE osquery_host_id = ?
`
_, err := tx.ExecContext(ctx, sqlUpdate, nodeKey, teamID, osqueryHostID)
if err != nil {
return errors.Wrap(err, "update host")
}
@ -599,7 +597,6 @@ func (d *Datastore) EnrollHost(ctx context.Context, osqueryHostID, nodeKey strin
return nil
})
if err != nil {
return nil, err
}
@ -667,104 +664,53 @@ func (d *Datastore) MarkHostsSeen(ctx context.Context, hostIDs []uint, t time.Ti
return nil
}
func (d *Datastore) searchHostsWithOmits(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) {
hostQuery := transformQuery(query)
ipQuery := `"` + query + `"`
// SearchHosts performs a search on the hosts table using the following criteria:
// - Use the provided team filter.
// - Full-text search with the "query" argument (if query == "", then no fulltext matching is executed).
// Full-text search is used even if "query" is a short or stopword.
// (what defines a short word is the "ft_min_word_len" VARIABLE, set to 4 by default in Fleet deployments).
// - An optional list of IDs to omit from the search.
func (d *Datastore) SearchHosts(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) {
var sqlb strings.Builder
sqlb.WriteString("SELECT * FROM hosts WHERE")
sql := fmt.Sprintf(`
SELECT DISTINCT *
FROM hosts
WHERE
(
var args []interface{}
if len(query) > 0 {
sqlb.WriteString(` (
MATCH (hostname, uuid) AGAINST (? IN BOOLEAN MODE)
OR MATCH (primary_ip, primary_mac) AGAINST (? IN BOOLEAN MODE)
)
AND id NOT IN (?) AND %s
LIMIT 10
`, d.whereFilterHostsByTeams(filter, "hosts"),
)
sql, args, err := sqlx.In(sql, hostQuery, ipQuery, omit)
if err != nil {
return nil, errors.Wrap(err, "searching hosts")
) AND`)
// Transform query argument and append the truncation operator "*" for MATCH.
// From Oracle docs: "If a word is specified with the truncation operator, it is not
// stripped from a boolean query, even if it is too short or a stopword."
hostQuery := transformQueryWithSuffix(query, "*")
// Needs quotes to avoid each "." marking a word boundary.
// TODO(lucas): Currently matching the primary_mac doesn't work, see #1959.
ipQuery := `"` + query + `"`
args = append(args, hostQuery, ipQuery)
}
sql = d.reader.Rebind(sql)
hosts := []*fleet.Host{}
err = sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "searching hosts rebound")
}
return hosts, nil
}
func (d *Datastore) searchHostsDefault(ctx context.Context, filter fleet.TeamFilter, omit ...uint) ([]*fleet.Host, error) {
sql := fmt.Sprintf(`
SELECT * FROM hosts
WHERE id NOT in (?) AND %s
ORDER BY seen_time DESC
LIMIT 5
`, d.whereFilterHostsByTeams(filter, "hosts"),
)
var in interface{}
{
// use -1 if there are no values to omit.
// Avoids empty args error for `sqlx.In`
in = omit
if len(omit) == 0 {
in = -1
}
// use -1 if there are no values to omit.
// Avoids empty args error for `sqlx.In`
in = omit
if len(omit) == 0 {
in = -1
}
args = append(args, in)
sqlb.WriteString(" id NOT IN (?) AND ")
sqlb.WriteString(d.whereFilterHostsByTeams(filter, "hosts"))
sqlb.WriteString(` ORDER BY seen_time DESC LIMIT 10`)
var hosts []*fleet.Host
sql, args, err := sqlx.In(sql, in)
sql, args, err := sqlx.In(sqlb.String(), args...)
if err != nil {
return nil, errors.Wrap(err, "searching default hosts")
}
sql = d.reader.Rebind(sql)
err = sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "searching default hosts rebound")
}
return hosts, nil
}
// SearchHosts find hosts by query containing an IP address, a host name or UUID.
// Optionally pass a list of IDs to omit from the search
func (d *Datastore) SearchHosts(ctx context.Context, filter fleet.TeamFilter, query string, omit ...uint) ([]*fleet.Host, error) {
hostQuery := transformQuery(query)
if !queryMinLength(hostQuery) {
return d.searchHostsDefault(ctx, filter, omit...)
}
if len(omit) > 0 {
return d.searchHostsWithOmits(ctx, filter, query, omit...)
}
// Needs quotes to avoid each . marking a word boundary
ipQuery := `"` + query + `"`
sql := fmt.Sprintf(`
SELECT DISTINCT *
FROM hosts
WHERE
(
MATCH (hostname, uuid) AGAINST (? IN BOOLEAN MODE)
OR MATCH (primary_ip, primary_mac) AGAINST (? IN BOOLEAN MODE)
) AND %s
LIMIT 10
`, d.whereFilterHostsByTeams(filter, "hosts"),
)
hosts := []*fleet.Host{}
if err := sqlx.SelectContext(ctx, d.reader, &hosts, sql, hostQuery, ipQuery); err != nil {
if err := sqlx.SelectContext(ctx, d.reader, &hosts, sql, args...); err != nil {
return nil, errors.Wrap(err, "searching hosts")
}
return hosts, nil
}
func (d *Datastore) HostIDsByName(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) {
@ -789,7 +735,6 @@ func (d *Datastore) HostIDsByName(ctx context.Context, filter fleet.TeamFilter,
}
return hostIDs, nil
}
func (d *Datastore) HostByIdentifier(ctx context.Context, identifier string) (*fleet.Host, error) {

View file

@ -24,22 +24,26 @@ import (
var enrollTests = []struct {
uuid, hostname, platform, nodeKey string
}{
0: {uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
0: {
uuid: "6D14C88F-8ECF-48D5-9197-777647BF6B26",
hostname: "web.fleet.co",
platform: "linux",
nodeKey: "key0",
},
1: {uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
1: {
uuid: "B998C0EB-38CE-43B1-A743-FBD7A5C9513B",
hostname: "mail.fleet.co",
platform: "linux",
nodeKey: "key1",
},
2: {uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
2: {
uuid: "008F0688-5311-4C59-86EE-00C2D6FC3EC2",
hostname: "home.fleet.co",
platform: "darwin",
nodeKey: "key2",
},
3: {uuid: "uuid123",
3: {
uuid: "uuid123",
hostname: "fakehostname",
platform: "darwin",
nodeKey: "key3",
@ -848,6 +852,26 @@ func testHostsSearch(t *testing.T, ds *Datastore) {
hits, err = ds.SearchHosts(context.Background(), filter, "99.100.101", h3.ID)
require.NoError(t, err)
assert.Equal(t, 1, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "f")
require.NoError(t, err)
assert.Equal(t, 2, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "f", h3.ID)
require.NoError(t, err)
assert.Equal(t, 1, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "fx")
require.NoError(t, err)
assert.Equal(t, 0, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "x")
require.NoError(t, err)
assert.Equal(t, 0, len(hits))
hits, err = ds.SearchHosts(context.Background(), filter, "x", h3.ID)
require.NoError(t, err)
assert.Equal(t, 0, len(hits))
}
func testHostsSearchLimit(t *testing.T, ds *Datastore) {