Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/csharp/CSharpTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def parseCSharp(code):


if __name__ == '__main__':
print parseCSharp("public Boolean SomeValue { get { return someValue; } set { someValue = value; } }")
print parseCSharp("Console.WriteLine('cat'); int mouse = 5; int cat = 0.4; int cow = 'c'; int moo = \"mouse\"; ")
print parseCSharp("int i = 4; // i is assigned the literal value of '4' \n int j = i // j is assigned the value of i. Since i is a variable, //it can change and is not a 'literal'")
print(parseCSharp("public Boolean SomeValue { get { return someValue; } set { someValue = value; } }"))
print(parseCSharp("Console.WriteLine('cat'); int mouse = 5; int cat = 0.4; int cow = 'c'; int moo = \"mouse\"; "))
print(parseCSharp("int i = 4; // i is assigned the literal value of '4' \n int j = i // j is assigned the value of i. Since i is a variable, //it can change and is not a 'literal'"))
try:
print parseCSharp('string `fixed = Regex.Replace(input, "\s*()","$1");');
print(parseCSharp('string `fixed = Regex.Replace(input, "\s*()","$1");'));
except:
print "Error"
print("Error")

3 changes: 1 addition & 2 deletions src/model/buildData.lua
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ function main()
cmd:text()
opt = cmd:parse(arg)
local working_dir = os.getenv("CODENN_WORK")

local vocabFile = io.open(working_dir .. "/vocab." .. opt.language, 'r')
local vocabFile = io.open(working_dir .. "/vocab." .. opt.language, 'r')
local vocab = JSON:decode(vocabFile:read())
vocabFile:close()
torch.save(working_dir .. '/vocab.data.' .. opt.language , vocab)
Expand Down
25 changes: 14 additions & 11 deletions src/model/buildData.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
END = 4

def tokenizeNL(nl):
nl = nl.strip().decode('utf-8').encode('ascii', 'replace')
nl = nl.strip()
return re.findall(r"[\w]+|[^\s\w]", nl)

def tokenizeCode(code, lang):
code = code.strip().decode('utf-8').encode('ascii', 'replace')
code = code.strip()
typedCode = None
if lang == "sql":
query = SqlTemplate(code, regex=True)
Expand All @@ -43,9 +43,12 @@ def buildVocab(filename, code_unk_threshold, nl_unk_threshold, lang):
tokens = collections.Counter()

for line in open(filename, "r"):
qid, rid, nl, code, weight = line.strip().split('\t')
tokens.update(tokenizeCode(code, lang))
words.update(tokenizeNL(nl))
if len(line.strip().split('\t')) == 5:
qid, rid, nl, code, weight = line.strip().split('\t')
tokens.update(tokenizeCode(code, lang))
words.update(tokenizeNL(nl))
#tokens.update(tokenizeCode(code, lang))
#words.update(tokenizeNL(nl))

token_count = END + 1
nl_count = END + 1
Expand Down Expand Up @@ -83,10 +86,10 @@ def get_data(filename, vocab, dont_skip, max_code_length, max_nl_length):
dataset = []
skipped = 0
for line in open(filename, 'r'):

qid, rid, nl, code, wt = line.strip().split('\t')
codeToks = tokenizeCode(code, vocab["lang"])
nlToks = tokenizeNL(nl)
if len(line.strip().split('\t'))==5:
qid, rid, nl, code, wt = line.strip().split('\t')
codeToks = tokenizeCode(code, vocab["lang"])
nlToks = tokenizeNL(nl)

datasetEntry = {"id": rid, "code": code, "code_sizes": len(codeToks), "code_num":[], "nl_num":[]}

Expand All @@ -108,8 +111,8 @@ def get_data(filename, vocab, dont_skip, max_code_length, max_nl_length):
else:
skipped += 1

print 'Total size = ' + str(len(dataset))
print 'Total skipped = ' + str(skipped)
print('Total size = ' + str(len(dataset)))
print('Total skipped = ' + str(skipped))

f = open(os.environ["CODENN_WORK"] + '/' + os.path.basename(filename) + "." + lang, 'w')
f.write(json.dumps(dataset))
Expand Down
4 changes: 2 additions & 2 deletions src/model/buildData.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ SQL_UNK_THRESHOLD=3
CSHARP_UNK_THRESHOLD=2
NL_UNK_THRESHOLD=2

