From 31fe9d17b97c407f41e295876b4831475494c9ba Mon Sep 17 00:00:00 2001 From: Lucas Manuel Rodriguez Date: Fri, 22 Mar 2024 11:20:18 -0300 Subject: [PATCH] More fixes to support users with hosts in same team and hosts in different teams (#17789) #17441 --- cmd/fleet/calendar_cron.go | 60 +++++-- server/datastore/mysql/calendar_events.go | 24 +++ server/datastore/mysql/policies.go | 21 ++- server/datastore/mysql/policies_test.go | 200 ++++++++++++++++++++++ server/fleet/datastore.go | 3 +- server/mock/datastore_mock.go | 24 ++- 6 files changed, 306 insertions(+), 26 deletions(-) diff --git a/cmd/fleet/calendar_cron.go b/cmd/fleet/calendar_cron.go index fa63b487dc..17962f1847 100644 --- a/cmd/fleet/calendar_cron.go +++ b/cmd/fleet/calendar_cron.go @@ -125,7 +125,7 @@ func cronCalendarEventsForTeam( for _, policy := range policies { policyIDs = append(policyIDs, policy.ID) } - hosts, err := ds.GetHostsPolicyMemberships(ctx, domain, policyIDs) + hosts, err := ds.GetTeamHostsPolicyMemberships(ctx, domain, team.ID, policyIDs) if err != nil { return fmt.Errorf("get team hosts failing policies: %w", err) } @@ -150,22 +150,28 @@ func cronCalendarEventsForTeam( } level.Debug(logger).Log( "msg", "summary", + "team_id", team.ID, "passing_hosts", len(passingHosts), "failing_hosts", len(failingHosts), "failing_hosts_without_associated_email", len(failingHostsWithoutAssociatedEmail), ) + // Remove calendar events from hosts that are passing the calendar policies. + // + // We execute this first to remove any calendar events for a user that is now passing + // policies on one of its hosts, and possibly create a new calendar event if they have + // another failing host on the same team. + if err := removeCalendarEventsFromPassingHosts(ctx, ds, calendar, passingHosts); err != nil { + level.Info(logger).Log("msg", "removing calendar events from passing hosts", "err", err) + } + + // Process hosts that are failing calendar policies. if err := processCalendarFailingHosts( ctx, ds, calendar, orgName, failingHosts, logger, ); err != nil { level.Info(logger).Log("msg", "processing failing hosts", "err", err) } - // Remove calendar events from hosts that are passing the policies. - if err := removeCalendarEventsFromPassingHosts(ctx, ds, calendar, passingHosts); err != nil { - level.Info(logger).Log("msg", "removing calendar events from passing hosts", "err", err) - } - // At last we want to log the hosts that are failing and don't have an associated email. logHostsWithoutAssociatedEmail( domain, @@ -184,14 +190,26 @@ func processCalendarFailingHosts( hosts []fleet.HostPolicyMembershipData, logger kitlog.Logger, ) error { + hosts = filterHostsWithSameEmail(hosts) + for _, host := range hosts { logger := log.With(logger, "host_id", host.HostID) - hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEvent(ctx, host.HostID) + hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEventByEmail(ctx, host.Email) expiredEvent := false webhookAlreadyFiredThisMonth := false if err == nil { + if hostCalendarEvent.HostID != host.HostID { + // This calendar event belongs to another host with this associated email, + // thus we skip this entry. + continue // continue with next host + } + if hostCalendarEvent.WebhookStatus == fleet.CalendarWebhookStatusPending { + // This can happen if the host went offline (and never returned results) + // after setting the webhook as pending. + continue // continue with next host + } now := time.Now() webhookAlreadyFired := hostCalendarEvent.WebhookStatus == fleet.CalendarWebhookStatusSent if webhookAlreadyFired && sameDate(now, calendarEvent.StartTime) { @@ -200,7 +218,7 @@ func processCalendarFailingHosts( continue // continue with next host } webhookAlreadyFiredThisMonth = webhookAlreadyFired && sameMonth(now, calendarEvent.StartTime) - if calendarEvent.EndTime.Before(time.Now()) { + if calendarEvent.EndTime.Before(now) { expiredEvent = true } } @@ -232,6 +250,25 @@ func processCalendarFailingHosts( return nil } +func filterHostsWithSameEmail(hosts []fleet.HostPolicyMembershipData) []fleet.HostPolicyMembershipData { + minHostPerEmail := make(map[string]fleet.HostPolicyMembershipData) + for _, host := range hosts { + minHost, ok := minHostPerEmail[host.Email] + if !ok { + minHostPerEmail[host.Email] = host + continue + } + if host.HostID < minHost.HostID { + minHostPerEmail[host.Email] = host + } + } + filtered := make([]fleet.HostPolicyMembershipData, 0, len(minHostPerEmail)) + for _, host := range minHostPerEmail { + filtered = append(filtered, host) + } + return filtered +} + func processFailingHostExistingCalendarEvent( ctx context.Context, ds fleet.Datastore, @@ -416,10 +453,13 @@ func removeCalendarEventsFromPassingHosts( hosts []fleet.HostPolicyMembershipData, ) error { for _, host := range hosts { - calendarEvent, err := ds.GetCalendarEvent(ctx, host.Email) + hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEventByEmail(ctx, host.Email) switch { case err == nil: - // OK + if hostCalendarEvent.HostID != host.HostID { + // This calendar event belongs to another host, thus we skip this entry. + continue + } case fleet.IsNotFound(err): continue default: diff --git a/server/datastore/mysql/calendar_events.go b/server/datastore/mysql/calendar_events.go index 5ffc0f77f3..45d8d88331 100644 --- a/server/datastore/mysql/calendar_events.go +++ b/server/datastore/mysql/calendar_events.go @@ -167,6 +167,30 @@ func (ds *Datastore) GetHostCalendarEvent(ctx context.Context, hostID uint) (*fl return &hostCalendarEvent, &calendarEvent, nil } +func (ds *Datastore) GetHostCalendarEventByEmail(ctx context.Context, email string) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) { + const calendarEventsQuery = ` + SELECT * FROM calendar_events WHERE email = ? + ` + var calendarEvent fleet.CalendarEvent + if err := sqlx.GetContext(ctx, ds.reader(ctx), &calendarEvent, calendarEventsQuery, email); err != nil { + if err == sql.ErrNoRows { + return nil, nil, ctxerr.Wrap(ctx, notFound("CalendarEvent").WithMessage(fmt.Sprintf("email: %s", email))) + } + return nil, nil, ctxerr.Wrap(ctx, err, "get calendar event") + } + const hostCalendarEventsQuery = ` + SELECT * FROM host_calendar_events WHERE calendar_event_id = ? + ` + var hostCalendarEvent fleet.HostCalendarEvent + if err := sqlx.GetContext(ctx, ds.reader(ctx), &hostCalendarEvent, hostCalendarEventsQuery, calendarEvent.ID); err != nil { + if err == sql.ErrNoRows { + return nil, nil, ctxerr.Wrap(ctx, notFound("HostCalendarEvent").WithID(calendarEvent.ID)) + } + return nil, nil, ctxerr.Wrap(ctx, err, "get host calendar event") + } + return &hostCalendarEvent, &calendarEvent, nil +} + func (ds *Datastore) UpdateHostCalendarWebhookStatus(ctx context.Context, hostID uint, status fleet.CalendarWebhookStatus) error { const calendarEventsQuery = ` UPDATE host_calendar_events SET diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index b711c8b932..71530961de 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -1172,8 +1172,12 @@ func (ds *Datastore) GetCalendarPolicies(ctx context.Context, teamID uint) ([]fl } // TODO(lucas): Must be tested at scale. -// TODO(lucas): Filter out hosts with team_id == NULL -func (ds *Datastore) GetHostsPolicyMemberships(ctx context.Context, domain string, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) { +func (ds *Datastore) GetTeamHostsPolicyMemberships( + ctx context.Context, + domain string, + teamID uint, + policyIDs []uint, +) ([]fleet.HostPolicyMembershipData, error) { query := ` SELECT COALESCE(sh.email, '') AS email, @@ -1188,18 +1192,17 @@ func (ds *Datastore) GetHostsPolicyMemberships(ctx context.Context, domain strin GROUP BY host_id ) pm LEFT JOIN ( - SELECT MIN(h.host_id) as host_id, h.email as email - FROM ( - SELECT host_id, MIN(email) AS email - FROM host_emails WHERE email LIKE CONCAT('%@', ?) - GROUP BY host_id - ) h GROUP BY h.email + SELECT host_id, MIN(email) AS email + FROM host_emails + JOIN hosts ON host_emails.host_id=hosts.id + WHERE email LIKE CONCAT('%@', ?) AND team_id = ? + GROUP BY host_id ) sh ON sh.host_id = pm.host_id JOIN hosts h ON h.id = pm.host_id LEFT JOIN host_display_names hdn ON hdn.host_id = pm.host_id; ` - query, args, err := sqlx.In(query, policyIDs, domain) + query, args, err := sqlx.In(query, policyIDs, domain, teamID) if err != nil { return nil, ctxerr.Wrapf(ctx, err, "build select get team hosts policy memberships query") } diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 514de6dd38..15ebeee171 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -60,6 +60,7 @@ func TestPolicies(t *testing.T) { {"TestPoliciesNameEmoji", testPoliciesNameEmoji}, {"TestPoliciesNameSort", testPoliciesNameSort}, {"TestGetCalendarPolicies", testGetCalendarPolicies}, + {"GetTeamHostsPolicyMemberships", testGetTeamHostsPolicyMemberships}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -2860,3 +2861,202 @@ func testGetCalendarPolicies(t *testing.T, ds *Datastore) { require.Equal(t, calendarPolicies[0].ID, teamPolicy2.ID) require.Equal(t, calendarPolicies[1].ID, teamPolicy3.ID) } + +func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) { + ctx := context.Background() + + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"}) + require.NoError(t, err) + team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) + require.NoError(t, err) + + team1Policy1, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{ + Name: "Team 1 Policy 1", + Query: "SELECT * FROM osquery_info;", + CalendarEventsEnabled: true, + }) + require.NoError(t, err) + team1Policy2, err := ds.NewTeamPolicy(ctx, team1.ID, nil, fleet.PolicyPayload{ + Name: "Team 1 Policy 2", + Query: "SELECT * FROM system_info;", + CalendarEventsEnabled: false, + }) + require.NoError(t, err) + team2Policy1, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{ + Name: "Team 2 Policy 1", + Query: "SELECT * FROM os_version;", + CalendarEventsEnabled: true, + }) + require.NoError(t, err) + team2Policy2, err := ds.NewTeamPolicy(ctx, team2.ID, nil, fleet.PolicyPayload{ + Name: "Team 2 Policy 2", + Query: "SELECT * FROM processes;", + CalendarEventsEnabled: true, + }) + require.NoError(t, err) + + // Empty teams. + hostsTeam1, err := ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policy1.ID, team1Policy2.ID}) + require.NoError(t, err) + require.Len(t, hostsTeam1, 0) + + host1, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("host1"), + NodeKey: ptr.String("host1"), + HardwareSerial: "serial1", + ComputerName: "display_name1", + TeamID: &team1.ID, + }) + require.NoError(t, err) + host2, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("host2"), + NodeKey: ptr.String("host2"), + HardwareSerial: "serial2", + ComputerName: "display_name2", + TeamID: &team2.ID, + }) + require.NoError(t, err) + host3, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("host3"), + NodeKey: ptr.String("host3"), + HardwareSerial: "serial3", + ComputerName: "display_name3", + TeamID: &team2.ID, + }) + require.NoError(t, err) + host4, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("host4"), + NodeKey: ptr.String("host4"), + HardwareSerial: "serial4", + ComputerName: "display_name4", + }) + require.NoError(t, err) + host5, err := ds.NewHost(ctx, &fleet.Host{ + OsqueryHostID: ptr.String("host5"), + NodeKey: ptr.String("host5"), + HardwareSerial: "serial5", + ComputerName: "display_name5", + TeamID: &team1.ID, + }) + require.NoError(t, err) + + // No policy results yet. + hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policy1.ID, team1Policy2.ID}) + require.NoError(t, err) + require.Len(t, hostsTeam1, 0) + + err = ds.ReplaceHostDeviceMapping(ctx, host1.ID, []*fleet.HostDeviceMapping{ + {HostID: host1.ID, Email: "foo@example.com", Source: "google_chrome_profiles"}, + }, "google_chrome_profiles") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host1.ID, []*fleet.HostDeviceMapping{ + {HostID: host1.ID, Email: "zoo@example.com", Source: "custom"}, + }, "custom") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host2.ID, []*fleet.HostDeviceMapping{ + {HostID: host2.ID, Email: "foo@example.com", Source: "custom"}, + }, "custom") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host2.ID, []*fleet.HostDeviceMapping{ + {HostID: host2.ID, Email: "foo@other.com", Source: "google_chrome_profiles"}, + }, "google_chrome_profiles") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host3.ID, []*fleet.HostDeviceMapping{ + {HostID: host3.ID, Email: "zoo@example.com", Source: "google_chrome_profiles"}, + }, "google_chrome_profiles") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host4.ID, []*fleet.HostDeviceMapping{ + {HostID: host4.ID, Email: "foo@example.com", Source: "google_chrome_profiles"}, + }, "google_chrome_profiles") + require.NoError(t, err) + err = ds.ReplaceHostDeviceMapping(ctx, host5.ID, []*fleet.HostDeviceMapping{ + {HostID: host5.ID, Email: "foo@other.com", Source: "google_chrome_profiles"}, + }, "google_chrome_profiles") + require.NoError(t, err) + + err = ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{ + team1Policy1.ID: ptr.Bool(true), + team1Policy2.ID: ptr.Bool(false), + }, time.Now(), false) + require.NoError(t, err) + + err = ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{ + team2Policy1.ID: ptr.Bool(false), + team2Policy2.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + + err = ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{ + team2Policy1.ID: ptr.Bool(true), + team2Policy2.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + + err = ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{ + team1Policy1.ID: ptr.Bool(false), + team1Policy2.ID: ptr.Bool(false), + }, time.Now(), false) + require.NoError(t, err) + + team1Policies, err := ds.GetCalendarPolicies(ctx, team1.ID) + require.NoError(t, err) + require.Len(t, team1Policies, 1) + team2Policies, err := ds.GetCalendarPolicies(ctx, team2.ID) + require.NoError(t, err) + require.Len(t, team2Policies, 2) + + hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}) + require.NoError(t, err) + require.Len(t, hostsTeam1, 2) + require.Equal(t, host1.ID, hostsTeam1[0].HostID) + require.Equal(t, "foo@example.com", hostsTeam1[0].Email) + require.True(t, hostsTeam1[0].Passing) + require.Equal(t, "serial1", hostsTeam1[0].HostHardwareSerial) + require.Equal(t, "display_name1", hostsTeam1[0].HostDisplayName) + require.Equal(t, host5.ID, hostsTeam1[1].HostID) + require.Empty(t, hostsTeam1[1].Email) + require.False(t, hostsTeam1[1].Passing) + require.Equal(t, "serial5", hostsTeam1[1].HostHardwareSerial) + require.Equal(t, "display_name5", hostsTeam1[1].HostDisplayName) + + err = ds.AddHostsToTeam(ctx, &team1.ID, []uint{host4.ID}) + require.NoError(t, err) + err = ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{ + team1Policy1.ID: ptr.Bool(false), + team1Policy2.ID: ptr.Bool(false), + }, time.Now(), false) + require.NoError(t, err) + + hostsTeam1, err = ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team1.ID, []uint{team1Policies[0].ID}) + require.NoError(t, err) + require.Len(t, hostsTeam1, 3) + require.Equal(t, host1.ID, hostsTeam1[0].HostID) + require.Equal(t, "foo@example.com", hostsTeam1[0].Email) + require.True(t, hostsTeam1[0].Passing) + require.Equal(t, "serial1", hostsTeam1[0].HostHardwareSerial) + require.Equal(t, "display_name1", hostsTeam1[0].HostDisplayName) + require.Equal(t, host4.ID, hostsTeam1[1].HostID) + require.Equal(t, "foo@example.com", hostsTeam1[1].Email) + require.False(t, hostsTeam1[1].Passing) + require.Equal(t, "serial4", hostsTeam1[1].HostHardwareSerial) + require.Equal(t, "display_name4", hostsTeam1[1].HostDisplayName) + require.Equal(t, host5.ID, hostsTeam1[2].HostID) + require.Empty(t, hostsTeam1[2].Email) + require.False(t, hostsTeam1[2].Passing) + require.Equal(t, "serial5", hostsTeam1[2].HostHardwareSerial) + require.Equal(t, "display_name5", hostsTeam1[2].HostDisplayName) + + hostsTeam2, err := ds.GetTeamHostsPolicyMemberships(ctx, "example.com", team2.ID, []uint{team2Policies[0].ID, team2Policies[1].ID}) + require.NoError(t, err) + require.Len(t, hostsTeam2, 2) + require.Equal(t, host2.ID, hostsTeam2[0].HostID) + require.Equal(t, "foo@example.com", hostsTeam2[0].Email) + require.False(t, hostsTeam2[0].Passing) + require.Equal(t, "serial2", hostsTeam2[0].HostHardwareSerial) + require.Equal(t, "display_name2", hostsTeam2[0].HostDisplayName) + require.Equal(t, host3.ID, hostsTeam2[1].HostID) + require.Equal(t, "zoo@example.com", hostsTeam2[1].Email) + require.True(t, hostsTeam2[1].Passing) + require.Equal(t, "serial3", hostsTeam2[1].HostHardwareSerial) + require.Equal(t, "display_name3", hostsTeam2[1].HostDisplayName) +} diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 65098efdfc..a2f8bf6cdd 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -594,7 +594,7 @@ type Datastore interface { PolicyQueriesForHost(ctx context.Context, host *Host) (map[string]string, error) - GetHostsPolicyMemberships(ctx context.Context, domain string, policyIDs []uint) ([]HostPolicyMembershipData, error) + GetTeamHostsPolicyMemberships(ctx context.Context, domain string, teamID uint, policyIDs []uint) ([]HostPolicyMembershipData, error) GetCalendarPolicies(ctx context.Context, teamID uint) ([]PolicyCalendarData, error) // Methods used for async processing of host policy query results. @@ -624,6 +624,7 @@ type Datastore interface { DeleteCalendarEvent(ctx context.Context, calendarEventID uint) error UpdateCalendarEvent(ctx context.Context, calendarEventID uint, startTime time.Time, endTime time.Time, data []byte) error GetHostCalendarEvent(ctx context.Context, hostID uint) (*HostCalendarEvent, *CalendarEvent, error) + GetHostCalendarEventByEmail(ctx context.Context, email string) (*HostCalendarEvent, *CalendarEvent, error) UpdateHostCalendarWebhookStatus(ctx context.Context, hostID uint, status CalendarWebhookStatus) error ListCalendarEvents(ctx context.Context, teamID *uint) ([]*CalendarEvent, error) ListOutOfDateCalendarEvents(ctx context.Context, t time.Time) ([]*CalendarEvent, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 1b77b29cbe..425e0945e7 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -440,7 +440,7 @@ type UpdateHostPolicyCountsFunc func(ctx context.Context) error type PolicyQueriesForHostFunc func(ctx context.Context, host *fleet.Host) (map[string]string, error) -type GetHostsPolicyMembershipsFunc func(ctx context.Context, domain string, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) +type GetTeamHostsPolicyMembershipsFunc func(ctx context.Context, domain string, teamID uint, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) type GetCalendarPoliciesFunc func(ctx context.Context, teamID uint) ([]fleet.PolicyCalendarData, error) @@ -472,6 +472,8 @@ type UpdateCalendarEventFunc func(ctx context.Context, calendarEventID uint, sta type GetHostCalendarEventFunc func(ctx context.Context, hostID uint) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) +type GetHostCalendarEventByEmailFunc func(ctx context.Context, email string) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) + type UpdateHostCalendarWebhookStatusFunc func(ctx context.Context, hostID uint, status fleet.CalendarWebhookStatus) error type ListCalendarEventsFunc func(ctx context.Context, teamID *uint) ([]*fleet.CalendarEvent, error) @@ -1512,8 +1514,8 @@ type DataStore struct { PolicyQueriesForHostFunc PolicyQueriesForHostFunc PolicyQueriesForHostFuncInvoked bool - GetHostsPolicyMembershipsFunc GetHostsPolicyMembershipsFunc - GetHostsPolicyMembershipsFuncInvoked bool + GetTeamHostsPolicyMembershipsFunc GetTeamHostsPolicyMembershipsFunc + GetTeamHostsPolicyMembershipsFuncInvoked bool GetCalendarPoliciesFunc GetCalendarPoliciesFunc GetCalendarPoliciesFuncInvoked bool @@ -1560,6 +1562,9 @@ type DataStore struct { GetHostCalendarEventFunc GetHostCalendarEventFunc GetHostCalendarEventFuncInvoked bool + GetHostCalendarEventByEmailFunc GetHostCalendarEventByEmailFunc + GetHostCalendarEventByEmailFuncInvoked bool + UpdateHostCalendarWebhookStatusFunc UpdateHostCalendarWebhookStatusFunc UpdateHostCalendarWebhookStatusFuncInvoked bool @@ -3649,11 +3654,11 @@ func (s *DataStore) PolicyQueriesForHost(ctx context.Context, host *fleet.Host) return s.PolicyQueriesForHostFunc(ctx, host) } -func (s *DataStore) GetHostsPolicyMemberships(ctx context.Context, domain string, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) { +func (s *DataStore) GetTeamHostsPolicyMemberships(ctx context.Context, domain string, teamID uint, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) { s.mu.Lock() - s.GetHostsPolicyMembershipsFuncInvoked = true + s.GetTeamHostsPolicyMembershipsFuncInvoked = true s.mu.Unlock() - return s.GetHostsPolicyMembershipsFunc(ctx, domain, policyIDs) + return s.GetTeamHostsPolicyMembershipsFunc(ctx, domain, teamID, policyIDs) } func (s *DataStore) GetCalendarPolicies(ctx context.Context, teamID uint) ([]fleet.PolicyCalendarData, error) { @@ -3761,6 +3766,13 @@ func (s *DataStore) GetHostCalendarEvent(ctx context.Context, hostID uint) (*fle return s.GetHostCalendarEventFunc(ctx, hostID) } +func (s *DataStore) GetHostCalendarEventByEmail(ctx context.Context, email string) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) { + s.mu.Lock() + s.GetHostCalendarEventByEmailFuncInvoked = true + s.mu.Unlock() + return s.GetHostCalendarEventByEmailFunc(ctx, email) +} + func (s *DataStore) UpdateHostCalendarWebhookStatus(ctx context.Context, hostID uint, status fleet.CalendarWebhookStatus) error { s.mu.Lock() s.UpdateHostCalendarWebhookStatusFuncInvoked = true