Skip to content

Commit 49b5874

Browse files
committed
organize to base_prompt_scorer
1 parent d4a3ad1 commit 49b5874

File tree

3 files changed

+129
-42
lines changed

3 files changed

+129
-42
lines changed

pkg/scorers/api_scorers/prompt_scorer/base_prompt_scorer.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package prompt_scorer
22

33
import (
44
"fmt"
5+
"strconv"
56

67
"github.com/JudgmentLabs/judgeval-go/pkg/data"
78
"github.com/JudgmentLabs/judgeval-go/pkg/env"
@@ -18,6 +19,25 @@ type BasePromptScorer struct {
1819
organizationID string
1920
}
2021

22+
type ScorerOptions struct {
23+
APIKey string
24+
OrganizationID string
25+
}
26+
27+
type ScorerOption func(*ScorerOptions)
28+
29+
func WithAPIKey(apiKey string) ScorerOption {
30+
return func(opts *ScorerOptions) {
31+
opts.APIKey = apiKey
32+
}
33+
}
34+
35+
func WithOrganizationID(orgID string) ScorerOption {
36+
return func(opts *ScorerOptions) {
37+
opts.OrganizationID = orgID
38+
}
39+
}
40+
2141
func NewBasePromptScorer(
2242
scoreType data.APIScorerType,
2343
name string,
@@ -38,6 +58,26 @@ func NewBasePromptScorer(
3858
}
3959
}
4060

61+
func parseScorerOptions(options interface{}) map[string]float64 {
62+
result := make(map[string]float64)
63+
if options == nil {
64+
return result
65+
}
66+
67+
if optionsMap, ok := options.(map[string]interface{}); ok {
68+
for k, v := range optionsMap {
69+
if num, ok := v.(float64); ok {
70+
result[k] = num
71+
} else if str, ok := v.(string); ok {
72+
if num, err := strconv.ParseFloat(str, 64); err == nil {
73+
result[k] = num
74+
}
75+
}
76+
}
77+
}
78+
return result
79+
}
80+
4181
func ScorerExists(name, judgmentAPIKey, organizationID string) (bool, error) {
4282
client := api.NewClient(env.JudgmentAPIURL, judgmentAPIKey, organizationID)
4383
request := &models.ScorerExistsRequest{
@@ -119,3 +159,22 @@ func (bps *BasePromptScorer) GetOptions() map[string]float64 {
119159
}
120160
return result
121161
}
162+
163+
func (bps *BasePromptScorer) GetScorerConfig() models.ScorerConfig {
164+
config := bps.APIScorer.GetScorerConfig()
165+
166+
kwargs := make(map[string]interface{})
167+
kwargs["prompt"] = bps.GetPrompt()
168+
if bps.GetOptions() != nil {
169+
kwargs["options"] = bps.GetOptions()
170+
}
171+
172+
if bps.APIScorer.AdditionalProperties != nil {
173+
for k, v := range bps.APIScorer.AdditionalProperties {
174+
kwargs[k] = v
175+
}
176+
}
177+
178+
config.Kwargs = kwargs
179+
return config
180+
}

pkg/scorers/api_scorers/prompt_scorer/prompt_scorer.go

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@ package prompt_scorer
22

33
import (
44
"fmt"
5-
"strconv"
65

76
"github.com/JudgmentLabs/judgeval-go/pkg/data"
87
"github.com/JudgmentLabs/judgeval-go/pkg/env"
9-
"github.com/JudgmentLabs/judgeval-go/pkg/internal/api/models"
108
)
119

1210
type PromptScorer struct {
1311
*BasePromptScorer
1412
}
1513

16-
func Get(name string) (*PromptScorer, error) {
17-
return GetWithCredentials(name, env.JudgmentAPIKey, env.JudgmentOrgID)
18-
}
14+
func Get(name string, opts ...ScorerOption) (*PromptScorer, error) {
15+
options := &ScorerOptions{
16+
APIKey: env.JudgmentAPIKey,
17+
OrganizationID: env.JudgmentOrgID,
18+
}
19+
20+
for _, opt := range opts {
21+
opt(options)
22+
}
1923

20-
func GetWithCredentials(name, judgmentAPIKey, organizationID string) (*PromptScorer, error) {
21-
scorerConfig, err := FetchPromptScorer(name, judgmentAPIKey, organizationID)
24+
scorerConfig, err := FetchPromptScorer(name, options.APIKey, options.OrganizationID)
2225
if err != nil {
2326
return nil, err
2427
}
@@ -27,21 +30,7 @@ func GetWithCredentials(name, judgmentAPIKey, organizationID string) (*PromptSco
2730
return nil, fmt.Errorf("scorer with name %s is not a PromptScorer", name)
2831
}
2932

30-
options := make(map[string]float64)
31-
if scorerConfig.Options != nil {
32-
if optionsMap, ok := scorerConfig.Options.(map[string]interface{}); ok {
33-
for k, v := range optionsMap {
34-
if num, ok := v.(float64); ok {
35-
options[k] = num
36-
} else if str, ok := v.(string); ok {
37-
if num, err := strconv.ParseFloat(str, 64); err == nil {
38-
options[k] = num
39-
}
40-
}
41-
}
42-
}
43-
}
44-
33+
scorerOptions := parseScorerOptions(scorerConfig.Options)
4534
threshold := 0.5
4635
if scorerConfig.Threshold != 0 {
4736
threshold = scorerConfig.Threshold
@@ -53,28 +42,13 @@ func GetWithCredentials(name, judgmentAPIKey, organizationID string) (*PromptSco
5342
name,
5443
scorerConfig.Prompt,
5544
threshold,
56-
options,
57-
judgmentAPIKey,
58-
organizationID,
45+
scorerOptions,
46+
options.APIKey,
47+
options.OrganizationID,
5948
),
6049
}, nil
6150
}
6251

63-
func (ps *PromptScorer) GetScorerConfig() models.ScorerConfig {
64-
config := ps.BasePromptScorer.APIScorer.GetScorerConfig()
65-
66-
kwargs := make(map[string]interface{})
67-
kwargs["prompt"] = ps.GetPrompt()
68-
if ps.GetOptions() != nil {
69-
kwargs["options"] = ps.GetOptions()
70-
}
71-
72-
if ps.APIScorer.AdditionalProperties != nil {
73-
for k, v := range ps.APIScorer.AdditionalProperties {
74-
kwargs[k] = v
75-
}
76-
}
77-
78-
config.Kwargs = kwargs
79-
return config
52+
func (ps *PromptScorer) IsTrace() bool {
53+
return false
8054
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package prompt_scorer
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/JudgmentLabs/judgeval-go/pkg/data"
7+
"github.com/JudgmentLabs/judgeval-go/pkg/env"
8+
)
9+
10+
type TracePromptScorer struct {
11+
*BasePromptScorer
12+
}
13+
14+
func GetTrace(name string, opts ...ScorerOption) (*TracePromptScorer, error) {
15+
options := &ScorerOptions{
16+
APIKey: env.JudgmentAPIKey,
17+
OrganizationID: env.JudgmentOrgID,
18+
}
19+
20+
for _, opt := range opts {
21+
opt(options)
22+
}
23+
24+
scorerConfig, err := FetchPromptScorer(name, options.APIKey, options.OrganizationID)
25+
if err != nil {
26+
return nil, err
27+
}
28+
29+
if !scorerConfig.IsTrace {
30+
return nil, fmt.Errorf("scorer with name %s is not a TracePromptScorer", name)
31+
}
32+
33+
scorerOptions := parseScorerOptions(scorerConfig.Options)
34+
threshold := 0.5
35+
if scorerConfig.Threshold != 0 {
36+
threshold = scorerConfig.Threshold
37+
}
38+
39+
return &TracePromptScorer{
40+
BasePromptScorer: NewBasePromptScorer(
41+
data.TracePromptScorer,
42+
name,
43+
scorerConfig.Prompt,
44+
threshold,
45+
scorerOptions,
46+
options.APIKey,
47+
options.OrganizationID,
48+
),
49+
}, nil
50+
}
51+
52+
func (tps *TracePromptScorer) IsTrace() bool {
53+
return true
54+
}

0 commit comments

Comments
 (0)