python buildData.py sql $MAX_CODE_LENGTH $MAX_NL_LENGTH $SQL_UNK_THRESHOLD $NL_UNK_THRESHOLD
python buildData.py csharp $MAX_CODE_LENGTH $MAX_NL_LENGTH $CSHARP_UNK_THRESHOLD $NL_UNK_THRESHOLD
python3 buildData.py sql $MAX_CODE_LENGTH $MAX_NL_LENGTH $SQL_UNK_THRESHOLD $NL_UNK_THRESHOLD
python3 buildData.py csharp $MAX_CODE_LENGTH $MAX_NL_LENGTH $CSHARP_UNK_THRESHOLD $NL_UNK_THRESHOLD


th buildData.lua -language sql -max_code_length $MAX_CODE_LENGTH -max_nl_length $MAX_NL_LENGTH -batch_size $BATCH_SIZE
Expand Down
16 changes: 8 additions & 8 deletions src/sql/SqlTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sql.ParseTypes import *
import pdb
import re
from regexp_tokenizer import tokenizeRegex
from .regexp_tokenizer import tokenizeRegex

class SqlTemplate:

Expand Down Expand Up @@ -62,9 +62,9 @@ def renameIdentifiers(self, tok):
tok.value = "CODE_HEX"

def __hash__(self):
return hash(tuple([str(x) for x in self.tokensWithBlanks]))
return hash(tuple([str(x) for x in self.tokensWithBlanks]))

def __init__(self, sql, regex=False, rename=True):
def __init__(self, sql, regex=False, rename=True):

self.sql = SqlTemplate.sanitizeSql(sql)

Expand Down Expand Up @@ -123,7 +123,7 @@ def identifySubQueries(self, tokenList):

for tok in tokenList.tokens:
if isinstance(tok, sqlparse.sql.TokenList):
subQuery = self.identifySubQueries(tok)
subQuery = self.identifySubQueries(tok)
if (subQuery and isinstance(tok, sqlparse.sql.Parenthesis)):
tok.ptype = SUBQUERY
elif str(tok) == "select":
Expand Down Expand Up @@ -156,7 +156,7 @@ def identifyLiterals(self, tokenList):
tok.ptype = WILDCARD
elif (tok.ttype in blankTokens or isinstance(tok, blankTokenTypes[0])):
tok.ptype = COLUMN

def identifyFunctions(self, tokenList):
for tok in tokenList.tokens:
if (isinstance(tok, sqlparse.sql.Function)):
Expand All @@ -173,8 +173,8 @@ def identifyTables(self, tokenList):
if tokenList.ptype == SUBQUERY:
self.tableStack.append(False)

for i in xrange(len(tokenList.tokens)):
prevtok = tokenList.tokens[i - 1] # Possible bug but unlikely
for i in range(len(tokenList.tokens)):
prevtok = tokenList.tokens[i - 1] # Possible bug but unlikely
tok = tokenList.tokens[i]

if (str(tok) == "." and tok.ttype == sqlparse.tokens.Punctuation and prevtok.ptype == COLUMN):
Expand All @@ -187,7 +187,7 @@ def identifyTables(self, tokenList):
self.tableStack[-1] = False

if isinstance(tok, sqlparse.sql.TokenList):
self.identifyTables(tok)
self.identifyTables(tok)

elif (tok.ptype == COLUMN):
if self.tableStack[-1]:
Expand Down
4 changes: 2 additions & 2 deletions src/sql/regexp_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def tokenizeRegex(s):
return results

if __name__ == '__main__':
print tokenizeRegex("^discount[^(]*\\([0-9]+\\%\\)$")
print tokenizeRegex("'helloworld'")
print(tokenizeRegex("^discount[^(]*\\([0-9]+\\%\\)$"))
print(tokenizeRegex("'helloworld'"))
26 changes: 16 additions & 10 deletions src/sqlparse/sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
next
except NameError: # Python < 2.6
next = lambda i: i.next()
next = lambda i: i.__next__()


