mirror of
https://github.com/gocsaf/csaf.git
synced 2025-12-22 18:15:42 +01:00
Merge pull request #153 from csaf-poc/checker-throttling
Checker throttling
This commit is contained in:
commit
562538122a
6 changed files with 101 additions and 62 deletions
|
|
@ -9,56 +9,16 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
"github.com/csaf-poc/csaf_distribution/util"
|
||||
)
|
||||
|
||||
type client interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
Get(url string) (*http.Response, error)
|
||||
Head(url string) (*http.Response, error)
|
||||
Post(url, contentType string, body io.Reader) (*http.Response, error)
|
||||
PostForm(url string, data url.Values) (*http.Response, error)
|
||||
}
|
||||
|
||||
type limitingClient struct {
|
||||
client
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (lc *limitingClient) Do(req *http.Request) (*http.Response, error) {
|
||||
lc.limiter.Wait(context.Background())
|
||||
return lc.client.Do(req)
|
||||
}
|
||||
|
||||
func (lc *limitingClient) Get(url string) (*http.Response, error) {
|
||||
lc.limiter.Wait(context.Background())
|
||||
return lc.client.Get(url)
|
||||
}
|
||||
|
||||
func (lc *limitingClient) Head(url string) (*http.Response, error) {
|
||||
lc.limiter.Wait(context.Background())
|
||||
return lc.client.Head(url)
|
||||
}
|
||||
|
||||
func (lc *limitingClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
||||
lc.limiter.Wait(context.Background())
|
||||
return lc.client.Post(url, contentType, body)
|
||||
}
|
||||
|
||||
func (lc *limitingClient) PostForm(url string, data url.Values) (*http.Response, error) {
|
||||
lc.limiter.Wait(context.Background())
|
||||
return lc.client.PostForm(url, data)
|
||||
}
|
||||
|
||||
var errNotFound = errors.New("not found")
|
||||
|
||||
func downloadJSON(c client, url string, found func(io.Reader) error) error {
|
||||
func downloadJSON(c util.Client, url string, found func(io.Reader) error) error {
|
||||
res, err := c.Get(url)
|
||||
if err != nil || res.StatusCode != http.StatusOK ||
|
||||
res.Header.Get("Content-Type") != "application/json" {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/BurntSushi/toml"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/csaf-poc/csaf_distribution/csaf"
|
||||
"github.com/csaf-poc/csaf_distribution/util"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
|
|
@ -105,7 +106,7 @@ func (c *config) cryptoKey() (*crypto.Key, error) {
|
|||
return c.key, c.keyErr
|
||||
}
|
||||
|
||||
func (c *config) httpClient(p *provider) client {
|
||||
func (c *config) httpClient(p *provider) util.Client {
|
||||
|
||||
client := http.Client{}
|
||||
if p.Insecure != nil && *p.Insecure || c.Insecure != nil && *c.Insecure {
|
||||
|
|
@ -126,9 +127,9 @@ func (c *config) httpClient(p *provider) client {
|
|||
if p.Rate != nil {
|
||||
r = *p.Rate
|
||||
}
|
||||
return &limitingClient{
|
||||
client: &client,
|
||||
limiter: rate.NewLimiter(rate.Limit(r), 1),
|
||||
return &util.LimitingClient{
|
||||
Client: &client,
|
||||
Limiter: rate.NewLimiter(rate.Limit(r), 1),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ type worker struct {
|
|||
cfg *config
|
||||
signRing *crypto.KeyRing
|
||||
|
||||
client client // client per provider
|
||||
client util.Client // client per provider
|
||||
provider *provider // current provider
|
||||
metadataProvider interface{} // current metadata provider
|
||||
loc string // URL of current provider-metadata.json
|
||||
|
|
|
|||
|
|
@ -26,12 +26,13 @@ import (
|
|||
var reportHTML string
|
||||
|
||||
type options struct {
|
||||
Output string `short:"o" long:"output" description:"File name of the generated report" value-name:"REPORT-FILE"`
|
||||
Format string `short:"f" long:"format" choice:"json" choice:"html" description:"Format of report" default:"json"`
|
||||
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"`
|
||||
ClientCert *string `long:"client-cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE"`
|
||||
ClientKey *string `long:"client-key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE"`
|
||||
Version bool `long:"version" description:"Display version of the binary"`
|
||||
Output string `short:"o" long:"output" description:"File name of the generated report" value-name:"REPORT-FILE"`
|
||||
Format string `short:"f" long:"format" choice:"json" choice:"html" description:"Format of report" default:"json"`
|
||||
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider"`
|
||||
ClientCert *string `long:"client-cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE"`
|
||||
ClientKey *string `long:"client-key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE"`
|
||||
Version bool `long:"version" description:"Display version of the binary"`
|
||||
Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second"`
|
||||
}
|
||||
|
||||
func errCheck(err error) {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/csaf-poc/csaf_distribution/csaf"
|
||||
"github.com/csaf-poc/csaf_distribution/util"
|
||||
|
|
@ -40,7 +41,7 @@ type topicMessages []string
|
|||
|
||||
type processor struct {
|
||||
opts *options
|
||||
client *http.Client
|
||||
client util.Client
|
||||
|
||||
redirects map[string]string
|
||||
noneTLS map[string]struct{}
|
||||
|
|
@ -263,19 +264,21 @@ func (p *processor) checkRedirect(r *http.Request, via []*http.Request) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) httpClient() *http.Client {
|
||||
func (p *processor) httpClient() util.Client {
|
||||
|
||||
if p.client != nil {
|
||||
return p.client
|
||||
}
|
||||
|
||||
p.client = &http.Client{
|
||||
CheckRedirect: p.checkRedirect,
|
||||
}
|
||||
client := http.Client{}
|
||||
|
||||
client.CheckRedirect = p.checkRedirect
|
||||
|
||||
var tlsConfig tls.Config
|
||||
if p.opts.Insecure {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
if p.opts.ClientCert != nil && p.opts.ClientKey != nil {
|
||||
cert, err := tls.LoadX509KeyPair(*p.opts.ClientCert, *p.opts.ClientKey)
|
||||
if err != nil {
|
||||
|
|
@ -283,9 +286,21 @@ func (p *processor) httpClient() *http.Client {
|
|||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
p.client.Transport = &http.Transport{
|
||||
|
||||
client.Transport = &http.Transport{
|
||||
TLSClientConfig: &tlsConfig,
|
||||
}
|
||||
|
||||
if p.opts.Rate == nil {
|
||||
p.client = &client
|
||||
return &client
|
||||
}
|
||||
|
||||
p.client = &util.LimitingClient{
|
||||
Client: &client,
|
||||
Limiter: rate.NewLimiter(rate.Limit(*p.opts.Rate), 1),
|
||||
}
|
||||
|
||||
return p.client
|
||||
}
|
||||
|
||||
|
|
@ -458,7 +473,6 @@ func (p *processor) integrity(
|
|||
}
|
||||
|
||||
func (p *processor) processROLIEFeed(feed string) error {
|
||||
|
||||
client := p.httpClient()
|
||||
res, err := client.Get(feed)
|
||||
if err != nil {
|
||||
|
|
@ -531,6 +545,7 @@ func (p *processor) processROLIEFeed(feed string) error {
|
|||
// It returns error if fetching/reading the file(s) fails, otherwise nil.
|
||||
func (p *processor) checkIndex(base string, mask whereType) error {
|
||||
client := p.httpClient()
|
||||
|
||||
index := base + "/index.txt"
|
||||
p.checkTLS(index)
|
||||
|
||||
|
|
@ -795,10 +810,10 @@ func (p *processor) locateProviderMetadata(
|
|||
) error {
|
||||
|
||||
client := p.httpClient()
|
||||
|
||||
tryURL := func(url string) (bool, error) {
|
||||
log.Printf("Trying: %v\n", url)
|
||||
res, err := client.Get(url)
|
||||
|
||||
if err != nil || res.StatusCode != http.StatusOK ||
|
||||
res.Header.Get("Content-Type") != "application/json" {
|
||||
// ignore this as it is expected.
|
||||
|
|
@ -943,7 +958,6 @@ func (p *processor) checkProviderMetadata(domain string) error {
|
|||
func (p *processor) checkSecurity(domain string) error {
|
||||
|
||||
client := p.httpClient()
|
||||
|
||||
p.badSecurity.use()
|
||||
|
||||
path := "https://" + domain + "/.well-known/security.txt"
|
||||
|
|
|
|||
63
util/client.go
Normal file
63
util/client.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
// This file is Free Software under the MIT License
|
||||
// without warranty, see README.md and LICENSES/MIT.txt for details.
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
// SPDX-FileCopyrightText: 2022 German Federal Office for Information Security (BSI) <https://www.bsi.bund.de>
|
||||
// Software-Engineering: 2022 Intevation GmbH <https://intevation.de>
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Client is an interface to abstract http.Client.
|
||||
type Client interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
Get(url string) (*http.Response, error)
|
||||
Head(url string) (*http.Response, error)
|
||||
Post(url, contentType string, body io.Reader) (*http.Response, error)
|
||||
PostForm(url string, data url.Values) (*http.Response, error)
|
||||
}
|
||||
|
||||
// LimitingClient is a Client implementing rate throttling.
|
||||
type LimitingClient struct {
|
||||
Client
|
||||
Limiter *rate.Limiter
|
||||
}
|
||||
|
||||
// Do implements the respective method of the Client interface.
|
||||
func (lc *LimitingClient) Do(req *http.Request) (*http.Response, error) {
|
||||
lc.Limiter.Wait(context.Background())
|
||||
return lc.Client.Do(req)
|
||||
}
|
||||
|
||||
// Get implements the respective method of the Client interface.
|
||||
func (lc *LimitingClient) Get(url string) (*http.Response, error) {
|
||||
lc.Limiter.Wait(context.Background())
|
||||
return lc.Client.Get(url)
|
||||
}
|
||||
|
||||
// Head implements the respective method of the Client interface.
|
||||
func (lc *LimitingClient) Head(url string) (*http.Response, error) {
|
||||
lc.Limiter.Wait(context.Background())
|
||||
return lc.Client.Head(url)
|
||||
}
|
||||
|
||||
// Post implements the respective method of the Client interface.
|
||||
func (lc *LimitingClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
||||
lc.Limiter.Wait(context.Background())
|
||||
return lc.Client.Post(url, contentType, body)
|
||||
}
|
||||
|
||||
// PostForm implements the respective method of the Client interface.
|
||||
func (lc *LimitingClient) PostForm(url string, data url.Values) (*http.Response, error) {
|
||||
lc.Limiter.Wait(context.Background())
|
||||
return lc.Client.PostForm(url, data)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue