diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 43997ae987..acbd88746c 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -470,7 +470,7 @@ func (d *Datastore) AuthenticateHost(nodeKey string) (*kolide.Host, error) { if err := d.db.Get(host, sqlStatement, nodeKey); err != nil { switch err { case sql.ErrNoRows: - return nil, errors.Wrap(err, "host not found") + return nil, notFound("Host") default: return nil, errors.New("finding host") } diff --git a/server/mock/datastore_hosts.go b/server/mock/datastore_hosts.go index 20998970ba..f8a9231e9c 100644 --- a/server/mock/datastore_hosts.go +++ b/server/mock/datastore_hosts.go @@ -21,7 +21,6 @@ type HostFunc func(id uint) (*kolide.Host, error) type ListHostsFunc func(opt kolide.ListOptions) ([]*kolide.Host, error) type EnrollHostFunc func(osqueryHostId string, nodeKeySize int) (*kolide.Host, error) - type AuthenticateHostFunc func(nodeKey string) (*kolide.Host, error) type MarkHostSeenFunc func(host *kolide.Host, t time.Time) error diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index d4547d9c07..9745480fa4 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -39,9 +39,16 @@ func (svc service) AuthenticateHost(ctx context.Context, nodeKey string) (*kolid host, err := svc.ds.AuthenticateHost(nodeKey) if err != nil { - return nil, osqueryError{ - message: "authentication error: " + err.Error(), - nodeInvalid: true, + switch err.(type) { + case kolide.NotFoundError: + return nil, osqueryError{ + message: "authentication error: invalid node key: " + nodeKey, + nodeInvalid: true, + } + default: + return nil, osqueryError{ + message: "authentication error: " + err.Error(), + } } } diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go index f33dcfd758..def882b41a 100644 --- a/server/service/service_osquery_test.go +++ b/server/service/service_osquery_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "strings" @@ -933,3 +934,52 @@ func setupOsqueryTests(t *testing.T) (kolide.Datastore, kolide.Service, *clock.M return ds, svc, mockClock } + +type notFoundError struct{} + +func (e notFoundError) Error() string { + return "not found" +} + +func (e notFoundError) IsNotFound() bool { + return true +} + +func TestAuthenticationErrors(t *testing.T) { + ms := new(mock.Store) + ms.MarkHostSeenFunc = func(*kolide.Host, time.Time) error { + return nil + } + ms.AuthenticateHostFunc = func(nodeKey string) (*kolide.Host, error) { + return nil, nil + } + + svc, err := newTestService(ms, nil) + require.Nil(t, err) + ctx := context.Background() + + _, err = svc.AuthenticateHost(ctx, "") + require.NotNil(t, err) + require.True(t, err.(osqueryError).NodeInvalid()) + + _, err = svc.AuthenticateHost(ctx, "foo") + require.Nil(t, err) + + // return not found error + ms.AuthenticateHostFunc = func(nodeKey string) (*kolide.Host, error) { + return nil, notFoundError{} + } + + _, err = svc.AuthenticateHost(ctx, "foo") + require.NotNil(t, err) + require.True(t, err.(osqueryError).NodeInvalid()) + + // return other error + ms.AuthenticateHostFunc = func(nodeKey string) (*kolide.Host, error) { + return nil, errors.New("foo") + } + + _, err = svc.AuthenticateHost(ctx, "foo") + require.NotNil(t, err) + require.False(t, err.(osqueryError).NodeInvalid()) +}