diff --git a/regression.go b/regression.go index 3b9a227..5a01b1e 100644 --- a/regression.go +++ b/regression.go @@ -8,6 +8,10 @@ import ( "strings" "gonum.org/v1/gonum/mat" + "gonum.org/v1/plot" + "gonum.org/v1/plot/plotter" + "gonum.org/v1/plot/plotutil" + "gonum.org/v1/plot/vg" ) var ( @@ -283,6 +287,41 @@ func (r *Regression) String() string { return str } +//Plot regression and return file path where its saved +func (r *Regression) Plot() string { + if !r.initialised { + return errNotEnoughData.Error() + } + + p, err := plot.New() + if err != nil { + panic(err) + } + p.Title.Text = r.GetObserved() + p.X.Label.Text = r.GetVar(1) + p.Y.Label.Text = r.GetVar(0) + + observed := make(plotter.XYs, len(r.data)) + predicted := make(plotter.XYs, len(r.data)) + for i, d := range r.data { + observed[i].Y = d.Observed + observed[i].X = float64(i) + predicted[i].Y = d.Predicted + predicted[i].X = float64(i) + } + + err = plotutil.AddLinePoints(p, + "Observed", observed, + "Predicted", predicted) + if err != nil { + panic(err) + } + if err := p.Save(4*vg.Inch, 4*vg.Inch, "points.png"); err != nil { + panic(err) + } + return "points.png" +} + // MakeDataPoints makes a `[]*dataPoint` from a `[][]float64`. The expected fomat for the input is a row-major [][]float64. // That is to say the first slice represents a row, and the second represents the cols. // Furthermore it is expected that all the col slices are of the same length.