Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions gcp/workers/recoverer/recoverer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ def transaction():
else:
modified = datetime.datetime.now(datetime.UTC)

elif f == 'related':
related_group = osv.RelatedGroup.get_by_id(vuln_id)
if related_group is None:
related = []
related_modified = datetime.datetime.now(datetime.UTC)
else:
related = related_group.related_ids
related_modified = related_group.last_modified
# Only update the modified time if it's actually being modified
if vuln_proto.related != related:
vuln_proto.related[:] = related
if related_modified > modified:
modified = related_modified
else:
modified = datetime.datetime.now(datetime.UTC)

vuln_proto.modified.FromDatetime(modified)
osv.ListedVulnerability.from_vulnerability(vuln_proto).put()
vuln.modified = modified
Expand Down
16 changes: 7 additions & 9 deletions gcp/workers/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,6 @@ def xact():
ds_vuln = osv.Vulnerability.get_by_id(vulnerability.id)
is_new_bug = ds_vuln is None

# Compute the related fields here first.
# TODO(michaelkedar): Make a related computation in relations cron
related_raw = vulnerability.related
q = osv.Vulnerability.query(
osv.Vulnerability.related_raw == vulnerability.id)
related = set(vulnerability.related).union(set(r.key.id() for r in q))
vulnerability.related[:] = sorted(related)

old_published = None

# Update the schema version
Expand Down Expand Up @@ -673,13 +665,15 @@ def xact():
new_vulnerability.aliases.clear()
old_vulnerability.upstream.clear()
new_vulnerability.upstream.clear()
old_vulnerability.related.clear()
new_vulnerability.related.clear()

has_changed = old_vulnerability != new_vulnerability

ds_vuln.is_withdrawn = vulnerability.HasField('withdrawn')
ds_vuln.modified_raw = orig_modified_date
ds_vuln.alias_raw = list(vulnerability.aliases)
ds_vuln.related_raw = list(related_raw)
ds_vuln.related_raw = list(vulnerability.related)
ds_vuln.upstream_raw = list(vulnerability.upstream)
# Update the bug entity based on the comparison.
if has_changed:
Expand All @@ -701,6 +695,10 @@ def xact():
if upstream_group:
vulnerability.upstream[:] = sorted(upstream_group.upstream_ids)
ds_vuln.modified = max(upstream_group.last_modified, ds_vuln.modified)
related_group = osv.RelatedGroup.get_by_id(vulnerability.id)
if related_group:
vulnerability.related[:] = sorted(related_group.related_ids)
ds_vuln.modified = max(related_group.modified, ds_vuln.modified)
# Make sure modified date is >= withdrawn date
if ds_vuln.is_withdrawn and vulnerability.withdrawn.ToDatetime(
datetime.UTC) > ds_vuln.modified:
Expand Down
145 changes: 145 additions & 0 deletions go/cmd/relations/related.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package main

import (
"context"
"errors"
"fmt"
"log/slog"
"slices"
"time"

"cloud.google.com/go/datastore"
"github.com/google/osv.dev/go/logger"
"github.com/google/osv.dev/go/osv/models"
"google.golang.org/api/iterator"
)

// computeRelated computes all related groups for the given vulns.
// `groups` is a map of vuln IDs to their related IDs.
// `withdrawnVulns` is a map of withdrawn vulns.
// Returns a map of vuln IDs to their related IDs, with the inverse relation added.
// `groups` is modified in place.
func computeRelated(groups map[string][]string, withdrawnVulns map[string]struct{}) map[string][]string {
// Add the inverse relation of the groups to the map
for id, group := range groups {
if _, ok := withdrawnVulns[id]; ok {
// We want to prevent withdrawn vulns IDs from being added to related groups,
// if the withdrawn vuln itself references other non-withdrawn vulns.
// For example:
// - If A (withdrawn) relates to B (valid), B should NOT list A.
// - If A (valid) relates to B (withdrawn), B SHOULD list A.
continue
}
for _, related := range group {
if slices.Contains(groups[related], id) {
continue
}
groups[related] = append(groups[related], id)
slices.Sort(groups[related])
}
}

return groups
}

