AutoKeras 作成したモデルを用いての画像分類

AutoKerasで学習したモデルの保存と読み込みを試してみました。
AutoKeras 作成したモデルの保存と呼び出し、モデルの内容確認

今回はモデルを読み込んで画像の分類を試してみます。

テストデータの判定



作成したモデルを読み込み分類してみます。


  1. import tensorflow as tf
  2. from tensorflow.keras.datasets import cifar10
  3. import numpy as np
  4. import tkinter as tk
  5. from PIL import Image, ImageTk
  6. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  7. # 出力したモデルの読み込み
  8. model = tf.keras.models.load_model('model.h5')
  9. # 最初のテスト画像を判定
  10. predicted_y = model.predict(x_test[1])
  11. print(predicted_y)
  12. # 確信度が最大値のインデックス表示
  13. print(np.argmax(predicted_y))
  14. # 判定した画像を表示
  15. root = tk.Tk()
  16. root.geometry('200x200')
  17. root.title('cifar10 image')
  18. canvas = tk.Canvas(
  19.     root,
  20.     width=128,
  21.     height=128
  22. )
  23. canvas.place(x=0, y=0)
  24. #PILでjpgを使用
  25. img = Image.fromarray(np.uint8(x_test[1]))
  26. img = img.resize((128, 128)) # 画像を拡大表示
  27. tk_img = ImageTk.PhotoImage(img)
  28. canvas.create_image(
  29.     0,
  30.     0,
  31.     image=tk_img,
  32.     anchor=tk.NW
  33. )
  34. root.mainloop()




実行結果


[[3.22687876e-04 8.15596506e-02 1.15753856e-07 3.03708219e-07
2.27200161e-08 5.22791144e-09 1.87353741e-08 4.26373497e-08
9.17966664e-01 1.50418127e-04]]
8



a33_01.png

分類は「8(ship:船)」です。
正しく認識できているようです。




画像ファイルを読み込んでの判定



用意されたデータセットからではなく、画像ファイルを読み込んで判定を行ってみます。
まず、先程表示した船の画像をファイルに保存。
Python mnistデータセットの内容を画像ファイルに出力する(pillow使用)


  1. import numpy as np
  2. from PIL import Image
  3. from tensorflow.keras.datasets import cifar10
  4. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  5. img = Image.fromarray(np.uint8(x_test[1]))
  6. img.save('ship.png')



以下の画像が出力できました。

a33_02.png

この画像を読み込んで判定を行ってみます。


  1. import tensorflow as tf
  2. from PIL import Image
  3. import numpy as np
  4. img = np.array(Image.open('ship.png'))
  5. # 出力したモデルの読み込み
  6. model = tf.keras.models.load_model('model.h5')
  7. # テスト画像を判定
  8. predicted_y = model.predict(img)
  9. print(predicted_y)
  10. # 確信度が最大値のインデックス表示
  11. print(np.argmax(predicted_y))



実行結果


[[3.22687876e-04 8.15596506e-02 1.15753856e-07 3.03708219e-07
2.27200161e-08 5.22791144e-09 1.87353741e-08 4.26373497e-08
9.17966664e-01 1.50418127e-04]]
8



先ほどと同じく、「8(ship:船)」という分類になりました。




【参考URL】

Python, NumPyで画像処理(読み込み、演算、保存)
関連記事

プロフィール

Author:symfo
blog形式だと探しにくいので、まとめサイト作成中です。
Symfoware まとめ

PR




検索フォーム

月別アーカイブ