1
0
Fork 0
mirror of https://github.com/gocsaf/csaf.git synced 2025-12-22 11:55:40 +01:00

Refactor downloader to allow usage as library

- Allow to specify custom logger
- Add callback for downloaded documents
- Separate config of downloader CLI and library
- Remove `Forwarder` tests that test the handler
This commit is contained in:
koplas 2024-06-19 15:55:05 +02:00 committed by koplas
parent fae4fdeabe
commit 513282a7a8
No known key found for this signature in database
6 changed files with 514 additions and 481 deletions

View file

@ -0,0 +1,253 @@
// This file is Free Software under the Apache-2.0 License
// without warranty, see README.md and LICENSES/Apache-2.0.txt for details.
//
// SPDX-License-Identifier: Apache-2.0
//
// SPDX-FileCopyrightText: 2022 German Federal Office for Information Security (BSI) <https://www.bsi.bund.de>
// Software-Engineering: 2022 Intevation GmbH <https://intevation.de>
package main
import (
"crypto/tls"
"io"
"log"
"log/slog"
"net/http"
"os"
"path/filepath"
"time"
"github.com/csaf-poc/csaf_distribution/v3/internal/certs"
"github.com/csaf-poc/csaf_distribution/v3/internal/filter"
"github.com/csaf-poc/csaf_distribution/v3/internal/models"
"github.com/csaf-poc/csaf_distribution/v3/internal/options"
"github.com/csaf-poc/csaf_distribution/v3/lib/downloader"
)
const (
defaultWorker = 2
defaultPreset = "mandatory"
defaultForwardQueue = 5
defaultValidationMode = downloader.ValidationStrict
defaultLogFile = "downloader.log"
defaultLogLevel = slog.LevelInfo
)
// configPaths are the potential file locations of the Config file.
var configPaths = []string{
"~/.config/csaf/downloader.toml",
"~/.csaf_downloader.toml",
"csaf_downloader.toml",
}
type config struct {
Directory string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR" toml:"directory"`
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider" toml:"insecure"`
IgnoreSignatureCheck bool `long:"ignore_sigcheck" description:"Ignore signature check results, just warn on mismatch" toml:"ignore_sigcheck"`
ClientCert *string `long:"client_cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE" toml:"client_cert"`
ClientKey *string `long:"client_key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE" toml:"client_key"`
ClientPassphrase *string `long:"client_passphrase" description:"Optional passphrase for the client cert (limited, experimental, see doc)" value-name:"PASSPHRASE" toml:"client_passphrase"`
Version bool `long:"version" description:"Display version of the binary" toml:"-"`
NoStore bool `long:"no_store" short:"n" description:"Do not store files" toml:"no_store"`
Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)" toml:"rate"`
Worker int `long:"worker" short:"w" description:"NUMber of concurrent downloads" value-name:"NUM" toml:"worker"`
Range *models.TimeRange `long:"time_range" short:"t" description:"RANGE of time from which advisories to download" value-name:"RANGE" toml:"time_range"`
Folder string `long:"folder" short:"f" description:"Download into a given subFOLDER" value-name:"FOLDER" toml:"folder"`
IgnorePattern []string `long:"ignore_pattern" short:"i" description:"Do not download files if their URLs match any of the given PATTERNs" value-name:"PATTERN" toml:"ignore_pattern"`
ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields" toml:"header"`
EnumeratePMDOnly bool `long:"enumerate_pmd_only" description:"If this flag is set to true, the downloader will only enumerate valid provider metadata files, but not download documents" toml:"enumerate_pmd_only"`
RemoteValidator string `long:"validator" description:"URL to validate documents remotely" value-name:"URL" toml:"validator"`
RemoteValidatorCache string `long:"validator_cache" description:"FILE to cache remote validations" value-name:"FILE" toml:"validator_cache"`
RemoteValidatorPresets []string `long:"validator_preset" description:"One or more PRESETS to validate remotely" value-name:"PRESETS" toml:"validator_preset"`
//lint:ignore SA5008 We are using choice twice: strict, unsafe.
ValidationMode downloader.ValidationMode `long:"validation_mode" short:"m" choice:"strict" choice:"unsafe" value-name:"MODE" description:"MODE how strict the validation is" toml:"validation_mode"`
ForwardURL string `long:"forward_url" description:"URL of HTTP endpoint to forward downloads to" value-name:"URL" toml:"forward_url"`
ForwardHeader http.Header `long:"forward_header" description:"One or more extra HTTP header fields used by forwarding" toml:"forward_header"`
ForwardQueue int `long:"forward_queue" description:"Maximal queue LENGTH before forwarder" value-name:"LENGTH" toml:"forward_queue"`
ForwardInsecure bool `long:"forward_insecure" description:"Do not check TLS certificates from forward endpoint" toml:"forward_insecure"`
LogFile *string `long:"log_file" description:"FILE to log downloading to" value-name:"FILE" toml:"log_file"`
//lint:ignore SA5008 We are using choice or than once: debug, info, warn, error
LogLevel *options.LogLevel `long:"log_level" description:"LEVEL of logging details" value-name:"LEVEL" choice:"debug" choice:"info" choice:"warn" choice:"error" toml:"log_level"`
Config string `short:"c" long:"config" description:"Path to config TOML file" value-name:"TOML-FILE" toml:"-"`
clientCerts []tls.Certificate
ignorePattern filter.PatternMatcher
logger *slog.Logger
}
// parseArgsConfig parses the command line and if needed a config file.
func parseArgsConfig() ([]string, *config, error) {
var (
logFile = defaultLogFile
logLevel = &options.LogLevel{Level: defaultLogLevel}
)
p := options.Parser[config]{
DefaultConfigLocations: configPaths,
ConfigLocation: func(cfg *config) string { return cfg.Config },
Usage: "[OPTIONS] domain...",
HasVersion: func(cfg *config) bool { return cfg.Version },
SetDefaults: func(cfg *config) {
cfg.Worker = defaultWorker
cfg.RemoteValidatorPresets = []string{defaultPreset}
cfg.ValidationMode = defaultValidationMode
cfg.ForwardQueue = defaultForwardQueue
cfg.LogFile = &logFile
cfg.LogLevel = logLevel
},
// Re-establish default values if not set.
EnsureDefaults: func(cfg *config) {
if cfg.Worker == 0 {
cfg.Worker = defaultWorker
}
if cfg.RemoteValidatorPresets == nil {
cfg.RemoteValidatorPresets = []string{defaultPreset}
}
switch cfg.ValidationMode {
case downloader.ValidationStrict, downloader.ValidationUnsafe:
default:
cfg.ValidationMode = downloader.ValidationStrict
}
if cfg.LogFile == nil {
cfg.LogFile = &logFile
}
if cfg.LogLevel == nil {
cfg.LogLevel = logLevel
}
},
}
return p.Parse()
}
// prepareDirectory ensures that the working directory
// exists and is setup properly.
func (cfg *config) prepareDirectory() error {
// If not given use current working directory.
if cfg.Directory == "" {
dir, err := os.Getwd()
if err != nil {
return err
}
cfg.Directory = dir
return nil
}
// Use given directory
if _, err := os.Stat(cfg.Directory); err != nil {
// If it does not exist create it.
if os.IsNotExist(err) {
if err = os.MkdirAll(cfg.Directory, 0755); err != nil {
return err
}
} else {
return err
}
}
return nil
}
// dropSubSeconds drops all parts below resolution of seconds.
func dropSubSeconds(_ []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
t := a.Value.Time()
a.Value = slog.TimeValue(t.Truncate(time.Second))
}
return a
}
// prepareLogging sets up the structured logging.
func (cfg *config) prepareLogging() error {
var w io.Writer
if cfg.LogFile == nil || *cfg.LogFile == "" {
log.Println("using STDERR for logging")
w = os.Stderr
} else {
var fname string
// We put the log inside the download folder
// if it is not absolute.
if filepath.IsAbs(*cfg.LogFile) {
fname = *cfg.LogFile
} else {
fname = filepath.Join(cfg.Directory, *cfg.LogFile)
}
f, err := os.OpenFile(fname, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return err
}
log.Printf("using %q for logging\n", fname)
w = f
}
ho := slog.HandlerOptions{
// AddSource: true,
Level: cfg.LogLevel.Level,
ReplaceAttr: dropSubSeconds,
}
handler := slog.NewJSONHandler(w, &ho)
cfg.logger = slog.New(handler)
return nil
}
// compileIgnorePatterns compiles the configure patterns to be ignored.
func (cfg *config) compileIgnorePatterns() error {
pm, err := filter.NewPatternMatcher(cfg.IgnorePattern)
if err != nil {
return err
}
cfg.ignorePattern = pm
return nil
}
// prepareCertificates loads the client side certificates used by the HTTP client.
func (cfg *config) prepareCertificates() error {
cert, err := certs.LoadCertificate(
cfg.ClientCert, cfg.ClientKey, cfg.ClientPassphrase)
if err != nil {
return err
}
cfg.clientCerts = cert
return nil
}
// Prepare prepares internal state of a loaded configuration.
func (cfg *config) GetDownloadConfig() (*downloader.Config, error) {
for _, prepare := range []func(*config) error{
(*config).prepareDirectory,
(*config).prepareLogging,
(*config).prepareCertificates,
(*config).compileIgnorePatterns,
} {
if err := prepare(cfg); err != nil {
return nil, err
}
}
dCfg := &downloader.Config{
Insecure: cfg.Insecure,
IgnoreSignatureCheck: cfg.IgnoreSignatureCheck,
ClientCerts: cfg.clientCerts,
ClientKey: cfg.ClientKey,
ClientPassphrase: cfg.ClientPassphrase,
Rate: cfg.Rate,
Worker: cfg.Worker,
Range: cfg.Range,
IgnorePattern: cfg.ignorePattern,
ExtraHeader: cfg.ExtraHeader,
RemoteValidator: cfg.RemoteValidator,
RemoteValidatorCache: cfg.RemoteValidatorCache,
RemoteValidatorPresets: cfg.RemoteValidatorPresets,
ValidationMode: cfg.ValidationMode,
ForwardURL: cfg.ForwardURL,
ForwardHeader: cfg.ForwardHeader,
ForwardQueue: cfg.ForwardQueue,
ForwardInsecure: cfg.ForwardInsecure,
Logger: cfg.logger,
}
return dCfg, nil
}

