From 07ab770a35341ff1634859d60d76c18829e3040d Mon Sep 17 00:00:00 2001 From: "Sascha L. Teichmann" Date: Mon, 30 May 2022 23:12:08 +0200 Subject: [PATCH] Factored throttling client out of aggregator. --- cmd/csaf_aggregator/client.go | 44 +--------------------- cmd/csaf_aggregator/config.go | 9 +++-- cmd/csaf_aggregator/processor.go | 2 +- util/client.go | 63 ++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 47 deletions(-) create mode 100644 util/client.go diff --git a/cmd/csaf_aggregator/client.go b/cmd/csaf_aggregator/client.go index e2e0c2f..506b525 100644 --- a/cmd/csaf_aggregator/client.go +++ b/cmd/csaf_aggregator/client.go @@ -9,56 +9,16 @@ package main import ( - "context" "errors" "io" "net/http" - "net/url" - "golang.org/x/time/rate" + "github.com/csaf-poc/csaf_distribution/util" ) -type client interface { - Do(req *http.Request) (*http.Response, error) - Get(url string) (*http.Response, error) - Head(url string) (*http.Response, error) - Post(url, contentType string, body io.Reader) (*http.Response, error) - PostForm(url string, data url.Values) (*http.Response, error) -} - -type limitingClient struct { - client - limiter *rate.Limiter -} - -func (lc *limitingClient) Do(req *http.Request) (*http.Response, error) { - lc.limiter.Wait(context.Background()) - return lc.client.Do(req) -} - -func (lc *limitingClient) Get(url string) (*http.Response, error) { - lc.limiter.Wait(context.Background()) - return lc.client.Get(url) -} - -func (lc *limitingClient) Head(url string) (*http.Response, error) { - lc.limiter.Wait(context.Background()) - return lc.client.Head(url) -} - -func (lc *limitingClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { - lc.limiter.Wait(context.Background()) - return lc.client.Post(url, contentType, body) -} - -func (lc *limitingClient) PostForm(url string, data url.Values) (*http.Response, error) { - lc.limiter.Wait(context.Background()) - return lc.client.PostForm(url, data) -} - var errNotFound = errors.New("not found") -func downloadJSON(c client, url string, found func(io.Reader) error) error { +func downloadJSON(c util.Client, url string, found func(io.Reader) error) error { res, err := c.Get(url) if err != nil || res.StatusCode != http.StatusOK || res.Header.Get("Content-Type") != "application/json" { diff --git a/cmd/csaf_aggregator/config.go b/cmd/csaf_aggregator/config.go index 611809a..1159977 100644 --- a/cmd/csaf_aggregator/config.go +++ b/cmd/csaf_aggregator/config.go @@ -21,6 +21,7 @@ import ( "github.com/BurntSushi/toml" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/csaf-poc/csaf_distribution/csaf" + "github.com/csaf-poc/csaf_distribution/util" "golang.org/x/time/rate" ) @@ -105,7 +106,7 @@ func (c *config) cryptoKey() (*crypto.Key, error) { return c.key, c.keyErr } -func (c *config) httpClient(p *provider) client { +func (c *config) httpClient(p *provider) util.Client { client := http.Client{} if p.Insecure != nil && *p.Insecure || c.Insecure != nil && *c.Insecure { @@ -126,9 +127,9 @@ func (c *config) httpClient(p *provider) client { if p.Rate != nil { r = *p.Rate } - return &limitingClient{ - client: &client, - limiter: rate.NewLimiter(rate.Limit(r), 1), + return &util.LimitingClient{ + Client: &client, + Limiter: rate.NewLimiter(rate.Limit(r), 1), } } diff --git a/cmd/csaf_aggregator/processor.go b/cmd/csaf_aggregator/processor.go index 977e5c0..e5c8b52 100644 --- a/cmd/csaf_aggregator/processor.go +++ b/cmd/csaf_aggregator/processor.go @@ -39,7 +39,7 @@ type worker struct { cfg *config signRing *crypto.KeyRing - client client // client per provider + client util.Client // client per provider provider *provider // current provider metadataProvider interface{} // current metadata provider loc string // URL of current provider-metadata.json diff --git a/util/client.go b/util/client.go new file mode 100644 index 0000000..73edd5b --- /dev/null +++ b/util/client.go @@ -0,0 +1,63 @@ +// This file is Free Software under the MIT License +// without warranty, see README.md and LICENSES/MIT.txt for details. +// +// SPDX-License-Identifier: MIT +// +// SPDX-FileCopyrightText: 2022 German Federal Office for Information Security (BSI) +// Software-Engineering: 2022 Intevation GmbH + +package util + +import ( + "context" + "io" + "net/http" + "net/url" + + "golang.org/x/time/rate" +) + +// Client is an interface to abstract http.Client. +type Client interface { + Do(req *http.Request) (*http.Response, error) + Get(url string) (*http.Response, error) + Head(url string) (*http.Response, error) + Post(url, contentType string, body io.Reader) (*http.Response, error) + PostForm(url string, data url.Values) (*http.Response, error) +} + +// LimitingClient is a Client implementing rate throttling. +type LimitingClient struct { + Client + Limiter *rate.Limiter +} + +// Do implements the respective method of the Client interface. +func (lc *LimitingClient) Do(req *http.Request) (*http.Response, error) { + lc.Limiter.Wait(context.Background()) + return lc.Client.Do(req) +} + +// Get implements the respective method of the Client interface. +func (lc *LimitingClient) Get(url string) (*http.Response, error) { + lc.Limiter.Wait(context.Background()) + return lc.Client.Get(url) +} + +// Head implements the respective method of the Client interface. +func (lc *LimitingClient) Head(url string) (*http.Response, error) { + lc.Limiter.Wait(context.Background()) + return lc.Client.Head(url) +} + +// Post implements the respective method of the Client interface. +func (lc *LimitingClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { + lc.Limiter.Wait(context.Background()) + return lc.Client.Post(url, contentType, body) +} + +// PostForm implements the respective method of the Client interface. +func (lc *LimitingClient) PostForm(url string, data url.Values) (*http.Response, error) { + lc.Limiter.Wait(context.Background()) + return lc.Client.PostForm(url, data) +}