diff --git a/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx b/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx index d83e04623d..48d03553be 100644 --- a/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx +++ b/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx @@ -22,6 +22,7 @@ class EditPackForm extends Component { onCancel: PropTypes.func.isRequired, onFetchTargets: PropTypes.func, targetsCount: PropTypes.number, + isBasicTier: PropTypes.bool, }; render() { @@ -32,6 +33,7 @@ class EditPackForm extends Component { onCancel, onFetchTargets, targetsCount, + isBasicTier, } = this.props; return ( @@ -58,6 +60,7 @@ class EditPackForm extends Component { onSelect={fields.targets.onChange} selectedTargets={fields.targets.value} targetsCount={targetsCount} + isBasicTier={isBasicTier} />
diff --git a/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx b/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx index 808b2872ad..42a7070da1 100644 --- a/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx +++ b/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx @@ -21,6 +21,7 @@ class EditPackFormWrapper extends Component { pack: packInterface.isRequired, packTargets: PropTypes.arrayOf(targetInterface), targetsCount: PropTypes.number, + isBasicTier: PropTypes.bool, }; render() { @@ -34,6 +35,7 @@ class EditPackFormWrapper extends Component { pack, packTargets, targetsCount, + isBasicTier, } = this.props; if (isEdit) { @@ -45,6 +47,7 @@ class EditPackFormWrapper extends Component { onCancel={onCancelEditPack} onFetchTargets={onFetchTargets} targetsCount={targetsCount} + isBasicTier={isBasicTier} /> ); } @@ -73,6 +76,7 @@ class EditPackFormWrapper extends Component { selectedTargets={packTargets} targetsCount={targetsCount} disabled + isBasicTier={isBasicTier} className={`${baseClass}__select-targets`} />
diff --git a/frontend/fleet/entities/teams.ts b/frontend/fleet/entities/teams.ts index cf3dadf4c3..5230f61539 100644 --- a/frontend/fleet/entities/teams.ts +++ b/frontend/fleet/entities/teams.ts @@ -48,7 +48,7 @@ export default (client: any) => { page = 0, perPage = 100, globalFilter = "", - }: ITeamSearchOptions) => { + }: ITeamSearchOptions = {}) => { const { TEAMS } = endpoints; // TODO: add this query param logic to client class diff --git a/frontend/fleet/helpers.ts b/frontend/fleet/helpers.ts index d6aedfa4b9..58f3aa3109 100644 --- a/frontend/fleet/helpers.ts +++ b/frontend/fleet/helpers.ts @@ -279,6 +279,20 @@ export const formatScheduledQueryForClient = (scheduledQuery: any): any => { return scheduledQuery; }; +export const formatTeamForClient = (team: any): any => { + if (team.display_text === undefined) { + team.display_text = team.name; + } + return team; +}; + +export const formatPackForClient = (pack: any): any => { + pack.host_ids ||= []; + pack.label_ids ||= []; + pack.team_ids ||= []; + return pack; +}; + const setupData = (formData: any) => { const orgInfo = pick(formData, ORG_INFO_ATTRS); const adminInfo = pick(formData, ADMIN_ATTRS); @@ -355,6 +369,7 @@ export const secondsToHms = (d: number): string => { const sDisplay = s > 0 ? s + (s === 1 ? " sec " : " secs ") : ""; return hDisplay + mDisplay + sDisplay; }; + export default { addGravatarUrlToResource, formatConfigDataForServer, diff --git a/frontend/pages/packs/EditPackPage/EditPackPage.jsx b/frontend/pages/packs/EditPackPage/EditPackPage.jsx index 702360d393..fe2400d1a7 100644 --- a/frontend/pages/packs/EditPackPage/EditPackPage.jsx +++ b/frontend/pages/packs/EditPackPage/EditPackPage.jsx @@ -4,12 +4,15 @@ import { connect } from "react-redux"; import { filter, includes, isEqual, noop, size, find } from "lodash"; import { push } from "react-router-redux"; +import permissionUtils from "utilities/permissions"; import deepDifference from "utilities/deep_difference"; import EditPackFormWrapper from "components/packs/EditPackFormWrapper"; import hostActions from "redux/nodes/entities/hosts/actions"; import hostInterface from "interfaces/host"; import labelActions from "redux/nodes/entities/labels/actions"; +import teamActions from "redux/nodes/entities/teams/actions"; import labelInterface from "interfaces/label"; +import teamInterface from "interfaces/team"; import packActions from "redux/nodes/entities/packs/actions"; import ScheduleQuerySidePanel from "components/side_panels/ScheduleQuerySidePanel"; import packInterface from "interfaces/pack"; @@ -34,7 +37,9 @@ export class EditPackPage extends Component { packHosts: PropTypes.arrayOf(hostInterface), packID: PropTypes.string, packLabels: PropTypes.arrayOf(labelInterface), + packTeams: PropTypes.arrayOf(teamInterface), scheduledQueries: PropTypes.arrayOf(queryInterface), + isBasicTier: PropTypes.bool, }; static defaultProps = { @@ -60,6 +65,7 @@ export class EditPackPage extends Component { packHosts, packID, packLabels, + packTeams, scheduledQueries, } = this.props; const { load } = packActions; @@ -77,6 +83,10 @@ export class EditPackPage extends Component { if (!packLabels || packLabels.length !== pack.label_ids.length) { dispatch(labelActions.loadAll()); } + + if (!packTeams || packTeams.length !== pack.team_ids.length) { + dispatch(teamActions.loadAll()); + } } if (!size(scheduledQueries)) { @@ -90,7 +100,13 @@ export class EditPackPage extends Component { return false; } - componentWillReceiveProps({ dispatch, pack, packHosts, packLabels }) { + componentWillReceiveProps({ + dispatch, + pack, + packHosts, + packLabels, + packTeams, + }) { if (!isEqual(pack, this.props.pack)) { if (!packHosts || packHosts.length !== pack.host_ids.length) { dispatch(hostActions.loadAll()); @@ -99,6 +115,10 @@ export class EditPackPage extends Component { if (!packLabels || packLabels.length !== pack.label_ids.length) { dispatch(labelActions.loadAll()); } + + if (!packTeams || packTeams.length !== pack.team_ids.length) { + dispatch(teamActions.loadAll()); + } } return false; @@ -241,10 +261,12 @@ export class EditPackPage extends Component { pack, packHosts, packLabels, + packTeams, scheduledQueries, + isBasicTier, } = this.props; - const packTargets = [...packHosts, ...packLabels]; + const packTargets = [...packHosts, ...packLabels, ...packTeams]; if (!pack || isLoadingPack || isLoadingScheduledQueries) { return false; @@ -263,6 +285,7 @@ export class EditPackPage extends Component { pack={pack} packTargets={packTargets} targetsCount={targetsCount} + isBasicTier={isBasicTier} /> { return includes(pack.label_ids, label.id); }) : []; + const packTeams = pack + ? filter(state.entities.teams.data, (team) => { + return includes(pack.team_ids, team.id); + }) + : []; + const isBasicTier = permissionUtils.isBasicTier(state.app.config); return { allQueries, @@ -317,7 +346,9 @@ const mapStateToProps = (state, { params, route }) => { packHosts, packID, packLabels, + packTeams, scheduledQueries, + isBasicTier, }; }; diff --git a/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx b/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx index a3fc8e629c..c76de56da6 100644 --- a/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx +++ b/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx @@ -3,7 +3,12 @@ import { mount } from "enzyme"; import { noop } from "lodash"; import { connectedComponent, reduxMockStore } from "test/helpers"; -import { packStub, queryStub, scheduledQueryStub } from "test/stubs"; +import { + packStub, + queryStub, + scheduledQueryStub, + configStub, +} from "test/stubs"; import ConnectedEditPackPage, { EditPackPage, } from "pages/packs/EditPackPage/EditPackPage"; @@ -12,6 +17,7 @@ import labelActions from "redux/nodes/entities/labels/actions"; import packActions from "redux/nodes/entities/packs/actions"; import queryActions from "redux/nodes/entities/queries/actions"; import scheduledQueryActions from "redux/nodes/entities/scheduled_queries/actions"; +import teamActions from "redux/nodes/entities/teams/actions"; describe("EditPackPage - component", () => { beforeEach(() => { @@ -21,15 +27,18 @@ describe("EditPackPage - component", () => { jest.spyOn(labelActions, "loadAll").mockImplementation(() => spyResponse); jest.spyOn(packActions, "load").mockImplementation(() => spyResponse); jest.spyOn(queryActions, "loadAll").mockImplementation(() => spyResponse); + jest.spyOn(teamActions, "loadAll").mockImplementation(() => spyResponse); jest .spyOn(scheduledQueryActions, "loadAll") .mockImplementation(() => spyResponse); }); const store = { + app: { config: configStub }, entities: { hosts: { loading: false, data: {} }, labels: { loading: false, data: {} }, + teams: { loading: false, data: {} }, packs: { loading: false, data: { @@ -49,6 +58,7 @@ describe("EditPackPage - component", () => { describe("rendering", () => { it("does not render when packs are loading", () => { const packsLoadingStore = { + app: store.app, entities: { ...store.entities, packs: { ...store.entities.packs, loading: true }, @@ -67,6 +77,7 @@ describe("EditPackPage - component", () => { it("does not render when scheduled queries are loading", () => { const scheduledQueriesLoadingStore = { + app: store.app, entities: { ...store.entities, scheduled_queries: { @@ -88,6 +99,7 @@ describe("EditPackPage - component", () => { it("does not render when there is no pack", () => { const noPackStore = { + app: store.app, entities: { ...store.entities, packs: { data: {}, loading: false }, @@ -130,6 +142,7 @@ describe("EditPackPage - component", () => { isEdit: false, packHosts: [], packLabels: [], + packTeams: [], scheduledQueries: [], }; @@ -155,6 +168,7 @@ describe("EditPackPage - component", () => { packHosts: [], packID: String(packStub.id), packLabels: [], + packTeams: [], scheduledQueries: [scheduledQuery], }; @@ -206,6 +220,7 @@ describe("EditPackPage - component", () => { packHosts: [], packID: String(packStub.id), packLabels: [], + packTeams: [], scheduledQueries: [scheduledQuery], }; const pushAction = { diff --git a/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx b/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx index 6d287aa11c..f6ca290414 100644 --- a/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx +++ b/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx @@ -8,6 +8,7 @@ import packActions from "redux/nodes/entities/packs/actions"; import PackForm from "components/forms/packs/PackForm"; import PackInfoSidePanel from "components/side_panels/PackInfoSidePanel"; import PATHS from "router/paths"; +import permissionUtils from "utilities/permissions"; const baseClass = "pack-composer"; @@ -17,6 +18,7 @@ export class PackComposerPage extends Component { serverErrors: PropTypes.shape({ base: PropTypes.string, }), + isBasicTier: PropTypes.bool, }; static defaultProps = { @@ -60,7 +62,7 @@ export class PackComposerPage extends Component { render() { const { handleSubmit, onFetchTargets } = this; const { selectedTargetsCount } = this.state; - const { serverErrors } = this.props; + const { serverErrors, isBasicTier } = this.props; return (
@@ -70,6 +72,7 @@ export class PackComposerPage extends Component { onFetchTargets={onFetchTargets} selectedTargetsCount={selectedTargetsCount} serverErrors={serverErrors} + isBasicTier={isBasicTier} />
@@ -79,8 +82,9 @@ export class PackComposerPage extends Component { const mapStateToProps = (state) => { const { errors: serverErrors } = state.entities.packs; + const isBasicTier = permissionUtils.isBasicTier(state.app.config); - return { serverErrors }; + return { serverErrors, isBasicTier }; }; export default connect(mapStateToProps)(PackComposerPage); diff --git a/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx b/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx index 99be6fc9d5..e351538859 100644 --- a/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx +++ b/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx @@ -1,5 +1,6 @@ import { mount } from "enzyme"; +import { configStub } from "test/stubs"; import { connectedComponent, reduxMockStore } from "test/helpers"; import ConnectedPacksComposerPage from "./PackComposerPage"; @@ -8,6 +9,7 @@ describe("PackComposerPage - component", () => { entities: { packs: {}, }, + app: { config: configStub }, }); it("renders", () => { const page = mount( diff --git a/frontend/redux/nodes/entities/packs/config.js b/frontend/redux/nodes/entities/packs/config.js index 8543e12f84..34c439807f 100644 --- a/frontend/redux/nodes/entities/packs/config.js +++ b/frontend/redux/nodes/entities/packs/config.js @@ -1,6 +1,7 @@ import Fleet from "fleet"; import Config from "redux/nodes/entities/base/config"; import schemas from "redux/nodes/entities/base/schemas"; +import { formatPackForClient } from "fleet/helpers"; const { PACKS: schema } = schemas; @@ -10,6 +11,7 @@ export default new Config({ entityName: "packs", loadAllFunc: Fleet.packs.loadAll, loadFunc: Fleet.packs.load, + parseEntityFunc: formatPackForClient, schema, updateFunc: Fleet.packs.update, }); diff --git a/frontend/redux/nodes/entities/scheduled_queries/config.js b/frontend/redux/nodes/entities/scheduled_queries/config.js index 68f25100db..67a6445e8b 100644 --- a/frontend/redux/nodes/entities/scheduled_queries/config.js +++ b/frontend/redux/nodes/entities/scheduled_queries/config.js @@ -1,4 +1,4 @@ -import helpers from "fleet/helpers"; +import { formatScheduledQueryForClient } from "fleet/helpers"; import Fleet from "fleet"; import Config from "redux/nodes/entities/base/config"; import schemas from "redux/nodes/entities/base/schemas"; @@ -10,7 +10,7 @@ export default new Config({ destroyFunc: Fleet.scheduledQueries.destroy, entityName: "scheduled_queries", loadAllFunc: Fleet.scheduledQueries.loadAll, - parseEntityFunc: helpers.formatScheduledQueryForClient, + parseEntityFunc: formatScheduledQueryForClient, schema, updateFunc: Fleet.scheduledQueries.update, }); diff --git a/frontend/redux/nodes/entities/teams/config.ts b/frontend/redux/nodes/entities/teams/config.ts index ac44b02b3d..d7a461c989 100644 --- a/frontend/redux/nodes/entities/teams/config.ts +++ b/frontend/redux/nodes/entities/teams/config.ts @@ -5,6 +5,7 @@ import Fleet from "fleet"; import Config from "redux/nodes/entities/base/config"; // @ts-ignore import schemas from "redux/nodes/entities/base/schemas"; +import { formatTeamForClient } from "fleet/helpers"; const { TEAMS } = schemas; @@ -14,6 +15,7 @@ export default new Config({ entityName: "teams", loadFunc: Fleet.teams.load, loadAllFunc: Fleet.teams.loadAll, + parseEntityFunc: formatTeamForClient, schema: TEAMS, updateFunc: Fleet.teams.update, }); diff --git a/frontend/test/stubs.ts b/frontend/test/stubs.ts index 5f376d4151..b87dab35c7 100644 --- a/frontend/test/stubs.ts +++ b/frontend/test/stubs.ts @@ -123,6 +123,7 @@ export const packStub = { disabled: false, host_ids: [], label_ids: [], + team_ids: [], }; export const queryStub = { diff --git a/server/datastore/datastore.go b/server/datastore/datastore.go index 28d74f8532..3c8692c4f9 100644 --- a/server/datastore/datastore.go +++ b/server/datastore/datastore.go @@ -26,13 +26,12 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){ testSaveQuery, testListQuery, testDeletePack, - testNewPack, + testSavePack, testEnrollHost, testAuthenticateHost, testAuthenticateHostCaseSensitive, testLabels, testSaveLabel, - testManagingLabelsOnPacks, testPasswordResetRequests, testCreateUser, testSaveUser, @@ -52,7 +51,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){ testListHostsFilterAdditional, testListHostsStatus, testListHostsQuery, - testListHostsInPack, testListPacksForHost, testHostIDsByName, testHostByIdentifier, @@ -69,7 +67,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){ testCascadingDeletionOfQueries, testGetPackByName, testGetQueryByName, - testAddLabelToPackTwice, testGenerateHostStatusStatistics, testMarkHostSeen, testMarkHostsSeen, @@ -90,7 +87,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){ testApplyLabelSpecsRoundtrip, testGetLabelSpec, testLabelIDsByName, - testListLabelsForPack, testHostAdditional, testCarveMetadata, testCarveBlocks, diff --git a/server/datastore/datastore_labels.go b/server/datastore/datastore_labels.go index 66572d976e..ae14dfd911 100644 --- a/server/datastore/datastore_labels.go +++ b/server/datastore/datastore_labels.go @@ -156,62 +156,6 @@ func testLabels(t *testing.T, db fleet.Datastore) { assert.Len(t, labels, 1) } -func testManagingLabelsOnPacks(t *testing.T, ds fleet.Datastore) { - pack := &fleet.PackSpec{ - ID: 1, - Name: "pack1", - } - err := ds.ApplyPackSpecs([]*fleet.PackSpec{pack}) - require.Nil(t, err) - - labels, err := ds.ListLabelsForPack(pack.ID) - require.Nil(t, err) - assert.Len(t, labels, 0) - - mysqlLabel := &fleet.LabelSpec{ - ID: 1, - Name: "MySQL Monitoring", - Query: "select pid from processes where name = 'mysqld';", - } - err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel}) - require.Nil(t, err) - - pack.Targets = fleet.PackSpecTargets{ - Labels: []string{ - mysqlLabel.Name, - }, - } - err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack}) - require.Nil(t, err) - - labels, err = ds.ListLabelsForPack(pack.ID) - require.Nil(t, err) - if assert.Len(t, labels, 1) { - assert.Equal(t, "MySQL Monitoring", labels[0].Name) - } - - osqueryLabel := &fleet.LabelSpec{ - ID: 2, - Name: "Osquery Monitoring", - Query: "select pid from processes where name = 'osqueryd';", - } - err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel, osqueryLabel}) - require.Nil(t, err) - - pack.Targets = fleet.PackSpecTargets{ - Labels: []string{ - mysqlLabel.Name, - osqueryLabel.Name, - }, - } - err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack}) - require.Nil(t, err) - - labels, err = ds.ListLabelsForPack(pack.ID) - require.Nil(t, err) - assert.Len(t, labels, 2) -} - func testSearchLabels(t *testing.T, db fleet.Datastore) { specs := []*fleet.LabelSpec{ &fleet.LabelSpec{ diff --git a/server/datastore/datastore_packs.go b/server/datastore/datastore_packs.go index 27db57af35..3a37c4265a 100644 --- a/server/datastore/datastore_packs.go +++ b/server/datastore/datastore_packs.go @@ -26,18 +26,38 @@ func testDeletePack(t *testing.T, ds fleet.Datastore) { assert.NotNil(t, err) } -func testNewPack(t *testing.T, ds fleet.Datastore) { - pack := &fleet.Pack{ - Name: "foo", +func testSavePack(t *testing.T, ds fleet.Datastore) { + expectedPack := &fleet.Pack{ + Name: "foo", + HostIDs: []uint{1}, + LabelIDs: []uint{1}, + TeamIDs: []uint{1}, } - pack, err := ds.NewPack(pack) + pack, err := ds.NewPack(expectedPack) require.NoError(t, err) assert.NotEqual(t, uint(0), pack.ID) + test.EqualSkipTimestampsID(t, expectedPack, pack) pack, err = ds.Pack(pack.ID) require.NoError(t, err) - assert.Equal(t, "foo", pack.Name) + test.EqualSkipTimestampsID(t, expectedPack, pack) + + expectedPack = &fleet.Pack{ + ID: pack.ID, + Name: "bar", + HostIDs: []uint{3}, + LabelIDs: []uint{4, 6}, + TeamIDs: []uint{}, + } + + err = ds.SavePack(expectedPack) + require.NoError(t, err) + + pack, err = ds.Pack(pack.ID) + require.NoError(t, err) + assert.Equal(t, "bar", pack.Name) + test.EqualSkipTimestampsID(t, expectedPack, pack) } func testGetPackByName(t *testing.T, ds fleet.Datastore) { @@ -81,92 +101,8 @@ func testListPacks(t *testing.T, ds fleet.Datastore) { assert.Len(t, packs, 2) } -func testListHostsInPack(t *testing.T, ds fleet.Datastore) { - if ds.Name() == "inmem" { - t.Skip("inmem is deprecated") - } - - mockClock := clock.NewMockClock() - - l1 := fleet.LabelSpec{ - ID: 1, - Name: "foo", - } - err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1}) - require.Nil(t, err) - - p1 := &fleet.PackSpec{ - ID: 1, - Name: "foo_pack", - Targets: fleet.PackSpecTargets{ - Labels: []string{ - l1.Name, - }, - }, - } - err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1}) - require.Nil(t, err) - - h1 := test.NewHost(t, ds, "h1.local", "10.10.10.1", "1", "1", mockClock.Now()) - - hostsInPack, err := ds.ListHostsInPack(p1.ID, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, hostsInPack, 0) - - err = ds.RecordLabelQueryExecutions( - h1, - map[uint]bool{l1.ID: true}, - mockClock.Now(), - ) - require.Nil(t, err) - - hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, hostsInPack, 1) - - explicitHostsInPack, err := ds.ListExplicitHostsInPack(p1.ID, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, explicitHostsInPack, 0) - - h2 := test.NewHost(t, ds, "h2.local", "10.10.10.2", "2", "2", mockClock.Now()) - - err = ds.RecordLabelQueryExecutions( - h2, - map[uint]bool{l1.ID: true}, - mockClock.Now(), - ) - require.Nil(t, err) - - hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{}) - require.Nil(t, err) - require.Len(t, hostsInPack, 2) -} - -func testAddLabelToPackTwice(t *testing.T, ds fleet.Datastore) { - l1 := fleet.LabelSpec{ - ID: 1, - Name: "l1", - Query: "select 1", - } - err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1}) - require.Nil(t, err) - - p1 := &fleet.PackSpec{ - ID: 1, - Name: "pack1", - Targets: fleet.PackSpecTargets{ - Labels: []string{ - l1.Name, - l1.Name, - }, - }, - } - err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1}) - require.NotNil(t, err) -} - func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { - zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@fleet.co", true) + zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@example.com", true) queries := []*fleet.Query{ {Name: "foo", Description: "get the foos", Query: "select * from foo"}, {Name: "bar", Description: "do some bars", Query: "select baz from bar"}, @@ -176,15 +112,15 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { require.Nil(t, err) labels := []*fleet.LabelSpec{ - &fleet.LabelSpec{ + { Name: "foo", Query: "select * from foo", }, - &fleet.LabelSpec{ + { Name: "bar", Query: "select * from bar", }, - &fleet.LabelSpec{ + { Name: "bing", Query: "select * from bing", }, @@ -193,7 +129,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { require.Nil(t, err) expectedSpecs := []*fleet.PackSpec{ - &fleet.PackSpec{ + { ID: 1, Name: "test_pack", Targets: fleet.PackSpecTargets{ @@ -204,20 +140,20 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { }, }, Queries: []fleet.PackSpecQuery{ - fleet.PackSpecQuery{ + { QueryName: queries[0].Name, Name: "q0", Description: "test_foo", Interval: 42, }, - fleet.PackSpecQuery{ + { QueryName: queries[0].Name, Name: "foo_snapshot", Interval: 600, Snapshot: ptr.Bool(true), Denylist: ptr.Bool(false), }, - fleet.PackSpecQuery{ + { Name: "q2", QueryName: queries[1].Name, Interval: 600, @@ -229,7 +165,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { }, }, }, - &fleet.PackSpec{ + { ID: 2, Name: "test_pack_disabled", Disabled: true, @@ -241,19 +177,19 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec { }, }, Queries: []fleet.PackSpecQuery{ - fleet.PackSpecQuery{ + { QueryName: queries[0].Name, Name: "q0", Description: "test_foo", Interval: 42, }, - fleet.PackSpecQuery{ + { QueryName: queries[0].Name, Name: "foo_snapshot", Interval: 600, Snapshot: ptr.Bool(true), }, - fleet.PackSpecQuery{ + { Name: "q2", QueryName: queries[1].Name, Interval: 600, @@ -292,14 +228,14 @@ func testGetPackSpec(t *testing.T, ds fleet.Datastore) { func testApplyPackSpecMissingQueries(t *testing.T, ds fleet.Datastore) { // Do not define queries mentioned in spec specs := []*fleet.PackSpec{ - &fleet.PackSpec{ + { ID: 1, Name: "test_pack", Targets: fleet.PackSpecTargets{ Labels: []string{}, }, Queries: []fleet.PackSpecQuery{ - fleet.PackSpecQuery{ + { QueryName: "bar", Interval: 600, }, @@ -318,13 +254,13 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) { setupPackSpecsTest(t, ds) specs := []*fleet.PackSpec{ - &fleet.PackSpec{ + { Name: "test2", Targets: fleet.PackSpecTargets{ Labels: []string{}, }, Queries: []fleet.PackSpecQuery{ - fleet.PackSpecQuery{ + { QueryName: "foo", Interval: 600, }, @@ -340,67 +276,6 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) { assert.Equal(t, "foo", spec.Queries[0].Name) } -func testListLabelsForPack(t *testing.T, ds fleet.Datastore) { - labelSpecs := []*fleet.LabelSpec{ - &fleet.LabelSpec{ - Name: "foo", - Query: "select * from foo", - }, - &fleet.LabelSpec{ - Name: "bar", - Query: "select * from bar", - }, - &fleet.LabelSpec{ - Name: "bing", - Query: "select * from bing", - }, - } - err := ds.ApplyLabelSpecs(labelSpecs) - require.Nil(t, err) - - specs := []*fleet.PackSpec{ - &fleet.PackSpec{ - ID: 1, - Name: "test_pack", - Targets: fleet.PackSpecTargets{ - Labels: []string{ - "foo", - "bar", - "bing", - }, - }, - }, - &fleet.PackSpec{ - ID: 2, - Name: "test 2", - Targets: fleet.PackSpecTargets{ - Labels: []string{ - "bing", - }, - }, - }, - &fleet.PackSpec{ - ID: 3, - Name: "test 3", - }, - } - err = ds.ApplyPackSpecs(specs) - require.Nil(t, err) - - labels, err := ds.ListLabelsForPack(specs[0].ID) - require.Nil(t, err) - assert.Len(t, labels, 3) - - labels, err = ds.ListLabelsForPack(specs[1].ID) - require.Nil(t, err) - assert.Len(t, labels, 1) - assert.Equal(t, "bing", labels[0].Name) - - labels, err = ds.ListLabelsForPack(specs[2].ID) - require.Nil(t, err) - assert.Len(t, labels, 0) -} - func testListPacksForHost(t *testing.T, ds fleet.Datastore) { if ds.Name() == "inmem" { t.Skip("inmem is deprecated") @@ -507,40 +382,4 @@ func testListPacksForHost(t *testing.T, ds fleet.Datastore) { if assert.Len(t, packs, 1) { assert.Equal(t, "foo_pack", packs[0].Name) } - - // Add host directly to pack - err = ds.AddHostToPack(h1.ID, p2.ID) - require.Nil(t, err) - - packs, err = ds.ListPacksForHost(h1.ID) - require.Nil(t, err) - assert.Len(t, packs, 2) - - // Remove label membership for both - err = ds.RecordLabelQueryExecutions( - h1, - map[uint]bool{l2.ID: false, l1.ID: false}, - mockClock.Now(), - ) - require.Nil(t, err) - - err = ds.RecordLabelQueryExecutions( - h1, - map[uint]bool{l2.ID: false}, - mockClock.Now(), - ) - require.Nil(t, err) - packs, err = ds.ListPacksForHost(h1.ID) - require.Nil(t, err) - if assert.Len(t, packs, 1) { - assert.Equal(t, p2.Name, packs[0].Name) - } - - // Now host is added directly to both packs - err = ds.AddHostToPack(h1.ID, p1.ID) - require.Nil(t, err) - - packs, err = ds.ListPacksForHost(h1.ID) - require.Nil(t, err) - assert.Len(t, packs, 2) } diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go index 52244df0cd..a93f606832 100644 --- a/server/datastore/mysql/labels.go +++ b/server/datastore/mysql/labels.go @@ -178,7 +178,6 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error { // NewLabel creates a new fleet.Label func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) { - db := d.getTransaction(opts) query := ` INSERT INTO labels ( name, @@ -189,7 +188,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl label_membership_type ) VALUES ( ?, ?, ?, ?, ?, ?) ` - result, err := db.Exec( + result, err := d.db.Exec( query, label.Name, label.Description, diff --git a/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go b/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go new file mode 100644 index 0000000000..1af8e7956e --- /dev/null +++ b/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go @@ -0,0 +1,35 @@ +package tables + +import ( + "database/sql" + + "github.com/pkg/errors" +) + +func init() { + MigrationClient.AddMigration(Up_20210617174723, Down_20210617174723) +} + +func Up_20210617174723(tx *sql.Tx) error { + sql := ` + DELETE FROM pack_targets + WHERE pack_id NOT IN (SELECT id FROM packs) + ` + if _, err := tx.Exec(sql); err != nil { + return errors.Wrap(err, "delete orphaned pack targets") + } + + sql = ` + ALTER TABLE pack_targets + ADD FOREIGN KEY (pack_id) REFERENCES packs (id) ON UPDATE CASCADE ON DELETE CASCADE + ` + if _, err := tx.Exec(sql); err != nil { + return errors.Wrap(err, "add foreign key on pack_targets pack_id") + } + + return nil +} + +func Down_20210617174723(tx *sql.Tx) error { + return nil +} diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go index 07285ee5d2..e615955a42 100644 --- a/server/datastore/mysql/mysql.go +++ b/server/datastore/mysql/mysql.go @@ -46,23 +46,6 @@ type Datastore struct { config config.MysqlConfig } -type dbfunctions interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Get(dest interface{}, query string, args ...interface{}) error - Select(dest interface{}, query string, args ...interface{}) error -} - -func (d *Datastore) getTransaction(opts []fleet.OptionalArg) dbfunctions { - var result dbfunctions = d.db - for _, opt := range opts { - switch t := opt().(type) { - case dbfunctions: - result = t - } - } - return result -} - type txFn func(*sqlx.Tx) error // retryableError determines whether a MySQL error can be retried. By default diff --git a/server/datastore/mysql/packs.go b/server/datastore/mysql/packs.go index 56557b34b3..463e7024b5 100644 --- a/server/datastore/mysql/packs.go +++ b/server/datastore/mysql/packs.go @@ -3,6 +3,7 @@ package mysql import ( "database/sql" "fmt" + "strings" "github.com/fleetdm/fleet/server/fleet" "github.com/jmoiron/sqlx" @@ -198,19 +199,22 @@ WHERE pack_id = ? } func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.Pack, bool, error) { - db := d.getTransaction(opts) sqlStatement := ` SELECT * FROM packs WHERE name = ? ` var pack fleet.Pack - err := db.Get(&pack, sqlStatement, name) + err := d.db.Get(&pack, sqlStatement, name) if err != nil { if err == sql.ErrNoRows { return nil, false, nil } - return nil, false, errors.Wrap(err, "fetching packs by name") + return nil, false, errors.Wrap(err, "fetch pack by name") + } + + if err := d.loadPackTargets(&pack); err != nil { + return nil, false, err } return &pack, true, nil @@ -218,44 +222,144 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P // NewPack creates a new Pack func (d *Datastore) NewPack(pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) { - db := d.getTransaction(opts) + if err := d.withRetryTxx(func(tx *sqlx.Tx) error { + query := ` + INSERT INTO packs + (name, description, platform, disabled) + VALUES ( ?, ?, ?, ? ) + ` + result, err := d.db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled) + if err != nil { + return errors.Wrap(err, "insert pack") + } - query := ` - INSERT INTO packs - (name, description, platform, disabled) - VALUES ( ?, ?, ?, ? ) - ` + id, _ := result.LastInsertId() + pack.ID = uint(id) - result, err := db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled) - if err != nil { - return nil, errors.Wrap(err, "inserting pack") + if err := d.replacePackTargets(tx, pack); err != nil { + return err + } + + return nil + }); err != nil { + return nil, err + } + return pack, nil +} + +func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error { + sql := `DELETE FROM pack_targets WHERE pack_id = ?` + if _, err := tx.Exec(sql, pack.ID); err != nil { + return errors.Wrap(err, "delete pack targets") } - id, _ := result.LastInsertId() - pack.ID = uint(id) - return pack, nil + // Insert hosts + if len(pack.HostIDs) > 0 { + var args []interface{} + for _, id := range pack.HostIDs { + args = append(args, pack.ID, fleet.TargetHost, id) + } + values := strings.TrimSuffix( + strings.Repeat("(?,?,?),", len(pack.HostIDs)), + ",", + ) + sql = fmt.Sprintf(` + INSERT INTO pack_targets (pack_id, type, target_id) + VALUES %s + `, values) + if _, err := tx.Exec(sql, args...); err != nil { + return errors.Wrap(err, "insert host targets") + } + } + + // Insert labels + if len(pack.LabelIDs) > 0 { + var args []interface{} + for _, id := range pack.LabelIDs { + args = append(args, pack.ID, fleet.TargetLabel, id) + } + values := strings.TrimSuffix( + strings.Repeat("(?,?,?),", len(pack.LabelIDs)), + ",", + ) + sql = fmt.Sprintf(` + INSERT INTO pack_targets (pack_id, type, target_id) + VALUES %s + `, values) + if _, err := tx.Exec(sql, args...); err != nil { + return errors.Wrap(err, "insert label targets") + } + } + + // Insert teams + if len(pack.TeamIDs) > 0 { + var args []interface{} + for _, id := range pack.TeamIDs { + args = append(args, pack.ID, fleet.TargetTeam, id) + } + values := strings.TrimSuffix( + strings.Repeat("(?,?,?),", len(pack.TeamIDs)), + ",", + ) + sql = fmt.Sprintf(` + INSERT INTO pack_targets (pack_id, type, target_id) + VALUES %s + `, values) + if _, err := tx.Exec(sql, args...); err != nil { + return errors.Wrap(err, "insert team targets") + } + } + + return nil +} + +func (d *Datastore) loadPackTargets(pack *fleet.Pack) error { + var targets []fleet.PackTarget + sql := `SELECT * FROM pack_targets WHERE pack_id = ?` + if err := d.db.Select(&targets, sql, pack.ID); err != nil { + return errors.Wrap(err, "select pack targets") + } + + pack.HostIDs, pack.LabelIDs, pack.TeamIDs = []uint{}, []uint{}, []uint{} + for _, target := range targets { + switch target.Type { + case fleet.TargetHost: + pack.HostIDs = append(pack.HostIDs, target.TargetID) + case fleet.TargetLabel: + pack.LabelIDs = append(pack.LabelIDs, target.TargetID) + case fleet.TargetTeam: + pack.TeamIDs = append(pack.TeamIDs, target.TargetID) + default: + return errors.Errorf("unknown target type: %d", target.Type) + } + } + + return nil } // SavePack stores changes to pack func (d *Datastore) SavePack(pack *fleet.Pack) error { - query := ` + return d.withRetryTxx(func(tx *sqlx.Tx) error { + query := ` UPDATE packs SET name = ?, platform = ?, disabled = ?, description = ? WHERE id = ? ` - results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID) - if err != nil { - return errors.Wrap(err, "updating pack") - } - rowsAffected, err := results.RowsAffected() - if err != nil { - return errors.Wrap(err, "rows affected updating packs") - } - if rowsAffected == 0 { - return notFound("Pack").WithID(pack.ID) - } - return nil + results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID) + if err != nil { + return errors.Wrap(err, "updating pack") + } + rowsAffected, err := results.RowsAffected() + if err != nil { + return errors.Wrap(err, "rows affected updating packs") + } + if rowsAffected == 0 { + return notFound("Pack").WithID(pack.ID) + } + + return d.replacePackTargets(tx, pack) + }) } // DeletePack deletes a fleet.Pack so that it won't show up in results. @@ -271,7 +375,11 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) { if err == sql.ErrNoRows { return nil, notFound("Pack").WithID(pid) } else if err != nil { - return nil, errors.Wrap(err, "getting pack") + return nil, errors.Wrap(err, "get pack") + } + + if err := d.loadPackTargets(pack); err != nil { + return nil, err } return pack, nil @@ -285,107 +393,22 @@ func (d *Datastore) ListPacks(opt fleet.ListOptions) ([]*fleet.Pack, error) { if err != nil && err != sql.ErrNoRows { return nil, errors.Wrap(err, "listing packs") } + + for _, pack := range packs { + if err := d.loadPackTargets(pack); err != nil { + return nil, err + } + } + return packs, nil } -// AddLabelToPack associates a fleet.Label with a fleet.Pack -func (d *Datastore) AddLabelToPack(lid uint, pid uint, opts ...fleet.OptionalArg) error { - db := d.getTransaction(opts) - - query := ` - INSERT INTO pack_targets ( pack_id, type, target_id ) - VALUES ( ?, ?, ? ) - ON DUPLICATE KEY UPDATE id=id - ` - _, err := db.Exec(query, pid, fleet.TargetLabel, lid) - if err != nil { - return errors.Wrap(err, "adding label to pack") - } - - return nil -} - -// AddHostToPack associates a fleet.Host with a fleet.Pack -func (d *Datastore) AddHostToPack(hid, pid uint) error { - query := ` - INSERT INTO pack_targets ( pack_id, type, target_id ) - VALUES ( ?, ?, ? ) - ON DUPLICATE KEY UPDATE id=id - ` - _, err := d.db.Exec(query, pid, fleet.TargetHost, hid) - if err != nil { - return errors.Wrap(err, "adding host to pack") - } - - return nil -} - -// RemoreLabelFromPack will remove the association between a fleet.Label and -// a fleet.Pack -func (d *Datastore) RemoveLabelFromPack(lid, pid uint) error { - query := ` - DELETE FROM pack_targets - WHERE target_id = ? AND pack_id = ? AND type = ? - ` - _, err := d.db.Exec(query, lid, pid, fleet.TargetLabel) - if err == sql.ErrNoRows { - return notFound("PackTarget").WithMessage(fmt.Sprintf("label ID: %d, pack ID: %d", lid, pid)) - } else if err != nil { - return errors.Wrap(err, "removing label from pack") - } - return nil -} - -// RemoveHostFromPack will remove the association between a fleet.Host and a -// fleet.Pack -func (d *Datastore) RemoveHostFromPack(hid, pid uint) error { - query := ` - DELETE FROM pack_targets - WHERE target_id = ? AND pack_id = ? AND type = ? - ` - _, err := d.db.Exec(query, hid, pid, fleet.TargetHost) - if err == sql.ErrNoRows { - return notFound("PackTarget").WithMessage(fmt.Sprintf("host ID: %d, pack ID: %d", hid, pid)) - } else if err != nil { - return errors.Wrap(err, "removing host from pack") - } - return nil -} - -// ListLabelsForPack will return a list of fleet.Label records associated with fleet.Pack -func (d *Datastore) ListLabelsForPack(pid uint) ([]*fleet.Label, error) { - query := ` - SELECT - l.id, - l.created_at, - l.updated_at, - l.name - FROM - labels l - JOIN - pack_targets pt - ON - pt.target_id = l.id - WHERE - pt.type = ? - AND - pt.pack_id = ? - ` - - labels := []*fleet.Label{} - - if err := d.db.Select(&labels, query, fleet.TargetLabel, pid); err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "listing labels for pack") - } - - return labels, nil -} - func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) { query := ` SELECT DISTINCT packs.* FROM - ((SELECT p.* FROM packs p + ((SELECT p.* + FROM packs p JOIN pack_targets pt JOIN label_membership lm ON ( @@ -399,55 +422,17 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) { FROM packs p JOIN pack_targets pt ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = ?)) + UNION ALL + (SELECT p.* + FROM packs p + JOIN pack_targets pt + ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = (SELECT team_id FROM hosts WHERE id = ?))) ) packs ` packs := []*fleet.Pack{} - if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid); err != nil && err != sql.ErrNoRows { + if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows { return nil, errors.Wrap(err, "listing hosts in pack") } return packs, nil } - -func (d *Datastore) ListHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) { - query := ` - SELECT DISTINCT h.id - FROM hosts h - JOIN pack_targets pt - JOIN label_membership lm - ON ( - pt.target_id = lm.label_id - AND lm.host_id = h.id - AND pt.type = ? - ) OR ( - pt.target_id = h.id - AND pt.type = ? - ) - WHERE pt.pack_id = ? - ` - - hosts := []uint{} - if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetLabel, fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "listing hosts in pack") - } - return hosts, nil -} - -func (d *Datastore) ListExplicitHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) { - query := ` - SELECT DISTINCT h.id - FROM hosts h - JOIN pack_targets pt - ON ( - pt.target_id = h.id - AND pt.type = ? - ) - WHERE pt.pack_id = ? - ` - hosts := []uint{} - if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows { - return nil, errors.Wrap(err, "listing explicit hosts in pack") - } - return hosts, nil - -} diff --git a/server/datastore/mysql/queries.go b/server/datastore/mysql/queries.go index 5b0a99d166..6f331d9b2f 100644 --- a/server/datastore/mysql/queries.go +++ b/server/datastore/mysql/queries.go @@ -65,14 +65,13 @@ func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err err } func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.Query, error) { - db := d.getTransaction(opts) sqlStatement := ` SELECT * FROM queries WHERE name = ? ` var query fleet.Query - err := db.Get(&query, sqlStatement, name) + err := d.db.Get(&query, sqlStatement, name) if err != nil { if err == sql.ErrNoRows { return nil, notFound("Query").WithName(name) @@ -89,8 +88,6 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet. // NewQuery creates a New Query. func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { - db := d.getTransaction(opts) - sqlStatement := ` INSERT INTO queries ( name, @@ -101,7 +98,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl observer_can_run ) VALUES ( ?, ?, ?, ?, ?, ? ) ` - result, err := db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun) + result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun) if err != nil && isDuplicate(err) { return nil, alreadyExists("Query", 0) diff --git a/server/datastore/mysql/scheduled_queries.go b/server/datastore/mysql/scheduled_queries.go index ec41e1c7f6..e2f1649c2f 100644 --- a/server/datastore/mysql/scheduled_queries.go +++ b/server/datastore/mysql/scheduled_queries.go @@ -40,8 +40,6 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions) } func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.OptionalArg) (*fleet.ScheduledQuery, error) { - db := d.getTransaction(opts) - // This query looks up the query name using the ID (for backwards // compatibility with the UI) query := ` @@ -61,7 +59,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op FROM queries WHERE id = ? ` - result, err := db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID) + result, err := d.db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID) if err != nil { return nil, errors.Wrap(err, "insert scheduled query") } @@ -75,7 +73,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op Name string }{} - err = db.Select(&metadata, query, sq.QueryID) + err = d.db.Select(&metadata, query, sq.QueryID) if err != nil && err == sql.ErrNoRows { return nil, notFound("Query").WithID(sq.QueryID) } else if err != nil { diff --git a/server/fleet/packs.go b/server/fleet/packs.go index 0518626bec..5c4e510393 100644 --- a/server/fleet/packs.go +++ b/server/fleet/packs.go @@ -33,33 +33,8 @@ type PackStore interface { // exists the bool return value is true PackByName(name string, opts ...OptionalArg) (*Pack, bool, error) - // AddLabelToPack adds an existing label to an existing pack, both by ID. - AddLabelToPack(lid, pid uint, opts ...OptionalArg) error - - // RemoveLabelFromPack removes an existing label from it's association with - // an existing pack, both by ID. - RemoveLabelFromPack(lid, pid uint) error - - // ListLabelsForPack lists all labels that are associated with a pack. - ListLabelsForPack(pid uint) ([]*Label, error) - - // AddHostToPack adds an existing host to an existing pack, both by ID. - AddHostToPack(hid uint, pid uint) error - - // RemoveHostFromPack removes an existing host from it's association with - // an existing pack, both by ID. - RemoveHostFromPack(hid uint, pid uint) error - // ListPacksForHost lists the packs that a host should execute. ListPacksForHost(hid uint) (packs []*Pack, err error) - - // ListHostsInPack lists the IDs of all hosts that are associated with a pack - // through labels. - ListHostsInPack(pid uint, opt ListOptions) ([]uint, error) - - // ListExplicitHostsInPack lists the IDs of hosts that have been manually - // associated with a query pack. - ListExplicitHostsInPack(pid uint, opt ListOptions) ([]uint, error) } // PackService is the service interface for managing query packs. @@ -90,33 +65,8 @@ type PackService interface { // DeletePackByID is for backwards compatibility with the UI DeletePackByID(ctx context.Context, id uint) (err error) - // AddLabelToPack adds an existing label to an existing pack, both by ID. - AddLabelToPack(ctx context.Context, lid, pid uint) (err error) - - // RemoveLabelFromPack removes an existing label from it's association with - // an existing pack, both by ID. - RemoveLabelFromPack(ctx context.Context, lid, pid uint) (err error) - - // ListLabelsForPack lists all labels that are associated with a pack. - ListLabelsForPack(ctx context.Context, pid uint) (labels []*Label, err error) - - // AddHostToPack adds an existing host to an existing pack, both by ID. - AddHostToPack(ctx context.Context, hid, pid uint) (err error) - - // RemoveHostFromPack removes an existing host from it's association with - // an existing pack, both by ID. - RemoveHostFromPack(ctx context.Context, hid, pid uint) (err error) - // ListPacksForHost lists the packs that a host should execute. ListPacksForHost(ctx context.Context, hid uint) (packs []*Pack, err error) - - // ListHostsInPack lists the IDs of all hosts that are associated with a pack, - // both through labels and manual associations. - ListHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error) - - // ListExplicitHostsInPack lists the IDs of hosts that have been manually associated - // with a query pack. - ListExplicitHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error) } // Pack is the structure which represents an osquery query pack. @@ -127,6 +77,9 @@ type Pack struct { Description string `json:"description,omitempty"` Platform string `json:"platform,omitempty"` Disabled bool `json:"disabled,omitempty"` + LabelIDs []uint `json:"label_ids"` + HostIDs []uint `json:"host_ids"` + TeamIDs []uint `json:"team_ids"` } func (p Pack) AuthzType() string { @@ -145,6 +98,7 @@ type PackPayload struct { Disabled *bool `json:"disabled"` HostIDs *[]uint `json:"host_ids"` LabelIDs *[]uint `json:"label_ids"` + TeamIDs *[]uint `json:"team_ids"` } type PackSpec struct { @@ -174,10 +128,10 @@ type PackSpecQuery struct { Denylist *bool `json:"denylist,omitempty"` } -// PackTarget associates a pack with either a host or a label +// PackTarget targets a pack to a host, label, or team. type PackTarget struct { - ID uint - PackID uint + ID uint `db:"id"` + PackID uint `db:"pack_id"` Target } diff --git a/server/fleet/targets.go b/server/fleet/targets.go index 35e54c47c5..be4c587a5c 100644 --- a/server/fleet/targets.go +++ b/server/fleet/targets.go @@ -78,8 +78,8 @@ const ( ) type Target struct { - Type TargetType - TargetID uint + Type TargetType `db:"type"` + TargetID uint `db:"target_id"` } func (t Target) AuthzType() string { diff --git a/server/service/endpoint_appconfig.go b/server/service/endpoint_appconfig.go index 37fa579e76..acfd31d48c 100644 --- a/server/service/endpoint_appconfig.go +++ b/server/service/endpoint_appconfig.go @@ -22,7 +22,7 @@ type appConfigResponse struct { SSOSettings *fleet.SSOSettingsPayload `json:"sso_settings,omitempty"` HostExpirySettings *fleet.HostExpirySettings `json:"host_expiry_settings,omitempty"` HostSettings *fleet.HostSettings `json:"host_settings,omitempty"` - AgentOptions *json.RawMessage `json:"agent_options,omitempty"` + AgentOptions *json.RawMessage `json:"agent_options,omitempty"` License *fleet.LicenseInfo `json:"license,omitempty"` Err error `json:"error,omitempty"` } diff --git a/server/service/endpoint_packs.go b/server/service/endpoint_packs.go index 25c75bdfef..4d600d2011 100644 --- a/server/service/endpoint_packs.go +++ b/server/service/endpoint_packs.go @@ -18,6 +18,7 @@ type packResponse struct { // IDs of hosts which were explicitly selected. HostIDs []uint `json:"host_ids"` LabelIDs []uint `json:"label_ids"` + TeamIDs []uint `json:"team_ids"` } func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) { @@ -27,21 +28,11 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack return nil, err } - hosts, err := svc.ListExplicitHostsInPack(ctx, pack.ID, opts) - if err != nil { - return nil, err - } - - labels, err := svc.ListLabelsForPack(ctx, pack.ID) - labelIDs := make([]uint, len(labels)) - for i, label := range labels { - labelIDs[i] = label.ID - } - if err != nil { - return nil, err - } - - hostMetrics, err := svc.CountHostsInTargets(ctx, nil, fleet.HostTargets{HostIDs: hosts, LabelIDs: labelIDs}) + hostMetrics, err := svc.CountHostsInTargets( + ctx, + nil, + fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs}, + ) if err != nil { return nil, err } @@ -50,8 +41,9 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack Pack: pack, QueryCount: uint(len(queries)), TotalHostsCount: hostMetrics.TotalHosts, - HostIDs: hosts, - LabelIDs: labelIDs, + HostIDs: pack.HostIDs, + LabelIDs: pack.LabelIDs, + TeamIDs: pack.TeamIDs, }, nil } diff --git a/server/service/logging_packs.go b/server/service/logging_packs.go index d994fedf91..9fac2425fe 100644 --- a/server/service/logging_packs.go +++ b/server/service/logging_packs.go @@ -117,92 +117,6 @@ func (mw loggingMiddleware) DeletePack(ctx context.Context, name string) error { return err } -func (mw loggingMiddleware) AddLabelToPack(ctx context.Context, lid uint, pid uint) error { - var ( - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerInfo(err).Log( - "method", "AddLabelToPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - err = mw.Service.AddLabelToPack(ctx, lid, pid) - return err -} - -func (mw loggingMiddleware) RemoveLabelFromPack(ctx context.Context, lid uint, pid uint) error { - var ( - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerInfo(err).Log( - "method", "RemoveLabelFromPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - err = mw.Service.RemoveLabelFromPack(ctx, lid, pid) - return err -} - -func (mw loggingMiddleware) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) { - var ( - labels []*fleet.Label - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerDebug(err).Log( - "method", "ListLabelsForPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - labels, err = mw.Service.ListLabelsForPack(ctx, pid) - return labels, err -} - -func (mw loggingMiddleware) AddHostToPack(ctx context.Context, hid uint, pid uint) error { - var ( - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerInfo(err).Log( - "method", "AddHostToPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - err = mw.Service.AddHostToPack(ctx, hid, pid) - return err -} - -func (mw loggingMiddleware) RemoveHostFromPack(ctx context.Context, hid uint, pid uint) error { - var ( - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerInfo(err).Log( - "method", "RemoveHostFromPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - err = mw.Service.RemoveHostFromPack(ctx, hid, pid) - return err -} - func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) { var ( packs []*fleet.Pack @@ -221,24 +135,6 @@ func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]* return packs, err } -func (mw loggingMiddleware) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) { - var ( - hosts []uint - err error - ) - - defer func(begin time.Time) { - _ = mw.loggerDebug(err).Log( - "method", "ListHostsInPack", - "err", err, - "took", time.Since(begin), - ) - }(time.Now()) - - hosts, err = mw.Service.ListHostsInPack(ctx, pid, opt) - return hosts, err -} - func (mw loggingMiddleware) GetPackSpec(ctx context.Context, name string) (spec *fleet.PackSpec, err error) { defer func(begin time.Time) { _ = mw.loggerDebug(err).Log( diff --git a/server/service/service_labels.go b/server/service/service_labels.go index 3b1c75200d..c5e2dd5670 100644 --- a/server/service/service_labels.go +++ b/server/service/service_labels.go @@ -19,7 +19,7 @@ func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe } if spec.LabelMembershipType == fleet.LabelMembershipTypeManual && spec.Hosts == nil { // Hosts list doesn't need to contain anything, but it should at least not be nil. - return errors.Errorf("label %s is declared as manual but contains not `hosts key`", spec.Name) + return errors.Errorf("label %s is declared as manual but contains no `hosts key`", spec.Name) } } return svc.ds.ApplyLabelSpecs(specs) diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go index 21b411e22a..8d1cf07948 100644 --- a/server/service/service_osquery.go +++ b/server/service/service_osquery.go @@ -209,13 +209,13 @@ func (svc *Service) GetClientConfig(ctx context.Context) (map[string]interface{} baseConfig, err := svc.AgentOptionsForHost(ctx, &host) if err != nil { - return nil, osqueryError{message: "internal error: fetching base config: " + err.Error()} + return nil, osqueryError{message: "internal error: fetch base config: " + err.Error()} } var config map[string]interface{} err = json.Unmarshal(baseConfig, &config) if err != nil { - return nil, osqueryError{message: "internal error: parsing base configuration: " + err.Error()} + return nil, osqueryError{message: "internal error: parse base configuration: " + err.Error()} } packs, err := svc.ds.ListPacksForHost(host.ID) diff --git a/server/service/service_packs.go b/server/service/service_packs.go index 6b2f0de4aa..e727f63acd 100644 --- a/server/service/service_packs.go +++ b/server/service/service_packs.go @@ -74,24 +74,6 @@ func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pa return nil, err } - if p.HostIDs != nil { - for _, hostID := range *p.HostIDs { - err = svc.AddHostToPack(ctx, hostID, pack.ID) - if err != nil { - return nil, err - } - } - } - - if p.LabelIDs != nil { - for _, labelID := range *p.LabelIDs { - err = svc.AddLabelToPack(ctx, labelID, pack.ID) - if err != nil { - return nil, err - } - } - } - return &pack, nil } @@ -121,107 +103,23 @@ func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload pack.Disabled = *p.Disabled } + if p.HostIDs != nil { + pack.HostIDs = *p.HostIDs + } + + if p.LabelIDs != nil { + pack.LabelIDs = *p.LabelIDs + } + + if p.TeamIDs != nil { + pack.TeamIDs = *p.TeamIDs + } + err = svc.ds.SavePack(pack) if err != nil { return nil, err } - // we must determine what hosts are attached to this pack. then, given - // our new set of host_ids, we will mutate the database to reflect the - // desired state. - if p.HostIDs != nil { - - // first, let's retrieve the total set of hosts - hosts, err := svc.ListHostsInPack(ctx, pack.ID, fleet.ListOptions{}) - if err != nil { - return nil, err - } - - // it will be efficient to create a data structure with constant time - // lookups to determine whether or not a host is already added - existingHosts := map[uint]bool{} - for _, host := range hosts { - existingHosts[host] = true - } - - // we will also make a constant time lookup map for the desired set of - // hosts as well. - desiredHosts := map[uint]bool{} - for _, hostID := range *p.HostIDs { - desiredHosts[hostID] = true - } - - // if the request declares a host ID but the host is not already - // associated with the pack, we add it - for _, hostID := range *p.HostIDs { - if !existingHosts[hostID] { - err = svc.AddHostToPack(ctx, hostID, pack.ID) - if err != nil { - return nil, err - } - } - } - - // if the request does not declare the ID of a host which currently - // exists, we delete the existing relationship - for hostID := range existingHosts { - if !desiredHosts[hostID] { - err = svc.RemoveHostFromPack(ctx, hostID, pack.ID) - if err != nil { - return nil, err - } - } - } - } - - // we must determine what labels are attached to this pack. then, given - // our new set of label_ids, we will mutate the database to reflect the - // desired state. - if p.LabelIDs != nil { - - // first, let's retrieve the total set of labels - labels, err := svc.ListLabelsForPack(ctx, pack.ID) - if err != nil { - return nil, err - } - - // it will be efficient to create a data structure with constant time - // lookups to determine whether or not a label is already added - existingLabels := map[uint]bool{} - for _, label := range labels { - existingLabels[label.ID] = true - } - - // we will also make a constant time lookup map for the desired set of - // labels as well. - desiredLabels := map[uint]bool{} - for _, labelID := range *p.LabelIDs { - desiredLabels[labelID] = true - } - - // if the request declares a label ID but the label is not already - // associated with the pack, we add it - for _, labelID := range *p.LabelIDs { - if !existingLabels[labelID] { - err = svc.AddLabelToPack(ctx, labelID, pack.ID) - if err != nil { - return nil, err - } - } - } - - // if the request does not declare the ID of a label which currently - // exists, we delete the existing relationship - for labelID := range existingLabels { - if !desiredLabels[labelID] { - err = svc.RemoveLabelFromPack(ctx, labelID, pack.ID) - if err != nil { - return nil, err - } - } - } - } - return pack, err } @@ -245,62 +143,6 @@ func (svc *Service) DeletePackByID(ctx context.Context, id uint) error { return svc.ds.DeletePack(pack.Name) } -func (svc *Service) AddLabelToPack(ctx context.Context, lid, pid uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.AddLabelToPack(lid, pid) -} - -func (svc *Service) RemoveLabelFromPack(ctx context.Context, lid, pid uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.RemoveLabelFromPack(lid, pid) -} - -func (svc *Service) AddHostToPack(ctx context.Context, hid, pid uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.AddHostToPack(hid, pid) -} - -func (svc *Service) RemoveHostFromPack(ctx context.Context, hid, pid uint) error { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil { - return err - } - - return svc.ds.RemoveHostFromPack(hid, pid) -} - -func (svc *Service) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListLabelsForPack(pid) -} - -func (svc *Service) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListHostsInPack(pid, opt) -} - -func (svc *Service) ListExplicitHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) { - if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { - return nil, err - } - - return svc.ds.ListExplicitHostsInPack(pid, opt) -} - func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) { if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil { return nil, err diff --git a/server/test/comparisons.go b/server/test/comparisons.go index c23f8bcfc5..52265268cb 100644 --- a/server/test/comparisons.go +++ b/server/test/comparisons.go @@ -10,9 +10,16 @@ import ( "github.com/stretchr/testify/assert" ) +type TestingT interface { + assert.TestingT + Helper() +} + // ElementsMatchSkipID asserts that the elements match, skipping any field with // name "ID". -func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { +func ElementsMatchSkipID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + t.Helper() + opt := cmp.FilterPath(func(p cmp.Path) bool { for _, ps := range p { switch ps := ps.(type) { @@ -29,7 +36,9 @@ func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs // ElementsMatchSkipTimestampsID asserts that the elements match, skipping any field with // name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing after DB insertion. -func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { +func ElementsMatchSkipTimestampsID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + t.Helper() + opt := cmp.FilterPath(func(p cmp.Path) bool { for _, ps := range p { switch ps := ps.(type) { @@ -45,6 +54,31 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{}, return ElementsMatchWithOptions(t, listA, listB, []cmp.Option{opt}, msgAndArgs) } +// EqualSkipTimestampsID asserts that the structs are equal, skipping any field +// with name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing +// after DB insertion. +func EqualSkipTimestampsID(t TestingT, a, b interface{}, msgAndArgs ...interface{}) (ok bool) { + t.Helper() + + opt := cmp.FilterPath(func(p cmp.Path) bool { + for _, ps := range p { + switch ps := ps.(type) { + case cmp.StructField: + switch ps.Name() { + case "ID", "UpdateCreateTimestamps", "CreateTimestamp", "UpdateTimestamp", "CreatedAt", "UpdatedAt": + return true + } + } + } + return false + }, cmp.Ignore()) + + if !cmp.Equal(a, b, opt) { + return assert.Fail(t, cmp.Diff(a, b, opt), msgAndArgs...) + } + return true +} + // The below functions adapted from // https://github.com/stretchr/testify/blob/v1.7.0/assert/assertions.go#L895 by // utilizing the options provided in github.com/google/go-cmp/cmp @@ -53,7 +87,7 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{}, // additional options as provided by the cmp package. This allows, for example, // comparing structs while ignoring fields. See assert.ElementsMatch // documentation for more details. -func ElementsMatchWithOptions(t assert.TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) { +func ElementsMatchWithOptions(t TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) { if isEmpty(listA) && isEmpty(listB) { return true } @@ -100,7 +134,7 @@ func isEmpty(object interface{}) bool { } // isList checks that the provided value is array or slice. -func isList(t assert.TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) { +func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) { kind := reflect.TypeOf(list).Kind() if kind != reflect.Array && kind != reflect.Slice { return assert.Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind),