JAX是一個用於高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。本書詳解JAX框架深度學習的相關知識,配套示例源碼、PPT課件、資料集和開發環境。本書共分為13章,內容包括JAX從零開始,一學就會的線性回歸、多層感知機與自動微分器,深度學習的理論基礎,XLA與JAX一般特性,JAX的特性,JAX的一些細節,JAX中的卷積,JAX與TensorFlow的比較與交互,遵循JAX函數基本規則下的自訂函數,JAX中的包。給出3個實戰案例:使用ResNet完成CIFAR100資料集分類,有趣的詞嵌入,生成對抗網路(GAN)。本書適合JAX框架初學者、深度學習初學者以及深度學習從業人員,也適合作為高等院校和培訓機構人工智慧相關專業的師生教學參考書。
第1章 JAX從零開始
1.1 JAX來了
1.1.1 JAX是什麼
1.1.2 為什麼是JAX
1.2 JAX的安裝與使用
1.2.1 Windows Subsystem for Linux的安裝
1.2.2 JAX的安裝和驗證
1.2.3 PyCharm的下載與安裝
1.2.4 使用PyCharm和JAX
1.2.5 JAX的Python代碼小練習:計算SeLU函數
1.3 JAX實戰——MNIST手寫體的識別
1.3.1 步:準備數據集
1.3.2 第二步:模型的設計
1.3.3 第三步:模型的訓練
1.4 本章小結
第2章 一學就會的線性回點、一多層感知機與自動微分器
2.1 多層感知機
2.1.1 全連接層——多層感知機的隱藏層
2.1.2 使用JAX實現一個全連接層
2.1.3 多功能的全連接函數
2.2 JAX實戰——鶯尾花分類
2.2.1 鶯尾花數據準備與分析
2.2.2 模型分析——採用線性回歸實戰鶯尾花分類
2.2.3 基於JAX的線性回歸模型的編寫
2.2.4 多層感知機與神經網路
2.2.5 基於JAX的啟動函數、softmax函數與交叉熵函數
2.2.6 基於多層感知機的鶯尾花分類實戰
2.3 自動微分器
2.3.1 什麼是微分器
2.3.2 JAX中的自動微分
2.4 本章小結
第3章 深度學習的理論那礎
3.1 BP神經網路簡介
3.2 BP神經網路兩個基礎演算法詳解
3.2.1 小二乘法詳解
3.2.2 道士下山的故事——梯度下降演算法
3.2.3 小二乘法的梯度下降演算法以及JAX實現
3.3 回饋神經網路反向傳播演算法介紹
3.3.1 深度學習基礎
3.3.2 鏈式求導法則
3.3.3 回饋神經網路原理與公式推導
3.3.4 回饋神經網路原理的啟動函數
3.3.5 回饋神經網路原理的Python實現
3.4 本章小結
第4章 XLA與JAX一般特性
4.1 JAX 與XLA
4.1.1 XLA如何運行
4.1.2 XLA如何工作
4.2 JAX一般特性
4.2.1 利用JIT加快程式運行