diff --git a/gcp/workers/recoverer/recoverer.py b/gcp/workers/recoverer/recoverer.py index 3ceee794f5d..5d8f0aa996f 100644 --- a/gcp/workers/recoverer/recoverer.py +++ b/gcp/workers/recoverer/recoverer.py @@ -187,6 +187,22 @@ def transaction(): else: modified = datetime.datetime.now(datetime.UTC) + elif f == 'related': + related_group = osv.RelatedGroup.get_by_id(vuln_id) + if related_group is None: + related = [] + related_modified = datetime.datetime.now(datetime.UTC) + else: + related = related_group.related_ids + related_modified = related_group.last_modified + # Only update the modified time if it's actually being modified + if vuln_proto.related != related: + vuln_proto.related[:] = related + if related_modified > modified: + modified = related_modified + else: + modified = datetime.datetime.now(datetime.UTC) + vuln_proto.modified.FromDatetime(modified) osv.ListedVulnerability.from_vulnerability(vuln_proto).put() vuln.modified = modified diff --git a/gcp/workers/worker/worker.py b/gcp/workers/worker/worker.py index 5ab8e058541..648cb8d5d3e 100644 --- a/gcp/workers/worker/worker.py +++ b/gcp/workers/worker/worker.py @@ -612,14 +612,6 @@ def xact(): ds_vuln = osv.Vulnerability.get_by_id(vulnerability.id) is_new_bug = ds_vuln is None - # Compute the related fields here first. - # TODO(michaelkedar): Make a related computation in relations cron - related_raw = vulnerability.related - q = osv.Vulnerability.query( - osv.Vulnerability.related_raw == vulnerability.id) - related = set(vulnerability.related).union(set(r.key.id() for r in q)) - vulnerability.related[:] = sorted(related) - old_published = None # Update the schema version @@ -673,13 +665,15 @@ def xact(): new_vulnerability.aliases.clear() old_vulnerability.upstream.clear() new_vulnerability.upstream.clear() + old_vulnerability.related.clear() + new_vulnerability.related.clear() has_changed = old_vulnerability != new_vulnerability ds_vuln.is_withdrawn = vulnerability.HasField('withdrawn') ds_vuln.modified_raw = orig_modified_date ds_vuln.alias_raw = list(vulnerability.aliases) - ds_vuln.related_raw = list(related_raw) + ds_vuln.related_raw = list(vulnerability.related) ds_vuln.upstream_raw = list(vulnerability.upstream) # Update the bug entity based on the comparison. if has_changed: @@ -701,6 +695,10 @@ def xact(): if upstream_group: vulnerability.upstream[:] = sorted(upstream_group.upstream_ids) ds_vuln.modified = max(upstream_group.last_modified, ds_vuln.modified) + related_group = osv.RelatedGroup.get_by_id(vulnerability.id) + if related_group: + vulnerability.related[:] = sorted(related_group.related_ids) + ds_vuln.modified = max(related_group.modified, ds_vuln.modified) # Make sure modified date is >= withdrawn date if ds_vuln.is_withdrawn and vulnerability.withdrawn.ToDatetime( datetime.UTC) > ds_vuln.modified: diff --git a/go/cmd/relations/related.go b/go/cmd/relations/related.go new file mode 100644 index 00000000000..fb3e04fbac8 --- /dev/null +++ b/go/cmd/relations/related.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log/slog" + "slices" + "time" + + "cloud.google.com/go/datastore" + "github.com/google/osv.dev/go/logger" + "github.com/google/osv.dev/go/osv/models" + "google.golang.org/api/iterator" +) + +// computeRelated computes all related groups for the given vulns. +// `groups` is a map of vuln IDs to their related IDs. +// `withdrawnVulns` is a map of withdrawn vulns. +// Returns a map of vuln IDs to their related IDs, with the inverse relation added. +// `groups` is modified in place. +func computeRelated(groups map[string][]string, withdrawnVulns map[string]struct{}) map[string][]string { + // Add the inverse relation of the groups to the map + for id, group := range groups { + if _, ok := withdrawnVulns[id]; ok { + // We want to prevent withdrawn vulns IDs from being added to related groups, + // if the withdrawn vuln itself references other non-withdrawn vulns. + // For example: + // - If A (withdrawn) relates to B (valid), B should NOT list A. + // - If A (valid) relates to B (withdrawn), B SHOULD list A. + continue + } + for _, related := range group { + if slices.Contains(groups[related], id) { + continue + } + groups[related] = append(groups[related], id) + slices.Sort(groups[related]) + } + } + + return groups +} + +func updateRelated(ctx context.Context, cl *datastore.Client, id string, relatedIDs []string, ch chan<- Update) error { + if len(relatedIDs) == 0 { + logger.Info("Deleting related group due to no related vulns", slog.String("id", id)) + if err := cl.Delete(ctx, datastore.NameKey("RelatedGroup", id, nil)); err != nil { + return err + } + ch <- Update{ + ID: id, + Timestamp: time.Now().UTC(), + Field: updateFieldRelated, + Value: nil, + } + + return nil + } + + group := models.RelatedGroup{ + RelatedIDs: relatedIDs, + Modified: time.Now().UTC(), + } + if _, err := cl.Put(ctx, datastore.NameKey("RelatedGroup", id, nil), &group); err != nil { + return err + } + ch <- Update{ + ID: id, + Timestamp: group.Modified, + Field: updateFieldRelated, + Value: relatedIDs, + } + + return nil +} + +func ComputeRelatedGroups(ctx context.Context, cl *datastore.Client, ch chan<- Update) error { + // Query for all vulns that have related. + // It's easier to recompute all groups than to try and figure out which ones + // need to be recomputed. + logger.Info("Retrieving vulns for related computation...") + q := datastore.NewQuery("Vulnerability").FilterField("related_raw", ">", "") + + rawRelated := make(map[string][]string) + withdrawnVulns := make(map[string]struct{}) + it := cl.Run(ctx, q) + for { + var v models.Vulnerability + _, err := it.Next(&v) + if errors.Is(err, iterator.Done) { + break + } + if err != nil { + return fmt.Errorf("failed to iterate vulnerabilities: %w", err) + } + if v.IsWithdrawn { + withdrawnVulns[v.Key.Name] = struct{}{} + } + related := slices.Clone(v.RelatedRaw) + slices.Sort(related) + related = slices.Compact(related) + rawRelated[v.Key.Name] = related + } + logger.Info("Retrieved vulns with related ids", slog.Int("count", len(rawRelated))) + + logger.Info("Retrieving related groups...") + q = datastore.NewQuery("RelatedGroup") + it = cl.Run(ctx, q) + relatedGroups := make(map[string]models.RelatedGroup) + for { + var group models.RelatedGroup + _, err := it.Next(&group) + if errors.Is(err, iterator.Done) { + break + } + if err != nil { + return fmt.Errorf("failed to iterate related groups: %w", err) + } + relatedGroups[group.Key.Name] = group + } + logger.Info("Related groups successfully retrieved", slog.Int("count", len(relatedGroups))) + + related := computeRelated(rawRelated, withdrawnVulns) + + for id, relatedIDs := range related { + g, ok := relatedGroups[id] + delete(relatedGroups, id) + if !ok || !slices.Equal(g.RelatedIDs, relatedIDs) { + if err := updateRelated(ctx, cl, id, relatedIDs, ch); err != nil { + return fmt.Errorf("failed to update related group: %w", err) + } + } + } + + // The remaining groups in relatedGroups are the ones that are no longer + // present in the vulns, so we delete them. + for id := range relatedGroups { + if err := updateRelated(ctx, cl, id, nil, ch); err != nil { + return fmt.Errorf("failed to delete related group: %w", err) + } + } + + return nil +} diff --git a/go/cmd/relations/related_test.go b/go/cmd/relations/related_test.go new file mode 100644 index 00000000000..ba07606367b --- /dev/null +++ b/go/cmd/relations/related_test.go @@ -0,0 +1,123 @@ +package main + +import ( + "context" + "slices" + "testing" + "time" + + "cloud.google.com/go/datastore" + "github.com/google/go-cmp/cmp" + "github.com/google/osv.dev/go/osv/models" + "github.com/google/osv.dev/go/testutils" +) + +func TestComputeRelated(t *testing.T) { + tests := []struct { + name string + groups map[string][]string + want map[string][]string + }{ + { + name: "Unrelated groups", + groups: map[string][]string{"A": {"B"}, "C": {"D"}}, + want: map[string][]string{"A": {"B"}, "B": {"A"}, "C": {"D"}, "D": {"C"}}, + }, + { + name: "Related groups", + groups: map[string][]string{"A": {"B", "C"}, "B": {"A"}}, + want: map[string][]string{"A": {"B", "C"}, "B": {"A"}, "C": {"A"}}, + }, + { + name: "Already computed", + groups: map[string][]string{"A": {"B"}, "B": {"A"}}, + want: map[string][]string{"A": {"B"}, "B": {"A"}}, + }, + { + name: "Circular", + groups: map[string][]string{"A": {"B"}, "B": {"C"}, "C": {"A"}}, + want: map[string][]string{"A": {"B", "C"}, "B": {"A", "C"}, "C": {"A", "B"}}, + }, + { + name: "Empty", + groups: map[string][]string{}, + want: map[string][]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeRelated(tt.groups, map[string]struct{}{}) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("computeRelated() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestComputeRelatedGroups(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + + // Setup Datastore + vulns := []*models.Vulnerability{ + { + Key: datastore.NameKey("Vulnerability", "A", nil), + RelatedRaw: []string{"B"}, + Modified: time.Now().UTC(), + }, + { + Key: datastore.NameKey("Vulnerability", "B", nil), + RelatedRaw: []string{"A"}, + Modified: time.Now().UTC(), + }, + { + Key: datastore.NameKey("Vulnerability", "C", nil), + RelatedRaw: []string{"A", "D"}, + Modified: time.Now().UTC(), + }, + { + Key: datastore.NameKey("Vulnerability", "D", nil), + RelatedRaw: []string{"E"}, // Withdrawn, should be ignored + Modified: time.Now().UTC(), + IsWithdrawn: true, + }, + } + keys := make([]*datastore.Key, len(vulns)) + for i, v := range vulns { + keys[i] = v.Key + } + + if _, err := dsClient.PutMulti(ctx, keys, vulns); err != nil { + t.Fatalf("failed to put vulns: %v", err) + } + + ch := make(chan Update, 100) + if err := ComputeRelatedGroups(ctx, dsClient, ch); err != nil { + t.Fatalf("ComputeRelatedGroups failed: %v", err) + } + close(ch) + + // Check results + var groups []models.RelatedGroup + if _, err := dsClient.GetAll(ctx, datastore.NewQuery("RelatedGroup"), &groups); err != nil { + t.Fatalf("failed to get related groups: %v", err) + } + + expected := map[string][]string{ + "A": {"B", "C"}, + "B": {"A"}, + "C": {"A", "D"}, + "D": {"C", "E"}, + } + + got := make(map[string][]string) + for _, g := range groups { + slices.Sort(g.RelatedIDs) + got[g.Key.Name] = g.RelatedIDs + } + + if diff := cmp.Diff(expected, got); diff != "" { + t.Errorf("RelatedGroups mismatch (-want +got):\n%s", diff) + } +} diff --git a/go/cmd/relations/relations.go b/go/cmd/relations/relations.go index ab96d2afeab..8ef20389bed 100644 --- a/go/cmd/relations/relations.go +++ b/go/cmd/relations/relations.go @@ -61,6 +61,11 @@ func main() { logger.Error("failed to compute upstream groups", slog.Any("err", err)) } }) + wg.Go(func() { + if err := ComputeRelatedGroups(ctx, gc.datastoreClient, updater.Ch); err != nil { + logger.Error("failed to compute related groups", slog.Any("err", err)) + } + }) wg.Wait() updater.Finish() } diff --git a/go/cmd/relations/update.go b/go/cmd/relations/update.go index 49b36d2fce3..bce6a1aa4cf 100644 --- a/go/cmd/relations/update.go +++ b/go/cmd/relations/update.go @@ -160,6 +160,21 @@ func (u *Updater) run(ctx context.Context) { if u.Timestamp.After(modified) { modified = u.Timestamp } + case updateFieldRelated: + val, ok := u.Value.([]string) + if !ok { + logger.Error("updated related are not []string", slog.String("id", id)) + continue + } + if slices.Compare(v.GetRelated(), val) == 0 { + continue + } + hasUpdates = true + v.Related = val + updatedFields = append(updatedFields, "related") + if u.Timestamp.After(modified) { + modified = u.Timestamp + } default: logger.Error("unsupported update field", slog.Any("updateField", u.Field), slog.String("id", id)) } diff --git a/go/go.mod b/go/go.mod index ef3d97efcae..fd5126be6a0 100644 --- a/go/go.mod +++ b/go/go.mod @@ -8,6 +8,7 @@ require ( cloud.google.com/go/pubsub/v2 v2.3.0 cloud.google.com/go/storage v1.59.1 github.com/charmbracelet/lipgloss v1.1.0 + github.com/google/go-cmp v0.7.0 github.com/ossf/osv-schema/bindings/go v0.0.0-20260128001339-9d03e8f4632b golang.org/x/sync v0.19.0 google.golang.org/api v0.263.0 diff --git a/go/osv/models/internal/validate/validate.go b/go/osv/models/internal/validate/validate.go index df6007dd48b..854aee8493f 100644 --- a/go/osv/models/internal/validate/validate.go +++ b/go/osv/models/internal/validate/validate.go @@ -75,6 +75,14 @@ func readRecords(ctx context.Context, client *datastore.Client) { fmt.Printf("(Go) Failed getting ListedVulnerability: %v\n", err) os.Exit(1) } + + fmt.Println("(Go) Getting RelatedGroup") + key = datastore.NameKey("RelatedGroup", "CVE-123-456", nil) + var relatedGroup models.RelatedGroup + if err := client.Get(ctx, key, &relatedGroup); err != nil { + fmt.Printf("(Go) Failed getting RelatedGroup: %v\n", err) + os.Exit(1) + } } func writeRecords(ctx context.Context, client *datastore.Client) { @@ -155,4 +163,15 @@ func writeRecords(ctx context.Context, client *datastore.Client) { fmt.Printf("(Go) Failed writing ListedVulnerability %v: %v\n", key, err) os.Exit(1) } + + fmt.Println("(Go) Writing RelatedGroup") + key = datastore.NameKey("RelatedGroup", "CVE-987-654", nil) + relatedGroup := models.RelatedGroup{ + RelatedIDs: []string{"R-1", "R-2"}, + Modified: time.Date(2025, time.January, 1, 1, 1, 1, 1, time.UTC), + } + if _, err := client.Put(ctx, key, &relatedGroup); err != nil { + fmt.Printf("(Go) Failed writing RelatedGroup %v: %v\n", key, err) + os.Exit(1) + } } diff --git a/go/osv/models/internal/validate/validate.py b/go/osv/models/internal/validate/validate.py index 6ebac2f40ef..c03f775632e 100644 --- a/go/osv/models/internal/validate/validate.py +++ b/go/osv/models/internal/validate/validate.py @@ -20,7 +20,8 @@ import osv.tests from osv import Vulnerability, AliasGroup, AliasAllowListEntry, \ - AliasDenyListEntry, ListedVulnerability, Severity, UpstreamGroup + AliasDenyListEntry, ListedVulnerability, Severity, UpstreamGroup, \ + RelatedGroup def main() -> int: @@ -83,6 +84,13 @@ def main() -> int: search_indices=['cve-123-456', 'stdlib', 'requests'], ).put() + print('(Python) Putting RelatedGroup') + RelatedGroup( + id='CVE-123-456', + related_ids=['R-1', 'R-2'], + modified=datetime.datetime(2025, 6, 7, 8, 9, 10, tzinfo=datetime.UTC), + ).put() + # Run Go program to read the Python-created entities in Go. # And write Go entities. result = subprocess.run(['go', 'run', './validate.go'], check=False, cwd='.') @@ -108,6 +116,9 @@ def main() -> int: print('(Python) Getting ListedVulnerability') if ListedVulnerability.get_by_id('CVE-987-654') is None: return 1 + print('(Python) Getting RelatedGroup') + if RelatedGroup.get_by_id('CVE-987-654') is None: + return 1 return 0 diff --git a/go/osv/models/models.go b/go/osv/models/models.go index 38a5515ee23..176f1e45d8c 100644 --- a/go/osv/models/models.go +++ b/go/osv/models/models.go @@ -45,6 +45,12 @@ type UpstreamGroup struct { UpstreamHierarchy []byte `datastore:"upstream_hierarchy,noindex"` } +type RelatedGroup struct { + Key *datastore.Key `datastore:"__key__"` + RelatedIDs []string `datastore:"related_ids"` + Modified time.Time `datastore:"modified"` +} + type AliasAllowListEntry struct { VulnID string `datastore:"bug_id"` } diff --git a/osv/models.py b/osv/models.py index 0473bc13fa3..a7e1f106f60 100644 --- a/osv/models.py +++ b/osv/models.py @@ -1560,6 +1560,15 @@ class UpstreamGroup(ndb.Model): last_modified: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) +class RelatedGroup(ndb.Model): + """Related group for storing related ids of a Vulnerability""" + # Key is Vuln ID + # List of related ids + related_ids: list[str] = ndb.StringProperty(repeated=True) + # Date when group was last modified + modified: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) + + # --- ImportFinding --- # TODO(gongh@): redesign this to make it easy to scale. class ImportFindings(enum.IntEnum):