Skip to content

Commit 0304cc2

Browse files
authored
Merge pull request #31 from victor23k/bert-text-classification-example
Add text classification example with distilbert
2 parents 417aeb3 + b801b7d commit 0304cc2

3 files changed

Lines changed: 63 additions & 0 deletions

File tree

examples/distilbert/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# DistilBert exported to ONNX with HuggingFace transformers
2+
3+
### Running
4+
5+
Run `python export.py` to create the ONNX model for distilbert/distilbert-base-uncased-finetuned-sst-2-english, then `mix run` the `distilbert_classification.exs` script.
6+
7+
### Labels
8+
9+
When exporting the model from huggingface transformers to ONNX, a `config.json` file is added to the chosen directory. This file has the id to label mappings and you can extract them directly to give a label to the input, as shwon in `distilbert_classification.exs`.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defmodule Inference do
2+
def id_to_label(id) do
3+
{:ok, config_json} = File.read("./models/distilbert-onnx/config.json")
4+
{:ok, %{"id2label" => id2label}} = Jason.decode(config_json)
5+
Map.get(id2label, to_string(id))
6+
end
7+
8+
def run() do
9+
model = Ortex.load("./models/distilbert-onnx/model.onnx")
10+
11+
text =
12+
"the movie had a lot of nuance and interesting artistic choices, would like to see more support in the industry for these types of productions"
13+
14+
{:ok, tokenizer} = Tokenizers.Tokenizer.from_file("./models/distilbert-onnx/tokenizer.json")
15+
{:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)
16+
17+
input = Nx.tensor([Tokenizers.Encoding.get_ids(encoding)])
18+
mask = Nx.tensor([Tokenizers.Encoding.get_attention_mask(encoding)])
19+
20+
{output} = Ortex.run(model, {input, mask})
21+
22+
IO.inspect(output)
23+
24+
IO.inspect(
25+
output
26+
|> Nx.backend_transfer()
27+
|> Nx.argmax()
28+
|> Nx.to_number()
29+
|> id_to_label()
30+
)
31+
end
32+
end
33+
34+
Inference.run()

examples/distilbert/export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
### Install dependencies:
3+
4+
$ pip install transformers
5+
$ pip install optimum
6+
$ pip install "transformers[onnx]"
7+
8+
"""
9+
10+
from transformers import DistilBertTokenizer
11+
from optimum.onnxruntime import ORTModelForSequenceClassification
12+
13+
save_directory = "./models/distilbert-onnx/"
14+
15+
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
16+
model = ORTModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", export=True)
17+
print(model)
18+
19+
model.save_pretrained(save_directory)
20+
tokenizer.save_pretrained(save_directory)

0 commit comments

Comments
 (0)