Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add regression plots [WIP] #14

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions regression.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"strings"

"gonum.org/v1/gonum/mat"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)

var (
Expand Down Expand Up @@ -283,6 +287,41 @@ func (r *Regression) String() string {
return str
}

//Plot regression and return file path where its saved
func (r *Regression) Plot() string {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should:

  • return the err instead of panicking.
  • allow the user to specify where to save the file
  • needs to be more configurable in general. e.g. only the first two vars are plotted. What if there are more?

if !r.initialised {
return errNotEnoughData.Error()
}

p, err := plot.New()
if err != nil {
panic(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return err instead of panic

}
p.Title.Text = r.GetObserved()
p.X.Label.Text = r.GetVar(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The var numbers are hardcoded to 0 and 1, but what if there are > 2 vars?

p.Y.Label.Text = r.GetVar(0)

observed := make(plotter.XYs, len(r.data))
predicted := make(plotter.XYs, len(r.data))
for i, d := range r.data {
observed[i].Y = d.Observed
observed[i].X = float64(i)
predicted[i].Y = d.Predicted
predicted[i].X = float64(i)
}

err = plotutil.AddLinePoints(p,
"Observed", observed,
"Predicted", predicted)
if err != nil {
panic(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return err instead of panic

}
if err := p.Save(4*vg.Inch, 4*vg.Inch, "points.png"); err != nil {
panic(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return err instead of panic

}
return "points.png"
}

// MakeDataPoints makes a `[]*dataPoint` from a `[][]float64`. The expected fomat for the input is a row-major [][]float64.
// That is to say the first slice represents a row, and the second represents the cols.
// Furthermore it is expected that all the col slices are of the same length.
Expand Down