Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions src/main/proto/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import "gopkg.in/bblfsh/sdk.v1/uast/generated.proto";

// Feature Extractor Service
service FeatureExtractor {
// Extract allows to run multiple extractors on the same uast
rpc Extract (ExtractRequest) returns (FeaturesReply) {}

// Extract identifiers weighted set
rpc Identifiers (IdentifiersRequest) returns (FeaturesReply) {}
// Extract literals weighted set
Expand All @@ -17,35 +20,59 @@ service FeatureExtractor {
rpc Graphlet (GraphletRequest) returns (FeaturesReply) {}
}

message IdentifiersOptions {
int32 docfreqThreshold = 1;
int32 weight = 2;
bool splitStem = 3;
}

message LiteralsOptions {
int32 docfreqThreshold = 1;
int32 weight = 2;
}

message Uast2seqOptions {
int32 docfreqThreshold = 1;
int32 weight = 2;
int32 stride = 3;
repeated int32 seqLen = 4;
}

message GraphletOptions {
int32 docfreqThreshold = 1;
int32 weight = 2;
}

message ExtractRequest {
gopkg.in.bblfsh.sdk.v1.uast.Node uast = 1;
IdentifiersOptions identifiers = 2;
LiteralsOptions literals = 3;
Uast2seqOptions uast2seq = 4;
GraphletOptions graphlet = 5;
}

// The identifiers request message containing extractor configuration and uast.
message IdentifiersRequest {
gopkg.in.bblfsh.sdk.v1.uast.Node uast = 1;
int32 docfreqThreshold = 2;
int32 weight = 3;
bool splitStem = 4;
IdentifiersOptions options = 2;
}

// The literals request message containing extractor configuration and uast.
message LiteralsRequest {
gopkg.in.bblfsh.sdk.v1.uast.Node uast = 1;
int32 docfreqThreshold = 2;
int32 weight = 3;
LiteralsOptions options = 2;
}

// The uast2seq request message containing extractor configuration and uast.
message Uast2seqRequest {
gopkg.in.bblfsh.sdk.v1.uast.Node uast = 1;
int32 docfreqThreshold = 2;
int32 weight = 3;
int32 stride = 4;
repeated int32 seqLen = 5;
Uast2seqOptions options = 2;
}

// The graphlet request message containing extractor configuration and uast.
message GraphletRequest {
gopkg.in.bblfsh.sdk.v1.uast.Node uast = 1;
int32 docfreqThreshold = 2;
int32 weight = 3;
GraphletOptions options = 2;
}

message Feature {
Expand Down
371 changes: 306 additions & 65 deletions src/main/python/feature-extractor/pb/service_pb2.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions src/main/python/feature-extractor/pb/service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def __init__(self, channel):
Args:
channel: A grpc.Channel.
"""
self.Extract = channel.unary_unary(
'/tech.sourced.featurext.generated.FeatureExtractor/Extract',
request_serializer=service__pb2.ExtractRequest.SerializeToString,
response_deserializer=service__pb2.FeaturesReply.FromString,
)
self.Identifiers = channel.unary_unary(
'/tech.sourced.featurext.generated.FeatureExtractor/Identifiers',
request_serializer=service__pb2.IdentifiersRequest.SerializeToString,
Expand All @@ -40,6 +45,13 @@ class FeatureExtractorServicer(object):
"""Feature Extractor Service
"""

def Extract(self, request, context):
"""Extract allows to run multiple extractors on the same uast
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Identifiers(self, request, context):
"""Extract identifiers weighted set
"""
Expand Down Expand Up @@ -71,6 +83,11 @@ def Graphlet(self, request, context):

def add_FeatureExtractorServicer_to_server(servicer, server):
rpc_method_handlers = {
'Extract': grpc.unary_unary_rpc_method_handler(
servicer.Extract,
request_deserializer=service__pb2.ExtractRequest.FromString,
response_serializer=service__pb2.FeaturesReply.SerializeToString,
),
'Identifiers': grpc.unary_unary_rpc_method_handler(
servicer.Identifiers,
request_deserializer=service__pb2.IdentifiersRequest.FromString,
Expand Down
88 changes: 58 additions & 30 deletions src/main/python/feature-extractor/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import time

import grpc
from sourced.ml.extractors import IdentifiersBagExtractor, LiteralsBagExtractor, UastSeqBagExtractor, GraphletBagExtractor
from sourced.ml.extractors import IdentifiersBagExtractor, LiteralsBagExtractor, \
UastSeqBagExtractor, GraphletBagExtractor

import pb.service_pb2 as service_pb2
import pb.service_pb2_grpc as service_pb2_grpc
Expand All @@ -20,53 +21,80 @@
class Service(service_pb2_grpc.FeatureExtractorServicer):
"""Feature Extractor Service"""

extractors_names = ["identifiers", "literals", "uast2seq", "graphlet"]

def Extract(self, request, context):
""" Extract features using multiple extrators """

extractors = []

for name in self.extractors_names:
if request.HasField(name):
options = getattr(request, name, None)
if options is None:
continue
constructor = getattr(self, "_%s_extractor" % name)
extractors.append(constructor(options))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor functions do not check for options to be None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I understand generated grpc code it can't be None if request.HasField returns true. But I added just check just in case something changes. Thanks!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, didn't know that! But yup, better be sure by still adding the check 👍


features = []

for ex in extractors:
features.extend(_features_iter_to_list(ex.extract(request.uast)))

return service_pb2.FeaturesReply(features=features)

def Identifiers(self, request, context):
"""Extract identifiers weighted set"""

extractor = IdentifiersBagExtractor(
docfreq_threshold=request.docfreqThreshold,
split_stem=request.splitStem,
weight=request.weight or 1)

return self._create_response(extractor.extract(request.uast))
it = self._identifiers_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))

def Literals(self, request, context):
"""Extract literals weighted set"""

extractor = LiteralsBagExtractor(
docfreq_threshold=request.docfreqThreshold,
weight=request.weight or 1)

return self._create_response(extractor.extract(request.uast))
it = self._literals_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))