func updateRelated(ctx context.Context, cl *datastore.Client, id string, relatedIDs []string, ch chan<- Update) error {
if len(relatedIDs) == 0 {
logger.Info("Deleting related group due to no related vulns", slog.String("id", id))
if err := cl.Delete(ctx, datastore.NameKey("RelatedGroup", id, nil)); err != nil {
return err
}
ch <- Update{
ID: id,
Timestamp: time.Now().UTC(),
Field: updateFieldRelated,
Value: nil,
}

return nil
}

group := models.RelatedGroup{
RelatedIDs: relatedIDs,
Modified: time.Now().UTC(),
}
if _, err := cl.Put(ctx, datastore.NameKey("RelatedGroup", id, nil), &group); err != nil {
return err
}
ch <- Update{
ID: id,
Timestamp: group.Modified,
Field: updateFieldRelated,
Value: relatedIDs,
}

return nil
}

func ComputeRelatedGroups(ctx context.Context, cl *datastore.Client, ch chan<- Update) error {
// Query for all vulns that have related.
// It's easier to recompute all groups than to try and figure out which ones
// need to be recomputed.
logger.Info("Retrieving vulns for related computation...")
q := datastore.NewQuery("Vulnerability").FilterField("related_raw", ">", "")

rawRelated := make(map[string][]string)
withdrawnVulns := make(map[string]struct{})
it := cl.Run(ctx, q)
for {
var v models.Vulnerability
_, err := it.Next(&v)
if errors.Is(err, iterator.Done) {
break
}
if err != nil {
return fmt.Errorf("failed to iterate vulnerabilities: %w", err)
}
if v.IsWithdrawn {
withdrawnVulns[v.Key.Name] = struct{}{}
}
related := slices.Clone(v.RelatedRaw)
slices.Sort(related)
related = slices.Compact(related)
rawRelated[v.Key.Name] = related
}
logger.Info("Retrieved vulns with related ids", slog.Int("count", len(rawRelated)))

logger.Info("Retrieving related groups...")
q = datastore.NewQuery("RelatedGroup")
it = cl.Run(ctx, q)
relatedGroups := make(map[string]models.RelatedGroup)
for {
var group models.RelatedGroup
_, err := it.Next(&group)
if errors.Is(err, iterator.Done) {
break
}
if err != nil {
return fmt.Errorf("failed to iterate related groups: %w", err)
}
relatedGroups[group.Key.Name] = group
}
logger.Info("Related groups successfully retrieved", slog.Int("count", len(relatedGroups)))

related := computeRelated(rawRelated, withdrawnVulns)

for id, relatedIDs := range related {
g, ok := relatedGroups[id]
delete(relatedGroups, id)
if !ok || !slices.Equal(g.RelatedIDs, relatedIDs) {
if err := updateRelated(ctx, cl, id, relatedIDs, ch); err != nil {
return fmt.Errorf("failed to update related group: %w", err)
}
}
}

// The remaining groups in relatedGroups are the ones that are no longer
// present in the vulns, so we delete them.
for id := range relatedGroups {
if err := updateRelated(ctx, cl, id, nil, ch); err != nil {
return fmt.Errorf("failed to delete related group: %w", err)
}
}

return nil
}
123 changes: 123 additions & 0 deletions go/cmd/relations/related_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package main

import (
"context"
"slices"
"testing"
"time"

"cloud.google.com/go/datastore"
"github.com/google/go-cmp/cmp"
"github.com/google/osv.dev/go/osv/models"
"github.com/google/osv.dev/go/testutils"
)

