Skip to content

Commit 53a442a

Browse files
authored
fix(codeagent): get default branch accordingly (#266)
1 parent 8fc5d0e commit 53a442a

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

internal/github/client.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ func (c *Client) CreatePullRequest(workspace *models.Workspace) (*github.PullReq
125125
return nil, fmt.Errorf("invalid repository URL: %s", workspace.Repository)
126126
}
127127

128+
// 获取仓库的默认分支
129+
defaultBranch, err := c.getDefaultBranch(repoOwner, repoName)
130+
if err != nil {
131+
log.Errorf("Failed to get default branch for %s/%s, using 'main' as fallback: %v", repoOwner, repoName, err)
132+
defaultBranch = "main"
133+
}
134+
log.Infof("Using default branch '%s' for repository %s/%s", defaultBranch, repoOwner, repoName)
135+
128136
// 创建 PR
129137
prTitle := fmt.Sprintf("实现 Issue #%d: %s", workspace.Issue.GetNumber(), workspace.Issue.GetTitle())
130138
prBody := fmt.Sprintf(`## 实现计划
@@ -149,7 +157,7 @@ func (c *Client) CreatePullRequest(workspace *models.Workspace) (*github.PullReq
149157
Title: &prTitle,
150158
Body: &prBody,
151159
Head: &workspace.Branch,
152-
Base: github.String("main"), // 假设主分支是 main
160+
Base: &defaultBranch,
153161
}
154162

155163
pr, _, err := c.client.PullRequests.Create(context.Background(), repoOwner, repoName, newPR)
@@ -161,6 +169,21 @@ func (c *Client) CreatePullRequest(workspace *models.Workspace) (*github.PullReq
161169
return pr, nil
162170
}
163171

172+
// getDefaultBranch 获取仓库的默认分支
173+
func (c *Client) getDefaultBranch(owner, repo string) (string, error) {
174+
repository, _, err := c.client.Repositories.Get(context.Background(), owner, repo)
175+
if err != nil {
176+
return "", fmt.Errorf("failed to get repository info: %w", err)
177+
}
178+
179+
defaultBranch := repository.GetDefaultBranch()
180+
if defaultBranch == "" {
181+
return "", fmt.Errorf("repository has no default branch")
182+
}
183+
184+
return defaultBranch, nil
185+
}
186+
164187
// CommitAndPush 检测文件变更并提交推送
165188
func (c *Client) CommitAndPush(workspace *models.Workspace, result *models.ExecutionResult, codeClient code.Code) error {
166189
// 检查是否有文件变更

0 commit comments

Comments
 (0)