diff --git a/changes/10787-collation-fix b/changes/10787-collation-fix new file mode 100644 index 0000000000..c0fed55da7 --- /dev/null +++ b/changes/10787-collation-fix @@ -0,0 +1 @@ +* Fixed a migration that was causing `fleet prepare db` to fail due to changes in the collation of the tables. IMPORTANT: please make sure to have a database backup before running migrations. diff --git a/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation.go b/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation.go index e28bf82263..7cd4b59d0e 100644 --- a/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation.go +++ b/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation.go @@ -3,6 +3,7 @@ package tables import ( "bytes" "database/sql" + "encoding/json" "fmt" "text/template" @@ -13,6 +14,176 @@ func init() { MigrationClient.AddMigration(Up_20230315104937, Down_20230315104937) } +func fixupSoftware(tx *sql.Tx, collation string) error { + //nolint:gosec // string formatting must be used here, but input is not user-controllable + rows, err := tx.Query(` + SELECT + COUNT(*) as total, + CONCAT('[', GROUP_CONCAT(id SEPARATOR ','), ']') as ids + FROM software + GROUP BY ` + + fmt.Sprintf("`version` COLLATE %s,", collation) + + fmt.Sprintf("`release` COLLATE %s,", collation) + + fmt.Sprintf(`name COLLATE %s, + source COLLATE %s, + vendor COLLATE %s, + arch COLLATE %s + HAVING total > 1 + COLLATE %s`, collation, collation, collation, collation, collation)) + + if err != nil { + return fmt.Errorf("aggregating dupes: %w", err) + } + + defer rows.Close() + var idGroups [][]uint + for rows.Next() { + var rawIDs json.RawMessage + var total int + if err := rows.Scan(&total, &rawIDs); err != nil { + return fmt.Errorf("scanning values: %w", err) + } + var ids []uint + if err := json.Unmarshal(rawIDs, &ids); err != nil { + return fmt.Errorf("unmarshalling keys: %w", err) + } + idGroups = append(idGroups, ids) + } + + if len(idGroups) > 0 { + fmt.Printf("INFO: found %d duplicate software entries: %v\n", len(idGroups), idGroups) + } + + for _, ids := range idGroups { + for i := 1; i < len(ids); i++ { + if _, err := tx.Exec("DELETE FROM software_cve WHERE software_id = ?", ids[i]); err != nil { + return fmt.Errorf("deleting duplicated software with id %d from software_cve: %w", ids[i], err) + } + if _, err := tx.Exec("DELETE FROM software_host_counts WHERE software_id = ?", ids[i]); err != nil { + return fmt.Errorf("deleting duplicate software with id %d from software_host_counts: %w", ids[i], err) + } + if _, err := tx.Exec("DELETE FROM host_software WHERE software_id = ?", ids[i]); err != nil { + return fmt.Errorf("deleting duplicate software with id %d from host_software: %w", ids[i], err) + } + if _, err := tx.Exec("DELETE FROM software WHERE id = ?", ids[i]); err != nil { + return fmt.Errorf("deleting duplicate software with id %d: %w", ids[i], err) + } + } + } + + return nil +} + +func fixupHostUsers(tx *sql.Tx, collation string) error { + //nolint:gosec // string formatting must be used here, but input is not user-controllable + rows, err := tx.Query(fmt.Sprintf(` + SELECT + COUNT(*) as total, + CONCAT('[', GROUP_CONCAT(JSON_OBJECT('username', username, 'host_id', host_id, 'uid', uid) SEPARATOR ","), ']') as ids + FROM host_users + GROUP BY + host_id, + uid, + username COLLATE %s + HAVING total > 1 + COLLATE %s`, collation, collation)) + if err != nil { + return fmt.Errorf("aggregating dupes: %w", err) + } + + type hostUser struct { + Username string + HostID uint `json:"host_id"` + UID uint + } + + defer rows.Close() + var keyGroups [][]hostUser + for rows.Next() { + var raw json.RawMessage + var total int + if err := rows.Scan(&total, &raw); err != nil { + return fmt.Errorf("scanning dupe results: %w", err) + } + + var hu []hostUser + if err := json.Unmarshal(raw, &hu); err != nil { + return fmt.Errorf("unmarshalling dupe results: %w", err) + } + keyGroups = append(keyGroups, hu) + } + + if len(keyGroups) > 0 { + fmt.Printf("INFO: found %d duplicate host_software entries: %v\n", len(keyGroups), keyGroups) + } + + for _, keys := range keyGroups { + for i := 1; i < len(keys); i++ { + if _, err := tx.Exec("DELETE FROM host_users WHERE host_id = ? AND uid = ? AND username = ?", keys[i].HostID, keys[i].UID, keys[i].Username); err != nil { + return fmt.Errorf("deleting duplicate entries with key (host_id=%d, uid=%d, username=%s) from host_users: %w", keys[i].HostID, keys[i].UID, keys[i].Username, err) + } + } + } + return nil +} + +func fixupOS(tx *sql.Tx, collation string) error { + //nolint:gosec // string formatting must be used here, but input is not user-controllable + rows, err := tx.Query(fmt.Sprintf(` + SELECT + COUNT(*) as total, + CONCAT('[', GROUP_CONCAT(JSON_OBJECT('name', name, 'version', version, 'arch', arch, 'kernel_version', kernel_version, 'platform', platform) SEPARATOR ","), ']') as ids + FROM operating_systems + GROUP BY `+ + fmt.Sprintf("`version` COLLATE %s,", collation)+ + `name COLLATE %s, + arch COLLATE %s, + kernel_version COLLATE %s, + platform COLLATE %s + HAVING total > 1 + COLLATE %s`, collation, collation, collation, collation, collation)) + if err != nil { + return fmt.Errorf("aggregating dupes: %w", err) + } + + type os struct { + Name string + Version string + Arch string + KernelVersion string `json:"kernel_version"` + Platform string + } + + defer rows.Close() + var keyGroups [][]os + for rows.Next() { + var raw json.RawMessage + var total int + if err := rows.Scan(&total, &raw); err != nil { + return fmt.Errorf("scanning dupes: %w", err) + } + + var o []os + if err := json.Unmarshal(raw, &o); err != nil { + return fmt.Errorf("unmarshalling dupes: %w", err) + } + keyGroups = append(keyGroups, o) + } + + if len(keyGroups) > 0 { + fmt.Printf("INFO: found %d duplicate operating_system entries: %v\n", len(keyGroups), keyGroups) + } + + for _, keys := range keyGroups { + for i := 1; i < len(keys); i++ { + if _, err := tx.Exec("DELETE FROM operating_systems WHERE name = ? AND version = ? AND arch = ? AND kernel_version = ? AND platform = ?", keys[i].Name, keys[i].Version, keys[i].Arch, keys[i].KernelVersion, keys[i].Platform); err != nil { + return fmt.Errorf("deleting dupes with key (name=%s, version=%s, arch=%s, kernel_version=%s, platfrom=%s): %w", keys[i].Name, keys[i].Version, keys[i].Arch, keys[i].KernelVersion, keys[i].Platform, err) + } + } + } + return nil +} + // changeCollation changes the default collation set of the database and all // table to the provided collation // @@ -24,8 +195,19 @@ func changeCollation(tx *sql.Tx, charset string, collation string) (err error) { return fmt.Errorf("alter database: %w", err) } - txx := sqlx.Tx{Tx: tx} + if err := fixupSoftware(tx, collation); err != nil { + return fmt.Errorf("fixing software table: %w", err) + } + if err := fixupHostUsers(tx, collation); err != nil { + return fmt.Errorf("fixing host_users table: %w", err) + } + + if err := fixupOS(tx, collation); err != nil { + return fmt.Errorf("fixing operating_systems table: %w", err) + } + + txx := sqlx.Tx{Tx: tx} var names []string err = txx.Select(&names, ` SELECT table_name diff --git a/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation_test.go b/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation_test.go index 78c8882912..a1c1f71893 100644 --- a/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation_test.go +++ b/server/datastore/mysql/migrations/tables/20230315104937_EnsureUniformCollation_test.go @@ -1,6 +1,7 @@ package tables import ( + "strings" "testing" "github.com/jmoiron/sqlx" @@ -35,6 +36,55 @@ func TestUp_20230315104937(t *testing.T) { err = sqlx.Get(db, &c, "SELECT COUNT(*) FROM host_mdm_apple_profiles hmap JOIN hosts h WHERE h.uuid = hmap.host_uuid AND hmap.status = 'failed'") require.ErrorContains(t, err, "Error 1267") + var mysqlVersion string + err = sqlx.Get(db, &mysqlVersion, "SELECT VERSION()") + require.NoError(t, err) + + // this test can only be replicated in MySQL 8 because for prior + // versions all collations are padded. + if strings.HasPrefix(mysqlVersion, "8") { + // ensure software is using a different collation + _, err = db.Exec("ALTER TABLE `software` CONVERT TO CHARACTER SET `utf8mb4` COLLATE `utf8mb4_0900_ai_ci`") + require.NoError(t, err) + + // insert two software records + insertSoftwareStmt := `INSERT INTO software (name, version, source, bundle_identifier, vendor, arch) VALUES (?, '1.2.1', 'rpm_packages', '', ?, 'x86_64')` + _, err = db.Exec(insertSoftwareStmt, "zchunk-libs", "vendor") + require.NoError(t, err) + _, err = db.Exec(insertSoftwareStmt, "zchunk-libs", "vendor ") + require.NoError(t, err) + _, err = db.Exec(insertSoftwareStmt, "vim", "vendor") + require.NoError(t, err) + _, err = db.Exec(insertSoftwareStmt, "vim", "vendor ") + require.NoError(t, err) + + // insert host_users + _, err = db.Exec("ALTER TABLE `host_users` CONVERT TO CHARACTER SET `utf8mb4` COLLATE `utf8mb4_0900_ai_ci`") + require.NoError(t, err) + insertHostUsersStmt := `INSERT INTO host_users (host_id, uid, username) VALUES (?, 1, ?)` + _, err = db.Exec(insertHostUsersStmt, 1, "username") + require.NoError(t, err) + _, err = db.Exec(insertHostUsersStmt, 1, "username ") + require.NoError(t, err) + _, err = db.Exec(insertHostUsersStmt, 2, "username") + require.NoError(t, err) + _, err = db.Exec(insertHostUsersStmt, 2, "username ") + require.NoError(t, err) + + // insert operating_systems + _, err = db.Exec("ALTER TABLE `operating_systems` CONVERT TO CHARACTER SET `utf8mb4` COLLATE `utf8mb4_0900_ai_ci`") + require.NoError(t, err) + insertOSStmt := `INSERT INTO operating_systems (name,version,arch,kernel_version,platform) VALUES (?, '12.1', 'arch', 'kernel', ?)` + _, err = db.Exec(insertOSStmt, "macOS", "darwin") + require.NoError(t, err) + _, err = db.Exec(insertOSStmt, "macOS", "darwin ") + require.NoError(t, err) + _, err = db.Exec(insertOSStmt, "arch", "linux") + require.NoError(t, err) + _, err = db.Exec(insertOSStmt, "arch", "linux ") + require.NoError(t, err) + } + applyNext(t, db) err = sqlx.Get(db, &c, "SELECT COUNT(*) FROM host_mdm_apple_profiles hmap JOIN hosts h WHERE h.uuid = hmap.host_uuid AND hmap.status = 'failed'") @@ -57,4 +107,22 @@ func TestUp_20230315104937(t *testing.T) { WHERE collation_name != "utf8mb4_unicode_ci" AND table_schema = (SELECT database())`) require.NoError(t, err) require.Equal(t, []string{"secret", "node_key", "orbit_node_key"}, columns) + + if strings.HasPrefix(mysqlVersion, "8") { + // verify that duplicate columns have been removed + c = 0 + err = sqlx.Get(db, &c, "SELECT COUNT(*) FROM software") + require.NoError(t, err) + require.Equal(t, 2, c) + + c = 0 + err = sqlx.Get(db, &c, "SELECT COUNT(*) FROM host_users") + require.NoError(t, err) + require.Equal(t, 2, c) + + c = 0 + err = sqlx.Get(db, &c, "SELECT COUNT(*) FROM operating_systems") + require.NoError(t, err) + require.Equal(t, 2, c) + } }