Allow Packs to be targeted to Teams (#1130)

- Add additional target type for packs.
- Refactor pack target datastore.
- Fixes for frontend target selector tier logic on packs page.
This commit is contained in:
Zach Wasserman 2021-06-18 09:43:16 -07:00 committed by GitHub
parent 2ad557e3b3
commit 19e8da177f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 396 additions and 820 deletions

View file

@ -22,6 +22,7 @@ class EditPackForm extends Component {
onCancel: PropTypes.func.isRequired,
onFetchTargets: PropTypes.func,
targetsCount: PropTypes.number,
isBasicTier: PropTypes.bool,
};
render() {
@ -32,6 +33,7 @@ class EditPackForm extends Component {
onCancel,
onFetchTargets,
targetsCount,
isBasicTier,
} = this.props;
return (
@ -58,6 +60,7 @@ class EditPackForm extends Component {
onSelect={fields.targets.onChange}
selectedTargets={fields.targets.value}
targetsCount={targetsCount}
isBasicTier={isBasicTier}
/>
<div className={`${baseClass}__pack-buttons`}>
<Button onClick={onCancel} type="button" variant="inverse">

View file

@ -24,6 +24,7 @@ class PackForm extends Component {
handleSubmit: PropTypes.func,
onFetchTargets: PropTypes.func,
selectedTargetsCount: PropTypes.number,
isBasicTier: PropTypes.bool,
};
render() {
@ -34,6 +35,7 @@ class PackForm extends Component {
handleSubmit,
onFetchTargets,
selectedTargetsCount,
isBasicTier,
} = this.props;
const packFormClass = classnames(baseClass, className);
@ -63,6 +65,7 @@ class PackForm extends Component {
onFetchTargets={onFetchTargets}
selectedTargets={fields.targets.value}
targetsCount={selectedTargetsCount}
isBasicTier={isBasicTier}
/>
</div>
<div className={`${baseClass}__pack-buttons`}>

View file

@ -21,6 +21,7 @@ class EditPackFormWrapper extends Component {
pack: packInterface.isRequired,
packTargets: PropTypes.arrayOf(targetInterface),
targetsCount: PropTypes.number,
isBasicTier: PropTypes.bool,
};
render() {
@ -34,6 +35,7 @@ class EditPackFormWrapper extends Component {
pack,
packTargets,
targetsCount,
isBasicTier,
} = this.props;
if (isEdit) {
@ -45,6 +47,7 @@ class EditPackFormWrapper extends Component {
onCancel={onCancelEditPack}
onFetchTargets={onFetchTargets}
targetsCount={targetsCount}
isBasicTier={isBasicTier}
/>
);
}
@ -73,6 +76,7 @@ class EditPackFormWrapper extends Component {
selectedTargets={packTargets}
targetsCount={targetsCount}
disabled
isBasicTier={isBasicTier}
className={`${baseClass}__select-targets`}
/>
</div>

View file

@ -48,7 +48,7 @@ export default (client: any) => {
page = 0,
perPage = 100,
globalFilter = "",
}: ITeamSearchOptions) => {
}: ITeamSearchOptions = {}) => {
const { TEAMS } = endpoints;
// TODO: add this query param logic to client class

View file

@ -279,6 +279,20 @@ export const formatScheduledQueryForClient = (scheduledQuery: any): any => {
return scheduledQuery;
};
export const formatTeamForClient = (team: any): any => {
if (team.display_text === undefined) {
team.display_text = team.name;
}
return team;
};
export const formatPackForClient = (pack: any): any => {
pack.host_ids ||= [];
pack.label_ids ||= [];
pack.team_ids ||= [];
return pack;
};
const setupData = (formData: any) => {
const orgInfo = pick(formData, ORG_INFO_ATTRS);
const adminInfo = pick(formData, ADMIN_ATTRS);
@ -355,6 +369,7 @@ export const secondsToHms = (d: number): string => {
const sDisplay = s > 0 ? s + (s === 1 ? " sec " : " secs ") : "";
return hDisplay + mDisplay + sDisplay;
};
export default {
addGravatarUrlToResource,
formatConfigDataForServer,

View file

@ -4,12 +4,15 @@ import { connect } from "react-redux";
import { filter, includes, isEqual, noop, size, find } from "lodash";
import { push } from "react-router-redux";
import permissionUtils from "utilities/permissions";
import deepDifference from "utilities/deep_difference";
import EditPackFormWrapper from "components/packs/EditPackFormWrapper";
import hostActions from "redux/nodes/entities/hosts/actions";
import hostInterface from "interfaces/host";
import labelActions from "redux/nodes/entities/labels/actions";
import teamActions from "redux/nodes/entities/teams/actions";
import labelInterface from "interfaces/label";
import teamInterface from "interfaces/team";
import packActions from "redux/nodes/entities/packs/actions";
import ScheduleQuerySidePanel from "components/side_panels/ScheduleQuerySidePanel";
import packInterface from "interfaces/pack";
@ -34,7 +37,9 @@ export class EditPackPage extends Component {
packHosts: PropTypes.arrayOf(hostInterface),
packID: PropTypes.string,
packLabels: PropTypes.arrayOf(labelInterface),
packTeams: PropTypes.arrayOf(teamInterface),
scheduledQueries: PropTypes.arrayOf(queryInterface),
isBasicTier: PropTypes.bool,
};
static defaultProps = {
@ -60,6 +65,7 @@ export class EditPackPage extends Component {
packHosts,
packID,
packLabels,
packTeams,
scheduledQueries,
} = this.props;
const { load } = packActions;
@ -77,6 +83,10 @@ export class EditPackPage extends Component {
if (!packLabels || packLabels.length !== pack.label_ids.length) {
dispatch(labelActions.loadAll());
}
if (!packTeams || packTeams.length !== pack.team_ids.length) {
dispatch(teamActions.loadAll());
}
}
if (!size(scheduledQueries)) {
@ -90,7 +100,13 @@ export class EditPackPage extends Component {
return false;
}
componentWillReceiveProps({ dispatch, pack, packHosts, packLabels }) {
componentWillReceiveProps({
dispatch,
pack,
packHosts,
packLabels,
packTeams,
}) {
if (!isEqual(pack, this.props.pack)) {
if (!packHosts || packHosts.length !== pack.host_ids.length) {
dispatch(hostActions.loadAll());
@ -99,6 +115,10 @@ export class EditPackPage extends Component {
if (!packLabels || packLabels.length !== pack.label_ids.length) {
dispatch(labelActions.loadAll());
}
if (!packTeams || packTeams.length !== pack.team_ids.length) {
dispatch(teamActions.loadAll());
}
}
return false;
@ -241,10 +261,12 @@ export class EditPackPage extends Component {
pack,
packHosts,
packLabels,
packTeams,
scheduledQueries,
isBasicTier,
} = this.props;
const packTargets = [...packHosts, ...packLabels];
const packTargets = [...packHosts, ...packLabels, ...packTeams];
if (!pack || isLoadingPack || isLoadingScheduledQueries) {
return false;
@ -263,6 +285,7 @@ export class EditPackPage extends Component {
pack={pack}
packTargets={packTargets}
targetsCount={targetsCount}
isBasicTier={isBasicTier}
/>
<ScheduledQueriesListWrapper
onRemoveScheduledQueries={handleRemoveScheduledQueries}
@ -307,6 +330,12 @@ const mapStateToProps = (state, { params, route }) => {
return includes(pack.label_ids, label.id);
})
: [];
const packTeams = pack
? filter(state.entities.teams.data, (team) => {
return includes(pack.team_ids, team.id);
})
: [];
const isBasicTier = permissionUtils.isBasicTier(state.app.config);
return {
allQueries,
@ -317,7 +346,9 @@ const mapStateToProps = (state, { params, route }) => {
packHosts,
packID,
packLabels,
packTeams,
scheduledQueries,
isBasicTier,
};
};

View file

@ -3,7 +3,12 @@ import { mount } from "enzyme";
import { noop } from "lodash";
import { connectedComponent, reduxMockStore } from "test/helpers";
import { packStub, queryStub, scheduledQueryStub } from "test/stubs";
import {
packStub,
queryStub,
scheduledQueryStub,
configStub,
} from "test/stubs";
import ConnectedEditPackPage, {
EditPackPage,
} from "pages/packs/EditPackPage/EditPackPage";
@ -12,6 +17,7 @@ import labelActions from "redux/nodes/entities/labels/actions";
import packActions from "redux/nodes/entities/packs/actions";
import queryActions from "redux/nodes/entities/queries/actions";
import scheduledQueryActions from "redux/nodes/entities/scheduled_queries/actions";
import teamActions from "redux/nodes/entities/teams/actions";
describe("EditPackPage - component", () => {
beforeEach(() => {
@ -21,15 +27,18 @@ describe("EditPackPage - component", () => {
jest.spyOn(labelActions, "loadAll").mockImplementation(() => spyResponse);
jest.spyOn(packActions, "load").mockImplementation(() => spyResponse);
jest.spyOn(queryActions, "loadAll").mockImplementation(() => spyResponse);
jest.spyOn(teamActions, "loadAll").mockImplementation(() => spyResponse);
jest
.spyOn(scheduledQueryActions, "loadAll")
.mockImplementation(() => spyResponse);
});
const store = {
app: { config: configStub },
entities: {
hosts: { loading: false, data: {} },
labels: { loading: false, data: {} },
teams: { loading: false, data: {} },
packs: {
loading: false,
data: {
@ -49,6 +58,7 @@ describe("EditPackPage - component", () => {
describe("rendering", () => {
it("does not render when packs are loading", () => {
const packsLoadingStore = {
app: store.app,
entities: {
...store.entities,
packs: { ...store.entities.packs, loading: true },
@ -67,6 +77,7 @@ describe("EditPackPage - component", () => {
it("does not render when scheduled queries are loading", () => {
const scheduledQueriesLoadingStore = {
app: store.app,
entities: {
...store.entities,
scheduled_queries: {
@ -88,6 +99,7 @@ describe("EditPackPage - component", () => {
it("does not render when there is no pack", () => {
const noPackStore = {
app: store.app,
entities: {
...store.entities,
packs: { data: {}, loading: false },
@ -130,6 +142,7 @@ describe("EditPackPage - component", () => {
isEdit: false,
packHosts: [],
packLabels: [],
packTeams: [],
scheduledQueries: [],
};
@ -155,6 +168,7 @@ describe("EditPackPage - component", () => {
packHosts: [],
packID: String(packStub.id),
packLabels: [],
packTeams: [],
scheduledQueries: [scheduledQuery],
};
@ -206,6 +220,7 @@ describe("EditPackPage - component", () => {
packHosts: [],
packID: String(packStub.id),
packLabels: [],
packTeams: [],
scheduledQueries: [scheduledQuery],
};
const pushAction = {

View file

@ -8,6 +8,7 @@ import packActions from "redux/nodes/entities/packs/actions";
import PackForm from "components/forms/packs/PackForm";
import PackInfoSidePanel from "components/side_panels/PackInfoSidePanel";
import PATHS from "router/paths";
import permissionUtils from "utilities/permissions";
const baseClass = "pack-composer";
@ -17,6 +18,7 @@ export class PackComposerPage extends Component {
serverErrors: PropTypes.shape({
base: PropTypes.string,
}),
isBasicTier: PropTypes.bool,
};
static defaultProps = {
@ -60,7 +62,7 @@ export class PackComposerPage extends Component {
render() {
const { handleSubmit, onFetchTargets } = this;
const { selectedTargetsCount } = this.state;
const { serverErrors } = this.props;
const { serverErrors, isBasicTier } = this.props;
return (
<div className="has-sidebar">
@ -70,6 +72,7 @@ export class PackComposerPage extends Component {
onFetchTargets={onFetchTargets}
selectedTargetsCount={selectedTargetsCount}
serverErrors={serverErrors}
isBasicTier={isBasicTier}
/>
<PackInfoSidePanel />
</div>
@ -79,8 +82,9 @@ export class PackComposerPage extends Component {
const mapStateToProps = (state) => {
const { errors: serverErrors } = state.entities.packs;
const isBasicTier = permissionUtils.isBasicTier(state.app.config);
return { serverErrors };
return { serverErrors, isBasicTier };
};
export default connect(mapStateToProps)(PackComposerPage);

View file

@ -1,5 +1,6 @@
import { mount } from "enzyme";
import { configStub } from "test/stubs";
import { connectedComponent, reduxMockStore } from "test/helpers";
import ConnectedPacksComposerPage from "./PackComposerPage";
@ -8,6 +9,7 @@ describe("PackComposerPage - component", () => {
entities: {
packs: {},
},
app: { config: configStub },
});
it("renders", () => {
const page = mount(

View file

@ -1,6 +1,7 @@
import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
import schemas from "redux/nodes/entities/base/schemas";
import { formatPackForClient } from "fleet/helpers";
const { PACKS: schema } = schemas;
@ -10,6 +11,7 @@ export default new Config({
entityName: "packs",
loadAllFunc: Fleet.packs.loadAll,
loadFunc: Fleet.packs.load,
parseEntityFunc: formatPackForClient,
schema,
updateFunc: Fleet.packs.update,
});

View file

@ -1,4 +1,4 @@
import helpers from "fleet/helpers";
import { formatScheduledQueryForClient } from "fleet/helpers";
import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
import schemas from "redux/nodes/entities/base/schemas";
@ -10,7 +10,7 @@ export default new Config({
destroyFunc: Fleet.scheduledQueries.destroy,
entityName: "scheduled_queries",
loadAllFunc: Fleet.scheduledQueries.loadAll,
parseEntityFunc: helpers.formatScheduledQueryForClient,
parseEntityFunc: formatScheduledQueryForClient,
schema,
updateFunc: Fleet.scheduledQueries.update,
});

View file

@ -5,6 +5,7 @@ import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
// @ts-ignore
import schemas from "redux/nodes/entities/base/schemas";
import { formatTeamForClient } from "fleet/helpers";
const { TEAMS } = schemas;
@ -14,6 +15,7 @@ export default new Config({
entityName: "teams",
loadFunc: Fleet.teams.load,
loadAllFunc: Fleet.teams.loadAll,
parseEntityFunc: formatTeamForClient,
schema: TEAMS,
updateFunc: Fleet.teams.update,
});

View file

@ -123,6 +123,7 @@ export const packStub = {
disabled: false,
host_ids: [],
label_ids: [],
team_ids: [],
};
export const queryStub = {

View file

@ -26,13 +26,12 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testSaveQuery,
testListQuery,
testDeletePack,
testNewPack,
testSavePack,
testEnrollHost,
testAuthenticateHost,
testAuthenticateHostCaseSensitive,
testLabels,
testSaveLabel,
testManagingLabelsOnPacks,
testPasswordResetRequests,
testCreateUser,
testSaveUser,
@ -52,7 +51,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testListHostsFilterAdditional,
testListHostsStatus,
testListHostsQuery,
testListHostsInPack,
testListPacksForHost,
testHostIDsByName,
testHostByIdentifier,
@ -69,7 +67,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testCascadingDeletionOfQueries,
testGetPackByName,
testGetQueryByName,
testAddLabelToPackTwice,
testGenerateHostStatusStatistics,
testMarkHostSeen,
testMarkHostsSeen,
@ -90,7 +87,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testApplyLabelSpecsRoundtrip,
testGetLabelSpec,
testLabelIDsByName,
testListLabelsForPack,
testHostAdditional,
testCarveMetadata,
testCarveBlocks,

View file

@ -156,62 +156,6 @@ func testLabels(t *testing.T, db fleet.Datastore) {
assert.Len(t, labels, 1)
}
func testManagingLabelsOnPacks(t *testing.T, ds fleet.Datastore) {
pack := &fleet.PackSpec{
ID: 1,
Name: "pack1",
}
err := ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
require.Nil(t, err)
labels, err := ds.ListLabelsForPack(pack.ID)
require.Nil(t, err)
assert.Len(t, labels, 0)
mysqlLabel := &fleet.LabelSpec{
ID: 1,
Name: "MySQL Monitoring",
Query: "select pid from processes where name = 'mysqld';",
}
err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel})
require.Nil(t, err)
pack.Targets = fleet.PackSpecTargets{
Labels: []string{
mysqlLabel.Name,
},
}
err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
require.Nil(t, err)
labels, err = ds.ListLabelsForPack(pack.ID)
require.Nil(t, err)
if assert.Len(t, labels, 1) {
assert.Equal(t, "MySQL Monitoring", labels[0].Name)
}
osqueryLabel := &fleet.LabelSpec{
ID: 2,
Name: "Osquery Monitoring",
Query: "select pid from processes where name = 'osqueryd';",
}
err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel, osqueryLabel})
require.Nil(t, err)
pack.Targets = fleet.PackSpecTargets{
Labels: []string{
mysqlLabel.Name,
osqueryLabel.Name,
},
}
err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
require.Nil(t, err)
labels, err = ds.ListLabelsForPack(pack.ID)
require.Nil(t, err)
assert.Len(t, labels, 2)
}
func testSearchLabels(t *testing.T, db fleet.Datastore) {
specs := []*fleet.LabelSpec{
&fleet.LabelSpec{

View file

@ -26,18 +26,38 @@ func testDeletePack(t *testing.T, ds fleet.Datastore) {
assert.NotNil(t, err)
}
func testNewPack(t *testing.T, ds fleet.Datastore) {
pack := &fleet.Pack{
Name: "foo",
func testSavePack(t *testing.T, ds fleet.Datastore) {
expectedPack := &fleet.Pack{
Name: "foo",
HostIDs: []uint{1},
LabelIDs: []uint{1},
TeamIDs: []uint{1},
}
pack, err := ds.NewPack(pack)
pack, err := ds.NewPack(expectedPack)
require.NoError(t, err)
assert.NotEqual(t, uint(0), pack.ID)
test.EqualSkipTimestampsID(t, expectedPack, pack)
pack, err = ds.Pack(pack.ID)
require.NoError(t, err)
assert.Equal(t, "foo", pack.Name)
test.EqualSkipTimestampsID(t, expectedPack, pack)
expectedPack = &fleet.Pack{
ID: pack.ID,
Name: "bar",
HostIDs: []uint{3},
LabelIDs: []uint{4, 6},
TeamIDs: []uint{},
}
err = ds.SavePack(expectedPack)
require.NoError(t, err)
pack, err = ds.Pack(pack.ID)
require.NoError(t, err)
assert.Equal(t, "bar", pack.Name)
test.EqualSkipTimestampsID(t, expectedPack, pack)
}
func testGetPackByName(t *testing.T, ds fleet.Datastore) {
@ -81,92 +101,8 @@ func testListPacks(t *testing.T, ds fleet.Datastore) {
assert.Len(t, packs, 2)
}
func testListHostsInPack(t *testing.T, ds fleet.Datastore) {
if ds.Name() == "inmem" {
t.Skip("inmem is deprecated")
}
mockClock := clock.NewMockClock()
l1 := fleet.LabelSpec{
ID: 1,
Name: "foo",
}
err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1})
require.Nil(t, err)
p1 := &fleet.PackSpec{
ID: 1,
Name: "foo_pack",
Targets: fleet.PackSpecTargets{
Labels: []string{
l1.Name,
},
},
}
err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1})
require.Nil(t, err)
h1 := test.NewHost(t, ds, "h1.local", "10.10.10.1", "1", "1", mockClock.Now())
hostsInPack, err := ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
require.Nil(t, err)
require.Len(t, hostsInPack, 0)
err = ds.RecordLabelQueryExecutions(
h1,
map[uint]bool{l1.ID: true},
mockClock.Now(),
)
require.Nil(t, err)
hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
require.Nil(t, err)
require.Len(t, hostsInPack, 1)
explicitHostsInPack, err := ds.ListExplicitHostsInPack(p1.ID, fleet.ListOptions{})
require.Nil(t, err)
require.Len(t, explicitHostsInPack, 0)
h2 := test.NewHost(t, ds, "h2.local", "10.10.10.2", "2", "2", mockClock.Now())
err = ds.RecordLabelQueryExecutions(
h2,
map[uint]bool{l1.ID: true},
mockClock.Now(),
)
require.Nil(t, err)
hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
require.Nil(t, err)
require.Len(t, hostsInPack, 2)
}
func testAddLabelToPackTwice(t *testing.T, ds fleet.Datastore) {
l1 := fleet.LabelSpec{
ID: 1,
Name: "l1",
Query: "select 1",
}
err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1})
require.Nil(t, err)
p1 := &fleet.PackSpec{
ID: 1,
Name: "pack1",
Targets: fleet.PackSpecTargets{
Labels: []string{
l1.Name,
l1.Name,
},
},
}
err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1})
require.NotNil(t, err)
}
func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@fleet.co", true)
zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@example.com", true)
queries := []*fleet.Query{
{Name: "foo", Description: "get the foos", Query: "select * from foo"},
{Name: "bar", Description: "do some bars", Query: "select baz from bar"},
@ -176,15 +112,15 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
require.Nil(t, err)
labels := []*fleet.LabelSpec{
&fleet.LabelSpec{
{
Name: "foo",
Query: "select * from foo",
},
&fleet.LabelSpec{
{
Name: "bar",
Query: "select * from bar",
},
&fleet.LabelSpec{
{
Name: "bing",
Query: "select * from bing",
},
@ -193,7 +129,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
require.Nil(t, err)
expectedSpecs := []*fleet.PackSpec{
&fleet.PackSpec{
{
ID: 1,
Name: "test_pack",
Targets: fleet.PackSpecTargets{
@ -204,20 +140,20 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: queries[0].Name,
Name: "q0",
Description: "test_foo",
Interval: 42,
},
fleet.PackSpecQuery{
{
QueryName: queries[0].Name,
Name: "foo_snapshot",
Interval: 600,
Snapshot: ptr.Bool(true),
Denylist: ptr.Bool(false),
},
fleet.PackSpecQuery{
{
Name: "q2",
QueryName: queries[1].Name,
Interval: 600,
@ -229,7 +165,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
},
&fleet.PackSpec{
{
ID: 2,
Name: "test_pack_disabled",
Disabled: true,
@ -241,19 +177,19 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: queries[0].Name,
Name: "q0",
Description: "test_foo",
Interval: 42,
},
fleet.PackSpecQuery{
{
QueryName: queries[0].Name,
Name: "foo_snapshot",
Interval: 600,
Snapshot: ptr.Bool(true),
},
fleet.PackSpecQuery{
{
Name: "q2",
QueryName: queries[1].Name,
Interval: 600,
@ -292,14 +228,14 @@ func testGetPackSpec(t *testing.T, ds fleet.Datastore) {
func testApplyPackSpecMissingQueries(t *testing.T, ds fleet.Datastore) {
// Do not define queries mentioned in spec
specs := []*fleet.PackSpec{
&fleet.PackSpec{
{
ID: 1,
Name: "test_pack",
Targets: fleet.PackSpecTargets{
Labels: []string{},
},
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: "bar",
Interval: 600,
},
@ -318,13 +254,13 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) {
setupPackSpecsTest(t, ds)
specs := []*fleet.PackSpec{
&fleet.PackSpec{
{
Name: "test2",
Targets: fleet.PackSpecTargets{
Labels: []string{},
},
Queries: []fleet.PackSpecQuery{
fleet.PackSpecQuery{
{
QueryName: "foo",
Interval: 600,
},
@ -340,67 +276,6 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) {
assert.Equal(t, "foo", spec.Queries[0].Name)
}
func testListLabelsForPack(t *testing.T, ds fleet.Datastore) {
labelSpecs := []*fleet.LabelSpec{
&fleet.LabelSpec{
Name: "foo",
Query: "select * from foo",
},
&fleet.LabelSpec{
Name: "bar",
Query: "select * from bar",
},
&fleet.LabelSpec{
Name: "bing",
Query: "select * from bing",
},
}
err := ds.ApplyLabelSpecs(labelSpecs)
require.Nil(t, err)
specs := []*fleet.PackSpec{
&fleet.PackSpec{
ID: 1,
Name: "test_pack",
Targets: fleet.PackSpecTargets{
Labels: []string{
"foo",
"bar",
"bing",
},
},
},
&fleet.PackSpec{
ID: 2,
Name: "test 2",
Targets: fleet.PackSpecTargets{
Labels: []string{
"bing",
},
},
},
&fleet.PackSpec{
ID: 3,
Name: "test 3",
},
}
err = ds.ApplyPackSpecs(specs)
require.Nil(t, err)
labels, err := ds.ListLabelsForPack(specs[0].ID)
require.Nil(t, err)
assert.Len(t, labels, 3)
labels, err = ds.ListLabelsForPack(specs[1].ID)
require.Nil(t, err)
assert.Len(t, labels, 1)
assert.Equal(t, "bing", labels[0].Name)
labels, err = ds.ListLabelsForPack(specs[2].ID)
require.Nil(t, err)
assert.Len(t, labels, 0)
}
func testListPacksForHost(t *testing.T, ds fleet.Datastore) {
if ds.Name() == "inmem" {
t.Skip("inmem is deprecated")
@ -507,40 +382,4 @@ func testListPacksForHost(t *testing.T, ds fleet.Datastore) {
if assert.Len(t, packs, 1) {
assert.Equal(t, "foo_pack", packs[0].Name)
}
// Add host directly to pack
err = ds.AddHostToPack(h1.ID, p2.ID)
require.Nil(t, err)
packs, err = ds.ListPacksForHost(h1.ID)
require.Nil(t, err)
assert.Len(t, packs, 2)
// Remove label membership for both
err = ds.RecordLabelQueryExecutions(
h1,
map[uint]bool{l2.ID: false, l1.ID: false},
mockClock.Now(),
)
require.Nil(t, err)
err = ds.RecordLabelQueryExecutions(
h1,
map[uint]bool{l2.ID: false},
mockClock.Now(),
)
require.Nil(t, err)
packs, err = ds.ListPacksForHost(h1.ID)
require.Nil(t, err)
if assert.Len(t, packs, 1) {
assert.Equal(t, p2.Name, packs[0].Name)
}
// Now host is added directly to both packs
err = ds.AddHostToPack(h1.ID, p1.ID)
require.Nil(t, err)
packs, err = ds.ListPacksForHost(h1.ID)
require.Nil(t, err)
assert.Len(t, packs, 2)
}

View file

@ -178,7 +178,6 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
// NewLabel creates a new fleet.Label
func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) {
db := d.getTransaction(opts)
query := `
INSERT INTO labels (
name,
@ -189,7 +188,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
label_membership_type
) VALUES ( ?, ?, ?, ?, ?, ?)
`
result, err := db.Exec(
result, err := d.db.Exec(
query,
label.Name,
label.Description,

View file

@ -0,0 +1,35 @@
package tables
import (
"database/sql"
"github.com/pkg/errors"
)
func init() {
MigrationClient.AddMigration(Up_20210617174723, Down_20210617174723)
}
func Up_20210617174723(tx *sql.Tx) error {
sql := `
DELETE FROM pack_targets
WHERE pack_id NOT IN (SELECT id FROM packs)
`
if _, err := tx.Exec(sql); err != nil {
return errors.Wrap(err, "delete orphaned pack targets")
}
sql = `
ALTER TABLE pack_targets
ADD FOREIGN KEY (pack_id) REFERENCES packs (id) ON UPDATE CASCADE ON DELETE CASCADE
`
if _, err := tx.Exec(sql); err != nil {
return errors.Wrap(err, "add foreign key on pack_targets pack_id")
}
return nil
}
func Down_20210617174723(tx *sql.Tx) error {
return nil
}

View file

@ -46,23 +46,6 @@ type Datastore struct {
config config.MysqlConfig
}
type dbfunctions interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Get(dest interface{}, query string, args ...interface{}) error
Select(dest interface{}, query string, args ...interface{}) error
}
func (d *Datastore) getTransaction(opts []fleet.OptionalArg) dbfunctions {
var result dbfunctions = d.db
for _, opt := range opts {
switch t := opt().(type) {
case dbfunctions:
result = t
}
}
return result
}
type txFn func(*sqlx.Tx) error
// retryableError determines whether a MySQL error can be retried. By default

View file

@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"fmt"
"strings"
"github.com/fleetdm/fleet/server/fleet"
"github.com/jmoiron/sqlx"
@ -198,19 +199,22 @@ WHERE pack_id = ?
}
func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.Pack, bool, error) {
db := d.getTransaction(opts)
sqlStatement := `
SELECT *
FROM packs
WHERE name = ?
`
var pack fleet.Pack
err := db.Get(&pack, sqlStatement, name)
err := d.db.Get(&pack, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, false, nil
}
return nil, false, errors.Wrap(err, "fetching packs by name")
return nil, false, errors.Wrap(err, "fetch pack by name")
}
if err := d.loadPackTargets(&pack); err != nil {
return nil, false, err
}
return &pack, true, nil
@ -218,44 +222,144 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
// NewPack creates a new Pack
func (d *Datastore) NewPack(pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) {
db := d.getTransaction(opts)
if err := d.withRetryTxx(func(tx *sqlx.Tx) error {
query := `
INSERT INTO packs
(name, description, platform, disabled)
VALUES ( ?, ?, ?, ? )
`
result, err := d.db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled)
if err != nil {
return errors.Wrap(err, "insert pack")
}
query := `
INSERT INTO packs
(name, description, platform, disabled)
VALUES ( ?, ?, ?, ? )
`
id, _ := result.LastInsertId()
pack.ID = uint(id)
result, err := db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled)
if err != nil {
return nil, errors.Wrap(err, "inserting pack")
if err := d.replacePackTargets(tx, pack); err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}
return pack, nil
}
func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
sql := `DELETE FROM pack_targets WHERE pack_id = ?`
if _, err := tx.Exec(sql, pack.ID); err != nil {
return errors.Wrap(err, "delete pack targets")
}
id, _ := result.LastInsertId()
pack.ID = uint(id)
return pack, nil
// Insert hosts
if len(pack.HostIDs) > 0 {
var args []interface{}
for _, id := range pack.HostIDs {
args = append(args, pack.ID, fleet.TargetHost, id)
}
values := strings.TrimSuffix(
strings.Repeat("(?,?,?),", len(pack.HostIDs)),
",",
)
sql = fmt.Sprintf(`
INSERT INTO pack_targets (pack_id, type, target_id)
VALUES %s
`, values)
if _, err := tx.Exec(sql, args...); err != nil {
return errors.Wrap(err, "insert host targets")
}
}
// Insert labels
if len(pack.LabelIDs) > 0 {
var args []interface{}
for _, id := range pack.LabelIDs {
args = append(args, pack.ID, fleet.TargetLabel, id)
}
values := strings.TrimSuffix(
strings.Repeat("(?,?,?),", len(pack.LabelIDs)),
",",
)
sql = fmt.Sprintf(`
INSERT INTO pack_targets (pack_id, type, target_id)
VALUES %s
`, values)
if _, err := tx.Exec(sql, args...); err != nil {
return errors.Wrap(err, "insert label targets")
}
}
// Insert teams
if len(pack.TeamIDs) > 0 {
var args []interface{}
for _, id := range pack.TeamIDs {
args = append(args, pack.ID, fleet.TargetTeam, id)
}
values := strings.TrimSuffix(
strings.Repeat("(?,?,?),", len(pack.TeamIDs)),
",",
)
sql = fmt.Sprintf(`
INSERT INTO pack_targets (pack_id, type, target_id)
VALUES %s
`, values)
if _, err := tx.Exec(sql, args...); err != nil {
return errors.Wrap(err, "insert team targets")
}
}
return nil
}
func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
var targets []fleet.PackTarget
sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
if err := d.db.Select(&targets, sql, pack.ID); err != nil {
return errors.Wrap(err, "select pack targets")
}
pack.HostIDs, pack.LabelIDs, pack.TeamIDs = []uint{}, []uint{}, []uint{}
for _, target := range targets {
switch target.Type {
case fleet.TargetHost:
pack.HostIDs = append(pack.HostIDs, target.TargetID)
case fleet.TargetLabel:
pack.LabelIDs = append(pack.LabelIDs, target.TargetID)
case fleet.TargetTeam:
pack.TeamIDs = append(pack.TeamIDs, target.TargetID)
default:
return errors.Errorf("unknown target type: %d", target.Type)
}
}
return nil
}
// SavePack stores changes to pack
func (d *Datastore) SavePack(pack *fleet.Pack) error {
query := `
return d.withRetryTxx(func(tx *sqlx.Tx) error {
query := `
UPDATE packs
SET name = ?, platform = ?, disabled = ?, description = ?
WHERE id = ?
`
results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID)
if err != nil {
return errors.Wrap(err, "updating pack")
}
rowsAffected, err := results.RowsAffected()
if err != nil {
return errors.Wrap(err, "rows affected updating packs")
}
if rowsAffected == 0 {
return notFound("Pack").WithID(pack.ID)
}
return nil
results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID)
if err != nil {
return errors.Wrap(err, "updating pack")
}
rowsAffected, err := results.RowsAffected()
if err != nil {
return errors.Wrap(err, "rows affected updating packs")
}
if rowsAffected == 0 {
return notFound("Pack").WithID(pack.ID)
}
return d.replacePackTargets(tx, pack)
})
}
// DeletePack deletes a fleet.Pack so that it won't show up in results.
@ -271,7 +375,11 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
if err == sql.ErrNoRows {
return nil, notFound("Pack").WithID(pid)
} else if err != nil {
return nil, errors.Wrap(err, "getting pack")
return nil, errors.Wrap(err, "get pack")
}
if err := d.loadPackTargets(pack); err != nil {
return nil, err
}
return pack, nil
@ -285,107 +393,22 @@ func (d *Datastore) ListPacks(opt fleet.ListOptions) ([]*fleet.Pack, error) {
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing packs")
}
for _, pack := range packs {
if err := d.loadPackTargets(pack); err != nil {
return nil, err
}
}
return packs, nil
}
// AddLabelToPack associates a fleet.Label with a fleet.Pack
func (d *Datastore) AddLabelToPack(lid uint, pid uint, opts ...fleet.OptionalArg) error {
db := d.getTransaction(opts)
query := `
INSERT INTO pack_targets ( pack_id, type, target_id )
VALUES ( ?, ?, ? )
ON DUPLICATE KEY UPDATE id=id
`
_, err := db.Exec(query, pid, fleet.TargetLabel, lid)
if err != nil {
return errors.Wrap(err, "adding label to pack")
}
return nil
}
// AddHostToPack associates a fleet.Host with a fleet.Pack
func (d *Datastore) AddHostToPack(hid, pid uint) error {
query := `
INSERT INTO pack_targets ( pack_id, type, target_id )
VALUES ( ?, ?, ? )
ON DUPLICATE KEY UPDATE id=id
`
_, err := d.db.Exec(query, pid, fleet.TargetHost, hid)
if err != nil {
return errors.Wrap(err, "adding host to pack")
}
return nil
}
// RemoreLabelFromPack will remove the association between a fleet.Label and
// a fleet.Pack
func (d *Datastore) RemoveLabelFromPack(lid, pid uint) error {
query := `
DELETE FROM pack_targets
WHERE target_id = ? AND pack_id = ? AND type = ?
`
_, err := d.db.Exec(query, lid, pid, fleet.TargetLabel)
if err == sql.ErrNoRows {
return notFound("PackTarget").WithMessage(fmt.Sprintf("label ID: %d, pack ID: %d", lid, pid))
} else if err != nil {
return errors.Wrap(err, "removing label from pack")
}
return nil
}
// RemoveHostFromPack will remove the association between a fleet.Host and a
// fleet.Pack
func (d *Datastore) RemoveHostFromPack(hid, pid uint) error {
query := `
DELETE FROM pack_targets
WHERE target_id = ? AND pack_id = ? AND type = ?
`
_, err := d.db.Exec(query, hid, pid, fleet.TargetHost)
if err == sql.ErrNoRows {
return notFound("PackTarget").WithMessage(fmt.Sprintf("host ID: %d, pack ID: %d", hid, pid))
} else if err != nil {
return errors.Wrap(err, "removing host from pack")
}
return nil
}
// ListLabelsForPack will return a list of fleet.Label records associated with fleet.Pack
func (d *Datastore) ListLabelsForPack(pid uint) ([]*fleet.Label, error) {
query := `
SELECT
l.id,
l.created_at,
l.updated_at,
l.name
FROM
labels l
JOIN
pack_targets pt
ON
pt.target_id = l.id
WHERE
pt.type = ?
AND
pt.pack_id = ?
`
labels := []*fleet.Label{}
if err := d.db.Select(&labels, query, fleet.TargetLabel, pid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing labels for pack")
}
return labels, nil
}
func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
query := `
SELECT DISTINCT packs.*
FROM
((SELECT p.* FROM packs p
((SELECT p.*
FROM packs p
JOIN pack_targets pt
JOIN label_membership lm
ON (
@ -399,55 +422,17 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
FROM packs p
JOIN pack_targets pt
ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = ?))
UNION ALL
(SELECT p.*
FROM packs p
JOIN pack_targets pt
ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = (SELECT team_id FROM hosts WHERE id = ?)))
) packs
`
packs := []*fleet.Pack{}
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid); err != nil && err != sql.ErrNoRows {
if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing hosts in pack")
}
return packs, nil
}
func (d *Datastore) ListHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) {
query := `
SELECT DISTINCT h.id
FROM hosts h
JOIN pack_targets pt
JOIN label_membership lm
ON (
pt.target_id = lm.label_id
AND lm.host_id = h.id
AND pt.type = ?
) OR (
pt.target_id = h.id
AND pt.type = ?
)
WHERE pt.pack_id = ?
`
hosts := []uint{}
if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetLabel, fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing hosts in pack")
}
return hosts, nil
}
func (d *Datastore) ListExplicitHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) {
query := `
SELECT DISTINCT h.id
FROM hosts h
JOIN pack_targets pt
ON (
pt.target_id = h.id
AND pt.type = ?
)
WHERE pt.pack_id = ?
`
hosts := []uint{}
if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing explicit hosts in pack")
}
return hosts, nil
}

View file

@ -65,14 +65,13 @@ func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err err
}
func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.Query, error) {
db := d.getTransaction(opts)
sqlStatement := `
SELECT *
FROM queries
WHERE name = ?
`
var query fleet.Query
err := db.Get(&query, sqlStatement, name)
err := d.db.Get(&query, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, notFound("Query").WithName(name)
@ -89,8 +88,6 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
// NewQuery creates a New Query.
func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
db := d.getTransaction(opts)
sqlStatement := `
INSERT INTO queries (
name,
@ -101,7 +98,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
observer_can_run
) VALUES ( ?, ?, ?, ?, ?, ? )
`
result, err := db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
if err != nil && isDuplicate(err) {
return nil, alreadyExists("Query", 0)

View file

@ -40,8 +40,6 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
}
func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.OptionalArg) (*fleet.ScheduledQuery, error) {
db := d.getTransaction(opts)
// This query looks up the query name using the ID (for backwards
// compatibility with the UI)
query := `
@ -61,7 +59,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
FROM queries
WHERE id = ?
`
result, err := db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID)
result, err := d.db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID)
if err != nil {
return nil, errors.Wrap(err, "insert scheduled query")
}
@ -75,7 +73,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
Name string
}{}
err = db.Select(&metadata, query, sq.QueryID)
err = d.db.Select(&metadata, query, sq.QueryID)
if err != nil && err == sql.ErrNoRows {
return nil, notFound("Query").WithID(sq.QueryID)
} else if err != nil {

View file

@ -33,33 +33,8 @@ type PackStore interface {
// exists the bool return value is true
PackByName(name string, opts ...OptionalArg) (*Pack, bool, error)
// AddLabelToPack adds an existing label to an existing pack, both by ID.
AddLabelToPack(lid, pid uint, opts ...OptionalArg) error
// RemoveLabelFromPack removes an existing label from it's association with
// an existing pack, both by ID.
RemoveLabelFromPack(lid, pid uint) error
// ListLabelsForPack lists all labels that are associated with a pack.
ListLabelsForPack(pid uint) ([]*Label, error)
// AddHostToPack adds an existing host to an existing pack, both by ID.
AddHostToPack(hid uint, pid uint) error
// RemoveHostFromPack removes an existing host from it's association with
// an existing pack, both by ID.
RemoveHostFromPack(hid uint, pid uint) error
// ListPacksForHost lists the packs that a host should execute.
ListPacksForHost(hid uint) (packs []*Pack, err error)
// ListHostsInPack lists the IDs of all hosts that are associated with a pack
// through labels.
ListHostsInPack(pid uint, opt ListOptions) ([]uint, error)
// ListExplicitHostsInPack lists the IDs of hosts that have been manually
// associated with a query pack.
ListExplicitHostsInPack(pid uint, opt ListOptions) ([]uint, error)
}
// PackService is the service interface for managing query packs.
@ -90,33 +65,8 @@ type PackService interface {
// DeletePackByID is for backwards compatibility with the UI
DeletePackByID(ctx context.Context, id uint) (err error)
// AddLabelToPack adds an existing label to an existing pack, both by ID.
AddLabelToPack(ctx context.Context, lid, pid uint) (err error)
// RemoveLabelFromPack removes an existing label from it's association with
// an existing pack, both by ID.
RemoveLabelFromPack(ctx context.Context, lid, pid uint) (err error)
// ListLabelsForPack lists all labels that are associated with a pack.
ListLabelsForPack(ctx context.Context, pid uint) (labels []*Label, err error)
// AddHostToPack adds an existing host to an existing pack, both by ID.
AddHostToPack(ctx context.Context, hid, pid uint) (err error)
// RemoveHostFromPack removes an existing host from it's association with
// an existing pack, both by ID.
RemoveHostFromPack(ctx context.Context, hid, pid uint) (err error)
// ListPacksForHost lists the packs that a host should execute.
ListPacksForHost(ctx context.Context, hid uint) (packs []*Pack, err error)
// ListHostsInPack lists the IDs of all hosts that are associated with a pack,
// both through labels and manual associations.
ListHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error)
// ListExplicitHostsInPack lists the IDs of hosts that have been manually associated
// with a query pack.
ListExplicitHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error)
}
// Pack is the structure which represents an osquery query pack.
@ -127,6 +77,9 @@ type Pack struct {
Description string `json:"description,omitempty"`
Platform string `json:"platform,omitempty"`
Disabled bool `json:"disabled,omitempty"`
LabelIDs []uint `json:"label_ids"`
HostIDs []uint `json:"host_ids"`
TeamIDs []uint `json:"team_ids"`
}
func (p Pack) AuthzType() string {
@ -145,6 +98,7 @@ type PackPayload struct {
Disabled *bool `json:"disabled"`
HostIDs *[]uint `json:"host_ids"`
LabelIDs *[]uint `json:"label_ids"`
TeamIDs *[]uint `json:"team_ids"`
}
type PackSpec struct {
@ -174,10 +128,10 @@ type PackSpecQuery struct {
Denylist *bool `json:"denylist,omitempty"`
}
// PackTarget associates a pack with either a host or a label
// PackTarget targets a pack to a host, label, or team.
type PackTarget struct {
ID uint
PackID uint
ID uint `db:"id"`
PackID uint `db:"pack_id"`
Target
}

View file

@ -78,8 +78,8 @@ const (
)
type Target struct {
Type TargetType
TargetID uint
Type TargetType `db:"type"`
TargetID uint `db:"target_id"`
}
func (t Target) AuthzType() string {

View file

@ -22,7 +22,7 @@ type appConfigResponse struct {
SSOSettings *fleet.SSOSettingsPayload `json:"sso_settings,omitempty"`
HostExpirySettings *fleet.HostExpirySettings `json:"host_expiry_settings,omitempty"`
HostSettings *fleet.HostSettings `json:"host_settings,omitempty"`
AgentOptions *json.RawMessage `json:"agent_options,omitempty"`
AgentOptions *json.RawMessage `json:"agent_options,omitempty"`
License *fleet.LicenseInfo `json:"license,omitempty"`
Err error `json:"error,omitempty"`
}

View file

@ -18,6 +18,7 @@ type packResponse struct {
// IDs of hosts which were explicitly selected.
HostIDs []uint `json:"host_ids"`
LabelIDs []uint `json:"label_ids"`
TeamIDs []uint `json:"team_ids"`
}
func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) {
@ -27,21 +28,11 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack
return nil, err
}
hosts, err := svc.ListExplicitHostsInPack(ctx, pack.ID, opts)
if err != nil {
return nil, err
}
labels, err := svc.ListLabelsForPack(ctx, pack.ID)
labelIDs := make([]uint, len(labels))
for i, label := range labels {
labelIDs[i] = label.ID
}
if err != nil {
return nil, err
}
hostMetrics, err := svc.CountHostsInTargets(ctx, nil, fleet.HostTargets{HostIDs: hosts, LabelIDs: labelIDs})
hostMetrics, err := svc.CountHostsInTargets(
ctx,
nil,
fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs},
)
if err != nil {
return nil, err
}
@ -50,8 +41,9 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack
Pack: pack,
QueryCount: uint(len(queries)),
TotalHostsCount: hostMetrics.TotalHosts,
HostIDs: hosts,
LabelIDs: labelIDs,
HostIDs: pack.HostIDs,
LabelIDs: pack.LabelIDs,
TeamIDs: pack.TeamIDs,
}, nil
}

View file

@ -117,92 +117,6 @@ func (mw loggingMiddleware) DeletePack(ctx context.Context, name string) error {
return err
}
func (mw loggingMiddleware) AddLabelToPack(ctx context.Context, lid uint, pid uint) error {
var (
err error
)
defer func(begin time.Time) {
_ = mw.loggerInfo(err).Log(
"method", "AddLabelToPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
err = mw.Service.AddLabelToPack(ctx, lid, pid)
return err
}
func (mw loggingMiddleware) RemoveLabelFromPack(ctx context.Context, lid uint, pid uint) error {
var (
err error
)
defer func(begin time.Time) {
_ = mw.loggerInfo(err).Log(
"method", "RemoveLabelFromPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
err = mw.Service.RemoveLabelFromPack(ctx, lid, pid)
return err
}
func (mw loggingMiddleware) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) {
var (
labels []*fleet.Label
err error
)
defer func(begin time.Time) {
_ = mw.loggerDebug(err).Log(
"method", "ListLabelsForPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
labels, err = mw.Service.ListLabelsForPack(ctx, pid)
return labels, err
}
func (mw loggingMiddleware) AddHostToPack(ctx context.Context, hid uint, pid uint) error {
var (
err error
)
defer func(begin time.Time) {
_ = mw.loggerInfo(err).Log(
"method", "AddHostToPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
err = mw.Service.AddHostToPack(ctx, hid, pid)
return err
}
func (mw loggingMiddleware) RemoveHostFromPack(ctx context.Context, hid uint, pid uint) error {
var (
err error
)
defer func(begin time.Time) {
_ = mw.loggerInfo(err).Log(
"method", "RemoveHostFromPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
err = mw.Service.RemoveHostFromPack(ctx, hid, pid)
return err
}
func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
var (
packs []*fleet.Pack
@ -221,24 +135,6 @@ func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]*
return packs, err
}
func (mw loggingMiddleware) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
var (
hosts []uint
err error
)
defer func(begin time.Time) {
_ = mw.loggerDebug(err).Log(
"method", "ListHostsInPack",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
hosts, err = mw.Service.ListHostsInPack(ctx, pid, opt)
return hosts, err
}
func (mw loggingMiddleware) GetPackSpec(ctx context.Context, name string) (spec *fleet.PackSpec, err error) {
defer func(begin time.Time) {
_ = mw.loggerDebug(err).Log(

View file

@ -19,7 +19,7 @@ func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe
}
if spec.LabelMembershipType == fleet.LabelMembershipTypeManual && spec.Hosts == nil {
// Hosts list doesn't need to contain anything, but it should at least not be nil.
return errors.Errorf("label %s is declared as manual but contains not `hosts key`", spec.Name)
return errors.Errorf("label %s is declared as manual but contains no `hosts key`", spec.Name)
}
}
return svc.ds.ApplyLabelSpecs(specs)

View file

@ -209,13 +209,13 @@ func (svc *Service) GetClientConfig(ctx context.Context) (map[string]interface{}
baseConfig, err := svc.AgentOptionsForHost(ctx, &host)
if err != nil {
return nil, osqueryError{message: "internal error: fetching base config: " + err.Error()}
return nil, osqueryError{message: "internal error: fetch base config: " + err.Error()}
}
var config map[string]interface{}
err = json.Unmarshal(baseConfig, &config)
if err != nil {
return nil, osqueryError{message: "internal error: parsing base configuration: " + err.Error()}
return nil, osqueryError{message: "internal error: parse base configuration: " + err.Error()}
}
packs, err := svc.ds.ListPacksForHost(host.ID)

View file

@ -74,24 +74,6 @@ func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pa
return nil, err
}
if p.HostIDs != nil {
for _, hostID := range *p.HostIDs {
err = svc.AddHostToPack(ctx, hostID, pack.ID)
if err != nil {
return nil, err
}
}
}
if p.LabelIDs != nil {
for _, labelID := range *p.LabelIDs {
err = svc.AddLabelToPack(ctx, labelID, pack.ID)
if err != nil {
return nil, err
}
}
}
return &pack, nil
}
@ -121,107 +103,23 @@ func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload
pack.Disabled = *p.Disabled
}
if p.HostIDs != nil {
pack.HostIDs = *p.HostIDs
}
if p.LabelIDs != nil {
pack.LabelIDs = *p.LabelIDs
}
if p.TeamIDs != nil {
pack.TeamIDs = *p.TeamIDs
}
err = svc.ds.SavePack(pack)
if err != nil {
return nil, err
}
// we must determine what hosts are attached to this pack. then, given
// our new set of host_ids, we will mutate the database to reflect the
// desired state.
if p.HostIDs != nil {
// first, let's retrieve the total set of hosts
hosts, err := svc.ListHostsInPack(ctx, pack.ID, fleet.ListOptions{})
if err != nil {
return nil, err
}
// it will be efficient to create a data structure with constant time
// lookups to determine whether or not a host is already added
existingHosts := map[uint]bool{}
for _, host := range hosts {
existingHosts[host] = true
}
// we will also make a constant time lookup map for the desired set of
// hosts as well.
desiredHosts := map[uint]bool{}
for _, hostID := range *p.HostIDs {
desiredHosts[hostID] = true
}
// if the request declares a host ID but the host is not already
// associated with the pack, we add it
for _, hostID := range *p.HostIDs {
if !existingHosts[hostID] {
err = svc.AddHostToPack(ctx, hostID, pack.ID)
if err != nil {
return nil, err
}
}
}
// if the request does not declare the ID of a host which currently
// exists, we delete the existing relationship
for hostID := range existingHosts {
if !desiredHosts[hostID] {
err = svc.RemoveHostFromPack(ctx, hostID, pack.ID)
if err != nil {
return nil, err
}
}
}
}
// we must determine what labels are attached to this pack. then, given
// our new set of label_ids, we will mutate the database to reflect the
// desired state.
if p.LabelIDs != nil {
// first, let's retrieve the total set of labels
labels, err := svc.ListLabelsForPack(ctx, pack.ID)
if err != nil {
return nil, err
}
// it will be efficient to create a data structure with constant time
// lookups to determine whether or not a label is already added
existingLabels := map[uint]bool{}
for _, label := range labels {
existingLabels[label.ID] = true
}
// we will also make a constant time lookup map for the desired set of
// labels as well.
desiredLabels := map[uint]bool{}
for _, labelID := range *p.LabelIDs {
desiredLabels[labelID] = true
}
// if the request declares a label ID but the label is not already
// associated with the pack, we add it
for _, labelID := range *p.LabelIDs {
if !existingLabels[labelID] {
err = svc.AddLabelToPack(ctx, labelID, pack.ID)
if err != nil {
return nil, err
}
}
}
// if the request does not declare the ID of a label which currently
// exists, we delete the existing relationship
for labelID := range existingLabels {
if !desiredLabels[labelID] {
err = svc.RemoveLabelFromPack(ctx, labelID, pack.ID)
if err != nil {
return nil, err
}
}
}
}
return pack, err
}
@ -245,62 +143,6 @@ func (svc *Service) DeletePackByID(ctx context.Context, id uint) error {
return svc.ds.DeletePack(pack.Name)
}
func (svc *Service) AddLabelToPack(ctx context.Context, lid, pid uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.AddLabelToPack(lid, pid)
}
func (svc *Service) RemoveLabelFromPack(ctx context.Context, lid, pid uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.RemoveLabelFromPack(lid, pid)
}
func (svc *Service) AddHostToPack(ctx context.Context, hid, pid uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.AddHostToPack(hid, pid)
}
func (svc *Service) RemoveHostFromPack(ctx context.Context, hid, pid uint) error {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
return err
}
return svc.ds.RemoveHostFromPack(hid, pid)
}
func (svc *Service) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListLabelsForPack(pid)
}
func (svc *Service) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListHostsInPack(pid, opt)
}
func (svc *Service) ListExplicitHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
}
return svc.ds.ListExplicitHostsInPack(pid, opt)
}
func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err

View file

@ -10,9 +10,16 @@ import (
"github.com/stretchr/testify/assert"
)
type TestingT interface {
assert.TestingT
Helper()
}
// ElementsMatchSkipID asserts that the elements match, skipping any field with
// name "ID".
func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
func ElementsMatchSkipID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
t.Helper()
opt := cmp.FilterPath(func(p cmp.Path) bool {
for _, ps := range p {
switch ps := ps.(type) {
@ -29,7 +36,9 @@ func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs
// ElementsMatchSkipTimestampsID asserts that the elements match, skipping any field with
// name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing after DB insertion.
func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
func ElementsMatchSkipTimestampsID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
t.Helper()
opt := cmp.FilterPath(func(p cmp.Path) bool {
for _, ps := range p {
switch ps := ps.(type) {
@ -45,6 +54,31 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{},
return ElementsMatchWithOptions(t, listA, listB, []cmp.Option{opt}, msgAndArgs)
}
// EqualSkipTimestampsID asserts that the structs are equal, skipping any field
// with name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing
// after DB insertion.
func EqualSkipTimestampsID(t TestingT, a, b interface{}, msgAndArgs ...interface{}) (ok bool) {
t.Helper()
opt := cmp.FilterPath(func(p cmp.Path) bool {
for _, ps := range p {
switch ps := ps.(type) {
case cmp.StructField:
switch ps.Name() {
case "ID", "UpdateCreateTimestamps", "CreateTimestamp", "UpdateTimestamp", "CreatedAt", "UpdatedAt":
return true
}
}
}
return false
}, cmp.Ignore())
if !cmp.Equal(a, b, opt) {
return assert.Fail(t, cmp.Diff(a, b, opt), msgAndArgs...)
}
return true
}
// The below functions adapted from
// https://github.com/stretchr/testify/blob/v1.7.0/assert/assertions.go#L895 by
// utilizing the options provided in github.com/google/go-cmp/cmp
@ -53,7 +87,7 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{},
// additional options as provided by the cmp package. This allows, for example,
// comparing structs while ignoring fields. See assert.ElementsMatch
// documentation for more details.
func ElementsMatchWithOptions(t assert.TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) {
func ElementsMatchWithOptions(t TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) {
if isEmpty(listA) && isEmpty(listB) {
return true
}
@ -100,7 +134,7 @@ func isEmpty(object interface{}) bool {
}
// isList checks that the provided value is array or slice.
func isList(t assert.TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
kind := reflect.TypeOf(list).Kind()
if kind != reflect.Array && kind != reflect.Slice {
return assert.Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind),