基于weka手工实现逻辑斯谛回归(Logistic回归)

news/2024/7/3 14:56:03

一、logistic回归模型

逻辑斯谛回归模型其实是一种分类模型,这里实现的是参考李航的《统计机器学习》以及周志华的《机器学习》两本教材来整理实现的。

假定我们的输入为 x x x x x x 可以是多个维度的,我们想要根据 x x x 去预测 y y y y ∈ { 0 , 1 } y\in \{0,1\} y{0,1}。逻辑斯谛的模型如下:

p ( Y = 1 ∣ x ) = e x p ( w ⋅ x ) 1 + e x p ( w ⋅ x ) (1) p(Y=1|x)=\frac{exp(w\cdot x)}{1+exp(w\cdot x)}\tag{1} p(Y=1∣x)=1+exp(wx)exp(wx)(1)

其中的参数 w w w就是我们要进行学习的,注意:它是包含了权重系数和偏置(bias)b的。在书写程序时,这样表示更加简洁。

二、极大似然法参数估计

参数 w w w是我们需要学习的,我们采用极大似然法估计模型参数。

设:

P ( Y = 1 ∣ x ) = π ( x ) , P ( Y = 0 ∣ x ) = 1 − π ( x ) (2) P(Y=1|x)=\pi(x),\quad P(Y=0|x)=1-\pi(x)\tag{2} P(Y=1∣x)=π(x),P(Y=0∣x)=1π(x)(2)

似然函数为:

∏ i = 1 N [ π ( x i ) ] y i [ 1 − π ( x i ) ] 1 − y i (3) \prod_{i=1}^N[\pi(x_i)]^{y_i}[1-\pi(x_i)]^{1-y_i} \tag{3} i=1N[π(xi)]yi[1π(xi)]1yi(3)

因为这种指数的形式不利于求导我们需要将它们转化为对数的形式,如下:

L ( w ) = ∑ i = 1 N [ y i l o g π ( x i ) + ( 1 − y i ) l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i l o g ( π ( x i ) 1 − π ( x i ) ) + l o g ( 1 − π ( x i ) ) ] = ∑ i = 1 N [ y i ( w ⋅ x i ) − l o g ( 1 + e x p ( w ⋅ x i ) ) ] (4) \begin{aligned} L(w)=&\sum_{i=1}^N[y_ilog\pi(x_i)+(1-y_i)log(1-\pi(x_i))] \\ =&\sum_{i=1}^N [y_ilog(\frac{\pi(x_i)}{1-\pi(x_i)})+log(1-\pi(x_i))]\\ =&\sum_{i=1}^{N}[y_i(w\cdot x_i)-log(1+exp(w\cdot x_i))] \end{aligned} \tag{4} L(w)===i=1N[yilogπ(xi)+(1yi)log(1π(xi))]i=1N[yilog(1π(xi)π(xi))+log(1π(xi))]i=1N[yi(wxi)log(1+exp(wxi))](4)

L ( w ) L(w) L(w)求极大值,得到 w w w的估计值。

三、梯度下降法求解似然函数

梯度下降法是求极小值的,而我们想要得到的是 L ( w ) L(w) L(w)的最大值,因此,我们取 L ( w ) L(w) L(w)的相反数,即:

arg min ⁡ w − L ( w ) (5) \argmin_{w}-L(w) \tag{5} wargminL(w)(5)

L ( w ) L(w) L(w)关于 w w w求导,如下:

( − L ( w ) ) ′ = − ∑ i = 1 N [ ( y i ⋅ x i ) − e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) ⋅ x i ] = − ∑ i = 1 N [ ( y i − e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) ) ⋅ x i ] = ∑ i = 1 N [ ( e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) − y i ) ⋅ x i ] (6) \begin{aligned} (-L(w))'=&-\sum_{i=1}^N[(y_i\cdot x_i)-\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}\cdot x_i]\\ =&-\sum_{i=1}^N[(y_i-\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)})\cdot x_i]\\ =&\sum_{i=1}^N[(\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}-y_i)\cdot x_i] \end{aligned} \tag{6} (L(w))===i=1N[(yixi)1+exp(wx)exp(wxi)xi]i=1N[(yi1+exp(wx)exp(wxi))xi]i=1N[(1+exp(wx)exp(wxi)yi)xi](6)

然后我们就得到了参数 w w w的更新公式,如下:

