diff --git a/regression.go b/regression.go index 9fcd421..a5c346d 100644 --- a/regression.go +++ b/regression.go @@ -7,7 +7,12 @@ import ( "strconv" "strings" + "gonum.org/v1/plot/plotutil" + "gonum.org/v1/plot/vg" + "gonum.org/v1/gonum/mat" + "gonum.org/v1/plot" + "gonum.org/v1/plot/plotter" ) var ( @@ -68,17 +73,17 @@ func (r *Regression) Predict(vars []float64) (float64, error) { return p, nil } -// Set the name of the observed value +// SetObserved sets the name of the observed value func (r *Regression) SetObserved(name string) { r.names.obs = name } -// Get the name of the observed value +// GetObserved gets the name of the observed value func (r *Regression) GetObserved() string { return r.names.obs } -// Set the name of variable i +// SetVar sets the name of variable i func (r *Regression) SetVar(i int, name string) { if len(r.names.vars) == 0 { r.names.vars = make(map[int]string, 5) @@ -86,7 +91,7 @@ func (r *Regression) SetVar(i int, name string) { r.names.vars[i] = name } -// Get the name of variable i +// GetVar gets the name of variable i func (r *Regression) GetVar(i int) string { x := r.names.vars[i] if x == "" { @@ -96,7 +101,7 @@ func (r *Regression) GetVar(i int) string { return x } -// Registers a feature cross to be applied to the data points. +// AddCross registers a feature cross to be applied to the data points. func (r *Regression) AddCross(cross featureCross) { r.crosses = append(r.crosses, cross) } @@ -200,7 +205,7 @@ func (r *Regression) Run() error { return nil } -// Return the calulated coefficient for variable i +// Coeff returns the calulated coefficient for variable i func (r *Regression) Coeff(i int) float64 { if len(r.coeff) == 0 { return 0 @@ -254,6 +259,39 @@ func (r *Regression) calcResiduals() string { return str } +// Plot draws a graph for observerd points and saves the output as an image to given input savePath. +func (r *Regression) Plot(savePath string) error { + if !r.initialised { + return errNotEnoughData + } + p, err := plot.New() + if err != nil { + return err + } + p.X.Label.Text = "X" + p.Y.Label.Text = "Y" + //Each X variables have its own plotter.XYs + plotters := make([]plotter.XYs, len(r.names.vars)) + for i := 0; i < len(r.names.vars); i++ { + plotters[i] = make(plotter.XYs, len(r.data)) + } + for i, points := range r.data { + for j, variable := range points.Variables { + plotters[j][i].X = variable + plotters[j][i].Y = points.Observed + } + } + for i, name := range r.names.vars { + if err := plotutil.AddLinePoints(p, name, plotters[i]); err != nil { + return err + } + } + if err := p.Save(4*vg.Inch, 4*vg.Inch, savePath); err != nil { + return err + } + return nil +} + // Display a dataPoint as a string func (d *dataPoint) String() string { str := fmt.Sprintf("%.2f", d.Observed)