C# BP神经网络 类与实例(二) 完善版

2013年01月27日 14:18:05 苏内容
  标签: C//BP/神经网络
阅读:7328

在<<C# BP神经网络 类与实例>>这篇文章中,笔者转载了网友提供的一个类。非常感谢这位网友的无私奉献,使笔者得以快速的完成了软件的一个功能。

貌似我找到的是网友对原作改动过的版本,使得该类存在着一些不足,笔者对其进行了相应改进。

1. 未保存训练样本的结果,使得每次使用时都要重新训练数据,耗费时间:

虽然类中已经提供了保存w、v、b1、b2数组的方法,但其构造函数却有问题,同时没有保存in_rate inNum HideNum outNum 这些下面会用到的数据。

对此笔者进行了以下改进:

新建了一个无参的构造函数用来预测新样本时实例化对象;

提供saveParas、readParas方法用于保存和读取这些参数;

提供initial方法用来建立预测过程中用到的动态数组

 

2. 类中用伪随机数来初始化w、v数组,使得每次建立新对象时w、v会不同。笔者对该算法没有深入研究,不清楚为什么这样设计。

笔者把

R = new Random(); 

这行代码进行了修改

R = new Random(32);  //加了一个参数,使产生的伪随机序列相同

下面是我修改后的类和示例:

 

using System;
using System.IO;
using System.Text;
namespace BpANNet
{
    /// <summary>
    /// BpNet 的摘要说明。
    /// </summary>
    public class BpNet
    {
        public int inNum;//输入节点数
        int hideNum;//隐层节点数
        public int outNum;//输出层节点数
        public int sampleNum;//样本总数

        Random R;
        double[] x;//输入节点的输入数据
        double[] x1;//隐层节点的输出
        double[] x2;//输出节点的输出

        double[] o1;//隐层的输入
        double[] o2;//输出层的输入
        public double[,] w;//权值矩阵w
        public double[,] v;//权值矩阵V
        public double[,] dw;//权值矩阵w
        public double[,] dv;//权值矩阵V


        public double rate;//学习率
        public double[] b1;//隐层阈值矩阵
        public double[] b2;//输出层阈值矩阵
        public double[] db1;//隐层阈值矩阵
        public double[] db2;//输出层阈值矩阵

        double[] pp;//输出层的误差
        double[] qq;//隐层的误差
        double[] yd;//输出层的教师数据
        public double e;//均方误差
        double in_rate;//归一化比例系数

        public int computeHideNum(int m, int n)
        {
            double s = Math.Sqrt(0.43 * m * n + 0.12 * n * n + 2.54 * m + 0.77 * n + 0.35) + 0.51;
            int ss = Convert.ToInt32(s);
            return ((s - (double)ss) > 0.5) ? ss + 1 : ss;

        }
        
        public BpNet(double[,] p, double[,] t)
        {

            // 构造函数逻辑
            R = new Random(32); //加了一个参数,使产生的伪随机序列相同

            this.inNum = p.GetLength(1); //数组第二维大小为 输入节点数
            this.outNum = t.GetLength(1); //输出节点数
            this.hideNum = computeHideNum(inNum, outNum); //隐藏节点数,不知其原理
            //      this.hideNum=18;
            this.sampleNum = p.GetLength(0); //数组第一维大小 为

            Console.WriteLine("输入节点数目: " + inNum);
            Console.WriteLine("隐层节点数目:" + hideNum);
            Console.WriteLine("输出层节点数目:" + outNum);

            Console.ReadLine();

            x = new double[inNum];
            x1 = new double[hideNum];
            x2 = new double[outNum];

            o1 = new double[hideNum];
            o2 = new double[outNum];

            w = new double[inNum, hideNum];
            v = new double[hideNum, outNum];
            dw = new double[inNum, hideNum];
            dv = new double[hideNum, outNum];

            b1 = new double[hideNum];
            b2 = new double[outNum];
            db1 = new double[hideNum];
            db2 = new double[outNum];

            pp = new double[hideNum];
            qq = new double[outNum];
            yd = new double[outNum];

            //初始化w
            for (int i = 0; i < inNum; i++)
            {
                for (int j = 0; j < hideNum; j++)
                {
                    w[i, j] = (R.NextDouble() * 2 - 1.0) / 2;
                }
            }

            //初始化v
            for (int i = 0; i < hideNum; i++)
            {
                for (int j = 0; j < outNum; j++)
                {
                    v[i, j] = (R.NextDouble() * 2 - 1.0) / 2;
                }
            }

            rate = 0.8;
            e = 0.0;
            in_rate = 1.0;
        }

