diff --git a/dev/merge_pr.py b/dev/merge_pr.py new file mode 100644 index 00000000000..ef87dad4d12 --- /dev/null +++ b/dev/merge_pr.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python3 +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""merge_pr.py - Merge Apache Zeppelin pull requests via the GitHub API. + +Optionally cherry-picks into release branches and resolves JIRA issues. +No external dependencies — uses only Python 3 built-in libraries. + +Usage: + python3 dev/merge_pr.py --pr 5167 --dry-run + python3 dev/merge_pr.py --pr 5167 --resolve-jira --fix-versions 0.13.0 + python3 dev/merge_pr.py --pr 5167 --resolve-jira --release-branches branch-0.12 +""" + +import argparse +import json +import os +import re +import subprocess +import sys +import urllib.error +import urllib.request + +GITHUB_API_BASE = "https://api.github.com/repos/apache/zeppelin" +JIRA_API_BASE = "https://issues.apache.org/jira/rest/api/2" + +DEFAULT_BRANCH = "master" +DEFAULT_REMOTE = "apache" +JIRA_RESOLVE_TRANSITION = "Resolve Issue" +JIRA_CLOSED_STATUSES = frozenset(("Resolved", "Closed")) + +JIRA_ID_RE = re.compile(r"ZEPPELIN-\d{3,6}") +TITLE_FORMATTED_RE = re.compile(r"^\[ZEPPELIN-\d{3,6}](\[[A-Z0-9_\s,]+] )+\S+") +TITLE_REF_RE = re.compile(r"(?i)(ZEPPELIN[-\s]*\d{3,6})") +COMPONENT_RE = re.compile(r"(?i)(\[[\w\s,.\-]+])") +WHITESPACE_RE = re.compile(r"\s+") +LEADING_NON_WORD_RE = re.compile(r"^\W+") +SEMANTIC_VER_RE = re.compile(r"^\d+\.\d+\.\d+$") + + +class MergePR: + def __init__(self, args): + self.pr = args.pr + self.target = args.target or "" + self.fix_versions = _parse_csv(args.fix_versions) if args.fix_versions else [] + self.release_branches = _parse_csv(args.release_branches) if args.release_branches else [] + self.resolve_jira = args.resolve_jira + self.dry_run = args.dry_run + self.push_remote = args.push_remote or os.environ.get("PUSH_REMOTE_NAME", DEFAULT_REMOTE) + self.github_token = args.github_token or os.environ.get("GITHUB_OAUTH_KEY", "") + self.jira_token = args.jira_token or os.environ.get("JIRA_ACCESS_TOKEN", "") + + # ── Git ────────────────────────────────────────────────────────────── + + def _git(self, *args): + result = subprocess.run( + ["git", *args], + capture_output=True, text=True, + ) + if result.returncode != 0: + output = (result.stdout + result.stderr).strip() + raise RuntimeError(f"git {' '.join(args)} failed:\n{output}") + return result.stdout.strip() + + def _git_current_ref(self): + ref = self._git("rev-parse", "--abbrev-ref", "HEAD") + return self._git("rev-parse", "HEAD") if ref == "HEAD" else ref + + # ── HTTP ───────────────────────────────────────────────────────────── + + def _http(self, method, url, payload=None, auth=""): + data = json.dumps(payload).encode() if payload is not None else None + req = urllib.request.Request(url, data=data, method=method) + req.add_header("Content-Type", "application/json") + req.add_header("Accept", "application/json") + if auth: + req.add_header("Authorization", auth) + try: + with urllib.request.urlopen(req) as resp: + return resp.status, json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + err_body = e.read().decode() if e.fp else "" + try: + return e.code, json.loads(err_body) + except json.JSONDecodeError: + return e.code, {"error": err_body} + + # ── GitHub ─────────────────────────────────────────────────────────── + + def _gh_auth(self): + return f"token {self.github_token}" if self.github_token else "" + + def _gh_get_pr(self, num): + code, data = self._http("GET", f"{GITHUB_API_BASE}/pulls/{num}", auth=self._gh_auth()) + if code != 200: + raise RuntimeError(f"GET PR #{num}: HTTP {code}") + return data + + def _gh_merge_pr(self, num, title, msg): + payload = {"commit_title": title, "commit_message": msg, "merge_method": "squash"} + code, data = self._http("PUT", f"{GITHUB_API_BASE}/pulls/{num}/merge", payload, self._gh_auth()) + if code == 405: + raise RuntimeError(f"Merge PR #{num} is not allowed") + if code != 200: + raise RuntimeError(f"Merge PR #{num}: HTTP {code}") + return data + + def _gh_comment_pr(self, num, comment): + code, _ = self._http("POST", f"{GITHUB_API_BASE}/issues/{num}/comments", + {"body": comment}, self._gh_auth()) + if code != 201: + print(f"Warning: comment PR #{num}: HTTP {code}", file=sys.stderr) + + # ── JIRA ───────────────────────────────────────────────────────────── + + def _jira_auth(self): + return f"Bearer {self.jira_token}" if self.jira_token else "" + + def _jira_get_issue(self, key): + code, data = self._http("GET", f"{JIRA_API_BASE}/issue/{key}", auth=self._jira_auth()) + if code != 200: + raise RuntimeError(f"GET {key}: HTTP {code}") + return data + + def _jira_unreleased_versions(self): + code, data = self._http("GET", f"{JIRA_API_BASE}/project/ZEPPELIN/versions", auth=self._jira_auth()) + if code != 200: + raise RuntimeError(f"GET versions: HTTP {code}") + versions = [] + for v in data: + name = v.get("name", "") + if not v.get("released") and not v.get("archived") and SEMANTIC_VER_RE.match(name): + versions.append({"id": str(v["id"]), "name": name}) + versions.sort(key=lambda v: _ver_tuple(v["name"]), reverse=True) + return versions + + def _jira_transitions(self, key): + code, data = self._http("GET", f"{JIRA_API_BASE}/issue/{key}/transitions", auth=self._jira_auth()) + if code != 200: + raise RuntimeError(f"GET transitions {key}: HTTP {code}") + return [{"id": t["id"], "name": t["name"]} for t in data.get("transitions", [])] + + def _jira_resolve(self, key, transition_id, fix_ver, comment): + payload = { + "transition": {"id": transition_id}, + "update": { + "comment": [{"add": {"body": comment}}], + "fixVersions": [{"add": {"id": fv["id"], "name": fv["name"]}} for fv in fix_ver], + }, + } + code, _ = self._http("POST", f"{JIRA_API_BASE}/issue/{key}/transitions", payload, self._jira_auth()) + if code != 204: + raise RuntimeError(f"Resolve {key}: HTTP {code}") + + # ── Fix version resolution ─────────────────────────────────────────── + + def _resolve_fix_versions(self, branches, versions): + """Resolve fix version objects from explicit --fix-versions and branch inference. + + Returns a list of version dicts ({"id": ..., "name": ...}). + Raises RuntimeError if an explicit fix version is not found. + """ + vm = {v["name"]: v for v in versions} + fix_ver, seen = [], set() + + for fv in self.fix_versions: + if fv not in vm: + raise RuntimeError(f'fix version "{fv}" not found') + fix_ver.append(vm[fv]) + seen.add(fv) + + infer_master = not self.fix_versions + latest = versions[0]["name"] + names = [] + for branch in branches: + if branch == DEFAULT_BRANCH: + if infer_master and latest not in seen: + names.append(latest) + seen.add(latest) + else: + prefix = branch[len("branch-"):] if branch.startswith("branch-") else branch + found = [v["name"] for v in versions if v["name"].startswith(prefix + ".") or v["name"] == prefix] + if found: + pick = found[-1] # smallest matching (list is desc-sorted) + if pick not in seen: + names.append(pick) + seen.add(pick) + else: + print(f"Warning: no version found for {branch}, skipping", file=sys.stderr) + + # Remove redundant X.Y.0 when X.(Y-1).0 is also present + filtered = [] + for v in names: + parts = v.split(".") + if len(parts) == 3 and parts[2] == "0": + minor = int(parts[1]) + if minor > 0 and f"{parts[0]}.{minor - 1}.0" in seen: + continue + filtered.append(v) + + inferred = [vm[n] for n in filtered if n in vm] + if inferred: + print(f"Auto-inferred fix version(s): {', '.join(filtered)}") + fix_ver.extend(inferred) + return fix_ver + + # ── Effective command ──────────────────────────────────────────────── + + def _print_effective_command(self, target_branch, fix_ver): + parts = ["python3 dev/merge_pr.py", f"--pr {self.pr}"] + if target_branch and target_branch != DEFAULT_BRANCH: + parts.append(f"--target {target_branch}") + if self.release_branches: + parts.append(f"--release-branches {','.join(self.release_branches)}") + if self.resolve_jira: + parts.append("--resolve-jira") + if fix_ver: + parts.append(f"--fix-versions {','.join(fv['name'] for fv in fix_ver)}") + if self.push_remote != DEFAULT_REMOTE: + parts.append(f"--push-remote {self.push_remote}") + print(f"[dry-run] Effective command:\n {' '.join(parts)}") + + # ── Main flow ──────────────────────────────────────────────────────── + + def run(self): + original_head = self._git_current_ref() + + pr_data = self._gh_get_pr(self.pr) + if not pr_data.get("mergeable"): + raise RuntimeError(f"PR #{self.pr} is not mergeable") + pr_title = pr_data["title"] + if "[WIP]" in pr_title: + print(f"WARNING: PR title contains [WIP]: {pr_title}", file=sys.stderr) + + target_branch = self.target or pr_data["base"]["ref"] + title = _standardize_title(pr_title) + src = f"{pr_data['user']['login']}/{pr_data['head']['ref']}" + pr_body = pr_data.get("body", "") or "" + + print(f"=== Pull Request #{self.pr} ===") + print(f"title: {title}") + print(f"source: {src}") + print(f"target: {target_branch}") + print(f"url: {pr_data['url']}") + if self.release_branches: + print(f"release-branches: {', '.join(self.release_branches)}") + + # Resolve fix versions once (used for both dry-run display and actual JIRA resolution) + fix_ver = [] + if self.resolve_jira and self.jira_token and JIRA_ID_RE.search(title): + try: + versions = self._jira_unreleased_versions() + if versions: + branches = [target_branch] + self.release_branches + fix_ver = self._resolve_fix_versions(branches, versions) + except RuntimeError as e: + print(f"Warning: failed to resolve fix versions: {e}", file=sys.stderr) + + if self.dry_run: + print() + self._print_effective_command(target_branch, fix_ver) + return + + # Merge + body = pr_body.replace("@", "") + try: + name = self._git("config", "--get", "user.name") + except RuntimeError: + name = "" + try: + email = self._git("config", "--get", "user.email") + except RuntimeError: + email = "" + msg = f"{body}\n\nCloses #{self.pr} from {src}.\n\nSigned-off-by: {name} <{email}>" + + merge_data = self._gh_merge_pr(self.pr, title, msg) + sha = merge_data["sha"] + print(f"\nPR #{self.pr} merged! (hash: {_short_sha(sha)})") + + try: + self._git("fetch", self.push_remote, target_branch) + except RuntimeError: + pass + + # Cherry-pick into release branches + merged = [target_branch] + for branch in self.release_branches: + pick = _pick_branch_name(self.pr, branch) + try: + self._git("fetch", self.push_remote, f"{branch}:{pick}") + except RuntimeError as e: + print(f"Warning: fetch {branch} failed: {e}", file=sys.stderr) + continue + self._git("checkout", pick) + try: + self._git("cherry-pick", "-sx", sha) + self._git("push", self.push_remote, f"{pick}:{branch}") + h = self._git("rev-parse", pick) + print(f"Picked into {branch} (hash: {_short_sha(h)})") + merged.append(branch) + except RuntimeError as e: + print(f"Warning: cherry-pick/push into {branch} failed: {e}", file=sys.stderr) + try: + self._git("cherry-pick", "--abort") + except RuntimeError: + pass + finally: + self._git("checkout", original_head) + self._git("branch", "-D", pick) + + self._comment_merge_summary(merged, sha) + + if self.resolve_jira: + try: + self._do_resolve_jira(title, fix_ver) + except RuntimeError as e: + print(f"Warning: JIRA resolution failed: {e}", file=sys.stderr) + + def _comment_merge_summary(self, merged, sha): + lines = [f"Merged into {merged[0]} ({_short_sha(sha)})."] + for branch in merged[1:]: + lines.append(f"Cherry-picked into {branch}.") + try: + self._gh_comment_pr(self.pr, "\n".join(lines)) + print("Commented on PR with merge summary.") + except RuntimeError as e: + print(f"Warning: failed to comment on PR: {e}", file=sys.stderr) + + def _do_resolve_jira(self, title, fix_ver): + if not self.jira_token: + raise RuntimeError("JIRA_ACCESS_TOKEN is not set") + + ids = JIRA_ID_RE.findall(title) + if not ids: + print("No JIRA ID found in PR title, skipping.") + return + + for jira_id in ids: + try: + issue = self._jira_get_issue(jira_id) + except RuntimeError as e: + print(f"Warning: get {jira_id}: {e}", file=sys.stderr) + continue + status = issue.get("fields", {}).get("status", {}).get("name", "") + if status in JIRA_CLOSED_STATUSES: + print(f'JIRA {jira_id} already "{status}", skipping.') + continue + + print(f"=== JIRA {jira_id} ===") + print(f"Summary: {issue.get('fields', {}).get('summary', '')}") + print(f"Status: {status}") + + transitions = self._jira_transitions(jira_id) + resolve_id = next((t["id"] for t in transitions if t["name"] == JIRA_RESOLVE_TRANSITION), None) + if not resolve_id: + print(f"Warning: no '{JIRA_RESOLVE_TRANSITION}' transition for {jira_id}", file=sys.stderr) + continue + + jira_comment = ( + f"Issue resolved by pull request {self.pr}" + f"\n[https://github.com/apache/zeppelin/pull/{self.pr}]" + ) + try: + self._jira_resolve(jira_id, resolve_id, fix_ver, jira_comment) + print(f"Resolved {jira_id}!") + except RuntimeError as e: + print(f"Warning: resolve {jira_id}: {e}", file=sys.stderr) + + +# ── Module-level utilities ─────────────────────────────────────────────── + +def _parse_csv(value): + return [s.strip() for s in value.split(",") if s.strip()] if value else [] + + +def _ver_tuple(v): + return tuple(int(x) for x in v.split(".")) + + +def _short_sha(sha): + return sha[:8] if len(sha) > 8 else sha + + +def _pick_branch_name(pr_num, branch): + return f"PR_TOOL_PICK_PR_{pr_num}_{branch.upper()}" + + +def _standardize_title(text): + text = text.rstrip(".") + if text.startswith('Revert "') and text.endswith('"'): + return text + if TITLE_FORMATTED_RE.match(text): + return text + + jira_refs = [] + for m in TITLE_REF_RE.finditer(text): + ref = m.group(1) + jira_refs.append("[" + WHITESPACE_RE.sub("-", ref.upper()) + "]") + text = text.replace(ref, "") + + components = [] + for m in COMPONENT_RE.finditer(text): + comp = m.group(1) + components.append(comp.upper()) + text = text.replace(comp, "") + + text = LEADING_NON_WORD_RE.sub("", text) + result = "".join(jira_refs) + "".join(components) + " " + text + return WHITESPACE_RE.sub(" ", result.strip()) + + +# ── Entry point ────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description="Merge Apache Zeppelin pull requests", + usage="python3 dev/merge_pr.py [flags]", + ) + parser.add_argument("--pr", type=int, required=True, help="Pull request number") + parser.add_argument("--target", default="", help="Target branch (default: PR base branch)") + parser.add_argument("--fix-versions", default="", help="JIRA fix version(s), comma-separated") + parser.add_argument("--release-branches", default="", help="Release branch(es) to cherry-pick into, comma-separated") + parser.add_argument("--resolve-jira", action="store_true", help="Resolve associated JIRA issue(s)") + parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes") + parser.add_argument("--push-remote", default="", help="Git remote for pushing (default: apache)") + parser.add_argument("--github-token", default="", help="GitHub OAuth token (env: GITHUB_OAUTH_KEY)") + parser.add_argument("--jira-token", default="", help="JIRA access token (env: JIRA_ACCESS_TOKEN)") + + args = parser.parse_args() + MergePR(args).run() + + +if __name__ == "__main__": + main()