def Uast2seq(self, request, context):
"""Extract uast2seq weighted set"""

seq_len = list(request.seqLen) if request.seqLen else None

extractor = UastSeqBagExtractor(
docfreq_threshold=request.docfreqThreshold,
weight=request.weight or 1,
stride=request.stride or 1,
seq_len=seq_len or 5)

return self._create_response(extractor.extract(request.uast))
it = self._uast2seq_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))

def Graphlet(self, request, context):
"""Extract graphlet weighted set"""

extractor = GraphletBagExtractor(
docfreq_threshold=request.docfreqThreshold,
weight=request.weight or 1)
it = self._graphlet_extractor(request.options).extract(request.uast)
return service_pb2.FeaturesReply(features=_features_iter_to_list(it))

return self._create_response(extractor.extract(request.uast))
def _identifiers_extractor(self, options):
return IdentifiersBagExtractor(
docfreq_threshold=options.docfreqThreshold,
split_stem=options.splitStem,
weight=options.weight or 1)

def _create_response(self, f_iter):
features = [
service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter
]
def _literals_extractor(self, options):
return LiteralsBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1)

return service_pb2.FeaturesReply(features=features)
def _uast2seq_extractor(self, options):
seq_len = list(options.seqLen) if options.seqLen else None

return UastSeqBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1,
stride=options.stride or 1,
seq_len=seq_len or 5)

def _graphlet_extractor(self, options):
return GraphletBagExtractor(
docfreq_threshold=options.docfreqThreshold,
weight=options.weight or 1)


def _features_iter_to_list(f_iter):
return [service_pb2.Feature(name=f[0], weight=f[1]) for f in f_iter]


