diff --git a/changes/13287-filter-extensions-by-labels b/changes/13287-filter-extensions-by-labels new file mode 100644 index 0000000000..4ee9737ed7 --- /dev/null +++ b/changes/13287-filter-extensions-by-labels @@ -0,0 +1 @@ +* Add `labels` to the fleetd extensions feature to allow deploying extensions to hosts that belong to certain labels. diff --git a/ee/server/service/teams.go b/ee/server/service/teams.go index b8014bb8ec..33a777323f 100644 --- a/ee/server/service/teams.go +++ b/ee/server/service/teams.go @@ -259,7 +259,7 @@ func (svc *Service) ModifyTeamAgentOptions(ctx context.Context, teamID uint, tea } if teamOptions != nil { - if err := fleet.ValidateJSONAgentOptions(teamOptions); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, teamOptions, true); err != nil { err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOptions.Force && !applyOptions.DryRun { level.Info(svc.logger).Log("err", err, "msg", "force-apply team agent options with validation errors") @@ -694,7 +694,7 @@ func (svc *Service) ApplyTeamSpecs(ctx context.Context, specs []*fleet.TeamSpec, } if len(spec.AgentOptions) > 0 && !bytes.Equal(spec.AgentOptions, jsonNull) { - if err := fleet.ValidateJSONAgentOptions(spec.AgentOptions); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, spec.AgentOptions, true); err != nil { err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOpts.Force && !applyOpts.DryRun { level.Info(svc.logger).Log("err", err, "msg", "force-apply team agent options with validation errors") diff --git a/orbit/pkg/update/flag_runner.go b/orbit/pkg/update/flag_runner.go index 20134effcb..dbf8d85273 100644 --- a/orbit/pkg/update/flag_runner.go +++ b/orbit/pkg/update/flag_runner.go @@ -233,6 +233,8 @@ func (r *ExtensionRunner) DoExtensionConfigUpdate() (bool, error) { } } + log.Debug().Str("extensions", string(config.Extensions)).Msg("received extensions configuration") + var extensions fleet.Extensions err = json.Unmarshal(config.Extensions, &extensions) if err != nil { diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go index eb24b52b89..34e461e027 100644 --- a/server/datastore/mysql/labels.go +++ b/server/datastore/mysql/labels.go @@ -928,3 +928,31 @@ func (ds *Datastore) LabelsSummary(ctx context.Context) ([]*fleet.LabelSummary, } return labelsSummary, nil } + +// HostMemberOfAllLabels returns whether the given host is a member of all the provided labels. +// If the labels do not exist, then the host is considered not a member of the provided labels. +// A host will always be a member of an empty label set, so this method returns (true, nil) +// if labelNames is empty. +func (ds *Datastore) HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error) { + if len(labelNames) == 0 { + return true, nil + } + + sqlStatement := ` + SELECT COUNT(*) = ? FROM labels l + LEFT JOIN (SELECT label_id FROM label_membership WHERE host_id = ?) lm + ON l.id = lm.label_id + WHERE l.name IN (?) AND lm.label_id IS NOT NULL; + ` + sql, args, err := sqlx.In(sqlStatement, len(labelNames), hostID, labelNames) + if err != nil { + return false, ctxerr.Wrap(ctx, err, "building query to get label IDs") + } + + var ok bool + if err := sqlx.GetContext(ctx, ds.reader(ctx), &ok, sql, args...); err != nil { + return false, ctxerr.Wrap(ctx, err, "get label IDs") + } + + return ok, nil +} diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go index 86cedd44e9..2d904e19d6 100644 --- a/server/datastore/mysql/labels_test.go +++ b/server/datastore/mysql/labels_test.go @@ -65,6 +65,7 @@ func TestLabels(t *testing.T) { {"LabelsSummary", testLabelsSummary}, {"ListHostsInLabelFailingPolicies", testListHostsInLabelFailingPolicies}, {"ListHostsInLabelDiskEncryptionStatus", testListHostsInLabelDiskEncryptionStatus}, + {"HostMemberOfAllLabels", testHostMemberOfAllLabels}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -1142,3 +1143,185 @@ func testListHostsInLabelDiskEncryptionStatus(t *testing.T, ds *Datastore) { listHostsCheckCount(t, ds, fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{MacOSSettingsDiskEncryptionFilter: fleet.DiskEncryptionFailed}, 2) listHostsCheckCount(t, ds, fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{MacOSSettingsDiskEncryptionFilter: fleet.DiskEncryptionRemovingEnforcement}, 1) } + +func testHostMemberOfAllLabels(t *testing.T, ds *Datastore) { + ctx := context.Background() + + // + // Setup test + // - h1 member of 'All hosts', 'Foobar' and 'Zoobar' + // - h2 member of 'All hosts' and 'Foobar' + // - h3 member of 'All hosts' and 'Zoobar' + // - h4 member of 'All hosts' + // - h5 member of no labels + // + + allHostsLabel, err := ds.NewLabel(ctx, + &fleet.Label{ + Name: "All hosts", + Query: "SELECT 1", + LabelType: fleet.LabelTypeBuiltIn, + LabelMembershipType: fleet.LabelMembershipTypeDynamic, + }, + ) + require.NoError(t, err) + foobarLabel, err := ds.NewLabel(ctx, &fleet.Label{ + Name: "Foobar", + Query: "SELECT 1;", + LabelType: fleet.LabelTypeRegular, + LabelMembershipType: fleet.LabelMembershipTypeDynamic, + }) + require.NoError(t, err) + zoobarLabel, err := ds.NewLabel(ctx, &fleet.Label{ + Name: "Zoobar", + Query: "SELECT 2;", + LabelType: fleet.LabelTypeRegular, + LabelMembershipType: fleet.LabelMembershipTypeDynamic, + }) + require.NoError(t, err) + + newHostFunc := func(name string) *fleet.Host { + h, err := ds.NewHost(ctx, &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + OsqueryHostID: ptr.String(name), + NodeKey: ptr.String(name), + UUID: name, + Hostname: "foo.local" + name, + }) + require.NoError(t, err) + return h + } + + h1 := newHostFunc("h1") + h2 := newHostFunc("h2") + h3 := newHostFunc("h3") + h4 := newHostFunc("h4") + h5 := newHostFunc("h5") + _ = h5 + + err = ds.RecordLabelQueryExecutions(ctx, h1, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + foobarLabel.ID: ptr.Bool(true), + zoobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + err = ds.RecordLabelQueryExecutions(ctx, h2, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + foobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + err = ds.RecordLabelQueryExecutions(ctx, h3, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + zoobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + err = ds.RecordLabelQueryExecutions(ctx, h4, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + + // + // Run tests for HostMemberOfAllLabels + // + + for _, tc := range []struct { + name string + hostID uint + labelNames []string + expectedResult bool + }{ + { + name: "nonexistent host", + hostID: 999, + labelNames: []string{allHostsLabel.Name}, + expectedResult: false, + }, + { + name: "h1 does not belong to nonexistent label", + hostID: h1.ID, + labelNames: []string{"Non existent label"}, + expectedResult: false, + }, + { + name: "h1 does not belong to All hosts + nonexistent label", + hostID: h1.ID, + labelNames: []string{allHostsLabel.Name, "Non existent label"}, + expectedResult: false, + }, + { + name: "h1 belongs to the given subset of labels", + hostID: h1.ID, + labelNames: []string{allHostsLabel.Name, foobarLabel.Name}, + expectedResult: true, + }, + { + name: "h1 belongs to all the given labels", + hostID: h1.ID, + labelNames: []string{allHostsLabel.Name, foobarLabel.Name, zoobarLabel.Name}, + expectedResult: true, + }, + { + name: "h1 member of empty label set", + hostID: h1.ID, + labelNames: []string{}, + expectedResult: true, + }, + { + name: "h2 belongs to all the given labels", + hostID: h2.ID, + labelNames: []string{allHostsLabel.Name, foobarLabel.Name}, + expectedResult: true, + }, + { + name: "h2 does not belongs to all the given labels", + hostID: h2.ID, + labelNames: []string{allHostsLabel.Name, foobarLabel.Name, zoobarLabel.Name}, + expectedResult: false, + }, + { + name: "h2 belongs to the given label", + hostID: h2.ID, + labelNames: []string{foobarLabel.Name}, + expectedResult: true, + }, + { + name: "h2 does not belong to the given label", + hostID: h2.ID, + labelNames: []string{zoobarLabel.Name}, + expectedResult: false, + }, + { + name: "h3 belongs to all the given labels", + hostID: h3.ID, + labelNames: []string{allHostsLabel.Name, zoobarLabel.Name}, + expectedResult: true, + }, + { + name: "h4 belongs to all the given labels", + hostID: h4.ID, + labelNames: []string{allHostsLabel.Name}, + expectedResult: true, + }, + { + name: "h4 does not belong to the given labels", + hostID: h4.ID, + labelNames: []string{foobarLabel.Name}, + expectedResult: false, + }, + { + name: "h5 does not belong to the given labels", + hostID: h5.ID, + labelNames: []string{allHostsLabel.Name}, + expectedResult: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v, err := ds.HostMemberOfAllLabels(ctx, tc.hostID, tc.labelNames) + require.NoError(t, err) + require.Equal(t, tc.expectedResult, v) + }) + } +} diff --git a/server/fleet/agent_options.go b/server/fleet/agent_options.go index 6c37d879a8..cf085209cf 100644 --- a/server/fleet/agent_options.go +++ b/server/fleet/agent_options.go @@ -2,6 +2,7 @@ package fleet import ( "bytes" + "context" "encoding/json" "fmt" "strings" @@ -37,7 +38,7 @@ func (o *AgentOptions) ForPlatform(platform string) json.RawMessage { // Options payload. It ensures that all fields are known and have valid values. // The validation always uses the most recent Osquery version that is available // at the time of the Fleet release. -func ValidateJSONAgentOptions(rawJSON json.RawMessage) error { +func ValidateJSONAgentOptions(ctx context.Context, ds Datastore, rawJSON json.RawMessage, isPremium bool) error { var opts AgentOptions if err := JSONStrictDecode(bytes.NewReader(rawJSON), &opts); err != nil { return err @@ -55,6 +56,7 @@ func ValidateJSONAgentOptions(rawJSON json.RawMessage) error { return fmt.Errorf("common config: %w", err) } } + for platform, platformOpts := range opts.Overrides.Platforms { if len(platformOpts) > 0 { if err := validateJSONAgentOptionsSet(platformOpts); err != nil { @@ -62,6 +64,38 @@ func ValidateJSONAgentOptions(rawJSON json.RawMessage) error { } } } + + if len(opts.Extensions) > 0 { + if err := validateJSONAgentOptionsExtensions(ctx, ds, opts.Extensions, isPremium); err != nil { + return err + } + } + + return nil +} + +func validateJSONAgentOptionsExtensions(ctx context.Context, ds Datastore, optsExtensions json.RawMessage, isPremium bool) error { + var extensions map[string]ExtensionInfo + if err := json.Unmarshal(optsExtensions, &extensions); err != nil { + return fmt.Errorf("unmarshal extensions: %w", err) + } + for _, extensionInfo := range extensions { + if !isPremium && len(extensionInfo.Labels) != 0 { + // Setting labels settings in the extensions config is premium only. + return ErrMissingLicense + } + for _, labelName := range extensionInfo.Labels { + switch _, err := ds.GetLabelSpec(ctx, labelName); { + case err == nil: + // OK + case IsNotFound(err): + // Label does not exist, fail the request. + return fmt.Errorf("Label %q does not exist", labelName) + default: + return fmt.Errorf("get label by name: %w", err) + } + } + } return nil } diff --git a/server/fleet/agent_options_test.go b/server/fleet/agent_options_test.go index fdbae8a6e1..522e956705 100644 --- a/server/fleet/agent_options_test.go +++ b/server/fleet/agent_options_test.go @@ -1,6 +1,7 @@ package fleet import ( + "context" "errors" "testing" @@ -140,7 +141,7 @@ func TestValidateAgentOptions(t *testing.T) { for _, c := range cases { t.Run(c.desc, func(t *testing.T) { - err := ValidateJSONAgentOptions([]byte(c.in)) + err := ValidateJSONAgentOptions(context.Background(), nil, []byte(c.in), true) t.Logf("%T", errors.Unwrap(err)) if c.wantErr != "" { require.ErrorContains(t, err, c.wantErr) diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 1a8a533d98..55573e39b4 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -230,6 +230,12 @@ type Datastore interface { // HostIDsByOSID retrieves the IDs of all host for the given OS ID HostIDsByOSID(ctx context.Context, osID uint, offset int, limit int) ([]uint, error) + // HostMemberOfAllLabels returns whether the given host is a member of all the provided labels. + // If a label name does not exist, then the host is considered not a member of the provided label. + // A host will always be a member of an empty label set, so this method returns (true, nil) + // if labelNames is empty. + HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error) + // TODO JUAN: Refactor this to use the Operating System type instead. // HostIDsByOSVersion retrieves the IDs of all host matching osVersion HostIDsByOSVersion(ctx context.Context, osVersion OSVersion, offset int, limit int) ([]uint, error) diff --git a/server/fleet/orbit.go b/server/fleet/orbit.go index d4c640efc2..77ecd683f0 100644 --- a/server/fleet/orbit.go +++ b/server/fleet/orbit.go @@ -57,6 +57,8 @@ type ExtensionInfo struct { Platform string `json:"platform"` // Channel is the select TUF channel to listen for updates. Channel string `json:"channel"` + // Labels are the label names the host must be member of to run this extension. + Labels []string `json:"labels,omitempty"` } // Extensions holds a set of extensions to apply to an Orbit client. diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index c3b29b222e..28968f2336 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -174,6 +174,8 @@ type HostIDsByNameFunc func(ctx context.Context, filter fleet.TeamFilter, hostna type HostIDsByOSIDFunc func(ctx context.Context, osID uint, offset int, limit int) ([]uint, error) +type HostMemberOfAllLabelsFunc func(ctx context.Context, hostID uint, labelNames []string) (bool, error) + type HostIDsByOSVersionFunc func(ctx context.Context, osVersion fleet.OSVersion, offset int, limit int) ([]uint, error) type HostByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.Host, error) @@ -913,6 +915,9 @@ type DataStore struct { HostIDsByOSIDFunc HostIDsByOSIDFunc HostIDsByOSIDFuncInvoked bool + HostMemberOfAllLabelsFunc HostMemberOfAllLabelsFunc + HostMemberOfAllLabelsFuncInvoked bool + HostIDsByOSVersionFunc HostIDsByOSVersionFunc HostIDsByOSVersionFuncInvoked bool @@ -2218,6 +2223,13 @@ func (s *DataStore) HostIDsByOSID(ctx context.Context, osID uint, offset int, li return s.HostIDsByOSIDFunc(ctx, osID, offset, limit) } +func (s *DataStore) HostMemberOfAllLabels(ctx context.Context, hostID uint, labelNames []string) (bool, error) { + s.mu.Lock() + s.HostMemberOfAllLabelsFuncInvoked = true + s.mu.Unlock() + return s.HostMemberOfAllLabelsFunc(ctx, hostID, labelNames) +} + func (s *DataStore) HostIDsByOSVersion(ctx context.Context, osVersion fleet.OSVersion, offset int, limit int) ([]uint, error) { s.mu.Lock() s.HostIDsByOSVersionFuncInvoked = true diff --git a/server/service/appconfig.go b/server/service/appconfig.go index 7e8f94529c..0b279e5b8a 100644 --- a/server/service/appconfig.go +++ b/server/service/appconfig.go @@ -329,7 +329,7 @@ func (svc *Service) ModifyAppConfig(ctx context.Context, p []byte, applyOpts fle if newAppConfig.AgentOptions != nil { // if there were Agent Options in the new app config, then it replaced the // agent options in the resulting app config, so validate those. - if err := fleet.ValidateJSONAgentOptions(*appConfig.AgentOptions); err != nil { + if err := fleet.ValidateJSONAgentOptions(ctx, svc.ds, *appConfig.AgentOptions, license.IsPremium()); err != nil { err = fleet.NewUserMessageError(err, http.StatusBadRequest) if applyOpts.Force && !applyOpts.DryRun { level.Info(svc.logger).Log("err", err, "msg", "force-apply appConfig agent options with validation errors") diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 569a1652b0..f7243a1e6f 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -7610,6 +7610,122 @@ func (s *integrationTestSuite) TestDirectIngestSoftwareWithInvalidFields() { require.NotZero(t, wiresharkSoftware.ID) } +func (s *integrationTestSuite) TestOrbitConfigExtensions() { + t := s.T() + ctx := context.Background() + + appCfg, err := s.ds.AppConfig(ctx) + require.NoError(t, err) + defer func() { + err = s.ds.SaveAppConfig(ctx, appCfg) + require.NoError(t, err) + }() + + // Orbit client gets no extensions if extensions are not configured. + orbitLinuxClient := createOrbitEnrolledHost(t, "linux", "foobar1", s.ds) + resp := orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp) + require.Empty(t, resp.Extensions) + + // Attempt to add extensions (should succeed). + s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{ + "agent_options": { + "config": { + "options": { + "pack_delimiter": "/", + "logger_tls_period": 10, + "distributed_plugin": "tls", + "disable_distributed": false, + "logger_tls_endpoint": "/api/osquery/log", + "distributed_interval": 10, + "distributed_tls_max_attempts": 3 + } + }, + "extensions": { + "hello_world_linux": { + "channel": "stable", + "platform": "linux" + }, + "hello_mars_linux": { + "channel": "stable", + "platform": "linux" + }, + "hello_world_macos": { + "channel": "stable", + "platform": "macos" + } + } + } +}`), http.StatusOK) + + // Attempt to add labels to extensions (only available on premium). + s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{ + "agent_options": { + "config": { + "options": { + "pack_delimiter": "/", + "logger_tls_period": 10, + "distributed_plugin": "tls", + "disable_distributed": false, + "logger_tls_endpoint": "/api/osquery/log", + "distributed_interval": 10, + "distributed_tls_max_attempts": 3 + } + }, + "extensions": { + "hello_world_linux": { + "channel": "stable", + "platform": "linux" + }, + "hello_world_macos": { + "labels": [ + "All hosts", + "Some label" + ], + "channel": "stable", + "platform": "macos" + }, + "hello_world_windows": { + "channel": "stable", + "platform": "windows" + } + } + } +}`), http.StatusBadRequest) + + // Orbit client gets extensions configured for its platform. + orbitDarwinClient := createOrbitEnrolledHost(t, "darwin", "foobar2", s.ds) + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp) + require.JSONEq(t, `{ + "hello_world_macos": { + "platform": "macos", + "channel": "stable" + } + }`, string(resp.Extensions)) + + orbitWindowsClient := createOrbitEnrolledHost(t, "windows", "foobar3", s.ds) + + // Orbit client gets no extensions if none of the platforms target it. + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitWindowsClient.OrbitNodeKey)), http.StatusOK, &resp) + require.Empty(t, resp.Extensions) + + // Orbit client gets the two extensions configured for its platform. + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp) + require.JSONEq(t, `{ + "hello_world_linux": { + "channel": "stable", + "platform": "linux" + }, + "hello_mars_linux": { + "channel": "stable", + "platform": "linux" + } + }`, string(resp.Extensions)) +} + func (s *integrationTestSuite) TestHostsReportWithPolicyResults() { t := s.T() ctx := context.Background() diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index a4c0c9cd62..168567157e 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3924,3 +3924,185 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() { errMsg = extractServerErrorText(res.Body) require.Contains(t, errMsg, fleet.RunScriptHostOfflineErrMsg) } + +func (s *integrationEnterpriseTestSuite) TestOrbitConfigExtensions() { + t := s.T() + ctx := context.Background() + + appCfg, err := s.ds.AppConfig(ctx) + require.NoError(t, err) + defer func() { + err = s.ds.SaveAppConfig(ctx, appCfg) + require.NoError(t, err) + }() + + foobarLabel, err := s.ds.NewLabel(ctx, &fleet.Label{ + Name: "Foobar", + Query: "SELECT 1;", + }) + require.NoError(t, err) + zoobarLabel, err := s.ds.NewLabel(ctx, &fleet.Label{ + Name: "Zoobar", + Query: "SELECT 1;", + }) + require.NoError(t, err) + allHostsLabel, err := s.ds.GetLabelSpec(ctx, "All hosts") + require.NoError(t, err) + + orbitDarwinClient := createOrbitEnrolledHost(t, "darwin", "foobar1", s.ds) + orbitLinuxClient := createOrbitEnrolledHost(t, "linux", "foobar2", s.ds) + orbitWindowsClient := createOrbitEnrolledHost(t, "windows", "foobar3", s.ds) + + // orbitDarwinClient is member of 'All hosts' and 'Zoobar' labels. + err = s.ds.RecordLabelQueryExecutions(ctx, orbitDarwinClient, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + zoobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + // orbitLinuxClient is member of 'All hosts' and 'Foobar' labels. + err = s.ds.RecordLabelQueryExecutions(ctx, orbitLinuxClient, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + foobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + // orbitWindowsClient is member of the 'All hosts' label only. + err = s.ds.RecordLabelQueryExecutions(ctx, orbitWindowsClient, map[uint]*bool{ + allHostsLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + + // Attempt to add labels to extensions. + s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{ + "agent_options": { + "config": { + "options": { + "pack_delimiter": "/", + "logger_tls_period": 10, + "distributed_plugin": "tls", + "disable_distributed": false, + "logger_tls_endpoint": "/api/osquery/log", + "distributed_interval": 10, + "distributed_tls_max_attempts": 3 + } + }, + "extensions": { + "hello_world_linux": { + "labels": [ + "All hosts", + "Foobar" + ], + "channel": "stable", + "platform": "linux" + }, + "hello_world_macos": { + "labels": [ + "All hosts", + "Foobar" + ], + "channel": "stable", + "platform": "macos" + }, + "hello_mars_macos": { + "labels": [ + "All hosts", + "Zoobar" + ], + "channel": "stable", + "platform": "macos" + }, + "hello_world_windows": { + "labels": [ + "Zoobar" + ], + "channel": "stable", + "platform": "windows" + }, + "hello_mars_windows": { + "labels": [ + "Foobar" + ], + "channel": "stable", + "platform": "windows" + } + } + } +}`), http.StatusOK) + + resp := orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp) + require.JSONEq(t, `{ + "hello_mars_macos": { + "channel": "stable", + "platform": "macos" + } + }`, string(resp.Extensions)) + + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp) + require.JSONEq(t, `{ + "hello_world_linux": { + "channel": "stable", + "platform": "linux" + } + }`, string(resp.Extensions)) + + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitWindowsClient.OrbitNodeKey)), http.StatusOK, &resp) + require.Empty(t, string(resp.Extensions)) + + // orbitDarwinClient is now also a member of the 'Foobar' label. + err = s.ds.RecordLabelQueryExecutions(ctx, orbitDarwinClient, map[uint]*bool{ + foobarLabel.ID: ptr.Bool(true), + }, time.Now(), false) + require.NoError(t, err) + + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitDarwinClient.OrbitNodeKey)), http.StatusOK, &resp) + require.JSONEq(t, `{ + "hello_world_macos": { + "channel": "stable", + "platform": "macos" + }, + "hello_mars_macos": { + "channel": "stable", + "platform": "macos" + } + }`, string(resp.Extensions)) + + // orbitLinuxClient is no longer a member of the 'Foobar' label. + err = s.ds.RecordLabelQueryExecutions(ctx, orbitLinuxClient, map[uint]*bool{ + foobarLabel.ID: nil, + }, time.Now(), false) + require.NoError(t, err) + + resp = orbitGetConfigResponse{} + s.DoJSON("POST", "/api/fleet/orbit/config", json.RawMessage(fmt.Sprintf(`{"orbit_node_key": %q}`, *orbitLinuxClient.OrbitNodeKey)), http.StatusOK, &resp) + require.Empty(t, string(resp.Extensions)) + + // Attempt to set non-existent labels in the config. + s.DoRaw("PATCH", "/api/latest/fleet/config", []byte(`{ + "agent_options": { + "config": { + "options": { + "pack_delimiter": "/", + "logger_tls_period": 10, + "distributed_plugin": "tls", + "disable_distributed": false, + "logger_tls_endpoint": "/api/osquery/log", + "distributed_interval": 10, + "distributed_tls_max_attempts": 3 + } + }, + "extensions": { + "hello_world_linux": { + "labels": [ + "All hosts", + "Doesn't exist" + ], + "channel": "stable", + "platform": "linux" + } + } + } +}`), http.StatusBadRequest) +} diff --git a/server/service/orbit.go b/server/service/orbit.go index a630344bfa..28df989cd1 100644 --- a/server/service/orbit.go +++ b/server/service/orbit.go @@ -11,6 +11,7 @@ import ( "github.com/fleetdm/fleet/v4/server/config" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" hostctx "github.com/fleetdm/fleet/v4/server/contexts/host" + "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/kit/log/level" @@ -168,8 +169,6 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro // this is not a user-authenticated endpoint svc.authz.SkipAuthorization(ctx) - var notifs fleet.OrbitConfigNotifications - host, ok := hostctx.FromContext(ctx) if !ok { return fleet.OrbitConfig{}, fleet.OrbitError{Message: "internal error: missing host from request context"} @@ -181,6 +180,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro } // set the host's orbit notifications for macOS MDM + var notifs fleet.OrbitConfigNotifications if appConfig.MDM.EnabledAndConfigured && host.IsOsqueryEnrolled() { // TODO(mna): all those notifications implied a macos hosts, but none of // the checks enforce that (only indirectly in some cases, like @@ -250,7 +250,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro } } - extensionsFiltered, err := filterExtensionsByPlatform(opts.Extensions, host.Platform) + extensionsFiltered, err := svc.filterExtensionsForHost(ctx, opts.Extensions, host) if err != nil { return fleet.OrbitConfig{}, err } @@ -286,7 +286,7 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro } } - extensionsFiltered, err := filterExtensionsByPlatform(opts.Extensions, host.Platform) + extensionsFiltered, err := svc.filterExtensionsForHost(ctx, opts.Extensions, host) if err != nil { return fleet.OrbitConfig{}, err } @@ -308,20 +308,45 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro }, nil } -// filterExtensionsByPlatform filters a extensions configuration depending on the host platform. -// (to not send extensions targeted to other operating systems). -func filterExtensionsByPlatform(extensions json.RawMessage, hostPlatform string) (json.RawMessage, error) { +// filterExtensionsForHost filters a extensions configuration depending on the host platform and label membership. +// +// If all extensions are filtered, then it returns (nil, nil) (Orbit expects empty extensions if there +// are no extensions for the host.) +func (svc *Service) filterExtensionsForHost(ctx context.Context, extensions json.RawMessage, host *fleet.Host) (json.RawMessage, error) { if len(extensions) == 0 { - return extensions, nil + return nil, nil } var extensionsInfo fleet.Extensions if err := json.Unmarshal(extensions, &extensionsInfo); err != nil { - return nil, err + return nil, ctxerr.Wrap(ctx, err, "unmarshal extensions config") + } + + // Filter the extensions by platform. + extensionsInfo.FilterByHostPlatform(host.Platform) + + // Filter the extensions by labels (premium only feature). + if license, _ := license.FromContext(ctx); license != nil && license.IsPremium() { + for extensionName, extensionInfo := range extensionsInfo { + hostIsMemberOfAllLabels, err := svc.ds.HostMemberOfAllLabels(ctx, host.ID, extensionInfo.Labels) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "check host labels") + } + if hostIsMemberOfAllLabels { + // Do not filter out, but there's no need to send the label names to the devices. + extensionInfo.Labels = nil + extensionsInfo[extensionName] = extensionInfo + } else { + delete(extensionsInfo, extensionName) + } + } + } + // Orbit expects empty message if no extensions apply. + if len(extensionsInfo) == 0 { + return nil, nil } - extensionsInfo.FilterByHostPlatform(hostPlatform) extensionsFiltered, err := json.Marshal(extensionsInfo) if err != nil { - return nil, err + return nil, ctxerr.Wrap(ctx, err, "marshal extensions config") } return extensionsFiltered, nil }