From 02990204e1b763291fe70c5fb2fecc31a0465122 Mon Sep 17 00:00:00 2001 From: Stephan Dilly Date: Tue, 12 Jan 2021 15:49:28 +0100 Subject: [PATCH] fix crash when no remote called 'origin' present (#486) --- asyncgit/src/sync/branch.rs | 7 ++-- asyncgit/src/sync/cred.rs | 33 ++++++++--------- asyncgit/src/sync/mod.rs | 4 +-- asyncgit/src/sync/remotes.rs | 69 ++++++++++++++++++++++++++++++++---- src/components/push.rs | 21 +++++------ 5 files changed, 92 insertions(+), 42 deletions(-) diff --git a/asyncgit/src/sync/branch.rs b/asyncgit/src/sync/branch.rs index 871958bc..67455254 100644 --- a/asyncgit/src/sync/branch.rs +++ b/asyncgit/src/sync/branch.rs @@ -1,5 +1,6 @@ //! +use super::{remotes::get_first_remote_in_repo, utils::bytes2string}; use crate::{ error::{Error, Result}, sync::{utils, CommitId}, @@ -8,8 +9,6 @@ use git2::{BranchType, Repository}; use scopetime::scope_time; use utils::get_head_repo; -use super::utils::bytes2string; - /// returns the branch-name head is currently pointing to /// this might be expensive, see `cached::BranchName` pub(crate) fn get_branch_name(repo_path: &str) -> Result { @@ -98,8 +97,8 @@ pub(crate) fn branch_set_upstream( repo.find_branch(branch_name, BranchType::Local)?; if branch.upstream().is_err() { - //TODO: what about other remote names - let upstream_name = format!("origin/{}", branch_name); + let remote = get_first_remote_in_repo(repo)?; + let upstream_name = format!("{}/{}", remote, branch_name); branch.set_upstream(Some(upstream_name.as_str()))?; } diff --git a/asyncgit/src/sync/cred.rs b/asyncgit/src/sync/cred.rs index fbd23044..8d8706a7 100644 --- a/asyncgit/src/sync/cred.rs +++ b/asyncgit/src/sync/cred.rs @@ -5,6 +5,8 @@ use git2::{Config, CredentialHelper}; use crate::error::{Error, Result}; use crate::CWD; +use super::remotes::get_first_remote_in_repo; + /// basic Authentication Credentials #[derive(Debug, Clone, Default, PartialEq)] pub struct BasicAuthCredential { @@ -29,10 +31,10 @@ impl BasicAuthCredential { } /// know if username and password are needed for this url -pub fn need_username_password(remote: &str) -> Result { +pub fn need_username_password() -> Result { let repo = crate::sync::utils::repo(CWD)?; let url = repo - .find_remote(remote)? + .find_remote(&get_first_remote_in_repo(&repo)?)? .url() .ok_or(Error::UnknownRemote)? .to_owned(); @@ -41,12 +43,10 @@ pub fn need_username_password(remote: &str) -> Result { } /// extract username and password -pub fn extract_username_password( - remote: &str, -) -> Result { +pub fn extract_username_password() -> Result { let repo = crate::sync::utils::repo(CWD)?; let url = repo - .find_remote(remote)? + .find_remote(&get_first_remote_in_repo(&repo)?)? .url() .ok_or(Error::UnknownRemote)? .to_owned(); @@ -86,10 +86,11 @@ mod tests { need_username_password, BasicAuthCredential, }; use crate::sync::tests::repo_init; - use crate::sync::DEFAULT_REMOTE_NAME; use serial_test::serial; use std::env; + const DEFAULT_REMOTE_NAME: &str = "origin"; + #[test] fn test_credential_complete() { assert_eq!( @@ -164,10 +165,7 @@ mod tests { repo.remote(DEFAULT_REMOTE_NAME, "http://user@github.com") .unwrap(); - assert_eq!( - need_username_password(DEFAULT_REMOTE_NAME).unwrap(), - true - ); + assert_eq!(need_username_password().unwrap(), true); } #[test] @@ -181,10 +179,7 @@ mod tests { repo.remote(DEFAULT_REMOTE_NAME, "git@github.com:user/repo") .unwrap(); - assert_eq!( - need_username_password(DEFAULT_REMOTE_NAME).unwrap(), - false - ); + assert_eq!(need_username_password().unwrap(), false); } #[test] @@ -198,7 +193,7 @@ mod tests { env::set_current_dir(repo_path).unwrap(); - need_username_password(DEFAULT_REMOTE_NAME).unwrap(); + need_username_password().unwrap(); } #[test] @@ -216,7 +211,7 @@ mod tests { .unwrap(); assert_eq!( - extract_username_password(DEFAULT_REMOTE_NAME).unwrap(), + extract_username_password().unwrap(), BasicAuthCredential::new( Some("user".to_owned()), Some("pass".to_owned()) @@ -236,7 +231,7 @@ mod tests { .unwrap(); assert_eq!( - extract_username_password(DEFAULT_REMOTE_NAME).unwrap(), + extract_username_password().unwrap(), BasicAuthCredential::new(Some("user".to_owned()), None) ); } @@ -252,6 +247,6 @@ mod tests { env::set_current_dir(repo_path).unwrap(); - extract_username_password(DEFAULT_REMOTE_NAME).unwrap(); + extract_username_password().unwrap(); } } diff --git a/asyncgit/src/sync/mod.rs b/asyncgit/src/sync/mod.rs index 653248c8..618039b8 100644 --- a/asyncgit/src/sync/mod.rs +++ b/asyncgit/src/sync/mod.rs @@ -41,8 +41,8 @@ pub use hunks::{reset_hunk, stage_hunk, unstage_hunk}; pub use ignore::add_to_ignore; pub use logwalker::LogWalker; pub use remotes::{ - fetch_origin, get_remotes, push, ProgressNotification, - DEFAULT_REMOTE_NAME, + fetch_origin, get_first_remote, get_remotes, push, + ProgressNotification, }; pub use reset::{reset_stage, reset_workdir}; pub use stash::{get_stashes, stash_apply, stash_drop, stash_save}; diff --git a/asyncgit/src/sync/remotes.rs b/asyncgit/src/sync/remotes.rs index c2134b0e..86d59a07 100644 --- a/asyncgit/src/sync/remotes.rs +++ b/asyncgit/src/sync/remotes.rs @@ -2,12 +2,14 @@ use super::{branch::branch_set_upstream, CommitId}; use crate::{ - error::Result, sync::cred::BasicAuthCredential, sync::utils, + error::{Error, Result}, + sync::cred::BasicAuthCredential, + sync::utils, }; use crossbeam_channel::Sender; use git2::{ Cred, Error as GitError, FetchOptions, PackBuilderStage, - PushOptions, RemoteCallbacks, + PushOptions, RemoteCallbacks, Repository, }; use scopetime::scope_time; @@ -52,9 +54,6 @@ pub enum ProgressNotification { Done, } -/// -pub const DEFAULT_REMOTE_NAME: &str = "origin"; - /// pub fn get_remotes(repo_path: &str) -> Result> { scope_time!("get_remotes"); @@ -67,12 +66,37 @@ pub fn get_remotes(repo_path: &str) -> Result> { Ok(remotes) } +/// +pub fn get_first_remote(repo_path: &str) -> Result { + let repo = utils::repo(repo_path)?; + get_first_remote_in_repo(&repo) +} + +/// +pub(crate) fn get_first_remote_in_repo( + repo: &Repository, +) -> Result { + scope_time!("get_remotes"); + + let remotes = repo.remotes()?; + + let first_remote = remotes + .iter() + .next() + .flatten() + .map(String::from) + .ok_or_else(|| Error::Generic("no remote found".into()))?; + + Ok(first_remote) +} + /// pub fn fetch_origin(repo_path: &str, branch: &str) -> Result { scope_time!("fetch_origin"); let repo = utils::repo(repo_path)?; - let mut remote = repo.find_remote(DEFAULT_REMOTE_NAME)?; + let mut remote = + repo.find_remote(&get_first_remote_in_repo(&repo)?)?; let mut options = FetchOptions::new(); options.remote_callbacks(remote_callbacks(None, None)); @@ -247,8 +271,39 @@ mod tests { let remotes = get_remotes(repo_path).unwrap(); - assert_eq!(remotes, vec![String::from(DEFAULT_REMOTE_NAME)]); + assert_eq!(remotes, vec![String::from("origin")]); fetch_origin(repo_path, "master").unwrap(); } + + #[test] + fn test_first_remote() { + let td = TempDir::new().unwrap(); + + debug_cmd_print( + td.path().as_os_str().to_str().unwrap(), + "git clone https://github.com/extrawurst/brewdump.git", + ); + + debug_cmd_print( + td.path().as_os_str().to_str().unwrap(), + "cd brewdump && git remote add second https://github.com/extrawurst/brewdump.git", + ); + + let repo_path = td.path().join("brewdump"); + let repo_path = repo_path.as_os_str().to_str().unwrap(); + + let remotes = get_remotes(repo_path).unwrap(); + + assert_eq!( + remotes, + vec![String::from("origin"), String::from("second")] + ); + + let first = get_first_remote_in_repo( + &utils::repo(repo_path).unwrap(), + ) + .unwrap(); + assert_eq!(first, String::from("origin")); + } } diff --git a/src/components/push.rs b/src/components/push.rs index 21f183f1..41ec9cb5 100644 --- a/src/components/push.rs +++ b/src/components/push.rs @@ -10,13 +10,15 @@ use crate::{ }; use anyhow::Result; use asyncgit::{ - sync::cred::{ - extract_username_password, need_username_password, - BasicAuthCredential, + sync::{ + cred::{ + extract_username_password, need_username_password, + BasicAuthCredential, + }, + get_first_remote, }, - sync::DEFAULT_REMOTE_NAME, AsyncNotification, AsyncPush, PushProgress, PushProgressState, - PushRequest, + PushRequest, CWD, }; use crossbeam_channel::Sender; use crossterm::event::Event; @@ -69,9 +71,9 @@ impl PushComponent { pub fn push(&mut self, branch: String) -> Result<()> { self.branch = branch; self.show()?; - if need_username_password(DEFAULT_REMOTE_NAME)? { - let cred = extract_username_password(DEFAULT_REMOTE_NAME) - .unwrap_or_else(|_| { + if need_username_password()? { + let cred = + extract_username_password().unwrap_or_else(|_| { BasicAuthCredential::new(None, None) }); if cred.is_complete() { @@ -92,8 +94,7 @@ impl PushComponent { self.pending = true; self.progress = None; self.git_push.request(PushRequest { - //TODO: find tracking branch name - remote: String::from(DEFAULT_REMOTE_NAME), + remote: get_first_remote(CWD)?, branch: self.branch.clone(), basic_credential: cred, })?;