def serve(port, workers):
Expand Down
48 changes: 37 additions & 11 deletions src/main/python/feature-extractor/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,46 @@ def setUp(self):
def tearDown(self):
self.server.stop(0)

def test_Extract(self):
response = self.stub.Extract(
service_pb2.ExtractRequest(
uast=self.uast,
identifiers=service_pb2.IdentifiersOptions(
docfreqThreshold=5, splitStem=False),
literals=service_pb2.LiteralsOptions(docfreqThreshold=5),
))

self.assertEqual(len(response.features), 49 + 16)
self.assertEqual(response.features[0].name, 'i.sys')
self.assertEqual(response.features[0].weight, 1)
self.assertEqual(response.features[49].name, 'l.3b286224b098296c')

def test_Identifiers(self):
response = self.stub.Identifiers(
service_pb2.IdentifiersRequest(
docfreqThreshold=5, splitStem=False, uast=self.uast))
uast=self.uast,
options=service_pb2.IdentifiersOptions(
docfreqThreshold=5, splitStem=False)))

self.assertEqual(len(response.features), 49)
self.assertEqual(response.features[0].name, 'i.sys')
self.assertEqual(response.features[0].weight, 1)

def test_Literals(self):
response = self.stub.Literals(
service_pb2.LiteralsRequest(docfreqThreshold=5, uast=self.uast))
service_pb2.LiteralsRequest(
uast=self.uast,
options=service_pb2.LiteralsOptions(docfreqThreshold=5)))

self.assertEqual(len(response.features), 16)
self.assertEqual(response.features[0].name, 'l.3b286224b098296c')
self.assertEqual(response.features[0].weight, 1)

def test_Uast2seq(self):
response = self.stub.Uast2seq(
service_pb2.Uast2seqRequest(docfreqThreshold=5, uast=self.uast))
service_pb2.Uast2seqRequest(
uast=self.uast,
options=service_pb2.Uast2seqOptions(docfreqThreshold=5)))

self.assertEqual(len(response.features), 207)
self.assertEqual(response.features[0].name,
Expand All @@ -62,7 +82,9 @@ def test_Uast2seq(self):

def test_Graphlet(self):
response = self.stub.Graphlet(
service_pb2.GraphletRequest(docfreqThreshold=5, uast=self.uast))
service_pb2.GraphletRequest(
uast=self.uast,
options=service_pb2.GraphletOptions(docfreqThreshold=5)))

self.assertEqual(len(response.features), 106)
self.assertEqual(response.features[1].name,
Expand All @@ -72,29 +94,33 @@ def test_Graphlet(self):
def test_with_weight(self):
response = self.stub.Identifiers(
service_pb2.IdentifiersRequest(
docfreqThreshold=5, splitStem=False, uast=self.uast, weight=2))
uast=self.uast,
options=service_pb2.IdentifiersOptions(
docfreqThreshold=5, splitStem=False, weight=2)))

self.assertEqual(response.features[0].weight, 2)

response = self.stub.Literals(
service_pb2.LiteralsRequest(
docfreqThreshold=5, uast=self.uast, weight=2))
uast=self.uast,
options=service_pb2.LiteralsOptions(
docfreqThreshold=5, weight=2)))

self.assertEqual(response.features[0].weight, 2)

response = self.stub.Uast2seq(
service_pb2.Uast2seqRequest(
docfreqThreshold=5,
uast=self.uast,
weight=2,
stride=2,
seqLen=[1]))
options=service_pb2.Uast2seqOptions(
docfreqThreshold=5, weight=2, stride=2, seqLen=[1])))

self.assertEqual(response.features[0].weight, 6)

response = self.stub.Graphlet(
service_pb2.GraphletRequest(
docfreqThreshold=5, uast=self.uast, weight=2))
uast=self.uast,
options=service_pb2.GraphletOptions(
docfreqThreshold=5, weight=2)))

self.assertEqual(response.features[0].weight, 2)

Expand Down
Loading