线性、逻辑回归的java实现
线性、逻辑回归的java实现
线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。
将样本设为Sample类,其中采用数组作为特征的存储形式。
1. 样本类Sample
1 public class Sample { 2 3 double[] features; 4 int feaNum; // the number of sample's features 5 double value; // value of sample in regression 6 int label; // class of sample 7 8 public Sample(int number) { 9 feaNum = number; 10 features = new double[feaNum]; 11 } 12 13 public void outSample() { 14 System.out.println("The sample's features are:"); 15 for(int i = 0; i < feaNum; i++) { 16 System.out.print(features[i] + " "); 17 } 18 System.out.println(); 19 System.out.println("The label is: " + label); 20 System.out.println("The value is: " + value); 21 } 22 }
2. 抽象类Regression
public abstract class Regression { double[] theta; //parameters int paraNum; //the number of parameters double rate; //learning rate Sample[] sam; // samples int samNum; // the number of samples double th; // threshold value /** * initialize the samples * @param s : training set * @param num : the number of training samples */ public void Initialize(Sample[] s, int num) { samNum = num; sam = new Sample[samNum]; for(int i = 0; i < samNum; i++) { sam[i] = s[i]; } } /** * initialize all parameters * @param para : theta * @param learning_rate * @param threshold */ public void setPara(double[] para, double learning_rate, double threshold) { paraNum = para.length; theta = para; rate = learning_rate; th = threshold; } /** * predicte the value of sample s * @param s : prediction sample * @return : predicted value */ public abstract double PreVal(Sample s); /** * calculate the cost of all samples * @return : the cost */ public abstract double CostFun(); /** * update the theta */ public abstract void Update(); public void OutputTheta() { System.out.println("The parameters are:"); for(int i = 0; i < paraNum; i++) { System.out.print(theta[i] + " "); } System.out.println(CostFun()); } }
3. 线性回归LinearRegression
public class LinearRegression extends Regression{ public double PreVal(Sample s) { double val = 0; for(int i = 0; i < paraNum; i++) { val += theta[i] * s.features[i]; } return val; } public double CostFun() { double sum = 0; for(int i = 0; i < samNum; i++) { double d = PreVal(sam[i]) - sam[i].value; sum += Math.pow(d, 2); } return sum / (2*samNum); } public void Update() { double former = 0; // the cost before update double latter = CostFun(); // the cost after update double d = 0; double[] p = new double[paraNum]; do { former = latter; //update theta for(int i = 0; i < paraNum; i++) { // for theta[i] for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } theta = p; latter = CostFun(); }while(former - latter > th); } }
4. 逻辑回归LogisticRegression
public class LogisticRegression extends Regression{ public double PreVal(Sample s) { double val = 0; for(int i = 0; i < paraNum; i++) { val += theta[i] * s.features[i]; } return 1/(1 + Math.pow(Math.E, -val)); } public double CostFun() { double sum = 0; for(int i = 0; i < samNum; i++) { double p = PreVal(sam[i]); double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p); sum += d; } return -1 * (sum / samNum); } public void Update() { double former = 0; // the cost before update double latter = CostFun(); // the cost after update double d = 0; double[] p = new double[paraNum]; do { former = latter; //update theta for(int i = 0; i < paraNum; i++) { // for theta[i] for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } latter = CostFun(); }while(former - latter > th); theta = p; } }
5. 使用的线性回归样本
x0 x1 x2 x3 x4 y
1 2104 5 1 45 460
1 1416 3 2 40 232
1 1534 3 2 30 315
1 852 2 1 36 178
1 1254 3 3 45 321
1 987 2 2 35 241
1 1054 3 2 30 287
1 645 2 3 25 87
1 542 2 1 30 94
1 1065 3 1 25 241
1 2465 7 2 50 687
1 2410 6 1 45 654
1 1987 4 2 45 436
1 457 2 3 35 65
1 587 2 2 25 54
1 468 2 1 40 87
1 1354 3 1 35 215
1 1587 4 1 45 345
1 1789 4 2 35 325
1 2500 8 2 40 720
6. 线性回归测试
import java.io.IOException; import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException { //read Sample.txt Sample[] sam = new Sample[25]; int w = 0; long filePoint = 0; String s; RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r"); long fileLength = file.length(); while(filePoint < fileLength) { s = file.readLine(); //s --> sample String[] sub = s.split(" "); sam[w] = new Sample(sub.length - 1); for(int i = 0; i < sub.length; i++) { if(i == sub.length - 1) { sam[w].value = Double.parseDouble(sub[i]); } else { sam[w].features[i] = Double.parseDouble(sub[i]); } }//for w++; filePoint = file.getFilePointer(); }//while read file LinearRegression lr = new LinearRegression(); double[] para = {0,0,0,0,0}; double rate = 0.5; double th = 0.001; lr.Initialize(sam, w); lr.setPara(para, rate, th); lr.Update(); lr.OutputTheta(); } }
7. 使用的逻辑回归样本
x0 x1 x2 class
1 0.23 0.35 0
1 0.32 0.24 0
1 0.6 0.12 0
1 0.36 0.54 0
1 0.02 0.89 0
1 0.36 -0.12 0
1 -0.45 0.62 0
1 0.56 0.42 0
1 0.4 0.56 0
1 0.46 0.51 0
1 1.2 0.32 1
1 0.6 0.9 1
1 0.32 0.98 1
1 0.2 1.3 1
1 0.15 1.36 1
1 0.54 0.98 1
1 1.36 1.05 1
1 0.22 1.65 1
1 1.65 1.54 1
1 0.25 1.68 1
8. 逻辑回归测试
import java.io.IOException; import java.io.RandomAccessFile; public class Test { public static void main(String[] args) throws IOException { //read Sample.txt Sample[] sam = new Sample[25]; int w = 0; long filePoint = 0; String s; RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r"); long fileLength = file.length(); while(filePoint < fileLength) { s = file.readLine(); //s --> sample String[] sub = s.split(" "); sam[w] = new Sample(sub.length - 1); for(int i = 0; i < sub.length; i++) { if(i == sub.length - 1) { sam[w].label = Integer.parseInt(sub[i]); } else { sam[w].features[i] = Double.parseDouble(sub[i]); } }//for //sam[w].outSample(); w++; filePoint = file.getFilePointer(); }//while read file LogisticRegression lr = new LogisticRegression(); double[] para = {0,0,0}; double rate = 0.5; double th = 0.001; lr.Initialize(sam, w); lr.setPara(para, rate, th); lr.Update(); lr.OutputTheta(); } }