Skip to content

Replace stream with cycles #7182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.util.Assert;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

/**
* Exports the authentication {@link Configuration}
Expand Down Expand Up @@ -153,10 +152,7 @@ private <T> T lazyBean(Class<T> interfaceName) {
}
String beanName;
if (beanNamesForType.length > 1) {
List<String> primaryBeanNames = Arrays.stream(beanNamesForType)
.filter(i -> applicationContext instanceof ConfigurableApplicationContext)
.filter(n -> ((ConfigurableApplicationContext) applicationContext).getBeanFactory().getBeanDefinition(n).isPrimary())
.collect(Collectors.toList());
List<String> primaryBeanNames = getPrimaryBeanNames(beanNamesForType);

Assert.isTrue(primaryBeanNames.size() != 0, () -> "Found " + beanNamesForType.length
+ " beans for type " + interfaceName + ", but none marked as primary");
Expand All @@ -175,6 +171,20 @@ private <T> T lazyBean(Class<T> interfaceName) {
return (T) proxyFactory.getObject();
}

private List<String> getPrimaryBeanNames(String[] beanNamesForType) {
List<String> list = new ArrayList<>();
if (!(applicationContext instanceof ConfigurableApplicationContext)) {
return Collections.emptyList();
}
for (String beanName : beanNamesForType) {
if (((ConfigurableApplicationContext) applicationContext).getBeanFactory()
.getBeanDefinition(beanName).isPrimary()) {
list.add(beanName);
}
}
return list;
}

private AuthenticationManager getAuthenticationManagerBean() {
return lazyBean(AuthenticationManager.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;

/**
* A {@link ReactiveAuthorizationManager} that determines if the current user is
Expand Down Expand Up @@ -109,9 +108,14 @@ public static <T> AuthorityReactiveAuthorizationManager<T> hasAnyRole(String...
Assert.notNull(role, "role cannot be null");
}

return hasAnyAuthority(Stream.of(roles)
.map(r -> "ROLE_" + r)
.toArray(String[]::new)
);
return hasAnyAuthority(toNamedRolesArray(roles));
}

private static String[] toNamedRolesArray(String... roles) {
String[] result = new String[roles.length];
for (int i=0; i < roles.length; i++) {
result[i] = "ROLE_" + roles[i];
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

package org.springframework.security.converter;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.Base64;
import java.util.stream.Collectors;

import org.springframework.core.convert.converter.Converter;
Expand Down Expand Up @@ -66,10 +66,13 @@ public static Converter<InputStream, RSAPrivateKey> pkcs8() {
Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(PKCS8_PEM_HEADER),
"Key is not in PEM-encoded PKCS#8 format, " +
"please check that the header begins with -----" + PKCS8_PEM_HEADER + "-----");
String base64Encoded = lines.stream()
.filter(RsaKeyConverters::isNotPkcs8Wrapper)
.collect(Collectors.joining());
byte[] pkcs8 = Base64.getDecoder().decode(base64Encoded);
StringBuilder base64Encoded = new StringBuilder();
for (String line : lines) {
if (RsaKeyConverters.isNotPkcs8Wrapper(line)) {
base64Encoded.append(line);
}
}
byte[] pkcs8 = Base64.getDecoder().decode(base64Encoded.toString());

try {
return (RSAPrivateKey) keyFactory.generatePrivate(
Expand Down Expand Up @@ -97,10 +100,13 @@ public static Converter<InputStream, RSAPublicKey> x509() {
Assert.isTrue(!lines.isEmpty() && lines.get(0).startsWith(X509_PEM_HEADER),
"Key is not in PEM-encoded X.509 format, " +
"please check that the header begins with -----" + X509_PEM_HEADER + "-----");
String base64Encoded = lines.stream()
.filter(RsaKeyConverters::isNotX509Wrapper)
.collect(Collectors.joining());
byte[] x509 = Base64.getDecoder().decode(base64Encoded);
StringBuilder base64Encoded = new StringBuilder();
for (String line : lines) {
if (RsaKeyConverters.isNotX509Wrapper(line)) {
base64Encoded.append(line);
}
}
byte[] x509 = Base64.getDecoder().decode(base64Encoded.toString());

try {
return (RSAPublicKey) keyFactory.generatePublic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.util.Assert;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -56,7 +55,10 @@ public MapReactiveUserDetailsService(UserDetails... users) {
*/
public MapReactiveUserDetailsService(Collection<UserDetails> users) {
Assert.notEmpty(users, "users cannot be null or empty");
this.users = users.stream().collect(Collectors.toConcurrentMap( u -> getKey(u.getUsername()), Function.identity()));
this.users = new ConcurrentHashMap<>();
for (UserDetails user : users) {
this.users.put(getKey(user.getUsername()), user);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ public void constructorEmptyUsers() {
new MapReactiveUserDetailsService(users);
}

@Test
public void constructorCaseIntensiveKey() {
UserDetails userDetails = User.withUsername("USER").password("password").roles("USER").build();
MapReactiveUserDetailsService userDetailsService = new MapReactiveUserDetailsService(userDetails);
assertThat(userDetailsService.findByUsername("user").block()).isEqualTo(userDetails);
}

@Test
public void findByUsernameWhenFoundThenReturns() {
assertThat((users.findByUsername(USER_DETAILS.getUsername()).block())).isEqualTo(USER_DETAILS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
* An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates
Expand Down Expand Up @@ -64,10 +63,12 @@ public DelegatingOAuth2AuthorizedClientProvider(List<OAuth2AuthorizedClientProvi
@Nullable
public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
Assert.notNull(context, "context cannot be null");
return this.authorizedClientProviders.stream()
.map(authorizedClientProvider -> authorizedClientProvider.authorize(context))
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
for (OAuth2AuthorizedClientProvider authorizedClientProvider : authorizedClientProviders) {
OAuth2AuthorizedClient oauth2AuthorizedClient = authorizedClientProvider.authorize(context);
if (oauth2AuthorizedClient != null) {
return oauth2AuthorizedClient;
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.List;
import java.util.LinkedHashMap;
import java.util.ArrayList;
import java.util.function.Consumer;
import java.util.stream.Collectors;

/**
* A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of
Expand Down Expand Up @@ -286,10 +286,10 @@ public OAuth2AuthorizedClientProvider build() {
* @return the {@link DelegatingOAuth2AuthorizedClientProvider}
*/
public OAuth2AuthorizedClientProvider build() {
List<OAuth2AuthorizedClientProvider> authorizedClientProviders =
this.builders.values().stream()
.map(Builder::build)
.collect(Collectors.toList());
List<OAuth2AuthorizedClientProvider> authorizedClientProviders = new ArrayList<>();
for (Builder builder : this.builders.values()) {
authorizedClientProviders.add(builder.build());
}
return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
* An {@link OAuth2TokenValidator} responsible for
Expand Down Expand Up @@ -137,11 +136,8 @@ public void setClockSkew(Duration clockSkew) {
}

private static OAuth2Error invalidIdToken(Map<String, Object> invalidClaims) {
String claimsDetail = invalidClaims.entrySet().stream()
.map(it -> it.getKey() + " (" + it.getValue() + ")")
.collect(Collectors.joining(", "));
return new OAuth2Error("invalid_id_token",
"The ID Token contains invalid claims: " + claimsDetail,
"The ID Token contains invalid claims: " + invalidClaims,
"https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;
import java.util.stream.Collector;

import static java.util.stream.Collectors.collectingAndThen;
import static java.util.stream.Collectors.toConcurrentMap;
import java.util.concurrent.ConcurrentHashMap;

/**
* A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory.
Expand Down Expand Up @@ -62,9 +57,19 @@ public InMemoryClientRegistrationRepository(List<ClientRegistration> registratio

private static Map<String, ClientRegistration> createRegistrationsMap(List<ClientRegistration> registrations) {
Assert.notEmpty(registrations, "registrations cannot be empty");
Collector<ClientRegistration, ?, ConcurrentMap<String, ClientRegistration>> collector =
toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity());
return registrations.stream().collect(collectingAndThen(collector, Collections::unmodifiableMap));
return toUnmodifiableConcurrentMap(registrations);
}

private static Map<String, ClientRegistration> toUnmodifiableConcurrentMap(List<ClientRegistration> registrations) {
ConcurrentHashMap<String, ClientRegistration> result = new ConcurrentHashMap<>();
for (ClientRegistration registration : registrations) {
if (result.containsKey(registration.getRegistrationId())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this check necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jzheaux
When we use toConcurrentMap from Collectors, it uses uniqKeysMapAccumulator

 /**
     * {@code BiConsumer<Map, T>} that accumulates (key, value) pairs
     * extracted from elements into the map, throwing {@code IllegalStateException}
     * if duplicate keys are encountered.
     *
     * @param keyMapper a function that maps an element into a key
     * @param valueMapper a function that maps an element into a value
     * @param <T> type of elements
     * @param <K> type of map keys
     * @param <V> type of map values
     * @return an accumulating consumer
     */
    private static <T, K, V>
    BiConsumer<Map<K, V>, T> uniqKeysMapAccumulator(Function<? super T, ? extends K> keyMapper,
                                                    Function<? super T, ? extends V> valueMapper) {
        return (map, element) -> {
            K k = keyMapper.apply(element);
            V v = Objects.requireNonNull(valueMapper.apply(element));
            V u = map.putIfAbsent(k, v);
            if (u != null) throw duplicateKeyException(k, u, v);
        };
    }

I also add check to save logic with throwing an exception, if we find duplicate key. Maybe, I understand something wrong and logic with client Registration can just rewrite value, how do you think? Can we break something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Collisions aren't a concern in this case but thank you for keeping an eye on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, let's go ahead and remove this extra check and exception. Collisions aren't a concern here, @kostya05983

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oks, also I need a remove test for it, because there is a test which checks duplicates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kostya05983 for some reason I'd missed that there is a test that relies on this functionality. I apologize, but because of that, we should add the duplicate checks as well as the test back in. My mistake.

throw new IllegalStateException(String.format("Duplicate key %s",
registration.getRegistrationId()));
}
result.put(registration.getRegistrationId(), registration);
}
return Collections.unmodifiableMap(result);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.util.Assert;

Expand Down Expand Up @@ -61,11 +60,9 @@ public InMemoryReactiveClientRegistrationRepository(ClientRegistration... regist
*/
public InMemoryReactiveClientRegistrationRepository(List<ClientRegistration> registrations) {
Assert.notEmpty(registrations, "registrations cannot be null or empty");
this.clientIdToClientRegistration = registrations.stream()
.collect(Collectors.toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity()));
this.clientIdToClientRegistration = toConcurrentMap(registrations);
}


@Override
public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
return Mono.justOrEmpty(this.clientIdToClientRegistration.get(registrationId));
Expand All @@ -80,4 +77,12 @@ public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
public Iterator<ClientRegistration> iterator() {
return this.clientIdToClientRegistration.values().iterator();
}

private ConcurrentHashMap<String, ClientRegistration> toConcurrentMap(List<ClientRegistration> registrations) {
ConcurrentHashMap<String, ClientRegistration> result = new ConcurrentHashMap<>();
for (ClientRegistration registration : registrations) {
result.put(registration.getRegistrationId(), registration);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ public void validateWhenMissingClaimsThenHasErrors() {
.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
}

@Test
public void validateFormatError() {
this.claims.remove(IdTokenClaimNames.SUB);
this.claims.remove(IdTokenClaimNames.AUD);
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch(msg -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}"));
}

private Collection<OAuth2Error> validateIdToken() {
Jwt idToken = new Jwt("token123", this.issuedAt, this.expiresAt, this.headers, this.claims);
OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ public void constructorListClientRegistrationWhenEmptyThenIllegalArgumentExcepti
new InMemoryClientRegistrationRepository(registrations);
}

@Test(expected = IllegalStateException.class)
public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
new InMemoryClientRegistrationRepository(registrations);
}

@Test(expected = IllegalArgumentException.class)
public void constructorMapClientRegistrationWhenNullThenIllegalArgumentException() {
new InMemoryClientRegistrationRepository((Map<String, ClientRegistration>) null);
Expand All @@ -67,6 +61,12 @@ public void constructorMapClientRegistrationWhenEmptyMapThenRepositoryIsEmpty()
assertThat(clients).isEmpty();
}

@Test(expected = IllegalStateException.class)
public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
new InMemoryClientRegistrationRepository(registrations);
}

@Test
public void findByRegistrationIdWhenFoundThenFound() {
String id = this.registration.getRegistrationId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.ArrayList;

/**
* @author Joe Grandja
Expand Down Expand Up @@ -64,10 +63,13 @@ public Object convert(Object source, TypeDescriptor sourceType, TypeDescriptor t
}
}
if (source instanceof Collection) {
return ((Collection<?>) source).stream()
.filter(Objects::nonNull)
.map(Objects::toString)
.collect(Collectors.toList());
Collection<String> results = new ArrayList<>();
for (Object object : ((Collection<?>) source)) {
if (object != null) {
results.add(object.toString());
}
}
return results;
}
return Collections.singletonList(source.toString());
}
Expand Down
Loading