From c64183717bbb04a0147d3088146a3518e0b0b3a1 Mon Sep 17 00:00:00 2001 From: Cayde6 Date: Thu, 11 Sep 2025 23:05:57 +0900 Subject: [PATCH] fix: replace grpc.NewClient (#19653) (#24188) Signed-off-by: Jack-R-lantern --- .../commands/argocd_git_ask_pass.go | 2 +- cmpserver/apiclient/clientset.go | 2 +- commitserver/apiclient/clientset.go | 4 +- pkg/apiclient/apiclient.go | 2 +- reposerver/apiclient/clientset.go | 3 +- server/server.go | 4 +- util/grpc/grpc.go | 101 ++++++++---------- util/grpc/grpc_test.go | 2 +- 8 files changed, 50 insertions(+), 70 deletions(-) diff --git a/cmd/argocd-git-ask-pass/commands/argocd_git_ask_pass.go b/cmd/argocd-git-ask-pass/commands/argocd_git_ask_pass.go index b4780e2515..01c7c95f99 100644 --- a/cmd/argocd-git-ask-pass/commands/argocd_git_ask_pass.go +++ b/cmd/argocd-git-ask-pass/commands/argocd_git_ask_pass.go @@ -35,7 +35,7 @@ func NewCommand() *cobra.Command { if nonce == "" { errors.CheckError(fmt.Errorf("%s is not set", askpass.ASKPASS_NONCE_ENV)) } - conn, err := grpc_util.BlockingDial(ctx, "unix", askpass.SocketPath, nil, grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc_util.BlockingNewClient(ctx, "unix", askpass.SocketPath, nil, grpc.WithTransportCredentials(insecure.NewCredentials())) errors.CheckError(err) defer utilio.Close(conn) client := askpass.NewAskPassServiceClient(conn) diff --git a/cmpserver/apiclient/clientset.go b/cmpserver/apiclient/clientset.go index 1e3bd68c6f..21cb77f6b2 100644 --- a/cmpserver/apiclient/clientset.go +++ b/cmpserver/apiclient/clientset.go @@ -52,7 +52,7 @@ func NewConnection(address string) (*grpc.ClientConn, error) { } dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) - conn, err := grpc_util.BlockingDial(context.Background(), "unix", address, nil, dialOpts...) + conn, err := grpc_util.BlockingNewClient(context.Background(), "unix", address, nil, dialOpts...) if err != nil { log.Errorf("Unable to connect to config management plugin service with address %s", address) return nil, err diff --git a/commitserver/apiclient/clientset.go b/commitserver/apiclient/clientset.go index a89efc3eb8..be748d9384 100644 --- a/commitserver/apiclient/clientset.go +++ b/commitserver/apiclient/clientset.go @@ -40,9 +40,7 @@ func NewConnection(address string) (*grpc.ClientConn, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - // TODO: switch to grpc.NewClient. - //nolint:staticcheck - conn, err := grpc.Dial(address, opts...) + conn, err := grpc.NewClient(address, opts...) if err != nil { log.Errorf("Unable to connect to commit service with address %s", address) return nil, err diff --git a/pkg/apiclient/apiclient.go b/pkg/apiclient/apiclient.go index 5f0a4b2950..43d14494d4 100644 --- a/pkg/apiclient/apiclient.go +++ b/pkg/apiclient/apiclient.go @@ -542,7 +542,7 @@ func (c *client) newConn() (*grpc.ClientConn, io.Closer, error) { if c.UserAgent != "" { dialOpts = append(dialOpts, grpc.WithUserAgent(c.UserAgent)) } - conn, e := grpc_util.BlockingDial(ctx, network, serverAddr, creds, dialOpts...) + conn, e := grpc_util.BlockingNewClient(ctx, network, serverAddr, creds, dialOpts...) closers = append(closers, conn) return conn, utilio.NewCloser(func() error { var firstErr error diff --git a/reposerver/apiclient/clientset.go b/reposerver/apiclient/clientset.go index def437e878..9c6b5c09eb 100644 --- a/reposerver/apiclient/clientset.go +++ b/reposerver/apiclient/clientset.go @@ -82,8 +82,7 @@ func NewConnection(address string, timeoutSeconds int, tlsConfig *TLSConfigurati opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - //nolint:staticcheck - conn, err := grpc.Dial(address, opts...) + conn, err := grpc.NewClient(address, opts...) if err != nil { log.Errorf("Unable to connect to repository service with address %s", address) return nil, err diff --git a/server/server.go b/server/server.go index 124b08eaac..afce751e82 100644 --- a/server/server.go +++ b/server/server.go @@ -534,8 +534,8 @@ func (server *ArgoCDServer) Listen() (*Listeners, error) { } else { dOpts = append(dOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - //nolint:staticcheck - conn, err := grpc.Dial(fmt.Sprintf("localhost:%d", server.ListenPort), dOpts...) + + conn, err := grpc.NewClient(fmt.Sprintf("localhost:%d", server.ListenPort), dOpts...) if err != nil { utilio.Close(mainLn) utilio.Close(metricsLn) diff --git a/util/grpc/grpc.go b/util/grpc/grpc.go index 28265b7388..13ab17012a 100644 --- a/util/grpc/grpc.go +++ b/util/grpc/grpc.go @@ -14,6 +14,7 @@ import ( "golang.org/x/net/proxy" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -30,73 +31,55 @@ func LoggerRecoveryHandler(log *logrus.Entry) recovery.RecoveryHandlerFunc { } } -// BlockingDial is a helper method to dial the given address, using optional TLS credentials, +// BlockingNewClient is a helper method to dial the given address, using optional TLS credentials, // and blocking until the returned connection is ready. If the given credentials are nil, the // connection will be insecure (plain-text). // Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go -func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { - // grpc.Dial doesn't provide any information on permanent connection errors (like - // TLS handshake failures). So in order to provide good error messages, we need a - // custom dialer that can provide that info. That means we manage the TLS handshake. - result := make(chan any, 1) - writeResult := func(res any) { - // non-blocking write: we only need the first result - select { - case result <- res: - default: - } +func BlockingNewClient(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + proxyDialer := proxy.FromEnvironment() + rawConn, err := proxyDialer.Dial(network, address) + if err != nil { + return nil, fmt.Errorf("error dial proxy: %w", err) } - dialer := func(ctx context.Context, address string) (net.Conn, error) { - proxyDialer := proxy.FromEnvironment() - conn, err := proxyDialer.Dial(network, address) + if creds != nil { + rawConn, _, err = creds.ClientHandshake(ctx, address, rawConn) if err != nil { - writeResult(err) - return nil, fmt.Errorf("error dial proxy: %w", err) + return nil, fmt.Errorf("error creating connection: %w", err) } - if creds != nil { - conn, _, err = creds.ClientHandshake(ctx, address, conn) - if err != nil { - writeResult(err) - return nil, fmt.Errorf("error creating connection: %w", err) - } - } - return conn, nil + } + customDialer := func(_ context.Context, _ string) (net.Conn, error) { + return rawConn, nil } - // Even with grpc.FailOnNonTempDialError, this call will usually timeout in - // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to - // know when we're done. So we run it in a goroutine and then use result - // channel to either get the channel or fail-fast. - go func() { - opts = append(opts, - //nolint:staticcheck - grpc.WithBlock(), - //nolint:staticcheck - grpc.FailOnNonTempDialError(true), - grpc.WithContextDialer(dialer), - grpc.WithTransportCredentials(insecure.NewCredentials()), // we are handling TLS, so tell grpc not to - grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: common.GetGRPCKeepAliveTime()}), - ) - //nolint:staticcheck - conn, err := grpc.DialContext(ctx, address, opts...) - var res any - if err != nil { - res = err - } else { - res = conn - } - writeResult(res) - }() + opts = append(opts, + grpc.WithContextDialer(customDialer), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithKeepaliveParams(keepalive.ClientParameters{Time: common.GetGRPCKeepAliveTime()}), + ) - select { - case res := <-result: - if conn, ok := res.(*grpc.ClientConn); ok { - return conn, nil + conn, err := grpc.NewClient("passthrough:"+address, opts...) + if err != nil { + return nil, fmt.Errorf("grpc.NewClient failed: %w", err) + } + + conn.Connect() + if err := waitForReady(ctx, conn); err != nil { + return nil, fmt.Errorf("gRPC connection not ready: %w", err) + } + + return conn, nil +} + +func waitForReady(ctx context.Context, conn *grpc.ClientConn) error { + for { + state := conn.GetState() + if state == connectivity.Ready { + return nil + } + if !conn.WaitForStateChange(ctx, state) { + return ctx.Err() // context timeout or cancellation } - return nil, res.(error) - case <-ctx.Done(): - return nil, ctx.Err() } } @@ -120,7 +103,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) { ctx, cancel := context.WithTimeout(context.Background(), dialTime) defer cancel() - conn, err := BlockingDial(ctx, "tcp", address, creds) + conn, err := BlockingNewClient(ctx, "tcp", address, creds) if err == nil { _ = conn.Close() testResult.TLS = true @@ -128,7 +111,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) { ctx, cancel := context.WithTimeout(context.Background(), dialTime) defer cancel() - conn, err := BlockingDial(ctx, "tcp", address, creds) + conn, err := BlockingNewClient(ctx, "tcp", address, creds) if err == nil { _ = conn.Close() } else { @@ -143,7 +126,7 @@ func TestTLS(address string, dialTime time.Duration) (*TLSTestResult, error) { // refused). Test if server accepts plain-text connections ctx, cancel = context.WithTimeout(context.Background(), dialTime) defer cancel() - conn, err = BlockingDial(ctx, "tcp", address, nil) + conn, err = BlockingNewClient(ctx, "tcp", address, nil) if err == nil { _ = conn.Close() testResult.TLS = false diff --git a/util/grpc/grpc_test.go b/util/grpc/grpc_test.go index a42b47dfca..f2d6e48cce 100644 --- a/util/grpc/grpc_test.go +++ b/util/grpc/grpc_test.go @@ -93,7 +93,7 @@ func TestBlockingDial_ProxyEnvironmentHandling(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - conn, err := BlockingDial(ctx, "tcp", tt.address, nil) + conn, err := BlockingNewClient(ctx, "tcp", tt.address, nil) if tt.expectError { require.Error(t, err)