Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 90 additions & 46 deletions src/main/python/feature-extractor/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
sys.path.append('./pb')

from concurrent import futures
from multiprocessing import Pool
import argparse
import logging
import os
import signal
import time

import grpc
Expand All @@ -14,93 +16,133 @@

import pb.service_pb2 as service_pb2
import pb.service_pb2_grpc as service_pb2_grpc
# isn't used but re-exported so it can be used in tests
from pb.service_pb2 import gopkg_dot_in_dot_bblfsh_dot_sdk_dot_v1_dot_uast_dot_generated__pb2 as uast_pb

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# keep extractors out of the Service class to be able to pickle them
# return list instead of iterator for pickle also


def _identifiers_extractor(uast, options):
return list(
IdentifiersBagExtractor(
docfreq_threshold=options.docfreqThreshold,
split_stem=options.splitStem,
weight=options.weight or 1).extract(uast))


def _literals_extractor(uast, options):
return list(
LiteralsBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1).extract(uast))


def _uast2seq_extractor(uast, options):
seq_len = list(options.seqLen) if options.seqLen else None

return list(
UastSeqBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1,
stride=options.stride or 1,
seq_len=seq_len or 5).extract(uast))


def _graphlet_extractor(uast, options):
return list(
GraphletBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1).extract(uast))


def _features_from_iter(f_iter):
return [service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter]


class Service(service_pb2_grpc.FeatureExtractorServicer):
"""Feature Extractor Service"""

extractors_names = ["identifiers", "literals", "uast2seq", "graphlet"]
pool = None
extractors = {
"identifiers": _identifiers_extractor,
"literals": _literals_extractor,
"uast2seq": _uast2seq_extractor,
"graphlet": _graphlet_extractor,
}

def __init__(self, pool):
super(Service, self).__init__()
self.pool = pool

def Extract(self, request, context):
""" Extract features using multiple extrators """

extractors = []
results = []

for name in self.extractors_names:
for name in self.extractors:
if request.HasField(name):
options = getattr(request, name, None)
if options is None:
continue
constructor = getattr(self, "_%s_extractor" % name)
extractors.append(constructor(options))

result = self.pool.apply_async(self.extractors[name],
(request.uast, options))
results.append(result)

features = []

for ex in extractors:
features.extend(_features_iter_to_list(ex.extract(request.uast)))
for result in results:
features.extend(_features_from_iter(result.get()))

return service_pb2.FeaturesReply(features=features)

def Identifiers(self, request, context):
"""Extract identifiers weighted set"""

it = self._identifiers_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))
it = self.pool.apply(_identifiers_extractor,
(request.uast, request.options))
return service_pb2.FeaturesReply(features=_features_from_iter(it))

def Literals(self, request, context):
"""Extract literals weighted set"""

it = self._literals_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))
it = self.pool.apply(_literals_extractor,
(request.uast, request.options))
return service_pb2.FeaturesReply(features=_features_from_iter(it))

def Uast2seq(self, request, context):
"""Extract uast2seq weighted set"""

it = self._uast2seq_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))
it = self.pool.apply(_uast2seq_extractor,
(request.uast, request.options))
return service_pb2.FeaturesReply(features=_features_from_iter(it))

def Graphlet(self, request, context):
"""Extract graphlet weighted set"""

it = self._graphlet_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))

def _identifiers_extractor(self, options):
return IdentifiersBagExtractor(
docfreq_threshold=options.docfreqThreshold,
split_stem=options.splitStem,
weight=options.weight or 1)

def _literals_extractor(self, options):
return LiteralsBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1)

def _uast2seq_extractor(self, options):
seq_len = list(options.seqLen) if options.seqLen else None
it = self.pool.apply(_graphlet_extractor,
(request.uast, request.options))
return service_pb2.FeaturesReply(features=_features_from_iter(it))

return UastSeqBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1,
stride=options.stride or 1,
seq_len=seq_len or 5)

def _graphlet_extractor(self, options):
return GraphletBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1)


def _features_iter_to_list(f_iter):
return [service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter]
def worker_init():
""" ignore SIGINT (Ctrl-C) event inside workers.
Read more here:
https://stackoverflow.com/questions/1408356/keyboard-interrupts-with-pythons-multiprocessing-pool
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)


def serve(port, workers):
logger = logging.getLogger('feature-extractor')

server = _get_server(port, workers)
# processes=None uses os.cpu_count() as a value
pool = Pool(processes=None, initializer=worker_init)

server = _get_server(port, workers, pool)
server.start()
logger.info("server started on port %d" % port)

Expand All @@ -110,19 +152,21 @@ def serve(port, workers):
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
pool.terminate()
Copy link

@se7entyse7en se7entyse7en Feb 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be called for exceptions different from KeyboardInterrupt. Actually I don't know whether it is going to leak some resources. Maybe you could use this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All sub-processes are created using fork. Operation system is going to take care of cleaning up when the server exits. So it's not a big deal. Also, we run feature-extractor in the container and any exit would cause restart of the container which would clean-up everything.

I don't think we need to improve termination here as long as we don't have any real issue with how it works right now.

server.stop(0)


def _get_server(port, workers):
def _get_server(port, workers, pool):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=workers))
service_pb2_grpc.add_FeatureExtractorServicer_to_server(Service(), server)
service_pb2_grpc.add_FeatureExtractorServicer_to_server(
Service(pool), server)
server.add_insecure_port('[::]:%d' % port)
return server


if __name__ == '__main__':
port = int(os.getenv('FEATURE_EXT_PORT', "9001"))
workers = int(os.getenv('FEATURE_EXT_WORKERS', "10"))
workers = int(os.getenv('FEATURE_EXT_WORKERS', "100"))

parser = argparse.ArgumentParser(description='Feature Extractor Service.')
parser.add_argument(
Expand Down
12 changes: 7 additions & 5 deletions src/main/python/feature-extractor/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import sys
sys.path.append('./pb')

from multiprocessing import Pool
import grpc
import json
import unittest

import pb.service_pb2 as service_pb2
import pb.service_pb2_grpc as service_pb2_grpc
from google.protobuf.json_format import ParseDict as ProtoParseDict
from pb.service_pb2 import gopkg_dot_in_dot_bblfsh_dot_sdk_dot_v1_dot_uast_dot_generated__pb2 as uast_pb
from server import _get_server
# all grpc stuff must be imported from server and not directly from pb package
# otherwise requests will be failing with
# PicklingError: Can't pickle <class ...>: it's not the same object as ...
from server import _get_server, service_pb2, service_pb2_grpc, uast_pb


class TestServer(unittest.TestCase):
Expand All @@ -24,8 +25,9 @@ def setUp(self):
node.ParseFromString(f.read())
self.uast = node

pool = Pool(processes=1)
port = get_open_port()
self.server = _get_server(port, 1)
self.server = _get_server(port, 1, pool)
self.server.start()

channel = grpc.insecure_channel("localhost:%d" % port)
Expand Down