diff --git a/openmeter/billing/adapter/invoice.go b/openmeter/billing/adapter/invoice.go index 0942402a4e..2bc108ea86 100644 --- a/openmeter/billing/adapter/invoice.go +++ b/openmeter/billing/adapter/invoice.go @@ -13,6 +13,7 @@ import ( "github.com/openmeterio/openmeter/api" "github.com/openmeterio/openmeter/openmeter/app" "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoice" "github.com/openmeterio/openmeter/openmeter/ent/db/billinginvoiceline" @@ -606,33 +607,56 @@ func (a *adapter) UpdateInvoice(ctx context.Context, in billing.UpdateInvoiceAda func (a *adapter) GetInvoiceOwnership(ctx context.Context, in billing.GetInvoiceOwnershipAdapterInput) (billing.GetOwnershipAdapterResponse, error) { if err := in.Validate(); err != nil { - return billing.GetOwnershipAdapterResponse{}, billing.ValidationError{ + return nil, billing.ValidationError{ Err: err, } } return entutils.TransactingRepo(ctx, a, func(ctx context.Context, tx *adapter) (billing.GetOwnershipAdapterResponse, error) { - dbInvoice, err := tx.db.BillingInvoice.Query(). - Where(billinginvoice.ID(in.ID)). - Where(billinginvoice.Namespace(in.Namespace)). - First(ctx) + dbInvoices, err := tx.db.BillingInvoice.Query(). + Where( + billinginvoice.IDIn( + lo.Map( + in.InvoiceIDs, + func(invoiceID billing.InvoiceID, _ int) string { + return invoiceID.ID + }, + )..., + ), + ). + All(ctx) if err != nil { - if db.IsNotFound(err) { - return billing.GetOwnershipAdapterResponse{}, billing.NotFoundError{ - Entity: billing.EntityInvoice, - ID: in.ID, - Err: err, + return nil, err + } + + invoiceToCustomerID := lo.SliceToMap(dbInvoices, func(dbInvoice *db.BillingInvoice) (billing.InvoiceID, customer.CustomerID) { + return billing.InvoiceID{ + Namespace: dbInvoice.Namespace, + ID: dbInvoice.ID, + }, customer.CustomerID{ + Namespace: dbInvoice.Namespace, + ID: dbInvoice.CustomerID, } + }) + + // Let's validate if we got all the invoices (and most importantly look up invoices with + // namespaceID, to prevent looking up invoices with different than expected namespace ID) + var notFoundErrs []error + for _, invoiceID := range in.InvoiceIDs { + if _, found := invoiceToCustomerID[invoiceID]; !found { + notFoundErrs = append(notFoundErrs, billing.NotFoundError{ + Entity: billing.EntityInvoice, + ID: invoiceID.ID, + Err: fmt.Errorf("invoice not found: %s", invoiceID.ID), + }) } + } - return billing.GetOwnershipAdapterResponse{}, err + if len(notFoundErrs) > 0 { + return nil, errors.Join(notFoundErrs...) } - return billing.GetOwnershipAdapterResponse{ - Namespace: dbInvoice.Namespace, - InvoiceID: dbInvoice.ID, - CustomerID: dbInvoice.CustomerID, - }, nil + return invoiceToCustomerID, nil }) } diff --git a/openmeter/billing/invoice.go b/openmeter/billing/invoice.go index 3be3116ae9..77ecb08477 100644 --- a/openmeter/billing/invoice.go +++ b/openmeter/billing/invoice.go @@ -867,14 +867,18 @@ type ( type UpdateInvoiceAdapterInput = Invoice -type GetInvoiceOwnershipAdapterInput = InvoiceID +type GetInvoiceOwnershipAdapterInput struct { + InvoiceIDs []InvoiceID +} -type GetOwnershipAdapterResponse struct { - Namespace string - InvoiceID string - CustomerID string +func (i GetInvoiceOwnershipAdapterInput) Validate() error { + if len(i.InvoiceIDs) == 0 { + return errors.New("invoice IDs are required") + } } +type GetOwnershipAdapterResponse map[InvoiceID]customer.CustomerID + type DeleteInvoiceInput = InvoiceID type UpdateInvoiceLinesInternalInput struct { @@ -1061,3 +1065,19 @@ func (i UpdateInvoiceFieldsInput) Validate() error { } type RecalculateGatheringInvoicesInput = customer.CustomerID + +type StandardImmutableInvoiceUpdate struct { + UpsertValidationIssues mo.Option[ValidationIssues] +} + +type BulkUpdateInvoicesInput struct { + IncludeDeletedLines bool + Invoices []InvoiceID + GatheringInvoiceEditFunction func(*Invoice) error + StandardMutableInvoiceEditFunction func(*Invoice) error + StandardImmutableInvoiceEditFunction func(*Invoice) (StandardImmutableInvoiceUpdate, error) +} + +type BulkUpdateInvoicesResult struct { + InvoicesByID map[InvoiceID]Invoice +} diff --git a/openmeter/billing/service.go b/openmeter/billing/service.go index 574265d3d7..6be8edacf5 100644 --- a/openmeter/billing/service.go +++ b/openmeter/billing/service.go @@ -70,6 +70,7 @@ type InvoiceService interface { DeleteInvoice(ctx context.Context, input DeleteInvoiceInput) error // UpdateInvoice updates an invoice as a whole UpdateInvoice(ctx context.Context, input UpdateInvoiceInput) (Invoice, error) + BulkUpdateInvoices(ctx context.Context, input BulkUpdateInvoicesInput) (BulkUpdateInvoicesResult, error) // SimulateInvoice generates an invoice based on the provided input, but does not persist it // can be used to execute the invoice generation logic without actually creating an invoice in the database diff --git a/openmeter/billing/service/invoicebulk.go b/openmeter/billing/service/invoicebulk.go new file mode 100644 index 0000000000..0a216dcd51 --- /dev/null +++ b/openmeter/billing/service/invoicebulk.go @@ -0,0 +1,53 @@ +package billingservice + +import ( + "context" + "fmt" + + "github.com/openmeterio/openmeter/openmeter/billing" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/pkg/framework/transaction" +) + +func (s *Service) BulkUpdateInvoices(ctx context.Context, input billing.BulkUpdateInvoicesInput) (billing.BulkUpdateInvoicesResult, error) { + if err := input.Validate(); err != nil { + return billing.BulkUpdateInvoicesResult{}, billing.ValidationError{ + Err: err, + } + } + + return transaction.Run(ctx, s.adapter, func(ctx context.Context) (billing.BulkUpdateInvoicesResult, error) { + invoiceToCustomerID, err := s.adapter.GetInvoiceOwnership(ctx, billing.GetInvoiceOwnershipAdapterInput{ + InvoiceIDs: input.Invoices, + }) + if err != nil { + return billing.BulkUpdateInvoicesResult{}, fmt.Errorf("getting invoice ownership: %w", err) + } + + invoiceIDsByCustomerID := map[customer.CustomerID][]billing.InvoiceID{} + for invoiceID, customerID := range invoiceToCustomerID { + invoiceIDsByCustomerID[customerID] = append(invoiceIDsByCustomerID[customerID], invoiceID) + } + + for customerID, invoiceIDs := range invoiceIDsByCustomerID { + xxx, err := transcationForInvoiceManipulation(ctx, s, customerID, func(ctx context.Context) (billing.BulkUpdateInvoicesResult, error) { + invoices := make([]*billing.Invoice, 0, len(invoiceIDs)) + for _, invoiceID := range invoiceIDs { + invoice, err := s.GetInvoiceByID(ctx, billing.GetInvoiceByIdInput{ + Invoice: invoiceID, + Expand: billing.InvoiceExpandAll. + SetDeletedLines(input.IncludeDeletedLines), + }) + if err != nil { + return billing.BulkUpdateInvoicesResult{}, fmt.Errorf("getting invoice[%s]: %w", invoiceID.ID, err) + } + + invoices = append(invoices, &invoice) + } + }) + if err != nil { + return billing.BulkUpdateInvoicesResult{}, fmt.Errorf("updating invoices: %w", err) + } + } + }) +} diff --git a/openmeter/billing/service/lineservice/linebase.go b/openmeter/billing/service/lineservice/linebase.go index 64998cd444..a42f53c334 100644 --- a/openmeter/billing/service/lineservice/linebase.go +++ b/openmeter/billing/service/lineservice/linebase.go @@ -106,6 +106,13 @@ func (l lineBase) Validate(ctx context.Context, invoice *billing.Invoice) error } } + // Expanding the split lines are mandatory for the lineservice to work properly. + if l.line.SplitLineGroupID != nil && l.line.SplitLineHierarchy == nil { + return billing.ValidationError{ + Err: fmt.Errorf("split line group[%s] has no expanded hierarchy, while being part of a split line group", *l.line.SplitLineGroupID), + } + } + return nil } @@ -114,10 +121,6 @@ func (l lineBase) IsLastInPeriod() bool { return true } - if l.line.SplitLineHierarchy == nil { - return true - } - if l.line.SplitLineHierarchy.Group.ServicePeriod.End.Equal(l.line.Period.End) { return true } @@ -130,10 +133,6 @@ func (l lineBase) IsFirstInPeriod() bool { return true } - if l.line.SplitLineHierarchy == nil { - return true - } - if l.line.SplitLineHierarchy.Group.ServicePeriod.Start.Equal(l.line.Period.Start) { return true }