Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot Graph #20

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions regression.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import (
"strconv"
"strings"

"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"

"gonum.org/v1/gonum/mat"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
)

var (
Expand Down Expand Up @@ -68,25 +73,25 @@ func (r *Regression) Predict(vars []float64) (float64, error) {
return p, nil
}

// Set the name of the observed value
// SetObserved sets the name of the observed value
func (r *Regression) SetObserved(name string) {
r.names.obs = name
}

// Get the name of the observed value
// GetObserved gets the name of the observed value
func (r *Regression) GetObserved() string {
return r.names.obs
}

// Set the name of variable i
// SetVar sets the name of variable i
func (r *Regression) SetVar(i int, name string) {
if len(r.names.vars) == 0 {
r.names.vars = make(map[int]string, 5)
}
r.names.vars[i] = name
}

// Get the name of variable i
// GetVar gets the name of variable i
func (r *Regression) GetVar(i int) string {
x := r.names.vars[i]
if x == "" {
Expand All @@ -96,7 +101,7 @@ func (r *Regression) GetVar(i int) string {
return x
}

// Registers a feature cross to be applied to the data points.
// AddCross registers a feature cross to be applied to the data points.
func (r *Regression) AddCross(cross featureCross) {
r.crosses = append(r.crosses, cross)
}
Expand Down Expand Up @@ -200,7 +205,7 @@ func (r *Regression) Run() error {
return nil
}

// Return the calulated coefficient for variable i
// Coeff returns the calulated coefficient for variable i
func (r *Regression) Coeff(i int) float64 {
if len(r.coeff) == 0 {
return 0
Expand Down Expand Up @@ -254,6 +259,39 @@ func (r *Regression) calcResiduals() string {
return str
}

// Plot draws a graph for observerd points and saves the output as an image to given input savePath.
func (r *Regression) Plot(savePath string) error {
if !r.initialised {
return errNotEnoughData
}
p, err := plot.New()
if err != nil {
return err
}
p.X.Label.Text = "X"
p.Y.Label.Text = "Y"
//Each X variables have its own plotter.XYs
plotters := make([]plotter.XYs, len(r.names.vars))
for i := 0; i < len(r.names.vars); i++ {
plotters[i] = make(plotter.XYs, len(r.data))
}
for i, points := range r.data {
for j, variable := range points.Variables {
plotters[j][i].X = variable
plotters[j][i].Y = points.Observed
}
}
for i, name := range r.names.vars {
if err := plotutil.AddLinePoints(p, name, plotters[i]); err != nil {
return err
}
}
if err := p.Save(4*vg.Inch, 4*vg.Inch, savePath); err != nil {
return err
}
return nil
}

// Display a dataPoint as a string
func (d *dataPoint) String() string {
str := fmt.Sprintf("%.2f", d.Observed)
Expand Down