Java OpenCV 人工智能01.3 机器学习 深度学习
Java OpenCV-4.X 人工智能01.3 机器学习 深度学习1 Java OpenCV ANN_MLP 人工神经网络-多层感知器2 Java OpenCV 全部测试代码1 Java OpenCV ANN_MLP 人工神经网络-多层感知器MLP 多层感知器package com.xu.opencv;import java.io.BufferedReader;import java.io.Fi
·
Java OpenCV 人工智能01.3 机器学习 深度学习
OpenCV 包含的ML模块
序号 | 模块 | 中文 | 英文 |
---|---|---|---|
1 | ANN_MLP | 人工神经网络–多层感知器 | Artificial Neural Networks Multi-Layer Perception |
2 | DTree | 决策树 | Decision Tree |
3 | EM | EM算法 | Expectation Maximization |
4 | KNearest | K-邻近算法 | K-Nearest Neighbour Classifier |
5 | LogisticRegression | 逻辑回归 | logistic regression |
6 | NormalBayesClassifier | 贝叶斯分类 | Normal Bayes Classifier |
7 | RTrees | 随机森林 | Random Forest |
8 | SVM | 支持向量机 | Support Vector Machines |
9 | SVMSGD | 随机梯度下降SVM分类器 | Stochastic Gradient Descent SVM Classifier |
1 ANN_MLP 人工神经网络-多层感知器
package com.xu.opencv;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Boost;
import org.opencv.ml.DTrees;
import org.opencv.ml.EM;
import org.opencv.ml.KNearest;
import org.opencv.ml.LogisticRegression;
import org.opencv.ml.Ml;
import org.opencv.ml.NormalBayesClassifier;
import org.opencv.ml.RTrees;
import org.opencv.ml.SVM;
import org.opencv.ml.SVMSGD;
import org.opencv.ml.TrainData;
/**
* @author Administrator
*/
public class ML {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) throws Exception {
ANN_MLP();
}
/**
* Artificial Neural Networks -- Multi-Layer Perception 人工神经网络--多层感知器
*
* @throws Exception
*/
public static void ANN_MLP() throws Exception {
ANN_MLP ann = ANN_MLP.create();
/**
* 输入层:对应着每个图片的像素,所以是28*28
* 隐含层:两个,神经元个数分别为 512 和 256
* 输出层:和训练的标签对应,神经元为10个,即手写数字 0123456789
*/
int[] layer = {28 * 28, 512, 256, 10};
Mat layerSizes = new Mat(1, layer.length, CvType.CV_32FC1);
for (int i = 0; i < layer.length; i++) {
layerSizes.put(0, i, layer[i]);
}
ann.setLayerSizes(layerSizes);
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.setTrainMethod(ANN_MLP.BACKPROP);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM);
ann.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 100000, 0.000001));
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\train.txt", CvType.CV_32FC1, CvType.CV_32FC1);
TrainData data = TrainData.create(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
ann.train(data);
ann.save("D:\\OneDrive\\桌面\\ann.xml");
ann.clear();
}
public static Map<String, Mat> read(String path, int train_type, int label_type) throws Exception {
BufferedReader reader = new BufferedReader(new FileReader(path));
Mat train = new Mat(600, 28 * 28, train_type);
Mat label = new Mat(600, 10, label_type);
Map<String, Mat> map = new HashMap<>(2);
int index = 0;
String line = null;
while ((line = reader.readLine()) != null) {
int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray();
for (int i = 0; i < 28 * 28; i++) {
train.put(index, i, (float) data[i]);
}
int col = Integer.parseInt(line.split("->")[0]);
label.put(index, col, (float) col);
index++;
if (index >= 600) {
break;
}
}
map.put("train", train);
map.put("label", label);
reader.close();
return map;
}
}
2 SVM 支持向量机
package com.xu.opencv;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Boost;
import org.opencv.ml.DTrees;
import org.opencv.ml.EM;
import org.opencv.ml.KNearest;
import org.opencv.ml.LogisticRegression;
import org.opencv.ml.Ml;
import org.opencv.ml.NormalBayesClassifier;
import org.opencv.ml.RTrees;
import org.opencv.ml.SVM;
import org.opencv.ml.SVMSGD;
import org.opencv.ml.TrainData;
/**
* @author Administrator
*/
public class ML {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) throws Exception {
SVM();
}
/**
* 支持向量机 Support Vector Machines
*
* @throws Exception
*/
public static void SVM() throws Exception {
SVM svm = SVM.create();
svm.setC(1);
svm.setP(0);
svm.setNu(0);
svm.setCoef0(0);
svm.setGamma(1);
svm.setDegree(0);
svm.setType(SVM.C_SVC);
svm.setKernel(SVM.LINEAR);
svm.setTermCriteria(new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 1000, 0));
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
svm.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
svm.save("D:\\OneDrive\\桌面\\svm.xml");
}
public static Map<String, Mat> read(String path) throws Exception {
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
Mat train = new Mat(60000, 28 * 28, CvType.CV_32FC1);
Mat label = new Mat(1, 60000, CvType.CV_32SC1);
Map<String, Mat> map = new HashMap<>(2);
int index = 0;
while ((line = reader.readLine()) != null) {
int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray();
for (int i = 0; i < 28 * 28; i++) {
train.put(index, i, data[i]);
}
label.put(index, 0, Integer.parseInt(line.split("->")[0]));
index++;
}
map.put("train", train);
map.put("label", label);
reader.close();
return map;
}
}
3 Java OpenCV 全部测试代码
package com.xu.opencv;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.opencv.core.Core;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;
import org.opencv.ml.Boost;
import org.opencv.ml.DTrees;
import org.opencv.ml.EM;
import org.opencv.ml.KNearest;
import org.opencv.ml.LogisticRegression;
import org.opencv.ml.Ml;
import org.opencv.ml.NormalBayesClassifier;
import org.opencv.ml.RTrees;
import org.opencv.ml.SVM;
import org.opencv.ml.SVMSGD;
import org.opencv.ml.TrainData;
/**
* @author Administrator
*/
public class ML {
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) throws Exception {
ANN_MLP();
}
/**
* ANN_MLP 人工神经网络-多层感知器
*
* @throws Exception
*/
public static void ANN_MLP() throws Exception {
ANN_MLP ann = ANN_MLP.create();
/**
* 输入层:对应着每个图片的像素,所以是28*28
* 隐含层:两个,神经元个数分别为 512 和 256
* 输出层:和训练的标签对应,神经元为10个,即手写数字 0123456789
*/
int[] layer = {28 * 28, 512, 256, 10};
Mat layerSizes = new Mat(1, layer.length, CvType.CV_32FC1);
for (int i = 0; i < layer.length; i++) {
layerSizes.put(0, i, layer[i]);
}
ann.setLayerSizes(layerSizes);
ann.setBackpropWeightScale(0.1);
ann.setBackpropMomentumScale(0.1);
ann.setTrainMethod(ANN_MLP.BACKPROP);
ann.setActivationFunction(ANN_MLP.SIGMOID_SYM);
ann.setTermCriteria(new TermCriteria(TermCriteria.MAX_ITER, 100000, 0.000001));
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\train.txt", CvType.CV_32FC1, CvType.CV_32FC1);
TrainData data = TrainData.create(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
ann.train(data);
ann.save("D:\\OneDrive\\桌面\\ann.xml");
Mat response = new Mat();
ann.predict(map.get("train"), response);
for (int i = 0; i < response.height(); i++) {
System.out.println(response.get(i, 0)[0]);
}
ann.clear();
}
public static Map<String, Mat> read(String path, int train_type, int label_type) throws Exception {
BufferedReader reader = new BufferedReader(new FileReader(path));
Mat train = new Mat(600, 28 * 28, train_type);
Mat label = new Mat(600, 10, label_type);
Map<String, Mat> map = new HashMap<>(2);
int index = 0;
String line = null;
while ((line = reader.readLine()) != null) {
int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray();
for (int i = 0; i < 28 * 28; i++) {
train.put(index, i, (float) data[i]);
}
int col = Integer.parseInt(line.split("->")[0]);
label.put(index, col, (float) col);
index++;
if (index >= 600) {
break;
}
}
map.put("train", train);
map.put("label", label);
reader.close();
return map;
}
/**
* 决策树 Decision Tree
*
* @throws Exception
*/
public static void DTrees() throws Exception {
DTrees tree = DTrees.create();
tree.setMaxCategories(10); // 设置分类数为 0-9
tree.setMaxDepth(10);
tree.setMinSampleCount(10);
tree.setCVFolds(0);
tree.setUseSurrogates(false);
tree.setUse1SERule(false);
tree.setTruncatePrunedTree(false);
tree.setRegressionAccuracy(0);
tree.setPriors(new Mat());
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\train.txt");
tree.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
tree.save("D:\\OneDrive\\桌面\\DTrees.xml");
}
/**
* EM算法 Expectation - Maximization
*
* @throws Exception
*/
public static void EM() throws Exception {
EM em = EM.create();
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
em.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
em.save("D:\\OneDrive\\桌面\\em.xml");
}
/**
* K-邻近算法 K-Nearest Neighbour Classifier
*
* @throws Exception
*/
public static void KNN() throws Exception {
KNearest knn = KNearest.create();
knn.setDefaultK(5);
knn.setIsClassifier(true);
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\train.txt");
TrainData data = TrainData.create(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
knn.train(data);
knn.save("D:\\OneDrive\\桌面\\knn.xml");
}
/**
* 逻辑回归 logistic regression
*
* @throws Exception
*/
public static void LogisticRegression() throws Exception {
LogisticRegression logic = LogisticRegression.create();
logic.setLearningRate(0.00001);
logic.setIterations(100);
logic.setRegularization(LogisticRegression.REG_DISABLE);
logic.setTrainMethod(LogisticRegression.MINI_BATCH);
logic.setMiniBatchSize(1);
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\train.txt");
logic.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
logic.save("D:\\OneDrive\\桌面\\logic.xml");
}
/**
* 贝叶斯分类 Normal Bayes Classifier
*
* @throws Exception
*/
public static void NormalBayesClassifier() throws Exception {
NormalBayesClassifier nbc = NormalBayesClassifier.create();
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
nbc.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
nbc.save("D:\\OneDrive\\桌面\\nbc.xml");
}
/**
* 随机森林 Random Forest
*
* @throws Exception
*/
public static void RTrees() throws Exception {
RTrees tree = RTrees.create();
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
tree.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
tree.save("D:\\OneDrive\\桌面\\RTrees.xml");
}
/**
* 支持向量机 Support Vector Machines
*
* @throws Exception
*/
public static void SVM() throws Exception {
SVM svm = SVM.create();
svm.setC(1);
svm.setP(0);
svm.setNu(0);
svm.setCoef0(0);
svm.setGamma(1);
svm.setDegree(0);
svm.setType(SVM.C_SVC);
svm.setKernel(SVM.LINEAR);
svm.setTermCriteria(new TermCriteria(TermCriteria.EPS + TermCriteria.MAX_ITER, 1000, 0));
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
svm.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
svm.save("D:\\OneDrive\\桌面\\svm.xml");
}
/**
* SVMSGD 随机梯度下降SVM分类器
*
* @throws Exception
*/
public static void SVMSGD() throws Exception {
SVMSGD svmsgd = SVMSGD.create();
Map<String, Mat> map = read("D:\\OneDrive\\桌面\\data.txt");
svmsgd.train(map.get("train"), Ml.ROW_SAMPLE, map.get("label"));
svmsgd.save("D:\\OneDrive\\桌面\\svmsgd.xml");
}
public static Map<String, Mat> read(String path) throws Exception {
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
Mat train = new Mat(60000, 28 * 28, CvType.CV_32FC1);
Mat label = new Mat(1, 60000, CvType.CV_32SC1);
Map<String, Mat> map = new HashMap<>(2);
int index = 0;
while ((line = reader.readLine()) != null) {
int[] data = Arrays.asList(line.split("->")[1].split(",")).stream().mapToInt(Integer::parseInt).toArray();
for (int i = 0; i < 28 * 28; i++) {
train.put(index, i, data[i]);
}
label.put(index, 0, Integer.parseInt(line.split("->")[0]));
index++;
}
map.put("train", train);
map.put("label", label);
reader.close();
return map;
}
}
更多推荐
所有评论(0)