27905: Specify TLS Server cert path when running orbit shell (#30329)

For #27905 

Provide TLS Server cert path via --tls_server_certs flag when running orbit shell.
This commit is contained in:
Juan Fernandez 2025-07-01 14:36:52 -04:00 committed by GitHub
parent 3eadc66bf7
commit c0dae08549
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 119 additions and 1 deletions

View file

@ -0,0 +1 @@
* Added new flag `--fleet-certificate` to `sudo orbit shell` command (which sets osquery's `--tls_server_certs` flag).

View file

@ -915,7 +915,6 @@ func main() {
log.Info().Msg("No cert chain available. Relying on system store.")
}
}
}
fleetClientCertPath := filepath.Join(c.String("root-dir"), constant.FleetTLSClientCertificateFileName)

View file

@ -3,6 +3,8 @@ package main
import (
"context"
"fmt"
"github.com/fleetdm/fleet/v4/pkg/certificate"
"github.com/fleetdm/fleet/v4/pkg/file"
"os"
"path/filepath"
"runtime"
@ -36,6 +38,11 @@ var shellCommand = &cli.Command{
Usage: "Enable debug logging",
EnvVars: []string{"ORBIT_DEBUG"},
},
&cli.StringFlag{
Name: "fleet-certificate",
Usage: "Path to the Fleet server certificate chain",
EnvVars: []string{"ORBIT_FLEET_CERTIFICATE"},
},
},
Action: func(c *cli.Context) error {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
@ -98,6 +105,15 @@ var shellCommand = &cli.Command{
osquery.WithFlags([]string{"--database_path", osqueryDB}),
}
certPath, err := getCertPath(
c.String("root-dir"),
c.String("fleet-certificate"),
)
if err != nil {
return err
}
opts = append(opts, osquery.WithFlags([]string{"--tls_server_certs", certPath}))
// Detect if the additional arguments have a positional argument.
//
// osqueryi/osqueryd has the following usage:
@ -140,3 +156,24 @@ var shellCommand = &cli.Command{
return nil
},
}
func getCertPath(rootDir, fleetCertPath string) (string, error) {
certPath := filepath.Join(rootDir, "certs.pem")
if fleetCertPath != "" {
certPath = fleetCertPath
}
exists, err := file.Exists(certPath)
switch {
case err != nil:
return "", fmt.Errorf("failed to check if cert exists %s: %w", certPath, err)
case !exists:
return "", fmt.Errorf("cert not found at %s", certPath)
default:
if _, err := certificate.LoadPEM(certPath); err != nil {
return "", fmt.Errorf("invalid PEM format %s: %w", certPath, err)
}
}
return certPath, nil
}

View file

@ -0,0 +1,81 @@
package main
import (
"fmt"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetCertPath(t *testing.T) {
validRoot := t.TempDir()
invalidRoot := t.TempDir()
srcCertPath := filepath.Join("..", "..", "pkg", "cryptoinfo", "testdata", "test_crt.pem")
srcCert, err := os.ReadFile(srcCertPath)
require.NoError(t, err)
validCertPath := filepath.Join(validRoot, "certs.pem")
require.NoError(t, os.WriteFile(validCertPath, srcCert, 0644))
invalidCertPath := filepath.Join(invalidRoot, "invalid_cert.pem")
require.NoError(t, os.WriteFile(invalidCertPath, []byte(`INVALID_CERT_CONTENT`), 0644))
tests := []struct {
name string
rootDir string
fleetCert string
expectedPath string
expectError error
}{
{
name: "Default cert path exists",
rootDir: validRoot,
fleetCert: "",
expectedPath: validCertPath,
},
{
name: "Provided cert path exists",
rootDir: validRoot,
fleetCert: srcCertPath,
expectedPath: srcCertPath,
},
{
name: "Default cert does not exist",
rootDir: invalidRoot,
fleetCert: "",
expectedPath: "",
expectError: fmt.Errorf("cert not found at %s", filepath.Join(invalidRoot, "certs.pem")),
},
{
name: "Invalid cert path provided",
rootDir: "",
fleetCert: filepath.Join(validRoot, "blah.pem"),
expectedPath: "",
expectError: fmt.Errorf("cert not found at %s", filepath.Join(validRoot, "blah.pem")),
},
{
name: "Invalid PEM format",
rootDir: "",
fleetCert: invalidCertPath,
expectedPath: "",
expectError: fmt.Errorf("invalid PEM format %s: no valid certificates found in %s", invalidCertPath, invalidCertPath),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path, err := getCertPath(tt.rootDir, tt.fleetCert)
if tt.expectError != nil {
require.Error(t, err)
require.Empty(t, path)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedPath, path)
}
})
}
}