w ′ = w − l r ⋅ ( − L ( w ) ′ ) = w − l r ⋅ ( ∑ i = 1 N [ ( e x p ( w ⋅ x i ) 1 + e x p ( w ⋅ x ) − y i ) ⋅ x i ] ) (7) \begin{aligned} w'=&w-lr\cdot(-L(w)')\\ =&w-lr\cdot(\sum_{i=1}^N[(\frac{exp(w\cdot x_i)}{1+exp(w\cdot x)}-y_i)\cdot x_i]) \end{aligned} \tag{7} w==wlr(L(w))wlr(i=1N[(1+exp(wx)exp(wxi)yi)xi])(7)

关于优化方法的选择,最开始是选择西瓜书上提供的牛顿法来实现的,牛顿法的好处是,可以获得较快的收敛速度,但是坏处是,当海森矩阵为奇异矩阵时,会出现无法求解的情况。

因此,可以采用拟牛顿法进行优化,在解决这个问题的同时,也可以很快的收敛。

但是,自己对拟牛顿法并不熟悉,而梯度下降法虽然收敛可能较慢,但是实现起来较为简单,因此这里采用了梯度下降法来优化似然函数。

四、基于weka的代码实现

package weka.classifiers.myf;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.matrix.Matrix;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Standardize;

import java.util.Arrays;

/**
 * @author YFMan
 * @Description 自定义的 Logistic 回归分类器
 * @Date 2023/6/13 11:02
 */
public class myLogistic extends Classifier {
    // 用于存储 线性回归 系数 的数组
    private double[] m_Coefficients;

    // 类别索引
    private int m_ClassIndex;

    // 牛顿法的迭代次数
    private int m_MaxIterations = 1000;

    // 属性数量
    private int m_numAttributes;

    // 系数数量
    private int m_numCoefficients;

    // 梯度下降步长
    private double m_lr = 1e-4;

    // 标准化数据的过滤器
    public static final int FILTER_STANDARDIZE = 1;

    // 用于标准化数据的过滤器
    protected Filter m_StandardizeFilter = null;

    // 用于将 normal 转为 binary 的过滤器
    protected Filter m_NormalToBinaryFilter = null;


    /*
     * @Author YFMan
     * @Description 采用牛顿法来训练 logistic 回归模型
     * @Date 2023/5/9 22:08
     * @Param [data] 训练数据
     * @return void
     **/
    public void buildClassifier(Instances data) throws Exception {
        // 设置类别索引
        m_ClassIndex = data.classIndex();

        // 设置属性数量
        m_numAttributes = data.numAttributes();

        // 系数数量 = 输入属性数量 + 1(截距参数b)
        m_numCoefficients = m_numAttributes;

        // 初始化 系数数组
        m_Coefficients = new double[m_numCoefficients];
        Arrays.fill(m_Coefficients, 0);

        // 将输入数据进行标准化
        m_StandardizeFilter = new Standardize();
        m_StandardizeFilter.setInputFormat(data);
        data = Filter.useFilter(data, m_StandardizeFilter);

        // 将类别属性转为二值属性
        m_NormalToBinaryFilter = new NominalToBinary();
        m_NormalToBinaryFilter.setInputFormat(data);
        data = Filter.useFilter(data, m_NormalToBinaryFilter);

        // 梯度下降法
        for(int curPerformIteration = 0; curPerformIteration < m_MaxIterations;curPerformIteration++){

            double[] deltaM_Coefficients = new double[m_numCoefficients];
            // 计算 l(w) 的一阶导数
            for(int i = 0;i<data.numInstances();i++){

                double yi = data.instance(i).value(m_ClassIndex);
                double wxi = 0;
                int column = 0;
                for(int j=0;j<m_numAttributes;j++){
                    if(j!=m_ClassIndex){
                        wxi += m_Coefficients[column] * data.instance(i).value(j);
                        column++;
                    }
                }
                // 加上截距参数 b
                wxi += m_Coefficients[column];
                double pi1 = Math.exp(wxi) / (1 + Math.exp(wxi));
                for(int k=0;k<m_numCoefficients - 1;k++){
                    deltaM_Coefficients[k] += m_lr * (pi1 - yi) * data.instance(i).value(k);
                }
                // 这里计算 bias b 对应的更新量
                deltaM_Coefficients[m_numCoefficients - 1] += m_lr * (pi1 - yi);
            }

            // 进行参数更新
            for(int k=0;k<m_numCoefficients;k++){
                m_Coefficients[k] -= deltaM_Coefficients[k];
            }

            // 如果参数更新量小于阈值,则停止迭代
            double delta = 0;
            for(int k=0;k<m_numCoefficients;k++){
                delta += deltaM_Coefficients[k] * deltaM_Coefficients[k];
            }
            if(delta < 1e-6){
                break;
            }

        }
    }


