1
0
Fork 0
mirror of https://github.com/gocsaf/csaf.git synced 2025-12-22 05:40:11 +01:00

Add concurrent downloads to downloader. (#363)

* Add concurrent downloads to downloader.

* Moved to Go 1.20

* close files channel on producer side.

* Improve error handling

* New flag to ignore signature check results. Improve docs. Do not use number of CPUs to determine number of download workers.

* Set number of default workers in downloader to two.
This commit is contained in:
Sascha L. Teichmann 2023-05-02 10:10:12 +02:00 committed by GitHub
parent 91479c9912
commit f32fba683d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 87 deletions

View file

@ -14,7 +14,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.19.1 go-version: 1.20.3
- name: Build - name: Build
run: go build -v ./cmd/... run: go build -v ./cmd/...

View file

@ -9,7 +9,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.19.1 go-version: 1.20.3
- name: Set up Node.js - name: Set up Node.js
uses: actions/setup-node@v3 uses: actions/setup-node@v3

View file

@ -15,7 +15,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: '^1.19.1' go-version: '^1.20.3'
- name: Build - name: Build
run: make dist run: make dist

View file

@ -10,10 +10,12 @@ package main
import ( import (
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"hash" "hash"
"io" "io"
@ -25,6 +27,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
@ -34,12 +37,12 @@ import (
) )
type downloader struct { type downloader struct {
client util.Client
opts *options opts *options
directory string directory string
keys []*crypto.KeyRing keys *crypto.KeyRing
eval *util.PathEval eval *util.PathEval
validator csaf.RemoteValidator validator csaf.RemoteValidator
mkdirMu sync.Mutex
} }
func newDownloader(opts *options) (*downloader, error) { func newDownloader(opts *options) (*downloader, error) {
@ -57,6 +60,7 @@ func newDownloader(opts *options) (*downloader, error) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"preparing remote validator failed: %w", err) "preparing remote validator failed: %w", err)
} }
validator = csaf.SynchronizedRemoteValidator(validator)
} }
return &downloader{ return &downloader{
@ -75,10 +79,6 @@ func (d *downloader) close() {
func (d *downloader) httpClient() util.Client { func (d *downloader) httpClient() util.Client {
if d.client != nil {
return d.client
}
hClient := http.Client{} hClient := http.Client{}
var tlsConfig tls.Config var tlsConfig tls.Config
@ -112,14 +112,14 @@ func (d *downloader) httpClient() util.Client {
} }
} }
d.client = client return client
return d.client
} }
func (d *downloader) download(domain string) error { func (d *downloader) download(ctx context.Context, domain string) error {
client := d.httpClient()
lpmd := csaf.LoadProviderMetadataForDomain( lpmd := csaf.LoadProviderMetadataForDomain(
d.httpClient(), domain, func(format string, args ...any) { client, domain, func(format string, args ...any) {
log.Printf( log.Printf(
"Looking for provider-metadata.json of '"+domain+"': "+format+"\n", args...) "Looking for provider-metadata.json of '"+domain+"': "+format+"\n", args...)
}) })
@ -134,7 +134,7 @@ func (d *downloader) download(domain string) error {
} }
if err := d.loadOpenPGPKeys( if err := d.loadOpenPGPKeys(
d.httpClient(), client,
lpmd.Document, lpmd.Document,
base, base,
); err != nil { ); err != nil {
@ -142,13 +142,64 @@ func (d *downloader) download(domain string) error {
} }
afp := csaf.NewAdvisoryFileProcessor( afp := csaf.NewAdvisoryFileProcessor(
d.httpClient(), client,
d.eval, d.eval,
lpmd.Document, lpmd.Document,
base, base,
nil) nil)
return afp.Process(d.downloadFiles) return afp.Process(func(label csaf.TLPLabel, files []csaf.AdvisoryFile) error {
return d.downloadFiles(ctx, label, files)
})
}
func (d *downloader) downloadFiles(
ctx context.Context,
label csaf.TLPLabel,
files []csaf.AdvisoryFile,
) error {
var (
advisoryCh = make(chan csaf.AdvisoryFile)
errorCh = make(chan error)
errDone = make(chan struct{})
errs []error
wg sync.WaitGroup
)
// collect errors
go func() {
defer close(errDone)
for err := range errorCh {
errs = append(errs, err)
}
}()
var n int
if n = d.opts.Worker; n < 1 {
n = 1
}
for i := 0; i < n; i++ {
wg.Add(1)
go d.downloadWorker(ctx, &wg, label, advisoryCh, errorCh)
}
allFiles:
for _, file := range files {
select {
case advisoryCh <- file:
case <-ctx.Done():
break allFiles
}
}
close(advisoryCh)
wg.Wait()
close(errorCh)
<-errDone
return errors.Join(errs...)
} }
func (d *downloader) loadOpenPGPKeys( func (d *downloader) loadOpenPGPKeys(
@ -213,12 +264,15 @@ func (d *downloader) loadOpenPGPKeys(
"Fingerprint of public OpenPGP key %s does not match remotely loaded.", u) "Fingerprint of public OpenPGP key %s does not match remotely loaded.", u)
continue continue
} }
keyring, err := crypto.NewKeyRing(ckey) if d.keys == nil {
if err != nil { if keyring, err := crypto.NewKeyRing(ckey); err != nil {
log.Printf("Creating store for public OpenPGP key %s failed: %v.", u, err) log.Printf("Creating store for public OpenPGP key %s failed: %v.", u, err)
continue } else {
d.keys = keyring
}
} else {
d.keys.AddKey(ckey)
} }
d.keys = append(d.keys, keyring)
} }
return nil return nil
} }
@ -240,21 +294,36 @@ func (d *downloader) logValidationIssues(url string, errors []string, err error)
} }
} }
func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFile) error { func (d *downloader) downloadWorker(
ctx context.Context,
wg *sync.WaitGroup,
label csaf.TLPLabel,
files <-chan csaf.AdvisoryFile,
errorCh chan<- error,
) {
defer wg.Done()
client := d.httpClient() var (
client = d.httpClient()
data bytes.Buffer
lastDir string
initialReleaseDate time.Time
dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339)
lower = strings.ToLower(string(label))
)
var data bytes.Buffer nextAdvisory:
for {
var lastDir string var file csaf.AdvisoryFile
var ok bool
lower := strings.ToLower(string(label)) select {
case file, ok = <-files:
var initialReleaseDate time.Time if !ok {
return
dateExtract := util.TimeMatcher(&initialReleaseDate, time.RFC3339) }
case <-ctx.Done():
for _, file := range files { return
}
u, err := url.Parse(file.URL()) u, err := url.Parse(file.URL())
if err != nil { if err != nil {
@ -297,7 +366,7 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
) )
// Only hash when we have a remote counter part we can compare it with. // Only hash when we have a remote counter part we can compare it with.
if remoteSHA256, s256Data, err = d.loadHash(file.SHA256URL()); err != nil { if remoteSHA256, s256Data, err = loadHash(client, file.SHA256URL()); err != nil {
if d.opts.Verbose { if d.opts.Verbose {
log.Printf("WARN: cannot fetch %s: %v\n", file.SHA256URL(), err) log.Printf("WARN: cannot fetch %s: %v\n", file.SHA256URL(), err)
} }
@ -306,7 +375,7 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
writers = append(writers, s256) writers = append(writers, s256)
} }
if remoteSHA512, s512Data, err = d.loadHash(file.SHA512URL()); err != nil { if remoteSHA512, s512Data, err = loadHash(client, file.SHA512URL()); err != nil {
if d.opts.Verbose { if d.opts.Verbose {
log.Printf("WARN: cannot fetch %s: %v\n", file.SHA512URL(), err) log.Printf("WARN: cannot fetch %s: %v\n", file.SHA512URL(), err)
} }
@ -345,9 +414,9 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
} }
// Only check signature if we have loaded keys. // Only check signature if we have loaded keys.
if len(d.keys) > 0 { if d.keys != nil {
var sign *crypto.PGPSignature var sign *crypto.PGPSignature
sign, signData, err = d.loadSignature(file.SignURL()) sign, signData, err = loadSignature(client, file.SignURL())
if err != nil { if err != nil {
if d.opts.Verbose { if d.opts.Verbose {
log.Printf("downloading signature '%s' failed: %v\n", log.Printf("downloading signature '%s' failed: %v\n",
@ -355,12 +424,14 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
} }
} }
if sign != nil { if sign != nil {
if !d.checkSignature(data.Bytes(), sign) { if err := d.checkSignature(data.Bytes(), sign); err != nil {
log.Printf("Cannot verify signature for %s\n", file.URL()) log.Printf("Cannot verify signature for %s: %v\n", file.URL(), err)
if !d.opts.IgnoreSignatureCheck {
continue continue
} }
} }
} }
}
// Validate against CSAF schema. // Validate against CSAF schema.
if errors, err := csaf.ValidateCSAF(doc); err != nil || len(errors) > 0 { if errors, err := csaf.ValidateCSAF(doc); err != nil || len(errors) > 0 {
@ -372,9 +443,10 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
if d.validator != nil { if d.validator != nil {
rvr, err := d.validator.Validate(doc) rvr, err := d.validator.Validate(doc)
if err != nil { if err != nil {
return fmt.Errorf( errorCh <- fmt.Errorf(
"calling remote validator on %q failed: %w", "calling remote validator on %q failed: %w",
file.URL(), err) file.URL(), err)
continue
} }
if !rvr.Valid { if !rvr.Valid {
log.Printf("Remote validation of %q failed\n", file.URL()) log.Printf("Remote validation of %q failed\n", file.URL())
@ -391,56 +463,51 @@ func (d *downloader) downloadFiles(label csaf.TLPLabel, files []csaf.AdvisoryFil
newDir := path.Join(d.directory, lower, strconv.Itoa(initialReleaseDate.Year())) newDir := path.Join(d.directory, lower, strconv.Itoa(initialReleaseDate.Year()))
if newDir != lastDir { if newDir != lastDir {
if err := os.MkdirAll(newDir, 0755); err != nil { if err := d.mkdirAll(newDir, 0755); err != nil {
return err errorCh <- err
continue
} }
lastDir = newDir lastDir = newDir
} }
path := filepath.Join(lastDir, filename) path := filepath.Join(lastDir, filename)
if err := os.WriteFile(path, data.Bytes(), 0644); err != nil {
return err
}
// Write hash sums. // Write data to disk.
if s256Data != nil { for _, x := range []struct {
if err := os.WriteFile(path+".sha256", s256Data, 0644); err != nil { p string
return err d []byte
}{
{path, data.Bytes()},
{path + ".sha256", s256Data},
{path + ".sha512", s512Data},
{path + ".asc", signData},
} {
if x.d != nil {
if err := os.WriteFile(x.p, x.d, 0644); err != nil {
errorCh <- err
continue nextAdvisory
} }
} }
if s512Data != nil {
if err := os.WriteFile(path+".sha512", s512Data, 0644); err != nil {
return err
}
}
// Write signature.
if signData != nil {
if err := os.WriteFile(path+".asc", signData, 0644); err != nil {
return err
}
} }
log.Printf("Written advisory '%s'.\n", path) log.Printf("Written advisory '%s'.\n", path)
} }
return nil
} }
func (d *downloader) checkSignature(data []byte, sign *crypto.PGPSignature) bool { func (d *downloader) mkdirAll(path string, perm os.FileMode) error {
d.mkdirMu.Lock()
defer d.mkdirMu.Unlock()
return os.MkdirAll(path, perm)
}
func (d *downloader) checkSignature(data []byte, sign *crypto.PGPSignature) error {
pm := crypto.NewPlainMessage(data) pm := crypto.NewPlainMessage(data)
t := crypto.GetUnixTime() t := crypto.GetUnixTime()
for _, key := range d.keys { return d.keys.VerifyDetached(pm, sign, t)
if err := key.VerifyDetached(pm, sign, t); err == nil {
return true
}
}
return false
} }
func (d *downloader) loadSignature(p string) (*crypto.PGPSignature, []byte, error) { func loadSignature(client util.Client, p string) (*crypto.PGPSignature, []byte, error) {
resp, err := d.httpClient().Get(p) resp, err := client.Get(p)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -460,8 +527,8 @@ func (d *downloader) loadSignature(p string) (*crypto.PGPSignature, []byte, erro
return sign, data, nil return sign, data, nil
} }
func (d *downloader) loadHash(p string) ([]byte, []byte, error) { func loadHash(client util.Client, p string) ([]byte, []byte, error) {
resp, err := d.httpClient().Get(p) resp, err := client.Get(p)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -507,14 +574,14 @@ func (d *downloader) prepareDirectory() error {
} }
// run performs the downloads for all the given domains. // run performs the downloads for all the given domains.
func (d *downloader) run(domains []string) error { func (d *downloader) run(ctx context.Context, domains []string) error {
if err := d.prepareDirectory(); err != nil { if err := d.prepareDirectory(); err != nil {
return err return err
} }
for _, domain := range domains { for _, domain := range domains {
if err := d.download(domain); err != nil { if err := d.download(ctx, domain); err != nil {
return err return err
} }
} }

View file

@ -10,21 +10,27 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal"
"github.com/csaf-poc/csaf_distribution/util" "github.com/csaf-poc/csaf_distribution/util"
"github.com/jessevdk/go-flags" "github.com/jessevdk/go-flags"
) )
const defaultWorker = 2
type options struct { type options struct {
Directory *string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR"` Directory *string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR"`
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"` Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"`
IgnoreSignatureCheck bool `long:"ignoresigcheck" description:"Ignore signature check results, just warn on mismatch"`
Version bool `long:"version" description:"Display version of the binary"` Version bool `long:"version" description:"Display version of the binary"`
Verbose bool `long:"verbose" short:"v" description:"Verbose output"` Verbose bool `long:"verbose" short:"v" description:"Verbose output"`
Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)"` Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)"`
Worker int `long:"worker" short:"w" description:"NUMber of concurrent downloads" value-name:"NUM"`
ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields"` ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields"`
@ -48,12 +54,20 @@ func run(opts *options, domains []string) error {
return err return err
} }
defer d.close() defer d.close()
return d.run(domains)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx, stop := signal.NotifyContext(ctx, os.Interrupt)
defer stop()
return d.run(ctx, domains)
} }
func main() { func main() {
opts := new(options) opts := &options{
Worker: defaultWorker,
}
parser := flags.NewParser(opts, flags.Default) parser := flags.NewParser(opts, flags.Default)
parser.Usage = "[OPTIONS] domain..." parser.Usage = "[OPTIONS] domain..."

View file

@ -9,9 +9,11 @@ csaf_downloader [OPTIONS] domain...
Application Options: Application Options:
-d, --directory=DIR DIRectory to store the downloaded files in -d, --directory=DIR DIRectory to store the downloaded files in
--insecure Do not check TLS certificates from provider --insecure Do not check TLS certificates from provider
--ignoresigcheck Ignore signature check results, just warn on mismatch
--version Display version of the binary --version Display version of the binary
-v, --verbose Verbose output -v, --verbose Verbose output
-r, --rate= The average upper limit of https operations per second (defaults to unlimited) -r, --rate= The average upper limit of https operations per second (defaults to unlimited)
-w, --worker=NUM NUMber of concurrent downloads (default: 2)
-H, --header= One or more extra HTTP header fields -H, --header= One or more extra HTTP header fields
--validator=URL URL to validate documents remotely --validator=URL URL to validate documents remotely
--validatorcache=FILE FILE to cache remote validations --validatorcache=FILE FILE to cache remote validations
@ -24,3 +26,8 @@ Help Options:
Will download all CSAF documents for the given _domains_, by trying each as a CSAF provider. Will download all CSAF documents for the given _domains_, by trying each as a CSAF provider.
If a _domain_ starts with `https://` it is instead considered a direct URL to the `provider-metadata.json` and downloading procedes from there. If a _domain_ starts with `https://` it is instead considered a direct URL to the `provider-metadata.json` and downloading procedes from there.
Increasing the number of workers opens more connections to the web servers
to download more advisories at once. This may improve the overall speed of the download.
However, since this also increases the load on the servers, their administrators could
have taken countermeasures to limit this.

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/csaf-poc/csaf_distribution module github.com/csaf-poc/csaf_distribution
go 1.19 go 1.20
require ( require (
github.com/BurntSushi/toml v1.2.1 github.com/BurntSushi/toml v1.2.1