diff --git a/billing/checkout/service.go b/billing/checkout/service.go index 43b4925bd..7ec7af315 100644 --- a/billing/checkout/service.go +++ b/billing/checkout/service.go @@ -116,6 +116,7 @@ type Service struct { paymentMethodConfig []billing.PaymentMethodConfig syncJob *cron.Cron + syncJobMu sync.Mutex mu sync.Mutex syncDelay time.Duration } @@ -148,6 +149,9 @@ func (s *Service) Init(ctx context.Context) error { if s.syncDelay == time.Duration(0) { return nil } + + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() } @@ -169,9 +173,10 @@ func (s *Service) Init(ctx context.Context) error { } func (s *Service) Close() error { + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() - return s.syncJob.Stop().Err() } return nil } diff --git a/billing/checkout/service_concurrent_test.go b/billing/checkout/service_concurrent_test.go new file mode 100644 index 000000000..bc9555a20 --- /dev/null +++ b/billing/checkout/service_concurrent_test.go @@ -0,0 +1,35 @@ +package checkout + +import ( + "context" + "io" + "log/slog" + "sync" + "testing" + "time" +) + +// TestService_InitClose_Concurrent guards against an unsynchronized syncJob field. +// Two goroutines is the minimum needed to surface the race under `go test -race`. +func TestService_InitClose_Concurrent(t *testing.T) { + s := &Service{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + syncDelay: time.Hour, + } + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Init(context.Background()); err != nil { + t.Errorf("Init: %v", err) + } + }() + } + wg.Wait() + + if err := s.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} diff --git a/billing/customer/service.go b/billing/customer/service.go index 7a2ed4ec2..e6ed72b68 100644 --- a/billing/customer/service.go +++ b/billing/customer/service.go @@ -45,6 +45,7 @@ type Service struct { creditService CreditService syncJob *cron.Cron + syncJobMu sync.Mutex mu sync.Mutex syncDelay time.Duration } @@ -358,6 +359,9 @@ func (s *Service) Init(ctx context.Context) error { if s.syncDelay == time.Duration(0) { return nil } + + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() } @@ -378,9 +382,10 @@ func (s *Service) Init(ctx context.Context) error { } func (s *Service) Close() error { + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() - return s.syncJob.Stop().Err() } return nil } diff --git a/billing/customer/service_concurrent_test.go b/billing/customer/service_concurrent_test.go new file mode 100644 index 000000000..c248b9abf --- /dev/null +++ b/billing/customer/service_concurrent_test.go @@ -0,0 +1,35 @@ +package customer + +import ( + "context" + "io" + "log/slog" + "sync" + "testing" + "time" +) + +// TestService_InitClose_Concurrent guards against an unsynchronized syncJob field. +// Two goroutines is the minimum needed to surface the race under `go test -race`. +func TestService_InitClose_Concurrent(t *testing.T) { + s := &Service{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + syncDelay: time.Hour, + } + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Init(context.Background()); err != nil { + t.Errorf("Init: %v", err) + } + }() + } + wg.Wait() + + if err := s.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} diff --git a/billing/invoice/service.go b/billing/invoice/service.go index 3d5eceac0..0f300663a 100644 --- a/billing/invoice/service.go +++ b/billing/invoice/service.go @@ -74,6 +74,7 @@ type Service struct { locker Locker syncJob *cron.Cron + syncJobMu sync.Mutex mu sync.Mutex syncDelay time.Duration @@ -112,23 +113,8 @@ func NewService(logger *slog.Logger, stripeClient *client.API, invoiceRepository } func (s *Service) Init(ctx context.Context) error { - if s.syncDelay != time.Duration(0) { - if s.syncJob != nil { - s.syncJob.Stop() - } - s.syncJob = cron.New(cron.WithChain( - cron.SkipIfStillRunning(cron.DefaultLogger), - cron.Recover(cron.DefaultLogger), - )) - - if _, err := s.syncJob.AddFunc(fmt.Sprintf("@every %s", s.syncDelay.String()), func() { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - s.backgroundSync(ctx) - }); err != nil { - return err - } - s.syncJob.Start() + if err := s.initSyncJob(ctx); err != nil { + return err } if s.creditOverdraftProduct != "" { @@ -156,9 +142,36 @@ func (s *Service) Init(ctx context.Context) error { return nil } +func (s *Service) initSyncJob(ctx context.Context) error { + if s.syncDelay == time.Duration(0) { + return nil + } + + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() + if s.syncJob != nil { + <-s.syncJob.Stop().Done() + } + s.syncJob = cron.New(cron.WithChain( + cron.SkipIfStillRunning(cron.DefaultLogger), + cron.Recover(cron.DefaultLogger), + )) + if _, err := s.syncJob.AddFunc(fmt.Sprintf("@every %s", s.syncDelay.String()), func() { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + s.backgroundSync(ctx) + }); err != nil { + return err + } + s.syncJob.Start() + return nil +} + func (s *Service) Close() error { + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { - return s.syncJob.Stop().Err() + <-s.syncJob.Stop().Done() } return nil } diff --git a/billing/invoice/service_concurrent_test.go b/billing/invoice/service_concurrent_test.go new file mode 100644 index 000000000..fe79c8f0e --- /dev/null +++ b/billing/invoice/service_concurrent_test.go @@ -0,0 +1,35 @@ +package invoice + +import ( + "context" + "io" + "log/slog" + "sync" + "testing" + "time" +) + +// TestService_InitClose_Concurrent guards against an unsynchronized syncJob field. +// Two goroutines is the minimum needed to surface the race under `go test -race`. +func TestService_InitClose_Concurrent(t *testing.T) { + s := &Service{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + syncDelay: time.Hour, + } + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Init(context.Background()); err != nil { + t.Errorf("Init: %v", err) + } + }() + } + wg.Wait() + + if err := s.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} diff --git a/billing/subscription/service.go b/billing/subscription/service.go index 72ff19c81..b57d2c51b 100644 --- a/billing/subscription/service.go +++ b/billing/subscription/service.go @@ -77,9 +77,10 @@ type Service struct { productService ProductService creditService CreditService - syncJob *cron.Cron - mu sync.Mutex - config billing.Config + syncJob *cron.Cron + syncJobMu sync.Mutex + mu sync.Mutex + config billing.Config } func NewService(logger *slog.Logger, stripeClient *client.API, config billing.Config, repository Repository, @@ -116,6 +117,9 @@ func (s *Service) Init(ctx context.Context) error { if syncDelay == time.Duration(0) { return nil } + + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() } @@ -136,9 +140,10 @@ func (s *Service) Init(ctx context.Context) error { } func (s *Service) Close() error { + s.syncJobMu.Lock() + defer s.syncJobMu.Unlock() if s.syncJob != nil { <-s.syncJob.Stop().Done() - return s.syncJob.Stop().Err() } return nil } diff --git a/billing/subscription/service_concurrent_test.go b/billing/subscription/service_concurrent_test.go new file mode 100644 index 000000000..d5a6bd636 --- /dev/null +++ b/billing/subscription/service_concurrent_test.go @@ -0,0 +1,37 @@ +package subscription + +import ( + "context" + "io" + "log/slog" + "sync" + "testing" + "time" + + "github.com/raystack/frontier/billing" +) + +// TestService_InitClose_Concurrent guards against an unsynchronized syncJob field. +// Two goroutines is the minimum needed to surface the race under `go test -race`. +func TestService_InitClose_Concurrent(t *testing.T) { + s := &Service{ + log: slog.New(slog.NewTextHandler(io.Discard, nil)), + config: billing.Config{RefreshInterval: billing.RefreshInterval{Subscription: time.Hour}}, + } + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Init(context.Background()); err != nil { + t.Errorf("Init: %v", err) + } + }() + } + wg.Wait() + + if err := s.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} diff --git a/internal/store/blob/resources_repository.go b/internal/store/blob/resources_repository.go index b8dfbcbd4..c8eb932fc 100644 --- a/internal/store/blob/resources_repository.go +++ b/internal/store/blob/resources_repository.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "sync" "time" @@ -11,8 +12,6 @@ import ( "github.com/raystack/frontier/core/namespace" "github.com/raystack/frontier/core/resource" - "log/slog" - "github.com/ghodss/yaml" "github.com/pkg/errors" "github.com/robfig/cron/v3" @@ -63,7 +62,10 @@ func (repo *ResourcesRepository) GetAll(ctx context.Context) ([]resource.YAML, e } err := repo.refresh(ctx) - return repo.cached, err + repo.mu.Lock() + currentCache = repo.cached + repo.mu.Unlock() + return currentCache, err } func (repo *ResourcesRepository) GetRelationsForNamespace(ctx context.Context, namespaceID string) (map[string]bool, error) {