1
0
Fork 0
mirror of https://github.com/gocsaf/csaf.git synced 2025-12-22 11:55:40 +01:00

Add 'Rate' config option for download throttling (Checker)

This commit is contained in:
Fadi Abbud 2022-05-30 13:38:29 +02:00
parent 3a2c4f8b22
commit a1036c3847
2 changed files with 45 additions and 15 deletions

View file

@ -26,12 +26,13 @@ import (
var reportHTML string var reportHTML string
type options struct { type options struct {
Output string `short:"o" long:"output" description:"File name of the generated report" value-name:"REPORT-FILE"` Output string `short:"o" long:"output" description:"File name of the generated report" value-name:"REPORT-FILE"`
Format string `short:"f" long:"format" choice:"json" choice:"html" description:"Format of report" default:"json"` Format string `short:"f" long:"format" choice:"json" choice:"html" description:"Format of report" default:"json"`
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"` Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"`
ClientCert *string `long:"client-cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE"` ClientCert *string `long:"client-cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE"`
ClientKey *string `long:"client-key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE"` ClientKey *string `long:"client-key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE"`
Version bool `long:"version" description:"Display version of the binary"` Version bool `long:"version" description:"Display version of the binary"`
Rate *float64 `long:"rate" short:"t"`
} }
func errCheck(err error) { func errCheck(err error) {

View file

@ -11,6 +11,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"crypto/tls" "crypto/tls"
@ -30,6 +31,7 @@ import (
"time" "time"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
"golang.org/x/time/rate"
"github.com/csaf-poc/csaf_distribution/csaf" "github.com/csaf-poc/csaf_distribution/csaf"
"github.com/csaf-poc/csaf_distribution/util" "github.com/csaf-poc/csaf_distribution/util"
@ -38,9 +40,22 @@ import (
// topicMessages stores the collected topicMessages for a specific topic. // topicMessages stores the collected topicMessages for a specific topic.
type topicMessages []string type topicMessages []string
type client interface {
Get(url string) (*http.Response, error)
}
type limitingClient struct {
client
limiter *rate.Limiter
}
func (lc *limitingClient) Get(url string) (*http.Response, error) {
lc.limiter.Wait(context.Background())
return lc.client.Get(url)
}
type processor struct { type processor struct {
opts *options opts *options
client *http.Client client client
redirects map[string]string redirects map[string]string
noneTLS map[string]struct{} noneTLS map[string]struct{}
@ -263,15 +278,14 @@ func (p *processor) checkRedirect(r *http.Request, via []*http.Request) error {
return nil return nil
} }
func (p *processor) httpClient() *http.Client { func (p *processor) httpClient() client {
if p.client != nil { if p.client != nil {
return p.client return p.client
} }
p.client = &http.Client{ client := http.Client{}
CheckRedirect: p.checkRedirect, client.CheckRedirect = p.checkRedirect
}
var tlsConfig tls.Config var tlsConfig tls.Config
if p.opts.Insecure { if p.opts.Insecure {
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
@ -283,10 +297,26 @@ func (p *processor) httpClient() *http.Client {
} }
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.Certificates = []tls.Certificate{cert}
} }
p.client.Transport = &http.Transport{ client.Transport = &http.Transport{
TLSClientConfig: &tlsConfig, TLSClientConfig: &tlsConfig,
} }
p.client = &client
if p.opts.Rate == nil {
return &client
}
var r float64
if p.opts.Rate != nil {
r = *p.opts.Rate
}
p.client = &limitingClient{
client: &client,
limiter: rate.NewLimiter(rate.Limit(r), 1),
}
return p.client return p.client
} }
var yearFromURL = regexp.MustCompile(`.*/(\d{4})/[^/]+$`) var yearFromURL = regexp.MustCompile(`.*/(\d{4})/[^/]+$`)
@ -458,7 +488,6 @@ func (p *processor) integrity(
} }
func (p *processor) processROLIEFeed(feed string) error { func (p *processor) processROLIEFeed(feed string) error {
client := p.httpClient() client := p.httpClient()
res, err := client.Get(feed) res, err := client.Get(feed)
if err != nil { if err != nil {
@ -531,6 +560,7 @@ func (p *processor) processROLIEFeed(feed string) error {
// It returns error if fetching/reading the file(s) fails, otherwise nil. // It returns error if fetching/reading the file(s) fails, otherwise nil.
func (p *processor) checkIndex(base string, mask whereType) error { func (p *processor) checkIndex(base string, mask whereType) error {
client := p.httpClient() client := p.httpClient()
index := base + "/index.txt" index := base + "/index.txt"
p.checkTLS(index) p.checkTLS(index)
@ -795,10 +825,10 @@ func (p *processor) locateProviderMetadata(
) error { ) error {
client := p.httpClient() client := p.httpClient()
tryURL := func(url string) (bool, error) { tryURL := func(url string) (bool, error) {
log.Printf("Trying: %v\n", url) log.Printf("Trying: %v\n", url)
res, err := client.Get(url) res, err := client.Get(url)
if err != nil || res.StatusCode != http.StatusOK || if err != nil || res.StatusCode != http.StatusOK ||
res.Header.Get("Content-Type") != "application/json" { res.Header.Get("Content-Type") != "application/json" {
// ignore this as it is expected. // ignore this as it is expected.
@ -943,7 +973,6 @@ func (p *processor) checkProviderMetadata(domain string) error {
func (p *processor) checkSecurity(domain string) error { func (p *processor) checkSecurity(domain string) error {
client := p.httpClient() client := p.httpClient()
p.badSecurity.use() p.badSecurity.use()
path := "https://" + domain + "/.well-known/security.txt" path := "https://" + domain + "/.well-known/security.txt"