        //训练函数
        public void train(double[,] p, double[,] t)
        {
            e = 0.0;
            //求p,t中的最大值
            double pMax = 0.0;
            for (int isamp = 0; isamp < sampleNum; isamp++)
            {
                for (int i = 0; i < inNum; i++)
                {
                    if (Math.Abs(p[isamp, i]) > pMax)
                    {
                        pMax = Math.Abs(p[isamp, i]);
                    }
                }

                for (int j = 0; j < outNum; j++)
                {
                    if (Math.Abs(t[isamp, j]) > pMax)
                    {
                        pMax = Math.Abs(t[isamp, j]);
                    }
                }

                in_rate = pMax;
            }//end isamp



            for (int isamp = 0; isamp < sampleNum; isamp++)
            {
                //数据归一化
                for (int i = 0; i < inNum; i++)
                {
                    x[i] = p[isamp, i] / in_rate;
                }
                for (int i = 0; i < outNum; i++)
                {
                    yd[i] = t[isamp, i] / in_rate;
                }

                //计算隐层的输入和输出

                for (int j = 0; j < hideNum; j++)
                {
                    o1[j] = 0.0;
                    for (int i = 0; i < inNum; i++)
                    {
                        o1[j] += w[i, j] * x[i];
                    }
                    x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j]));
                }

