add "tf.keras" in __init__
add "tf.keras" in __init__
This commit is contained in:
parent
919d0a9312
commit
58781c8e49
@ -5,10 +5,30 @@
|
||||
# @function :
|
||||
|
||||
|
||||
# from keras_textclassification.text_classification_api import train
|
||||
#
|
||||
# res = "假道士敷衍超渡,鬼魂一家感觉受到了屈辱,现出真身捉弄他"
|
||||
# mention = "道士"
|
||||
# offset = 1
|
||||
# print(res[1])
|
||||
# print(res[1+1])
|
||||
import sys
|
||||
import os
|
||||
path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(path_root)
|
||||
from distutils.util import strtobool
|
||||
|
||||
|
||||
# gpu/tf日志的环境, 默认CPU
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "-1") # "0,1"
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
# 自动化(默认AUTO), 即定义是不是高自由度, 如"CUSTOM"可以高度自定义, 网络架构embedding/graph/loss等均可高度自定义
|
||||
# 默认使用keras_textclassification.keras
|
||||
|
||||
tf_keras = os.environ.get("TF_KERAS", "0")
|
||||
print(tf_keras)
|
||||
is_tf_keras = strtobool(tf_keras)
|
||||
|
||||
|
||||
if is_tf_keras:
|
||||
import tensorflow as tf
|
||||
# Python Import机制备忘-模块搜索路径(sys.path)、嵌套Import、package Import
|
||||
sys.modules["keras"] = tf.keras
|
||||
|
||||
|
||||
__version__ = "0.2.0"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user