diff --git a/ee/server/service/software_installers.go b/ee/server/service/software_installers.go index 5826f488a1..1bbccd4ab6 100644 --- a/ee/server/service/software_installers.go +++ b/ee/server/service/software_installers.go @@ -1223,7 +1223,7 @@ func UninstallSoftwareMigration( // Update $PACKAGE_ID in uninstall script preProcessUninstallScript(&payload) - // Update the package_id in the software installer and the uninstall script + // Update the package_id and extension in the software installer and the uninstall script if err := ds.UpdateSoftwareInstallerWithoutPackageIDs(ctx, id, payload); err != nil { return ctxerr.Wrap(ctx, err, "updating package_id in software installer") } diff --git a/server/datastore/mysql/software_installers.go b/server/datastore/mysql/software_installers.go index ee314da7d2..5a7273a74c 100644 --- a/server/datastore/mysql/software_installers.go +++ b/server/datastore/mysql/software_installers.go @@ -959,10 +959,10 @@ func (ds *Datastore) UpdateSoftwareInstallerWithoutPackageIDs(ctx context.Contex } query := ` UPDATE software_installers - SET package_ids = ?, uninstall_script_content_id = ? + SET package_ids = ?, uninstall_script_content_id = ?, extension = ? WHERE id = ? ` - _, err = ds.writer(ctx).ExecContext(ctx, query, strings.Join(payload.PackageIDs, ","), uninstallScriptID, id) + _, err = ds.writer(ctx).ExecContext(ctx, query, strings.Join(payload.PackageIDs, ","), uninstallScriptID, payload.Extension, id) if err != nil { return ctxerr.Wrap(ctx, err, "update software installer without package ID") } diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index d871544aeb..9e1cf62a04 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -10584,14 +10584,22 @@ func (s *integrationEnterpriseTestSuite) TestSoftwareInstallerUploadDownloadAndD installerID, titleID := checkSoftwareInstaller(t, payload) var origPackageIDs string - // Update DB by clearing package id + var origExtension string + // Update DB by clearing package id and tweaking extension mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { if err := sqlx.GetContext(context.Background(), q, &origPackageIDs, `SELECT package_ids FROM software_installers WHERE id = ?`, installerID); err != nil { return err } require.NotEmpty(t, origPackageIDs) - if _, err = q.ExecContext(context.Background(), `UPDATE software_installers SET package_ids = '' WHERE id = ?`, + + if err := sqlx.GetContext(context.Background(), q, &origExtension, `SELECT extension FROM software_installers WHERE id = ?`, + installerID); err != nil { + return err + } + require.NotEmpty(t, origExtension) + + if _, err = q.ExecContext(context.Background(), `UPDATE software_installers SET package_ids = '', extension = 'rb' WHERE id = ?`, installerID); err != nil { return err } @@ -10610,7 +10618,7 @@ func (s *integrationEnterpriseTestSuite) TestSoftwareInstallerUploadDownloadAndD err = eeservice.UninstallSoftwareMigration(context.Background(), s.ds, s.softwareInstallStore, logger) require.NoError(t, err) - // Check package ID + // Check package ID and extension mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { var packageIDs string if err := sqlx.GetContext(context.Background(), q, &packageIDs, `SELECT package_ids FROM software_installers WHERE id = ?`, @@ -10618,6 +10626,14 @@ func (s *integrationEnterpriseTestSuite) TestSoftwareInstallerUploadDownloadAndD return err } assert.Equal(t, origPackageIDs, packageIDs) + + var extension string + if err := sqlx.GetContext(context.Background(), q, &extension, `SELECT extension FROM software_installers WHERE id = ?`, + installerID); err != nil { + return err + } + assert.Equal(t, origExtension, extension) + return nil })