def _group_left_right(tlist, ttype, value, cls,
Expand Down Expand Up @@ -173,14 +173,22 @@ def _consume_cycle(tl, i):
for t in tl.tokens[i:]:
# Don't take whitespaces into account.
if t.ttype is T.Whitespace:
yield t
try:
yield t
except StopIteration:
return
continue
if next(x)(t):
yield t
try:
yield t
except StopIteration:
return
else:
if isinstance(t, sql.Comment) and t.is_multiline():
yield t
raise StopIteration
try:
yield t
except StopIteration:
return

def _next_token(tl, i):
# chooses the next token. if two tokens are found then the
Expand All @@ -202,16 +210,14 @@ def _next_token(tl, i):
return t2

# bottom up approach: group subgroups first
[group_identifier(sgroup) for sgroup in tlist.get_sublists()
if not isinstance(sgroup, sql.Identifier)]
[group_identifier(sgroup) for sgroup in tlist.get_sublists() if not isinstance(sgroup, sql.Identifier)]

# real processing
idx = 0
token = _next_token(tlist, idx)
while token:
identifier_tokens = [token] + list(
_consume_cycle(tlist,
tlist.token_index(token) + 1))
k = _consume_cycle(tlist,tlist.token_index(token)+1)
identifier_tokens = [token] + list(_consume_cycle(tlist,tlist.token_index(token) + 1))
# remove trailing whitespace
if identifier_tokens and identifier_tokens[-1].ttype is T.Whitespace:
identifier_tokens = identifier_tokens[:-1]
Expand Down
4 changes: 2 additions & 2 deletions src/sqlparse/sqlparse/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def process(self, stack, stream):
f.close()

# There was a problem loading the include file
except IOError, err:
except IOError as err:
# Raise the exception to the interpreter
if self.raiseexceptions:
raise
Expand All @@ -171,7 +171,7 @@ def process(self, stack, stream):
self.raiseexceptions)

# Max recursion limit reached
except ValueError, err:
except ValueError as err:
# Raise the exception to the interpreter
if self.raiseexceptions:
raise
Expand Down
20 changes: 9 additions & 11 deletions src/sqlparse/sqlparse/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from sqlparse import tokens
from sqlparse.keywords import KEYWORDS, KEYWORDS_COMMON
from cStringIO import StringIO
from io import StringIO


class include(str):
Expand Down Expand Up @@ -81,7 +81,7 @@ def _process_state(cls, unprocessed, processed, state):

try:
rex = re.compile(tdef[0], rflags).match
except Exception, err:
except Exception as err:
raise ValueError(("uncompilable regex %r in state"
" %r of %r: %s"
% (tdef[0], state, cls, err)))
Expand Down Expand Up @@ -135,7 +135,7 @@ def process_tokendef(cls):
cls._tmpname = 0
processed = cls._all_tokens[cls.__name__] = {}
#tokendefs = tokendefs or cls.tokens[name]
for state in cls.tokens.keys():
for state in list(cls.tokens.keys()):
cls._process_state(cls.tokens, processed, state)
return processed

Expand All @@ -152,9 +152,7 @@ def __call__(cls, *args, **kwds):
return type.__call__(cls, *args, **kwds)


class Lexer(object):

__metaclass__ = LexerMeta
class Lexer(object, metaclass=LexerMeta):

encoding = 'utf-8'
stripall = False
Expand Down Expand Up @@ -235,8 +233,8 @@ def _decode(self, text):
if self.encoding == 'guess':
try:
text = text.decode('utf-8')
if text.startswith(u'\ufeff'):
text = text[len(u'\ufeff'):]
if text.startswith('\ufeff'):
text = text[len('\ufeff'):]
except UnicodeDecodeError:
text = text.decode('latin1')
else:
Expand All @@ -258,13 +256,13 @@ def get_tokens(self, text, unfiltered=False):
Also preprocess the text, i.e. expand tabs and strip it if
wanted and applies registered filters.
"""
if isinstance(text, basestring):
if isinstance(text,str):
if self.stripall:
text = text.strip()
elif self.stripnl:
text = text.strip('\n')

if sys.version_info[0] < 3 and isinstance(text, unicode):
if sys.version_info[0] < 3 and isinstance(text, str):
text = StringIO(text.encode('utf-8'))
self.encoding = 'utf-8'
else:
Expand Down Expand Up @@ -342,7 +340,7 @@ def get_tokens_unprocessed(self, stream, stack=('root',)):
pos += 1
statestack = ['root']
statetokens = tokendefs['root']
yield pos, tokens.Text, u'\n'
yield pos, tokens.Text, '\n'
continue
yield pos, tokens.Error, text[pos]
pos += 1
Expand Down
Loading