View file

@ -11,75 +11,40 @@ package main
import ( import (
"context" "context"
"github.com/csaf-poc/csaf_distribution/v3/lib/downloader"
"log/slog" "log/slog"
"os" "os"
"os/signal" "os/signal"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"github.com/csaf-poc/csaf_distribution/v3/internal/options" "github.com/csaf-poc/csaf_distribution/v3/internal/options"
"github.com/csaf-poc/csaf_distribution/v3/lib/downloader"
) )
const ( // failedForwardDir is the name of the special sub folder
defaultWorker = 2 // where advisories get stored which fail forwarding.
defaultPreset = "mandatory" const failedForwardDir = "failed_forward"
defaultForwardQueue = 5
defaultValidationMode = downloader.ValidationStrict
defaultLogFile = "downloader.log"
defaultLogLevel = slog.LevelInfo
)
// configPaths are the potential file locations of the Config file. // failedValidationDir is the name of the sub folder
var configPaths = []string{ // where advisories are stored that fail validation in
"~/.config/csaf/downloader.toml", // unsafe mode.
"~/.csaf_downloader.toml", const failedValidationDir = "failed_validation"
"csaf_downloader.toml",
}
// parseArgsConfig parses the command line and if needed a config file. var mkdirMu sync.Mutex
func parseArgsConfig() ([]string, *downloader.Config, error) {
var ( func run(cfg *config, domains []string) error {
logFile = defaultLogFile dCfg, err := cfg.GetDownloadConfig()
logLevel = &options.LogLevel{Level: defaultLogLevel} if err != nil {
) return err
p := options.Parser[downloader.Config]{
DefaultConfigLocations: configPaths,
ConfigLocation: func(cfg *downloader.Config) string { return cfg.Config },
Usage: "[OPTIONS] domain...",
HasVersion: func(cfg *downloader.Config) bool { return cfg.Version },
SetDefaults: func(cfg *downloader.Config) {
cfg.Worker = defaultWorker
cfg.RemoteValidatorPresets = []string{defaultPreset}
cfg.ValidationMode = defaultValidationMode
cfg.ForwardQueue = defaultForwardQueue
cfg.LogFile = &logFile
cfg.LogLevel = logLevel
},
// Re-establish default values if not set.
EnsureDefaults: func(cfg *downloader.Config) {
if cfg.Worker == 0 {
cfg.Worker = defaultWorker
}
if cfg.RemoteValidatorPresets == nil {
cfg.RemoteValidatorPresets = []string{defaultPreset}
}
switch cfg.ValidationMode {
case downloader.ValidationStrict, downloader.ValidationUnsafe:
default:
cfg.ValidationMode = downloader.ValidationStrict
}
if cfg.LogFile == nil {
cfg.LogFile = &logFile
}
if cfg.LogLevel == nil {
cfg.LogLevel = logLevel
}
},
} }
return p.Parse()
}
func run(cfg *downloader.Config, domains []string) error { dCfg.DownloadHandler = downloadHandler(cfg)
d, err := downloader.NewDownloader(cfg) dCfg.FailedForwardHandler = storeFailedAdvisory(cfg)
d, err := downloader.NewDownloader(dCfg)
if err != nil { if err != nil {
return err return err
} }
@ -91,7 +56,7 @@ func run(cfg *downloader.Config, domains []string) error {
defer stop() defer stop()
if cfg.ForwardURL != "" { if cfg.ForwardURL != "" {
f := downloader.NewForwarder(cfg) f := downloader.NewForwarder(dCfg)
go f.Run() go f.Run()
defer func() { defer func() {
f.Log() f.Log()
@ -108,11 +73,103 @@ func run(cfg *downloader.Config, domains []string) error {
return d.Run(ctx, domains) return d.Run(ctx, domains)
} }
func main() { func mkdirAll(path string, perm os.FileMode) error {
mkdirMu.Lock()
defer mkdirMu.Unlock()
return os.MkdirAll(path, perm)
}
func downloadHandler(cfg *config) func(d downloader.DownloadedDocument) error {
return func(d downloader.DownloadedDocument) error {
if cfg.NoStore {
// Do not write locally.
if d.ValStatus == downloader.ValidValidationStatus {
return nil
}
}
var lastDir string
// Advisories that failed validation are stored in a special folder.
var newDir string
if d.ValStatus != downloader.ValidValidationStatus {
newDir = path.Join(cfg.Directory, failedValidationDir)
} else {
newDir = cfg.Directory
}
lower := strings.ToLower(string(d.Label))
// Do we have a configured destination folder?
if cfg.Folder != "" {
newDir = path.Join(newDir, cfg.Folder)
} else {
newDir = path.Join(newDir, lower, strconv.Itoa(d.InitialReleaseDate.Year()))
}
if newDir != lastDir {
if err := mkdirAll(newDir, 0755); err != nil {
return err
}
lastDir = newDir
}
// Write advisory to file
filePath := filepath.Join(lastDir, d.Filename)
for _, x := range []struct {
p string
d []byte
}{
{filePath, d.Data.Bytes()},
{filePath + ".sha256", d.S256Data},
{filePath + ".sha512", d.S512Data},
{filePath + ".asc", d.SignData},
} {
if x.d != nil {
if err := os.WriteFile(x.p, x.d, 0644); err != nil {
return err
}
}
}
slog.Info("Written advisory", "path", filePath)
return nil
}
}
// storeFailedAdvisory stores an advisory in a special folder
// in case the forwarding failed.
func storeFailedAdvisory(cfg *config) func(filename, doc, sha256, sha512 string) error {
return func(filename, doc, sha256, sha512 string) error {
// Create special folder if it does not exist.
dir := filepath.Join(cfg.Directory, failedForwardDir)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
// Store parts which are not empty.
for _, x := range []struct {
p string
d string
}{
{filename, doc},
{filename + ".sha256", sha256},
{filename + ".sha512", sha512},
} {
if len(x.d) != 0 {
path := filepath.Join(dir, x.p)
if err := os.WriteFile(path, []byte(x.d), 0644); err != nil {
return err
}
}
}
return nil
}
}
func main() {
domains, cfg, err := parseArgsConfig() domains, cfg, err := parseArgsConfig()
options.ErrorCheck(err) options.ErrorCheck(err)
options.ErrorCheck(cfg.Prepare())
if len(domains) == 0 { if len(domains) == 0 {
slog.Warn("No domains given.") slog.Warn("No domains given.")

View file

@ -11,18 +11,11 @@ package downloader
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io"
"log"
"log/slog" "log/slog"
"net/http" "net/http"
"os"
"path/filepath"
"time"
"github.com/csaf-poc/csaf_distribution/v3/internal/certs"
"github.com/csaf-poc/csaf_distribution/v3/internal/filter" "github.com/csaf-poc/csaf_distribution/v3/internal/filter"
"github.com/csaf-poc/csaf_distribution/v3/internal/models" "github.com/csaf-poc/csaf_distribution/v3/internal/models"
"github.com/csaf-poc/csaf_distribution/v3/internal/options"
) )
// ValidationMode specifies the strict the validation is. // ValidationMode specifies the strict the validation is.
@ -37,43 +30,33 @@ const (
// Config provides the download configuration. // Config provides the download configuration.
type Config struct { type Config struct {
Directory string `short:"d" long:"directory" description:"DIRectory to store the downloaded files in" value-name:"DIR" toml:"directory"` Insecure bool
Insecure bool `long:"insecure" description:"Do not check TLS certificates from provider" toml:"insecure"` IgnoreSignatureCheck bool
IgnoreSignatureCheck bool `long:"ignore_sigcheck" description:"Ignore signature check results, just warn on mismatch" toml:"ignore_sigcheck"` ClientCerts []tls.Certificate
ClientCert *string `long:"client_cert" description:"TLS client certificate file (PEM encoded data)" value-name:"CERT-FILE" toml:"client_cert"` ClientKey *string
ClientKey *string `long:"client_key" description:"TLS client private key file (PEM encoded data)" value-name:"KEY-FILE" toml:"client_key"` ClientPassphrase *string
ClientPassphrase *string `long:"client_passphrase" description:"Optional passphrase for the client cert (limited, experimental, see doc)" value-name:"PASSPHRASE" toml:"client_passphrase"` Rate *float64
Version bool `long:"version" description:"Display version of the binary" toml:"-"` Worker int
NoStore bool `long:"no_store" short:"n" description:"Do not store files" toml:"no_store"` Range *models.TimeRange
Rate *float64 `long:"rate" short:"r" description:"The average upper limit of https operations per second (defaults to unlimited)" toml:"rate"` IgnorePattern filter.PatternMatcher
Worker int `long:"worker" short:"w" description:"NUMber of concurrent downloads" value-name:"NUM" toml:"worker"` ExtraHeader http.Header
Range *models.TimeRange `long:"time_range" short:"t" description:"RANGE of time from which advisories to download" value-name:"RANGE" toml:"time_range"`
Folder string `long:"folder" short:"f" description:"Download into a given subFOLDER" value-name:"FOLDER" toml:"folder"`
IgnorePattern []string `long:"ignore_pattern" short:"i" description:"Do not download files if their URLs match any of the given PATTERNs" value-name:"PATTERN" toml:"ignore_pattern"`
ExtraHeader http.Header `long:"header" short:"H" description:"One or more extra HTTP header fields" toml:"header"`
EnumeratePMDOnly bool `long:"enumerate_pmd_only" description:"If this flag is set to true, the downloader will only enumerate valid provider metadata files, but not download documents" toml:"enumerate_pmd_only"` RemoteValidator string
// CLI only?
RemoteValidatorCache string
RemoteValidatorPresets []string
RemoteValidator string `long:"validator" description:"URL to validate documents remotely" value-name:"URL" toml:"validator"` ValidationMode ValidationMode
RemoteValidatorCache string `long:"validator_cache" description:"FILE to cache remote validations" value-name:"FILE" toml:"validator_cache"`
RemoteValidatorPresets []string `long:"validator_preset" description:"One or more PRESETS to validate remotely" value-name:"PRESETS" toml:"validator_preset"`
//lint:ignore SA5008 We are using choice twice: strict, unsafe. ForwardURL string
ValidationMode ValidationMode `long:"validation_mode" short:"m" choice:"strict" choice:"unsafe" value-name:"MODE" description:"MODE how strict the validation is" toml:"validation_mode"` ForwardHeader http.Header
ForwardQueue int
ForwardInsecure bool
ForwardURL string `long:"forward_url" description:"URL of HTTP endpoint to forward downloads to" value-name:"URL" toml:"forward_url"` DownloadHandler func(DownloadedDocument) error
ForwardHeader http.Header `long:"forward_header" description:"One or more extra HTTP header fields used by forwarding" toml:"forward_header"` FailedForwardHandler func(filename, doc, sha256, sha512 string) error
ForwardQueue int `long:"forward_queue" description:"Maximal queue LENGTH before forwarder" value-name:"LENGTH" toml:"forward_queue"`
ForwardInsecure bool `long:"forward_insecure" description:"Do not check TLS certificates from forward endpoint" toml:"forward_insecure"`
LogFile *string `long:"log_file" description:"FILE to log downloading to" value-name:"FILE" toml:"log_file"` Logger *slog.Logger
//lint:ignore SA5008 We are using choice or than once: debug, info, warn, error
LogLevel *options.LogLevel `long:"log_level" description:"LEVEL of logging details" value-name:"LEVEL" choice:"debug" choice:"info" choice:"warn" choice:"error" toml:"log_level"`
Config string `short:"c" long:"config" description:"Path to config TOML file" value-name:"TOML-FILE" toml:"-"`
clientCerts []tls.Certificate
ignorePattern filter.PatternMatcher
} }
// UnmarshalText implements [encoding.TextUnmarshaler]. // UnmarshalText implements [encoding.TextUnmarshaler].
@ -99,114 +82,10 @@ func (vm *ValidationMode) UnmarshalFlag(value string) error {
// ignoreFile returns true if the given URL should not be downloaded. // ignoreFile returns true if the given URL should not be downloaded.
func (cfg *Config) ignoreURL(u string) bool { func (cfg *Config) ignoreURL(u string) bool {
return cfg.ignorePattern.Matches(u) return cfg.IgnorePattern.Matches(u)
} }
// verbose is considered a log level equal or less debug. // verbose is considered a log level equal or less debug.
func (cfg *Config) verbose() bool { func (cfg *Config) verbose() bool {
return cfg.LogLevel.Level <= slog.LevelDebug return cfg.Logger.Enabled(nil, slog.LevelDebug)
}
// prepareDirectory ensures that the working directory
// exists and is setup properly.
func (cfg *Config) prepareDirectory() error {
// If not given use current working directory.
if cfg.Directory == "" {
dir, err := os.Getwd()
if err != nil {
return err
}
cfg.Directory = dir
return nil
}
// Use given directory
if _, err := os.Stat(cfg.Directory); err != nil {
// If it does not exist create it.
if os.IsNotExist(err) {
if err = os.MkdirAll(cfg.Directory, 0755); err != nil {
return err
}
} else {
return err
}
}
return nil
}
// dropSubSeconds drops all parts below resolution of seconds.
func dropSubSeconds(_ []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
t := a.Value.Time()
a.Value = slog.TimeValue(t.Truncate(time.Second))
}
return a
}
// prepareLogging sets up the structured logging.
func (cfg *Config) prepareLogging() error {
var w io.Writer
if cfg.LogFile == nil || *cfg.LogFile == "" {
log.Println("using STDERR for logging")
w = os.Stderr
} else {
var fname string
// We put the log inside the download folder
// if it is not absolute.
if filepath.IsAbs(*cfg.LogFile) {
fname = *cfg.LogFile
} else {
fname = filepath.Join(cfg.Directory, *cfg.LogFile)
}
f, err := os.OpenFile(fname, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return err
}
log.Printf("using %q for logging\n", fname)
w = f
}
ho := slog.HandlerOptions{
//AddSource: true,
Level: cfg.LogLevel.Level,
ReplaceAttr: dropSubSeconds,
}
handler := slog.NewJSONHandler(w, &ho)
logger := slog.New(handler)
slog.SetDefault(logger)
return nil
}
// compileIgnorePatterns compiles the configure patterns to be ignored.
func (cfg *Config) compileIgnorePatterns() error {
pm, err := filter.NewPatternMatcher(cfg.IgnorePattern)
if err != nil {
return err
}
cfg.ignorePattern = pm
return nil
}
// prepareCertificates loads the client side certificates used by the HTTP client.
func (cfg *Config) prepareCertificates() error {
cert, err := certs.LoadCertificate(
cfg.ClientCert, cfg.ClientKey, cfg.ClientPassphrase)
if err != nil {
return err
}
cfg.clientCerts = cert
return nil
}
// Prepare prepares internal state of a loaded configuration.
func (cfg *Config) Prepare() error {
for _, prepare := range []func(*Config) error{
(*Config).prepareDirectory,
(*Config).prepareLogging,
(*Config).prepareCertificates,
(*Config).compileIgnorePatterns,
} {
if err := prepare(cfg); err != nil {
return err
}
}
return nil
} }

View file

@ -22,10 +22,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path"
"path/filepath" "path/filepath"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -48,6 +45,18 @@ type Downloader struct {
stats stats stats stats
} }
// DownloadedDocument contains the document data with additional metadata.
type DownloadedDocument struct {
Data bytes.Buffer
S256Data []byte
S512Data []byte
SignData []byte
InitialReleaseDate time.Time
Filename string
ValStatus ValidationStatus
Label csaf.TLPLabel
}
// failedValidationDir is the name of the sub folder // failedValidationDir is the name of the sub folder
// where advisories are stored that fail validation in // where advisories are stored that fail validation in
// unsafe mode. // unsafe mode.
@ -94,15 +103,17 @@ func (d *Downloader) addStats(o *stats) {
} }
// logRedirect logs redirects of the http client. // logRedirect logs redirects of the http client.
func logRedirect(req *http.Request, via []*http.Request) error { func logRedirect(logger *slog.Logger) func(req *http.Request, via []*http.Request) error {
vs := make([]string, len(via)) return func(req *http.Request, via []*http.Request) error {
for i, v := range via { vs := make([]string, len(via))
vs[i] = v.URL.String() for i, v := range via {
vs[i] = v.URL.String()
}
logger.Debug("Redirecting",
"to", req.URL.String(),
"via", strings.Join(vs, " -> "))
return nil
} }
slog.Debug("Redirecting",
"to", req.URL.String(),
"via", strings.Join(vs, " -> "))
return nil
} }
func (d *Downloader) httpClient() util.Client { func (d *Downloader) httpClient() util.Client {
@ -110,7 +121,7 @@ func (d *Downloader) httpClient() util.Client {
hClient := http.Client{} hClient := http.Client{}
if d.cfg.verbose() { if d.cfg.verbose() {
hClient.CheckRedirect = logRedirect hClient.CheckRedirect = logRedirect(d.cfg.Logger)
} }
var tlsConfig tls.Config var tlsConfig tls.Config
@ -118,8 +129,8 @@ func (d *Downloader) httpClient() util.Client {
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
} }
if len(d.cfg.clientCerts) != 0 { if len(d.cfg.ClientCerts) != 0 {
tlsConfig.Certificates = d.cfg.clientCerts tlsConfig.Certificates = d.cfg.ClientCerts
} }
hClient.Transport = &http.Transport{ hClient.Transport = &http.Transport{
@ -140,7 +151,7 @@ func (d *Downloader) httpClient() util.Client {
if d.cfg.verbose() { if d.cfg.verbose() {
client = &util.LoggingClient{ client = &util.LoggingClient{
Client: client, Client: client,
Log: httpLog("downloader"), Log: httpLog("downloader", d.cfg.Logger),
} }
} }
@ -156,9 +167,9 @@ func (d *Downloader) httpClient() util.Client {
} }
// httpLog does structured logging in a [util.LoggingClient]. // httpLog does structured logging in a [util.LoggingClient].
func httpLog(who string) func(string, string) { func httpLog(who string, logger *slog.Logger) func(string, string) {
return func(method, url string) { return func(method, url string) {
slog.Debug("http", logger.Debug("http",
"who", who, "who", who,
"method", method, "method", method,
"url", url) "url", url)
@ -176,7 +187,7 @@ func (d *Downloader) enumerate(domain string) error {
for _, pmd := range lpmd { for _, pmd := range lpmd {
if d.cfg.verbose() { if d.cfg.verbose() {
for i := range pmd.Messages { for i := range pmd.Messages {
slog.Debug("Enumerating provider-metadata.json", d.cfg.Logger.Debug("Enumerating provider-metadata.json",
"domain", domain, "domain", domain,
"message", pmd.Messages[i].Message) "message", pmd.Messages[i].Message)
} }
@ -188,7 +199,7 @@ func (d *Downloader) enumerate(domain string) error {
// print the results // print the results
doc, err := json.MarshalIndent(docs, "", " ") doc, err := json.MarshalIndent(docs, "", " ")
if err != nil { if err != nil {
slog.Error("Couldn't marshal PMD document json") d.cfg.Logger.Error("Couldn't marshal PMD document json")
} }
fmt.Println(string(doc)) fmt.Println(string(doc))
@ -211,7 +222,7 @@ func (d *Downloader) download(ctx context.Context, domain string) error {
return fmt.Errorf("no valid provider-metadata.json found for '%s'", domain) return fmt.Errorf("no valid provider-metadata.json found for '%s'", domain)
} else if d.cfg.verbose() { } else if d.cfg.verbose() {
for i := range lpmd.Messages { for i := range lpmd.Messages {
slog.Debug("Loading provider-metadata.json", d.cfg.Logger.Debug("Loading provider-metadata.json",
"domain", domain, "domain", domain,
"message", lpmd.Messages[i].Message) "message", lpmd.Messages[i].Message)
} }
@ -241,7 +252,7 @@ func (d *Downloader) download(ctx context.Context, domain string) error {
// Do we need time range based filtering? // Do we need time range based filtering?
if d.cfg.Range != nil { if d.cfg.Range != nil {
slog.Debug("Setting up filter to accept advisories within", d.cfg.Logger.Debug("Setting up filter to accept advisories within",
"timerange", d.cfg.Range) "timerange", d.cfg.Range)
afp.AgeAccept = d.cfg.Range.Contains afp.AgeAccept = d.cfg.Range.Contains
} }
@ -306,7 +317,6 @@ func (d *Downloader) loadOpenPGPKeys(
base *url.URL, base *url.URL,
expr *util.PathEval, expr *util.PathEval,
) error { ) error {
src, err := expr.Eval("$.public_openpgp_keys", doc) src, err := expr.Eval("$.public_openpgp_keys", doc)
if err != nil { if err != nil {
// no keys. // no keys.
@ -331,7 +341,7 @@ func (d *Downloader) loadOpenPGPKeys(
} }
up, err := url.Parse(*key.URL) up, err := url.Parse(*key.URL)
if err != nil { if err != nil {
slog.Warn("Invalid URL", d.cfg.Logger.Warn("Invalid URL",
"url", *key.URL, "url", *key.URL,
"error", err) "error", err)
continue continue
@ -341,14 +351,14 @@ func (d *Downloader) loadOpenPGPKeys(
res, err := client.Get(u) res, err := client.Get(u)
if err != nil { if err != nil {
slog.Warn( d.cfg.Logger.Warn(
"Fetching public OpenPGP key failed", "Fetching public OpenPGP key failed",
"url", u, "url", u,
"error", err) "error", err)
continue continue
} }
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
slog.Warn( d.cfg.Logger.Warn(
"Fetching public OpenPGP key failed", "Fetching public OpenPGP key failed",
"url", u, "url", u,
"status_code", res.StatusCode, "status_code", res.StatusCode,
@ -362,7 +372,7 @@ func (d *Downloader) loadOpenPGPKeys(
}() }()
if err != nil { if err != nil {
slog.Warn( d.cfg.Logger.Warn(
"Reading public OpenPGP key failed", "Reading public OpenPGP key failed",
"url", u, "url", u,
"error", err) "error", err)
@ -370,14 +380,14 @@ func (d *Downloader) loadOpenPGPKeys(
} }
if !strings.EqualFold(ckey.GetFingerprint(), string(key.Fingerprint)) { if !strings.EqualFold(ckey.GetFingerprint(), string(key.Fingerprint)) {
slog.Warn( d.cfg.Logger.Warn(
"Fingerprint of public OpenPGP key does not match remotely loaded", "Fingerprint of public OpenPGP key does not match remotely loaded",
"url", u) "url", u)
continue continue
} }
if d.keys == nil { if d.keys == nil {
if keyring, err := crypto.NewKeyRing(ckey); err != nil { if keyring, err := crypto.NewKeyRing(ckey); err != nil {
slog.Warn( d.cfg.Logger.Warn(
"Creating store for public OpenPGP key failed", "Creating store for public OpenPGP key failed",
"url", u, "url", u,
"error", err) "error", err)
@ -394,18 +404,18 @@ func (d *Downloader) loadOpenPGPKeys(
// logValidationIssues logs the issues reported by the advisory schema validation. // logValidationIssues logs the issues reported by the advisory schema validation.
func (d *Downloader) logValidationIssues(url string, errors []string, err error) { func (d *Downloader) logValidationIssues(url string, errors []string, err error) {
if err != nil { if err != nil {
slog.Error("Failed to validate", d.cfg.Logger.Error("Failed to validate",
"url", url, "url", url,
"error", err) "error", err)
return return
} }
if len(errors) > 0 { if len(errors) > 0 {
if d.cfg.verbose() { if d.cfg.verbose() {
slog.Error("CSAF file has validation errors", d.cfg.Logger.Error("CSAF file has validation errors",
"url", url, "url", url,
"error", strings.Join(errors, ", ")) "error", strings.Join(errors, ", "))
} else { } else {
slog.Error("CSAF file has validation errors", d.cfg.Logger.Error("CSAF file has validation errors",
"url", url, "url", url,
"count", len(errors)) "count", len(errors))
} }
@ -424,10 +434,8 @@ func (d *Downloader) downloadWorker(
var ( var (
client = d.httpClient() client = d.httpClient()
data bytes.Buffer data bytes.Buffer
lastDir string
initialReleaseDate time.Time initialReleaseDate time.Time
dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339) dateExtract = util.TimeMatcher(&initialReleaseDate, time.RFC3339)
lower = strings.ToLower(string(label))
stats = stats{} stats = stats{}
expr = util.NewPathEval() expr = util.NewPathEval()
) )
@ -451,14 +459,14 @@ nextAdvisory:
u, err := url.Parse(file.URL()) u, err := url.Parse(file.URL())
if err != nil { if err != nil {
stats.downloadFailed++ stats.downloadFailed++
slog.Warn("Ignoring invalid URL", d.cfg.Logger.Warn("Ignoring invalid URL",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue continue
} }
if d.cfg.ignoreURL(file.URL()) { if d.cfg.ignoreURL(file.URL()) {
slog.Debug("Ignoring URL", "url", file.URL()) d.cfg.Logger.Debug("Ignoring URL", "url", file.URL())
continue continue
} }
@ -466,7 +474,7 @@ nextAdvisory:
filename := filepath.Base(u.Path) filename := filepath.Base(u.Path)
if !util.ConformingFileName(filename) { if !util.ConformingFileName(filename) {
stats.filenameFailed++ stats.filenameFailed++
slog.Warn("Ignoring none conforming filename", d.cfg.Logger.Warn("Ignoring none conforming filename",
"filename", filename) "filename", filename)
continue continue
} }
@ -474,7 +482,7 @@ nextAdvisory:
resp, err := client.Get(file.URL()) resp, err := client.Get(file.URL())
if err != nil { if err != nil {
stats.downloadFailed++ stats.downloadFailed++
slog.Warn("Cannot GET", d.cfg.Logger.Warn("Cannot GET",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue continue
@ -482,7 +490,7 @@ nextAdvisory:
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
stats.downloadFailed++ stats.downloadFailed++
slog.Warn("Cannot load", d.cfg.Logger.Warn("Cannot load",
"url", file.URL(), "url", file.URL(),
"status", resp.Status, "status", resp.Status,
"status_code", resp.StatusCode) "status_code", resp.StatusCode)
@ -491,7 +499,7 @@ nextAdvisory:
// Warn if we do not get JSON. // Warn if we do not get JSON.
if ct := resp.Header.Get("Content-Type"); ct != "application/json" { if ct := resp.Header.Get("Content-Type"); ct != "application/json" {
slog.Warn("Content type is not 'application/json'", d.cfg.Logger.Warn("Content type is not 'application/json'",
"url", file.URL(), "url", file.URL(),
"content_type", ct) "content_type", ct)
} }
@ -506,7 +514,7 @@ nextAdvisory:
// Only hash when we have a remote counter part we can compare it with. // Only hash when we have a remote counter part we can compare it with.
if remoteSHA256, s256Data, err = loadHash(client, file.SHA256URL()); err != nil { if remoteSHA256, s256Data, err = loadHash(client, file.SHA256URL()); err != nil {
slog.Warn("Cannot fetch SHA256", d.cfg.Logger.Warn("Cannot fetch SHA256",
"url", file.SHA256URL(), "url", file.SHA256URL(),
"error", err) "error", err)
} else { } else {
@ -515,7 +523,7 @@ nextAdvisory:
} }
if remoteSHA512, s512Data, err = loadHash(client, file.SHA512URL()); err != nil { if remoteSHA512, s512Data, err = loadHash(client, file.SHA512URL()); err != nil {
slog.Warn("Cannot fetch SHA512", d.cfg.Logger.Warn("Cannot fetch SHA512",
"url", file.SHA512URL(), "url", file.SHA512URL(),
"error", err) "error", err)
} else { } else {
@ -538,7 +546,7 @@ nextAdvisory:
return json.NewDecoder(tee).Decode(&doc) return json.NewDecoder(tee).Decode(&doc)
}(); err != nil { }(); err != nil {
stats.downloadFailed++ stats.downloadFailed++
slog.Warn("Downloading failed", d.cfg.Logger.Warn("Downloading failed",
"url", file.URL(), "url", file.URL(),
"error", err) "error", err)
continue continue
@ -570,7 +578,7 @@ nextAdvisory:
var sign *crypto.PGPSignature var sign *crypto.PGPSignature
sign, signData, err = loadSignature(client, file.SignURL()) sign, signData, err = loadSignature(client, file.SignURL())
if err != nil { if err != nil {
slog.Warn("Downloading signature failed", d.cfg.Logger.Warn("Downloading signature failed",
"url", file.SignURL(), "url", file.SignURL(),
"error", err) "error", err)
} }
@ -624,7 +632,7 @@ nextAdvisory:
} }
// Run all the validations. // Run all the validations.
valStatus := notValidatedValidationStatus valStatus := NotValidatedValidationStatus
for _, check := range []func() error{ for _, check := range []func() error{
s256Check, s256Check,
s512Check, s512Check,
@ -634,14 +642,14 @@ nextAdvisory:
remoteValidatorCheck, remoteValidatorCheck,
} { } {
if err := check(); err != nil { if err := check(); err != nil {
slog.Error("Validation check failed", "error", err) d.cfg.Logger.Error("Validation check failed", "error", err)
valStatus.update(invalidValidationStatus) valStatus.update(InvalidValidationStatus)
if d.cfg.ValidationMode == ValidationStrict { if d.cfg.ValidationMode == ValidationStrict {
continue nextAdvisory continue nextAdvisory
} }
} }
} }
valStatus.update(validValidationStatus) valStatus.update(ValidValidationStatus)
// Send to Forwarder // Send to Forwarder
if d.Forwarder != nil { if d.Forwarder != nil {
@ -651,15 +659,6 @@ nextAdvisory:
string(s256Data), string(s256Data),
string(s512Data)) string(s512Data))
} }
if d.cfg.NoStore {
// Do not write locally.
if valStatus == validValidationStatus {
stats.succeeded++
}
continue
}
if err := expr.Extract( if err := expr.Extract(
`$.document.tracking.initial_release_date`, dateExtract, false, doc, `$.document.tracking.initial_release_date`, dateExtract, false, doc,
); err != nil { ); err != nil {
@ -669,61 +668,26 @@ nextAdvisory:
} }
initialReleaseDate = initialReleaseDate.UTC() initialReleaseDate = initialReleaseDate.UTC()
// Advisories that failed validation are stored in a special folder. download := DownloadedDocument{
var newDir string Data: data,
if valStatus != validValidationStatus { S256Data: s256Data,
newDir = path.Join(d.cfg.Directory, failedValidationDir) S512Data: s512Data,
SignData: signData,
InitialReleaseDate: initialReleaseDate,
Filename: filename,
ValStatus: valStatus,
Label: label,
}
err = d.cfg.DownloadHandler(download)
if err != nil {
errorCh <- err
} else { } else {
newDir = d.cfg.Directory stats.succeeded++
} }
// Do we have a configured destination folder?
if d.cfg.Folder != "" {
newDir = path.Join(newDir, d.cfg.Folder)
} else {
newDir = path.Join(newDir, lower, strconv.Itoa(initialReleaseDate.Year()))
}
if newDir != lastDir {
if err := d.mkdirAll(newDir, 0755); err != nil {
errorCh <- err
continue
}
lastDir = newDir
}
// Write advisory to file
filePath := filepath.Join(lastDir, filename)
// Write data to disk.
for _, x := range []struct {
p string
d []byte
}{
{filePath, data.Bytes()},
{filePath + ".sha256", s256Data},
{filePath + ".sha512", s512Data},
{filePath + ".asc", signData},
} {
if x.d != nil {
if err := os.WriteFile(x.p, x.d, 0644); err != nil {
errorCh <- err
continue nextAdvisory
}
}
}
stats.succeeded++
slog.Info("Written advisory", "path", filePath)
} }
} }
func (d *Downloader) mkdirAll(path string, perm os.FileMode) error {
d.mkdirMu.Lock()
defer d.mkdirMu.Unlock()
return os.MkdirAll(path, perm)
}
func (d *Downloader) checkSignature(data []byte, sign *crypto.PGPSignature) error { func (d *Downloader) checkSignature(data []byte, sign *crypto.PGPSignature) error {
pm := crypto.NewPlainMessage(data) pm := crypto.NewPlainMessage(data)
t := crypto.GetUnixTime() t := crypto.GetUnixTime()

View file

@ -12,10 +12,8 @@ import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"io" "io"
"log/slog"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -27,19 +25,22 @@ import (
// where advisories get stored which fail forwarding. // where advisories get stored which fail forwarding.
const failedForwardDir = "failed_forward" const failedForwardDir = "failed_forward"
// validationStatus represents the validation status // ValidationStatus represents the validation status
// known to the HTTP endpoint. // known to the HTTP endpoint.
type validationStatus string type ValidationStatus string
const ( const (
validValidationStatus = validationStatus("valid") // ValidValidationStatus represents a valid document.
invalidValidationStatus = validationStatus("invalid") ValidValidationStatus = ValidationStatus("valid")
notValidatedValidationStatus = validationStatus("not_validated") // InvalidValidationStatus represents an invalid document.
InvalidValidationStatus = ValidationStatus("invalid")
// NotValidatedValidationStatus represents a not validated document.
NotValidatedValidationStatus = ValidationStatus("not_validated")
) )
func (vs *validationStatus) update(status validationStatus) { func (vs *ValidationStatus) update(status ValidationStatus) {
// Cannot heal after it fails at least once. // Cannot heal after it fails at least once.
if *vs != invalidValidationStatus { if *vs != InvalidValidationStatus {
*vs = status *vs = status
} }
} }
@ -69,7 +70,7 @@ func NewForwarder(cfg *Config) *Forwarder {
// Run runs the Forwarder. Meant to be used in a Go routine. // Run runs the Forwarder. Meant to be used in a Go routine.
func (f *Forwarder) Run() { func (f *Forwarder) Run() {
defer slog.Debug("Forwarder done") defer f.cfg.Logger.Debug("Forwarder done")
for cmd := range f.cmds { for cmd := range f.cmds {
cmd(f) cmd(f)
@ -84,7 +85,7 @@ func (f *Forwarder) Close() {
// Log logs the current statistics. // Log logs the current statistics.
func (f *Forwarder) Log() { func (f *Forwarder) Log() {
f.cmds <- func(f *Forwarder) { f.cmds <- func(f *Forwarder) {
slog.Info("Forward statistics", f.cfg.Logger.Info("Forward statistics",
"succeeded", f.succeeded, "succeeded", f.succeeded,
"failed", f.failed) "failed", f.failed)
} }
@ -122,7 +123,7 @@ func (f *Forwarder) httpClient() util.Client {
if f.cfg.verbose() { if f.cfg.verbose() {
client = &util.LoggingClient{ client = &util.LoggingClient{
Client: client, Client: client,
Log: httpLog("Forwarder"), Log: httpLog("Forwarder", f.cfg.Logger),
} }
} }
@ -139,7 +140,7 @@ func replaceExt(fname, nExt string) string {
// buildRequest creates an HTTP request suited to forward the given advisory. // buildRequest creates an HTTP request suited to forward the given advisory.
func (f *Forwarder) buildRequest( func (f *Forwarder) buildRequest(
filename, doc string, filename, doc string,
status validationStatus, status ValidationStatus,
sha256, sha512 string, sha256, sha512 string,
) (*http.Request, error) { ) (*http.Request, error) {
body := new(bytes.Buffer) body := new(bytes.Buffer)
@ -187,38 +188,11 @@ func (f *Forwarder) buildRequest(
return req, nil return req, nil
} }
// storeFailedAdvisory stores an advisory in a special folder
// in case the forwarding failed.
func (f *Forwarder) storeFailedAdvisory(filename, doc, sha256, sha512 string) error {
// Create special folder if it does not exist.
dir := filepath.Join(f.cfg.Directory, failedForwardDir)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
// Store parts which are not empty.
for _, x := range []struct {
p string
d string
}{
{filename, doc},
{filename + ".sha256", sha256},
{filename + ".sha512", sha512},
} {
if len(x.d) != 0 {
path := filepath.Join(dir, x.p)
if err := os.WriteFile(path, []byte(x.d), 0644); err != nil {
return err
}
}
}
return nil
}
// storeFailed is a logging wrapper around storeFailedAdvisory. // storeFailed is a logging wrapper around storeFailedAdvisory.
func (f *Forwarder) storeFailed(filename, doc, sha256, sha512 string) { func (f *Forwarder) storeFailed(filename, doc, sha256, sha512 string) {
f.failed++ f.failed++
if err := f.storeFailedAdvisory(filename, doc, sha256, sha512); err != nil { if err := f.cfg.FailedForwardHandler(filename, doc, sha256, sha512); err != nil {
slog.Error("Storing advisory failed forwarding failed", f.cfg.Logger.Error("Storing advisory failed forwarding failed",
"error", err) "error", err)
} }
} }
@ -241,21 +215,21 @@ func limitedString(r io.Reader, max int) (string, error) {
// till the configured queue size is filled. // till the configured queue size is filled.
func (f *Forwarder) forward( func (f *Forwarder) forward(
filename, doc string, filename, doc string,
status validationStatus, status ValidationStatus,
sha256, sha512 string, sha256, sha512 string,
) { ) {
// Run this in the main loop of the Forwarder. // Run this in the main loop of the Forwarder.
f.cmds <- func(f *Forwarder) { f.cmds <- func(f *Forwarder) {
req, err := f.buildRequest(filename, doc, status, sha256, sha512) req, err := f.buildRequest(filename, doc, status, sha256, sha512)
if err != nil { if err != nil {
slog.Error("building forward Request failed", f.cfg.Logger.Error("building forward Request failed",
"error", err) "error", err)
f.storeFailed(filename, doc, sha256, sha512) f.storeFailed(filename, doc, sha256, sha512)
return return
} }
res, err := f.httpClient().Do(req) res, err := f.httpClient().Do(req)
if err != nil { if err != nil {
slog.Error("sending forward request failed", f.cfg.Logger.Error("sending forward request failed",
"error", err) "error", err)
f.storeFailed(filename, doc, sha256, sha512) f.storeFailed(filename, doc, sha256, sha512)
return return
@ -263,10 +237,10 @@ func (f *Forwarder) forward(
if res.StatusCode != http.StatusCreated { if res.StatusCode != http.StatusCreated {
defer res.Body.Close() defer res.Body.Close()
if msg, err := limitedString(res.Body, 512); err != nil { if msg, err := limitedString(res.Body, 512); err != nil {
slog.Error("reading forward result failed", f.cfg.Logger.Error("reading forward result failed",
"error", err) "error", err)
} else { } else {
slog.Error("forwarding failed", f.cfg.Logger.Error("forwarding failed",
"filename", filename, "filename", filename,
"body", msg, "body", msg,
"status_code", res.StatusCode) "status_code", res.StatusCode)
@ -274,7 +248,7 @@ func (f *Forwarder) forward(
f.storeFailed(filename, doc, sha256, sha512) f.storeFailed(filename, doc, sha256, sha512)
} else { } else {
f.succeeded++ f.succeeded++
slog.Debug( f.cfg.Logger.Debug(
"forwarding succeeded", "forwarding succeeded",
"filename", filename) "filename", filename)
} }

View file

@ -19,26 +19,24 @@ import (
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/csaf-poc/csaf_distribution/v3/internal/options"
"github.com/csaf-poc/csaf_distribution/v3/util" "github.com/csaf-poc/csaf_distribution/v3/util"
) )
func TestValidationStatusUpdate(t *testing.T) { func TestValidationStatusUpdate(t *testing.T) {
sv := validValidationStatus sv := ValidValidationStatus
sv.update(invalidValidationStatus) sv.update(InvalidValidationStatus)
sv.update(validValidationStatus) sv.update(ValidValidationStatus)
if sv != invalidValidationStatus { if sv != InvalidValidationStatus {
t.Fatalf("got %q expected %q", sv, invalidValidationStatus) t.Fatalf("got %q expected %q", sv, InvalidValidationStatus)
} }
sv = notValidatedValidationStatus sv = NotValidatedValidationStatus
sv.update(validValidationStatus) sv.update(ValidValidationStatus)
sv.update(notValidatedValidationStatus) sv.update(NotValidatedValidationStatus)
if sv != notValidatedValidationStatus { if sv != NotValidatedValidationStatus {
t.Fatalf("got %q expected %q", sv, notValidatedValidationStatus) t.Fatalf("got %q expected %q", sv, NotValidatedValidationStatus)
} }
} }
@ -51,9 +49,10 @@ func TestForwarderLogStats(t *testing.T) {
Level: slog.LevelInfo, Level: slog.LevelInfo,
}) })
lg := slog.New(h) lg := slog.New(h)
slog.SetDefault(lg)
cfg := &Config{} cfg := &Config{
Logger: lg,
}
fw := NewForwarder(cfg) fw := NewForwarder(cfg)
fw.failed = 11 fw.failed = 11
fw.succeeded = 13 fw.succeeded = 13
@ -100,7 +99,7 @@ func TestForwarderHTTPClient(t *testing.T) {
ForwardHeader: http.Header{ ForwardHeader: http.Header{
"User-Agent": []string{"curl/7.55.1"}, "User-Agent": []string{"curl/7.55.1"},
}, },
LogLevel: &options.LogLevel{Level: slog.LevelDebug}, Logger: slog.Default(),
} }
fw := NewForwarder(cfg) fw := NewForwarder(cfg)
if c1, c2 := fw.httpClient(), fw.httpClient(); c1 != c2 { if c1, c2 := fw.httpClient(), fw.httpClient(); c1 != c2 {
@ -122,7 +121,6 @@ func TestForwarderReplaceExtension(t *testing.T) {
} }
func TestForwarderBuildRequest(t *testing.T) { func TestForwarderBuildRequest(t *testing.T) {
// Good case ... // Good case ...
cfg := &Config{ cfg := &Config{
ForwardURL: "https://example.com", ForwardURL: "https://example.com",
@ -131,10 +129,9 @@ func TestForwarderBuildRequest(t *testing.T) {
req, err := fw.buildRequest( req, err := fw.buildRequest(
"test.json", "{}", "test.json", "{}",
invalidValidationStatus, InvalidValidationStatus,
"256", "256",
"512") "512")
if err != nil { if err != nil {
t.Fatalf("buildRequest failed: %v", err) t.Fatalf("buildRequest failed: %v", err)
} }
@ -175,9 +172,9 @@ func TestForwarderBuildRequest(t *testing.T) {
} }
foundAdvisory = true foundAdvisory = true
case contains("validation_status"): case contains("validation_status"):
if vs := validationStatus(data); vs != invalidValidationStatus { if vs := ValidationStatus(data); vs != InvalidValidationStatus {
t.Fatalf("validation_status: got %q expected %q", t.Fatalf("validation_status: got %q expected %q",
vs, invalidValidationStatus) vs, InvalidValidationStatus)
} }
foundValidationStatus = true foundValidationStatus = true
case contains("hash-256"): case contains("hash-256"):
@ -209,7 +206,7 @@ func TestForwarderBuildRequest(t *testing.T) {
if _, err := fw.buildRequest( if _, err := fw.buildRequest(
"test.json", "{}", "test.json", "{}",
invalidValidationStatus, InvalidValidationStatus,
"256", "256",
"512", "512",
); err == nil { ); err == nil {
@ -241,101 +238,6 @@ func TestLimitedString(t *testing.T) {
} }
} }
func TestStoreFailedAdvisory(t *testing.T) {
dir, err := os.MkdirTemp("", "storeFailedAdvisory")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
cfg := &Config{Directory: dir}
fw := NewForwarder(cfg)
badDir := filepath.Join(dir, failedForwardDir)
if err := os.WriteFile(badDir, []byte("test"), 0664); err != nil {
t.Fatal(err)
}
if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil {
t.Fatal("if the destination exists as a file an error should occur")
}
if err := os.Remove(badDir); err != nil {
t.Fatal(err)
}
if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err != nil {
t.Fatal(err)
}
sha256Path := filepath.Join(dir, failedForwardDir, "advisory.json.sha256")
// Write protect advisory.
if err := os.Chmod(sha256Path, 0); err != nil {
t.Fatal(err)
}
if err := fw.storeFailedAdvisory("advisory.json", "{}", "256", "512"); err == nil {
t.Fatal("expected to fail with an error")
}
if err := os.Chmod(sha256Path, 0644); err != nil {
t.Fatal(err)
}
}
func TestStoredFailed(t *testing.T) {
dir, err := os.MkdirTemp("", "storeFailed")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dir)
orig := slog.Default()
defer slog.SetDefault(orig)
var buf bytes.Buffer
h := slog.NewJSONHandler(&buf, &slog.HandlerOptions{
Level: slog.LevelError,
})
lg := slog.New(h)
slog.SetDefault(lg)
cfg := &Config{Directory: dir}
fw := NewForwarder(cfg)
// An empty filename should lead to an error.
fw.storeFailed("", "{}", "256", "512")
if fw.failed != 1 {
t.Fatalf("got %d expected 1", fw.failed)
}
type entry struct {
Msg string `json:"msg"`
Level string `json:"level"`
}
sc := bufio.NewScanner(bytes.NewReader(buf.Bytes()))
found := false
for sc.Scan() {
var e entry
if err := json.Unmarshal(sc.Bytes(), &e); err != nil {
t.Fatalf("JSON parsing log failed: %v", err)
}
if e.Msg == "Storing advisory failed forwarding failed" && e.Level == "ERROR" {
found = true
break
}
}
if err := sc.Err(); err != nil {
t.Fatalf("scanning log failed: %v", err)
}
if !found {
t.Fatal("Cannot error logging statistics in log")
}
}
type fakeClient struct { type fakeClient struct {
util.Client util.Client
state int state int
@ -383,11 +285,15 @@ func TestForwarderForward(t *testing.T) {
// in the other test cases. // in the other test cases.
h := slog.NewJSONHandler(io.Discard, nil) h := slog.NewJSONHandler(io.Discard, nil)
lg := slog.New(h) lg := slog.New(h)
slog.SetDefault(lg)
failedHandler := func(filename, doc, sha256, sha512 string) error {
return nil
}
cfg := &Config{ cfg := &Config{
ForwardURL: "http://example.com", ForwardURL: "http://example.com",
Directory: dir, Logger: lg,
FailedForwardHandler: failedHandler,
} }
fw := NewForwarder(cfg) fw := NewForwarder(cfg)
@ -405,7 +311,7 @@ func TestForwarderForward(t *testing.T) {
for i := 0; i <= 3; i++ { for i := 0; i <= 3; i++ {
fw.forward( fw.forward(
"test.json", "{}", "test.json", "{}",
invalidValidationStatus, InvalidValidationStatus,
"256", "256",
"512") "512")
} }
@ -419,7 +325,7 @@ func TestForwarderForward(t *testing.T) {
<-wait <-wait
fw.forward( fw.forward(
"test.json", "{}", "test.json", "{}",
invalidValidationStatus, InvalidValidationStatus,
"256", "256",
"512") "512")