diff --git a/pkg/database/report.go b/pkg/database/report.go index a18fa2b..4b882ba 100644 --- a/pkg/database/report.go +++ b/pkg/database/report.go @@ -66,7 +66,7 @@ func (r *ReportClient) GetExpiredIPsFromReport(reportID uint) ([]*IP, error) { func (r *ReportClient) FindById(reportID uint) (*Report, error) { var report Report - result := r.db.Preload("IPs").Preload("StatsReport").First(&report, reportID) + result := r.db.Preload("StatsReport").First(&report, reportID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -74,6 +74,11 @@ func (r *ReportClient) FindById(reportID uint) (*Report, error) { return nil, result.Error } + err := r.db.Model(&report).Association("IPs").Find(&report.IPs) + if err != nil { + return nil, err + } + return &report, nil } @@ -85,13 +90,19 @@ func (r *ReportClient) FindByHash(filepath string) (*Report, error) { return nil, err } - result := r.db.Preload("IPs").Preload("StatsReport").Where("file_hash = ?", hash).First(&report) + result := r.db.Preload("StatsReport").Where("file_hash = ?", hash).First(&report) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil } return nil, result.Error } + + err = r.db.Model(&report).Association("IPs").Find(&report.IPs) + if err != nil { + return nil, err + } + return &report, nil } @@ -131,10 +142,21 @@ func (r *ReportClient) Find(reportID string) (*Report, error) { func (r *ReportClient) FindAll() ([]*Report, error) { reports := []*Report{} - result := r.db.Preload("IPs").Preload("StatsReport").Find(&reports) + + // First, get all reports with StatsReport only + result := r.db.Preload("StatsReport").Find(&reports) if result.Error != nil { return nil, result.Error } + + // Then load IPs using Association API for each report (GORM handles batching internally) + for _, report := range reports { + err := r.db.Model(report).Association("IPs").Find(&report.IPs) + if err != nil { + return nil, err + } + } + return reports, nil } @@ -158,7 +180,7 @@ func (r *ReportClient) DeleteExpiredSince(expirationDate time.Time) error { func (r *ReportClient) FilePathExist(filePath string) (*Report, bool, error) { var reports []Report - result := r.db.Model(&Report{}).Preload("IPs").Where("file_path = ?", filePath).Find(&reports) + result := r.db.Model(&Report{}).Where("file_path = ?", filePath).Find(&reports) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, false, nil @@ -168,5 +190,11 @@ func (r *ReportClient) FilePathExist(filePath string) (*Report, bool, error) { if len(reports) == 0 { return nil, false, nil } + + err := r.db.Model(&reports[0]).Association("IPs").Find(&reports[0].IPs) + if err != nil { + return nil, false, err + } + return &reports[0], true, nil }