View Javadoc
1   package org.oxerr.spring.cache.redis.scored;
2   
3   import java.time.Duration;
4   import java.util.Optional;
5   import java.util.function.Function;
6   
7   import org.springframework.data.redis.cache.CacheStatistics;
8   import org.springframework.data.redis.cache.CacheStatisticsCollector;
9   import org.springframework.data.redis.cache.RedisCacheWriter;
10  import org.springframework.data.redis.connection.RedisConnection;
11  import org.springframework.data.redis.connection.RedisConnectionFactory;
12  import org.springframework.data.redis.connection.RedisZSetCommands.Limit;
13  import org.springframework.data.redis.connection.RedisZSetCommands.Range;
14  import org.springframework.data.redis.connection.RedisZSetCommands.ZAddArgs;
15  import org.springframework.lang.NonNull;
16  import org.springframework.lang.Nullable;
17  import org.springframework.util.Assert;
18  
19  /**
20   * {@link RedisCacheWriter} implementation using sorted set as back-end store.
21   */
22  public class ScoredRedisCacheWriter implements RedisCacheWriter {
23  
24  	private static final String MUST_NOT_BE_NULL = " must not be null!";
25  
26  	private static final String NAME_NOT_NULL = "Name" + MUST_NOT_BE_NULL;
27  	private static final String KEY_NOT_NULL = "Key" + MUST_NOT_BE_NULL;
28  	private static final String VALUE_NOT_NULL = "Value" + MUST_NOT_BE_NULL;
29  
30  	private static final String OK = "OK";
31  
32  	private static final Double DEFAULT_SCORE = Double.valueOf(0d);
33  
34  	private final RedisConnectionFactory connectionFactory;
35  	private final CacheStatisticsCollector statistics;
36  	private final ScoreHolder scoreHolder;
37  	private final RedisCacheWriter cacheWriter;
38  
39  	public ScoredRedisCacheWriter(
40  		@NonNull RedisConnectionFactory connectionFactory
41  	) {
42  		this(connectionFactory, CacheStatisticsCollector.none());
43  	}
44  
45  	public ScoredRedisCacheWriter(
46  		@NonNull RedisConnectionFactory connectionFactory,
47  		@NonNull CacheStatisticsCollector cacheStatisticsCollector
48  	) {
49  		this(
50  			connectionFactory,
51  			cacheStatisticsCollector,
52  			new InheritableThreadLocalScoreHolder()
53  		);
54  	}
55  
56  	public ScoredRedisCacheWriter(
57  		@NonNull RedisConnectionFactory connectionFactory,
58  		@NonNull CacheStatisticsCollector cacheStatisticsCollector,
59  		@NonNull ScoreHolder scoreHolder
60  	) {
61  		this(
62  			connectionFactory,
63  			cacheStatisticsCollector,
64  			scoreHolder,
65  			RedisCacheWriter
66  				.nonLockingRedisCacheWriter(connectionFactory)
67  				.withStatisticsCollector(cacheStatisticsCollector)
68  		);
69  	}
70  
71  	public ScoredRedisCacheWriter(
72  		@NonNull RedisConnectionFactory connectionFactory,
73  		@NonNull CacheStatisticsCollector cacheStatisticsCollector,
74  		@NonNull ScoreHolder scoreHolder,
75  		@NonNull RedisCacheWriter cacheWriter
76  	) {
77  		this.connectionFactory = connectionFactory;
78  		this.statistics = cacheStatisticsCollector;
79  		this.scoreHolder = scoreHolder;
80  		this.cacheWriter = cacheWriter;
81  	}
82  
83  	@Override
84  	public CacheStatistics getCacheStatistics(String cacheName) {
85  		return statistics.getCacheStatistics(cacheName);
86  	}
87  
88  	@Override
89  	public void put(String name, byte[] key, byte[] value, Duration ttl) {
90  		Assert.notNull(name, NAME_NOT_NULL);
91  		Assert.notNull(key, KEY_NOT_NULL);
92  		Assert.notNull(value, VALUE_NOT_NULL);
93  
94  		final double score = getScore();
95  
96  		final long millis = ttl.toMillis();
97  		final Range range = Range.range().lt(score);
98  
99  		execute(name, connection -> {
100 
101 			connection.zAdd(key, score, value);
102 			connection.zRemRangeByScore(key, range);
103 
104 			if (shouldExpireWithin(ttl)) {
105 				connection.pExpire(key, millis);
106 			}
107 
108 			return OK;
109 		});
110 
111 		statistics.incPuts(name);
112 	}
113 
114 	@Override
115 	public byte[] get(String name, byte[] key) {
116 		Assert.notNull(name, NAME_NOT_NULL);
117 		Assert.notNull(key, KEY_NOT_NULL);
118 
119 		final Range range = Range.unbounded();
120 		final Limit limit = Limit.limit().count(1);
121 
122 		byte[] result = execute(name, connection -> connection.zRevRangeByScore(key, range, limit))
123 			.stream()
124 			.findFirst()
125 			.orElse(null);
126 
127 		statistics.incGets(name);
128 
129 		if (result != null) {
130 			statistics.incHits(name);
131 		} else {
132 			statistics.incMisses(name);
133 		}
134 
135 		return result;
136 	}
137 
138 	@Override
139 	public byte[] putIfAbsent(String name, byte[] key, byte[] value, Duration ttl) {
140 		Assert.notNull(name, NAME_NOT_NULL);
141 		Assert.notNull(key, KEY_NOT_NULL);
142 		Assert.notNull(value, VALUE_NOT_NULL);
143 
144 		final double score = getScore();
145 
146 		final long millis = ttl.toMillis();
147 		final Range range = Range.range().lt(score);
148 		final Limit limit = Limit.limit().count(1);
149 
150 		return execute(name, connection -> {
151 			final Boolean added = connection.zAdd(key, score, value, ZAddArgs.ifNotExists());
152 			connection.zRemRangeByScore(key, range);
153 
154 			if (shouldExpireWithin(ttl)) {
155 				connection.pExpire(key, millis);
156 			}
157 
158 			if (Boolean.TRUE.equals(added)) {
159 				statistics.incPuts(name);
160 				return null;
161 			}
162 
163 			return connection.zRevRangeByScore(key, range, limit)
164 				.stream()
165 				.findFirst()
166 				.orElse(null);
167 		});
168 	}
169 
170 	@Override
171 	public void remove(String name, byte[] key) {
172 		this.cacheWriter.remove(name, key);
173 	}
174 
175 	@Override
176 	public void clean(String name, byte[] pattern) {
177 		this.cacheWriter.clean(name, pattern);
178 	}
179 
180 	@Override
181 	public void clearStatistics(String name) {
182 		statistics.reset(name);
183 	}
184 
185 	@Override
186 	public ScoredRedisCacheWriter withStatisticsCollector(CacheStatisticsCollector cacheStatisticsCollector) {
187 		return new ScoredRedisCacheWriter(connectionFactory, cacheStatisticsCollector);
188 	}
189 
190 	private <T> T execute(String name, Function<RedisConnection, T> callback) {
191 		try (RedisConnection connection = connectionFactory.getConnection()) {
192 			return callback.apply(connection);
193 		}
194 	}
195 
196 	private static boolean shouldExpireWithin(@Nullable Duration ttl) {
197 		return ttl != null && !ttl.isZero() && !ttl.isNegative();
198 	}
199 
200 	private double getScore() {
201 		return Optional.ofNullable(this.scoreHolder.get()).orElse(DEFAULT_SCORE);
202 	}
203 
204 }