                //计算输出层的输入和输出
                for (int k = 0; k < outNum; k++)
                {
                    o2[k] = 0.0;
                    for (int j = 0; j < hideNum; j++)
                    {
                        o2[k] += v[j, k] * x1[j];
                    }
                    x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k]));
                }

                //计算输出层误差和均方差

                for (int k = 0; k < outNum; k++)
                {
                    qq[k] = (yd[k] - x2[k]) * x2[k] * (1.0 - x2[k]);
                    e += (yd[k] - x2[k]) * (yd[k] - x2[k]);
                    //更新V
                    for (int j = 0; j < hideNum; j++)
                    {
                        v[j, k] += rate * qq[k] * x1[j];
                    }
                }

                //计算隐层误差

                for (int j = 0; j < hideNum; j++)
                {
                    pp[j] = 0.0;
                    for (int k = 0; k < outNum; k++)
                    {
                        pp[j] += qq[k] * v[j, k];
                    }
                    pp[j] = pp[j] * x1[j] * (1 - x1[j]);

                    //更新W

                    for (int i = 0; i < inNum; i++)
                    {
                        w[i, j] += rate * pp[j] * x[i];
                    }
                }

                //更新b2
                for (int k = 0; k < outNum; k++)
                {
                    b2[k] += rate * qq[k];
                }

                //更新b1
                for (int j = 0; j < hideNum; j++)
                {
                    b1[j] += rate * pp[j];
                }

            }//end isamp
            e = Math.Sqrt(e);
            //      adjustWV(w,dw);
            //      adjustWV(v,dv);


        }//end train

        public void adjustWV(double[,] w, double[,] dw)
        {
            for (int i = 0; i < w.GetLength(0); i++)
            {
                for (int j = 0; j < w.GetLength(1); j++)
                {
                    w[i, j] += dw[i, j];
                }
            }

        }

        public void adjustWV(double[] w, double[] dw)
        {
            for (int i = 0; i < w.Length; i++)
            {

                w[i] += dw[i];

            }

        }

       
        public BpNet() //仿真函数 用的构造函数
        {
           
        }
        
        //数据仿真函数
        public double[] sim(double[] psim) //in_rate inNum HideNum outNum 
        {
            for (int i = 0; i < inNum; i++)
                x[i] = psim[i] / in_rate;

            for (int j = 0; j < hideNum; j++)
            {
                o1[j] = 0.0;
                for (int i = 0; i < inNum; i++)
                    o1[j] = o1[j] + w[i, j] * x[i];
                x1[j] = 1.0 / (1.0 + Math.Exp(-o1[j] - b1[j]));
            }
            for (int k = 0; k < outNum; k++)
            {
                o2[k] = 0.0;
                for (int j = 0; j < hideNum; j++)
                    o2[k] = o2[k] + v[j, k] * x1[j];
                x2[k] = 1.0 / (1.0 + Math.Exp(-o2[k] - b2[k]));

                x2[k] = in_rate * x2[k];

            }

            return x2;
        } //end sim

        //保存矩阵w,v
        public void saveMatrix(double[,] w, string filename)
        {
            StreamWriter sw = new StreamWriter(filename);
            for (int i = 0; i < w.GetLength(0); i++)
            {
                for (int j = 0; j < w.GetLength(1); j++)
                {
                    sw.Write(w[i, j] + " ");
                }
                sw.WriteLine();
            }
            sw.Close();

        }

        //保存矩阵b1,b2
        public void saveMatrix(double[] b, string filename)
        {
            StreamWriter sw = new StreamWriter(filename);
            for (int i = 0; i < b.Length; i++)
            {
                sw.Write(b[i] + " ");
            }
            sw.Close();
        }

        //保存参数 in_rate inNum HideNum outNum 
        public void saveParas(string filename)
        {
            try
            {
                StreamWriter sw = new StreamWriter(filename);
                string str = inNum.ToString() + " "
                        + hideNum.ToString() + " "
                        + outNum.ToString() + " "
                        + in_rate.ToString();
                sw.WriteLine(str);
                sw.Close();
            }
            catch (Exception e)
            {
                // Let the user know what went wrong.
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }
        }

        //读回参数 in_rate inNum HideNum outNum, tjt 预测新数据 
        public void readParas(string filename)
        {
            StreamReader sr;
            try
            {
                sr = new StreamReader(filename);
                String line;
                if((line = sr.ReadLine()) != null)
                {
                    string[] strArr = line.Split(' ');
                    this.inNum = Convert.ToInt32(strArr[0]);
                    this.hideNum = Convert.ToInt32(strArr[1]);
                    this.outNum = Convert.ToInt32(strArr[2]);
                    this.in_rate = Convert.ToDouble(strArr[3]);
                }
                sr.Close();

            }
            catch (Exception e)
            {
                // Let the user know what went wrong.
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }
        }

        public void initial() // 建立一些中间数组 tjt 预测新数据
        {
            x = new double[inNum];
            x1 = new double[hideNum];
            x2 = new double[outNum];

            o1 = new double[hideNum];
            o2 = new double[outNum];

            w = new double[inNum, hideNum];
            v = new double[hideNum, outNum];
            dw = new double[inNum, hideNum];
            dv = new double[hideNum, outNum];

            b1 = new double[hideNum];
            b2 = new double[outNum];
            db1 = new double[hideNum];
            db2 = new double[outNum];

            pp = new double[hideNum];
            qq = new double[outNum];
            yd = new double[outNum];
        }

        //读取矩阵W,V
        public void readMatrixW(double[,] w, string filename) 
        {

            StreamReader sr;
            try
            {

                sr = new StreamReader(filename);

                String line;
                int i = 0;

                while ((line = sr.ReadLine()) != null)
                {

                    string[] s1 = line.Trim().Split(' ');
                    for (int j = 0; j < s1.Length; j++)
                    {
                        w[i, j] = Convert.ToDouble(s1[j]);
                    }
                    i++;
                }
                sr.Close();

            }
            catch (Exception e)
            {
                // Let the user know what went wrong.
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }
        }


        //读取矩阵b1,b2
        public void readMatrixB(double[] b, string filename)
        {

            StreamReader sr;
            try
            {
                sr = new StreamReader(filename);
                String line;                
                while ((line = sr.ReadLine()) != null)
                {
                    int i = 0;
                    string[] s1 = line.Trim().Split(' ');
                    for (int j = 0; j < s1.Length; j++)
                    {
                        b[i] = Convert.ToDouble(s1[j]);
                        i++;
                    }                    
                }
                sr.Close();

            }
            catch (Exception e)
            {
                // Let the user know what went wrong.
                Console.WriteLine("The file could not be read:");
                Console.WriteLine(e.Message);
            }

        }

    }//end bpnet
} //end namespace

示例:

//主调用程序
using System;

