001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.regression;
018    
019    import org.apache.commons.math.MathRuntimeException;
020    import org.apache.commons.math.linear.RealMatrix;
021    import org.apache.commons.math.linear.Array2DRowRealMatrix;
022    import org.apache.commons.math.linear.RealVector;
023    import org.apache.commons.math.linear.ArrayRealVector;
024    
025    /**
026     * Abstract base class for implementations of MultipleLinearRegression.
027     * @version $Revision: 791244 $ $Date: 2009-07-05 09:29:37 -0400 (Sun, 05 Jul 2009) $
028     * @since 2.0
029     */
030    public abstract class AbstractMultipleLinearRegression implements
031            MultipleLinearRegression {
032    
033        /** X sample data. */
034        protected RealMatrix X;
035    
036        /** Y sample data. */
037        protected RealVector Y;
038    
039        /**
040         * Loads model x and y sample data from a flat array of data, overriding any previous sample.
041         * Assumes that rows are concatenated with y values first in each row.
042         * 
043         * @param data input data array
044         * @param nobs number of observations (rows)
045         * @param nvars number of independent variables (columns, not counting y)
046         */
047        public void newSampleData(double[] data, int nobs, int nvars) {
048            double[] y = new double[nobs];
049            double[][] x = new double[nobs][nvars + 1];
050            int pointer = 0;
051            for (int i = 0; i < nobs; i++) {
052                y[i] = data[pointer++];
053                x[i][0] = 1.0d;
054                for (int j = 1; j < nvars + 1; j++) {
055                    x[i][j] = data[pointer++];
056                }
057            }
058            this.X = new Array2DRowRealMatrix(x);
059            this.Y = new ArrayRealVector(y);
060        }
061        
062        /**
063         * Loads new y sample data, overriding any previous sample
064         * 
065         * @param y the [n,1] array representing the y sample
066         */
067        protected void newYSampleData(double[] y) {
068            this.Y = new ArrayRealVector(y);
069        }
070    
071        /**
072         * Loads new x sample data, overriding any previous sample
073         * 
074         * @param x the [n,k] array representing the x sample
075         */
076        protected void newXSampleData(double[][] x) {
077            this.X = new Array2DRowRealMatrix(x);
078        }
079    
080        /**
081         * Validates sample data.
082         * 
083         * @param x the [n,k] array representing the x sample
084         * @param y the [n,1] array representing the y sample
085         * @throws IllegalArgumentException if the x and y array data are not
086         *             compatible for the regression
087         */
088        protected void validateSampleData(double[][] x, double[] y) {
089            if ((x == null) || (y == null) || (x.length != y.length)) {
090                throw MathRuntimeException.createIllegalArgumentException(
091                      "dimension mismatch {0} != {1}",
092                      (x == null) ? 0 : x.length,
093                      (y == null) ? 0 : y.length);
094            } else if ((x.length > 0) && (x[0].length > x.length)) {
095                throw MathRuntimeException.createIllegalArgumentException(
096                      "not enough data ({0} rows) for this many predictors ({1} predictors)",
097                      x.length, x[0].length);
098            }
099        }
100    
101        /**
102         * Validates sample data.
103         * 
104         * @param x the [n,k] array representing the x sample
105         * @param covariance the [n,n] array representing the covariance matrix
106         * @throws IllegalArgumentException if the x sample data or covariance
107         *             matrix are not compatible for the regression
108         */
109        protected void validateCovarianceData(double[][] x, double[][] covariance) {
110            if (x.length != covariance.length) {
111                throw MathRuntimeException.createIllegalArgumentException(
112                     "dimension mismatch {0} != {1}", x.length, covariance.length);
113            }
114            if (covariance.length > 0 && covariance.length != covariance[0].length) {
115                throw MathRuntimeException.createIllegalArgumentException(
116                      "a {0}x{1} matrix was provided instead of a square matrix",
117                      covariance.length, covariance[0].length);
118            }
119        }
120    
121        /**
122         * {@inheritDoc}
123         */
124        public double[] estimateRegressionParameters() {
125            RealVector b = calculateBeta();
126            return b.getData();
127        }
128    
129        /**
130         * {@inheritDoc}
131         */
132        public double[] estimateResiduals() {
133            RealVector b = calculateBeta();
134            RealVector e = Y.subtract(X.operate(b));
135            return e.getData();
136        }
137    
138        /**
139         * {@inheritDoc}
140         */
141        public double[][] estimateRegressionParametersVariance() {
142            return calculateBetaVariance().getData();
143        }
144        
145        /**
146         * {@inheritDoc}
147         */
148        public double[] estimateRegressionParametersStandardErrors() {
149            double[][] betaVariance = estimateRegressionParametersVariance();
150            double sigma = calculateYVariance();
151            int length = betaVariance[0].length;
152            double[] result = new double[length];
153            for (int i = 0; i < length; i++) {
154                result[i] = Math.sqrt(sigma * betaVariance[i][i]);
155            }
156            return result;
157        }
158    
159        /**
160         * {@inheritDoc}
161         */
162        public double estimateRegressandVariance() {
163            return calculateYVariance();
164        }
165    
166        /**
167         * Calculates the beta of multiple linear regression in matrix notation.
168         * 
169         * @return beta
170         */
171        protected abstract RealVector calculateBeta();
172    
173        /**
174         * Calculates the beta variance of multiple linear regression in matrix
175         * notation.
176         * 
177         * @return beta variance
178         */
179        protected abstract RealMatrix calculateBetaVariance();
180    
181        /**
182         * Calculates the Y variance of multiple linear regression.
183         * 
184         * @return Y variance
185         */
186        protected abstract double calculateYVariance();
187    
188        /**
189         * Calculates the residuals of multiple linear regression in matrix
190         * notation.
191         * 
192         * <pre>
193         * u = y - X * b
194         * </pre>
195         * 
196         * @return The residuals [n,1] matrix
197         */
198        protected RealVector calculateResiduals() {
199            RealVector b = calculateBeta();
200            return Y.subtract(X.operate(b));
201        }
202    
203    }