Skip to content

Compose literals for argument values in docstring #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,73 @@ class cpp_function : public function {
return new detail::function_record();
}

template<typename ContainerHandle, typename StringifierUnary>
static std::string join(ContainerHandle container, StringifierUnary&& f, std::string sep = ", ") {
std::string joined;
for (auto element : container) {
joined += f(element) + sep;
}
if (!joined.empty()) {
joined.erase(joined.size() - sep.size());
}
return joined;
}

// Generate a literal expression for default function argument values
static std::string compose_literal(pybind11::handle h) {
auto typehandle = type::handle_of(h);
if (detail::get_internals().registered_types_py.count(Py_TYPE(h.ptr())) > 0) {
if (hasattr(typehandle, "__members__") && hasattr(h, "name")) {
// Bound enum type, can be fully represented
auto descr = typehandle.attr("__module__").cast<std::string>();
descr += "." + typehandle.attr("__qualname__").cast<std::string>();
descr += "." + h.attr("name").cast<std::string>();
return descr;
}

// Use ellipsis expression instead of repr to ensure syntactic validity
return "...";
}

if (isinstance<dict>(h)) {
std::string literal = "{";
literal += join(
reinterpret_borrow<dict>(h),
[](const std::pair<handle, handle>& v) { return compose_literal(v.first) + ": " + compose_literal(v.second); }
);
literal += "}";
return literal;
}

if (isinstance<list>(h)) {
std::string literal = "[";
literal += join(reinterpret_borrow<list>(h), &compose_literal);
literal += "]";
return literal;
}

if (isinstance<tuple>(h)) {
std::string literal = "(";
literal += join(reinterpret_borrow<tuple>(h), &compose_literal);
literal += ")";
return literal;
}

if (isinstance<set>(h)) {
auto v = reinterpret_borrow<set>(h);
if (v.empty()) {
return "set()";
}
std::string literal = "{";
literal += join(v, &compose_literal);
literal += "}";
return literal;
}

// All other types should be terminal and well-represented by repr
return repr(h).cast<std::string>();
}

/// Special internal constructor for functors, lambda functions, etc.
template <typename Func, typename Return, typename... Args, typename... Extra>
void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) {
Expand Down Expand Up @@ -235,8 +302,10 @@ class cpp_function : public function {
a.name = strdup(a.name);
if (a.descr)
a.descr = strdup(a.descr);
else if (a.value)
a.descr = strdup(repr(a.value).cast<std::string>().c_str());
else if (a.value) {
std::string literal = compose_literal(a.value);
a.descr = strdup(literal.c_str());
}
}

rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__");
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ set(PYBIND11_TEST_FILES
test_copy_move.cpp
test_custom_type_casters.cpp
test_docstring_options.cpp
test_docstring_function_signature.cpp
test_eigen.cpp
test_enum.cpp
test_eval.cpp
Expand Down
23 changes: 23 additions & 0 deletions tests/test_docstring_function_signature.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
tests/test_docstring_options.cpp -- generation of docstrings function signatures

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/

#include "pybind11_tests.h"
#include "pybind11/stl.h"

enum class Color {Red};

TEST_SUBMODULE(docstring_function_signature, m) {
// test_docstring_function_signatures
pybind11::enum_<Color> (m, "Color").value("Red", Color::Red);
m.def("a", [](Color) {}, pybind11::arg("a") = Color::Red);
m.def("b", [](int) {}, pybind11::arg("a") = 1);
m.def("c", [](std::vector<int>) {}, pybind11::arg("a") = std::vector<int> {{1, 2, 3, 4}});
m.def("d", [](UserType) {}, pybind11::arg("a") = UserType {});
m.def("e", [](std::pair<UserType, int>) {}, pybind11::arg("a") = std::make_pair<UserType, int>(UserType(), 4));
m.def("f", [](std::vector<Color>) {}, pybind11::arg("a") = std::vector<Color> {Color::Red});
m.def("g", [](std::tuple<int, Color, double>) {}, pybind11::arg("a") = std::make_tuple(4, Color::Red, 1.9));
}
41 changes: 41 additions & 0 deletions tests/test_docstring_function_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
from pybind11_tests import docstring_function_signature as m
import sys


def test_docstring_function_signature():
def syntactically_valid(sig):
try:
complete_fnsig = "def " + sig + ": pass"
ast.parse(complete_fnsig)
return True
except SyntaxError:
return False

pass

methods = ["a", "b", "c", "d", "e", "f", "g"]
root_module = "pybind11_tests"
module = "{}.{}".format(root_module, "docstring_function_signature")
expected_signatures = [
"a(a: {0}.Color = {0}.Color.Red) -> None".format(module),
"b(a: int = 1) -> None",
"c(a: List[int] = [1, 2, 3, 4]) -> None",
"d(a: {}.UserType = ...) -> None".format(root_module),
"e(a: Tuple[{}.UserType, int] = (..., 4)) -> None".format(root_module),
"f(a: List[{0}.Color] = [{0}.Color.Red]) -> None".format(module),
"g(a: Tuple[int, {0}.Color, float] = (4, {0}.Color.Red, 1.9)) -> None".format(
module
),
]

for method, signature in zip(methods, expected_signatures):
docstring = getattr(m, method).__doc__.strip("\n")
assert docstring == signature

if sys.version_info.major >= 3 and sys.version_info.minor >= 5:
import ast

for method in methods:
docstring = getattr(m, method).__doc__.strip("\n")
assert syntactically_valid(docstring)