diff --git a/server/pubsub/inmem_query_results.go b/server/pubsub/inmem_query_results.go index cd4ba29c80..4ccdd427a9 100644 --- a/server/pubsub/inmem_query_results.go +++ b/server/pubsub/inmem_query_results.go @@ -49,13 +49,13 @@ func (im *inmemQueryResults) WriteResult(result kolide.DistributedQueryResult) e return nil } -func (im *inmemQueryResults) ReadChannel(ctx context.Context, query kolide.DistributedQueryCampaign) (<-chan interface{}, error) { - channel := im.getChannel(query.ID) +func (im *inmemQueryResults) ReadChannel(ctx context.Context, campaign kolide.DistributedQueryCampaign) (<-chan interface{}, error) { + channel := im.getChannel(campaign.ID) go func() { <-ctx.Done() close(channel) im.channelMutex.Lock() - delete(im.resultChannels, query.ID) + delete(im.resultChannels, campaign.ID) im.channelMutex.Unlock() }() return channel, nil diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index b7fe36aa96..53e9ddfea2 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -603,10 +603,10 @@ func (svc service) ingestDistributedQuery(host kolide.Host, name string, rows [] return osqueryError{message: "loading orphaned campaign: " + err.Error()} } - if campaign.CreatedAt.Before(svc.clock.Now().Add(5 * time.Second)) { + if campaign.CreatedAt.After(svc.clock.Now().Add(-5 * time.Second)) { // Give the client 5 seconds to connect before considering the // campaign orphaned - return osqueryError{message: "campaign waiting for listener"} + return osqueryError{message: "campaign waiting for listener (please retry)"} } if campaign.Status != kolide.QueryComplete { @@ -619,6 +619,9 @@ func (svc service) ingestDistributedQuery(host kolide.Host, name string, rows [] if err := svc.liveQueryStore.StopQuery(strconv.Itoa(int(campaignID))); err != nil { return osqueryError{message: "stopping orphaned campaign: " + err.Error()} } + + // No need to record query completion in this case + return nil } err = svc.liveQueryStore.QueryCompletedByHost(strconv.Itoa(int(campaignID)), host.ID) diff --git a/server/service/service_osquery_test.go b/server/service/service_osquery_test.go index 10c156a462..f00a9645f7 100644 --- a/server/service/service_osquery_test.go +++ b/server/service/service_osquery_test.go @@ -1073,6 +1073,253 @@ func TestDistributedQueryResults(t *testing.T) { require.Nil(t, err) } +func TestIngestDistributedQueryParseIdError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + host := kolide.Host{ID: 1} + err := svc.ingestDistributedQuery(host, "bad_name", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse campaign") +} + +func TestIngestDistributedQueryOrphanedCampaignLoadError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + ds.DistributedQueryCampaignFunc = func(id uint) (*kolide.DistributedQueryCampaign, error) { + return nil, fmt.Errorf("missing campaign") + } + + host := kolide.Host{ID: 1} + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "loading orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedCampaignWaitListener(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: kolide.UpdateCreateTimestamps{ + CreateTimestamp: kolide.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-1 * time.Second), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(id uint) (*kolide.DistributedQueryCampaign, error) { + return campaign, nil + } + + host := kolide.Host{ID: 1} + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "campaign waiting for listener") +} + +func TestIngestDistributedQueryOrphanedCloseError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: kolide.UpdateCreateTimestamps{ + CreateTimestamp: kolide.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-30 * time.Second), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(id uint) (*kolide.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(campaign *kolide.DistributedQueryCampaign) error { + return fmt.Errorf("failed save") + } + + host := kolide.Host{ID: 1} + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "closing orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedStopError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: kolide.UpdateCreateTimestamps{ + CreateTimestamp: kolide.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-30 * time.Second), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(id uint) (*kolide.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(campaign *kolide.DistributedQueryCampaign) error { + return nil + } + lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(fmt.Errorf("failed")) + + host := kolide.Host{ID: 1} + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "stopping orphaned campaign") +} + +func TestIngestDistributedQueryOrphanedStop(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ + ID: 42, + UpdateCreateTimestamps: kolide.UpdateCreateTimestamps{ + CreateTimestamp: kolide.CreateTimestamp{ + CreatedAt: mockClock.Now().Add(-30 * time.Second), + }, + }, + } + + ds.DistributedQueryCampaignFunc = func(id uint) (*kolide.DistributedQueryCampaign, error) { + return campaign, nil + } + ds.SaveDistributedQueryCampaignFunc = func(campaign *kolide.DistributedQueryCampaign) error { + return nil + } + lq.On("StopQuery", strconv.Itoa(int(campaign.ID))).Return(nil) + + host := kolide.Host{ID: 1} + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.NoError(t, err) + lq.AssertExpectations(t) +} + +func TestIngestDistributedQueryRecordCompletionError(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ID: 42} + host := kolide.Host{ID: 1} + + lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(fmt.Errorf("fail")) + + go func() { + ch, err := rs.ReadChannel(context.Background(), *campaign) + require.NoError(t, err) + <-ch + }() + time.Sleep(10 * time.Millisecond) + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "record query completion") + lq.AssertExpectations(t) +} + +func TestIngestDistributedQuery(t *testing.T) { + mockClock := clock.NewMockClock() + ds := new(mock.Store) + rs := pubsub.NewInmemQueryResults() + lq := new(live_query.MockLiveQuery) + svc := service{ + ds: ds, + resultStore: rs, + liveQueryStore: lq, + logger: log.NewNopLogger(), + clock: mockClock, + } + + campaign := &kolide.DistributedQueryCampaign{ID: 42} + host := kolide.Host{ID: 1} + + lq.On("QueryCompletedByHost", strconv.Itoa(int(campaign.ID)), host.ID).Return(nil) + + go func() { + ch, err := rs.ReadChannel(context.Background(), *campaign) + require.NoError(t, err) + <-ch + }() + time.Sleep(10 * time.Millisecond) + + err := svc.ingestDistributedQuery(host, "kolide_distributed_query_42", []map[string]string{}, false) + require.NoError(t, err) + lq.AssertExpectations(t) +} + func TestUpdateHostIntervals(t *testing.T) { ds := new(mock.Store)