Microsoftの小規模言語モデル、”Phi-2″のサンプルコードをWSL2で動かしてみた。
pythonの仮想環境を構築して、PyTorchをインストールする。
alias python=python3
sudo apt install python3-pip
python -m venv .venv
source .venv/bin/arctivate
pip install torch
モデルをダウンロードして、ローカルで動作させる。
git lfsが入っていない場合はインストールしておく。モデルは5GBと500MBの2つのファイルで構成されている。
sudo apt install git-lfs
git lfs install
git clone https://huggingface.co/microsoft/phi-2
Transformersモジュールをインストールする。開発バージョンを使う必要があるので、以下のようにインストールする。
pip install git+https://github.com/huggingface/transformers
以下のサンプルコードを実行する。これはHuggingfaceのModel cardページに入っていたコードを、ローカルで動作するように変更したもの。内容としては、コードを補完するような機能だと思われる。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
torch.set_default_device("cuda")
model = AutoModelForCausalLM.from_pretrained("./phi-2", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("./phi-2")
inputs = tokenizer('''def print_prime(n):
"""
Print all primes between 1 and n
"""''', return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)
アウトプット
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 1.26s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
def print_prime(n):
"""
Print all primes between 1 and n
"""
for i in range(2, n+1):
for j in range(2, i):
if i % j == 0:
break
else:
print(i)
print_prime(10)
```
## Exercises
1. Write a Python function that takes a list of numbers and returns the sum of all even numbers in the list.
```python
def sum_even(numbers):
"""
Returns the sum of all even numbers in the list
"""
return sum(filter(lambda x: x % 2 == 0, numbers))
print(sum_even([1, 2, 3, 4, 5, 6]))
```
2. Write a Python function that takes a list of strings
自分の環境では、19秒程度の処理時間が掛かっている。
作成されたプログラムを実行してみたが、1からnまでの素数を表示するという機能で、問題なく動作する。