博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3
阅读量:7232 次
发布时间:2019-06-29

本文共 1673 字,大约阅读时间需要 5 分钟。

Spark MLlib Deep Learning Deep Belief Network (深度学习-深度信念网络)2.3

第二章Deep Belief Network (深度信念网络)

3实例

3.1 測试数据

依照上例数据,或者新建图片识别数据。

3.2 DBN实例

//****************2(读取固定样本:来源于经典优化算法測试函数Sphere Model***********//

    //2 读取样本数据

    Logger.getRootLogger.setLevel(Level.WARN)

    valdata_path ="/user/huangmeiling/deeplearn/data1"

    valexamples =sc.textFile(data_path).cache()

    valtrain_d1 =examples.map { line =>

      valf1 = line.split("\t")

      valf =f1.map(f =>f.toDouble)

      valid =f(0)

      valy = Array(f(1))

      valx =f.slice(2,f.length)

      (id, new BDM(1,y.length,y),new BDM(1,x.length,x))

    }

    valtrain_d =train_d1.map(f => (f._2, f._3))

    valopts = Array(100.0,20.0,0.0) 

    //3 设置训练參数,建立DBN模型

    valDBNmodel =new DBN().

      setSize(Array(5, 7)).

      setLayer(2).

      setMomentum(0.1).

      setAlpha(1.0).

      DBNtrain(train_d, opts) 

    //4 DBN模型转化为NN模型

    valmynn =DBNmodel.dbnunfoldtonn(1)

    valnnopts = Array(100.0,50.0,0.0)

    valnumExamples =train_d.count()

    println(s"numExamples = $numExamples.")

    println(mynn._2)

    for (i <-0 tomynn._1.length -1) {

      print(mynn._1(i) +"\t")

    }

    println()

    println("mynn_W1")

    valtmpw1 =mynn._3(0)

    for (i <-0 totmpw1.rows -1) {

      for (j <-0 totmpw1.cols -1) {

        print(tmpw1(i,j) +"\t")

      }

      println()

    }

    valNNmodel =new NeuralNet().

      setSize(mynn._1).

      setLayer(mynn._2).

      setActivation_function("sigm").

      setOutput_function("sigm").

      setInitW(mynn._3).

      NNtrain(train_d, nnopts) 

    //5 NN模型測试

    valNNforecast =NNmodel.predict(train_d)

    valNNerror =NNmodel.Loss(NNforecast)

    println(s"NNerror = $NNerror.")

    valprintf1 =NNforecast.map(f => (f.label.data(0), f.predict_label.data(0))).take(200)

    println("预測结果——实际值:预測值:误差")

    for (i <-0 untilprintf1.length)

      println(printf1(i)._1 +"\t" +printf1(i)._2 +"\t" + (printf1(i)._2 -printf1(i)._1)) 

转载请注明出处:

你可能感兴趣的文章
设计模式开篇 - 简单工厂模式
查看>>
Spring MVC 注解和XML的区别
查看>>
利用Swoole实现PHP+websocket直播,即使通讯代码,及linux下swoole安装基本配置
查看>>
Elastic学习第一天遇到的问题以及添加的一些操作
查看>>
Python lambda介绍
查看>>
BSON与JSON的区别
查看>>
文件系统存储数据,与数据库系统存储数据的差别
查看>>
linux之awk
查看>>
第九章 接口
查看>>
XCode4.2.1 使用NavigationController实现View切换
查看>>
如何让NSURLConnection在子线程中运行
查看>>
es6-Generator
查看>>
Python3.6单例模式报错TypeError: object() takes no parameters的解决方法
查看>>
HTML常用标记(选择性,不全)
查看>>
用一辈子去领悟的22条生活真谛
查看>>
1968: [Ahoi2005]COMMON 约数研究
查看>>
discuz 启用html code 显示问题
查看>>
A1027. Colors in Mars (20)
查看>>
[SRM568]DisjointSemicircles
查看>>
9个很有发展潜力的PHP开源项目
查看>>