func TestComputeRelated(t *testing.T) {
tests := []struct {
name string
groups map[string][]string
want map[string][]string
}{
{
name: "Unrelated groups",
groups: map[string][]string{"A": {"B"}, "C": {"D"}},
want: map[string][]string{"A": {"B"}, "B": {"A"}, "C": {"D"}, "D": {"C"}},
},
{
name: "Related groups",
groups: map[string][]string{"A": {"B", "C"}, "B": {"A"}},
want: map[string][]string{"A": {"B", "C"}, "B": {"A"}, "C": {"A"}},
},
{
name: "Already computed",
groups: map[string][]string{"A": {"B"}, "B": {"A"}},
want: map[string][]string{"A": {"B"}, "B": {"A"}},
},
{
name: "Circular",
groups: map[string][]string{"A": {"B"}, "B": {"C"}, "C": {"A"}},
want: map[string][]string{"A": {"B", "C"}, "B": {"A", "C"}, "C": {"A", "B"}},
},
{
name: "Empty",
groups: map[string][]string{},
want: map[string][]string{},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := computeRelated(tt.groups, map[string]struct{}{})
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("computeRelated() mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestComputeRelatedGroups(t *testing.T) {
ctx := context.Background()
dsClient := testutils.MustNewDatastoreClientForTesting(t)

// Setup Datastore
vulns := []*models.Vulnerability{
{
Key: datastore.NameKey("Vulnerability", "A", nil),
RelatedRaw: []string{"B"},
Modified: time.Now().UTC(),
},
{
Key: datastore.NameKey("Vulnerability", "B", nil),
RelatedRaw: []string{"A"},
Modified: time.Now().UTC(),
},
{
Key: datastore.NameKey("Vulnerability", "C", nil),
RelatedRaw: []string{"A", "D"},
Modified: time.Now().UTC(),
},
{
Key: datastore.NameKey("Vulnerability", "D", nil),
RelatedRaw: []string{"E"}, // Withdrawn, should be ignored
Modified: time.Now().UTC(),
IsWithdrawn: true,
},
}
keys := make([]*datastore.Key, len(vulns))
for i, v := range vulns {
keys[i] = v.Key
}

if _, err := dsClient.PutMulti(ctx, keys, vulns); err != nil {
t.Fatalf("failed to put vulns: %v", err)
}

ch := make(chan Update, 100)
if err := ComputeRelatedGroups(ctx, dsClient, ch); err != nil {
t.Fatalf("ComputeRelatedGroups failed: %v", err)
}
close(ch)

// Check results
var groups []models.RelatedGroup
if _, err := dsClient.GetAll(ctx, datastore.NewQuery("RelatedGroup"), &groups); err != nil {
t.Fatalf("failed to get related groups: %v", err)
}

expected := map[string][]string{
"A": {"B", "C"},
"B": {"A"},
"C": {"A", "D"},
"D": {"C", "E"},
}

got := make(map[string][]string)
for _, g := range groups {
slices.Sort(g.RelatedIDs)
got[g.Key.Name] = g.RelatedIDs
}

if diff := cmp.Diff(expected, got); diff != "" {
t.Errorf("RelatedGroups mismatch (-want +got):\n%s", diff)
}
}
5 changes: 5 additions & 0 deletions go/cmd/relations/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ func main() {
logger.Error("failed to compute upstream groups", slog.Any("err", err))
}
})
wg.Go(func() {
if err := ComputeRelatedGroups(ctx, gc.datastoreClient, updater.Ch); err != nil {
logger.Error("failed to compute related groups", slog.Any("err", err))
}
})
wg.Wait()
updater.Finish()
}
Expand Down
15 changes: 15 additions & 0 deletions go/cmd/relations/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ func (u *Updater) run(ctx context.Context) {
if u.Timestamp.After(modified) {
modified = u.Timestamp
}
case updateFieldRelated:
val, ok := u.Value.([]string)
if !ok {
logger.Error("updated related are not []string", slog.String("id", id))
continue
}
if slices.Compare(v.GetRelated(), val) == 0 {
continue
}
hasUpdates = true
v.Related = val
updatedFields = append(updatedFields, "related")
if u.Timestamp.After(modified) {
modified = u.Timestamp
}
default:
logger.Error("unsupported update field", slog.Any("updateField", u.Field), slog.String("id", id))
}
Expand Down
1 change: 1 addition & 0 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
cloud.google.com/go/pubsub/v2 v2.3.0
cloud.google.com/go/storage v1.59.1
github.com/charmbracelet/lipgloss v1.1.0
github.com/google/go-cmp v0.7.0
github.com/ossf/osv-schema/bindings/go v0.0.0-20260128001339-9d03e8f4632b
golang.org/x/sync v0.19.0
google.golang.org/api v0.263.0
Expand Down
Loading
Loading