Initial release (see README.md for more information)
This commit is contained in:
Executable
+429
@@ -0,0 +1,429 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
# disable TensorFlow output
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
import simple_cache
|
||||
import ipaddress
|
||||
import pathlib
|
||||
import config_parser
|
||||
from typing import Any, Dict
|
||||
from types import FunctionType
|
||||
import threading
|
||||
import sys
|
||||
import logging
|
||||
|
||||
def existing_file(path: str) -> pathlib.Path:
|
||||
'''
|
||||
Checks whether a path is a file and returns the it if it exists.
|
||||
|
||||
:param path: The path
|
||||
:type path: str
|
||||
|
||||
:raises argparse.ArgumentTypeError: If the path is a directory or does not exist.
|
||||
|
||||
:return: The path
|
||||
:rtype: pathlib.Path
|
||||
'''
|
||||
|
||||
if os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f'\'{path}\': Is a directory')
|
||||
if not os.path.isfile(path):
|
||||
raise argparse.ArgumentTypeError(f'\'{path}\': No such file or directory')
|
||||
return pathlib.Path(path)
|
||||
|
||||
# parse arguments
|
||||
|
||||
argument_parser = argparse.ArgumentParser(description='The API for the Word Language Detector AI.')
|
||||
argument_parser.add_argument('-c', '--config', type=existing_file, default=pathlib.Path(os.path.dirname(__file__) + '/../api.conf'), )
|
||||
argument_parser.add_argument('-H', '--host', type=ipaddress.ip_address, default=None, help='The host to listen')
|
||||
argument_parser.add_argument('-p', '--port', type=int, default=None, help='The port to listen')
|
||||
argument_parser.add_argument('-d', '--model-dir', type=pathlib.Path, default=None, help='The directory with the model. It has to include following files: \'label_encoder.json\', \'language_detector.keras\', \'max_len\' and \'tokenizer.json\'')
|
||||
args = argument_parser.parse_args()
|
||||
|
||||
# load default configuration
|
||||
with open(os.path.dirname(__file__) + '/../default_configuration.conf') as f:
|
||||
default_configuration = config_parser.ini.INIConfiguration.from_string(f.read())
|
||||
f.close()
|
||||
|
||||
# load configuration
|
||||
with open(args.config) as f:
|
||||
configuration = config_parser.ini.INIConfiguration.from_string(f.read(), default=default_configuration, ignore_errors=True)
|
||||
f.close()
|
||||
|
||||
logfile = configuration.logging.logfile
|
||||
if os.path.isdir(logfile):
|
||||
logfile = None
|
||||
if not logfile:
|
||||
logfile = None
|
||||
|
||||
LOGLEVEL = logging._nameToLevel.get(configuration.logging.loglevel.upper(), logging._nameToLevel.get(default_configuration.logging.loglevel.upper(), logging.INFO))
|
||||
|
||||
# set up logging
|
||||
LOGGING_FORMAT = '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
|
||||
logging.basicConfig(level=LOGLEVEL, format=LOGGING_FORMAT)
|
||||
logger = logging.getLogger('api')
|
||||
logger.propagate = False
|
||||
logger.handlers = list()
|
||||
|
||||
formatter = logging.Formatter(LOGGING_FORMAT)
|
||||
if logfile is not None:
|
||||
logger_file_handler = logging.FileHandler(logfile)
|
||||
logger_file_handler.setFormatter(formatter)
|
||||
logger.addHandler(logger_file_handler)
|
||||
|
||||
logger.info('test')
|
||||
|
||||
class Language:
|
||||
def __init__(self, langcode: str, repr: str):
|
||||
'''
|
||||
A language.
|
||||
|
||||
:param langcode: The language code according to ISO 639.
|
||||
:type langcode: str
|
||||
:param repr: The representation of the language (the language name).
|
||||
:type repr: str
|
||||
'''
|
||||
|
||||
self.langcode = langcode
|
||||
self.repr = repr
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.repr
|
||||
|
||||
# create language objects
|
||||
lang_en = Language('en', 'English')
|
||||
lang_de = Language('de', 'German')
|
||||
|
||||
class WordLanguageDetector:
|
||||
def __init__(self, model_path: str, tokenizer_path: str, label_encoder_path: str, max_len_path: str):
|
||||
'''
|
||||
The detector for the language of a word.
|
||||
|
||||
:param model_path: The path to the model file.
|
||||
:type model_path: str
|
||||
:param tokenizer_path: The path to the tokenizer file.
|
||||
:type tokenizer_path: str
|
||||
:param label_encoder_path: The path to the label encoder file.
|
||||
:type label_encoder_path: str
|
||||
:param max_len_path: The path to the token maximum length file.
|
||||
:type max_len_path: str
|
||||
'''
|
||||
|
||||
self.model,\
|
||||
self.label_encoder,\
|
||||
self.tokenizer,\
|
||||
self.max_len = self._load(
|
||||
model_path = model_path,
|
||||
tokenizer_path = tokenizer_path,
|
||||
label_encoder_path = label_encoder_path,
|
||||
max_len_path = max_len_path
|
||||
)
|
||||
|
||||
def _load(self, model_path: str, tokenizer_path: str, label_encoder_path: str, max_len_path: str) -> tuple:
|
||||
'''
|
||||
Loads the model, tokenizer, label encoder and token maximum length.
|
||||
|
||||
:param model_path: The path to the model file.
|
||||
:type model_path: str
|
||||
:param tokenizer_path: The path to the tokenizer file.
|
||||
:type tokenizer_path: str
|
||||
:param label_encoder_path: The path to the label encoder file.
|
||||
:type label_encoder_path: str
|
||||
:param max_len_path: The path to the token maximum length file.
|
||||
:type max_len_path: str
|
||||
|
||||
:return: The model as a tuple (model, label encoder, tokenizer, token maximum length)
|
||||
:rtype: tuple
|
||||
'''
|
||||
label_encoder = LabelEncoder()
|
||||
with open(label_encoder_path, 'r', encoding='utf-8') as f:
|
||||
label_encoder.classes_ = json.load(f)
|
||||
f.close()
|
||||
|
||||
with open(tokenizer_path, 'r', encoding='utf-8') as f:
|
||||
tokenizer = tokenizer_from_json(f.read())
|
||||
f.close()
|
||||
|
||||
model = load_model(model_path)
|
||||
|
||||
with open(max_len_path, 'rb') as f:
|
||||
max_len = int.from_bytes(f.read())
|
||||
f.close()
|
||||
|
||||
return model, label_encoder, tokenizer, max_len
|
||||
|
||||
def predict_language(self, word: str) -> Language:
|
||||
'''
|
||||
Predicts the language of a word.
|
||||
|
||||
:param word: The word
|
||||
:type word: str
|
||||
|
||||
:return: The language of the word
|
||||
:rtype: Language
|
||||
'''
|
||||
word = word.lower()
|
||||
|
||||
seq = self.tokenizer.texts_to_sequences([word])
|
||||
|
||||
padded = pad_sequences(seq, maxlen=self.max_len, padding='post')
|
||||
|
||||
pred = self.model.predict(padded)[0][0]
|
||||
|
||||
if pred > 0.5:
|
||||
return lang_en
|
||||
else:
|
||||
return lang_de
|
||||
|
||||
class Word(BaseModel):
|
||||
word: str
|
||||
enableCache: bool = True
|
||||
|
||||
class Words(BaseModel):
|
||||
words: list[Word] = Field(..., description='The words to be detected')
|
||||
|
||||
class WordInfo(BaseModel):
|
||||
language: str
|
||||
|
||||
class LanguagesResponse(BaseModel):
|
||||
words: Dict[str, WordInfo]
|
||||
|
||||
class DisabledCache:
|
||||
def __init__(self, generate_value_func: FunctionType, max_size: int = 256, ttl: int = 120) -> None:
|
||||
'''
|
||||
The cache class for disabling the cache, but compatible with
|
||||
simple_cache.Cache.
|
||||
It always calls the function to generate a value when a value is
|
||||
retrieved from the cache. No data is cached and the methods for storing
|
||||
data to the cache or delete data from it do nothing.
|
||||
|
||||
'''
|
||||
self.generate_value_func = generate_value_func
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _set(self, key: str, value: Any) -> None: pass
|
||||
|
||||
def _delete(self, key: str, value: Any) -> None: pass
|
||||
|
||||
def _get(self, key: str) -> Any:
|
||||
return self.generate_value_func(key)
|
||||
|
||||
def _clean_item(self, key: str) -> None: pass
|
||||
|
||||
def _clear_expired(self) -> None: pass
|
||||
|
||||
def _clear(self) -> None: pass
|
||||
|
||||
__setitem__ = set = _set
|
||||
__delitem__ = delete = _delete
|
||||
__getitem__ = get = _get
|
||||
clean_item = _clean_item
|
||||
clear_expired = _clear_expired
|
||||
clear = _clear
|
||||
|
||||
def load_ai(model_dir: str) -> None:
|
||||
'''
|
||||
Loads the word language detector AI.
|
||||
|
||||
:param model_dir: The directory with the model.
|
||||
:type model_dir: str
|
||||
'''
|
||||
|
||||
global tokenizer_from_json, pad_sequences, load_model, LabelEncoder, word_language_detector, ai_loaded
|
||||
|
||||
# import modules
|
||||
logger.info('Importing AI modules ...')
|
||||
imports_start_time = time.time() * 1000
|
||||
from tensorflow.keras.preprocessing.text import tokenizer_from_json
|
||||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||
from tensorflow.keras.models import load_model
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
logger.info(f'AI modules imports completed in {time.time() * 1000 - imports_start_time:.2f} ms')
|
||||
|
||||
# load model
|
||||
word_language_detector = WordLanguageDetector(
|
||||
model_path=f'{model_dir}/language_detector.keras',
|
||||
tokenizer_path=f'{model_dir}/tokenizer.json',
|
||||
label_encoder_path=f'{model_dir}/label_encoder.json',
|
||||
max_len_path=f'{model_dir}/max_len'
|
||||
)
|
||||
ai_loaded = True
|
||||
|
||||
|
||||
def validate_host(host: str) -> bool:
|
||||
'''
|
||||
Checks whether the address is a valid address (``ip:port``), e. g. ``0.0.0.0:3000``.
|
||||
|
||||
:param addr: The address
|
||||
:type addr: str
|
||||
|
||||
:return: ``True`` if the address is valid, otherwise ``False``.
|
||||
:rtype: bool
|
||||
'''
|
||||
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def predict_language_wrapper(word: str) -> Language:
|
||||
if not isinstance(cache, DisabledCache):
|
||||
logger.info(f'Language of the word \'{word}\' not found in cache, regenerating it')
|
||||
return word_language_detector.predict_language(word)
|
||||
|
||||
|
||||
ai_loaded = False
|
||||
|
||||
# host
|
||||
host = args.host or configuration.server.host
|
||||
if not validate_host(host):
|
||||
logger.error(f'Error in configuration: [server]host: \'{host}\' is not a valid host, using default value.')
|
||||
host = default_configuration.server.host
|
||||
|
||||
|
||||
# port
|
||||
port = args.port or configuration.server.port
|
||||
if not port.isdigit():
|
||||
logger.error(f'Error in configuration: [server]port: \'{port}\' is not an integer, using default value.')
|
||||
port = default_configuration.server.port
|
||||
else:
|
||||
if not (0 <= int(port) <= 65535):
|
||||
logger.error(f'Error in configuration: [server]port: {port} is not between 0 and 65535, using default value.')
|
||||
port = default_configuration.server.port
|
||||
port = int(port)
|
||||
|
||||
|
||||
# model directory
|
||||
model_dir = (args.model_dir or configuration.model.modelDirectory).strip()
|
||||
if not model_dir:
|
||||
logger.critical('No model directory specified. Please set the option modelDirectory in the [model] section to the model directory.')
|
||||
sys.exit(1)
|
||||
if not os.path.exists(model_dir):
|
||||
logger.critical(f'Error in configuration: [model]modelDirectory: {model_dir}: no such file or directory')
|
||||
sys.exit(1)
|
||||
if not os.path.isdir(model_dir):
|
||||
logger.critical(f'Error in configuration: [model]modelDirectory: {model_dir}: not a directory')
|
||||
sys.exit(1)
|
||||
|
||||
# AI modules lazy import
|
||||
ai_modules_lazy_import = configuration.model.lazyImport
|
||||
if ai_modules_lazy_import.lower() not in ('true', 'false'):
|
||||
logger.error(f'Error in configuration: [model]lazyImport: \'{ai_modules_lazy_import}\' is not a boolean, using default value.')
|
||||
ai_modules_lazy_import = default_configuration.model.lazyImport
|
||||
|
||||
ai_modules_lazy_import = ai_modules_lazy_import.lower() == 'true'
|
||||
|
||||
|
||||
# Initialize cache
|
||||
|
||||
# maximum size of the cache
|
||||
cache_max_size = configuration.cache.maxSize
|
||||
if not cache_max_size.isdigit():
|
||||
logger.error(f'Error in configuration: [cache]maxSize: \'{cache_max_size}\' is not an integer, using default value.')
|
||||
cache_max_size = default_configuration.cache.maxSize
|
||||
cache_max_size = int(cache_max_size)
|
||||
|
||||
# TTL
|
||||
cache_ttl = configuration.cache.ttl
|
||||
if not cache_ttl.isdigit():
|
||||
logger.error(f'Error in configuration: [cache]ttl: \'{cache_ttl}\' is not an integer, using default value.')
|
||||
cache_ttl = default_configuration.cache.ttl
|
||||
cache_ttl = int(cache_ttl)
|
||||
|
||||
cache_enabled = configuration.cache.enableCache
|
||||
if cache_enabled.lower() not in ('true', 'false'):
|
||||
logger.error(f'Error in configuration: [cache]enableCache: \'{cache_enabled}\' is not a boolean, using default value.')
|
||||
cache_enabled = default_configuration.cache.enableCache
|
||||
|
||||
cache_enabled = cache_enabled.lower() == 'true'
|
||||
|
||||
if cache_enabled:
|
||||
# cache is enabled
|
||||
cache = simple_cache.Cache(
|
||||
generate_value_func = predict_language_wrapper,
|
||||
max_size=cache_max_size,
|
||||
ttl=cache_ttl
|
||||
)
|
||||
else:
|
||||
# cache is disabled
|
||||
cache = DisabledCache(generate_value_func = predict_language_wrapper)
|
||||
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(
|
||||
docs_url=configuration.docs.swagger,
|
||||
redoc_url=configuration.docs.redoc,
|
||||
openapi_url=configuration.docs.openapi
|
||||
)
|
||||
|
||||
if not ai_modules_lazy_import:
|
||||
load_ai()
|
||||
|
||||
|
||||
# predict word language
|
||||
@app.post('/prediction', response_model=LanguagesResponse)
|
||||
async def predict_language(words: Words):
|
||||
if not ai_loaded:
|
||||
load_ai(model_dir)
|
||||
|
||||
return {
|
||||
'words': {
|
||||
w.word: {
|
||||
'language': (cache.get(w.word) if w.enableCache else word_language_detector.predict_language(w.word)).langcode
|
||||
} for w in words.words
|
||||
}
|
||||
}
|
||||
|
||||
# list all supported languages
|
||||
@app.get('/languages')
|
||||
async def get_languages():
|
||||
return {
|
||||
'languages': [
|
||||
'en',
|
||||
'de'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# run Uvicorn
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(
|
||||
app,
|
||||
host = host,
|
||||
port = port,
|
||||
server_header = False,
|
||||
log_config = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
'handlers': {
|
||||
'uvicorn_handler': {
|
||||
'class': 'logging.FileHandler',
|
||||
'filename': logfile,
|
||||
'formatter': 'uvicorn_formatter',
|
||||
},
|
||||
},
|
||||
'formatters': {
|
||||
'uvicorn_formatter': {
|
||||
'format': LOGGING_FORMAT
|
||||
},
|
||||
},
|
||||
'loggers': {
|
||||
'uvicorn': {
|
||||
'handlers': ['uvicorn_handler'],
|
||||
'level': LOGLEVEL,
|
||||
'propagate': False
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user