namespace BpANNet
{
    /// <summary>
    /// Class1 的摘要说明。
    /// </summary>
    class Class1
    {
        /// <summary>
        /// 应用程序的主入口点。
        /// </summary>
        [STAThread]
        static void Main(string[] args)
        {
            //      double [,] p1=new double[,]{{0.05,0.02},{0.09,0.11},{0.12,0.20},{0.15,0.22},{0.20,0.25},{0.75,0.75},{0.80,0.83},{0.82,0.80},{0.90,0.89},{0.95,0.89},{0.09,0.04},{0.1,0.1},{0.14,0.21},{0.18,0.24},{0.22,0.28},{0.77,0.78},{0.79,0.81},{0.84,0.82},{0.94,0.93},{0.98,0.99}};
            //      double [,] t1=new double[,]{{1,0},{1,0},{1,0},{1,0},{1,0},{0,1},{0,1},{0,1},{0,1},{0,1},{1,0},{1,0},{1,0},{1,0},{1,0},{0,1},{0,1},{0,1},{0,1},{0,1}};
            double[,] p1 = new double[,] { { 0.1399, 0.1467, 0.1567, 0.1595, 0.1588, 0.1622 }, { 0.1467, 0.1567, 0.1595, 0.1588, 0.1622, 0.1611 }, { 0.1567, 0.1595, 0.1588, 0.1622, 0.1611, 0.1615 }, { 0.1595, 0.1588, 0.1622, 0.1611, 0.1615, 0.1685 }, { 0.1588, 0.1622, 0.1611, 0.1615, 0.1685, 0.1789 } };
            double[,] t1 = new double[,] { { 0.1622 }, { 0.1611 }, { 0.1615 }, { 0.1685 }, { 0.1789 }, { 0.1790 } };
            BpNet bp = new BpNet(p1, t1);
            int study = 0;
            do
            {
                study++;
                bp.train(p1, t1);
                //       bp.rate=0.95-(0.95-0.3)*study/50000;
                //        Console.Write("第 "+ study+"次学习: ");
                //        Console.WriteLine(" 均方差为 "+bp.e);

            } while (bp.e > 0.001 && study < 50000);
            Console.Write("第 " + study + "次学习: ");
            Console.WriteLine(" 均方差为 " + bp.e);
            bp.saveMatrix(bp.w, "w.txt");
            bp.saveMatrix(bp.v, "v.txt");
            bp.saveMatrix(bp.b1, "b1.txt");
            bp.saveMatrix(bp.b2, "b2.txt");
            bp.saveParas("para.txt");

            pretect(); //开始预测新样本
        }

        public static void pretect()
        {
            Console.WriteLine("预测开始...");
            BpNet bp = new BpNet();
            bp.readParas("para.txt");
            bp.initial();
            bp.readMatrixW(bp.w,"w.txt");
            bp.readMatrixW(bp.v, "v.txt");
            bp.readMatrixB(bp.b1, "b1.txt");
            bp.readMatrixB(bp.b2, "b2.txt");
 
            //      double [,] p2=new double[,]{{0.05,0.02},{0.09,0.11},{0.12,0.20},{0.15,0.22},{0.20,0.25},{0.75,0.75},{0.80,0.83},{0.82,0.80},{0.90,0.89},{0.95,0.89},{0.09,0.04},{0.1,0.1},{0.14,0.21},{0.18,0.24},{0.22,0.28},{0.77,0.78},{0.79,0.81},{0.84,0.82},{0.94,0.93},{0.98,0.99}};
            double[,] p2 = new double[,] { { 0.1399, 0.1467, 0.1567, 0.1595, 0.1588, 0.1622 }, { 0.1622, 0.1611, 0.1615, 0.1685, 0.1789, 0.1790 } };
            int aa = bp.inNum;
            int bb = bp.outNum;
            int cc = p2.GetLength(0);
            double[] p21 = new double[aa];
            double[] t2 = new double[bb];
            for (int n = 0; n < cc; n++)
            {
                for (int i = 0; i < aa; i++)
                {
                    p21[i] = p2[n, i];
                }
                t2 = bp.sim(p21);

                for (int i = 0; i < t2.Length; i++)
                {
                    Console.WriteLine("预测数据" + n.ToString() + ": " + t2[i] + " ");
                }

            }

            Console.ReadLine();
        }
    }
}

4楼 tfg1025 2013-07-01 15:08发表 [回复]

错误很多,估计也是从网上直接拷贝过来的,在进行权值更新的时候,dw dv都没有初始化,学习率也没有用到 怎么更新?
3楼 sunyixiao 2013-01-18 16:57发表 [回复]

哥们,我很抱歉的通知你,你的程序在逻辑上有一点问题!我不知道你是否熟悉神经网络。你的程序在更新权值的时候出现了问题。你不能先更新v,再计算w的误差,因为w的误差是在v的情形下产生的。所以你应该先把v的更新保存在一个地方,让后计算w的误差并更新。之后再更新v
2楼 liuhh1 2012-08-21 16:48发表 [回复]

pp应该为隐层误差;qq是输出层误差。注释稍微改一下就行。程序很好 类里面的dw dv没有用到 含dw dv的那段代码可以不要
1楼 yannian1990 2012-06-07 16:53发表 [回复] [引用] [举报]

不错,很好的修改

扩展阅读