diff --git a/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx b/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx
index d83e04623d..48d03553be 100644
--- a/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx
+++ b/frontend/components/forms/packs/EditPackForm/EditPackForm.jsx
@@ -22,6 +22,7 @@ class EditPackForm extends Component {
onCancel: PropTypes.func.isRequired,
onFetchTargets: PropTypes.func,
targetsCount: PropTypes.number,
+ isBasicTier: PropTypes.bool,
};
render() {
@@ -32,6 +33,7 @@ class EditPackForm extends Component {
onCancel,
onFetchTargets,
targetsCount,
+ isBasicTier,
} = this.props;
return (
@@ -58,6 +60,7 @@ class EditPackForm extends Component {
onSelect={fields.targets.onChange}
selectedTargets={fields.targets.value}
targetsCount={targetsCount}
+ isBasicTier={isBasicTier}
/>
diff --git a/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx b/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx
index 808b2872ad..42a7070da1 100644
--- a/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx
+++ b/frontend/components/packs/EditPackFormWrapper/EditPackFormWrapper.jsx
@@ -21,6 +21,7 @@ class EditPackFormWrapper extends Component {
pack: packInterface.isRequired,
packTargets: PropTypes.arrayOf(targetInterface),
targetsCount: PropTypes.number,
+ isBasicTier: PropTypes.bool,
};
render() {
@@ -34,6 +35,7 @@ class EditPackFormWrapper extends Component {
pack,
packTargets,
targetsCount,
+ isBasicTier,
} = this.props;
if (isEdit) {
@@ -45,6 +47,7 @@ class EditPackFormWrapper extends Component {
onCancel={onCancelEditPack}
onFetchTargets={onFetchTargets}
targetsCount={targetsCount}
+ isBasicTier={isBasicTier}
/>
);
}
@@ -73,6 +76,7 @@ class EditPackFormWrapper extends Component {
selectedTargets={packTargets}
targetsCount={targetsCount}
disabled
+ isBasicTier={isBasicTier}
className={`${baseClass}__select-targets`}
/>
diff --git a/frontend/fleet/entities/teams.ts b/frontend/fleet/entities/teams.ts
index cf3dadf4c3..5230f61539 100644
--- a/frontend/fleet/entities/teams.ts
+++ b/frontend/fleet/entities/teams.ts
@@ -48,7 +48,7 @@ export default (client: any) => {
page = 0,
perPage = 100,
globalFilter = "",
- }: ITeamSearchOptions) => {
+ }: ITeamSearchOptions = {}) => {
const { TEAMS } = endpoints;
// TODO: add this query param logic to client class
diff --git a/frontend/fleet/helpers.ts b/frontend/fleet/helpers.ts
index d6aedfa4b9..58f3aa3109 100644
--- a/frontend/fleet/helpers.ts
+++ b/frontend/fleet/helpers.ts
@@ -279,6 +279,20 @@ export const formatScheduledQueryForClient = (scheduledQuery: any): any => {
return scheduledQuery;
};
+export const formatTeamForClient = (team: any): any => {
+ if (team.display_text === undefined) {
+ team.display_text = team.name;
+ }
+ return team;
+};
+
+export const formatPackForClient = (pack: any): any => {
+ pack.host_ids ||= [];
+ pack.label_ids ||= [];
+ pack.team_ids ||= [];
+ return pack;
+};
+
const setupData = (formData: any) => {
const orgInfo = pick(formData, ORG_INFO_ATTRS);
const adminInfo = pick(formData, ADMIN_ATTRS);
@@ -355,6 +369,7 @@ export const secondsToHms = (d: number): string => {
const sDisplay = s > 0 ? s + (s === 1 ? " sec " : " secs ") : "";
return hDisplay + mDisplay + sDisplay;
};
+
export default {
addGravatarUrlToResource,
formatConfigDataForServer,
diff --git a/frontend/pages/packs/EditPackPage/EditPackPage.jsx b/frontend/pages/packs/EditPackPage/EditPackPage.jsx
index 702360d393..fe2400d1a7 100644
--- a/frontend/pages/packs/EditPackPage/EditPackPage.jsx
+++ b/frontend/pages/packs/EditPackPage/EditPackPage.jsx
@@ -4,12 +4,15 @@ import { connect } from "react-redux";
import { filter, includes, isEqual, noop, size, find } from "lodash";
import { push } from "react-router-redux";
+import permissionUtils from "utilities/permissions";
import deepDifference from "utilities/deep_difference";
import EditPackFormWrapper from "components/packs/EditPackFormWrapper";
import hostActions from "redux/nodes/entities/hosts/actions";
import hostInterface from "interfaces/host";
import labelActions from "redux/nodes/entities/labels/actions";
+import teamActions from "redux/nodes/entities/teams/actions";
import labelInterface from "interfaces/label";
+import teamInterface from "interfaces/team";
import packActions from "redux/nodes/entities/packs/actions";
import ScheduleQuerySidePanel from "components/side_panels/ScheduleQuerySidePanel";
import packInterface from "interfaces/pack";
@@ -34,7 +37,9 @@ export class EditPackPage extends Component {
packHosts: PropTypes.arrayOf(hostInterface),
packID: PropTypes.string,
packLabels: PropTypes.arrayOf(labelInterface),
+ packTeams: PropTypes.arrayOf(teamInterface),
scheduledQueries: PropTypes.arrayOf(queryInterface),
+ isBasicTier: PropTypes.bool,
};
static defaultProps = {
@@ -60,6 +65,7 @@ export class EditPackPage extends Component {
packHosts,
packID,
packLabels,
+ packTeams,
scheduledQueries,
} = this.props;
const { load } = packActions;
@@ -77,6 +83,10 @@ export class EditPackPage extends Component {
if (!packLabels || packLabels.length !== pack.label_ids.length) {
dispatch(labelActions.loadAll());
}
+
+ if (!packTeams || packTeams.length !== pack.team_ids.length) {
+ dispatch(teamActions.loadAll());
+ }
}
if (!size(scheduledQueries)) {
@@ -90,7 +100,13 @@ export class EditPackPage extends Component {
return false;
}
- componentWillReceiveProps({ dispatch, pack, packHosts, packLabels }) {
+ componentWillReceiveProps({
+ dispatch,
+ pack,
+ packHosts,
+ packLabels,
+ packTeams,
+ }) {
if (!isEqual(pack, this.props.pack)) {
if (!packHosts || packHosts.length !== pack.host_ids.length) {
dispatch(hostActions.loadAll());
@@ -99,6 +115,10 @@ export class EditPackPage extends Component {
if (!packLabels || packLabels.length !== pack.label_ids.length) {
dispatch(labelActions.loadAll());
}
+
+ if (!packTeams || packTeams.length !== pack.team_ids.length) {
+ dispatch(teamActions.loadAll());
+ }
}
return false;
@@ -241,10 +261,12 @@ export class EditPackPage extends Component {
pack,
packHosts,
packLabels,
+ packTeams,
scheduledQueries,
+ isBasicTier,
} = this.props;
- const packTargets = [...packHosts, ...packLabels];
+ const packTargets = [...packHosts, ...packLabels, ...packTeams];
if (!pack || isLoadingPack || isLoadingScheduledQueries) {
return false;
@@ -263,6 +285,7 @@ export class EditPackPage extends Component {
pack={pack}
packTargets={packTargets}
targetsCount={targetsCount}
+ isBasicTier={isBasicTier}
/>
{
return includes(pack.label_ids, label.id);
})
: [];
+ const packTeams = pack
+ ? filter(state.entities.teams.data, (team) => {
+ return includes(pack.team_ids, team.id);
+ })
+ : [];
+ const isBasicTier = permissionUtils.isBasicTier(state.app.config);
return {
allQueries,
@@ -317,7 +346,9 @@ const mapStateToProps = (state, { params, route }) => {
packHosts,
packID,
packLabels,
+ packTeams,
scheduledQueries,
+ isBasicTier,
};
};
diff --git a/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx b/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx
index a3fc8e629c..c76de56da6 100644
--- a/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx
+++ b/frontend/pages/packs/EditPackPage/EditPackPage.tests.jsx
@@ -3,7 +3,12 @@ import { mount } from "enzyme";
import { noop } from "lodash";
import { connectedComponent, reduxMockStore } from "test/helpers";
-import { packStub, queryStub, scheduledQueryStub } from "test/stubs";
+import {
+ packStub,
+ queryStub,
+ scheduledQueryStub,
+ configStub,
+} from "test/stubs";
import ConnectedEditPackPage, {
EditPackPage,
} from "pages/packs/EditPackPage/EditPackPage";
@@ -12,6 +17,7 @@ import labelActions from "redux/nodes/entities/labels/actions";
import packActions from "redux/nodes/entities/packs/actions";
import queryActions from "redux/nodes/entities/queries/actions";
import scheduledQueryActions from "redux/nodes/entities/scheduled_queries/actions";
+import teamActions from "redux/nodes/entities/teams/actions";
describe("EditPackPage - component", () => {
beforeEach(() => {
@@ -21,15 +27,18 @@ describe("EditPackPage - component", () => {
jest.spyOn(labelActions, "loadAll").mockImplementation(() => spyResponse);
jest.spyOn(packActions, "load").mockImplementation(() => spyResponse);
jest.spyOn(queryActions, "loadAll").mockImplementation(() => spyResponse);
+ jest.spyOn(teamActions, "loadAll").mockImplementation(() => spyResponse);
jest
.spyOn(scheduledQueryActions, "loadAll")
.mockImplementation(() => spyResponse);
});
const store = {
+ app: { config: configStub },
entities: {
hosts: { loading: false, data: {} },
labels: { loading: false, data: {} },
+ teams: { loading: false, data: {} },
packs: {
loading: false,
data: {
@@ -49,6 +58,7 @@ describe("EditPackPage - component", () => {
describe("rendering", () => {
it("does not render when packs are loading", () => {
const packsLoadingStore = {
+ app: store.app,
entities: {
...store.entities,
packs: { ...store.entities.packs, loading: true },
@@ -67,6 +77,7 @@ describe("EditPackPage - component", () => {
it("does not render when scheduled queries are loading", () => {
const scheduledQueriesLoadingStore = {
+ app: store.app,
entities: {
...store.entities,
scheduled_queries: {
@@ -88,6 +99,7 @@ describe("EditPackPage - component", () => {
it("does not render when there is no pack", () => {
const noPackStore = {
+ app: store.app,
entities: {
...store.entities,
packs: { data: {}, loading: false },
@@ -130,6 +142,7 @@ describe("EditPackPage - component", () => {
isEdit: false,
packHosts: [],
packLabels: [],
+ packTeams: [],
scheduledQueries: [],
};
@@ -155,6 +168,7 @@ describe("EditPackPage - component", () => {
packHosts: [],
packID: String(packStub.id),
packLabels: [],
+ packTeams: [],
scheduledQueries: [scheduledQuery],
};
@@ -206,6 +220,7 @@ describe("EditPackPage - component", () => {
packHosts: [],
packID: String(packStub.id),
packLabels: [],
+ packTeams: [],
scheduledQueries: [scheduledQuery],
};
const pushAction = {
diff --git a/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx b/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx
index 6d287aa11c..f6ca290414 100644
--- a/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx
+++ b/frontend/pages/packs/PackComposerPage/PackComposerPage.jsx
@@ -8,6 +8,7 @@ import packActions from "redux/nodes/entities/packs/actions";
import PackForm from "components/forms/packs/PackForm";
import PackInfoSidePanel from "components/side_panels/PackInfoSidePanel";
import PATHS from "router/paths";
+import permissionUtils from "utilities/permissions";
const baseClass = "pack-composer";
@@ -17,6 +18,7 @@ export class PackComposerPage extends Component {
serverErrors: PropTypes.shape({
base: PropTypes.string,
}),
+ isBasicTier: PropTypes.bool,
};
static defaultProps = {
@@ -60,7 +62,7 @@ export class PackComposerPage extends Component {
render() {
const { handleSubmit, onFetchTargets } = this;
const { selectedTargetsCount } = this.state;
- const { serverErrors } = this.props;
+ const { serverErrors, isBasicTier } = this.props;
return (
@@ -70,6 +72,7 @@ export class PackComposerPage extends Component {
onFetchTargets={onFetchTargets}
selectedTargetsCount={selectedTargetsCount}
serverErrors={serverErrors}
+ isBasicTier={isBasicTier}
/>
@@ -79,8 +82,9 @@ export class PackComposerPage extends Component {
const mapStateToProps = (state) => {
const { errors: serverErrors } = state.entities.packs;
+ const isBasicTier = permissionUtils.isBasicTier(state.app.config);
- return { serverErrors };
+ return { serverErrors, isBasicTier };
};
export default connect(mapStateToProps)(PackComposerPage);
diff --git a/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx b/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx
index 99be6fc9d5..e351538859 100644
--- a/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx
+++ b/frontend/pages/packs/PackComposerPage/PackComposerPage.tests.jsx
@@ -1,5 +1,6 @@
import { mount } from "enzyme";
+import { configStub } from "test/stubs";
import { connectedComponent, reduxMockStore } from "test/helpers";
import ConnectedPacksComposerPage from "./PackComposerPage";
@@ -8,6 +9,7 @@ describe("PackComposerPage - component", () => {
entities: {
packs: {},
},
+ app: { config: configStub },
});
it("renders", () => {
const page = mount(
diff --git a/frontend/redux/nodes/entities/packs/config.js b/frontend/redux/nodes/entities/packs/config.js
index 8543e12f84..34c439807f 100644
--- a/frontend/redux/nodes/entities/packs/config.js
+++ b/frontend/redux/nodes/entities/packs/config.js
@@ -1,6 +1,7 @@
import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
import schemas from "redux/nodes/entities/base/schemas";
+import { formatPackForClient } from "fleet/helpers";
const { PACKS: schema } = schemas;
@@ -10,6 +11,7 @@ export default new Config({
entityName: "packs",
loadAllFunc: Fleet.packs.loadAll,
loadFunc: Fleet.packs.load,
+ parseEntityFunc: formatPackForClient,
schema,
updateFunc: Fleet.packs.update,
});
diff --git a/frontend/redux/nodes/entities/scheduled_queries/config.js b/frontend/redux/nodes/entities/scheduled_queries/config.js
index 68f25100db..67a6445e8b 100644
--- a/frontend/redux/nodes/entities/scheduled_queries/config.js
+++ b/frontend/redux/nodes/entities/scheduled_queries/config.js
@@ -1,4 +1,4 @@
-import helpers from "fleet/helpers";
+import { formatScheduledQueryForClient } from "fleet/helpers";
import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
import schemas from "redux/nodes/entities/base/schemas";
@@ -10,7 +10,7 @@ export default new Config({
destroyFunc: Fleet.scheduledQueries.destroy,
entityName: "scheduled_queries",
loadAllFunc: Fleet.scheduledQueries.loadAll,
- parseEntityFunc: helpers.formatScheduledQueryForClient,
+ parseEntityFunc: formatScheduledQueryForClient,
schema,
updateFunc: Fleet.scheduledQueries.update,
});
diff --git a/frontend/redux/nodes/entities/teams/config.ts b/frontend/redux/nodes/entities/teams/config.ts
index ac44b02b3d..d7a461c989 100644
--- a/frontend/redux/nodes/entities/teams/config.ts
+++ b/frontend/redux/nodes/entities/teams/config.ts
@@ -5,6 +5,7 @@ import Fleet from "fleet";
import Config from "redux/nodes/entities/base/config";
// @ts-ignore
import schemas from "redux/nodes/entities/base/schemas";
+import { formatTeamForClient } from "fleet/helpers";
const { TEAMS } = schemas;
@@ -14,6 +15,7 @@ export default new Config({
entityName: "teams",
loadFunc: Fleet.teams.load,
loadAllFunc: Fleet.teams.loadAll,
+ parseEntityFunc: formatTeamForClient,
schema: TEAMS,
updateFunc: Fleet.teams.update,
});
diff --git a/frontend/test/stubs.ts b/frontend/test/stubs.ts
index 5f376d4151..b87dab35c7 100644
--- a/frontend/test/stubs.ts
+++ b/frontend/test/stubs.ts
@@ -123,6 +123,7 @@ export const packStub = {
disabled: false,
host_ids: [],
label_ids: [],
+ team_ids: [],
};
export const queryStub = {
diff --git a/server/datastore/datastore.go b/server/datastore/datastore.go
index 28d74f8532..3c8692c4f9 100644
--- a/server/datastore/datastore.go
+++ b/server/datastore/datastore.go
@@ -26,13 +26,12 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testSaveQuery,
testListQuery,
testDeletePack,
- testNewPack,
+ testSavePack,
testEnrollHost,
testAuthenticateHost,
testAuthenticateHostCaseSensitive,
testLabels,
testSaveLabel,
- testManagingLabelsOnPacks,
testPasswordResetRequests,
testCreateUser,
testSaveUser,
@@ -52,7 +51,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testListHostsFilterAdditional,
testListHostsStatus,
testListHostsQuery,
- testListHostsInPack,
testListPacksForHost,
testHostIDsByName,
testHostByIdentifier,
@@ -69,7 +67,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testCascadingDeletionOfQueries,
testGetPackByName,
testGetQueryByName,
- testAddLabelToPackTwice,
testGenerateHostStatusStatistics,
testMarkHostSeen,
testMarkHostsSeen,
@@ -90,7 +87,6 @@ var TestFunctions = [...]func(*testing.T, fleet.Datastore){
testApplyLabelSpecsRoundtrip,
testGetLabelSpec,
testLabelIDsByName,
- testListLabelsForPack,
testHostAdditional,
testCarveMetadata,
testCarveBlocks,
diff --git a/server/datastore/datastore_labels.go b/server/datastore/datastore_labels.go
index 66572d976e..ae14dfd911 100644
--- a/server/datastore/datastore_labels.go
+++ b/server/datastore/datastore_labels.go
@@ -156,62 +156,6 @@ func testLabels(t *testing.T, db fleet.Datastore) {
assert.Len(t, labels, 1)
}
-func testManagingLabelsOnPacks(t *testing.T, ds fleet.Datastore) {
- pack := &fleet.PackSpec{
- ID: 1,
- Name: "pack1",
- }
- err := ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
- require.Nil(t, err)
-
- labels, err := ds.ListLabelsForPack(pack.ID)
- require.Nil(t, err)
- assert.Len(t, labels, 0)
-
- mysqlLabel := &fleet.LabelSpec{
- ID: 1,
- Name: "MySQL Monitoring",
- Query: "select pid from processes where name = 'mysqld';",
- }
- err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel})
- require.Nil(t, err)
-
- pack.Targets = fleet.PackSpecTargets{
- Labels: []string{
- mysqlLabel.Name,
- },
- }
- err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
- require.Nil(t, err)
-
- labels, err = ds.ListLabelsForPack(pack.ID)
- require.Nil(t, err)
- if assert.Len(t, labels, 1) {
- assert.Equal(t, "MySQL Monitoring", labels[0].Name)
- }
-
- osqueryLabel := &fleet.LabelSpec{
- ID: 2,
- Name: "Osquery Monitoring",
- Query: "select pid from processes where name = 'osqueryd';",
- }
- err = ds.ApplyLabelSpecs([]*fleet.LabelSpec{mysqlLabel, osqueryLabel})
- require.Nil(t, err)
-
- pack.Targets = fleet.PackSpecTargets{
- Labels: []string{
- mysqlLabel.Name,
- osqueryLabel.Name,
- },
- }
- err = ds.ApplyPackSpecs([]*fleet.PackSpec{pack})
- require.Nil(t, err)
-
- labels, err = ds.ListLabelsForPack(pack.ID)
- require.Nil(t, err)
- assert.Len(t, labels, 2)
-}
-
func testSearchLabels(t *testing.T, db fleet.Datastore) {
specs := []*fleet.LabelSpec{
&fleet.LabelSpec{
diff --git a/server/datastore/datastore_packs.go b/server/datastore/datastore_packs.go
index 27db57af35..3a37c4265a 100644
--- a/server/datastore/datastore_packs.go
+++ b/server/datastore/datastore_packs.go
@@ -26,18 +26,38 @@ func testDeletePack(t *testing.T, ds fleet.Datastore) {
assert.NotNil(t, err)
}
-func testNewPack(t *testing.T, ds fleet.Datastore) {
- pack := &fleet.Pack{
- Name: "foo",
+func testSavePack(t *testing.T, ds fleet.Datastore) {
+ expectedPack := &fleet.Pack{
+ Name: "foo",
+ HostIDs: []uint{1},
+ LabelIDs: []uint{1},
+ TeamIDs: []uint{1},
}
- pack, err := ds.NewPack(pack)
+ pack, err := ds.NewPack(expectedPack)
require.NoError(t, err)
assert.NotEqual(t, uint(0), pack.ID)
+ test.EqualSkipTimestampsID(t, expectedPack, pack)
pack, err = ds.Pack(pack.ID)
require.NoError(t, err)
- assert.Equal(t, "foo", pack.Name)
+ test.EqualSkipTimestampsID(t, expectedPack, pack)
+
+ expectedPack = &fleet.Pack{
+ ID: pack.ID,
+ Name: "bar",
+ HostIDs: []uint{3},
+ LabelIDs: []uint{4, 6},
+ TeamIDs: []uint{},
+ }
+
+ err = ds.SavePack(expectedPack)
+ require.NoError(t, err)
+
+ pack, err = ds.Pack(pack.ID)
+ require.NoError(t, err)
+ assert.Equal(t, "bar", pack.Name)
+ test.EqualSkipTimestampsID(t, expectedPack, pack)
}
func testGetPackByName(t *testing.T, ds fleet.Datastore) {
@@ -81,92 +101,8 @@ func testListPacks(t *testing.T, ds fleet.Datastore) {
assert.Len(t, packs, 2)
}
-func testListHostsInPack(t *testing.T, ds fleet.Datastore) {
- if ds.Name() == "inmem" {
- t.Skip("inmem is deprecated")
- }
-
- mockClock := clock.NewMockClock()
-
- l1 := fleet.LabelSpec{
- ID: 1,
- Name: "foo",
- }
- err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1})
- require.Nil(t, err)
-
- p1 := &fleet.PackSpec{
- ID: 1,
- Name: "foo_pack",
- Targets: fleet.PackSpecTargets{
- Labels: []string{
- l1.Name,
- },
- },
- }
- err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1})
- require.Nil(t, err)
-
- h1 := test.NewHost(t, ds, "h1.local", "10.10.10.1", "1", "1", mockClock.Now())
-
- hostsInPack, err := ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
- require.Nil(t, err)
- require.Len(t, hostsInPack, 0)
-
- err = ds.RecordLabelQueryExecutions(
- h1,
- map[uint]bool{l1.ID: true},
- mockClock.Now(),
- )
- require.Nil(t, err)
-
- hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
- require.Nil(t, err)
- require.Len(t, hostsInPack, 1)
-
- explicitHostsInPack, err := ds.ListExplicitHostsInPack(p1.ID, fleet.ListOptions{})
- require.Nil(t, err)
- require.Len(t, explicitHostsInPack, 0)
-
- h2 := test.NewHost(t, ds, "h2.local", "10.10.10.2", "2", "2", mockClock.Now())
-
- err = ds.RecordLabelQueryExecutions(
- h2,
- map[uint]bool{l1.ID: true},
- mockClock.Now(),
- )
- require.Nil(t, err)
-
- hostsInPack, err = ds.ListHostsInPack(p1.ID, fleet.ListOptions{})
- require.Nil(t, err)
- require.Len(t, hostsInPack, 2)
-}
-
-func testAddLabelToPackTwice(t *testing.T, ds fleet.Datastore) {
- l1 := fleet.LabelSpec{
- ID: 1,
- Name: "l1",
- Query: "select 1",
- }
- err := ds.ApplyLabelSpecs([]*fleet.LabelSpec{&l1})
- require.Nil(t, err)
-
- p1 := &fleet.PackSpec{
- ID: 1,
- Name: "pack1",
- Targets: fleet.PackSpecTargets{
- Labels: []string{
- l1.Name,
- l1.Name,
- },
- },
- }
- err = ds.ApplyPackSpecs([]*fleet.PackSpec{p1})
- require.NotNil(t, err)
-}
-
func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
- zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@fleet.co", true)
+ zwass := test.NewUser(t, ds, "Zach", "zwass", "zwass@example.com", true)
queries := []*fleet.Query{
{Name: "foo", Description: "get the foos", Query: "select * from foo"},
{Name: "bar", Description: "do some bars", Query: "select baz from bar"},
@@ -176,15 +112,15 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
require.Nil(t, err)
labels := []*fleet.LabelSpec{
- &fleet.LabelSpec{
+ {
Name: "foo",
Query: "select * from foo",
},
- &fleet.LabelSpec{
+ {
Name: "bar",
Query: "select * from bar",
},
- &fleet.LabelSpec{
+ {
Name: "bing",
Query: "select * from bing",
},
@@ -193,7 +129,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
require.Nil(t, err)
expectedSpecs := []*fleet.PackSpec{
- &fleet.PackSpec{
+ {
ID: 1,
Name: "test_pack",
Targets: fleet.PackSpecTargets{
@@ -204,20 +140,20 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
Queries: []fleet.PackSpecQuery{
- fleet.PackSpecQuery{
+ {
QueryName: queries[0].Name,
Name: "q0",
Description: "test_foo",
Interval: 42,
},
- fleet.PackSpecQuery{
+ {
QueryName: queries[0].Name,
Name: "foo_snapshot",
Interval: 600,
Snapshot: ptr.Bool(true),
Denylist: ptr.Bool(false),
},
- fleet.PackSpecQuery{
+ {
Name: "q2",
QueryName: queries[1].Name,
Interval: 600,
@@ -229,7 +165,7 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
},
- &fleet.PackSpec{
+ {
ID: 2,
Name: "test_pack_disabled",
Disabled: true,
@@ -241,19 +177,19 @@ func setupPackSpecsTest(t *testing.T, ds fleet.Datastore) []*fleet.PackSpec {
},
},
Queries: []fleet.PackSpecQuery{
- fleet.PackSpecQuery{
+ {
QueryName: queries[0].Name,
Name: "q0",
Description: "test_foo",
Interval: 42,
},
- fleet.PackSpecQuery{
+ {
QueryName: queries[0].Name,
Name: "foo_snapshot",
Interval: 600,
Snapshot: ptr.Bool(true),
},
- fleet.PackSpecQuery{
+ {
Name: "q2",
QueryName: queries[1].Name,
Interval: 600,
@@ -292,14 +228,14 @@ func testGetPackSpec(t *testing.T, ds fleet.Datastore) {
func testApplyPackSpecMissingQueries(t *testing.T, ds fleet.Datastore) {
// Do not define queries mentioned in spec
specs := []*fleet.PackSpec{
- &fleet.PackSpec{
+ {
ID: 1,
Name: "test_pack",
Targets: fleet.PackSpecTargets{
Labels: []string{},
},
Queries: []fleet.PackSpecQuery{
- fleet.PackSpecQuery{
+ {
QueryName: "bar",
Interval: 600,
},
@@ -318,13 +254,13 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) {
setupPackSpecsTest(t, ds)
specs := []*fleet.PackSpec{
- &fleet.PackSpec{
+ {
Name: "test2",
Targets: fleet.PackSpecTargets{
Labels: []string{},
},
Queries: []fleet.PackSpecQuery{
- fleet.PackSpecQuery{
+ {
QueryName: "foo",
Interval: 600,
},
@@ -340,67 +276,6 @@ func testApplyPackSpecMissingName(t *testing.T, ds fleet.Datastore) {
assert.Equal(t, "foo", spec.Queries[0].Name)
}
-func testListLabelsForPack(t *testing.T, ds fleet.Datastore) {
- labelSpecs := []*fleet.LabelSpec{
- &fleet.LabelSpec{
- Name: "foo",
- Query: "select * from foo",
- },
- &fleet.LabelSpec{
- Name: "bar",
- Query: "select * from bar",
- },
- &fleet.LabelSpec{
- Name: "bing",
- Query: "select * from bing",
- },
- }
- err := ds.ApplyLabelSpecs(labelSpecs)
- require.Nil(t, err)
-
- specs := []*fleet.PackSpec{
- &fleet.PackSpec{
- ID: 1,
- Name: "test_pack",
- Targets: fleet.PackSpecTargets{
- Labels: []string{
- "foo",
- "bar",
- "bing",
- },
- },
- },
- &fleet.PackSpec{
- ID: 2,
- Name: "test 2",
- Targets: fleet.PackSpecTargets{
- Labels: []string{
- "bing",
- },
- },
- },
- &fleet.PackSpec{
- ID: 3,
- Name: "test 3",
- },
- }
- err = ds.ApplyPackSpecs(specs)
- require.Nil(t, err)
-
- labels, err := ds.ListLabelsForPack(specs[0].ID)
- require.Nil(t, err)
- assert.Len(t, labels, 3)
-
- labels, err = ds.ListLabelsForPack(specs[1].ID)
- require.Nil(t, err)
- assert.Len(t, labels, 1)
- assert.Equal(t, "bing", labels[0].Name)
-
- labels, err = ds.ListLabelsForPack(specs[2].ID)
- require.Nil(t, err)
- assert.Len(t, labels, 0)
-}
-
func testListPacksForHost(t *testing.T, ds fleet.Datastore) {
if ds.Name() == "inmem" {
t.Skip("inmem is deprecated")
@@ -507,40 +382,4 @@ func testListPacksForHost(t *testing.T, ds fleet.Datastore) {
if assert.Len(t, packs, 1) {
assert.Equal(t, "foo_pack", packs[0].Name)
}
-
- // Add host directly to pack
- err = ds.AddHostToPack(h1.ID, p2.ID)
- require.Nil(t, err)
-
- packs, err = ds.ListPacksForHost(h1.ID)
- require.Nil(t, err)
- assert.Len(t, packs, 2)
-
- // Remove label membership for both
- err = ds.RecordLabelQueryExecutions(
- h1,
- map[uint]bool{l2.ID: false, l1.ID: false},
- mockClock.Now(),
- )
- require.Nil(t, err)
-
- err = ds.RecordLabelQueryExecutions(
- h1,
- map[uint]bool{l2.ID: false},
- mockClock.Now(),
- )
- require.Nil(t, err)
- packs, err = ds.ListPacksForHost(h1.ID)
- require.Nil(t, err)
- if assert.Len(t, packs, 1) {
- assert.Equal(t, p2.Name, packs[0].Name)
- }
-
- // Now host is added directly to both packs
- err = ds.AddHostToPack(h1.ID, p1.ID)
- require.Nil(t, err)
-
- packs, err = ds.ListPacksForHost(h1.ID)
- require.Nil(t, err)
- assert.Len(t, packs, 2)
}
diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go
index 52244df0cd..a93f606832 100644
--- a/server/datastore/mysql/labels.go
+++ b/server/datastore/mysql/labels.go
@@ -178,7 +178,6 @@ func (d *Datastore) getLabelHostnames(label *fleet.LabelSpec) error {
// NewLabel creates a new fleet.Label
func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fleet.Label, error) {
- db := d.getTransaction(opts)
query := `
INSERT INTO labels (
name,
@@ -189,7 +188,7 @@ func (d *Datastore) NewLabel(label *fleet.Label, opts ...fleet.OptionalArg) (*fl
label_membership_type
) VALUES ( ?, ?, ?, ?, ?, ?)
`
- result, err := db.Exec(
+ result, err := d.db.Exec(
query,
label.Name,
label.Description,
diff --git a/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go b/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go
new file mode 100644
index 0000000000..1af8e7956e
--- /dev/null
+++ b/server/datastore/mysql/migrations/tables/20210617174723_PackTargetForeignKey.go
@@ -0,0 +1,35 @@
+package tables
+
+import (
+ "database/sql"
+
+ "github.com/pkg/errors"
+)
+
+func init() {
+ MigrationClient.AddMigration(Up_20210617174723, Down_20210617174723)
+}
+
+func Up_20210617174723(tx *sql.Tx) error {
+ sql := `
+ DELETE FROM pack_targets
+ WHERE pack_id NOT IN (SELECT id FROM packs)
+ `
+ if _, err := tx.Exec(sql); err != nil {
+ return errors.Wrap(err, "delete orphaned pack targets")
+ }
+
+ sql = `
+ ALTER TABLE pack_targets
+ ADD FOREIGN KEY (pack_id) REFERENCES packs (id) ON UPDATE CASCADE ON DELETE CASCADE
+ `
+ if _, err := tx.Exec(sql); err != nil {
+ return errors.Wrap(err, "add foreign key on pack_targets pack_id")
+ }
+
+ return nil
+}
+
+func Down_20210617174723(tx *sql.Tx) error {
+ return nil
+}
diff --git a/server/datastore/mysql/mysql.go b/server/datastore/mysql/mysql.go
index 07285ee5d2..e615955a42 100644
--- a/server/datastore/mysql/mysql.go
+++ b/server/datastore/mysql/mysql.go
@@ -46,23 +46,6 @@ type Datastore struct {
config config.MysqlConfig
}
-type dbfunctions interface {
- Exec(query string, args ...interface{}) (sql.Result, error)
- Get(dest interface{}, query string, args ...interface{}) error
- Select(dest interface{}, query string, args ...interface{}) error
-}
-
-func (d *Datastore) getTransaction(opts []fleet.OptionalArg) dbfunctions {
- var result dbfunctions = d.db
- for _, opt := range opts {
- switch t := opt().(type) {
- case dbfunctions:
- result = t
- }
- }
- return result
-}
-
type txFn func(*sqlx.Tx) error
// retryableError determines whether a MySQL error can be retried. By default
diff --git a/server/datastore/mysql/packs.go b/server/datastore/mysql/packs.go
index 56557b34b3..463e7024b5 100644
--- a/server/datastore/mysql/packs.go
+++ b/server/datastore/mysql/packs.go
@@ -3,6 +3,7 @@ package mysql
import (
"database/sql"
"fmt"
+ "strings"
"github.com/fleetdm/fleet/server/fleet"
"github.com/jmoiron/sqlx"
@@ -198,19 +199,22 @@ WHERE pack_id = ?
}
func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.Pack, bool, error) {
- db := d.getTransaction(opts)
sqlStatement := `
SELECT *
FROM packs
WHERE name = ?
`
var pack fleet.Pack
- err := db.Get(&pack, sqlStatement, name)
+ err := d.db.Get(&pack, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, false, nil
}
- return nil, false, errors.Wrap(err, "fetching packs by name")
+ return nil, false, errors.Wrap(err, "fetch pack by name")
+ }
+
+ if err := d.loadPackTargets(&pack); err != nil {
+ return nil, false, err
}
return &pack, true, nil
@@ -218,44 +222,144 @@ func (d *Datastore) PackByName(name string, opts ...fleet.OptionalArg) (*fleet.P
// NewPack creates a new Pack
func (d *Datastore) NewPack(pack *fleet.Pack, opts ...fleet.OptionalArg) (*fleet.Pack, error) {
- db := d.getTransaction(opts)
+ if err := d.withRetryTxx(func(tx *sqlx.Tx) error {
+ query := `
+ INSERT INTO packs
+ (name, description, platform, disabled)
+ VALUES ( ?, ?, ?, ? )
+ `
+ result, err := d.db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled)
+ if err != nil {
+ return errors.Wrap(err, "insert pack")
+ }
- query := `
- INSERT INTO packs
- (name, description, platform, disabled)
- VALUES ( ?, ?, ?, ? )
- `
+ id, _ := result.LastInsertId()
+ pack.ID = uint(id)
- result, err := db.Exec(query, pack.Name, pack.Description, pack.Platform, pack.Disabled)
- if err != nil {
- return nil, errors.Wrap(err, "inserting pack")
+ if err := d.replacePackTargets(tx, pack); err != nil {
+ return err
+ }
+
+ return nil
+ }); err != nil {
+ return nil, err
+ }
+ return pack, nil
+}
+
+func (d *Datastore) replacePackTargets(tx *sqlx.Tx, pack *fleet.Pack) error {
+ sql := `DELETE FROM pack_targets WHERE pack_id = ?`
+ if _, err := tx.Exec(sql, pack.ID); err != nil {
+ return errors.Wrap(err, "delete pack targets")
}
- id, _ := result.LastInsertId()
- pack.ID = uint(id)
- return pack, nil
+ // Insert hosts
+ if len(pack.HostIDs) > 0 {
+ var args []interface{}
+ for _, id := range pack.HostIDs {
+ args = append(args, pack.ID, fleet.TargetHost, id)
+ }
+ values := strings.TrimSuffix(
+ strings.Repeat("(?,?,?),", len(pack.HostIDs)),
+ ",",
+ )
+ sql = fmt.Sprintf(`
+ INSERT INTO pack_targets (pack_id, type, target_id)
+ VALUES %s
+ `, values)
+ if _, err := tx.Exec(sql, args...); err != nil {
+ return errors.Wrap(err, "insert host targets")
+ }
+ }
+
+ // Insert labels
+ if len(pack.LabelIDs) > 0 {
+ var args []interface{}
+ for _, id := range pack.LabelIDs {
+ args = append(args, pack.ID, fleet.TargetLabel, id)
+ }
+ values := strings.TrimSuffix(
+ strings.Repeat("(?,?,?),", len(pack.LabelIDs)),
+ ",",
+ )
+ sql = fmt.Sprintf(`
+ INSERT INTO pack_targets (pack_id, type, target_id)
+ VALUES %s
+ `, values)
+ if _, err := tx.Exec(sql, args...); err != nil {
+ return errors.Wrap(err, "insert label targets")
+ }
+ }
+
+ // Insert teams
+ if len(pack.TeamIDs) > 0 {
+ var args []interface{}
+ for _, id := range pack.TeamIDs {
+ args = append(args, pack.ID, fleet.TargetTeam, id)
+ }
+ values := strings.TrimSuffix(
+ strings.Repeat("(?,?,?),", len(pack.TeamIDs)),
+ ",",
+ )
+ sql = fmt.Sprintf(`
+ INSERT INTO pack_targets (pack_id, type, target_id)
+ VALUES %s
+ `, values)
+ if _, err := tx.Exec(sql, args...); err != nil {
+ return errors.Wrap(err, "insert team targets")
+ }
+ }
+
+ return nil
+}
+
+func (d *Datastore) loadPackTargets(pack *fleet.Pack) error {
+ var targets []fleet.PackTarget
+ sql := `SELECT * FROM pack_targets WHERE pack_id = ?`
+ if err := d.db.Select(&targets, sql, pack.ID); err != nil {
+ return errors.Wrap(err, "select pack targets")
+ }
+
+ pack.HostIDs, pack.LabelIDs, pack.TeamIDs = []uint{}, []uint{}, []uint{}
+ for _, target := range targets {
+ switch target.Type {
+ case fleet.TargetHost:
+ pack.HostIDs = append(pack.HostIDs, target.TargetID)
+ case fleet.TargetLabel:
+ pack.LabelIDs = append(pack.LabelIDs, target.TargetID)
+ case fleet.TargetTeam:
+ pack.TeamIDs = append(pack.TeamIDs, target.TargetID)
+ default:
+ return errors.Errorf("unknown target type: %d", target.Type)
+ }
+ }
+
+ return nil
}
// SavePack stores changes to pack
func (d *Datastore) SavePack(pack *fleet.Pack) error {
- query := `
+ return d.withRetryTxx(func(tx *sqlx.Tx) error {
+ query := `
UPDATE packs
SET name = ?, platform = ?, disabled = ?, description = ?
WHERE id = ?
`
- results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID)
- if err != nil {
- return errors.Wrap(err, "updating pack")
- }
- rowsAffected, err := results.RowsAffected()
- if err != nil {
- return errors.Wrap(err, "rows affected updating packs")
- }
- if rowsAffected == 0 {
- return notFound("Pack").WithID(pack.ID)
- }
- return nil
+ results, err := d.db.Exec(query, pack.Name, pack.Platform, pack.Disabled, pack.Description, pack.ID)
+ if err != nil {
+ return errors.Wrap(err, "updating pack")
+ }
+ rowsAffected, err := results.RowsAffected()
+ if err != nil {
+ return errors.Wrap(err, "rows affected updating packs")
+ }
+ if rowsAffected == 0 {
+ return notFound("Pack").WithID(pack.ID)
+ }
+
+ return d.replacePackTargets(tx, pack)
+ })
}
// DeletePack deletes a fleet.Pack so that it won't show up in results.
@@ -271,7 +375,11 @@ func (d *Datastore) Pack(pid uint) (*fleet.Pack, error) {
if err == sql.ErrNoRows {
return nil, notFound("Pack").WithID(pid)
} else if err != nil {
- return nil, errors.Wrap(err, "getting pack")
+ return nil, errors.Wrap(err, "get pack")
+ }
+
+ if err := d.loadPackTargets(pack); err != nil {
+ return nil, err
}
return pack, nil
@@ -285,107 +393,22 @@ func (d *Datastore) ListPacks(opt fleet.ListOptions) ([]*fleet.Pack, error) {
if err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing packs")
}
+
+ for _, pack := range packs {
+ if err := d.loadPackTargets(pack); err != nil {
+ return nil, err
+ }
+ }
+
return packs, nil
}
-// AddLabelToPack associates a fleet.Label with a fleet.Pack
-func (d *Datastore) AddLabelToPack(lid uint, pid uint, opts ...fleet.OptionalArg) error {
- db := d.getTransaction(opts)
-
- query := `
- INSERT INTO pack_targets ( pack_id, type, target_id )
- VALUES ( ?, ?, ? )
- ON DUPLICATE KEY UPDATE id=id
- `
- _, err := db.Exec(query, pid, fleet.TargetLabel, lid)
- if err != nil {
- return errors.Wrap(err, "adding label to pack")
- }
-
- return nil
-}
-
-// AddHostToPack associates a fleet.Host with a fleet.Pack
-func (d *Datastore) AddHostToPack(hid, pid uint) error {
- query := `
- INSERT INTO pack_targets ( pack_id, type, target_id )
- VALUES ( ?, ?, ? )
- ON DUPLICATE KEY UPDATE id=id
- `
- _, err := d.db.Exec(query, pid, fleet.TargetHost, hid)
- if err != nil {
- return errors.Wrap(err, "adding host to pack")
- }
-
- return nil
-}
-
-// RemoreLabelFromPack will remove the association between a fleet.Label and
-// a fleet.Pack
-func (d *Datastore) RemoveLabelFromPack(lid, pid uint) error {
- query := `
- DELETE FROM pack_targets
- WHERE target_id = ? AND pack_id = ? AND type = ?
- `
- _, err := d.db.Exec(query, lid, pid, fleet.TargetLabel)
- if err == sql.ErrNoRows {
- return notFound("PackTarget").WithMessage(fmt.Sprintf("label ID: %d, pack ID: %d", lid, pid))
- } else if err != nil {
- return errors.Wrap(err, "removing label from pack")
- }
- return nil
-}
-
-// RemoveHostFromPack will remove the association between a fleet.Host and a
-// fleet.Pack
-func (d *Datastore) RemoveHostFromPack(hid, pid uint) error {
- query := `
- DELETE FROM pack_targets
- WHERE target_id = ? AND pack_id = ? AND type = ?
- `
- _, err := d.db.Exec(query, hid, pid, fleet.TargetHost)
- if err == sql.ErrNoRows {
- return notFound("PackTarget").WithMessage(fmt.Sprintf("host ID: %d, pack ID: %d", hid, pid))
- } else if err != nil {
- return errors.Wrap(err, "removing host from pack")
- }
- return nil
-}
-
-// ListLabelsForPack will return a list of fleet.Label records associated with fleet.Pack
-func (d *Datastore) ListLabelsForPack(pid uint) ([]*fleet.Label, error) {
- query := `
- SELECT
- l.id,
- l.created_at,
- l.updated_at,
- l.name
- FROM
- labels l
- JOIN
- pack_targets pt
- ON
- pt.target_id = l.id
- WHERE
- pt.type = ?
- AND
- pt.pack_id = ?
- `
-
- labels := []*fleet.Label{}
-
- if err := d.db.Select(&labels, query, fleet.TargetLabel, pid); err != nil && err != sql.ErrNoRows {
- return nil, errors.Wrap(err, "listing labels for pack")
- }
-
- return labels, nil
-}
-
func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
query := `
SELECT DISTINCT packs.*
FROM
- ((SELECT p.* FROM packs p
+ ((SELECT p.*
+ FROM packs p
JOIN pack_targets pt
JOIN label_membership lm
ON (
@@ -399,55 +422,17 @@ func (d *Datastore) ListPacksForHost(hid uint) ([]*fleet.Pack, error) {
FROM packs p
JOIN pack_targets pt
ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = ?))
+ UNION ALL
+ (SELECT p.*
+ FROM packs p
+ JOIN pack_targets pt
+ ON (p.id = pt.pack_id AND pt.type = ? AND pt.target_id = (SELECT team_id FROM hosts WHERE id = ?)))
) packs
`
packs := []*fleet.Pack{}
- if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid); err != nil && err != sql.ErrNoRows {
+ if err := d.db.Select(&packs, query, fleet.TargetLabel, hid, fleet.TargetHost, hid, fleet.TargetTeam, hid); err != nil && err != sql.ErrNoRows {
return nil, errors.Wrap(err, "listing hosts in pack")
}
return packs, nil
}
-
-func (d *Datastore) ListHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) {
- query := `
- SELECT DISTINCT h.id
- FROM hosts h
- JOIN pack_targets pt
- JOIN label_membership lm
- ON (
- pt.target_id = lm.label_id
- AND lm.host_id = h.id
- AND pt.type = ?
- ) OR (
- pt.target_id = h.id
- AND pt.type = ?
- )
- WHERE pt.pack_id = ?
- `
-
- hosts := []uint{}
- if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetLabel, fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows {
- return nil, errors.Wrap(err, "listing hosts in pack")
- }
- return hosts, nil
-}
-
-func (d *Datastore) ListExplicitHostsInPack(pid uint, opt fleet.ListOptions) ([]uint, error) {
- query := `
- SELECT DISTINCT h.id
- FROM hosts h
- JOIN pack_targets pt
- ON (
- pt.target_id = h.id
- AND pt.type = ?
- )
- WHERE pt.pack_id = ?
- `
- hosts := []uint{}
- if err := d.db.Select(&hosts, appendListOptionsToSQL(query, opt), fleet.TargetHost, pid); err != nil && err != sql.ErrNoRows {
- return nil, errors.Wrap(err, "listing explicit hosts in pack")
- }
- return hosts, nil
-
-}
diff --git a/server/datastore/mysql/queries.go b/server/datastore/mysql/queries.go
index 5b0a99d166..6f331d9b2f 100644
--- a/server/datastore/mysql/queries.go
+++ b/server/datastore/mysql/queries.go
@@ -65,14 +65,13 @@ func (d *Datastore) ApplyQueries(authorID uint, queries []*fleet.Query) (err err
}
func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.Query, error) {
- db := d.getTransaction(opts)
sqlStatement := `
SELECT *
FROM queries
WHERE name = ?
`
var query fleet.Query
- err := db.Get(&query, sqlStatement, name)
+ err := d.db.Get(&query, sqlStatement, name)
if err != nil {
if err == sql.ErrNoRows {
return nil, notFound("Query").WithName(name)
@@ -89,8 +88,6 @@ func (d *Datastore) QueryByName(name string, opts ...fleet.OptionalArg) (*fleet.
// NewQuery creates a New Query.
func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
- db := d.getTransaction(opts)
-
sqlStatement := `
INSERT INTO queries (
name,
@@ -101,7 +98,7 @@ func (d *Datastore) NewQuery(query *fleet.Query, opts ...fleet.OptionalArg) (*fl
observer_can_run
) VALUES ( ?, ?, ?, ?, ?, ? )
`
- result, err := db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
+ result, err := d.db.Exec(sqlStatement, query.Name, query.Description, query.Query, query.Saved, query.AuthorID, query.ObserverCanRun)
if err != nil && isDuplicate(err) {
return nil, alreadyExists("Query", 0)
diff --git a/server/datastore/mysql/scheduled_queries.go b/server/datastore/mysql/scheduled_queries.go
index ec41e1c7f6..e2f1649c2f 100644
--- a/server/datastore/mysql/scheduled_queries.go
+++ b/server/datastore/mysql/scheduled_queries.go
@@ -40,8 +40,6 @@ func (d *Datastore) ListScheduledQueriesInPack(id uint, opts fleet.ListOptions)
}
func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.OptionalArg) (*fleet.ScheduledQuery, error) {
- db := d.getTransaction(opts)
-
// This query looks up the query name using the ID (for backwards
// compatibility with the UI)
query := `
@@ -61,7 +59,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
FROM queries
WHERE id = ?
`
- result, err := db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID)
+ result, err := d.db.Exec(query, sq.Name, sq.PackID, sq.Snapshot, sq.Removed, sq.Interval, sq.Platform, sq.Version, sq.Shard, sq.Denylist, sq.QueryID)
if err != nil {
return nil, errors.Wrap(err, "insert scheduled query")
}
@@ -75,7 +73,7 @@ func (d *Datastore) NewScheduledQuery(sq *fleet.ScheduledQuery, opts ...fleet.Op
Name string
}{}
- err = db.Select(&metadata, query, sq.QueryID)
+ err = d.db.Select(&metadata, query, sq.QueryID)
if err != nil && err == sql.ErrNoRows {
return nil, notFound("Query").WithID(sq.QueryID)
} else if err != nil {
diff --git a/server/fleet/packs.go b/server/fleet/packs.go
index 0518626bec..5c4e510393 100644
--- a/server/fleet/packs.go
+++ b/server/fleet/packs.go
@@ -33,33 +33,8 @@ type PackStore interface {
// exists the bool return value is true
PackByName(name string, opts ...OptionalArg) (*Pack, bool, error)
- // AddLabelToPack adds an existing label to an existing pack, both by ID.
- AddLabelToPack(lid, pid uint, opts ...OptionalArg) error
-
- // RemoveLabelFromPack removes an existing label from it's association with
- // an existing pack, both by ID.
- RemoveLabelFromPack(lid, pid uint) error
-
- // ListLabelsForPack lists all labels that are associated with a pack.
- ListLabelsForPack(pid uint) ([]*Label, error)
-
- // AddHostToPack adds an existing host to an existing pack, both by ID.
- AddHostToPack(hid uint, pid uint) error
-
- // RemoveHostFromPack removes an existing host from it's association with
- // an existing pack, both by ID.
- RemoveHostFromPack(hid uint, pid uint) error
-
// ListPacksForHost lists the packs that a host should execute.
ListPacksForHost(hid uint) (packs []*Pack, err error)
-
- // ListHostsInPack lists the IDs of all hosts that are associated with a pack
- // through labels.
- ListHostsInPack(pid uint, opt ListOptions) ([]uint, error)
-
- // ListExplicitHostsInPack lists the IDs of hosts that have been manually
- // associated with a query pack.
- ListExplicitHostsInPack(pid uint, opt ListOptions) ([]uint, error)
}
// PackService is the service interface for managing query packs.
@@ -90,33 +65,8 @@ type PackService interface {
// DeletePackByID is for backwards compatibility with the UI
DeletePackByID(ctx context.Context, id uint) (err error)
- // AddLabelToPack adds an existing label to an existing pack, both by ID.
- AddLabelToPack(ctx context.Context, lid, pid uint) (err error)
-
- // RemoveLabelFromPack removes an existing label from it's association with
- // an existing pack, both by ID.
- RemoveLabelFromPack(ctx context.Context, lid, pid uint) (err error)
-
- // ListLabelsForPack lists all labels that are associated with a pack.
- ListLabelsForPack(ctx context.Context, pid uint) (labels []*Label, err error)
-
- // AddHostToPack adds an existing host to an existing pack, both by ID.
- AddHostToPack(ctx context.Context, hid, pid uint) (err error)
-
- // RemoveHostFromPack removes an existing host from it's association with
- // an existing pack, both by ID.
- RemoveHostFromPack(ctx context.Context, hid, pid uint) (err error)
-
// ListPacksForHost lists the packs that a host should execute.
ListPacksForHost(ctx context.Context, hid uint) (packs []*Pack, err error)
-
- // ListHostsInPack lists the IDs of all hosts that are associated with a pack,
- // both through labels and manual associations.
- ListHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error)
-
- // ListExplicitHostsInPack lists the IDs of hosts that have been manually associated
- // with a query pack.
- ListExplicitHostsInPack(ctx context.Context, pid uint, opt ListOptions) (hosts []uint, err error)
}
// Pack is the structure which represents an osquery query pack.
@@ -127,6 +77,9 @@ type Pack struct {
Description string `json:"description,omitempty"`
Platform string `json:"platform,omitempty"`
Disabled bool `json:"disabled,omitempty"`
+ LabelIDs []uint `json:"label_ids"`
+ HostIDs []uint `json:"host_ids"`
+ TeamIDs []uint `json:"team_ids"`
}
func (p Pack) AuthzType() string {
@@ -145,6 +98,7 @@ type PackPayload struct {
Disabled *bool `json:"disabled"`
HostIDs *[]uint `json:"host_ids"`
LabelIDs *[]uint `json:"label_ids"`
+ TeamIDs *[]uint `json:"team_ids"`
}
type PackSpec struct {
@@ -174,10 +128,10 @@ type PackSpecQuery struct {
Denylist *bool `json:"denylist,omitempty"`
}
-// PackTarget associates a pack with either a host or a label
+// PackTarget targets a pack to a host, label, or team.
type PackTarget struct {
- ID uint
- PackID uint
+ ID uint `db:"id"`
+ PackID uint `db:"pack_id"`
Target
}
diff --git a/server/fleet/targets.go b/server/fleet/targets.go
index 35e54c47c5..be4c587a5c 100644
--- a/server/fleet/targets.go
+++ b/server/fleet/targets.go
@@ -78,8 +78,8 @@ const (
)
type Target struct {
- Type TargetType
- TargetID uint
+ Type TargetType `db:"type"`
+ TargetID uint `db:"target_id"`
}
func (t Target) AuthzType() string {
diff --git a/server/service/endpoint_appconfig.go b/server/service/endpoint_appconfig.go
index 37fa579e76..acfd31d48c 100644
--- a/server/service/endpoint_appconfig.go
+++ b/server/service/endpoint_appconfig.go
@@ -22,7 +22,7 @@ type appConfigResponse struct {
SSOSettings *fleet.SSOSettingsPayload `json:"sso_settings,omitempty"`
HostExpirySettings *fleet.HostExpirySettings `json:"host_expiry_settings,omitempty"`
HostSettings *fleet.HostSettings `json:"host_settings,omitempty"`
- AgentOptions *json.RawMessage `json:"agent_options,omitempty"`
+ AgentOptions *json.RawMessage `json:"agent_options,omitempty"`
License *fleet.LicenseInfo `json:"license,omitempty"`
Err error `json:"error,omitempty"`
}
diff --git a/server/service/endpoint_packs.go b/server/service/endpoint_packs.go
index 25c75bdfef..4d600d2011 100644
--- a/server/service/endpoint_packs.go
+++ b/server/service/endpoint_packs.go
@@ -18,6 +18,7 @@ type packResponse struct {
// IDs of hosts which were explicitly selected.
HostIDs []uint `json:"host_ids"`
LabelIDs []uint `json:"label_ids"`
+ TeamIDs []uint `json:"team_ids"`
}
func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack) (*packResponse, error) {
@@ -27,21 +28,11 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack
return nil, err
}
- hosts, err := svc.ListExplicitHostsInPack(ctx, pack.ID, opts)
- if err != nil {
- return nil, err
- }
-
- labels, err := svc.ListLabelsForPack(ctx, pack.ID)
- labelIDs := make([]uint, len(labels))
- for i, label := range labels {
- labelIDs[i] = label.ID
- }
- if err != nil {
- return nil, err
- }
-
- hostMetrics, err := svc.CountHostsInTargets(ctx, nil, fleet.HostTargets{HostIDs: hosts, LabelIDs: labelIDs})
+ hostMetrics, err := svc.CountHostsInTargets(
+ ctx,
+ nil,
+ fleet.HostTargets{HostIDs: pack.HostIDs, LabelIDs: pack.LabelIDs, TeamIDs: pack.TeamIDs},
+ )
if err != nil {
return nil, err
}
@@ -50,8 +41,9 @@ func packResponseForPack(ctx context.Context, svc fleet.Service, pack fleet.Pack
Pack: pack,
QueryCount: uint(len(queries)),
TotalHostsCount: hostMetrics.TotalHosts,
- HostIDs: hosts,
- LabelIDs: labelIDs,
+ HostIDs: pack.HostIDs,
+ LabelIDs: pack.LabelIDs,
+ TeamIDs: pack.TeamIDs,
}, nil
}
diff --git a/server/service/logging_packs.go b/server/service/logging_packs.go
index d994fedf91..9fac2425fe 100644
--- a/server/service/logging_packs.go
+++ b/server/service/logging_packs.go
@@ -117,92 +117,6 @@ func (mw loggingMiddleware) DeletePack(ctx context.Context, name string) error {
return err
}
-func (mw loggingMiddleware) AddLabelToPack(ctx context.Context, lid uint, pid uint) error {
- var (
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerInfo(err).Log(
- "method", "AddLabelToPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- err = mw.Service.AddLabelToPack(ctx, lid, pid)
- return err
-}
-
-func (mw loggingMiddleware) RemoveLabelFromPack(ctx context.Context, lid uint, pid uint) error {
- var (
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerInfo(err).Log(
- "method", "RemoveLabelFromPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- err = mw.Service.RemoveLabelFromPack(ctx, lid, pid)
- return err
-}
-
-func (mw loggingMiddleware) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) {
- var (
- labels []*fleet.Label
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerDebug(err).Log(
- "method", "ListLabelsForPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- labels, err = mw.Service.ListLabelsForPack(ctx, pid)
- return labels, err
-}
-
-func (mw loggingMiddleware) AddHostToPack(ctx context.Context, hid uint, pid uint) error {
- var (
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerInfo(err).Log(
- "method", "AddHostToPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- err = mw.Service.AddHostToPack(ctx, hid, pid)
- return err
-}
-
-func (mw loggingMiddleware) RemoveHostFromPack(ctx context.Context, hid uint, pid uint) error {
- var (
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerInfo(err).Log(
- "method", "RemoveHostFromPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- err = mw.Service.RemoveHostFromPack(ctx, hid, pid)
- return err
-}
-
func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
var (
packs []*fleet.Pack
@@ -221,24 +135,6 @@ func (mw loggingMiddleware) ListPacksForHost(ctx context.Context, hid uint) ([]*
return packs, err
}
-func (mw loggingMiddleware) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
- var (
- hosts []uint
- err error
- )
-
- defer func(begin time.Time) {
- _ = mw.loggerDebug(err).Log(
- "method", "ListHostsInPack",
- "err", err,
- "took", time.Since(begin),
- )
- }(time.Now())
-
- hosts, err = mw.Service.ListHostsInPack(ctx, pid, opt)
- return hosts, err
-}
-
func (mw loggingMiddleware) GetPackSpec(ctx context.Context, name string) (spec *fleet.PackSpec, err error) {
defer func(begin time.Time) {
_ = mw.loggerDebug(err).Log(
diff --git a/server/service/service_labels.go b/server/service/service_labels.go
index 3b1c75200d..c5e2dd5670 100644
--- a/server/service/service_labels.go
+++ b/server/service/service_labels.go
@@ -19,7 +19,7 @@ func (svc *Service) ApplyLabelSpecs(ctx context.Context, specs []*fleet.LabelSpe
}
if spec.LabelMembershipType == fleet.LabelMembershipTypeManual && spec.Hosts == nil {
// Hosts list doesn't need to contain anything, but it should at least not be nil.
- return errors.Errorf("label %s is declared as manual but contains not `hosts key`", spec.Name)
+ return errors.Errorf("label %s is declared as manual but contains no `hosts key`", spec.Name)
}
}
return svc.ds.ApplyLabelSpecs(specs)
diff --git a/server/service/service_osquery.go b/server/service/service_osquery.go
index 21b411e22a..8d1cf07948 100644
--- a/server/service/service_osquery.go
+++ b/server/service/service_osquery.go
@@ -209,13 +209,13 @@ func (svc *Service) GetClientConfig(ctx context.Context) (map[string]interface{}
baseConfig, err := svc.AgentOptionsForHost(ctx, &host)
if err != nil {
- return nil, osqueryError{message: "internal error: fetching base config: " + err.Error()}
+ return nil, osqueryError{message: "internal error: fetch base config: " + err.Error()}
}
var config map[string]interface{}
err = json.Unmarshal(baseConfig, &config)
if err != nil {
- return nil, osqueryError{message: "internal error: parsing base configuration: " + err.Error()}
+ return nil, osqueryError{message: "internal error: parse base configuration: " + err.Error()}
}
packs, err := svc.ds.ListPacksForHost(host.ID)
diff --git a/server/service/service_packs.go b/server/service/service_packs.go
index 6b2f0de4aa..e727f63acd 100644
--- a/server/service/service_packs.go
+++ b/server/service/service_packs.go
@@ -74,24 +74,6 @@ func (svc *Service) NewPack(ctx context.Context, p fleet.PackPayload) (*fleet.Pa
return nil, err
}
- if p.HostIDs != nil {
- for _, hostID := range *p.HostIDs {
- err = svc.AddHostToPack(ctx, hostID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
-
- if p.LabelIDs != nil {
- for _, labelID := range *p.LabelIDs {
- err = svc.AddLabelToPack(ctx, labelID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
-
return &pack, nil
}
@@ -121,107 +103,23 @@ func (svc *Service) ModifyPack(ctx context.Context, id uint, p fleet.PackPayload
pack.Disabled = *p.Disabled
}
+ if p.HostIDs != nil {
+ pack.HostIDs = *p.HostIDs
+ }
+
+ if p.LabelIDs != nil {
+ pack.LabelIDs = *p.LabelIDs
+ }
+
+ if p.TeamIDs != nil {
+ pack.TeamIDs = *p.TeamIDs
+ }
+
err = svc.ds.SavePack(pack)
if err != nil {
return nil, err
}
- // we must determine what hosts are attached to this pack. then, given
- // our new set of host_ids, we will mutate the database to reflect the
- // desired state.
- if p.HostIDs != nil {
-
- // first, let's retrieve the total set of hosts
- hosts, err := svc.ListHostsInPack(ctx, pack.ID, fleet.ListOptions{})
- if err != nil {
- return nil, err
- }
-
- // it will be efficient to create a data structure with constant time
- // lookups to determine whether or not a host is already added
- existingHosts := map[uint]bool{}
- for _, host := range hosts {
- existingHosts[host] = true
- }
-
- // we will also make a constant time lookup map for the desired set of
- // hosts as well.
- desiredHosts := map[uint]bool{}
- for _, hostID := range *p.HostIDs {
- desiredHosts[hostID] = true
- }
-
- // if the request declares a host ID but the host is not already
- // associated with the pack, we add it
- for _, hostID := range *p.HostIDs {
- if !existingHosts[hostID] {
- err = svc.AddHostToPack(ctx, hostID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
-
- // if the request does not declare the ID of a host which currently
- // exists, we delete the existing relationship
- for hostID := range existingHosts {
- if !desiredHosts[hostID] {
- err = svc.RemoveHostFromPack(ctx, hostID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
- }
-
- // we must determine what labels are attached to this pack. then, given
- // our new set of label_ids, we will mutate the database to reflect the
- // desired state.
- if p.LabelIDs != nil {
-
- // first, let's retrieve the total set of labels
- labels, err := svc.ListLabelsForPack(ctx, pack.ID)
- if err != nil {
- return nil, err
- }
-
- // it will be efficient to create a data structure with constant time
- // lookups to determine whether or not a label is already added
- existingLabels := map[uint]bool{}
- for _, label := range labels {
- existingLabels[label.ID] = true
- }
-
- // we will also make a constant time lookup map for the desired set of
- // labels as well.
- desiredLabels := map[uint]bool{}
- for _, labelID := range *p.LabelIDs {
- desiredLabels[labelID] = true
- }
-
- // if the request declares a label ID but the label is not already
- // associated with the pack, we add it
- for _, labelID := range *p.LabelIDs {
- if !existingLabels[labelID] {
- err = svc.AddLabelToPack(ctx, labelID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
-
- // if the request does not declare the ID of a label which currently
- // exists, we delete the existing relationship
- for labelID := range existingLabels {
- if !desiredLabels[labelID] {
- err = svc.RemoveLabelFromPack(ctx, labelID, pack.ID)
- if err != nil {
- return nil, err
- }
- }
- }
- }
-
return pack, err
}
@@ -245,62 +143,6 @@ func (svc *Service) DeletePackByID(ctx context.Context, id uint) error {
return svc.ds.DeletePack(pack.Name)
}
-func (svc *Service) AddLabelToPack(ctx context.Context, lid, pid uint) error {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
- return err
- }
-
- return svc.ds.AddLabelToPack(lid, pid)
-}
-
-func (svc *Service) RemoveLabelFromPack(ctx context.Context, lid, pid uint) error {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
- return err
- }
-
- return svc.ds.RemoveLabelFromPack(lid, pid)
-}
-
-func (svc *Service) AddHostToPack(ctx context.Context, hid, pid uint) error {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
- return err
- }
-
- return svc.ds.AddHostToPack(hid, pid)
-}
-
-func (svc *Service) RemoveHostFromPack(ctx context.Context, hid, pid uint) error {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionWrite); err != nil {
- return err
- }
-
- return svc.ds.RemoveHostFromPack(hid, pid)
-}
-
-func (svc *Service) ListLabelsForPack(ctx context.Context, pid uint) ([]*fleet.Label, error) {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
- return nil, err
- }
-
- return svc.ds.ListLabelsForPack(pid)
-}
-
-func (svc *Service) ListHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
- return nil, err
- }
-
- return svc.ds.ListHostsInPack(pid, opt)
-}
-
-func (svc *Service) ListExplicitHostsInPack(ctx context.Context, pid uint, opt fleet.ListOptions) ([]uint, error) {
- if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
- return nil, err
- }
-
- return svc.ds.ListExplicitHostsInPack(pid, opt)
-}
-
func (svc *Service) ListPacksForHost(ctx context.Context, hid uint) ([]*fleet.Pack, error) {
if err := svc.authz.Authorize(ctx, &fleet.Pack{}, fleet.ActionRead); err != nil {
return nil, err
diff --git a/server/test/comparisons.go b/server/test/comparisons.go
index c23f8bcfc5..52265268cb 100644
--- a/server/test/comparisons.go
+++ b/server/test/comparisons.go
@@ -10,9 +10,16 @@ import (
"github.com/stretchr/testify/assert"
)
+type TestingT interface {
+ assert.TestingT
+ Helper()
+}
+
// ElementsMatchSkipID asserts that the elements match, skipping any field with
// name "ID".
-func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
+func ElementsMatchSkipID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
+ t.Helper()
+
opt := cmp.FilterPath(func(p cmp.Path) bool {
for _, ps := range p {
switch ps := ps.(type) {
@@ -29,7 +36,9 @@ func ElementsMatchSkipID(t assert.TestingT, listA, listB interface{}, msgAndArgs
// ElementsMatchSkipTimestampsID asserts that the elements match, skipping any field with
// name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing after DB insertion.
-func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
+func ElementsMatchSkipTimestampsID(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
+ t.Helper()
+
opt := cmp.FilterPath(func(p cmp.Path) bool {
for _, ps := range p {
switch ps := ps.(type) {
@@ -45,6 +54,31 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{},
return ElementsMatchWithOptions(t, listA, listB, []cmp.Option{opt}, msgAndArgs)
}
+// EqualSkipTimestampsID asserts that the structs are equal, skipping any field
+// with name "ID", "CreatedAt", and "UpdatedAt". This is useful for comparing
+// after DB insertion.
+func EqualSkipTimestampsID(t TestingT, a, b interface{}, msgAndArgs ...interface{}) (ok bool) {
+ t.Helper()
+
+ opt := cmp.FilterPath(func(p cmp.Path) bool {
+ for _, ps := range p {
+ switch ps := ps.(type) {
+ case cmp.StructField:
+ switch ps.Name() {
+ case "ID", "UpdateCreateTimestamps", "CreateTimestamp", "UpdateTimestamp", "CreatedAt", "UpdatedAt":
+ return true
+ }
+ }
+ }
+ return false
+ }, cmp.Ignore())
+
+ if !cmp.Equal(a, b, opt) {
+ return assert.Fail(t, cmp.Diff(a, b, opt), msgAndArgs...)
+ }
+ return true
+}
+
// The below functions adapted from
// https://github.com/stretchr/testify/blob/v1.7.0/assert/assertions.go#L895 by
// utilizing the options provided in github.com/google/go-cmp/cmp
@@ -53,7 +87,7 @@ func ElementsMatchSkipTimestampsID(t assert.TestingT, listA, listB interface{},
// additional options as provided by the cmp package. This allows, for example,
// comparing structs while ignoring fields. See assert.ElementsMatch
// documentation for more details.
-func ElementsMatchWithOptions(t assert.TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) {
+func ElementsMatchWithOptions(t TestingT, listA, listB interface{}, opts cmp.Options, msgAndArgs ...interface{}) (ok bool) {
if isEmpty(listA) && isEmpty(listB) {
return true
}
@@ -100,7 +134,7 @@ func isEmpty(object interface{}) bool {
}
// isList checks that the provided value is array or slice.
-func isList(t assert.TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
+func isList(t TestingT, list interface{}, msgAndArgs ...interface{}) (ok bool) {
kind := reflect.TypeOf(list).Kind()
if kind != reflect.Array && kind != reflect.Slice {
return assert.Fail(t, fmt.Sprintf("%q has an unsupported type %s, expecting array or slice", list, kind),