1
0
Fork 0
mirror of https://github.com/gocsaf/csaf.git synced 2025-12-22 05:40:11 +01:00

Add concurrent downloads to downloader. (#363)

* Add concurrent downloads to downloader.

* Moved to Go 1.20

* close files channel on producer side.

* Improve error handling

* New flag to ignore signature check results. Improve docs. Do not use number of CPUs to determine number of download workers.

* Set number of default workers in downloader to two.
This commit is contained in:
Sascha L. Teichmann 2023-05-02 10:10:12 +02:00 committed by GitHub
parent 91479c9912
commit f32fba683d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 87 deletions

View file

@ -10,10 +10,12 @@ package main
import (
"bytes"
"context"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
@ -25,6 +27,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/ProtonMail/gopenpgp/v2/crypto"
@ -34,12 +37,12 @@ import (
)
type downloader struct {
client util.Client
opts *options
directory string
keys []*crypto.KeyRing
keys *crypto.KeyRing
eval *util.PathEval
validator csaf.RemoteValidator
mkdirMu sync.Mutex
}
func newDownloader(opts *options) (*downloader, error) {
@ -57,6 +60,7 @@ func newDownloader(opts *options) (*downloader, error) {
return nil, fmt.Errorf(
"preparing remote validator failed: %w", err)
}
validator = csaf.SynchronizedRemoteValidator(validator)
}
return &downloader{
@ -75,10 +79,6 @@ func (d *downloader) close() {
func (d *downloader) httpClient() util.Client {
if d.client != nil {
return d.client
}
hClient := http.Client{}
var tlsConfig tls.Config
@ -112,14 +112,14 @@ func (d *downloader) httpClient() util.Client {
}
}
d.client = client
return d.client
return client
}
func (d *downloader) download(domain string) error {
func (d *downloader) download(ctx context.Context, domain string) error {
client := d.httpClient()
lpmd := csaf.LoadProviderMetadataForDomain(
d.httpClient(), domain, func(format string, args ...any) {
client, domain, func(format string, args ...any) {
log.Printf(
"Looking for provider-metadata.json of '"+domain+"': "+format+"\n", args...)
})
@ -134,7 +134,7 @@ func (d *downloader) download(domain string) error {
}
if err := d.loadOpenPGPKeys(
d.httpClient(),
client,
lpmd.Document,
base,
); err != nil {
@ -142,13 +142,64 @@ func (d *downloader) download(domain string) error {
}
afp := csaf.NewAdvisoryFileProcessor(
d.httpClient(),
client,
d.eval,
lpmd.Document,
base,
nil)
return afp.Process(d.downloadFiles)
return afp.Process(func(label csaf.TLPLabel, files []csaf.AdvisoryFile) error {
return d.downloadFiles(ctx, label, files)
})
}
func (d *downloader) downloadFiles(
ctx context.Context,
label csaf.TLPLabel,
files []csaf.AdvisoryFile,
) error {
var (
advisoryCh = make(chan csaf.AdvisoryFile)
errorCh = make(chan error)
errDone = make(chan struct{})
errs []error
wg sync.WaitGroup
)
// collect errors
go func() {
defer close(errDone)
for err := range errorCh {
errs = append(errs, err)
}
}()
var n int
if n = d.opts.Worker; n < 1 {
n = 1
}
for i := 0; i < n; i++ {
wg.Add(1)
go d.downloadWorker(ctx, &wg, label, advisoryCh, errorCh)
}
allFiles:
for _, file := range files {
select {
case advisoryCh <- file:
case <-ctx.Done():
break allFiles
}
}
close(advisoryCh)
wg.Wait()
close(errorCh)
<-errDone
return errors.Join(errs...)
}
func (d *downloader) loadOpenPGPKeys(
@ -213,12 +264,15 @@ func (d *downloader) loadOpenPGPKeys(
"Fingerprint of public OpenPGP key %s does not match remotely loaded.", u)
continue
}
keyring, err := crypto.NewKeyRing(ckey)
if err != nil {
log.Printf("Creating store for public OpenPGP key %s failed: %v.", u, err)
continue
if d.keys == nil {
if keyring, err := crypto.NewKeyRing(ckey); err != nil {
log.Printf("Creating store for public OpenPGP key %s failed: %v.", u, err)
} else {
d.keys = keyring
}
} else {
d.keys.AddKey(ckey)
}
d.keys = append(d.keys, keyring)
}
return nil
}
@ -240,21 +294,36 @@ func (d *downloader) logValidationIssues(url string, errors []string, err error)
}
}
func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFile) error {
func (d *downloader) downloadWorker(
ctx context.Context,
wg *sync.WaitGroup,
label csaf.TLPLabel,
files <-chan csaf.AdvisoryFile,
errorCh chan<- error,
) {
defer wg.Done()
client := d.httpClient()
var (
client = d.httpClient()
data bytes.Buffer
lastDir string
initialReleaseDate time.Time
dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339)
lower = strings.ToLower(string(label))
)
var data bytes.Buffer
var lastDir string
lower := strings.ToLower(string(label))
var initialReleaseDate time.Time
dateExtract := util.TimeMatcher(&initialReleaseDate, time.RFC3339)
for _, file := range files {
nextAdvisory:
for {
var file csaf.AdvisoryFile
var ok bool
select {
case file, ok = <-files:
if !ok {
return
}
case <-ctx.Done():
return
}
u, err := url.Parse(file.URL())
if err != nil {
@ -297,7 +366,7 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
)
// Only hash when we have a remote counter part we can compare it with.
if remoteSHA256, s256Data, err = d.loadHash(file.SHA256URL()); err != nil {
if remoteSHA256, s256Data, err = loadHash(client, file.SHA256URL()); err != nil {
if d.opts.Verbose {
log.Printf("WARN: cannot fetch %s: %v\n", file.SHA256URL(), err)
}
@ -306,7 +375,7 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
writers = append(writers, s256)
}
if remoteSHA512, s512Data, err = d.loadHash(file.SHA512URL()); err != nil {
if remoteSHA512, s512Data, err = loadHash(client, file.SHA512URL()); err != nil {
if d.opts.Verbose {
log.Printf("WARN: cannot fetch %s: %v\n", file.SHA512URL(), err)
}
@ -345,9 +414,9 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
}
// Only check signature if we have loaded keys.
if len(d.keys) > 0 {
if d.keys != nil {
var sign *crypto.PGPSignature
sign, signData, err = d.loadSignature(file.SignURL())
sign, signData, err = loadSignature(client, file.SignURL())
if err != nil {
if d.opts.Verbose {
log.Printf("downloading signature '%s' failed: %v\n",
@ -355,9 +424,11 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
}
}
if sign != nil {
if !d.checkSignature(data.Bytes(), sign) {
log.Printf("Cannot verify signature for %s\n", file.URL())
continue
if err := d.checkSignature(data.Bytes(), sign); err != nil {
log.Printf("Cannot verify signature for %s: %v\n", file.URL(), err)
if !d.opts.IgnoreSignatureCheck {
continue
}
}
}
}
@ -372,9 +443,10 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
if d.validator != nil {
rvr, err := d.validator.Validate(doc)
if err != nil {
return fmt.Errorf(
errorCh <- fmt.Errorf(
"calling remote validator on %q failed: %w",
file.URL(), err)
continue
}
if !rvr.Valid {
log.Printf("Remote validation of %q failed\n", file.URL())
@ -391,56 +463,51 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
newDir := path.Join(d.directory, lower, strconv.Itoa(initialReleaseDate.Year()))
if newDir != lastDir {
if err := os.MkdirAll(newDir, 0755); err != nil {
return err
if err := d.mkdirAll(newDir, 0755); err != nil {
errorCh <- err
continue
}
lastDir = newDir
}
path := filepath.Join(lastDir, filename)
if err := os.WriteFile(path, data.Bytes(), 0644); err != nil {
return err
}
// Write hash sums.
if s256Data != nil {
if err := os.WriteFile(path+".sha256", s256Data, 0644); err != nil {
return err
}
}
if s512Data != nil {
if err := os.WriteFile(path+".sha512", s512Data, 0644); err != nil {
return err
}
}
// Write signature.
if signData != nil {
if err := os.WriteFile(path+".asc", signData, 0644); err != nil {
return err
// Write data to disk.
for _, x := range []struct {
p string
d []byte
}{
{path, data.Bytes()},
{path + ".sha256", s256Data},
{path + ".sha512", s512Data},
{path + ".asc", signData},
} {
if x.d != nil {
if err := os.WriteFile(x.p, x.d, 0644); err != nil {
errorCh <- err
continue nextAdvisory
}
}
}
log.Printf("Written advisory '%s'.\n", path)
}
return nil
}
func (d *downloader) checkSignature(data []byte, sign *crypto.PGPSignature) bool {
func (d *downloader) mkdirAll(path string, perm os.FileMode) error {
d.mkdirMu.Lock()
defer d.mkdirMu.Unlock()
return os.MkdirAll(path, perm)
}
func (d *downloader) checkSignature(data []byte, sign *crypto.PGPSignature) error {
pm := crypto.NewPlainMessage(data)
t := crypto.GetUnixTime()
for _, key := range d.keys {
if err := key.VerifyDetached(pm, sign, t); err == nil {
return true
}
}
return false
return d.keys.VerifyDetached(pm, sign, t)
}
func (d *downloader) loadSignature(p string) (*crypto.PGPSignature, []byte, error) {
resp, err := d.httpClient().Get(p)
func loadSignature(client util.Client, p string) (*crypto.PGPSignature, []byte, error) {
resp, err := client.Get(p)
if err != nil {
return nil, nil, err
}
@ -460,8 +527,8 @@ func (d *downloader) loadSignature(p string) (*crypto.PGPSignature, []byte, erro
return sign, data, nil
}
func (d *downloader) loadHash(p string) ([]byte, []byte, error) {
resp, err := d.httpClient().Get(p)
func loadHash(client util.Client, p string) ([]byte, []byte, error) {
resp, err := client.Get(p)
if err != nil {
return nil, nil, err
}
@ -507,14 +574,14 @@ func (d *downloader) prepareDirectory() error {
}
// run performs the downloads for all the given domains.
func (d *downloader) run(domains []string) error {
func (d *downloader) run(ctx context.Context, domains []string) error {
if err := d.prepareDirectory(); err != nil {
return err
}
for _, domain := range domains {
if err := d.download(domain); err != nil {
if err := d.download(ctx, domain); err != nil {
return err
}
}