    /*
     * @Author YFMan
     * @Description // 分类实例
     * @Date 2023/6/16 11:17
     * @Param [instance]
     * @return double[]
     **/
    public double[] distributionForInstance(Instance instance) throws Exception {

        // 将输入数据进行标准化
        m_StandardizeFilter.input(instance);
        instance = m_StandardizeFilter.output();

        // 将输入属性二值化
        m_NormalToBinaryFilter.input(instance);
        instance = m_NormalToBinaryFilter.output();

        double[] result = new double[2];
        result[0] = 0;
        result[1] = 0;
        int column = 0;
        for(int i=0;i<m_numAttributes;i++){
            if(m_ClassIndex != i){
                result[0] += instance.value(i) * m_Coefficients[column];
                column++;
            }
        }
        result[0] += m_Coefficients[column];
        result[0] = 1 / (1 + Math.exp(result[0]));

        result[1] = 1 - result[0];

        return result;
    }

    /*
     * @Author YFMan
     * @Description 主函数 生成一个线性回归函数预测器
     * @Date 2023/5/9 22:35
     * @Param [argv]
     * @return void
     **/
    public static void main(String[] argv) {
        runClassifier(new myLogistic(), argv);
    }
}

http://www.niftyadmin.cn/n/4260300.html

相关文章

美阿拉斯加州发生5.4级地震 暂无人员伤亡及灾情报告

中新网1月14日电 据美国地质勘探局(U.S. Geological Survey)发布消息称&#xff0c;当地时间13日&#xff0c;美国阿拉斯加州安克雷奇(Anchorage)附近地区发生5.4级地震&#xff0c;目前尚未传出伤亡或其他灾情。资料图&#xff1a;当地时间2018年11月30日&#xff0c;美国阿拉…

实验十五、帧中继交换机的配置

实验十五、帧中继交换机的配置 一、 实验目的 1. 掌握FRAM-RELAY SWITCH 的配置 2. 理解DLCI、LMI 等概念 二、 应用环境 假设在银行系统里&#xff0c;总行和各分理处需要进行通讯&#xff0c;而分理处之间不需要通讯&#xff0c;帧中继是最 好的选择 三、 实验设备 1. DCR-17…

[HNOI2016]最小公倍数

题目描述 给定一张N个顶点M条边的无向图(顶点编号为1,2,…,n)&#xff0c;每条边上带有权值。所有权值都可以分解成2^a*3^b的形式。现在有q个询问&#xff0c;每次询问给定四个参数u、v、a和b&#xff0c;请你求出是否存在一条顶点u到v之间的路径&#xff0c;使得路径依次经过的…

再续 asp.net 域名欺骗式开发之泛解析域名

前言&#xff1a; 在很久前&#xff0c;曾发布过一篇&#xff1a;asp.net 域名欺骗式开发有不少新新人类表示对此文不屑&#xff0c;觉得太基础&#xff0c;他们早懂了&#xff0c;懂了就懂了&#xff0c;毕竟还有人还没有懂的。今天再续文&#xff0c;讲解域名欺骗式开发的进阶…

Dave Python 练习十五 -- 面向对象编程

#encodingutf-8 ### *************** 面向对象编程 ******************** #*********** Part 1: 面向对象编程 *********************** #面向对象编程踩上了进化的步伐&#xff0c;增强了结构化编程&#xff0c;实现了数据与动作的融合&#xff1a;数据层和逻 #辑层现在由一个…

机器学习入门-数据过采样(上采样)1. SMOTE

from imblearn.over_sampling import SMOTE # 导入 overstamp SMOTE(random_state0) # 对训练集的数据进行上采样&#xff0c;测试集的数据不需要SMOTE_train_x, SMOTE_train_y overstamp.fit_sample(train_x, train_y) 由于数据分布的不均衡&#xff0c;因此对数据进行上采…

全民娱乐 手机电视将成为3G手机最主要应用

全民娱乐手机电视将成为3G手机最主要应用<?xml:namespace prefix o ns "urn:schemas-microsoft-com:office:office" />究竟什么才是3G时代的主要应用&#xff1f;对于这个问题一直也是仁者见仁智者见智&#xff0c;大家心中都有自己的一杆秤。自己的喜好自然…

数据结构的定义和简介

1. 概述数据结构定义:我们如何把现实中大量而复杂的问题以特定的数据类型和特定的存储结构保存到主存储器(内存)中,以及在此基础上为实现某个功能(如元素的CURD、排序等)而执行的相应操作&#xff0c;这个相应的操作也叫算法。数据结构 元素 元素的关系算法 对数据结构的操作…