This repository has been archived by the owner on Oct 26, 2023. It is now read-only.
forked from NouamaneTazi/bloomz.cpp
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoptions.go
117 lines (99 loc) · 2.46 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package bloomz
import "runtime"
type ModelOptions struct {
ContextSize int
F16Memory bool
}
type PredictOptions struct {
Seed, Threads, Tokens, TopK, Repeat int
TopP, Temperature, Penalty float64
}
type PredictOption func(p *PredictOptions)
type ModelOption func(p *ModelOptions)
var DefaultModelOptions ModelOptions = ModelOptions{
ContextSize: 512,
F16Memory: false,
}
var DefaultOptions PredictOptions = PredictOptions{
Seed: -1,
Threads: runtime.NumCPU(),
Tokens: 128,
TopK: 10000,
TopP: 0.90,
Temperature: 0.96,
Penalty: 1,
Repeat: 64,
}
// SetContext sets the context size.
func SetContext(c int) ModelOption {
return func(p *ModelOptions) {
p.ContextSize = c
}
}
var EnableF16Memory ModelOption = func(p *ModelOptions) {
p.F16Memory = true
}
// Create a new PredictOptions object with the given options.
func NewModelOptions(opts ...ModelOption) ModelOptions {
p := DefaultModelOptions
for _, opt := range opts {
opt(&p)
}
return p
}
// SetSeed sets the random seed for sampling text generation.
func SetSeed(seed int) PredictOption {
return func(p *PredictOptions) {
p.Seed = seed
}
}
// SetThreads sets the number of threads to use for text generation.
func SetThreads(threads int) PredictOption {
return func(p *PredictOptions) {
p.Threads = threads
}
}
// SetTokens sets the number of tokens to generate.
func SetTokens(tokens int) PredictOption {
return func(p *PredictOptions) {
p.Tokens = tokens
}
}
// SetTopK sets the value for top-K sampling.
func SetTopK(topk int) PredictOption {
return func(p *PredictOptions) {
p.TopK = topk
}
}
// SetTopP sets the value for nucleus sampling.
func SetTopP(topp float64) PredictOption {
return func(p *PredictOptions) {
p.TopP = topp
}
}
// SetTemperature sets the temperature value for text generation.
func SetTemperature(temp float64) PredictOption {
return func(p *PredictOptions) {
p.Temperature = temp
}
}
// SetPenalty sets the repetition penalty for text generation.
func SetPenalty(penalty float64) PredictOption {
return func(p *PredictOptions) {
p.Penalty = penalty
}
}
// SetRepeat sets the number of times to repeat text generation.
func SetRepeat(repeat int) PredictOption {
return func(p *PredictOptions) {
p.Repeat = repeat
}
}
// Create a new PredictOptions object with the given options.
func NewPredictOptions(opts ...PredictOption) PredictOptions {
p := DefaultOptions
for _, opt := range opts {
opt(&p)
}
return p
}