1
0
Fork 0
mirror of https://github.com/gocsaf/csaf.git synced 2025-12-22 18:15:42 +01:00

Merge pull request #625 from gocsaf/close-body-downloader

Move advisory downloading to download context method
This commit is contained in:
JanHoefelmeyer 2025-03-17 11:59:52 +01:00 committed by GitHub
commit cf4cf7c6c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -417,81 +417,74 @@ func (d *downloader) logValidationIssues(url string, errors []string, err error)
} }
} }
func (d *downloader) downloadWorker( // downloadContext stores the common context of a downloader.
ctx context.Context, type downloadContext struct {
wg *sync.WaitGroup, d *downloader
label csaf.TLPLabel, client util.Client
files <-chan csaf.AdvisoryFile,
errorCh chan<- error,
) {
defer wg.Done()
var (
client = d.httpClient()
data bytes.Buffer data bytes.Buffer
lastDir string lastDir string
initialReleaseDate time.Time initialReleaseDate time.Time
dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339) dateExtract func(any) error
lower = strings.ToLower(string(label)) lower string
stats = stats{} stats stats
expr = util.NewPathEval() expr *util.PathEval
) }
// Add collected stats back to total. func newDownloadContext(d *downloader, label csaf.TLPLabel) *downloadContext {
defer d.addStats(&stats) dc := &downloadContext{
d: d,
nextAdvisory: client: d.httpClient(),
for { lower: strings.ToLower(string(label)),
var file csaf.AdvisoryFile expr: util.NewPathEval(),
var ok bool
select {
case file, ok = <-files:
if !ok {
return
}
case <-ctx.Done():
return
} }
dc.dateExtract = util.TimeMatcher(&dc.initialReleaseDate, time.RFC3339)
return dc
}
func (dc *downloadContext) downloadAdvisory(
file csaf.AdvisoryFile,
errorCh chan<- error,
) error {
u, err := url.Parse(file.URL()) u, err := url.Parse(file.URL())
if err != nil { if err != nil {
stats.downloadFailed++ dc.stats.downloadFailed++
slog.Warn("Ignoring invalid URL", slog.Warn("Ignoring invalid URL",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue return nil
} }
if d.cfg.ignoreURL(file.URL()) { if dc.d.cfg.ignoreURL(file.URL()) {
slog.Debug("Ignoring URL", "url", file.URL()) slog.Debug("Ignoring URL", "url", file.URL())
continue return nil
} }
// Ignore not conforming filenames. // Ignore not conforming filenames.
filename := filepath.Base(u.Path) filename := filepath.Base(u.Path)
if !util.ConformingFileName(filename) { if !util.ConformingFileName(filename) {
stats.filenameFailed++ dc.stats.filenameFailed++
slog.Warn("Ignoring none conforming filename", slog.Warn("Ignoring none conforming filename",
"filename", filename) "filename", filename)
continue return nil
} }
resp, err := client.Get(file.URL()) resp, err := dc.client.Get(file.URL())
if err != nil { if err != nil {
stats.downloadFailed++ dc.stats.downloadFailed++
slog.Warn("Cannot GET", slog.Warn("Cannot GET",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue return nil
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
stats.downloadFailed++ dc.stats.downloadFailed++
slog.Warn("Cannot load", slog.Warn("Cannot load",
"url", file.URL(), "url", file.URL(),
"status", resp.Status, "status", resp.Status,
"status_code", resp.StatusCode) "status_code", resp.StatusCode)
continue return nil
} }
// Warn if we do not get JSON. // Warn if we do not get JSON.
@ -515,7 +508,7 @@ nextAdvisory:
url: file.SHA512URL(), url: file.SHA512URL(),
warn: true, warn: true,
hashType: algSha512, hashType: algSha512,
preferred: strings.EqualFold(string(d.cfg.PreferredHash), string(algSha512)), preferred: strings.EqualFold(string(dc.d.cfg.PreferredHash), string(algSha512)),
}) })
} else { } else {
slog.Info("SHA512 not present") slog.Info("SHA512 not present")
@ -525,7 +518,7 @@ nextAdvisory:
url: file.SHA256URL(), url: file.SHA256URL(),
warn: true, warn: true,
hashType: algSha256, hashType: algSha256,
preferred: strings.EqualFold(string(d.cfg.PreferredHash), string(algSha256)), preferred: strings.EqualFold(string(dc.d.cfg.PreferredHash), string(algSha256)),
}) })
} else { } else {
slog.Info("SHA256 not present") slog.Info("SHA256 not present")
@ -536,7 +529,7 @@ nextAdvisory:
} }
} }
remoteSHA256, s256Data, remoteSHA512, s512Data = loadHashes(client, hashToFetch) remoteSHA256, s256Data, remoteSHA512, s512Data = loadHashes(dc.client, hashToFetch)
if remoteSHA512 != nil { if remoteSHA512 != nil {
s512 = sha512.New() s512 = sha512.New()
writers = append(writers, s512) writers = append(writers, s512)
@ -547,30 +540,28 @@ nextAdvisory:
} }
// Remember the data as we need to store it to file later. // Remember the data as we need to store it to file later.
data.Reset() dc.data.Reset()
writers = append(writers, &data) writers = append(writers, &dc.data)
// Download the advisory and hash it. // Download the advisory and hash it.
hasher := io.MultiWriter(writers...) hasher := io.MultiWriter(writers...)
var doc any var doc any
if err := func() error {
defer resp.Body.Close()
tee := io.TeeReader(resp.Body, hasher) tee := io.TeeReader(resp.Body, hasher)
return json.NewDecoder(tee).Decode(&doc)
}(); err != nil { if err := json.NewDecoder(tee).Decode(&doc); err != nil {
stats.downloadFailed++ dc.stats.downloadFailed++
slog.Warn("Downloading failed", slog.Warn("Downloading failed",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue return nil
} }
// Compare the checksums. // Compare the checksums.
s256Check := func() error { s256Check := func() error {
if s256 != nil && !bytes.Equal(s256.Sum(nil), remoteSHA256) { if s256 != nil && !bytes.Equal(s256.Sum(nil), remoteSHA256) {
stats.sha256Failed++ dc.stats.sha256Failed++
return fmt.Errorf("SHA256 checksum of %s does not match", file.URL()) return fmt.Errorf("SHA256 checksum of %s does not match", file.URL())
} }
return nil return nil
@ -578,7 +569,7 @@ nextAdvisory:
s512Check := func() error { s512Check := func() error {
if s512 != nil && !bytes.Equal(s512.Sum(nil), remoteSHA512) { if s512 != nil && !bytes.Equal(s512.Sum(nil), remoteSHA512) {
stats.sha512Failed++ dc.stats.sha512Failed++
return fmt.Errorf("SHA512 checksum of %s does not match", file.URL()) return fmt.Errorf("SHA512 checksum of %s does not match", file.URL())
} }
return nil return nil
@ -587,20 +578,20 @@ nextAdvisory:
// Validate OpenPGP signature. // Validate OpenPGP signature.
keysCheck := func() error { keysCheck := func() error {
// Only check signature if we have loaded keys. // Only check signature if we have loaded keys.
if d.keys == nil { if dc.d.keys == nil {
return nil return nil
} }
var sign *crypto.PGPSignature var sign *crypto.PGPSignature
sign, signData, err = loadSignature(client, file.SignURL()) sign, signData, err = loadSignature(dc.client, file.SignURL())
if err != nil { if err != nil {
slog.Warn("Downloading signature failed", slog.Warn("Downloading signature failed",
"url", file.SignURL(), "url", file.SignURL(),
"error", err) "error", err)
} }
if sign != nil { if sign != nil {
if err := d.checkSignature(data.Bytes(), sign); err != nil { if err := dc.d.checkSignature(dc.data.Bytes(), sign); err != nil {
if !d.cfg.IgnoreSignatureCheck { if !dc.d.cfg.IgnoreSignatureCheck {
stats.signatureFailed++ dc.stats.signatureFailed++
return fmt.Errorf("cannot verify signature for %s: %v", file.URL(), err) return fmt.Errorf("cannot verify signature for %s: %v", file.URL(), err)
} }
} }
@ -611,8 +602,8 @@ nextAdvisory:
// Validate against CSAF schema. // Validate against CSAF schema.
schemaCheck := func() error { schemaCheck := func() error {
if errors, err := csaf.ValidateCSAF(doc); err != nil || len(errors) > 0 { if errors, err := csaf.ValidateCSAF(doc); err != nil || len(errors) > 0 {
stats.schemaFailed++ dc.stats.schemaFailed++
d.logValidationIssues(file.URL(), errors, err) dc.d.logValidationIssues(file.URL(), errors, err)
return fmt.Errorf("schema validation for %q failed", file.URL()) return fmt.Errorf("schema validation for %q failed", file.URL())
} }
return nil return nil
@ -620,8 +611,8 @@ nextAdvisory:
// Validate if filename is conforming. // Validate if filename is conforming.
filenameCheck := func() error { filenameCheck := func() error {
if err := util.IDMatchesFilename(expr, doc, filename); err != nil { if err := util.IDMatchesFilename(dc.expr, doc, filename); err != nil {
stats.filenameFailed++ dc.stats.filenameFailed++
return fmt.Errorf("filename not conforming %s: %s", file.URL(), err) return fmt.Errorf("filename not conforming %s: %s", file.URL(), err)
} }
return nil return nil
@ -629,10 +620,10 @@ nextAdvisory:
// Validate against remote validator. // Validate against remote validator.
remoteValidatorCheck := func() error { remoteValidatorCheck := func() error {
if d.validator == nil { if dc.d.validator == nil {
return nil return nil
} }
rvr, err := d.validator.Validate(doc) rvr, err := dc.d.validator.Validate(doc)
if err != nil { if err != nil {
errorCh <- fmt.Errorf( errorCh <- fmt.Errorf(
"calling remote validator on %q failed: %w", "calling remote validator on %q failed: %w",
@ -640,7 +631,7 @@ nextAdvisory:
return nil return nil
} }
if !rvr.Valid { if !rvr.Valid {
stats.remoteFailed++ dc.stats.remoteFailed++
return fmt.Errorf("remote validation of %q failed", file.URL()) return fmt.Errorf("remote validation of %q failed", file.URL())
} }
return nil return nil
@ -659,71 +650,71 @@ nextAdvisory:
if err := check(); err != nil { if err := check(); err != nil {
slog.Error("Validation check failed", "error", err) slog.Error("Validation check failed", "error", err)
valStatus.update(invalidValidationStatus) valStatus.update(invalidValidationStatus)
if d.cfg.ValidationMode == validationStrict { if dc.d.cfg.ValidationMode == validationStrict {
continue nextAdvisory return nil
} }
} }
} }
valStatus.update(validValidationStatus) valStatus.update(validValidationStatus)
// Send to forwarder // Send to forwarder
if d.forwarder != nil { if dc.d.forwarder != nil {
d.forwarder.forward( dc.d.forwarder.forward(
filename, data.String(), filename, dc.data.String(),
valStatus, valStatus,
string(s256Data), string(s256Data),
string(s512Data)) string(s512Data))
} }
if d.cfg.NoStore { if dc.d.cfg.NoStore {
// Do not write locally. // Do not write locally.
if valStatus == validValidationStatus { if valStatus == validValidationStatus {
stats.succeeded++ dc.stats.succeeded++
} }
continue return nil
} }
if err := expr.Extract( if err := dc.expr.Extract(
`$.document.tracking.initial_release_date`, dateExtract, false, doc, `$.document.tracking.initial_release_date`, dc.dateExtract, false, doc,
); err != nil { ); err != nil {
slog.Warn("Cannot extract initial_release_date from advisory", slog.Warn("Cannot extract initial_release_date from advisory",
"url", file.URL()) "url", file.URL())
initialReleaseDate = time.Now() dc.initialReleaseDate = time.Now()
} }
initialReleaseDate = initialReleaseDate.UTC() dc.initialReleaseDate = dc.initialReleaseDate.UTC()
// Advisories that failed validation are stored in a special folder. // Advisories that failed validation are stored in a special folder.
var newDir string var newDir string
if valStatus != validValidationStatus { if valStatus != validValidationStatus {
newDir = path.Join(d.cfg.Directory, failedValidationDir) newDir = path.Join(dc.d.cfg.Directory, failedValidationDir)
} else { } else {
newDir = d.cfg.Directory newDir = dc.d.cfg.Directory
} }
// Do we have a configured destination folder? // Do we have a configured destination folder?
if d.cfg.Folder != "" { if dc.d.cfg.Folder != "" {
newDir = path.Join(newDir, d.cfg.Folder) newDir = path.Join(newDir, dc.d.cfg.Folder)
} else { } else {
newDir = path.Join(newDir, lower, strconv.Itoa(initialReleaseDate.Year())) newDir = path.Join(newDir, dc.lower, strconv.Itoa(dc.initialReleaseDate.Year()))
} }
if newDir != lastDir { if newDir != dc.lastDir {
if err := d.mkdirAll(newDir, 0755); err != nil { if err := dc.d.mkdirAll(newDir, 0755); err != nil {
errorCh <- err errorCh <- err
continue return nil
} }
lastDir = newDir dc.lastDir = newDir
} }
// Write advisory to file // Write advisory to file
path := filepath.Join(lastDir, filename) path := filepath.Join(dc.lastDir, filename)
// Write data to disk. // Write data to disk.
for _, x := range []struct { for _, x := range []struct {
p string p string
d []byte d []byte
}{ }{
{path, data.Bytes()}, {path, dc.data.Bytes()},
{path + ".sha256", s256Data}, {path + ".sha256", s256Data},
{path + ".sha512", s512Data}, {path + ".sha512", s512Data},
{path + ".asc", signData}, {path + ".asc", signData},
@ -731,13 +722,45 @@ nextAdvisory:
if x.d != nil { if x.d != nil {
if err := os.WriteFile(x.p, x.d, 0644); err != nil { if err := os.WriteFile(x.p, x.d, 0644); err != nil {
errorCh <- err errorCh <- err
continue nextAdvisory return nil
} }
} }
} }
stats.succeeded++ dc.stats.succeeded++
slog.Info("Written advisory", "path", path) slog.Info("Written advisory", "path", path)
return nil
}
func (d *downloader) downloadWorker(
ctx context.Context,
wg *sync.WaitGroup,
label csaf.TLPLabel,
files <-chan csaf.AdvisoryFile,
errorCh chan<- error,
) {
defer wg.Done()
dc := newDownloadContext(d, label)
// Add collected stats back to total.
defer d.addStats(&dc.stats)
for {
var file csaf.AdvisoryFile
var ok bool
select {
case file, ok = <-files:
if !ok {
return
}
case <-ctx.Done():
return
}
if err := dc.downloadAdvisory(file, errorCh); err != nil {
slog.Error("download terminated", "error", err)
return
}
} }
} }