@@ -3,6 +3,7 @@ package pricing
33import (
44 "context"
55 "fmt"
6+ "slices"
67 "sync"
78 "time"
89
@@ -26,6 +27,8 @@ type PricingManager struct {
2627 pricingData map [string ]configstore.TableModelPricing
2728 mu sync.RWMutex
2829
30+ modelPool map [schemas.ModelProvider ][]string
31+
2932 // Background sync worker
3033 syncTicker * time.Ticker
3134 done chan struct {}
@@ -75,9 +78,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
7578 configStore : configStore ,
7679 logger : logger ,
7780 pricingData : make (map [string ]configstore.TableModelPricing ),
81+ modelPool : make (map [schemas.ModelProvider ][]string ),
7882 done : make (chan struct {}),
7983 }
8084
85+ logger .Info ("initializing pricing manager..." )
86+
8187 if configStore != nil {
8288 // Load initial pricing data
8389 if err := pm .loadPricingFromDatabase (ctx ); err != nil {
@@ -88,14 +94,16 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
8894 if err := pm .syncPricing (ctx ); err != nil {
8995 return nil , fmt .Errorf ("failed to sync pricing data: %w" , err )
9096 }
91-
9297 } else {
9398 // Load pricing data from config memory
9499 if err := pm .loadPricingIntoMemory (ctx ); err != nil {
95100 return nil , fmt .Errorf ("failed to load pricing data from config memory: %w" , err )
96101 }
97102 }
98103
104+ // Populate model pool with normalized providers
105+ pm .populateModelPool ()
106+
99107 // Start background sync worker
100108 pm .syncCtx , pm .syncCancel = context .WithCancel (ctx )
101109 pm .startSyncWorker (pm .syncCtx )
@@ -333,6 +341,80 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string,
333341 return totalCost
334342}
335343
344+ // populateModelPool populates the model pool with all available models per provider (thread-safe)
345+ func (pm * PricingManager ) populateModelPool () {
346+ // Acquire write lock for the entire rebuild operation
347+ pm .mu .Lock ()
348+ defer pm .mu .Unlock ()
349+
350+ // Clear existing model pool
351+ pm .modelPool = make (map [schemas.ModelProvider ][]string )
352+
353+ // Map to track unique models per provider
354+ providerModels := make (map [schemas.ModelProvider ]map [string ]bool )
355+
356+ // Iterate through all pricing data to collect models per provider
357+ for _ , pricing := range pm .pricingData {
358+ // Normalize provider before adding to model pool
359+ normalizedProvider := schemas .ModelProvider (normalizeProvider (pricing .Provider ))
360+
361+ // Initialize map for this provider if not exists
362+ if providerModels [normalizedProvider ] == nil {
363+ providerModels [normalizedProvider ] = make (map [string ]bool )
364+ }
365+
366+ // Add model to the provider's model set (using map for deduplication)
367+ providerModels [normalizedProvider ][pricing.Model ] = true
368+ }
369+
370+ // Convert sets to slices and assign to modelPool
371+ for provider , modelSet := range providerModels {
372+ models := make ([]string , 0 , len (modelSet ))
373+ for model := range modelSet {
374+ models = append (models , model )
375+ }
376+ pm .modelPool [provider ] = models
377+ }
378+
379+ // Log the populated model pool for debugging
380+ totalModels := 0
381+ for provider , models := range pm .modelPool {
382+ totalModels += len (models )
383+ pm .logger .Debug ("populated %d models for provider %s" , len (models ), string (provider ))
384+ }
385+ pm .logger .Info ("populated model pool with %d models across %d providers" , totalModels , len (pm .modelPool ))
386+ }
387+
388+ // GetModelsForProvider returns all available models for a given provider (thread-safe)
389+ func (pm * PricingManager ) GetModelsForProvider (provider schemas.ModelProvider ) []string {
390+ pm .mu .RLock ()
391+ defer pm .mu .RUnlock ()
392+
393+ models , exists := pm .modelPool [provider ]
394+ if ! exists {
395+ return []string {}
396+ }
397+
398+ // Return a copy to prevent external modification
399+ result := make ([]string , len (models ))
400+ copy (result , models )
401+ return result
402+ }
403+
404+ // GetProvidersForModel returns all providers for a given model (thread-safe)
405+ func (pm * PricingManager ) GetProvidersForModel (model string ) []schemas.ModelProvider {
406+ pm .mu .RLock ()
407+ defer pm .mu .RUnlock ()
408+
409+ providers := make ([]schemas.ModelProvider , 0 )
410+ for provider , models := range pm .modelPool {
411+ if slices .Contains (models , model ) {
412+ providers = append (providers , provider )
413+ }
414+ }
415+ return providers
416+ }
417+
336418// getPricing returns pricing information for a model (thread-safe)
337419func (pm * PricingManager ) getPricing (model , provider string , requestType schemas.RequestType ) (* configstore.TableModelPricing , bool ) {
338420 pm .mu .RLock ()
0 commit comments