Add NSFW detection

This commit is contained in:
Martin Herkt
2017-10-27 05:22:11 +02:00
parent def5d9802f
commit 7bbeb2d144
7 changed files with 3666 additions and 10 deletions

View File

@@ -44,6 +44,13 @@ app.config["FHOST_MIME_BLACKLIST"] = [
app.config["FHOST_UPLOAD_BLACKLIST"] = "tornodes.txt"
app.config["NSFW_DETECT"] = True
app.config["NSFW_THRESHOLD"] = 0.7
if app.config["NSFW_DETECT"]:
from nsfw_detect import NSFWDetector
nsfw = NSFWDetector()
try:
mimedetect = Magic(mime=True, mime_encoding=False)
except:
@@ -72,6 +79,9 @@ class URL(db.Model):
def getname(self):
return su.enbase(self.id, 1)
def geturl(self):
return url_for("get", path=self.getname(), _external=True) + "\n"
class File(db.Model):
id = db.Column(db.Integer, primary_key = True)
sha256 = db.Column(db.String, unique = True)
@@ -79,23 +89,29 @@ class File(db.Model):
mime = db.Column(db.UnicodeText)
addr = db.Column(db.UnicodeText)
removed = db.Column(db.Boolean, default=False)
nsfw_score = db.Column(db.Float)
def __init__(self, sha256, ext, mime, addr):
def __init__(self, sha256, ext, mime, addr, nsfw_score):
self.sha256 = sha256
self.ext = ext
self.mime = mime
self.addr = addr
self.nsfw_score = nsfw_score
def getname(self):
return u"{0}{1}".format(su.enbase(self.id, 1), self.ext)
def geturl(self):
n = self.getname()
if self.nsfw_score and self.nsfw_score > app.config["NSFW_THRESHOLD"]:
return url_for("get", path=n, _external=True, _anchor="nsfw") + "\n"
else:
return url_for("get", path=n, _external=True) + "\n"
def getpath(fn):
return os.path.join(app.config["FHOST_STORAGE_PATH"], fn)
def geturl(p):
return url_for("get", path=p, _external=True) + "\n"
def fhost_url(scheme=None):
if not scheme:
return url_for(".fhost", _external=True).rstrip("/")
@@ -115,13 +131,13 @@ def shorten(url):
existing = URL.query.filter_by(url=url).first()
if existing:
return geturl(existing.getname())
return existing.geturl()
else:
u = URL(url)
db.session.add(u)
db.session.commit()
return geturl(u.getname())
return u.geturl()
def in_upload_bl(addr):
if os.path.isfile(app.config["FHOST_UPLOAD_BLACKLIST"]):
@@ -152,11 +168,15 @@ def store_file(f, addr):
with open(epath, "wb") as of:
of.write(data)
if existing.nsfw_score == None:
if app.config["NSFW_DETECT"]:
existing.nsfw_score = nsfw.detect(epath)
os.utime(epath, None)
existing.addr = addr
db.session.commit()
return geturl(existing.getname())
return existing.geturl()
else:
guessmime = mimedetect.from_buffer(data)
@@ -186,14 +206,21 @@ def store_file(f, addr):
if not ext:
ext = ".bin"
with open(getpath(digest), "wb") as of:
spath = getpath(digest)
with open(spath, "wb") as of:
of.write(data)
sf = File(digest, ext, mime, addr)
if app.config["NSFW_DETECT"]:
nsfw_score = nsfw.detect(spath)
else:
nsfw_score = None
sf = File(digest, ext, mime, addr, nsfw_score)
db.session.add(sf)
db.session.commit()
return geturl(sf.getname())
return sf.geturl()
def store_url(url, addr):
if is_fhost_url(url):
@@ -438,6 +465,37 @@ def queryaddr(a):
for f in res:
query(su.enbase(f.id, 1))
def nsfw_detect(f):
try:
open(f["path"], 'r').close()
f["nsfw_score"] = nsfw.detect(f["path"])
return f
except:
return None
@manager.command
def update_nsfw():
if not app.config["NSFW_DETECT"]:
print("NSFW detection is disabled in app config")
return 1
from multiprocessing import Pool
import tqdm
res = File.query.filter_by(nsfw_score=None, removed=False)
with Pool() as p:
results = []
work = [{ "path" : getpath(f.sha256), "id" : f.id} for f in res]
for r in tqdm.tqdm(p.imap_unordered(nsfw_detect, work), total=len(work)):
if r:
results.append({"id": r["id"], "nsfw_score" : r["nsfw_score"]})
db.session.bulk_update_mappings(File, results)
db.session.commit()
@manager.command
def querybl():
if os.path.isfile(app.config["FHOST_UPLOAD_BLACKLIST"]):